1# Copyright 2015, 2016 OpenMarket Ltd 2# Copyright 2019 New Vector Ltd 3# Copyright 2019,2020 The Matrix.org Foundation C.I.C. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16import abc 17from typing import ( 18 TYPE_CHECKING, 19 Collection, 20 Dict, 21 Iterable, 22 List, 23 Optional, 24 Tuple, 25 cast, 26) 27 28import attr 29from canonicaljson import encode_canonical_json 30 31from synapse.api.constants import DeviceKeyAlgorithms 32from synapse.logging.opentracing import log_kv, set_tag, trace 33from synapse.storage._base import SQLBaseStore, db_to_json 34from synapse.storage.database import ( 35 DatabasePool, 36 LoggingDatabaseConnection, 37 LoggingTransaction, 38 make_in_list_sql_clause, 39) 40from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore 41from synapse.storage.engines import PostgresEngine 42from synapse.storage.util.id_generators import StreamIdGenerator 43from synapse.types import JsonDict 44from synapse.util import json_encoder 45from synapse.util.caches.descriptors import cached, cachedList 46from synapse.util.iterutils import batch_iter 47 48if TYPE_CHECKING: 49 from synapse.handlers.e2e_keys import SignatureListItem 50 from synapse.server import HomeServer 51 52 53@attr.s(slots=True) 54class DeviceKeyLookupResult: 55 """The type returned by get_e2e_device_keys_and_signatures""" 56 57 display_name = attr.ib(type=Optional[str]) 58 59 # the key data from e2e_device_keys_json. Typically includes fields like 60 # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing 61 # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.) 62 keys = attr.ib(type=Optional[JsonDict]) 63 64 65class EndToEndKeyBackgroundStore(SQLBaseStore): 66 def __init__( 67 self, 68 database: DatabasePool, 69 db_conn: LoggingDatabaseConnection, 70 hs: "HomeServer", 71 ): 72 super().__init__(database, db_conn, hs) 73 74 self.db_pool.updates.register_background_index_update( 75 "e2e_cross_signing_keys_idx", 76 index_name="e2e_cross_signing_keys_stream_idx", 77 table="e2e_cross_signing_keys", 78 columns=["stream_id"], 79 unique=True, 80 ) 81 82 83class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore): 84 def __init__( 85 self, 86 database: DatabasePool, 87 db_conn: LoggingDatabaseConnection, 88 hs: "HomeServer", 89 ): 90 super().__init__(database, db_conn, hs) 91 92 self._allow_device_name_lookup_over_federation = ( 93 self.hs.config.federation.allow_device_name_lookup_over_federation 94 ) 95 96 async def get_e2e_device_keys_for_federation_query( 97 self, user_id: str 98 ) -> Tuple[int, List[JsonDict]]: 99 """Get all devices (with any device keys) for a user 100 101 Returns: 102 (stream_id, devices) 103 """ 104 now_stream_id = self.get_device_stream_token() 105 106 devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) 107 108 if devices: 109 user_devices = devices[user_id] 110 results = [] 111 for device_id, device in user_devices.items(): 112 result = {"device_id": device_id} 113 114 keys = device.keys 115 if keys: 116 result["keys"] = keys 117 118 device_display_name = None 119 if self._allow_device_name_lookup_over_federation: 120 device_display_name = device.display_name 121 if device_display_name: 122 result["device_display_name"] = device_display_name 123 124 results.append(result) 125 126 return now_stream_id, results 127 128 return now_stream_id, [] 129 130 @trace 131 async def get_e2e_device_keys_for_cs_api( 132 self, query_list: List[Tuple[str, Optional[str]]] 133 ) -> Dict[str, Dict[str, JsonDict]]: 134 """Fetch a list of device keys, formatted suitably for the C/S API. 135 Args: 136 query_list(list): List of pairs of user_ids and device_ids. 137 Returns: 138 Dict mapping from user-id to dict mapping from device_id to 139 key data. The key data will be a dict in the same format as the 140 DeviceKeys type returned by POST /_matrix/client/r0/keys/query. 141 """ 142 set_tag("query_list", query_list) 143 if not query_list: 144 return {} 145 146 results = await self.get_e2e_device_keys_and_signatures(query_list) 147 148 # Build the result structure, un-jsonify the results, and add the 149 # "unsigned" section 150 rv: Dict[str, Dict[str, JsonDict]] = {} 151 for user_id, device_keys in results.items(): 152 rv[user_id] = {} 153 for device_id, device_info in device_keys.items(): 154 r = device_info.keys 155 r["unsigned"] = {} 156 display_name = device_info.display_name 157 if display_name is not None: 158 r["unsigned"]["device_display_name"] = display_name 159 rv[user_id][device_id] = r 160 161 return rv 162 163 @trace 164 async def get_e2e_device_keys_and_signatures( 165 self, 166 query_list: List[Tuple[str, Optional[str]]], 167 include_all_devices: bool = False, 168 include_deleted_devices: bool = False, 169 ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: 170 """Fetch a list of device keys 171 172 Any cross-signatures made on the keys by the owner of the device are also 173 included. 174 175 The cross-signatures are added to the `signatures` field within the `keys` 176 object in the response. 177 178 Args: 179 query_list: List of pairs of user_ids and device_ids. Device id can be None 180 to indicate "all devices for this user" 181 182 include_all_devices: whether to return devices without device keys 183 184 include_deleted_devices: whether to include null entries for 185 devices which no longer exist (but were in the query_list). 186 This option only takes effect if include_all_devices is true. 187 188 Returns: 189 Dict mapping from user-id to dict mapping from device_id to 190 key data. 191 """ 192 set_tag("include_all_devices", include_all_devices) 193 set_tag("include_deleted_devices", include_deleted_devices) 194 195 result = await self.db_pool.runInteraction( 196 "get_e2e_device_keys", 197 self._get_e2e_device_keys_txn, 198 query_list, 199 include_all_devices, 200 include_deleted_devices, 201 ) 202 203 # get the (user_id, device_id) tuples to look up cross-signatures for 204 signature_query = ( 205 (user_id, device_id) 206 for user_id, dev in result.items() 207 for device_id, d in dev.items() 208 if d is not None and d.keys is not None 209 ) 210 211 for batch in batch_iter(signature_query, 50): 212 cross_sigs_result = await self.db_pool.runInteraction( 213 "get_e2e_cross_signing_signatures", 214 self._get_e2e_cross_signing_signatures_for_devices_txn, 215 batch, 216 ) 217 218 # add each cross-signing signature to the correct device in the result dict. 219 for (user_id, key_id, device_id, signature) in cross_sigs_result: 220 target_device_result = result[user_id][device_id] 221 # We've only looked up cross-signatures for non-deleted devices with key 222 # data. 223 assert target_device_result is not None 224 assert target_device_result.keys is not None 225 target_device_signatures = target_device_result.keys.setdefault( 226 "signatures", {} 227 ) 228 signing_user_signatures = target_device_signatures.setdefault( 229 user_id, {} 230 ) 231 signing_user_signatures[key_id] = signature 232 233 log_kv(result) 234 return result 235 236 def _get_e2e_device_keys_txn( 237 self, 238 txn: LoggingTransaction, 239 query_list: Collection[Tuple[str, str]], 240 include_all_devices: bool = False, 241 include_deleted_devices: bool = False, 242 ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: 243 """Get information on devices from the database 244 245 The results include the device's keys and self-signatures, but *not* any 246 cross-signing signatures which have been added subsequently (for which, see 247 get_e2e_device_keys_and_signatures) 248 """ 249 query_clauses = [] 250 query_params = [] 251 252 if include_all_devices is False: 253 include_deleted_devices = False 254 255 if include_deleted_devices: 256 deleted_devices = set(query_list) 257 258 for (user_id, device_id) in query_list: 259 query_clause = "user_id = ?" 260 query_params.append(user_id) 261 262 if device_id is not None: 263 query_clause += " AND device_id = ?" 264 query_params.append(device_id) 265 266 query_clauses.append(query_clause) 267 268 sql = ( 269 "SELECT user_id, device_id, " 270 " d.display_name, " 271 " k.key_json" 272 " FROM devices d" 273 " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" 274 " WHERE %s AND NOT d.hidden" 275 ) % ( 276 "LEFT" if include_all_devices else "INNER", 277 " OR ".join("(" + q + ")" for q in query_clauses), 278 ) 279 280 txn.execute(sql, query_params) 281 282 result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {} 283 for (user_id, device_id, display_name, key_json) in txn: 284 if include_deleted_devices: 285 deleted_devices.remove((user_id, device_id)) 286 result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( 287 display_name, db_to_json(key_json) if key_json else None 288 ) 289 290 if include_deleted_devices: 291 for user_id, device_id in deleted_devices: 292 result.setdefault(user_id, {})[device_id] = None 293 294 return result 295 296 def _get_e2e_cross_signing_signatures_for_devices_txn( 297 self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]] 298 ) -> List[Tuple[str, str, str, str]]: 299 """Get cross-signing signatures for a given list of devices 300 301 Returns signatures made by the owners of the devices. 302 303 Returns: a list of results; each entry in the list is a tuple of 304 (user_id, key_id, target_device_id, signature). 305 """ 306 signature_query_clauses = [] 307 signature_query_params = [] 308 309 for (user_id, device_id) in device_query: 310 signature_query_clauses.append( 311 "target_user_id = ? AND target_device_id = ? AND user_id = ?" 312 ) 313 signature_query_params.extend([user_id, device_id, user_id]) 314 315 signature_sql = """ 316 SELECT user_id, key_id, target_device_id, signature 317 FROM e2e_cross_signing_signatures WHERE %s 318 """ % ( 319 " OR ".join("(" + q + ")" for q in signature_query_clauses) 320 ) 321 322 txn.execute(signature_sql, signature_query_params) 323 return cast( 324 List[ 325 Tuple[ 326 str, 327 str, 328 str, 329 str, 330 ] 331 ], 332 txn.fetchall(), 333 ) 334 335 async def get_e2e_one_time_keys( 336 self, user_id: str, device_id: str, key_ids: List[str] 337 ) -> Dict[Tuple[str, str], str]: 338 """Retrieve a number of one-time keys for a user 339 340 Args: 341 user_id(str): id of user to get keys for 342 device_id(str): id of device to get keys for 343 key_ids(list[str]): list of key ids (excluding algorithm) to 344 retrieve 345 346 Returns: 347 A map from (algorithm, key_id) to json string for key 348 """ 349 350 rows = await self.db_pool.simple_select_many_batch( 351 table="e2e_one_time_keys_json", 352 column="key_id", 353 iterable=key_ids, 354 retcols=("algorithm", "key_id", "key_json"), 355 keyvalues={"user_id": user_id, "device_id": device_id}, 356 desc="add_e2e_one_time_keys_check", 357 ) 358 result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows} 359 log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) 360 return result 361 362 async def add_e2e_one_time_keys( 363 self, 364 user_id: str, 365 device_id: str, 366 time_now: int, 367 new_keys: Iterable[Tuple[str, str, str]], 368 ) -> None: 369 """Insert some new one time keys for a device. Errors if any of the 370 keys already exist. 371 372 Args: 373 user_id: id of user to get keys for 374 device_id: id of device to get keys for 375 time_now: insertion time to record (ms since epoch) 376 new_keys: keys to add - each a tuple of (algorithm, key_id, key json) 377 """ 378 379 def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: 380 set_tag("user_id", user_id) 381 set_tag("device_id", device_id) 382 set_tag("new_keys", new_keys) 383 # We are protected from race between lookup and insertion due to 384 # a unique constraint. If there is a race of two calls to 385 # `add_e2e_one_time_keys` then they'll conflict and we will only 386 # insert one set. 387 self.db_pool.simple_insert_many_txn( 388 txn, 389 table="e2e_one_time_keys_json", 390 values=[ 391 { 392 "user_id": user_id, 393 "device_id": device_id, 394 "algorithm": algorithm, 395 "key_id": key_id, 396 "ts_added_ms": time_now, 397 "key_json": json_bytes, 398 } 399 for algorithm, key_id, json_bytes in new_keys 400 ], 401 ) 402 self._invalidate_cache_and_stream( 403 txn, self.count_e2e_one_time_keys, (user_id, device_id) 404 ) 405 406 await self.db_pool.runInteraction( 407 "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys 408 ) 409 410 @cached(max_entries=10000) 411 async def count_e2e_one_time_keys( 412 self, user_id: str, device_id: str 413 ) -> Dict[str, int]: 414 """Count the number of one time keys the server has for a device 415 Returns: 416 A mapping from algorithm to number of keys for that algorithm. 417 """ 418 419 def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]: 420 sql = ( 421 "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" 422 " WHERE user_id = ? AND device_id = ?" 423 " GROUP BY algorithm" 424 ) 425 txn.execute(sql, (user_id, device_id)) 426 427 # Initially set the key count to 0. This ensures that the client will always 428 # receive *some count*, even if it's 0. 429 result = {DeviceKeyAlgorithms.SIGNED_CURVE25519: 0} 430 431 # Override entries with the count of any keys we pulled from the database 432 for algorithm, key_count in txn: 433 result[algorithm] = key_count 434 435 return result 436 437 return await self.db_pool.runInteraction( 438 "count_e2e_one_time_keys", _count_e2e_one_time_keys 439 ) 440 441 async def set_e2e_fallback_keys( 442 self, user_id: str, device_id: str, fallback_keys: JsonDict 443 ) -> None: 444 """Set the user's e2e fallback keys. 445 446 Args: 447 user_id: the user whose keys are being set 448 device_id: the device whose keys are being set 449 fallback_keys: the keys to set. This is a map from key ID (which is 450 of the form "algorithm:id") to key data. 451 """ 452 await self.db_pool.runInteraction( 453 "set_e2e_fallback_keys_txn", 454 self._set_e2e_fallback_keys_txn, 455 user_id, 456 device_id, 457 fallback_keys, 458 ) 459 460 await self.invalidate_cache_and_stream( 461 "get_e2e_unused_fallback_key_types", (user_id, device_id) 462 ) 463 464 def _set_e2e_fallback_keys_txn( 465 self, 466 txn: LoggingTransaction, 467 user_id: str, 468 device_id: str, 469 fallback_keys: JsonDict, 470 ) -> None: 471 # fallback_keys will usually only have one item in it, so using a for 472 # loop (as opposed to calling simple_upsert_many_txn) won't be too bad 473 # FIXME: make sure that only one key per algorithm is uploaded 474 for key_id, fallback_key in fallback_keys.items(): 475 algorithm, key_id = key_id.split(":", 1) 476 old_key_json = self.db_pool.simple_select_one_onecol_txn( 477 txn, 478 table="e2e_fallback_keys_json", 479 keyvalues={ 480 "user_id": user_id, 481 "device_id": device_id, 482 "algorithm": algorithm, 483 }, 484 retcol="key_json", 485 allow_none=True, 486 ) 487 488 new_key_json = encode_canonical_json(fallback_key).decode("utf-8") 489 490 # If the uploaded key is the same as the current fallback key, 491 # don't do anything. This prevents marking the key as unused if it 492 # was already used. 493 if old_key_json != new_key_json: 494 self.db_pool.simple_upsert_txn( 495 txn, 496 table="e2e_fallback_keys_json", 497 keyvalues={ 498 "user_id": user_id, 499 "device_id": device_id, 500 "algorithm": algorithm, 501 }, 502 values={ 503 "key_id": key_id, 504 "key_json": json_encoder.encode(fallback_key), 505 "used": False, 506 }, 507 ) 508 509 @cached(max_entries=10000) 510 async def get_e2e_unused_fallback_key_types( 511 self, user_id: str, device_id: str 512 ) -> List[str]: 513 """Returns the fallback key types that have an unused key. 514 515 Args: 516 user_id: the user whose keys are being queried 517 device_id: the device whose keys are being queried 518 519 Returns: 520 a list of key types 521 """ 522 return await self.db_pool.simple_select_onecol( 523 "e2e_fallback_keys_json", 524 keyvalues={"user_id": user_id, "device_id": device_id, "used": False}, 525 retcol="algorithm", 526 desc="get_e2e_unused_fallback_key_types", 527 ) 528 529 async def get_e2e_cross_signing_key( 530 self, user_id: str, key_type: str, from_user_id: Optional[str] = None 531 ) -> Optional[JsonDict]: 532 """Returns a user's cross-signing key. 533 534 Args: 535 user_id: the user whose key is being requested 536 key_type: the type of key that is being requested: either 'master' 537 for a master key, 'self_signing' for a self-signing key, or 538 'user_signing' for a user-signing key 539 from_user_id: if specified, signatures made by this user on 540 the self-signing key will be included in the result 541 542 Returns: 543 dict of the key data or None if not found 544 """ 545 res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) 546 user_keys = res.get(user_id) 547 if not user_keys: 548 return None 549 return user_keys.get(key_type) 550 551 @cached(num_args=1) 552 def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: 553 """Dummy function. Only used to make a cache for 554 _get_bare_e2e_cross_signing_keys_bulk. 555 """ 556 raise NotImplementedError() 557 558 @cachedList( 559 cached_method_name="_get_bare_e2e_cross_signing_keys", 560 list_name="user_ids", 561 num_args=1, 562 ) 563 async def _get_bare_e2e_cross_signing_keys_bulk( 564 self, user_ids: Iterable[str] 565 ) -> Dict[str, Optional[Dict[str, JsonDict]]]: 566 """Returns the cross-signing keys for a set of users. The output of this 567 function should be passed to _get_e2e_cross_signing_signatures_txn if 568 the signatures for the calling user need to be fetched. 569 570 Args: 571 user_ids: the users whose keys are being requested 572 573 Returns: 574 A mapping from user ID to key type to key data. If a user's cross-signing 575 keys were not found, either their user ID will not be in the dict, or 576 their user ID will map to None. 577 578 """ 579 result = await self.db_pool.runInteraction( 580 "get_bare_e2e_cross_signing_keys_bulk", 581 self._get_bare_e2e_cross_signing_keys_bulk_txn, 582 user_ids, 583 ) 584 585 # The `Optional` comes from the `@cachedList` decorator. 586 return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) 587 588 def _get_bare_e2e_cross_signing_keys_bulk_txn( 589 self, 590 txn: LoggingTransaction, 591 user_ids: Iterable[str], 592 ) -> Dict[str, Dict[str, JsonDict]]: 593 """Returns the cross-signing keys for a set of users. The output of this 594 function should be passed to _get_e2e_cross_signing_signatures_txn if 595 the signatures for the calling user need to be fetched. 596 597 Args: 598 txn: db connection 599 user_ids: the users whose keys are being requested 600 601 Returns: 602 Mapping from user ID to key type to key data. 603 If a user's cross-signing keys were not found, their user ID will not be in 604 the dict. 605 606 """ 607 result: Dict[str, Dict[str, JsonDict]] = {} 608 609 for user_chunk in batch_iter(user_ids, 100): 610 clause, params = make_in_list_sql_clause( 611 txn.database_engine, "user_id", user_chunk 612 ) 613 614 # Fetch the latest key for each type per user. 615 if isinstance(self.database_engine, PostgresEngine): 616 # The `DISTINCT ON` clause will pick the *first* row it 617 # encounters, so ordering by stream ID desc will ensure we get 618 # the latest key. 619 sql = """ 620 SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id 621 FROM e2e_cross_signing_keys 622 WHERE %(clause)s 623 ORDER BY user_id, keytype, stream_id DESC 624 """ % { 625 "clause": clause 626 } 627 else: 628 # SQLite has special handling for bare columns when using 629 # MIN/MAX with a `GROUP BY` clause where it picks the value from 630 # a row that matches the MIN/MAX. 631 sql = """ 632 SELECT user_id, keytype, keydata, MAX(stream_id) 633 FROM e2e_cross_signing_keys 634 WHERE %(clause)s 635 GROUP BY user_id, keytype 636 """ % { 637 "clause": clause 638 } 639 640 txn.execute(sql, params) 641 rows = self.db_pool.cursor_to_dict(txn) 642 643 for row in rows: 644 user_id = row["user_id"] 645 key_type = row["keytype"] 646 key = db_to_json(row["keydata"]) 647 user_keys = result.setdefault(user_id, {}) 648 user_keys[key_type] = key 649 650 return result 651 652 def _get_e2e_cross_signing_signatures_txn( 653 self, 654 txn: LoggingTransaction, 655 keys: Dict[str, Optional[Dict[str, JsonDict]]], 656 from_user_id: str, 657 ) -> Dict[str, Optional[Dict[str, JsonDict]]]: 658 """Returns the cross-signing signatures made by a user on a set of keys. 659 660 Args: 661 txn: db connection 662 keys: a map of user ID to key type to key data. 663 This dict will be modified to add signatures. 664 from_user_id: fetch the signatures made by this user 665 666 Returns: 667 Mapping from user ID to key type to key data. 668 The return value will be the same as the keys argument, with the 669 modifications included. 670 """ 671 672 # find out what cross-signing keys (a.k.a. devices) we need to get 673 # signatures for. This is a map of (user_id, device_id) to key type 674 # (device_id is the key's public part). 675 devices: Dict[Tuple[str, str], str] = {} 676 677 for user_id, user_keys in keys.items(): 678 if user_keys is None: 679 continue 680 for key_type, key in user_keys.items(): 681 device_id = None 682 for k in key["keys"].values(): 683 device_id = k 684 # `key` ought to be a `CrossSigningKey`, whose .keys property is a 685 # dictionary with a single entry: 686 # "algorithm:base64_public_key": "base64_public_key" 687 # See https://spec.matrix.org/v1.1/client-server-api/#cross-signing 688 assert isinstance(device_id, str) 689 devices[(user_id, device_id)] = key_type 690 691 for batch in batch_iter(devices.keys(), size=100): 692 sql = """ 693 SELECT target_user_id, target_device_id, key_id, signature 694 FROM e2e_cross_signing_signatures 695 WHERE user_id = ? 696 AND (%s) 697 """ % ( 698 " OR ".join( 699 "(target_user_id = ? AND target_device_id = ?)" for _ in batch 700 ) 701 ) 702 query_params = [from_user_id] 703 for item in batch: 704 # item is a (user_id, device_id) tuple 705 query_params.extend(item) 706 707 txn.execute(sql, query_params) 708 rows = self.db_pool.cursor_to_dict(txn) 709 710 # and add the signatures to the appropriate keys 711 for row in rows: 712 key_id: str = row["key_id"] 713 target_user_id: str = row["target_user_id"] 714 target_device_id: str = row["target_device_id"] 715 key_type = devices[(target_user_id, target_device_id)] 716 # We need to copy everything, because the result may have come 717 # from the cache. dict.copy only does a shallow copy, so we 718 # need to recursively copy the dicts that will be modified. 719 user_keys = keys[target_user_id] 720 # `user_keys` cannot be `None` because we only fetched signatures for 721 # users with keys 722 assert user_keys is not None 723 user_keys = keys[target_user_id] = user_keys.copy() 724 725 target_user_key = user_keys[key_type] = user_keys[key_type].copy() 726 if "signatures" in target_user_key: 727 signatures = target_user_key["signatures"] = target_user_key[ 728 "signatures" 729 ].copy() 730 if from_user_id in signatures: 731 user_sigs = signatures[from_user_id] = signatures[from_user_id] 732 user_sigs[key_id] = row["signature"] 733 else: 734 signatures[from_user_id] = {key_id: row["signature"]} 735 else: 736 target_user_key["signatures"] = { 737 from_user_id: {key_id: row["signature"]} 738 } 739 740 return keys 741 742 async def get_e2e_cross_signing_keys_bulk( 743 self, user_ids: List[str], from_user_id: Optional[str] = None 744 ) -> Dict[str, Optional[Dict[str, JsonDict]]]: 745 """Returns the cross-signing keys for a set of users. 746 747 Args: 748 user_ids: the users whose keys are being requested 749 from_user_id: if specified, signatures made by this user on 750 the self-signing keys will be included in the result 751 752 Returns: 753 A map of user ID to key type to key data. If a user's cross-signing 754 keys were not found, either their user ID will not be in the dict, 755 or their user ID will map to None. 756 """ 757 758 result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) 759 760 if from_user_id: 761 result = await self.db_pool.runInteraction( 762 "get_e2e_cross_signing_signatures", 763 self._get_e2e_cross_signing_signatures_txn, 764 result, 765 from_user_id, 766 ) 767 768 return result 769 770 async def get_all_user_signature_changes_for_remotes( 771 self, instance_name: str, last_id: int, current_id: int, limit: int 772 ) -> Tuple[List[Tuple[int, tuple]], int, bool]: 773 """Get updates for groups replication stream. 774 775 Note that the user signature stream represents when a user signs their 776 device with their user-signing key, which is not published to other 777 users or servers, so no `destination` is needed in the returned 778 list. However, this is needed to poke workers. 779 780 Args: 781 instance_name: The writer we want to fetch updates from. Unused 782 here since there is only ever one writer. 783 last_id: The token to fetch updates from. Exclusive. 784 current_id: The token to fetch updates up to. Inclusive. 785 limit: The requested limit for the number of rows to return. The 786 function may return more or fewer rows. 787 788 Returns: 789 A tuple consisting of: the updates, a token to use to fetch 790 subsequent updates, and whether we returned fewer rows than exists 791 between the requested tokens due to the limit. 792 793 The token returned can be used in a subsequent call to this 794 function to get further updatees. 795 796 The updates are a list of 2-tuples of stream ID and the row data 797 """ 798 799 if last_id == current_id: 800 return [], current_id, False 801 802 def _get_all_user_signature_changes_for_remotes_txn( 803 txn: LoggingTransaction, 804 ) -> Tuple[List[Tuple[int, tuple]], int, bool]: 805 sql = """ 806 SELECT stream_id, from_user_id AS user_id 807 FROM user_signature_stream 808 WHERE ? < stream_id AND stream_id <= ? 809 ORDER BY stream_id ASC 810 LIMIT ? 811 """ 812 txn.execute(sql, (last_id, current_id, limit)) 813 814 updates = [(row[0], (row[1:])) for row in txn] 815 816 limited = False 817 upto_token = current_id 818 if len(updates) >= limit: 819 upto_token = updates[-1][0] 820 limited = True 821 822 return updates, upto_token, limited 823 824 return await self.db_pool.runInteraction( 825 "get_all_user_signature_changes_for_remotes", 826 _get_all_user_signature_changes_for_remotes_txn, 827 ) 828 829 @abc.abstractmethod 830 def get_device_stream_token(self) -> int: 831 """Get the current stream id from the _device_list_id_gen""" 832 ... 833 834 async def claim_e2e_one_time_keys( 835 self, query_list: Iterable[Tuple[str, str, str]] 836 ) -> Dict[str, Dict[str, Dict[str, str]]]: 837 """Take a list of one time keys out of the database. 838 839 Args: 840 query_list: An iterable of tuples of (user ID, device ID, algorithm). 841 842 Returns: 843 A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. 844 """ 845 846 @trace 847 def _claim_e2e_one_time_key_simple( 848 txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str 849 ) -> Optional[Tuple[str, str]]: 850 """Claim OTK for device for DBs that don't support RETURNING. 851 852 Returns: 853 A tuple of key name (algorithm + key ID) and key JSON, if an 854 OTK was found. 855 """ 856 857 sql = """ 858 SELECT key_id, key_json FROM e2e_one_time_keys_json 859 WHERE user_id = ? AND device_id = ? AND algorithm = ? 860 LIMIT 1 861 """ 862 863 txn.execute(sql, (user_id, device_id, algorithm)) 864 otk_row = txn.fetchone() 865 if otk_row is None: 866 return None 867 868 key_id, key_json = otk_row 869 870 self.db_pool.simple_delete_one_txn( 871 txn, 872 table="e2e_one_time_keys_json", 873 keyvalues={ 874 "user_id": user_id, 875 "device_id": device_id, 876 "algorithm": algorithm, 877 "key_id": key_id, 878 }, 879 ) 880 self._invalidate_cache_and_stream( 881 txn, self.count_e2e_one_time_keys, (user_id, device_id) 882 ) 883 884 return f"{algorithm}:{key_id}", key_json 885 886 @trace 887 def _claim_e2e_one_time_key_returning( 888 txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str 889 ) -> Optional[Tuple[str, str]]: 890 """Claim OTK for device for DBs that support RETURNING. 891 892 Returns: 893 A tuple of key name (algorithm + key ID) and key JSON, if an 894 OTK was found. 895 """ 896 897 # We can use RETURNING to do the fetch and DELETE in once step. 898 sql = """ 899 DELETE FROM e2e_one_time_keys_json 900 WHERE user_id = ? AND device_id = ? AND algorithm = ? 901 AND key_id IN ( 902 SELECT key_id FROM e2e_one_time_keys_json 903 WHERE user_id = ? AND device_id = ? AND algorithm = ? 904 LIMIT 1 905 ) 906 RETURNING key_id, key_json 907 """ 908 909 txn.execute( 910 sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) 911 ) 912 otk_row = txn.fetchone() 913 if otk_row is None: 914 return None 915 916 self._invalidate_cache_and_stream( 917 txn, self.count_e2e_one_time_keys, (user_id, device_id) 918 ) 919 920 key_id, key_json = otk_row 921 return f"{algorithm}:{key_id}", key_json 922 923 results: Dict[str, Dict[str, Dict[str, str]]] = {} 924 for user_id, device_id, algorithm in query_list: 925 if self.database_engine.supports_returning: 926 # If we support RETURNING clause we can use a single query that 927 # allows us to use autocommit mode. 928 _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning 929 db_autocommit = True 930 else: 931 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple 932 db_autocommit = False 933 934 row = await self.db_pool.runInteraction( 935 "claim_e2e_one_time_keys", 936 _claim_e2e_one_time_key, 937 user_id, 938 device_id, 939 algorithm, 940 db_autocommit=db_autocommit, 941 ) 942 if row: 943 device_results = results.setdefault(user_id, {}).setdefault( 944 device_id, {} 945 ) 946 device_results[row[0]] = row[1] 947 continue 948 949 # No one-time key available, so see if there's a fallback 950 # key 951 row = await self.db_pool.simple_select_one( 952 table="e2e_fallback_keys_json", 953 keyvalues={ 954 "user_id": user_id, 955 "device_id": device_id, 956 "algorithm": algorithm, 957 }, 958 retcols=("key_id", "key_json", "used"), 959 desc="_get_fallback_key", 960 allow_none=True, 961 ) 962 if row is None: 963 continue 964 965 key_id = row["key_id"] 966 key_json = row["key_json"] 967 used = row["used"] 968 969 # Mark fallback key as used if not already. 970 if not used: 971 await self.db_pool.simple_update_one( 972 table="e2e_fallback_keys_json", 973 keyvalues={ 974 "user_id": user_id, 975 "device_id": device_id, 976 "algorithm": algorithm, 977 "key_id": key_id, 978 }, 979 updatevalues={"used": True}, 980 desc="_get_fallback_key_set_used", 981 ) 982 await self.invalidate_cache_and_stream( 983 "get_e2e_unused_fallback_key_types", (user_id, device_id) 984 ) 985 986 device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) 987 device_results[f"{algorithm}:{key_id}"] = key_json 988 989 return results 990 991 992class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): 993 def __init__( 994 self, 995 database: DatabasePool, 996 db_conn: LoggingDatabaseConnection, 997 hs: "HomeServer", 998 ): 999 super().__init__(database, db_conn, hs) 1000 1001 self._cross_signing_id_gen = StreamIdGenerator( 1002 db_conn, "e2e_cross_signing_keys", "stream_id" 1003 ) 1004 1005 async def set_e2e_device_keys( 1006 self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict 1007 ) -> bool: 1008 """Stores device keys for a device. Returns whether there was a change 1009 or the keys were already in the database. 1010 """ 1011 1012 def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: 1013 set_tag("user_id", user_id) 1014 set_tag("device_id", device_id) 1015 set_tag("time_now", time_now) 1016 set_tag("device_keys", device_keys) 1017 1018 old_key_json = self.db_pool.simple_select_one_onecol_txn( 1019 txn, 1020 table="e2e_device_keys_json", 1021 keyvalues={"user_id": user_id, "device_id": device_id}, 1022 retcol="key_json", 1023 allow_none=True, 1024 ) 1025 1026 # In py3 we need old_key_json to match new_key_json type. The DB 1027 # returns unicode while encode_canonical_json returns bytes. 1028 new_key_json = encode_canonical_json(device_keys).decode("utf-8") 1029 1030 if old_key_json == new_key_json: 1031 log_kv({"Message": "Device key already stored."}) 1032 return False 1033 1034 self.db_pool.simple_upsert_txn( 1035 txn, 1036 table="e2e_device_keys_json", 1037 keyvalues={"user_id": user_id, "device_id": device_id}, 1038 values={"ts_added_ms": time_now, "key_json": new_key_json}, 1039 ) 1040 log_kv({"message": "Device keys stored."}) 1041 return True 1042 1043 return await self.db_pool.runInteraction( 1044 "set_e2e_device_keys", _set_e2e_device_keys_txn 1045 ) 1046 1047 async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: 1048 def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: 1049 log_kv( 1050 { 1051 "message": "Deleting keys for device", 1052 "device_id": device_id, 1053 "user_id": user_id, 1054 } 1055 ) 1056 self.db_pool.simple_delete_txn( 1057 txn, 1058 table="e2e_device_keys_json", 1059 keyvalues={"user_id": user_id, "device_id": device_id}, 1060 ) 1061 self.db_pool.simple_delete_txn( 1062 txn, 1063 table="e2e_one_time_keys_json", 1064 keyvalues={"user_id": user_id, "device_id": device_id}, 1065 ) 1066 self._invalidate_cache_and_stream( 1067 txn, self.count_e2e_one_time_keys, (user_id, device_id) 1068 ) 1069 self.db_pool.simple_delete_txn( 1070 txn, 1071 table="dehydrated_devices", 1072 keyvalues={"user_id": user_id, "device_id": device_id}, 1073 ) 1074 self.db_pool.simple_delete_txn( 1075 txn, 1076 table="e2e_fallback_keys_json", 1077 keyvalues={"user_id": user_id, "device_id": device_id}, 1078 ) 1079 self._invalidate_cache_and_stream( 1080 txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) 1081 ) 1082 1083 await self.db_pool.runInteraction( 1084 "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn 1085 ) 1086 1087 def _set_e2e_cross_signing_key_txn( 1088 self, 1089 txn: LoggingTransaction, 1090 user_id: str, 1091 key_type: str, 1092 key: JsonDict, 1093 stream_id: int, 1094 ) -> None: 1095 """Set a user's cross-signing key. 1096 1097 Args: 1098 txn: db connection 1099 user_id: the user to set the signing key for 1100 key_type: the type of key that is being set: either 'master' 1101 for a master key, 'self_signing' for a self-signing key, or 1102 'user_signing' for a user-signing key 1103 key: the key data 1104 stream_id 1105 """ 1106 # the 'key' dict will look something like: 1107 # { 1108 # "user_id": "@alice:example.com", 1109 # "usage": ["self_signing"], 1110 # "keys": { 1111 # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", 1112 # }, 1113 # "signatures": { 1114 # "@alice:example.com": { 1115 # "ed25519:base64+master+public+key": "base64+signature" 1116 # } 1117 # } 1118 # } 1119 # The "keys" property must only have one entry, which will be the public 1120 # key, so we just grab the first value in there 1121 pubkey = next(iter(key["keys"].values())) 1122 1123 # The cross-signing keys need to occupy the same namespace as devices, 1124 # since signatures are identified by device ID. So add an entry to the 1125 # device table to make sure that we don't have a collision with device 1126 # IDs. 1127 # We only need to do this for local users, since remote servers should be 1128 # responsible for checking this for their own users. 1129 if self.hs.is_mine_id(user_id): 1130 self.db_pool.simple_insert_txn( 1131 txn, 1132 "devices", 1133 values={ 1134 "user_id": user_id, 1135 "device_id": pubkey, 1136 "display_name": key_type + " signing key", 1137 "hidden": True, 1138 }, 1139 ) 1140 1141 # and finally, store the key itself 1142 self.db_pool.simple_insert_txn( 1143 txn, 1144 "e2e_cross_signing_keys", 1145 values={ 1146 "user_id": user_id, 1147 "keytype": key_type, 1148 "keydata": json_encoder.encode(key), 1149 "stream_id": stream_id, 1150 }, 1151 ) 1152 1153 self._invalidate_cache_and_stream( 1154 txn, self._get_bare_e2e_cross_signing_keys, (user_id,) 1155 ) 1156 1157 async def set_e2e_cross_signing_key( 1158 self, user_id: str, key_type: str, key: JsonDict 1159 ) -> None: 1160 """Set a user's cross-signing key. 1161 1162 Args: 1163 user_id: the user to set the user-signing key for 1164 key_type: the type of cross-signing key to set 1165 key: the key data 1166 """ 1167 1168 async with self._cross_signing_id_gen.get_next() as stream_id: 1169 return await self.db_pool.runInteraction( 1170 "add_e2e_cross_signing_key", 1171 self._set_e2e_cross_signing_key_txn, 1172 user_id, 1173 key_type, 1174 key, 1175 stream_id, 1176 ) 1177 1178 async def store_e2e_cross_signing_signatures( 1179 self, user_id: str, signatures: "Iterable[SignatureListItem]" 1180 ) -> None: 1181 """Stores cross-signing signatures. 1182 1183 Args: 1184 user_id: the user who made the signatures 1185 signatures: signatures to add 1186 """ 1187 await self.db_pool.simple_insert_many( 1188 "e2e_cross_signing_signatures", 1189 [ 1190 { 1191 "user_id": user_id, 1192 "key_id": item.signing_key_id, 1193 "target_user_id": item.target_user_id, 1194 "target_device_id": item.target_device_id, 1195 "signature": item.signature, 1196 } 1197 for item in signatures 1198 ], 1199 "add_e2e_signing_key", 1200 ) 1201