1# Copyright 2016 OpenMarket Ltd 2# Copyright 2019 New Vector Ltd 3# Copyright 2019,2020 The Matrix.org Foundation C.I.C. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16import abc 17import logging 18from typing import ( 19 TYPE_CHECKING, 20 Any, 21 Collection, 22 Dict, 23 Iterable, 24 List, 25 Optional, 26 Set, 27 Tuple, 28) 29 30from synapse.api.errors import Codes, StoreError 31from synapse.logging.opentracing import ( 32 get_active_span_text_map, 33 set_tag, 34 trace, 35 whitelisted_homeserver, 36) 37from synapse.metrics.background_process_metrics import wrap_as_background_process 38from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause 39from synapse.storage.database import ( 40 DatabasePool, 41 LoggingDatabaseConnection, 42 LoggingTransaction, 43 make_tuple_comparison_clause, 44) 45from synapse.types import JsonDict, get_verify_key_from_cross_signing_key 46from synapse.util import json_decoder, json_encoder 47from synapse.util.caches.descriptors import cached, cachedList 48from synapse.util.caches.lrucache import LruCache 49from synapse.util.iterutils import batch_iter 50from synapse.util.stringutils import shortstr 51 52if TYPE_CHECKING: 53 from synapse.server import HomeServer 54 55logger = logging.getLogger(__name__) 56 57DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( 58 "drop_device_list_streams_non_unique_indexes" 59) 60 61BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" 62 63 64class DeviceWorkerStore(SQLBaseStore): 65 def __init__( 66 self, 67 database: DatabasePool, 68 db_conn: LoggingDatabaseConnection, 69 hs: "HomeServer", 70 ): 71 super().__init__(database, db_conn, hs) 72 73 if hs.config.worker.run_background_tasks: 74 self._clock.looping_call( 75 self._prune_old_outbound_device_pokes, 60 * 60 * 1000 76 ) 77 78 async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: 79 """Retrieve number of all devices of given users. 80 Only returns number of devices that are not marked as hidden. 81 82 Args: 83 user_ids: The IDs of the users which owns devices 84 Returns: 85 Number of devices of this users. 86 """ 87 88 def count_devices_by_users_txn(txn, user_ids): 89 sql = """ 90 SELECT count(*) 91 FROM devices 92 WHERE 93 hidden = '0' AND 94 """ 95 96 clause, args = make_in_list_sql_clause( 97 txn.database_engine, "user_id", user_ids 98 ) 99 100 txn.execute(sql + clause, args) 101 return txn.fetchone()[0] 102 103 if not user_ids: 104 return 0 105 106 return await self.db_pool.runInteraction( 107 "count_devices_by_users", count_devices_by_users_txn, user_ids 108 ) 109 110 async def get_device( 111 self, user_id: str, device_id: str 112 ) -> Optional[Dict[str, Any]]: 113 """Retrieve a device. Only returns devices that are not marked as 114 hidden. 115 116 Args: 117 user_id: The ID of the user which owns the device 118 device_id: The ID of the device to retrieve 119 Returns: 120 A dict containing the device information, or `None` if the device does not 121 exist. 122 """ 123 return await self.db_pool.simple_select_one( 124 table="devices", 125 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, 126 retcols=("user_id", "device_id", "display_name"), 127 desc="get_device", 128 allow_none=True, 129 ) 130 131 async def get_device_opt( 132 self, user_id: str, device_id: str 133 ) -> Optional[Dict[str, Any]]: 134 """Retrieve a device. Only returns devices that are not marked as 135 hidden. 136 137 Args: 138 user_id: The ID of the user which owns the device 139 device_id: The ID of the device to retrieve 140 Returns: 141 A dict containing the device information, or None if the device does not exist. 142 """ 143 return await self.db_pool.simple_select_one( 144 table="devices", 145 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, 146 retcols=("user_id", "device_id", "display_name"), 147 desc="get_device", 148 allow_none=True, 149 ) 150 151 async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: 152 """Retrieve all of a user's registered devices. Only returns devices 153 that are not marked as hidden. 154 155 Args: 156 user_id: 157 Returns: 158 A mapping from device_id to a dict containing "device_id", "user_id" 159 and "display_name" for each device. 160 """ 161 devices = await self.db_pool.simple_select_list( 162 table="devices", 163 keyvalues={"user_id": user_id, "hidden": False}, 164 retcols=("user_id", "device_id", "display_name"), 165 desc="get_devices_by_user", 166 ) 167 168 return {d["device_id"]: d for d in devices} 169 170 async def get_devices_by_auth_provider_session_id( 171 self, auth_provider_id: str, auth_provider_session_id: str 172 ) -> List[Dict[str, Any]]: 173 """Retrieve the list of devices associated with a SSO IdP session ID. 174 175 Args: 176 auth_provider_id: The SSO IdP ID as defined in the server config 177 auth_provider_session_id: The session ID within the IdP 178 Returns: 179 A list of dicts containing the device_id and the user_id of each device 180 """ 181 return await self.db_pool.simple_select_list( 182 table="device_auth_providers", 183 keyvalues={ 184 "auth_provider_id": auth_provider_id, 185 "auth_provider_session_id": auth_provider_session_id, 186 }, 187 retcols=("user_id", "device_id"), 188 desc="get_devices_by_auth_provider_session_id", 189 ) 190 191 @trace 192 async def get_device_updates_by_remote( 193 self, destination: str, from_stream_id: int, limit: int 194 ) -> Tuple[int, List[Tuple[str, JsonDict]]]: 195 """Get a stream of device updates to send to the given remote server. 196 197 Args: 198 destination: The host the device updates are intended for 199 from_stream_id: The minimum stream_id to filter updates by, exclusive 200 limit: Maximum number of device updates to return 201 202 Returns: 203 - The current stream id (i.e. the stream id of the last update included 204 in the response); and 205 - The list of updates, where each update is a pair of EDU type and 206 EDU contents. 207 """ 208 now_stream_id = self.get_device_stream_token() 209 210 has_changed = self._device_list_federation_stream_cache.has_entity_changed( 211 destination, int(from_stream_id) 212 ) 213 if not has_changed: 214 return now_stream_id, [] 215 216 updates = await self.db_pool.runInteraction( 217 "get_device_updates_by_remote", 218 self._get_device_updates_by_remote_txn, 219 destination, 220 from_stream_id, 221 now_stream_id, 222 limit, 223 ) 224 225 # We need to ensure `updates` doesn't grow too big. 226 # Currently: `len(updates) <= limit`. 227 228 # Return an empty list if there are no updates 229 if not updates: 230 return now_stream_id, [] 231 232 # get the cross-signing keys of the users in the list, so that we can 233 # determine which of the device changes were cross-signing keys 234 users = {r[0] for r in updates} 235 master_key_by_user = {} 236 self_signing_key_by_user = {} 237 for user in users: 238 cross_signing_key = await self.get_e2e_cross_signing_key(user, "master") 239 if cross_signing_key: 240 key_id, verify_key = get_verify_key_from_cross_signing_key( 241 cross_signing_key 242 ) 243 # verify_key is a VerifyKey from signedjson, which uses 244 # .version to denote the portion of the key ID after the 245 # algorithm and colon, which is the device ID 246 master_key_by_user[user] = { 247 "key_info": cross_signing_key, 248 "device_id": verify_key.version, 249 } 250 251 cross_signing_key = await self.get_e2e_cross_signing_key( 252 user, "self_signing" 253 ) 254 if cross_signing_key: 255 key_id, verify_key = get_verify_key_from_cross_signing_key( 256 cross_signing_key 257 ) 258 self_signing_key_by_user[user] = { 259 "key_info": cross_signing_key, 260 "device_id": verify_key.version, 261 } 262 263 # Perform the equivalent of a GROUP BY 264 # 265 # Iterate through the updates list and copy non-duplicate 266 # (user_id, device_id) entries into a map, with the value being 267 # the max stream_id across each set of duplicate entries 268 # 269 # maps (user_id, device_id) -> (stream_id, opentracing_context) 270 # 271 # opentracing_context contains the opentracing metadata for the request 272 # that created the poke 273 # 274 # The most recent request's opentracing_context is used as the 275 # context which created the Edu. 276 277 # This is the stream ID that we will return for the consumer to resume 278 # following this stream later. 279 last_processed_stream_id = from_stream_id 280 281 query_map = {} 282 cross_signing_keys_by_user = {} 283 for user_id, device_id, update_stream_id, update_context in updates: 284 # Calculate the remaining length budget. 285 # Note that, for now, each entry in `cross_signing_keys_by_user` 286 # gives rise to two device updates in the result, so those cost twice 287 # as much (and are the whole reason we need to separately calculate 288 # the budget; we know len(updates) <= limit otherwise!) 289 # N.B. len() on dicts is cheap since they store their size. 290 remaining_length_budget = limit - ( 291 len(query_map) + 2 * len(cross_signing_keys_by_user) 292 ) 293 assert remaining_length_budget >= 0 294 295 is_master_key_update = ( 296 user_id in master_key_by_user 297 and device_id == master_key_by_user[user_id]["device_id"] 298 ) 299 is_self_signing_key_update = ( 300 user_id in self_signing_key_by_user 301 and device_id == self_signing_key_by_user[user_id]["device_id"] 302 ) 303 304 is_cross_signing_key_update = ( 305 is_master_key_update or is_self_signing_key_update 306 ) 307 308 if ( 309 is_cross_signing_key_update 310 and user_id not in cross_signing_keys_by_user 311 ): 312 # This will give rise to 2 device updates. 313 # If we don't have the budget, stop here! 314 if remaining_length_budget < 2: 315 break 316 317 if is_master_key_update: 318 result = cross_signing_keys_by_user.setdefault(user_id, {}) 319 result["master_key"] = master_key_by_user[user_id]["key_info"] 320 elif is_self_signing_key_update: 321 result = cross_signing_keys_by_user.setdefault(user_id, {}) 322 result["self_signing_key"] = self_signing_key_by_user[user_id][ 323 "key_info" 324 ] 325 else: 326 key = (user_id, device_id) 327 328 if key not in query_map and remaining_length_budget < 1: 329 # We don't have space for a new entry 330 break 331 332 previous_update_stream_id, _ = query_map.get(key, (0, None)) 333 334 if update_stream_id > previous_update_stream_id: 335 # FIXME If this overwrites an older update, this discards the 336 # previous OpenTracing context. 337 # It might make it harder to track down issues using OpenTracing. 338 # If there's a good reason why it doesn't matter, a comment here 339 # about that would not hurt. 340 query_map[key] = (update_stream_id, update_context) 341 342 # As this update has been added to the response, advance the stream 343 # position. 344 last_processed_stream_id = update_stream_id 345 346 # In the worst case scenario, each update is for a distinct user and is 347 # added either to the query_map or to cross_signing_keys_by_user, 348 # but not both: 349 # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here, 350 # so len(query_map) + len(cross_signing_keys_by_user) <= limit. 351 352 results = await self._get_device_update_edus_by_remote( 353 destination, from_stream_id, query_map 354 ) 355 356 # len(results) <= len(query_map) here, 357 # so len(results) + len(cross_signing_keys_by_user) <= limit. 358 359 # Add the updated cross-signing keys to the results list 360 for user_id, result in cross_signing_keys_by_user.items(): 361 result["user_id"] = user_id 362 results.append(("m.signing_key_update", result)) 363 # also send the unstable version 364 # FIXME: remove this when enough servers have upgraded 365 # and remove the length budgeting above. 366 results.append(("org.matrix.signing_key_update", result)) 367 368 return last_processed_stream_id, results 369 370 def _get_device_updates_by_remote_txn( 371 self, 372 txn: LoggingTransaction, 373 destination: str, 374 from_stream_id: int, 375 now_stream_id: int, 376 limit: int, 377 ) -> List[Tuple[str, str, int, Optional[str]]]: 378 """Return device update information for a given remote destination 379 380 Args: 381 txn: The transaction to execute 382 destination: The host the device updates are intended for 383 from_stream_id: The minimum stream_id to filter updates by, exclusive 384 now_stream_id: The maximum stream_id to filter updates by, inclusive 385 limit: Maximum number of device updates to return 386 387 Returns: 388 List: List of device update tuples: 389 - user_id 390 - device_id 391 - stream_id 392 - opentracing_context 393 """ 394 # get the list of device updates that need to be sent 395 sql = """ 396 SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes 397 WHERE destination = ? AND ? < stream_id AND stream_id <= ? 398 ORDER BY stream_id 399 LIMIT ? 400 """ 401 txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) 402 403 return list(txn) 404 405 async def _get_device_update_edus_by_remote( 406 self, 407 destination: str, 408 from_stream_id: int, 409 query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]], 410 ) -> List[Tuple[str, dict]]: 411 """Returns a list of device update EDUs as well as E2EE keys 412 413 Args: 414 destination: The host the device updates are intended for 415 from_stream_id: The minimum stream_id to filter updates by, exclusive 416 query_map: Dictionary mapping (user_id, device_id) to 417 (update stream_id, the relevant json-encoded opentracing context) 418 419 Returns: 420 List of objects representing a device update EDU. 421 422 Postconditions: 423 The returned list has a length not exceeding that of the query_map: 424 len(result) <= len(query_map) 425 """ 426 devices = ( 427 await self.get_e2e_device_keys_and_signatures( 428 # Because these are (user_id, device_id) tuples with all 429 # device_ids not being None, the returned list's length will not 430 # exceed that of query_map. 431 query_map.keys(), 432 include_all_devices=True, 433 include_deleted_devices=True, 434 ) 435 if query_map 436 else {} 437 ) 438 439 results = [] 440 for user_id, user_devices in devices.items(): 441 # The prev_id for the first row is always the last row before 442 # `from_stream_id` 443 prev_id = await self._get_last_device_update_for_remote_user( 444 destination, user_id, from_stream_id 445 ) 446 447 # make sure we go through the devices in stream order 448 device_ids = sorted( 449 user_devices.keys(), 450 key=lambda i: query_map[(user_id, i)][0], 451 ) 452 453 for device_id in device_ids: 454 device = user_devices[device_id] 455 stream_id, opentracing_context = query_map[(user_id, device_id)] 456 result = { 457 "user_id": user_id, 458 "device_id": device_id, 459 "prev_id": [prev_id] if prev_id else [], 460 "stream_id": stream_id, 461 "org.matrix.opentracing_context": opentracing_context, 462 } 463 464 prev_id = stream_id 465 466 if device is not None: 467 keys = device.keys 468 if keys: 469 result["keys"] = keys 470 471 device_display_name = device.display_name 472 if device_display_name: 473 result["device_display_name"] = device_display_name 474 else: 475 result["deleted"] = True 476 477 results.append(("m.device_list_update", result)) 478 479 return results 480 481 async def _get_last_device_update_for_remote_user( 482 self, destination: str, user_id: str, from_stream_id: int 483 ) -> int: 484 def f(txn): 485 prev_sent_id_sql = """ 486 SELECT coalesce(max(stream_id), 0) as stream_id 487 FROM device_lists_outbound_last_success 488 WHERE destination = ? AND user_id = ? AND stream_id <= ? 489 """ 490 txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) 491 rows = txn.fetchall() 492 return rows[0][0] 493 494 return await self.db_pool.runInteraction( 495 "get_last_device_update_for_remote_user", f 496 ) 497 498 async def mark_as_sent_devices_by_remote( 499 self, destination: str, stream_id: int 500 ) -> None: 501 """Mark that updates have successfully been sent to the destination.""" 502 await self.db_pool.runInteraction( 503 "mark_as_sent_devices_by_remote", 504 self._mark_as_sent_devices_by_remote_txn, 505 destination, 506 stream_id, 507 ) 508 509 def _mark_as_sent_devices_by_remote_txn( 510 self, txn: LoggingTransaction, destination: str, stream_id: int 511 ) -> None: 512 # We update the device_lists_outbound_last_success with the successfully 513 # poked users. 514 sql = """ 515 SELECT user_id, coalesce(max(o.stream_id), 0) 516 FROM device_lists_outbound_pokes as o 517 WHERE destination = ? AND o.stream_id <= ? 518 GROUP BY user_id 519 """ 520 txn.execute(sql, (destination, stream_id)) 521 rows = txn.fetchall() 522 523 self.db_pool.simple_upsert_many_txn( 524 txn=txn, 525 table="device_lists_outbound_last_success", 526 key_names=("destination", "user_id"), 527 key_values=((destination, user_id) for user_id, _ in rows), 528 value_names=("stream_id",), 529 value_values=((stream_id,) for _, stream_id in rows), 530 ) 531 532 # Delete all sent outbound pokes 533 sql = """ 534 DELETE FROM device_lists_outbound_pokes 535 WHERE destination = ? AND stream_id <= ? 536 """ 537 txn.execute(sql, (destination, stream_id)) 538 539 async def add_user_signature_change_to_streams( 540 self, from_user_id: str, user_ids: List[str] 541 ) -> int: 542 """Persist that a user has made new signatures 543 544 Args: 545 from_user_id: the user who made the signatures 546 user_ids: the users who were signed 547 548 Returns: 549 The new stream ID. 550 """ 551 552 async with self._device_list_id_gen.get_next() as stream_id: 553 await self.db_pool.runInteraction( 554 "add_user_sig_change_to_streams", 555 self._add_user_signature_change_txn, 556 from_user_id, 557 user_ids, 558 stream_id, 559 ) 560 return stream_id 561 562 def _add_user_signature_change_txn( 563 self, 564 txn: LoggingTransaction, 565 from_user_id: str, 566 user_ids: List[str], 567 stream_id: int, 568 ) -> None: 569 txn.call_after( 570 self._user_signature_stream_cache.entity_has_changed, 571 from_user_id, 572 stream_id, 573 ) 574 self.db_pool.simple_insert_txn( 575 txn, 576 "user_signature_stream", 577 values={ 578 "stream_id": stream_id, 579 "from_user_id": from_user_id, 580 "user_ids": json_encoder.encode(user_ids), 581 }, 582 ) 583 584 @abc.abstractmethod 585 def get_device_stream_token(self) -> int: 586 """Get the current stream id from the _device_list_id_gen""" 587 ... 588 589 @trace 590 async def get_user_devices_from_cache( 591 self, query_list: List[Tuple[str, str]] 592 ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: 593 """Get the devices (and keys if any) for remote users from the cache. 594 595 Args: 596 query_list: List of (user_id, device_ids), if device_ids is 597 falsey then return all device ids for that user. 598 599 Returns: 600 A tuple of (user_ids_not_in_cache, results_map), where 601 user_ids_not_in_cache is a set of user_ids and results_map is a 602 mapping of user_id -> device_id -> device_info. 603 """ 604 user_ids = {user_id for user_id, _ in query_list} 605 user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) 606 607 # We go and check if any of the users need to have their device lists 608 # resynced. If they do then we remove them from the cached list. 609 users_needing_resync = await self.get_user_ids_requiring_device_list_resync( 610 user_ids 611 ) 612 user_ids_in_cache = { 613 user_id for user_id, stream_id in user_map.items() if stream_id 614 } - users_needing_resync 615 user_ids_not_in_cache = user_ids - user_ids_in_cache 616 617 results = {} 618 for user_id, device_id in query_list: 619 if user_id not in user_ids_in_cache: 620 continue 621 622 if device_id: 623 device = await self._get_cached_user_device(user_id, device_id) 624 results.setdefault(user_id, {})[device_id] = device 625 else: 626 results[user_id] = await self.get_cached_devices_for_user(user_id) 627 628 set_tag("in_cache", results) 629 set_tag("not_in_cache", user_ids_not_in_cache) 630 631 return user_ids_not_in_cache, results 632 633 @cached(num_args=2, tree=True) 634 async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: 635 content = await self.db_pool.simple_select_one_onecol( 636 table="device_lists_remote_cache", 637 keyvalues={"user_id": user_id, "device_id": device_id}, 638 retcol="content", 639 desc="_get_cached_user_device", 640 ) 641 return db_to_json(content) 642 643 @cached() 644 async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: 645 devices = await self.db_pool.simple_select_list( 646 table="device_lists_remote_cache", 647 keyvalues={"user_id": user_id}, 648 retcols=("device_id", "content"), 649 desc="get_cached_devices_for_user", 650 ) 651 return { 652 device["device_id"]: db_to_json(device["content"]) for device in devices 653 } 654 655 async def get_users_whose_devices_changed( 656 self, from_key: int, user_ids: Iterable[str] 657 ) -> Set[str]: 658 """Get set of users whose devices have changed since `from_key` that 659 are in the given list of user_ids. 660 661 Args: 662 from_key: The device lists stream token 663 user_ids: The user IDs to query for devices. 664 665 Returns: 666 The set of user_ids whose devices have changed since `from_key` 667 """ 668 669 # Get set of users who *may* have changed. Users not in the returned 670 # list have definitely not changed. 671 to_check = self._device_list_stream_cache.get_entities_changed( 672 user_ids, from_key 673 ) 674 675 if not to_check: 676 return set() 677 678 def _get_users_whose_devices_changed_txn(txn): 679 changes = set() 680 681 sql = """ 682 SELECT DISTINCT user_id FROM device_lists_stream 683 WHERE stream_id > ? 684 AND 685 """ 686 687 for chunk in batch_iter(to_check, 100): 688 clause, args = make_in_list_sql_clause( 689 txn.database_engine, "user_id", chunk 690 ) 691 txn.execute(sql + clause, (from_key,) + tuple(args)) 692 changes.update(user_id for user_id, in txn) 693 694 return changes 695 696 return await self.db_pool.runInteraction( 697 "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn 698 ) 699 700 async def get_users_whose_signatures_changed( 701 self, user_id: str, from_key: int 702 ) -> Set[str]: 703 """Get the users who have new cross-signing signatures made by `user_id` since 704 `from_key`. 705 706 Args: 707 user_id: the user who made the signatures 708 from_key: The device lists stream token 709 710 Returns: 711 A set of user IDs with updated signatures. 712 """ 713 714 if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): 715 sql = """ 716 SELECT DISTINCT user_ids FROM user_signature_stream 717 WHERE from_user_id = ? AND stream_id > ? 718 """ 719 rows = await self.db_pool.execute( 720 "get_users_whose_signatures_changed", None, sql, user_id, from_key 721 ) 722 return {user for row in rows for user in db_to_json(row[0])} 723 else: 724 return set() 725 726 async def get_all_device_list_changes_for_remotes( 727 self, instance_name: str, last_id: int, current_id: int, limit: int 728 ) -> Tuple[List[Tuple[int, tuple]], int, bool]: 729 """Get updates for device lists replication stream. 730 731 Args: 732 instance_name: The writer we want to fetch updates from. Unused 733 here since there is only ever one writer. 734 last_id: The token to fetch updates from. Exclusive. 735 current_id: The token to fetch updates up to. Inclusive. 736 limit: The requested limit for the number of rows to return. The 737 function may return more or fewer rows. 738 739 Returns: 740 A tuple consisting of: the updates, a token to use to fetch 741 subsequent updates, and whether we returned fewer rows than exists 742 between the requested tokens due to the limit. 743 744 The token returned can be used in a subsequent call to this 745 function to get further updates. 746 747 The updates are a list of 2-tuples of stream ID and the row data 748 """ 749 750 if last_id == current_id: 751 return [], current_id, False 752 753 def _get_all_device_list_changes_for_remotes(txn): 754 # This query Does The Right Thing where it'll correctly apply the 755 # bounds to the inner queries. 756 sql = """ 757 SELECT stream_id, entity FROM ( 758 SELECT stream_id, user_id AS entity FROM device_lists_stream 759 UNION ALL 760 SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes 761 ) AS e 762 WHERE ? < stream_id AND stream_id <= ? 763 LIMIT ? 764 """ 765 766 txn.execute(sql, (last_id, current_id, limit)) 767 updates = [(row[0], row[1:]) for row in txn] 768 limited = False 769 upto_token = current_id 770 if len(updates) >= limit: 771 upto_token = updates[-1][0] 772 limited = True 773 774 return updates, upto_token, limited 775 776 return await self.db_pool.runInteraction( 777 "get_all_device_list_changes_for_remotes", 778 _get_all_device_list_changes_for_remotes, 779 ) 780 781 @cached(max_entries=10000) 782 async def get_device_list_last_stream_id_for_remote( 783 self, user_id: str 784 ) -> Optional[Any]: 785 """Get the last stream_id we got for a user. May be None if we haven't 786 got any information for them. 787 """ 788 return await self.db_pool.simple_select_one_onecol( 789 table="device_lists_remote_extremeties", 790 keyvalues={"user_id": user_id}, 791 retcol="stream_id", 792 desc="get_device_list_last_stream_id_for_remote", 793 allow_none=True, 794 ) 795 796 @cachedList( 797 cached_method_name="get_device_list_last_stream_id_for_remote", 798 list_name="user_ids", 799 ) 800 async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): 801 rows = await self.db_pool.simple_select_many_batch( 802 table="device_lists_remote_extremeties", 803 column="user_id", 804 iterable=user_ids, 805 retcols=("user_id", "stream_id"), 806 desc="get_device_list_last_stream_id_for_remotes", 807 ) 808 809 results = {user_id: None for user_id in user_ids} 810 results.update({row["user_id"]: row["stream_id"] for row in rows}) 811 812 return results 813 814 async def get_user_ids_requiring_device_list_resync( 815 self, 816 user_ids: Optional[Collection[str]] = None, 817 ) -> Set[str]: 818 """Given a list of remote users return the list of users that we 819 should resync the device lists for. If None is given instead of a list, 820 return every user that we should resync the device lists for. 821 822 Returns: 823 The IDs of users whose device lists need resync. 824 """ 825 if user_ids: 826 rows = await self.db_pool.simple_select_many_batch( 827 table="device_lists_remote_resync", 828 column="user_id", 829 iterable=user_ids, 830 retcols=("user_id",), 831 desc="get_user_ids_requiring_device_list_resync_with_iterable", 832 ) 833 else: 834 rows = await self.db_pool.simple_select_list( 835 table="device_lists_remote_resync", 836 keyvalues=None, 837 retcols=("user_id",), 838 desc="get_user_ids_requiring_device_list_resync", 839 ) 840 841 return {row["user_id"] for row in rows} 842 843 async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None: 844 """Records that the server has reason to believe the cache of the devices 845 for the remote users is out of date. 846 """ 847 await self.db_pool.simple_upsert( 848 table="device_lists_remote_resync", 849 keyvalues={"user_id": user_id}, 850 values={}, 851 insertion_values={"added_ts": self._clock.time_msec()}, 852 desc="mark_remote_user_device_cache_as_stale", 853 ) 854 855 async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None: 856 # Remove the database entry that says we need to resync devices, after a resync 857 await self.db_pool.simple_delete( 858 table="device_lists_remote_resync", 859 keyvalues={"user_id": user_id}, 860 desc="mark_remote_user_device_cache_as_valid", 861 ) 862 863 async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: 864 """Mark that we no longer track device lists for remote user.""" 865 866 def _mark_remote_user_device_list_as_unsubscribed_txn(txn): 867 self.db_pool.simple_delete_txn( 868 txn, 869 table="device_lists_remote_extremeties", 870 keyvalues={"user_id": user_id}, 871 ) 872 self._invalidate_cache_and_stream( 873 txn, self.get_device_list_last_stream_id_for_remote, (user_id,) 874 ) 875 876 await self.db_pool.runInteraction( 877 "mark_remote_user_device_list_as_unsubscribed", 878 _mark_remote_user_device_list_as_unsubscribed_txn, 879 ) 880 881 async def get_dehydrated_device( 882 self, user_id: str 883 ) -> Optional[Tuple[str, JsonDict]]: 884 """Retrieve the information for a dehydrated device. 885 886 Args: 887 user_id: the user whose dehydrated device we are looking for 888 Returns: 889 a tuple whose first item is the device ID, and the second item is 890 the dehydrated device information 891 """ 892 # FIXME: make sure device ID still exists in devices table 893 row = await self.db_pool.simple_select_one( 894 table="dehydrated_devices", 895 keyvalues={"user_id": user_id}, 896 retcols=["device_id", "device_data"], 897 allow_none=True, 898 ) 899 return ( 900 (row["device_id"], json_decoder.decode(row["device_data"])) if row else None 901 ) 902 903 def _store_dehydrated_device_txn( 904 self, txn, user_id: str, device_id: str, device_data: str 905 ) -> Optional[str]: 906 old_device_id = self.db_pool.simple_select_one_onecol_txn( 907 txn, 908 table="dehydrated_devices", 909 keyvalues={"user_id": user_id}, 910 retcol="device_id", 911 allow_none=True, 912 ) 913 self.db_pool.simple_upsert_txn( 914 txn, 915 table="dehydrated_devices", 916 keyvalues={"user_id": user_id}, 917 values={"device_id": device_id, "device_data": device_data}, 918 ) 919 return old_device_id 920 921 async def store_dehydrated_device( 922 self, user_id: str, device_id: str, device_data: JsonDict 923 ) -> Optional[str]: 924 """Store a dehydrated device for a user. 925 926 Args: 927 user_id: the user that we are storing the device for 928 device_id: the ID of the dehydrated device 929 device_data: the dehydrated device information 930 Returns: 931 device id of the user's previous dehydrated device, if any 932 """ 933 return await self.db_pool.runInteraction( 934 "store_dehydrated_device_txn", 935 self._store_dehydrated_device_txn, 936 user_id, 937 device_id, 938 json_encoder.encode(device_data), 939 ) 940 941 async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: 942 """Remove a dehydrated device. 943 944 Args: 945 user_id: the user that the dehydrated device belongs to 946 device_id: the ID of the dehydrated device 947 """ 948 count = await self.db_pool.simple_delete( 949 "dehydrated_devices", 950 {"user_id": user_id, "device_id": device_id}, 951 desc="remove_dehydrated_device", 952 ) 953 return count >= 1 954 955 @wrap_as_background_process("prune_old_outbound_device_pokes") 956 async def _prune_old_outbound_device_pokes( 957 self, prune_age: int = 24 * 60 * 60 * 1000 958 ) -> None: 959 """Delete old entries out of the device_lists_outbound_pokes to ensure 960 that we don't fill up due to dead servers. 961 962 Normally, we try to send device updates as a delta since a previous known point: 963 this is done by setting the prev_id in the m.device_list_update EDU. However, 964 for that to work, we have to have a complete record of each change to 965 each device, which can add up to quite a lot of data. 966 967 An alternative mechanism is that, if the remote server sees that it has missed 968 an entry in the stream_id sequence for a given user, it will request a full 969 list of that user's devices. Hence, we can reduce the amount of data we have to 970 store (and transmit in some future transaction), by clearing almost everything 971 for a given destination out of the database, and having the remote server 972 resync. 973 974 All we need to do is make sure we keep at least one row for each 975 (user, destination) pair, to remind us to send a m.device_list_update EDU for 976 that user when the destination comes back. It doesn't matter which device 977 we keep. 978 """ 979 yesterday = self._clock.time_msec() - prune_age 980 981 def _prune_txn(txn): 982 # look for (user, destination) pairs which have an update older than 983 # the cutoff. 984 # 985 # For each pair, we also need to know the most recent stream_id, and 986 # an arbitrary device_id at that stream_id. 987 select_sql = """ 988 SELECT 989 dlop1.destination, 990 dlop1.user_id, 991 MAX(dlop1.stream_id) AS stream_id, 992 (SELECT MIN(dlop2.device_id) AS device_id FROM 993 device_lists_outbound_pokes dlop2 994 WHERE dlop2.destination = dlop1.destination AND 995 dlop2.user_id=dlop1.user_id AND 996 dlop2.stream_id=MAX(dlop1.stream_id) 997 ) 998 FROM device_lists_outbound_pokes dlop1 999 GROUP BY destination, user_id 1000 HAVING min(ts) < ? AND count(*) > 1 1001 """ 1002 1003 txn.execute(select_sql, (yesterday,)) 1004 rows = txn.fetchall() 1005 1006 if not rows: 1007 return 1008 1009 logger.info( 1010 "Pruning old outbound device list updates for %i users/destinations: %s", 1011 len(rows), 1012 shortstr((row[0], row[1]) for row in rows), 1013 ) 1014 1015 # we want to keep the update with the highest stream_id for each user. 1016 # 1017 # there might be more than one update (with different device_ids) with the 1018 # same stream_id, so we also delete all but one rows with the max stream id. 1019 delete_sql = """ 1020 DELETE FROM device_lists_outbound_pokes 1021 WHERE destination = ? AND user_id = ? AND ( 1022 stream_id < ? OR 1023 (stream_id = ? AND device_id != ?) 1024 ) 1025 """ 1026 count = 0 1027 for (destination, user_id, stream_id, device_id) in rows: 1028 txn.execute( 1029 delete_sql, (destination, user_id, stream_id, stream_id, device_id) 1030 ) 1031 count += txn.rowcount 1032 1033 # Since we've deleted unsent deltas, we need to remove the entry 1034 # of last successful sent so that the prev_ids are correctly set. 1035 sql = """ 1036 DELETE FROM device_lists_outbound_last_success 1037 WHERE destination = ? AND user_id = ? 1038 """ 1039 txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) 1040 1041 logger.info("Pruned %d device list outbound pokes", count) 1042 1043 await self.db_pool.runInteraction( 1044 "_prune_old_outbound_device_pokes", 1045 _prune_txn, 1046 ) 1047 1048 1049class DeviceBackgroundUpdateStore(SQLBaseStore): 1050 def __init__( 1051 self, 1052 database: DatabasePool, 1053 db_conn: LoggingDatabaseConnection, 1054 hs: "HomeServer", 1055 ): 1056 super().__init__(database, db_conn, hs) 1057 1058 self.db_pool.updates.register_background_index_update( 1059 "device_lists_stream_idx", 1060 index_name="device_lists_stream_user_id", 1061 table="device_lists_stream", 1062 columns=["user_id", "device_id"], 1063 ) 1064 1065 # create a unique index on device_lists_remote_cache 1066 self.db_pool.updates.register_background_index_update( 1067 "device_lists_remote_cache_unique_idx", 1068 index_name="device_lists_remote_cache_unique_id", 1069 table="device_lists_remote_cache", 1070 columns=["user_id", "device_id"], 1071 unique=True, 1072 ) 1073 1074 # And one on device_lists_remote_extremeties 1075 self.db_pool.updates.register_background_index_update( 1076 "device_lists_remote_extremeties_unique_idx", 1077 index_name="device_lists_remote_extremeties_unique_idx", 1078 table="device_lists_remote_extremeties", 1079 columns=["user_id"], 1080 unique=True, 1081 ) 1082 1083 # once they complete, we can remove the old non-unique indexes. 1084 self.db_pool.updates.register_background_update_handler( 1085 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, 1086 self._drop_device_list_streams_non_unique_indexes, 1087 ) 1088 1089 # clear out duplicate device list outbound pokes 1090 self.db_pool.updates.register_background_update_handler( 1091 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, 1092 self._remove_duplicate_outbound_pokes, 1093 ) 1094 1095 # a pair of background updates that were added during the 1.14 release cycle, 1096 # but replaced with 58/06dlols_unique_idx.py 1097 self.db_pool.updates.register_noop_background_update( 1098 "device_lists_outbound_last_success_unique_idx", 1099 ) 1100 self.db_pool.updates.register_noop_background_update( 1101 "drop_device_lists_outbound_last_success_non_unique_idx", 1102 ) 1103 1104 async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): 1105 def f(conn): 1106 txn = conn.cursor() 1107 txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") 1108 txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") 1109 txn.close() 1110 1111 await self.db_pool.runWithConnection(f) 1112 await self.db_pool.updates._end_background_update( 1113 DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES 1114 ) 1115 return 1 1116 1117 async def _remove_duplicate_outbound_pokes(self, progress, batch_size): 1118 # for some reason, we have accumulated duplicate entries in 1119 # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less 1120 # efficient. 1121 # 1122 # For each duplicate, we delete all the existing rows and put one back. 1123 1124 KEY_COLS = ["stream_id", "destination", "user_id", "device_id"] 1125 last_row = progress.get( 1126 "last_row", 1127 {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, 1128 ) 1129 1130 def _txn(txn): 1131 clause, args = make_tuple_comparison_clause( 1132 [(x, last_row[x]) for x in KEY_COLS] 1133 ) 1134 sql = """ 1135 SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts 1136 FROM device_lists_outbound_pokes 1137 WHERE %s 1138 GROUP BY %s 1139 HAVING count(*) > 1 1140 ORDER BY %s 1141 LIMIT ? 1142 """ % ( 1143 clause, # WHERE 1144 ",".join(KEY_COLS), # GROUP BY 1145 ",".join(KEY_COLS), # ORDER BY 1146 ) 1147 txn.execute(sql, args + [batch_size]) 1148 rows = self.db_pool.cursor_to_dict(txn) 1149 1150 row = None 1151 for row in rows: 1152 self.db_pool.simple_delete_txn( 1153 txn, 1154 "device_lists_outbound_pokes", 1155 {x: row[x] for x in KEY_COLS}, 1156 ) 1157 1158 row["sent"] = False 1159 self.db_pool.simple_insert_txn( 1160 txn, 1161 "device_lists_outbound_pokes", 1162 row, 1163 ) 1164 1165 if row: 1166 self.db_pool.updates._background_update_progress_txn( 1167 txn, 1168 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, 1169 {"last_row": row}, 1170 ) 1171 1172 return len(rows) 1173 1174 rows = await self.db_pool.runInteraction( 1175 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn 1176 ) 1177 1178 if not rows: 1179 await self.db_pool.updates._end_background_update( 1180 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES 1181 ) 1182 1183 return rows 1184 1185 1186class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): 1187 def __init__( 1188 self, 1189 database: DatabasePool, 1190 db_conn: LoggingDatabaseConnection, 1191 hs: "HomeServer", 1192 ): 1193 super().__init__(database, db_conn, hs) 1194 1195 # Map of (user_id, device_id) -> bool. If there is an entry that implies 1196 # the device exists. 1197 self.device_id_exists_cache = LruCache( 1198 cache_name="device_id_exists", max_size=10000 1199 ) 1200 1201 async def store_device( 1202 self, 1203 user_id: str, 1204 device_id: str, 1205 initial_device_display_name: Optional[str], 1206 auth_provider_id: Optional[str] = None, 1207 auth_provider_session_id: Optional[str] = None, 1208 ) -> bool: 1209 """Ensure the given device is known; add it to the store if not 1210 1211 Args: 1212 user_id: id of user associated with the device 1213 device_id: id of device 1214 initial_device_display_name: initial displayname of the device. 1215 Ignored if device exists. 1216 auth_provider_id: The SSO IdP the user used, if any. 1217 auth_provider_session_id: The session ID (sid) got from a OIDC login. 1218 1219 Returns: 1220 Whether the device was inserted or an existing device existed with that ID. 1221 1222 Raises: 1223 StoreError: if the device is already in use 1224 """ 1225 key = (user_id, device_id) 1226 if self.device_id_exists_cache.get(key, None): 1227 return False 1228 1229 try: 1230 inserted = await self.db_pool.simple_upsert( 1231 "devices", 1232 keyvalues={ 1233 "user_id": user_id, 1234 "device_id": device_id, 1235 }, 1236 values={}, 1237 insertion_values={ 1238 "display_name": initial_device_display_name, 1239 "hidden": False, 1240 }, 1241 desc="store_device", 1242 ) 1243 if not inserted: 1244 # if the device already exists, check if it's a real device, or 1245 # if the device ID is reserved by something else 1246 hidden = await self.db_pool.simple_select_one_onecol( 1247 "devices", 1248 keyvalues={"user_id": user_id, "device_id": device_id}, 1249 retcol="hidden", 1250 ) 1251 if hidden: 1252 raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) 1253 1254 if auth_provider_id and auth_provider_session_id: 1255 await self.db_pool.simple_insert( 1256 "device_auth_providers", 1257 values={ 1258 "user_id": user_id, 1259 "device_id": device_id, 1260 "auth_provider_id": auth_provider_id, 1261 "auth_provider_session_id": auth_provider_session_id, 1262 }, 1263 desc="store_device_auth_provider", 1264 ) 1265 1266 self.device_id_exists_cache.set(key, True) 1267 return inserted 1268 except StoreError: 1269 raise 1270 except Exception as e: 1271 logger.error( 1272 "store_device with device_id=%s(%r) user_id=%s(%r)" 1273 " display_name=%s(%r) failed: %s", 1274 type(device_id).__name__, 1275 device_id, 1276 type(user_id).__name__, 1277 user_id, 1278 type(initial_device_display_name).__name__, 1279 initial_device_display_name, 1280 e, 1281 ) 1282 raise StoreError(500, "Problem storing device.") 1283 1284 async def delete_device(self, user_id: str, device_id: str) -> None: 1285 """Delete a device and its device_inbox. 1286 1287 Args: 1288 user_id: The ID of the user which owns the device 1289 device_id: The ID of the device to delete 1290 """ 1291 1292 await self.delete_devices(user_id, [device_id]) 1293 1294 async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: 1295 """Deletes several devices. 1296 1297 Args: 1298 user_id: The ID of the user which owns the devices 1299 device_ids: The IDs of the devices to delete 1300 """ 1301 1302 def _delete_devices_txn(txn: LoggingTransaction) -> None: 1303 self.db_pool.simple_delete_many_txn( 1304 txn, 1305 table="devices", 1306 column="device_id", 1307 values=device_ids, 1308 keyvalues={"user_id": user_id, "hidden": False}, 1309 ) 1310 1311 self.db_pool.simple_delete_many_txn( 1312 txn, 1313 table="device_inbox", 1314 column="device_id", 1315 values=device_ids, 1316 keyvalues={"user_id": user_id}, 1317 ) 1318 1319 self.db_pool.simple_delete_many_txn( 1320 txn, 1321 table="device_auth_providers", 1322 column="device_id", 1323 values=device_ids, 1324 keyvalues={"user_id": user_id}, 1325 ) 1326 1327 await self.db_pool.runInteraction("delete_devices", _delete_devices_txn) 1328 for device_id in device_ids: 1329 self.device_id_exists_cache.invalidate((user_id, device_id)) 1330 1331 async def update_device( 1332 self, user_id: str, device_id: str, new_display_name: Optional[str] = None 1333 ) -> None: 1334 """Update a device. Only updates the device if it is not marked as 1335 hidden. 1336 1337 Args: 1338 user_id: The ID of the user which owns the device 1339 device_id: The ID of the device to update 1340 new_display_name: new displayname for device; None to leave unchanged 1341 Raises: 1342 StoreError: if the device is not found 1343 """ 1344 updates = {} 1345 if new_display_name is not None: 1346 updates["display_name"] = new_display_name 1347 if not updates: 1348 return None 1349 await self.db_pool.simple_update_one( 1350 table="devices", 1351 keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, 1352 updatevalues=updates, 1353 desc="update_device", 1354 ) 1355 1356 async def update_remote_device_list_cache_entry( 1357 self, user_id: str, device_id: str, content: JsonDict, stream_id: str 1358 ) -> None: 1359 """Updates a single device in the cache of a remote user's devicelist. 1360 1361 Note: assumes that we are the only thread that can be updating this user's 1362 device list. 1363 1364 Args: 1365 user_id: User to update device list for 1366 device_id: ID of decivice being updated 1367 content: new data on this device 1368 stream_id: the version of the device list 1369 """ 1370 await self.db_pool.runInteraction( 1371 "update_remote_device_list_cache_entry", 1372 self._update_remote_device_list_cache_entry_txn, 1373 user_id, 1374 device_id, 1375 content, 1376 stream_id, 1377 ) 1378 1379 def _update_remote_device_list_cache_entry_txn( 1380 self, 1381 txn: LoggingTransaction, 1382 user_id: str, 1383 device_id: str, 1384 content: JsonDict, 1385 stream_id: str, 1386 ) -> None: 1387 if content.get("deleted"): 1388 self.db_pool.simple_delete_txn( 1389 txn, 1390 table="device_lists_remote_cache", 1391 keyvalues={"user_id": user_id, "device_id": device_id}, 1392 ) 1393 1394 txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) 1395 else: 1396 self.db_pool.simple_upsert_txn( 1397 txn, 1398 table="device_lists_remote_cache", 1399 keyvalues={"user_id": user_id, "device_id": device_id}, 1400 values={"content": json_encoder.encode(content)}, 1401 # we don't need to lock, because we assume we are the only thread 1402 # updating this user's devices. 1403 lock=False, 1404 ) 1405 1406 txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) 1407 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) 1408 txn.call_after( 1409 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) 1410 ) 1411 1412 self.db_pool.simple_upsert_txn( 1413 txn, 1414 table="device_lists_remote_extremeties", 1415 keyvalues={"user_id": user_id}, 1416 values={"stream_id": stream_id}, 1417 # again, we can assume we are the only thread updating this user's 1418 # extremity. 1419 lock=False, 1420 ) 1421 1422 async def update_remote_device_list_cache( 1423 self, user_id: str, devices: List[dict], stream_id: int 1424 ) -> None: 1425 """Replace the entire cache of the remote user's devices. 1426 1427 Note: assumes that we are the only thread that can be updating this user's 1428 device list. 1429 1430 Args: 1431 user_id: User to update device list for 1432 devices: list of device objects supplied over federation 1433 stream_id: the version of the device list 1434 """ 1435 await self.db_pool.runInteraction( 1436 "update_remote_device_list_cache", 1437 self._update_remote_device_list_cache_txn, 1438 user_id, 1439 devices, 1440 stream_id, 1441 ) 1442 1443 def _update_remote_device_list_cache_txn( 1444 self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int 1445 ) -> None: 1446 self.db_pool.simple_delete_txn( 1447 txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} 1448 ) 1449 1450 self.db_pool.simple_insert_many_txn( 1451 txn, 1452 table="device_lists_remote_cache", 1453 values=[ 1454 { 1455 "user_id": user_id, 1456 "device_id": content["device_id"], 1457 "content": json_encoder.encode(content), 1458 } 1459 for content in devices 1460 ], 1461 ) 1462 1463 txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) 1464 txn.call_after(self._get_cached_user_device.invalidate, (user_id,)) 1465 txn.call_after( 1466 self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) 1467 ) 1468 1469 self.db_pool.simple_upsert_txn( 1470 txn, 1471 table="device_lists_remote_extremeties", 1472 keyvalues={"user_id": user_id}, 1473 values={"stream_id": stream_id}, 1474 # we don't need to lock, because we can assume we are the only thread 1475 # updating this user's extremity. 1476 lock=False, 1477 ) 1478 1479 async def add_device_change_to_streams( 1480 self, user_id: str, device_ids: Collection[str], hosts: List[str] 1481 ) -> int: 1482 """Persist that a user's devices have been updated, and which hosts 1483 (if any) should be poked. 1484 """ 1485 if not device_ids: 1486 return 1487 1488 async with self._device_list_id_gen.get_next_mult( 1489 len(device_ids) 1490 ) as stream_ids: 1491 await self.db_pool.runInteraction( 1492 "add_device_change_to_stream", 1493 self._add_device_change_to_stream_txn, 1494 user_id, 1495 device_ids, 1496 stream_ids, 1497 ) 1498 1499 if not hosts: 1500 return stream_ids[-1] 1501 1502 context = get_active_span_text_map() 1503 async with self._device_list_id_gen.get_next_mult( 1504 len(hosts) * len(device_ids) 1505 ) as stream_ids: 1506 await self.db_pool.runInteraction( 1507 "add_device_outbound_poke_to_stream", 1508 self._add_device_outbound_poke_to_stream_txn, 1509 user_id, 1510 device_ids, 1511 hosts, 1512 stream_ids, 1513 context, 1514 ) 1515 1516 return stream_ids[-1] 1517 1518 def _add_device_change_to_stream_txn( 1519 self, 1520 txn: LoggingTransaction, 1521 user_id: str, 1522 device_ids: Collection[str], 1523 stream_ids: List[str], 1524 ): 1525 txn.call_after( 1526 self._device_list_stream_cache.entity_has_changed, 1527 user_id, 1528 stream_ids[-1], 1529 ) 1530 1531 min_stream_id = stream_ids[0] 1532 1533 # Delete older entries in the table, as we really only care about 1534 # when the latest change happened. 1535 txn.execute_batch( 1536 """ 1537 DELETE FROM device_lists_stream 1538 WHERE user_id = ? AND device_id = ? AND stream_id < ? 1539 """, 1540 [(user_id, device_id, min_stream_id) for device_id in device_ids], 1541 ) 1542 1543 self.db_pool.simple_insert_many_txn( 1544 txn, 1545 table="device_lists_stream", 1546 values=[ 1547 {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} 1548 for stream_id, device_id in zip(stream_ids, device_ids) 1549 ], 1550 ) 1551 1552 def _add_device_outbound_poke_to_stream_txn( 1553 self, 1554 txn: LoggingTransaction, 1555 user_id: str, 1556 device_ids: Collection[str], 1557 hosts: List[str], 1558 stream_ids: List[str], 1559 context: Dict[str, str], 1560 ): 1561 for host in hosts: 1562 txn.call_after( 1563 self._device_list_federation_stream_cache.entity_has_changed, 1564 host, 1565 stream_ids[-1], 1566 ) 1567 1568 now = self._clock.time_msec() 1569 next_stream_id = iter(stream_ids) 1570 1571 self.db_pool.simple_insert_many_txn( 1572 txn, 1573 table="device_lists_outbound_pokes", 1574 values=[ 1575 { 1576 "destination": destination, 1577 "stream_id": next(next_stream_id), 1578 "user_id": user_id, 1579 "device_id": device_id, 1580 "sent": False, 1581 "ts": now, 1582 "opentracing_context": json_encoder.encode(context) 1583 if whitelisted_homeserver(destination) 1584 else "{}", 1585 } 1586 for destination in hosts 1587 for device_id in device_ids 1588 ], 1589 ) 1590