1# Copyright 2016 OpenMarket Ltd
2# Copyright 2019 New Vector Ltd
3# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import abc
17import logging
18from typing import (
19    TYPE_CHECKING,
20    Any,
21    Collection,
22    Dict,
23    Iterable,
24    List,
25    Optional,
26    Set,
27    Tuple,
28)
29
30from synapse.api.errors import Codes, StoreError
31from synapse.logging.opentracing import (
32    get_active_span_text_map,
33    set_tag,
34    trace,
35    whitelisted_homeserver,
36)
37from synapse.metrics.background_process_metrics import wrap_as_background_process
38from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
39from synapse.storage.database import (
40    DatabasePool,
41    LoggingDatabaseConnection,
42    LoggingTransaction,
43    make_tuple_comparison_clause,
44)
45from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
46from synapse.util import json_decoder, json_encoder
47from synapse.util.caches.descriptors import cached, cachedList
48from synapse.util.caches.lrucache import LruCache
49from synapse.util.iterutils import batch_iter
50from synapse.util.stringutils import shortstr
51
52if TYPE_CHECKING:
53    from synapse.server import HomeServer
54
55logger = logging.getLogger(__name__)
56
57DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
58    "drop_device_list_streams_non_unique_indexes"
59)
60
61BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
62
63
64class DeviceWorkerStore(SQLBaseStore):
65    def __init__(
66        self,
67        database: DatabasePool,
68        db_conn: LoggingDatabaseConnection,
69        hs: "HomeServer",
70    ):
71        super().__init__(database, db_conn, hs)
72
73        if hs.config.worker.run_background_tasks:
74            self._clock.looping_call(
75                self._prune_old_outbound_device_pokes, 60 * 60 * 1000
76            )
77
78    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
79        """Retrieve number of all devices of given users.
80        Only returns number of devices that are not marked as hidden.
81
82        Args:
83            user_ids: The IDs of the users which owns devices
84        Returns:
85            Number of devices of this users.
86        """
87
88        def count_devices_by_users_txn(txn, user_ids):
89            sql = """
90                SELECT count(*)
91                FROM devices
92                WHERE
93                    hidden = '0' AND
94            """
95
96            clause, args = make_in_list_sql_clause(
97                txn.database_engine, "user_id", user_ids
98            )
99
100            txn.execute(sql + clause, args)
101            return txn.fetchone()[0]
102
103        if not user_ids:
104            return 0
105
106        return await self.db_pool.runInteraction(
107            "count_devices_by_users", count_devices_by_users_txn, user_ids
108        )
109
110    async def get_device(
111        self, user_id: str, device_id: str
112    ) -> Optional[Dict[str, Any]]:
113        """Retrieve a device. Only returns devices that are not marked as
114        hidden.
115
116        Args:
117            user_id: The ID of the user which owns the device
118            device_id: The ID of the device to retrieve
119        Returns:
120            A dict containing the device information, or `None` if the device does not
121            exist.
122        """
123        return await self.db_pool.simple_select_one(
124            table="devices",
125            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
126            retcols=("user_id", "device_id", "display_name"),
127            desc="get_device",
128            allow_none=True,
129        )
130
131    async def get_device_opt(
132        self, user_id: str, device_id: str
133    ) -> Optional[Dict[str, Any]]:
134        """Retrieve a device. Only returns devices that are not marked as
135        hidden.
136
137        Args:
138            user_id: The ID of the user which owns the device
139            device_id: The ID of the device to retrieve
140        Returns:
141            A dict containing the device information, or None if the device does not exist.
142        """
143        return await self.db_pool.simple_select_one(
144            table="devices",
145            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
146            retcols=("user_id", "device_id", "display_name"),
147            desc="get_device",
148            allow_none=True,
149        )
150
151    async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
152        """Retrieve all of a user's registered devices. Only returns devices
153        that are not marked as hidden.
154
155        Args:
156            user_id:
157        Returns:
158            A mapping from device_id to a dict containing "device_id", "user_id"
159            and "display_name" for each device.
160        """
161        devices = await self.db_pool.simple_select_list(
162            table="devices",
163            keyvalues={"user_id": user_id, "hidden": False},
164            retcols=("user_id", "device_id", "display_name"),
165            desc="get_devices_by_user",
166        )
167
168        return {d["device_id"]: d for d in devices}
169
170    async def get_devices_by_auth_provider_session_id(
171        self, auth_provider_id: str, auth_provider_session_id: str
172    ) -> List[Dict[str, Any]]:
173        """Retrieve the list of devices associated with a SSO IdP session ID.
174
175        Args:
176            auth_provider_id: The SSO IdP ID as defined in the server config
177            auth_provider_session_id: The session ID within the IdP
178        Returns:
179            A list of dicts containing the device_id and the user_id of each device
180        """
181        return await self.db_pool.simple_select_list(
182            table="device_auth_providers",
183            keyvalues={
184                "auth_provider_id": auth_provider_id,
185                "auth_provider_session_id": auth_provider_session_id,
186            },
187            retcols=("user_id", "device_id"),
188            desc="get_devices_by_auth_provider_session_id",
189        )
190
191    @trace
192    async def get_device_updates_by_remote(
193        self, destination: str, from_stream_id: int, limit: int
194    ) -> Tuple[int, List[Tuple[str, JsonDict]]]:
195        """Get a stream of device updates to send to the given remote server.
196
197        Args:
198            destination: The host the device updates are intended for
199            from_stream_id: The minimum stream_id to filter updates by, exclusive
200            limit: Maximum number of device updates to return
201
202        Returns:
203            - The current stream id (i.e. the stream id of the last update included
204              in the response); and
205            - The list of updates, where each update is a pair of EDU type and
206              EDU contents.
207        """
208        now_stream_id = self.get_device_stream_token()
209
210        has_changed = self._device_list_federation_stream_cache.has_entity_changed(
211            destination, int(from_stream_id)
212        )
213        if not has_changed:
214            return now_stream_id, []
215
216        updates = await self.db_pool.runInteraction(
217            "get_device_updates_by_remote",
218            self._get_device_updates_by_remote_txn,
219            destination,
220            from_stream_id,
221            now_stream_id,
222            limit,
223        )
224
225        # We need to ensure `updates` doesn't grow too big.
226        # Currently: `len(updates) <= limit`.
227
228        # Return an empty list if there are no updates
229        if not updates:
230            return now_stream_id, []
231
232        # get the cross-signing keys of the users in the list, so that we can
233        # determine which of the device changes were cross-signing keys
234        users = {r[0] for r in updates}
235        master_key_by_user = {}
236        self_signing_key_by_user = {}
237        for user in users:
238            cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
239            if cross_signing_key:
240                key_id, verify_key = get_verify_key_from_cross_signing_key(
241                    cross_signing_key
242                )
243                # verify_key is a VerifyKey from signedjson, which uses
244                # .version to denote the portion of the key ID after the
245                # algorithm and colon, which is the device ID
246                master_key_by_user[user] = {
247                    "key_info": cross_signing_key,
248                    "device_id": verify_key.version,
249                }
250
251            cross_signing_key = await self.get_e2e_cross_signing_key(
252                user, "self_signing"
253            )
254            if cross_signing_key:
255                key_id, verify_key = get_verify_key_from_cross_signing_key(
256                    cross_signing_key
257                )
258                self_signing_key_by_user[user] = {
259                    "key_info": cross_signing_key,
260                    "device_id": verify_key.version,
261                }
262
263        # Perform the equivalent of a GROUP BY
264        #
265        # Iterate through the updates list and copy non-duplicate
266        # (user_id, device_id) entries into a map, with the value being
267        # the max stream_id across each set of duplicate entries
268        #
269        # maps (user_id, device_id) -> (stream_id, opentracing_context)
270        #
271        # opentracing_context contains the opentracing metadata for the request
272        # that created the poke
273        #
274        # The most recent request's opentracing_context is used as the
275        # context which created the Edu.
276
277        # This is the stream ID that we will return for the consumer to resume
278        # following this stream later.
279        last_processed_stream_id = from_stream_id
280
281        query_map = {}
282        cross_signing_keys_by_user = {}
283        for user_id, device_id, update_stream_id, update_context in updates:
284            # Calculate the remaining length budget.
285            # Note that, for now, each entry in `cross_signing_keys_by_user`
286            # gives rise to two device updates in the result, so those cost twice
287            # as much (and are the whole reason we need to separately calculate
288            # the budget; we know len(updates) <= limit otherwise!)
289            # N.B. len() on dicts is cheap since they store their size.
290            remaining_length_budget = limit - (
291                len(query_map) + 2 * len(cross_signing_keys_by_user)
292            )
293            assert remaining_length_budget >= 0
294
295            is_master_key_update = (
296                user_id in master_key_by_user
297                and device_id == master_key_by_user[user_id]["device_id"]
298            )
299            is_self_signing_key_update = (
300                user_id in self_signing_key_by_user
301                and device_id == self_signing_key_by_user[user_id]["device_id"]
302            )
303
304            is_cross_signing_key_update = (
305                is_master_key_update or is_self_signing_key_update
306            )
307
308            if (
309                is_cross_signing_key_update
310                and user_id not in cross_signing_keys_by_user
311            ):
312                # This will give rise to 2 device updates.
313                # If we don't have the budget, stop here!
314                if remaining_length_budget < 2:
315                    break
316
317            if is_master_key_update:
318                result = cross_signing_keys_by_user.setdefault(user_id, {})
319                result["master_key"] = master_key_by_user[user_id]["key_info"]
320            elif is_self_signing_key_update:
321                result = cross_signing_keys_by_user.setdefault(user_id, {})
322                result["self_signing_key"] = self_signing_key_by_user[user_id][
323                    "key_info"
324                ]
325            else:
326                key = (user_id, device_id)
327
328                if key not in query_map and remaining_length_budget < 1:
329                    # We don't have space for a new entry
330                    break
331
332                previous_update_stream_id, _ = query_map.get(key, (0, None))
333
334                if update_stream_id > previous_update_stream_id:
335                    # FIXME If this overwrites an older update, this discards the
336                    #  previous OpenTracing context.
337                    #  It might make it harder to track down issues using OpenTracing.
338                    #  If there's a good reason why it doesn't matter, a comment here
339                    #  about that would not hurt.
340                    query_map[key] = (update_stream_id, update_context)
341
342            # As this update has been added to the response, advance the stream
343            # position.
344            last_processed_stream_id = update_stream_id
345
346        # In the worst case scenario, each update is for a distinct user and is
347        # added either to the query_map or to cross_signing_keys_by_user,
348        # but not both:
349        # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
350        # so len(query_map) + len(cross_signing_keys_by_user) <= limit.
351
352        results = await self._get_device_update_edus_by_remote(
353            destination, from_stream_id, query_map
354        )
355
356        # len(results) <= len(query_map) here,
357        # so len(results) + len(cross_signing_keys_by_user) <= limit.
358
359        # Add the updated cross-signing keys to the results list
360        for user_id, result in cross_signing_keys_by_user.items():
361            result["user_id"] = user_id
362            results.append(("m.signing_key_update", result))
363            # also send the unstable version
364            # FIXME: remove this when enough servers have upgraded
365            #        and remove the length budgeting above.
366            results.append(("org.matrix.signing_key_update", result))
367
368        return last_processed_stream_id, results
369
370    def _get_device_updates_by_remote_txn(
371        self,
372        txn: LoggingTransaction,
373        destination: str,
374        from_stream_id: int,
375        now_stream_id: int,
376        limit: int,
377    ) -> List[Tuple[str, str, int, Optional[str]]]:
378        """Return device update information for a given remote destination
379
380        Args:
381            txn: The transaction to execute
382            destination: The host the device updates are intended for
383            from_stream_id: The minimum stream_id to filter updates by, exclusive
384            now_stream_id: The maximum stream_id to filter updates by, inclusive
385            limit: Maximum number of device updates to return
386
387        Returns:
388            List: List of device update tuples:
389                - user_id
390                - device_id
391                - stream_id
392                - opentracing_context
393        """
394        # get the list of device updates that need to be sent
395        sql = """
396            SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
397            WHERE destination = ? AND ? < stream_id AND stream_id <= ?
398            ORDER BY stream_id
399            LIMIT ?
400        """
401        txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
402
403        return list(txn)
404
405    async def _get_device_update_edus_by_remote(
406        self,
407        destination: str,
408        from_stream_id: int,
409        query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
410    ) -> List[Tuple[str, dict]]:
411        """Returns a list of device update EDUs as well as E2EE keys
412
413        Args:
414            destination: The host the device updates are intended for
415            from_stream_id: The minimum stream_id to filter updates by, exclusive
416            query_map: Dictionary mapping (user_id, device_id) to
417                (update stream_id, the relevant json-encoded opentracing context)
418
419        Returns:
420            List of objects representing a device update EDU.
421
422        Postconditions:
423            The returned list has a length not exceeding that of the query_map:
424                len(result) <= len(query_map)
425        """
426        devices = (
427            await self.get_e2e_device_keys_and_signatures(
428                # Because these are (user_id, device_id) tuples with all
429                # device_ids not being None, the returned list's length will not
430                # exceed that of query_map.
431                query_map.keys(),
432                include_all_devices=True,
433                include_deleted_devices=True,
434            )
435            if query_map
436            else {}
437        )
438
439        results = []
440        for user_id, user_devices in devices.items():
441            # The prev_id for the first row is always the last row before
442            # `from_stream_id`
443            prev_id = await self._get_last_device_update_for_remote_user(
444                destination, user_id, from_stream_id
445            )
446
447            # make sure we go through the devices in stream order
448            device_ids = sorted(
449                user_devices.keys(),
450                key=lambda i: query_map[(user_id, i)][0],
451            )
452
453            for device_id in device_ids:
454                device = user_devices[device_id]
455                stream_id, opentracing_context = query_map[(user_id, device_id)]
456                result = {
457                    "user_id": user_id,
458                    "device_id": device_id,
459                    "prev_id": [prev_id] if prev_id else [],
460                    "stream_id": stream_id,
461                    "org.matrix.opentracing_context": opentracing_context,
462                }
463
464                prev_id = stream_id
465
466                if device is not None:
467                    keys = device.keys
468                    if keys:
469                        result["keys"] = keys
470
471                    device_display_name = device.display_name
472                    if device_display_name:
473                        result["device_display_name"] = device_display_name
474                else:
475                    result["deleted"] = True
476
477                results.append(("m.device_list_update", result))
478
479        return results
480
481    async def _get_last_device_update_for_remote_user(
482        self, destination: str, user_id: str, from_stream_id: int
483    ) -> int:
484        def f(txn):
485            prev_sent_id_sql = """
486                SELECT coalesce(max(stream_id), 0) as stream_id
487                FROM device_lists_outbound_last_success
488                WHERE destination = ? AND user_id = ? AND stream_id <= ?
489            """
490            txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
491            rows = txn.fetchall()
492            return rows[0][0]
493
494        return await self.db_pool.runInteraction(
495            "get_last_device_update_for_remote_user", f
496        )
497
498    async def mark_as_sent_devices_by_remote(
499        self, destination: str, stream_id: int
500    ) -> None:
501        """Mark that updates have successfully been sent to the destination."""
502        await self.db_pool.runInteraction(
503            "mark_as_sent_devices_by_remote",
504            self._mark_as_sent_devices_by_remote_txn,
505            destination,
506            stream_id,
507        )
508
509    def _mark_as_sent_devices_by_remote_txn(
510        self, txn: LoggingTransaction, destination: str, stream_id: int
511    ) -> None:
512        # We update the device_lists_outbound_last_success with the successfully
513        # poked users.
514        sql = """
515            SELECT user_id, coalesce(max(o.stream_id), 0)
516            FROM device_lists_outbound_pokes as o
517            WHERE destination = ? AND o.stream_id <= ?
518            GROUP BY user_id
519        """
520        txn.execute(sql, (destination, stream_id))
521        rows = txn.fetchall()
522
523        self.db_pool.simple_upsert_many_txn(
524            txn=txn,
525            table="device_lists_outbound_last_success",
526            key_names=("destination", "user_id"),
527            key_values=((destination, user_id) for user_id, _ in rows),
528            value_names=("stream_id",),
529            value_values=((stream_id,) for _, stream_id in rows),
530        )
531
532        # Delete all sent outbound pokes
533        sql = """
534            DELETE FROM device_lists_outbound_pokes
535            WHERE destination = ? AND stream_id <= ?
536        """
537        txn.execute(sql, (destination, stream_id))
538
539    async def add_user_signature_change_to_streams(
540        self, from_user_id: str, user_ids: List[str]
541    ) -> int:
542        """Persist that a user has made new signatures
543
544        Args:
545            from_user_id: the user who made the signatures
546            user_ids: the users who were signed
547
548        Returns:
549            The new stream ID.
550        """
551
552        async with self._device_list_id_gen.get_next() as stream_id:
553            await self.db_pool.runInteraction(
554                "add_user_sig_change_to_streams",
555                self._add_user_signature_change_txn,
556                from_user_id,
557                user_ids,
558                stream_id,
559            )
560        return stream_id
561
562    def _add_user_signature_change_txn(
563        self,
564        txn: LoggingTransaction,
565        from_user_id: str,
566        user_ids: List[str],
567        stream_id: int,
568    ) -> None:
569        txn.call_after(
570            self._user_signature_stream_cache.entity_has_changed,
571            from_user_id,
572            stream_id,
573        )
574        self.db_pool.simple_insert_txn(
575            txn,
576            "user_signature_stream",
577            values={
578                "stream_id": stream_id,
579                "from_user_id": from_user_id,
580                "user_ids": json_encoder.encode(user_ids),
581            },
582        )
583
584    @abc.abstractmethod
585    def get_device_stream_token(self) -> int:
586        """Get the current stream id from the _device_list_id_gen"""
587        ...
588
589    @trace
590    async def get_user_devices_from_cache(
591        self, query_list: List[Tuple[str, str]]
592    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
593        """Get the devices (and keys if any) for remote users from the cache.
594
595        Args:
596            query_list: List of (user_id, device_ids), if device_ids is
597                falsey then return all device ids for that user.
598
599        Returns:
600            A tuple of (user_ids_not_in_cache, results_map), where
601            user_ids_not_in_cache is a set of user_ids and results_map is a
602            mapping of user_id -> device_id -> device_info.
603        """
604        user_ids = {user_id for user_id, _ in query_list}
605        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
606
607        # We go and check if any of the users need to have their device lists
608        # resynced. If they do then we remove them from the cached list.
609        users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
610            user_ids
611        )
612        user_ids_in_cache = {
613            user_id for user_id, stream_id in user_map.items() if stream_id
614        } - users_needing_resync
615        user_ids_not_in_cache = user_ids - user_ids_in_cache
616
617        results = {}
618        for user_id, device_id in query_list:
619            if user_id not in user_ids_in_cache:
620                continue
621
622            if device_id:
623                device = await self._get_cached_user_device(user_id, device_id)
624                results.setdefault(user_id, {})[device_id] = device
625            else:
626                results[user_id] = await self.get_cached_devices_for_user(user_id)
627
628        set_tag("in_cache", results)
629        set_tag("not_in_cache", user_ids_not_in_cache)
630
631        return user_ids_not_in_cache, results
632
633    @cached(num_args=2, tree=True)
634    async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
635        content = await self.db_pool.simple_select_one_onecol(
636            table="device_lists_remote_cache",
637            keyvalues={"user_id": user_id, "device_id": device_id},
638            retcol="content",
639            desc="_get_cached_user_device",
640        )
641        return db_to_json(content)
642
643    @cached()
644    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
645        devices = await self.db_pool.simple_select_list(
646            table="device_lists_remote_cache",
647            keyvalues={"user_id": user_id},
648            retcols=("device_id", "content"),
649            desc="get_cached_devices_for_user",
650        )
651        return {
652            device["device_id"]: db_to_json(device["content"]) for device in devices
653        }
654
655    async def get_users_whose_devices_changed(
656        self, from_key: int, user_ids: Iterable[str]
657    ) -> Set[str]:
658        """Get set of users whose devices have changed since `from_key` that
659        are in the given list of user_ids.
660
661        Args:
662            from_key: The device lists stream token
663            user_ids: The user IDs to query for devices.
664
665        Returns:
666            The set of user_ids whose devices have changed since `from_key`
667        """
668
669        # Get set of users who *may* have changed. Users not in the returned
670        # list have definitely not changed.
671        to_check = self._device_list_stream_cache.get_entities_changed(
672            user_ids, from_key
673        )
674
675        if not to_check:
676            return set()
677
678        def _get_users_whose_devices_changed_txn(txn):
679            changes = set()
680
681            sql = """
682                SELECT DISTINCT user_id FROM device_lists_stream
683                WHERE stream_id > ?
684                AND
685            """
686
687            for chunk in batch_iter(to_check, 100):
688                clause, args = make_in_list_sql_clause(
689                    txn.database_engine, "user_id", chunk
690                )
691                txn.execute(sql + clause, (from_key,) + tuple(args))
692                changes.update(user_id for user_id, in txn)
693
694            return changes
695
696        return await self.db_pool.runInteraction(
697            "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
698        )
699
700    async def get_users_whose_signatures_changed(
701        self, user_id: str, from_key: int
702    ) -> Set[str]:
703        """Get the users who have new cross-signing signatures made by `user_id` since
704        `from_key`.
705
706        Args:
707            user_id: the user who made the signatures
708            from_key: The device lists stream token
709
710        Returns:
711            A set of user IDs with updated signatures.
712        """
713
714        if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
715            sql = """
716                SELECT DISTINCT user_ids FROM user_signature_stream
717                WHERE from_user_id = ? AND stream_id > ?
718            """
719            rows = await self.db_pool.execute(
720                "get_users_whose_signatures_changed", None, sql, user_id, from_key
721            )
722            return {user for row in rows for user in db_to_json(row[0])}
723        else:
724            return set()
725
726    async def get_all_device_list_changes_for_remotes(
727        self, instance_name: str, last_id: int, current_id: int, limit: int
728    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
729        """Get updates for device lists replication stream.
730
731        Args:
732            instance_name: The writer we want to fetch updates from. Unused
733                here since there is only ever one writer.
734            last_id: The token to fetch updates from. Exclusive.
735            current_id: The token to fetch updates up to. Inclusive.
736            limit: The requested limit for the number of rows to return. The
737                function may return more or fewer rows.
738
739        Returns:
740            A tuple consisting of: the updates, a token to use to fetch
741            subsequent updates, and whether we returned fewer rows than exists
742            between the requested tokens due to the limit.
743
744            The token returned can be used in a subsequent call to this
745            function to get further updates.
746
747            The updates are a list of 2-tuples of stream ID and the row data
748        """
749
750        if last_id == current_id:
751            return [], current_id, False
752
753        def _get_all_device_list_changes_for_remotes(txn):
754            # This query Does The Right Thing where it'll correctly apply the
755            # bounds to the inner queries.
756            sql = """
757                SELECT stream_id, entity FROM (
758                    SELECT stream_id, user_id AS entity FROM device_lists_stream
759                    UNION ALL
760                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
761                ) AS e
762                WHERE ? < stream_id AND stream_id <= ?
763                LIMIT ?
764            """
765
766            txn.execute(sql, (last_id, current_id, limit))
767            updates = [(row[0], row[1:]) for row in txn]
768            limited = False
769            upto_token = current_id
770            if len(updates) >= limit:
771                upto_token = updates[-1][0]
772                limited = True
773
774            return updates, upto_token, limited
775
776        return await self.db_pool.runInteraction(
777            "get_all_device_list_changes_for_remotes",
778            _get_all_device_list_changes_for_remotes,
779        )
780
781    @cached(max_entries=10000)
782    async def get_device_list_last_stream_id_for_remote(
783        self, user_id: str
784    ) -> Optional[Any]:
785        """Get the last stream_id we got for a user. May be None if we haven't
786        got any information for them.
787        """
788        return await self.db_pool.simple_select_one_onecol(
789            table="device_lists_remote_extremeties",
790            keyvalues={"user_id": user_id},
791            retcol="stream_id",
792            desc="get_device_list_last_stream_id_for_remote",
793            allow_none=True,
794        )
795
796    @cachedList(
797        cached_method_name="get_device_list_last_stream_id_for_remote",
798        list_name="user_ids",
799    )
800    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
801        rows = await self.db_pool.simple_select_many_batch(
802            table="device_lists_remote_extremeties",
803            column="user_id",
804            iterable=user_ids,
805            retcols=("user_id", "stream_id"),
806            desc="get_device_list_last_stream_id_for_remotes",
807        )
808
809        results = {user_id: None for user_id in user_ids}
810        results.update({row["user_id"]: row["stream_id"] for row in rows})
811
812        return results
813
814    async def get_user_ids_requiring_device_list_resync(
815        self,
816        user_ids: Optional[Collection[str]] = None,
817    ) -> Set[str]:
818        """Given a list of remote users return the list of users that we
819        should resync the device lists for. If None is given instead of a list,
820        return every user that we should resync the device lists for.
821
822        Returns:
823            The IDs of users whose device lists need resync.
824        """
825        if user_ids:
826            rows = await self.db_pool.simple_select_many_batch(
827                table="device_lists_remote_resync",
828                column="user_id",
829                iterable=user_ids,
830                retcols=("user_id",),
831                desc="get_user_ids_requiring_device_list_resync_with_iterable",
832            )
833        else:
834            rows = await self.db_pool.simple_select_list(
835                table="device_lists_remote_resync",
836                keyvalues=None,
837                retcols=("user_id",),
838                desc="get_user_ids_requiring_device_list_resync",
839            )
840
841        return {row["user_id"] for row in rows}
842
843    async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
844        """Records that the server has reason to believe the cache of the devices
845        for the remote users is out of date.
846        """
847        await self.db_pool.simple_upsert(
848            table="device_lists_remote_resync",
849            keyvalues={"user_id": user_id},
850            values={},
851            insertion_values={"added_ts": self._clock.time_msec()},
852            desc="mark_remote_user_device_cache_as_stale",
853        )
854
855    async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
856        # Remove the database entry that says we need to resync devices, after a resync
857        await self.db_pool.simple_delete(
858            table="device_lists_remote_resync",
859            keyvalues={"user_id": user_id},
860            desc="mark_remote_user_device_cache_as_valid",
861        )
862
863    async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
864        """Mark that we no longer track device lists for remote user."""
865
866        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
867            self.db_pool.simple_delete_txn(
868                txn,
869                table="device_lists_remote_extremeties",
870                keyvalues={"user_id": user_id},
871            )
872            self._invalidate_cache_and_stream(
873                txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
874            )
875
876        await self.db_pool.runInteraction(
877            "mark_remote_user_device_list_as_unsubscribed",
878            _mark_remote_user_device_list_as_unsubscribed_txn,
879        )
880
881    async def get_dehydrated_device(
882        self, user_id: str
883    ) -> Optional[Tuple[str, JsonDict]]:
884        """Retrieve the information for a dehydrated device.
885
886        Args:
887            user_id: the user whose dehydrated device we are looking for
888        Returns:
889            a tuple whose first item is the device ID, and the second item is
890            the dehydrated device information
891        """
892        # FIXME: make sure device ID still exists in devices table
893        row = await self.db_pool.simple_select_one(
894            table="dehydrated_devices",
895            keyvalues={"user_id": user_id},
896            retcols=["device_id", "device_data"],
897            allow_none=True,
898        )
899        return (
900            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
901        )
902
903    def _store_dehydrated_device_txn(
904        self, txn, user_id: str, device_id: str, device_data: str
905    ) -> Optional[str]:
906        old_device_id = self.db_pool.simple_select_one_onecol_txn(
907            txn,
908            table="dehydrated_devices",
909            keyvalues={"user_id": user_id},
910            retcol="device_id",
911            allow_none=True,
912        )
913        self.db_pool.simple_upsert_txn(
914            txn,
915            table="dehydrated_devices",
916            keyvalues={"user_id": user_id},
917            values={"device_id": device_id, "device_data": device_data},
918        )
919        return old_device_id
920
921    async def store_dehydrated_device(
922        self, user_id: str, device_id: str, device_data: JsonDict
923    ) -> Optional[str]:
924        """Store a dehydrated device for a user.
925
926        Args:
927            user_id: the user that we are storing the device for
928            device_id: the ID of the dehydrated device
929            device_data: the dehydrated device information
930        Returns:
931            device id of the user's previous dehydrated device, if any
932        """
933        return await self.db_pool.runInteraction(
934            "store_dehydrated_device_txn",
935            self._store_dehydrated_device_txn,
936            user_id,
937            device_id,
938            json_encoder.encode(device_data),
939        )
940
941    async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
942        """Remove a dehydrated device.
943
944        Args:
945            user_id: the user that the dehydrated device belongs to
946            device_id: the ID of the dehydrated device
947        """
948        count = await self.db_pool.simple_delete(
949            "dehydrated_devices",
950            {"user_id": user_id, "device_id": device_id},
951            desc="remove_dehydrated_device",
952        )
953        return count >= 1
954
955    @wrap_as_background_process("prune_old_outbound_device_pokes")
956    async def _prune_old_outbound_device_pokes(
957        self, prune_age: int = 24 * 60 * 60 * 1000
958    ) -> None:
959        """Delete old entries out of the device_lists_outbound_pokes to ensure
960        that we don't fill up due to dead servers.
961
962        Normally, we try to send device updates as a delta since a previous known point:
963        this is done by setting the prev_id in the m.device_list_update EDU. However,
964        for that to work, we have to have a complete record of each change to
965        each device, which can add up to quite a lot of data.
966
967        An alternative mechanism is that, if the remote server sees that it has missed
968        an entry in the stream_id sequence for a given user, it will request a full
969        list of that user's devices. Hence, we can reduce the amount of data we have to
970        store (and transmit in some future transaction), by clearing almost everything
971        for a given destination out of the database, and having the remote server
972        resync.
973
974        All we need to do is make sure we keep at least one row for each
975        (user, destination) pair, to remind us to send a m.device_list_update EDU for
976        that user when the destination comes back. It doesn't matter which device
977        we keep.
978        """
979        yesterday = self._clock.time_msec() - prune_age
980
981        def _prune_txn(txn):
982            # look for (user, destination) pairs which have an update older than
983            # the cutoff.
984            #
985            # For each pair, we also need to know the most recent stream_id, and
986            # an arbitrary device_id at that stream_id.
987            select_sql = """
988            SELECT
989                dlop1.destination,
990                dlop1.user_id,
991                MAX(dlop1.stream_id) AS stream_id,
992                (SELECT MIN(dlop2.device_id) AS device_id FROM
993                    device_lists_outbound_pokes dlop2
994                    WHERE dlop2.destination = dlop1.destination AND
995                      dlop2.user_id=dlop1.user_id AND
996                      dlop2.stream_id=MAX(dlop1.stream_id)
997                )
998            FROM device_lists_outbound_pokes dlop1
999                GROUP BY destination, user_id
1000                HAVING min(ts) < ? AND count(*) > 1
1001            """
1002
1003            txn.execute(select_sql, (yesterday,))
1004            rows = txn.fetchall()
1005
1006            if not rows:
1007                return
1008
1009            logger.info(
1010                "Pruning old outbound device list updates for %i users/destinations: %s",
1011                len(rows),
1012                shortstr((row[0], row[1]) for row in rows),
1013            )
1014
1015            # we want to keep the update with the highest stream_id for each user.
1016            #
1017            # there might be more than one update (with different device_ids) with the
1018            # same stream_id, so we also delete all but one rows with the max stream id.
1019            delete_sql = """
1020                DELETE FROM device_lists_outbound_pokes
1021                WHERE destination = ? AND user_id = ? AND (
1022                    stream_id < ? OR
1023                    (stream_id = ? AND device_id != ?)
1024                )
1025            """
1026            count = 0
1027            for (destination, user_id, stream_id, device_id) in rows:
1028                txn.execute(
1029                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
1030                )
1031                count += txn.rowcount
1032
1033            # Since we've deleted unsent deltas, we need to remove the entry
1034            # of last successful sent so that the prev_ids are correctly set.
1035            sql = """
1036                DELETE FROM device_lists_outbound_last_success
1037                WHERE destination = ? AND user_id = ?
1038            """
1039            txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
1040
1041            logger.info("Pruned %d device list outbound pokes", count)
1042
1043        await self.db_pool.runInteraction(
1044            "_prune_old_outbound_device_pokes",
1045            _prune_txn,
1046        )
1047
1048
1049class DeviceBackgroundUpdateStore(SQLBaseStore):
1050    def __init__(
1051        self,
1052        database: DatabasePool,
1053        db_conn: LoggingDatabaseConnection,
1054        hs: "HomeServer",
1055    ):
1056        super().__init__(database, db_conn, hs)
1057
1058        self.db_pool.updates.register_background_index_update(
1059            "device_lists_stream_idx",
1060            index_name="device_lists_stream_user_id",
1061            table="device_lists_stream",
1062            columns=["user_id", "device_id"],
1063        )
1064
1065        # create a unique index on device_lists_remote_cache
1066        self.db_pool.updates.register_background_index_update(
1067            "device_lists_remote_cache_unique_idx",
1068            index_name="device_lists_remote_cache_unique_id",
1069            table="device_lists_remote_cache",
1070            columns=["user_id", "device_id"],
1071            unique=True,
1072        )
1073
1074        # And one on device_lists_remote_extremeties
1075        self.db_pool.updates.register_background_index_update(
1076            "device_lists_remote_extremeties_unique_idx",
1077            index_name="device_lists_remote_extremeties_unique_idx",
1078            table="device_lists_remote_extremeties",
1079            columns=["user_id"],
1080            unique=True,
1081        )
1082
1083        # once they complete, we can remove the old non-unique indexes.
1084        self.db_pool.updates.register_background_update_handler(
1085            DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
1086            self._drop_device_list_streams_non_unique_indexes,
1087        )
1088
1089        # clear out duplicate device list outbound pokes
1090        self.db_pool.updates.register_background_update_handler(
1091            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
1092            self._remove_duplicate_outbound_pokes,
1093        )
1094
1095        # a pair of background updates that were added during the 1.14 release cycle,
1096        # but replaced with 58/06dlols_unique_idx.py
1097        self.db_pool.updates.register_noop_background_update(
1098            "device_lists_outbound_last_success_unique_idx",
1099        )
1100        self.db_pool.updates.register_noop_background_update(
1101            "drop_device_lists_outbound_last_success_non_unique_idx",
1102        )
1103
1104    async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
1105        def f(conn):
1106            txn = conn.cursor()
1107            txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
1108            txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
1109            txn.close()
1110
1111        await self.db_pool.runWithConnection(f)
1112        await self.db_pool.updates._end_background_update(
1113            DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
1114        )
1115        return 1
1116
1117    async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
1118        # for some reason, we have accumulated duplicate entries in
1119        # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
1120        # efficient.
1121        #
1122        # For each duplicate, we delete all the existing rows and put one back.
1123
1124        KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
1125        last_row = progress.get(
1126            "last_row",
1127            {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
1128        )
1129
1130        def _txn(txn):
1131            clause, args = make_tuple_comparison_clause(
1132                [(x, last_row[x]) for x in KEY_COLS]
1133            )
1134            sql = """
1135                SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
1136                FROM device_lists_outbound_pokes
1137                WHERE %s
1138                GROUP BY %s
1139                HAVING count(*) > 1
1140                ORDER BY %s
1141                LIMIT ?
1142                """ % (
1143                clause,  # WHERE
1144                ",".join(KEY_COLS),  # GROUP BY
1145                ",".join(KEY_COLS),  # ORDER BY
1146            )
1147            txn.execute(sql, args + [batch_size])
1148            rows = self.db_pool.cursor_to_dict(txn)
1149
1150            row = None
1151            for row in rows:
1152                self.db_pool.simple_delete_txn(
1153                    txn,
1154                    "device_lists_outbound_pokes",
1155                    {x: row[x] for x in KEY_COLS},
1156                )
1157
1158                row["sent"] = False
1159                self.db_pool.simple_insert_txn(
1160                    txn,
1161                    "device_lists_outbound_pokes",
1162                    row,
1163                )
1164
1165            if row:
1166                self.db_pool.updates._background_update_progress_txn(
1167                    txn,
1168                    BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
1169                    {"last_row": row},
1170                )
1171
1172            return len(rows)
1173
1174        rows = await self.db_pool.runInteraction(
1175            BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
1176        )
1177
1178        if not rows:
1179            await self.db_pool.updates._end_background_update(
1180                BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
1181            )
1182
1183        return rows
1184
1185
1186class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
1187    def __init__(
1188        self,
1189        database: DatabasePool,
1190        db_conn: LoggingDatabaseConnection,
1191        hs: "HomeServer",
1192    ):
1193        super().__init__(database, db_conn, hs)
1194
1195        # Map of (user_id, device_id) -> bool. If there is an entry that implies
1196        # the device exists.
1197        self.device_id_exists_cache = LruCache(
1198            cache_name="device_id_exists", max_size=10000
1199        )
1200
1201    async def store_device(
1202        self,
1203        user_id: str,
1204        device_id: str,
1205        initial_device_display_name: Optional[str],
1206        auth_provider_id: Optional[str] = None,
1207        auth_provider_session_id: Optional[str] = None,
1208    ) -> bool:
1209        """Ensure the given device is known; add it to the store if not
1210
1211        Args:
1212            user_id: id of user associated with the device
1213            device_id: id of device
1214            initial_device_display_name: initial displayname of the device.
1215                Ignored if device exists.
1216            auth_provider_id: The SSO IdP the user used, if any.
1217            auth_provider_session_id: The session ID (sid) got from a OIDC login.
1218
1219        Returns:
1220            Whether the device was inserted or an existing device existed with that ID.
1221
1222        Raises:
1223            StoreError: if the device is already in use
1224        """
1225        key = (user_id, device_id)
1226        if self.device_id_exists_cache.get(key, None):
1227            return False
1228
1229        try:
1230            inserted = await self.db_pool.simple_upsert(
1231                "devices",
1232                keyvalues={
1233                    "user_id": user_id,
1234                    "device_id": device_id,
1235                },
1236                values={},
1237                insertion_values={
1238                    "display_name": initial_device_display_name,
1239                    "hidden": False,
1240                },
1241                desc="store_device",
1242            )
1243            if not inserted:
1244                # if the device already exists, check if it's a real device, or
1245                # if the device ID is reserved by something else
1246                hidden = await self.db_pool.simple_select_one_onecol(
1247                    "devices",
1248                    keyvalues={"user_id": user_id, "device_id": device_id},
1249                    retcol="hidden",
1250                )
1251                if hidden:
1252                    raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
1253
1254            if auth_provider_id and auth_provider_session_id:
1255                await self.db_pool.simple_insert(
1256                    "device_auth_providers",
1257                    values={
1258                        "user_id": user_id,
1259                        "device_id": device_id,
1260                        "auth_provider_id": auth_provider_id,
1261                        "auth_provider_session_id": auth_provider_session_id,
1262                    },
1263                    desc="store_device_auth_provider",
1264                )
1265
1266            self.device_id_exists_cache.set(key, True)
1267            return inserted
1268        except StoreError:
1269            raise
1270        except Exception as e:
1271            logger.error(
1272                "store_device with device_id=%s(%r) user_id=%s(%r)"
1273                " display_name=%s(%r) failed: %s",
1274                type(device_id).__name__,
1275                device_id,
1276                type(user_id).__name__,
1277                user_id,
1278                type(initial_device_display_name).__name__,
1279                initial_device_display_name,
1280                e,
1281            )
1282            raise StoreError(500, "Problem storing device.")
1283
1284    async def delete_device(self, user_id: str, device_id: str) -> None:
1285        """Delete a device and its device_inbox.
1286
1287        Args:
1288            user_id: The ID of the user which owns the device
1289            device_id: The ID of the device to delete
1290        """
1291
1292        await self.delete_devices(user_id, [device_id])
1293
1294    async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
1295        """Deletes several devices.
1296
1297        Args:
1298            user_id: The ID of the user which owns the devices
1299            device_ids: The IDs of the devices to delete
1300        """
1301
1302        def _delete_devices_txn(txn: LoggingTransaction) -> None:
1303            self.db_pool.simple_delete_many_txn(
1304                txn,
1305                table="devices",
1306                column="device_id",
1307                values=device_ids,
1308                keyvalues={"user_id": user_id, "hidden": False},
1309            )
1310
1311            self.db_pool.simple_delete_many_txn(
1312                txn,
1313                table="device_inbox",
1314                column="device_id",
1315                values=device_ids,
1316                keyvalues={"user_id": user_id},
1317            )
1318
1319            self.db_pool.simple_delete_many_txn(
1320                txn,
1321                table="device_auth_providers",
1322                column="device_id",
1323                values=device_ids,
1324                keyvalues={"user_id": user_id},
1325            )
1326
1327        await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
1328        for device_id in device_ids:
1329            self.device_id_exists_cache.invalidate((user_id, device_id))
1330
1331    async def update_device(
1332        self, user_id: str, device_id: str, new_display_name: Optional[str] = None
1333    ) -> None:
1334        """Update a device. Only updates the device if it is not marked as
1335        hidden.
1336
1337        Args:
1338            user_id: The ID of the user which owns the device
1339            device_id: The ID of the device to update
1340            new_display_name: new displayname for device; None to leave unchanged
1341        Raises:
1342            StoreError: if the device is not found
1343        """
1344        updates = {}
1345        if new_display_name is not None:
1346            updates["display_name"] = new_display_name
1347        if not updates:
1348            return None
1349        await self.db_pool.simple_update_one(
1350            table="devices",
1351            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
1352            updatevalues=updates,
1353            desc="update_device",
1354        )
1355
1356    async def update_remote_device_list_cache_entry(
1357        self, user_id: str, device_id: str, content: JsonDict, stream_id: str
1358    ) -> None:
1359        """Updates a single device in the cache of a remote user's devicelist.
1360
1361        Note: assumes that we are the only thread that can be updating this user's
1362        device list.
1363
1364        Args:
1365            user_id: User to update device list for
1366            device_id: ID of decivice being updated
1367            content: new data on this device
1368            stream_id: the version of the device list
1369        """
1370        await self.db_pool.runInteraction(
1371            "update_remote_device_list_cache_entry",
1372            self._update_remote_device_list_cache_entry_txn,
1373            user_id,
1374            device_id,
1375            content,
1376            stream_id,
1377        )
1378
1379    def _update_remote_device_list_cache_entry_txn(
1380        self,
1381        txn: LoggingTransaction,
1382        user_id: str,
1383        device_id: str,
1384        content: JsonDict,
1385        stream_id: str,
1386    ) -> None:
1387        if content.get("deleted"):
1388            self.db_pool.simple_delete_txn(
1389                txn,
1390                table="device_lists_remote_cache",
1391                keyvalues={"user_id": user_id, "device_id": device_id},
1392            )
1393
1394            txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
1395        else:
1396            self.db_pool.simple_upsert_txn(
1397                txn,
1398                table="device_lists_remote_cache",
1399                keyvalues={"user_id": user_id, "device_id": device_id},
1400                values={"content": json_encoder.encode(content)},
1401                # we don't need to lock, because we assume we are the only thread
1402                # updating this user's devices.
1403                lock=False,
1404            )
1405
1406        txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
1407        txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1408        txn.call_after(
1409            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1410        )
1411
1412        self.db_pool.simple_upsert_txn(
1413            txn,
1414            table="device_lists_remote_extremeties",
1415            keyvalues={"user_id": user_id},
1416            values={"stream_id": stream_id},
1417            # again, we can assume we are the only thread updating this user's
1418            # extremity.
1419            lock=False,
1420        )
1421
1422    async def update_remote_device_list_cache(
1423        self, user_id: str, devices: List[dict], stream_id: int
1424    ) -> None:
1425        """Replace the entire cache of the remote user's devices.
1426
1427        Note: assumes that we are the only thread that can be updating this user's
1428        device list.
1429
1430        Args:
1431            user_id: User to update device list for
1432            devices: list of device objects supplied over federation
1433            stream_id: the version of the device list
1434        """
1435        await self.db_pool.runInteraction(
1436            "update_remote_device_list_cache",
1437            self._update_remote_device_list_cache_txn,
1438            user_id,
1439            devices,
1440            stream_id,
1441        )
1442
1443    def _update_remote_device_list_cache_txn(
1444        self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
1445    ) -> None:
1446        self.db_pool.simple_delete_txn(
1447            txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
1448        )
1449
1450        self.db_pool.simple_insert_many_txn(
1451            txn,
1452            table="device_lists_remote_cache",
1453            values=[
1454                {
1455                    "user_id": user_id,
1456                    "device_id": content["device_id"],
1457                    "content": json_encoder.encode(content),
1458                }
1459                for content in devices
1460            ],
1461        )
1462
1463        txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
1464        txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
1465        txn.call_after(
1466            self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
1467        )
1468
1469        self.db_pool.simple_upsert_txn(
1470            txn,
1471            table="device_lists_remote_extremeties",
1472            keyvalues={"user_id": user_id},
1473            values={"stream_id": stream_id},
1474            # we don't need to lock, because we can assume we are the only thread
1475            # updating this user's extremity.
1476            lock=False,
1477        )
1478
1479    async def add_device_change_to_streams(
1480        self, user_id: str, device_ids: Collection[str], hosts: List[str]
1481    ) -> int:
1482        """Persist that a user's devices have been updated, and which hosts
1483        (if any) should be poked.
1484        """
1485        if not device_ids:
1486            return
1487
1488        async with self._device_list_id_gen.get_next_mult(
1489            len(device_ids)
1490        ) as stream_ids:
1491            await self.db_pool.runInteraction(
1492                "add_device_change_to_stream",
1493                self._add_device_change_to_stream_txn,
1494                user_id,
1495                device_ids,
1496                stream_ids,
1497            )
1498
1499        if not hosts:
1500            return stream_ids[-1]
1501
1502        context = get_active_span_text_map()
1503        async with self._device_list_id_gen.get_next_mult(
1504            len(hosts) * len(device_ids)
1505        ) as stream_ids:
1506            await self.db_pool.runInteraction(
1507                "add_device_outbound_poke_to_stream",
1508                self._add_device_outbound_poke_to_stream_txn,
1509                user_id,
1510                device_ids,
1511                hosts,
1512                stream_ids,
1513                context,
1514            )
1515
1516        return stream_ids[-1]
1517
1518    def _add_device_change_to_stream_txn(
1519        self,
1520        txn: LoggingTransaction,
1521        user_id: str,
1522        device_ids: Collection[str],
1523        stream_ids: List[str],
1524    ):
1525        txn.call_after(
1526            self._device_list_stream_cache.entity_has_changed,
1527            user_id,
1528            stream_ids[-1],
1529        )
1530
1531        min_stream_id = stream_ids[0]
1532
1533        # Delete older entries in the table, as we really only care about
1534        # when the latest change happened.
1535        txn.execute_batch(
1536            """
1537            DELETE FROM device_lists_stream
1538            WHERE user_id = ? AND device_id = ? AND stream_id < ?
1539            """,
1540            [(user_id, device_id, min_stream_id) for device_id in device_ids],
1541        )
1542
1543        self.db_pool.simple_insert_many_txn(
1544            txn,
1545            table="device_lists_stream",
1546            values=[
1547                {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
1548                for stream_id, device_id in zip(stream_ids, device_ids)
1549            ],
1550        )
1551
1552    def _add_device_outbound_poke_to_stream_txn(
1553        self,
1554        txn: LoggingTransaction,
1555        user_id: str,
1556        device_ids: Collection[str],
1557        hosts: List[str],
1558        stream_ids: List[str],
1559        context: Dict[str, str],
1560    ):
1561        for host in hosts:
1562            txn.call_after(
1563                self._device_list_federation_stream_cache.entity_has_changed,
1564                host,
1565                stream_ids[-1],
1566            )
1567
1568        now = self._clock.time_msec()
1569        next_stream_id = iter(stream_ids)
1570
1571        self.db_pool.simple_insert_many_txn(
1572            txn,
1573            table="device_lists_outbound_pokes",
1574            values=[
1575                {
1576                    "destination": destination,
1577                    "stream_id": next(next_stream_id),
1578                    "user_id": user_id,
1579                    "device_id": device_id,
1580                    "sent": False,
1581                    "ts": now,
1582                    "opentracing_context": json_encoder.encode(context)
1583                    if whitelisted_homeserver(destination)
1584                    else "{}",
1585                }
1586                for destination in hosts
1587                for device_id in device_ids
1588            ],
1589        )
1590