1# Copyright 2015, 2016 OpenMarket Ltd
2# Copyright 2019 New Vector Ltd
3# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import abc
17from typing import (
18    TYPE_CHECKING,
19    Collection,
20    Dict,
21    Iterable,
22    List,
23    Optional,
24    Tuple,
25    cast,
26)
27
28import attr
29from canonicaljson import encode_canonical_json
30
31from synapse.api.constants import DeviceKeyAlgorithms
32from synapse.logging.opentracing import log_kv, set_tag, trace
33from synapse.storage._base import SQLBaseStore, db_to_json
34from synapse.storage.database import (
35    DatabasePool,
36    LoggingDatabaseConnection,
37    LoggingTransaction,
38    make_in_list_sql_clause,
39)
40from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
41from synapse.storage.engines import PostgresEngine
42from synapse.storage.util.id_generators import StreamIdGenerator
43from synapse.types import JsonDict
44from synapse.util import json_encoder
45from synapse.util.caches.descriptors import cached, cachedList
46from synapse.util.iterutils import batch_iter
47
48if TYPE_CHECKING:
49    from synapse.handlers.e2e_keys import SignatureListItem
50    from synapse.server import HomeServer
51
52
53@attr.s(slots=True)
54class DeviceKeyLookupResult:
55    """The type returned by get_e2e_device_keys_and_signatures"""
56
57    display_name = attr.ib(type=Optional[str])
58
59    # the key data from e2e_device_keys_json. Typically includes fields like
60    # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
61    # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
62    keys = attr.ib(type=Optional[JsonDict])
63
64
65class EndToEndKeyBackgroundStore(SQLBaseStore):
66    def __init__(
67        self,
68        database: DatabasePool,
69        db_conn: LoggingDatabaseConnection,
70        hs: "HomeServer",
71    ):
72        super().__init__(database, db_conn, hs)
73
74        self.db_pool.updates.register_background_index_update(
75            "e2e_cross_signing_keys_idx",
76            index_name="e2e_cross_signing_keys_stream_idx",
77            table="e2e_cross_signing_keys",
78            columns=["stream_id"],
79            unique=True,
80        )
81
82
83class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
84    def __init__(
85        self,
86        database: DatabasePool,
87        db_conn: LoggingDatabaseConnection,
88        hs: "HomeServer",
89    ):
90        super().__init__(database, db_conn, hs)
91
92        self._allow_device_name_lookup_over_federation = (
93            self.hs.config.federation.allow_device_name_lookup_over_federation
94        )
95
96    async def get_e2e_device_keys_for_federation_query(
97        self, user_id: str
98    ) -> Tuple[int, List[JsonDict]]:
99        """Get all devices (with any device keys) for a user
100
101        Returns:
102            (stream_id, devices)
103        """
104        now_stream_id = self.get_device_stream_token()
105
106        devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
107
108        if devices:
109            user_devices = devices[user_id]
110            results = []
111            for device_id, device in user_devices.items():
112                result = {"device_id": device_id}
113
114                keys = device.keys
115                if keys:
116                    result["keys"] = keys
117
118                device_display_name = None
119                if self._allow_device_name_lookup_over_federation:
120                    device_display_name = device.display_name
121                if device_display_name:
122                    result["device_display_name"] = device_display_name
123
124                results.append(result)
125
126            return now_stream_id, results
127
128        return now_stream_id, []
129
130    @trace
131    async def get_e2e_device_keys_for_cs_api(
132        self, query_list: List[Tuple[str, Optional[str]]]
133    ) -> Dict[str, Dict[str, JsonDict]]:
134        """Fetch a list of device keys, formatted suitably for the C/S API.
135        Args:
136            query_list(list): List of pairs of user_ids and device_ids.
137        Returns:
138            Dict mapping from user-id to dict mapping from device_id to
139            key data.  The key data will be a dict in the same format as the
140            DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
141        """
142        set_tag("query_list", query_list)
143        if not query_list:
144            return {}
145
146        results = await self.get_e2e_device_keys_and_signatures(query_list)
147
148        # Build the result structure, un-jsonify the results, and add the
149        # "unsigned" section
150        rv: Dict[str, Dict[str, JsonDict]] = {}
151        for user_id, device_keys in results.items():
152            rv[user_id] = {}
153            for device_id, device_info in device_keys.items():
154                r = device_info.keys
155                r["unsigned"] = {}
156                display_name = device_info.display_name
157                if display_name is not None:
158                    r["unsigned"]["device_display_name"] = display_name
159                rv[user_id][device_id] = r
160
161        return rv
162
163    @trace
164    async def get_e2e_device_keys_and_signatures(
165        self,
166        query_list: List[Tuple[str, Optional[str]]],
167        include_all_devices: bool = False,
168        include_deleted_devices: bool = False,
169    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
170        """Fetch a list of device keys
171
172        Any cross-signatures made on the keys by the owner of the device are also
173        included.
174
175        The cross-signatures are added to the `signatures` field within the `keys`
176        object in the response.
177
178        Args:
179            query_list: List of pairs of user_ids and device_ids. Device id can be None
180                to indicate "all devices for this user"
181
182            include_all_devices: whether to return devices without device keys
183
184            include_deleted_devices: whether to include null entries for
185                devices which no longer exist (but were in the query_list).
186                This option only takes effect if include_all_devices is true.
187
188        Returns:
189            Dict mapping from user-id to dict mapping from device_id to
190            key data.
191        """
192        set_tag("include_all_devices", include_all_devices)
193        set_tag("include_deleted_devices", include_deleted_devices)
194
195        result = await self.db_pool.runInteraction(
196            "get_e2e_device_keys",
197            self._get_e2e_device_keys_txn,
198            query_list,
199            include_all_devices,
200            include_deleted_devices,
201        )
202
203        # get the (user_id, device_id) tuples to look up cross-signatures for
204        signature_query = (
205            (user_id, device_id)
206            for user_id, dev in result.items()
207            for device_id, d in dev.items()
208            if d is not None and d.keys is not None
209        )
210
211        for batch in batch_iter(signature_query, 50):
212            cross_sigs_result = await self.db_pool.runInteraction(
213                "get_e2e_cross_signing_signatures",
214                self._get_e2e_cross_signing_signatures_for_devices_txn,
215                batch,
216            )
217
218            # add each cross-signing signature to the correct device in the result dict.
219            for (user_id, key_id, device_id, signature) in cross_sigs_result:
220                target_device_result = result[user_id][device_id]
221                # We've only looked up cross-signatures for non-deleted devices with key
222                # data.
223                assert target_device_result is not None
224                assert target_device_result.keys is not None
225                target_device_signatures = target_device_result.keys.setdefault(
226                    "signatures", {}
227                )
228                signing_user_signatures = target_device_signatures.setdefault(
229                    user_id, {}
230                )
231                signing_user_signatures[key_id] = signature
232
233        log_kv(result)
234        return result
235
236    def _get_e2e_device_keys_txn(
237        self,
238        txn: LoggingTransaction,
239        query_list: Collection[Tuple[str, str]],
240        include_all_devices: bool = False,
241        include_deleted_devices: bool = False,
242    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
243        """Get information on devices from the database
244
245        The results include the device's keys and self-signatures, but *not* any
246        cross-signing signatures which have been added subsequently (for which, see
247        get_e2e_device_keys_and_signatures)
248        """
249        query_clauses = []
250        query_params = []
251
252        if include_all_devices is False:
253            include_deleted_devices = False
254
255        if include_deleted_devices:
256            deleted_devices = set(query_list)
257
258        for (user_id, device_id) in query_list:
259            query_clause = "user_id = ?"
260            query_params.append(user_id)
261
262            if device_id is not None:
263                query_clause += " AND device_id = ?"
264                query_params.append(device_id)
265
266            query_clauses.append(query_clause)
267
268        sql = (
269            "SELECT user_id, device_id, "
270            "    d.display_name, "
271            "    k.key_json"
272            " FROM devices d"
273            "    %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
274            " WHERE %s AND NOT d.hidden"
275        ) % (
276            "LEFT" if include_all_devices else "INNER",
277            " OR ".join("(" + q + ")" for q in query_clauses),
278        )
279
280        txn.execute(sql, query_params)
281
282        result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
283        for (user_id, device_id, display_name, key_json) in txn:
284            if include_deleted_devices:
285                deleted_devices.remove((user_id, device_id))
286            result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
287                display_name, db_to_json(key_json) if key_json else None
288            )
289
290        if include_deleted_devices:
291            for user_id, device_id in deleted_devices:
292                result.setdefault(user_id, {})[device_id] = None
293
294        return result
295
296    def _get_e2e_cross_signing_signatures_for_devices_txn(
297        self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
298    ) -> List[Tuple[str, str, str, str]]:
299        """Get cross-signing signatures for a given list of devices
300
301        Returns signatures made by the owners of the devices.
302
303        Returns: a list of results; each entry in the list is a tuple of
304            (user_id, key_id, target_device_id, signature).
305        """
306        signature_query_clauses = []
307        signature_query_params = []
308
309        for (user_id, device_id) in device_query:
310            signature_query_clauses.append(
311                "target_user_id = ? AND target_device_id = ? AND user_id = ?"
312            )
313            signature_query_params.extend([user_id, device_id, user_id])
314
315        signature_sql = """
316            SELECT user_id, key_id, target_device_id, signature
317            FROM e2e_cross_signing_signatures WHERE %s
318            """ % (
319            " OR ".join("(" + q + ")" for q in signature_query_clauses)
320        )
321
322        txn.execute(signature_sql, signature_query_params)
323        return cast(
324            List[
325                Tuple[
326                    str,
327                    str,
328                    str,
329                    str,
330                ]
331            ],
332            txn.fetchall(),
333        )
334
335    async def get_e2e_one_time_keys(
336        self, user_id: str, device_id: str, key_ids: List[str]
337    ) -> Dict[Tuple[str, str], str]:
338        """Retrieve a number of one-time keys for a user
339
340        Args:
341            user_id(str): id of user to get keys for
342            device_id(str): id of device to get keys for
343            key_ids(list[str]): list of key ids (excluding algorithm) to
344                retrieve
345
346        Returns:
347            A map from (algorithm, key_id) to json string for key
348        """
349
350        rows = await self.db_pool.simple_select_many_batch(
351            table="e2e_one_time_keys_json",
352            column="key_id",
353            iterable=key_ids,
354            retcols=("algorithm", "key_id", "key_json"),
355            keyvalues={"user_id": user_id, "device_id": device_id},
356            desc="add_e2e_one_time_keys_check",
357        )
358        result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
359        log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
360        return result
361
362    async def add_e2e_one_time_keys(
363        self,
364        user_id: str,
365        device_id: str,
366        time_now: int,
367        new_keys: Iterable[Tuple[str, str, str]],
368    ) -> None:
369        """Insert some new one time keys for a device. Errors if any of the
370        keys already exist.
371
372        Args:
373            user_id: id of user to get keys for
374            device_id: id of device to get keys for
375            time_now: insertion time to record (ms since epoch)
376            new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
377        """
378
379        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
380            set_tag("user_id", user_id)
381            set_tag("device_id", device_id)
382            set_tag("new_keys", new_keys)
383            # We are protected from race between lookup and insertion due to
384            # a unique constraint. If there is a race of two calls to
385            # `add_e2e_one_time_keys` then they'll conflict and we will only
386            # insert one set.
387            self.db_pool.simple_insert_many_txn(
388                txn,
389                table="e2e_one_time_keys_json",
390                values=[
391                    {
392                        "user_id": user_id,
393                        "device_id": device_id,
394                        "algorithm": algorithm,
395                        "key_id": key_id,
396                        "ts_added_ms": time_now,
397                        "key_json": json_bytes,
398                    }
399                    for algorithm, key_id, json_bytes in new_keys
400                ],
401            )
402            self._invalidate_cache_and_stream(
403                txn, self.count_e2e_one_time_keys, (user_id, device_id)
404            )
405
406        await self.db_pool.runInteraction(
407            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
408        )
409
410    @cached(max_entries=10000)
411    async def count_e2e_one_time_keys(
412        self, user_id: str, device_id: str
413    ) -> Dict[str, int]:
414        """Count the number of one time keys the server has for a device
415        Returns:
416            A mapping from algorithm to number of keys for that algorithm.
417        """
418
419        def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
420            sql = (
421                "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
422                " WHERE user_id = ? AND device_id = ?"
423                " GROUP BY algorithm"
424            )
425            txn.execute(sql, (user_id, device_id))
426
427            # Initially set the key count to 0. This ensures that the client will always
428            # receive *some count*, even if it's 0.
429            result = {DeviceKeyAlgorithms.SIGNED_CURVE25519: 0}
430
431            # Override entries with the count of any keys we pulled from the database
432            for algorithm, key_count in txn:
433                result[algorithm] = key_count
434
435            return result
436
437        return await self.db_pool.runInteraction(
438            "count_e2e_one_time_keys", _count_e2e_one_time_keys
439        )
440
441    async def set_e2e_fallback_keys(
442        self, user_id: str, device_id: str, fallback_keys: JsonDict
443    ) -> None:
444        """Set the user's e2e fallback keys.
445
446        Args:
447            user_id: the user whose keys are being set
448            device_id: the device whose keys are being set
449            fallback_keys: the keys to set.  This is a map from key ID (which is
450                of the form "algorithm:id") to key data.
451        """
452        await self.db_pool.runInteraction(
453            "set_e2e_fallback_keys_txn",
454            self._set_e2e_fallback_keys_txn,
455            user_id,
456            device_id,
457            fallback_keys,
458        )
459
460        await self.invalidate_cache_and_stream(
461            "get_e2e_unused_fallback_key_types", (user_id, device_id)
462        )
463
464    def _set_e2e_fallback_keys_txn(
465        self,
466        txn: LoggingTransaction,
467        user_id: str,
468        device_id: str,
469        fallback_keys: JsonDict,
470    ) -> None:
471        # fallback_keys will usually only have one item in it, so using a for
472        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
473        # FIXME: make sure that only one key per algorithm is uploaded
474        for key_id, fallback_key in fallback_keys.items():
475            algorithm, key_id = key_id.split(":", 1)
476            old_key_json = self.db_pool.simple_select_one_onecol_txn(
477                txn,
478                table="e2e_fallback_keys_json",
479                keyvalues={
480                    "user_id": user_id,
481                    "device_id": device_id,
482                    "algorithm": algorithm,
483                },
484                retcol="key_json",
485                allow_none=True,
486            )
487
488            new_key_json = encode_canonical_json(fallback_key).decode("utf-8")
489
490            # If the uploaded key is the same as the current fallback key,
491            # don't do anything.  This prevents marking the key as unused if it
492            # was already used.
493            if old_key_json != new_key_json:
494                self.db_pool.simple_upsert_txn(
495                    txn,
496                    table="e2e_fallback_keys_json",
497                    keyvalues={
498                        "user_id": user_id,
499                        "device_id": device_id,
500                        "algorithm": algorithm,
501                    },
502                    values={
503                        "key_id": key_id,
504                        "key_json": json_encoder.encode(fallback_key),
505                        "used": False,
506                    },
507                )
508
509    @cached(max_entries=10000)
510    async def get_e2e_unused_fallback_key_types(
511        self, user_id: str, device_id: str
512    ) -> List[str]:
513        """Returns the fallback key types that have an unused key.
514
515        Args:
516            user_id: the user whose keys are being queried
517            device_id: the device whose keys are being queried
518
519        Returns:
520            a list of key types
521        """
522        return await self.db_pool.simple_select_onecol(
523            "e2e_fallback_keys_json",
524            keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
525            retcol="algorithm",
526            desc="get_e2e_unused_fallback_key_types",
527        )
528
529    async def get_e2e_cross_signing_key(
530        self, user_id: str, key_type: str, from_user_id: Optional[str] = None
531    ) -> Optional[JsonDict]:
532        """Returns a user's cross-signing key.
533
534        Args:
535            user_id: the user whose key is being requested
536            key_type: the type of key that is being requested: either 'master'
537                for a master key, 'self_signing' for a self-signing key, or
538                'user_signing' for a user-signing key
539            from_user_id: if specified, signatures made by this user on
540                the self-signing key will be included in the result
541
542        Returns:
543            dict of the key data or None if not found
544        """
545        res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
546        user_keys = res.get(user_id)
547        if not user_keys:
548            return None
549        return user_keys.get(key_type)
550
551    @cached(num_args=1)
552    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
553        """Dummy function.  Only used to make a cache for
554        _get_bare_e2e_cross_signing_keys_bulk.
555        """
556        raise NotImplementedError()
557
558    @cachedList(
559        cached_method_name="_get_bare_e2e_cross_signing_keys",
560        list_name="user_ids",
561        num_args=1,
562    )
563    async def _get_bare_e2e_cross_signing_keys_bulk(
564        self, user_ids: Iterable[str]
565    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
566        """Returns the cross-signing keys for a set of users.  The output of this
567        function should be passed to _get_e2e_cross_signing_signatures_txn if
568        the signatures for the calling user need to be fetched.
569
570        Args:
571            user_ids: the users whose keys are being requested
572
573        Returns:
574            A mapping from user ID to key type to key data. If a user's cross-signing
575            keys were not found, either their user ID will not be in the dict, or
576            their user ID will map to None.
577
578        """
579        result = await self.db_pool.runInteraction(
580            "get_bare_e2e_cross_signing_keys_bulk",
581            self._get_bare_e2e_cross_signing_keys_bulk_txn,
582            user_ids,
583        )
584
585        # The `Optional` comes from the `@cachedList` decorator.
586        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
587
588    def _get_bare_e2e_cross_signing_keys_bulk_txn(
589        self,
590        txn: LoggingTransaction,
591        user_ids: Iterable[str],
592    ) -> Dict[str, Dict[str, JsonDict]]:
593        """Returns the cross-signing keys for a set of users.  The output of this
594        function should be passed to _get_e2e_cross_signing_signatures_txn if
595        the signatures for the calling user need to be fetched.
596
597        Args:
598            txn: db connection
599            user_ids: the users whose keys are being requested
600
601        Returns:
602            Mapping from user ID to key type to key data.
603            If a user's cross-signing keys were not found, their user ID will not be in
604            the dict.
605
606        """
607        result: Dict[str, Dict[str, JsonDict]] = {}
608
609        for user_chunk in batch_iter(user_ids, 100):
610            clause, params = make_in_list_sql_clause(
611                txn.database_engine, "user_id", user_chunk
612            )
613
614            # Fetch the latest key for each type per user.
615            if isinstance(self.database_engine, PostgresEngine):
616                # The `DISTINCT ON` clause will pick the *first* row it
617                # encounters, so ordering by stream ID desc will ensure we get
618                # the latest key.
619                sql = """
620                    SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
621                        FROM e2e_cross_signing_keys
622                        WHERE %(clause)s
623                        ORDER BY user_id, keytype, stream_id DESC
624                """ % {
625                    "clause": clause
626                }
627            else:
628                # SQLite has special handling for bare columns when using
629                # MIN/MAX with a `GROUP BY` clause where it picks the value from
630                # a row that matches the MIN/MAX.
631                sql = """
632                    SELECT user_id, keytype, keydata, MAX(stream_id)
633                        FROM e2e_cross_signing_keys
634                        WHERE %(clause)s
635                        GROUP BY user_id, keytype
636                """ % {
637                    "clause": clause
638                }
639
640            txn.execute(sql, params)
641            rows = self.db_pool.cursor_to_dict(txn)
642
643            for row in rows:
644                user_id = row["user_id"]
645                key_type = row["keytype"]
646                key = db_to_json(row["keydata"])
647                user_keys = result.setdefault(user_id, {})
648                user_keys[key_type] = key
649
650        return result
651
652    def _get_e2e_cross_signing_signatures_txn(
653        self,
654        txn: LoggingTransaction,
655        keys: Dict[str, Optional[Dict[str, JsonDict]]],
656        from_user_id: str,
657    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
658        """Returns the cross-signing signatures made by a user on a set of keys.
659
660        Args:
661            txn: db connection
662            keys: a map of user ID to key type to key data.
663                This dict will be modified to add signatures.
664            from_user_id: fetch the signatures made by this user
665
666        Returns:
667            Mapping from user ID to key type to key data.
668            The return value will be the same as the keys argument, with the
669            modifications included.
670        """
671
672        # find out what cross-signing keys (a.k.a. devices) we need to get
673        # signatures for.  This is a map of (user_id, device_id) to key type
674        # (device_id is the key's public part).
675        devices: Dict[Tuple[str, str], str] = {}
676
677        for user_id, user_keys in keys.items():
678            if user_keys is None:
679                continue
680            for key_type, key in user_keys.items():
681                device_id = None
682                for k in key["keys"].values():
683                    device_id = k
684                # `key` ought to be a `CrossSigningKey`, whose .keys property is a
685                # dictionary with a single entry:
686                #     "algorithm:base64_public_key": "base64_public_key"
687                # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
688                assert isinstance(device_id, str)
689                devices[(user_id, device_id)] = key_type
690
691        for batch in batch_iter(devices.keys(), size=100):
692            sql = """
693                SELECT target_user_id, target_device_id, key_id, signature
694                  FROM e2e_cross_signing_signatures
695                 WHERE user_id = ?
696                   AND (%s)
697            """ % (
698                " OR ".join(
699                    "(target_user_id = ? AND target_device_id = ?)" for _ in batch
700                )
701            )
702            query_params = [from_user_id]
703            for item in batch:
704                # item is a (user_id, device_id) tuple
705                query_params.extend(item)
706
707            txn.execute(sql, query_params)
708            rows = self.db_pool.cursor_to_dict(txn)
709
710            # and add the signatures to the appropriate keys
711            for row in rows:
712                key_id: str = row["key_id"]
713                target_user_id: str = row["target_user_id"]
714                target_device_id: str = row["target_device_id"]
715                key_type = devices[(target_user_id, target_device_id)]
716                # We need to copy everything, because the result may have come
717                # from the cache.  dict.copy only does a shallow copy, so we
718                # need to recursively copy the dicts that will be modified.
719                user_keys = keys[target_user_id]
720                # `user_keys` cannot be `None` because we only fetched signatures for
721                # users with keys
722                assert user_keys is not None
723                user_keys = keys[target_user_id] = user_keys.copy()
724
725                target_user_key = user_keys[key_type] = user_keys[key_type].copy()
726                if "signatures" in target_user_key:
727                    signatures = target_user_key["signatures"] = target_user_key[
728                        "signatures"
729                    ].copy()
730                    if from_user_id in signatures:
731                        user_sigs = signatures[from_user_id] = signatures[from_user_id]
732                        user_sigs[key_id] = row["signature"]
733                    else:
734                        signatures[from_user_id] = {key_id: row["signature"]}
735                else:
736                    target_user_key["signatures"] = {
737                        from_user_id: {key_id: row["signature"]}
738                    }
739
740        return keys
741
742    async def get_e2e_cross_signing_keys_bulk(
743        self, user_ids: List[str], from_user_id: Optional[str] = None
744    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
745        """Returns the cross-signing keys for a set of users.
746
747        Args:
748            user_ids: the users whose keys are being requested
749            from_user_id: if specified, signatures made by this user on
750                the self-signing keys will be included in the result
751
752        Returns:
753            A map of user ID to key type to key data.  If a user's cross-signing
754            keys were not found, either their user ID will not be in the dict,
755            or their user ID will map to None.
756        """
757
758        result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
759
760        if from_user_id:
761            result = await self.db_pool.runInteraction(
762                "get_e2e_cross_signing_signatures",
763                self._get_e2e_cross_signing_signatures_txn,
764                result,
765                from_user_id,
766            )
767
768        return result
769
770    async def get_all_user_signature_changes_for_remotes(
771        self, instance_name: str, last_id: int, current_id: int, limit: int
772    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
773        """Get updates for groups replication stream.
774
775        Note that the user signature stream represents when a user signs their
776        device with their user-signing key, which is not published to other
777        users or servers, so no `destination` is needed in the returned
778        list. However, this is needed to poke workers.
779
780        Args:
781            instance_name: The writer we want to fetch updates from. Unused
782                here since there is only ever one writer.
783            last_id: The token to fetch updates from. Exclusive.
784            current_id: The token to fetch updates up to. Inclusive.
785            limit: The requested limit for the number of rows to return. The
786                function may return more or fewer rows.
787
788        Returns:
789            A tuple consisting of: the updates, a token to use to fetch
790            subsequent updates, and whether we returned fewer rows than exists
791            between the requested tokens due to the limit.
792
793            The token returned can be used in a subsequent call to this
794            function to get further updatees.
795
796            The updates are a list of 2-tuples of stream ID and the row data
797        """
798
799        if last_id == current_id:
800            return [], current_id, False
801
802        def _get_all_user_signature_changes_for_remotes_txn(
803            txn: LoggingTransaction,
804        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
805            sql = """
806                SELECT stream_id, from_user_id AS user_id
807                FROM user_signature_stream
808                WHERE ? < stream_id AND stream_id <= ?
809                ORDER BY stream_id ASC
810                LIMIT ?
811            """
812            txn.execute(sql, (last_id, current_id, limit))
813
814            updates = [(row[0], (row[1:])) for row in txn]
815
816            limited = False
817            upto_token = current_id
818            if len(updates) >= limit:
819                upto_token = updates[-1][0]
820                limited = True
821
822            return updates, upto_token, limited
823
824        return await self.db_pool.runInteraction(
825            "get_all_user_signature_changes_for_remotes",
826            _get_all_user_signature_changes_for_remotes_txn,
827        )
828
829    @abc.abstractmethod
830    def get_device_stream_token(self) -> int:
831        """Get the current stream id from the _device_list_id_gen"""
832        ...
833
834    async def claim_e2e_one_time_keys(
835        self, query_list: Iterable[Tuple[str, str, str]]
836    ) -> Dict[str, Dict[str, Dict[str, str]]]:
837        """Take a list of one time keys out of the database.
838
839        Args:
840            query_list: An iterable of tuples of (user ID, device ID, algorithm).
841
842        Returns:
843            A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
844        """
845
846        @trace
847        def _claim_e2e_one_time_key_simple(
848            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
849        ) -> Optional[Tuple[str, str]]:
850            """Claim OTK for device for DBs that don't support RETURNING.
851
852            Returns:
853                A tuple of key name (algorithm + key ID) and key JSON, if an
854                OTK was found.
855            """
856
857            sql = """
858                SELECT key_id, key_json FROM e2e_one_time_keys_json
859                WHERE user_id = ? AND device_id = ? AND algorithm = ?
860                LIMIT 1
861            """
862
863            txn.execute(sql, (user_id, device_id, algorithm))
864            otk_row = txn.fetchone()
865            if otk_row is None:
866                return None
867
868            key_id, key_json = otk_row
869
870            self.db_pool.simple_delete_one_txn(
871                txn,
872                table="e2e_one_time_keys_json",
873                keyvalues={
874                    "user_id": user_id,
875                    "device_id": device_id,
876                    "algorithm": algorithm,
877                    "key_id": key_id,
878                },
879            )
880            self._invalidate_cache_and_stream(
881                txn, self.count_e2e_one_time_keys, (user_id, device_id)
882            )
883
884            return f"{algorithm}:{key_id}", key_json
885
886        @trace
887        def _claim_e2e_one_time_key_returning(
888            txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
889        ) -> Optional[Tuple[str, str]]:
890            """Claim OTK for device for DBs that support RETURNING.
891
892            Returns:
893                A tuple of key name (algorithm + key ID) and key JSON, if an
894                OTK was found.
895            """
896
897            # We can use RETURNING to do the fetch and DELETE in once step.
898            sql = """
899                DELETE FROM e2e_one_time_keys_json
900                WHERE user_id = ? AND device_id = ? AND algorithm = ?
901                    AND key_id IN (
902                        SELECT key_id FROM e2e_one_time_keys_json
903                        WHERE user_id = ? AND device_id = ? AND algorithm = ?
904                        LIMIT 1
905                    )
906                RETURNING key_id, key_json
907            """
908
909            txn.execute(
910                sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
911            )
912            otk_row = txn.fetchone()
913            if otk_row is None:
914                return None
915
916            self._invalidate_cache_and_stream(
917                txn, self.count_e2e_one_time_keys, (user_id, device_id)
918            )
919
920            key_id, key_json = otk_row
921            return f"{algorithm}:{key_id}", key_json
922
923        results: Dict[str, Dict[str, Dict[str, str]]] = {}
924        for user_id, device_id, algorithm in query_list:
925            if self.database_engine.supports_returning:
926                # If we support RETURNING clause we can use a single query that
927                # allows us to use autocommit mode.
928                _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
929                db_autocommit = True
930            else:
931                _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
932                db_autocommit = False
933
934            row = await self.db_pool.runInteraction(
935                "claim_e2e_one_time_keys",
936                _claim_e2e_one_time_key,
937                user_id,
938                device_id,
939                algorithm,
940                db_autocommit=db_autocommit,
941            )
942            if row:
943                device_results = results.setdefault(user_id, {}).setdefault(
944                    device_id, {}
945                )
946                device_results[row[0]] = row[1]
947                continue
948
949            # No one-time key available, so see if there's a fallback
950            # key
951            row = await self.db_pool.simple_select_one(
952                table="e2e_fallback_keys_json",
953                keyvalues={
954                    "user_id": user_id,
955                    "device_id": device_id,
956                    "algorithm": algorithm,
957                },
958                retcols=("key_id", "key_json", "used"),
959                desc="_get_fallback_key",
960                allow_none=True,
961            )
962            if row is None:
963                continue
964
965            key_id = row["key_id"]
966            key_json = row["key_json"]
967            used = row["used"]
968
969            # Mark fallback key as used if not already.
970            if not used:
971                await self.db_pool.simple_update_one(
972                    table="e2e_fallback_keys_json",
973                    keyvalues={
974                        "user_id": user_id,
975                        "device_id": device_id,
976                        "algorithm": algorithm,
977                        "key_id": key_id,
978                    },
979                    updatevalues={"used": True},
980                    desc="_get_fallback_key_set_used",
981                )
982                await self.invalidate_cache_and_stream(
983                    "get_e2e_unused_fallback_key_types", (user_id, device_id)
984                )
985
986            device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
987            device_results[f"{algorithm}:{key_id}"] = key_json
988
989        return results
990
991
992class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
993    def __init__(
994        self,
995        database: DatabasePool,
996        db_conn: LoggingDatabaseConnection,
997        hs: "HomeServer",
998    ):
999        super().__init__(database, db_conn, hs)
1000
1001        self._cross_signing_id_gen = StreamIdGenerator(
1002            db_conn, "e2e_cross_signing_keys", "stream_id"
1003        )
1004
1005    async def set_e2e_device_keys(
1006        self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
1007    ) -> bool:
1008        """Stores device keys for a device. Returns whether there was a change
1009        or the keys were already in the database.
1010        """
1011
1012        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
1013            set_tag("user_id", user_id)
1014            set_tag("device_id", device_id)
1015            set_tag("time_now", time_now)
1016            set_tag("device_keys", device_keys)
1017
1018            old_key_json = self.db_pool.simple_select_one_onecol_txn(
1019                txn,
1020                table="e2e_device_keys_json",
1021                keyvalues={"user_id": user_id, "device_id": device_id},
1022                retcol="key_json",
1023                allow_none=True,
1024            )
1025
1026            # In py3 we need old_key_json to match new_key_json type. The DB
1027            # returns unicode while encode_canonical_json returns bytes.
1028            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
1029
1030            if old_key_json == new_key_json:
1031                log_kv({"Message": "Device key already stored."})
1032                return False
1033
1034            self.db_pool.simple_upsert_txn(
1035                txn,
1036                table="e2e_device_keys_json",
1037                keyvalues={"user_id": user_id, "device_id": device_id},
1038                values={"ts_added_ms": time_now, "key_json": new_key_json},
1039            )
1040            log_kv({"message": "Device keys stored."})
1041            return True
1042
1043        return await self.db_pool.runInteraction(
1044            "set_e2e_device_keys", _set_e2e_device_keys_txn
1045        )
1046
1047    async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
1048        def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
1049            log_kv(
1050                {
1051                    "message": "Deleting keys for device",
1052                    "device_id": device_id,
1053                    "user_id": user_id,
1054                }
1055            )
1056            self.db_pool.simple_delete_txn(
1057                txn,
1058                table="e2e_device_keys_json",
1059                keyvalues={"user_id": user_id, "device_id": device_id},
1060            )
1061            self.db_pool.simple_delete_txn(
1062                txn,
1063                table="e2e_one_time_keys_json",
1064                keyvalues={"user_id": user_id, "device_id": device_id},
1065            )
1066            self._invalidate_cache_and_stream(
1067                txn, self.count_e2e_one_time_keys, (user_id, device_id)
1068            )
1069            self.db_pool.simple_delete_txn(
1070                txn,
1071                table="dehydrated_devices",
1072                keyvalues={"user_id": user_id, "device_id": device_id},
1073            )
1074            self.db_pool.simple_delete_txn(
1075                txn,
1076                table="e2e_fallback_keys_json",
1077                keyvalues={"user_id": user_id, "device_id": device_id},
1078            )
1079            self._invalidate_cache_and_stream(
1080                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
1081            )
1082
1083        await self.db_pool.runInteraction(
1084            "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
1085        )
1086
1087    def _set_e2e_cross_signing_key_txn(
1088        self,
1089        txn: LoggingTransaction,
1090        user_id: str,
1091        key_type: str,
1092        key: JsonDict,
1093        stream_id: int,
1094    ) -> None:
1095        """Set a user's cross-signing key.
1096
1097        Args:
1098            txn: db connection
1099            user_id: the user to set the signing key for
1100            key_type: the type of key that is being set: either 'master'
1101                for a master key, 'self_signing' for a self-signing key, or
1102                'user_signing' for a user-signing key
1103            key: the key data
1104            stream_id
1105        """
1106        # the 'key' dict will look something like:
1107        # {
1108        #   "user_id": "@alice:example.com",
1109        #   "usage": ["self_signing"],
1110        #   "keys": {
1111        #     "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
1112        #   },
1113        #   "signatures": {
1114        #     "@alice:example.com": {
1115        #       "ed25519:base64+master+public+key": "base64+signature"
1116        #     }
1117        #   }
1118        # }
1119        # The "keys" property must only have one entry, which will be the public
1120        # key, so we just grab the first value in there
1121        pubkey = next(iter(key["keys"].values()))
1122
1123        # The cross-signing keys need to occupy the same namespace as devices,
1124        # since signatures are identified by device ID.  So add an entry to the
1125        # device table to make sure that we don't have a collision with device
1126        # IDs.
1127        # We only need to do this for local users, since remote servers should be
1128        # responsible for checking this for their own users.
1129        if self.hs.is_mine_id(user_id):
1130            self.db_pool.simple_insert_txn(
1131                txn,
1132                "devices",
1133                values={
1134                    "user_id": user_id,
1135                    "device_id": pubkey,
1136                    "display_name": key_type + " signing key",
1137                    "hidden": True,
1138                },
1139            )
1140
1141        # and finally, store the key itself
1142        self.db_pool.simple_insert_txn(
1143            txn,
1144            "e2e_cross_signing_keys",
1145            values={
1146                "user_id": user_id,
1147                "keytype": key_type,
1148                "keydata": json_encoder.encode(key),
1149                "stream_id": stream_id,
1150            },
1151        )
1152
1153        self._invalidate_cache_and_stream(
1154            txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
1155        )
1156
1157    async def set_e2e_cross_signing_key(
1158        self, user_id: str, key_type: str, key: JsonDict
1159    ) -> None:
1160        """Set a user's cross-signing key.
1161
1162        Args:
1163            user_id: the user to set the user-signing key for
1164            key_type: the type of cross-signing key to set
1165            key: the key data
1166        """
1167
1168        async with self._cross_signing_id_gen.get_next() as stream_id:
1169            return await self.db_pool.runInteraction(
1170                "add_e2e_cross_signing_key",
1171                self._set_e2e_cross_signing_key_txn,
1172                user_id,
1173                key_type,
1174                key,
1175                stream_id,
1176            )
1177
1178    async def store_e2e_cross_signing_signatures(
1179        self, user_id: str, signatures: "Iterable[SignatureListItem]"
1180    ) -> None:
1181        """Stores cross-signing signatures.
1182
1183        Args:
1184            user_id: the user who made the signatures
1185            signatures: signatures to add
1186        """
1187        await self.db_pool.simple_insert_many(
1188            "e2e_cross_signing_signatures",
1189            [
1190                {
1191                    "user_id": user_id,
1192                    "key_id": item.signing_key_id,
1193                    "target_user_id": item.target_user_id,
1194                    "target_device_id": item.target_device_id,
1195                    "signature": item.signature,
1196                }
1197                for item in signatures
1198            ],
1199            "add_e2e_signing_key",
1200        )
1201