1# Copyright 2016 OpenMarket Ltd
2# Copyright 2019 New Vector Ltd
3# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import logging
17from typing import (
18    TYPE_CHECKING,
19    Any,
20    Collection,
21    Dict,
22    Iterable,
23    List,
24    Mapping,
25    Optional,
26    Set,
27    Tuple,
28)
29
30from synapse.api import errors
31from synapse.api.constants import EventTypes
32from synapse.api.errors import (
33    Codes,
34    FederationDeniedError,
35    HttpResponseException,
36    RequestSendFailed,
37    SynapseError,
38)
39from synapse.logging.opentracing import log_kv, set_tag, trace
40from synapse.metrics.background_process_metrics import run_as_background_process
41from synapse.types import (
42    JsonDict,
43    StreamToken,
44    UserID,
45    get_domain_from_id,
46    get_verify_key_from_cross_signing_key,
47)
48from synapse.util import stringutils
49from synapse.util.async_helpers import Linearizer
50from synapse.util.caches.expiringcache import ExpiringCache
51from synapse.util.metrics import measure_func
52from synapse.util.retryutils import NotRetryingDestination
53
54if TYPE_CHECKING:
55    from synapse.server import HomeServer
56
57logger = logging.getLogger(__name__)
58
59MAX_DEVICE_DISPLAY_NAME_LEN = 100
60
61
62class DeviceWorkerHandler:
63    def __init__(self, hs: "HomeServer"):
64        self.clock = hs.get_clock()
65        self.hs = hs
66        self.store = hs.get_datastore()
67        self.notifier = hs.get_notifier()
68        self.state = hs.get_state_handler()
69        self.state_store = hs.get_storage().state
70        self._auth_handler = hs.get_auth_handler()
71        self.server_name = hs.hostname
72
73    @trace
74    async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
75        """
76        Retrieve the given user's devices
77
78        Args:
79            user_id: The user ID to query for devices.
80        Returns:
81            info on each device
82        """
83
84        set_tag("user_id", user_id)
85        device_map = await self.store.get_devices_by_user(user_id)
86
87        ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
88
89        devices = list(device_map.values())
90        for device in devices:
91            _update_device_from_client_ips(device, ips)
92
93        log_kv(device_map)
94        return devices
95
96    @trace
97    async def get_device(self, user_id: str, device_id: str) -> JsonDict:
98        """Retrieve the given device
99
100        Args:
101            user_id: The user to get the device from
102            device_id: The device to fetch.
103
104        Returns:
105            info on the device
106        Raises:
107            errors.NotFoundError: if the device was not found
108        """
109        device = await self.store.get_device(user_id, device_id)
110        if device is None:
111            raise errors.NotFoundError()
112
113        ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
114        _update_device_from_client_ips(device, ips)
115
116        set_tag("device", device)
117        set_tag("ips", ips)
118
119        return device
120
121    @trace
122    @measure_func("device.get_user_ids_changed")
123    async def get_user_ids_changed(
124        self, user_id: str, from_token: StreamToken
125    ) -> JsonDict:
126        """Get list of users that have had the devices updated, or have newly
127        joined a room, that `user_id` may be interested in.
128        """
129
130        set_tag("user_id", user_id)
131        set_tag("from_token", from_token)
132        now_room_key = self.store.get_room_max_token()
133
134        room_ids = await self.store.get_rooms_for_user(user_id)
135
136        # First we check if any devices have changed for users that we share
137        # rooms with.
138        users_who_share_room = await self.store.get_users_who_share_room_with_user(
139            user_id
140        )
141
142        tracked_users = set(users_who_share_room)
143
144        # Always tell the user about their own devices
145        tracked_users.add(user_id)
146
147        changed = await self.store.get_users_whose_devices_changed(
148            from_token.device_list_key, tracked_users
149        )
150
151        # Then work out if any users have since joined
152        rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
153
154        member_events = await self.store.get_membership_changes_for_user(
155            user_id, from_token.room_key, now_room_key
156        )
157        rooms_changed.update(event.room_id for event in member_events)
158
159        stream_ordering = from_token.room_key.stream
160
161        possibly_changed = set(changed)
162        possibly_left = set()
163        for room_id in rooms_changed:
164            current_state_ids = await self.store.get_current_state_ids(room_id)
165
166            # The user may have left the room
167            # TODO: Check if they actually did or if we were just invited.
168            if room_id not in room_ids:
169                for etype, state_key in current_state_ids.keys():
170                    if etype != EventTypes.Member:
171                        continue
172                    possibly_left.add(state_key)
173                continue
174
175            # Fetch the current state at the time.
176            try:
177                event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
178                    room_id, stream_ordering=stream_ordering
179                )
180            except errors.StoreError:
181                # we have purged the stream_ordering index since the stream
182                # ordering: treat it the same as a new room
183                event_ids = []
184
185            # special-case for an empty prev state: include all members
186            # in the changed list
187            if not event_ids:
188                log_kv(
189                    {"event": "encountered empty previous state", "room_id": room_id}
190                )
191                for etype, state_key in current_state_ids.keys():
192                    if etype != EventTypes.Member:
193                        continue
194                    possibly_changed.add(state_key)
195                continue
196
197            current_member_id = current_state_ids.get((EventTypes.Member, user_id))
198            if not current_member_id:
199                continue
200
201            # mapping from event_id -> state_dict
202            prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
203
204            # Check if we've joined the room? If so we just blindly add all the users to
205            # the "possibly changed" users.
206            for state_dict in prev_state_ids.values():
207                member_event = state_dict.get((EventTypes.Member, user_id), None)
208                if not member_event or member_event != current_member_id:
209                    for etype, state_key in current_state_ids.keys():
210                        if etype != EventTypes.Member:
211                            continue
212                        possibly_changed.add(state_key)
213                    break
214
215            # If there has been any change in membership, include them in the
216            # possibly changed list. We'll check if they are joined below,
217            # and we're not toooo worried about spuriously adding users.
218            for key, event_id in current_state_ids.items():
219                etype, state_key = key
220                if etype != EventTypes.Member:
221                    continue
222
223                # check if this member has changed since any of the extremities
224                # at the stream_ordering, and add them to the list if so.
225                for state_dict in prev_state_ids.values():
226                    prev_event_id = state_dict.get(key, None)
227                    if not prev_event_id or prev_event_id != event_id:
228                        if state_key != user_id:
229                            possibly_changed.add(state_key)
230                        break
231
232        if possibly_changed or possibly_left:
233            # Take the intersection of the users whose devices may have changed
234            # and those that actually still share a room with the user
235            possibly_joined = possibly_changed & users_who_share_room
236            possibly_left = (possibly_changed | possibly_left) - users_who_share_room
237        else:
238            possibly_joined = set()
239            possibly_left = set()
240
241        result = {"changed": list(possibly_joined), "left": list(possibly_left)}
242
243        log_kv(result)
244
245        return result
246
247    async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
248        stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
249            user_id
250        )
251        master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
252        self_signing_key = await self.store.get_e2e_cross_signing_key(
253            user_id, "self_signing"
254        )
255
256        return {
257            "user_id": user_id,
258            "stream_id": stream_id,
259            "devices": devices,
260            "master_key": master_key,
261            "self_signing_key": self_signing_key,
262        }
263
264
265class DeviceHandler(DeviceWorkerHandler):
266    def __init__(self, hs: "HomeServer"):
267        super().__init__(hs)
268
269        self.federation_sender = hs.get_federation_sender()
270
271        self.device_list_updater = DeviceListUpdater(hs, self)
272
273        federation_registry = hs.get_federation_registry()
274
275        federation_registry.register_edu_handler(
276            "m.device_list_update", self.device_list_updater.incoming_device_list_update
277        )
278
279        hs.get_distributor().observe("user_left_room", self.user_left_room)
280
281    def _check_device_name_length(self, name: Optional[str]) -> None:
282        """
283        Checks whether a device name is longer than the maximum allowed length.
284
285        Args:
286            name: The name of the device.
287
288        Raises:
289            SynapseError: if the device name is too long.
290        """
291        if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN:
292            raise SynapseError(
293                400,
294                "Device display name is too long (max %i)"
295                % (MAX_DEVICE_DISPLAY_NAME_LEN,),
296                errcode=Codes.TOO_LARGE,
297            )
298
299    async def check_device_registered(
300        self,
301        user_id: str,
302        device_id: Optional[str],
303        initial_device_display_name: Optional[str] = None,
304        auth_provider_id: Optional[str] = None,
305        auth_provider_session_id: Optional[str] = None,
306    ) -> str:
307        """
308        If the given device has not been registered, register it with the
309        supplied display name.
310
311        If no device_id is supplied, we make one up.
312
313        Args:
314            user_id:  @user:id
315            device_id: device id supplied by client
316            initial_device_display_name: device display name from client
317            auth_provider_id: The SSO IdP the user used, if any.
318            auth_provider_session_id: The session ID (sid) got from the SSO IdP.
319        Returns:
320            device id (generated if none was supplied)
321        """
322
323        self._check_device_name_length(initial_device_display_name)
324
325        if device_id is not None:
326            new_device = await self.store.store_device(
327                user_id=user_id,
328                device_id=device_id,
329                initial_device_display_name=initial_device_display_name,
330                auth_provider_id=auth_provider_id,
331                auth_provider_session_id=auth_provider_session_id,
332            )
333            if new_device:
334                await self.notify_device_update(user_id, [device_id])
335            return device_id
336
337        # if the device id is not specified, we'll autogen one, but loop a few
338        # times in case of a clash.
339        attempts = 0
340        while attempts < 5:
341            new_device_id = stringutils.random_string(10).upper()
342            new_device = await self.store.store_device(
343                user_id=user_id,
344                device_id=new_device_id,
345                initial_device_display_name=initial_device_display_name,
346                auth_provider_id=auth_provider_id,
347                auth_provider_session_id=auth_provider_session_id,
348            )
349            if new_device:
350                await self.notify_device_update(user_id, [new_device_id])
351                return new_device_id
352            attempts += 1
353
354        raise errors.StoreError(500, "Couldn't generate a device ID.")
355
356    @trace
357    async def delete_device(self, user_id: str, device_id: str) -> None:
358        """Delete the given device
359
360        Args:
361            user_id: The user to delete the device from.
362            device_id: The device to delete.
363        """
364
365        try:
366            await self.store.delete_device(user_id, device_id)
367        except errors.StoreError as e:
368            if e.code == 404:
369                # no match
370                set_tag("error", True)
371                log_kv(
372                    {"reason": "User doesn't have device id.", "device_id": device_id}
373                )
374                pass
375            else:
376                raise
377
378        await self._auth_handler.delete_access_tokens_for_user(
379            user_id, device_id=device_id
380        )
381
382        await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
383
384        await self.notify_device_update(user_id, [device_id])
385
386    @trace
387    async def delete_all_devices_for_user(
388        self, user_id: str, except_device_id: Optional[str] = None
389    ) -> None:
390        """Delete all of the user's devices
391
392        Args:
393            user_id: The user to remove all devices from
394            except_device_id: optional device id which should not be deleted
395        """
396        device_map = await self.store.get_devices_by_user(user_id)
397        device_ids = list(device_map)
398        if except_device_id is not None:
399            device_ids = [d for d in device_ids if d != except_device_id]
400        await self.delete_devices(user_id, device_ids)
401
402    async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
403        """Delete several devices
404
405        Args:
406            user_id: The user to delete devices from.
407            device_ids: The list of device IDs to delete
408        """
409
410        try:
411            await self.store.delete_devices(user_id, device_ids)
412        except errors.StoreError as e:
413            if e.code == 404:
414                # no match
415                set_tag("error", True)
416                set_tag("reason", "User doesn't have that device id.")
417                pass
418            else:
419                raise
420
421        # Delete access tokens and e2e keys for each device. Not optimised as it is not
422        # considered as part of a critical path.
423        for device_id in device_ids:
424            await self._auth_handler.delete_access_tokens_for_user(
425                user_id, device_id=device_id
426            )
427            await self.store.delete_e2e_keys_by_device(
428                user_id=user_id, device_id=device_id
429            )
430
431        await self.notify_device_update(user_id, device_ids)
432
433    async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
434        """Update the given device
435
436        Args:
437            user_id: The user to update devices of.
438            device_id: The device to update.
439            content: body of update request
440        """
441
442        # Reject a new displayname which is too long.
443        new_display_name = content.get("display_name")
444
445        self._check_device_name_length(new_display_name)
446
447        try:
448            await self.store.update_device(
449                user_id, device_id, new_display_name=new_display_name
450            )
451            await self.notify_device_update(user_id, [device_id])
452        except errors.StoreError as e:
453            if e.code == 404:
454                raise errors.NotFoundError()
455            else:
456                raise
457
458    @trace
459    @measure_func("notify_device_update")
460    async def notify_device_update(
461        self, user_id: str, device_ids: Collection[str]
462    ) -> None:
463        """Notify that a user's device(s) has changed. Pokes the notifier, and
464        remote servers if the user is local.
465
466        Args:
467            user_id: The Matrix ID of the user who's device list has been updated.
468            device_ids: The device IDs that have changed.
469        """
470        if not device_ids:
471            # No changes to notify about, so this is a no-op.
472            return
473
474        users_who_share_room = await self.store.get_users_who_share_room_with_user(
475            user_id
476        )
477
478        hosts: Set[str] = set()
479        if self.hs.is_mine_id(user_id):
480            hosts.update(get_domain_from_id(u) for u in users_who_share_room)
481            hosts.discard(self.server_name)
482
483        set_tag("target_hosts", hosts)
484
485        position = await self.store.add_device_change_to_streams(
486            user_id, device_ids, list(hosts)
487        )
488
489        if not position:
490            # This should only happen if there are no updates, so we bail.
491            return
492
493        for device_id in device_ids:
494            logger.debug(
495                "Notifying about update %r/%r, ID: %r", user_id, device_id, position
496            )
497
498        room_ids = await self.store.get_rooms_for_user(user_id)
499
500        # specify the user ID too since the user should always get their own device list
501        # updates, even if they aren't in any rooms.
502        self.notifier.on_new_event(
503            "device_list_key", position, users=[user_id], rooms=room_ids
504        )
505
506        if hosts:
507            logger.info(
508                "Sending device list update notif for %r to: %r", user_id, hosts
509            )
510            for host in hosts:
511                self.federation_sender.send_device_messages(host)
512                log_kv({"message": "sent device update to host", "host": host})
513
514    async def notify_user_signature_update(
515        self, from_user_id: str, user_ids: List[str]
516    ) -> None:
517        """Notify a user that they have made new signatures of other users.
518
519        Args:
520            from_user_id: the user who made the signature
521            user_ids: the users IDs that have new signatures
522        """
523
524        position = await self.store.add_user_signature_change_to_streams(
525            from_user_id, user_ids
526        )
527
528        self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
529
530    async def user_left_room(self, user: UserID, room_id: str) -> None:
531        user_id = user.to_string()
532        room_ids = await self.store.get_rooms_for_user(user_id)
533        if not room_ids:
534            # We no longer share rooms with this user, so we'll no longer
535            # receive device updates. Mark this in DB.
536            await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
537
538    async def store_dehydrated_device(
539        self,
540        user_id: str,
541        device_data: JsonDict,
542        initial_device_display_name: Optional[str] = None,
543    ) -> str:
544        """Store a dehydrated device for a user.  If the user had a previous
545        dehydrated device, it is removed.
546
547        Args:
548            user_id: the user that we are storing the device for
549            device_data: the dehydrated device information
550            initial_device_display_name: The display name to use for the device
551        Returns:
552            device id of the dehydrated device
553        """
554        device_id = await self.check_device_registered(
555            user_id,
556            None,
557            initial_device_display_name,
558        )
559        old_device_id = await self.store.store_dehydrated_device(
560            user_id, device_id, device_data
561        )
562        if old_device_id is not None:
563            await self.delete_device(user_id, old_device_id)
564        return device_id
565
566    async def get_dehydrated_device(
567        self, user_id: str
568    ) -> Optional[Tuple[str, JsonDict]]:
569        """Retrieve the information for a dehydrated device.
570
571        Args:
572            user_id: the user whose dehydrated device we are looking for
573        Returns:
574            a tuple whose first item is the device ID, and the second item is
575            the dehydrated device information
576        """
577        return await self.store.get_dehydrated_device(user_id)
578
579    async def rehydrate_device(
580        self, user_id: str, access_token: str, device_id: str
581    ) -> dict:
582        """Process a rehydration request from the user.
583
584        Args:
585            user_id: the user who is rehydrating the device
586            access_token: the access token used for the request
587            device_id: the ID of the device that will be rehydrated
588        Returns:
589            a dict containing {"success": True}
590        """
591        success = await self.store.remove_dehydrated_device(user_id, device_id)
592
593        if not success:
594            raise errors.NotFoundError()
595
596        # If the dehydrated device was successfully deleted (the device ID
597        # matched the stored dehydrated device), then modify the access
598        # token to use the dehydrated device's ID and copy the old device
599        # display name to the dehydrated device, and destroy the old device
600        # ID
601        old_device_id = await self.store.set_device_for_access_token(
602            access_token, device_id
603        )
604        old_device = await self.store.get_device(user_id, old_device_id)
605        if old_device is None:
606            raise errors.NotFoundError()
607        await self.store.update_device(user_id, device_id, old_device["display_name"])
608        # can't call self.delete_device because that will clobber the
609        # access token so call the storage layer directly
610        await self.store.delete_device(user_id, old_device_id)
611        await self.store.delete_e2e_keys_by_device(
612            user_id=user_id, device_id=old_device_id
613        )
614
615        # tell everyone that the old device is gone and that the dehydrated
616        # device has a new display name
617        await self.notify_device_update(user_id, [old_device_id, device_id])
618
619        return {"success": True}
620
621
622def _update_device_from_client_ips(
623    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
624) -> None:
625    ip = client_ips.get((device["user_id"], device["device_id"]), {})
626    device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
627
628
629class DeviceListUpdater:
630    "Handles incoming device list updates from federation and updates the DB"
631
632    def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
633        self.store = hs.get_datastore()
634        self.federation = hs.get_federation_client()
635        self.clock = hs.get_clock()
636        self.device_handler = device_handler
637
638        self._remote_edu_linearizer = Linearizer(name="remote_device_list")
639
640        # user_id -> list of updates waiting to be handled.
641        self._pending_updates: Dict[
642            str, List[Tuple[str, str, Iterable[str], JsonDict]]
643        ] = {}
644
645        # Recently seen stream ids. We don't bother keeping these in the DB,
646        # but they're useful to have them about to reduce the number of spurious
647        # resyncs.
648        self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
649            cache_name="device_update_edu",
650            clock=self.clock,
651            max_len=10000,
652            expiry_ms=30 * 60 * 1000,
653            iterable=True,
654        )
655
656        # Attempt to resync out of sync device lists every 30s.
657        self._resync_retry_in_progress = False
658        self.clock.looping_call(
659            run_as_background_process,
660            30 * 1000,
661            func=self._maybe_retry_device_resync,
662            desc="_maybe_retry_device_resync",
663        )
664
665    @trace
666    async def incoming_device_list_update(
667        self, origin: str, edu_content: JsonDict
668    ) -> None:
669        """Called on incoming device list update from federation. Responsible
670        for parsing the EDU and adding to pending updates list.
671        """
672
673        set_tag("origin", origin)
674        set_tag("edu_content", edu_content)
675        user_id = edu_content.pop("user_id")
676        device_id = edu_content.pop("device_id")
677        stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
678        prev_ids = edu_content.pop("prev_id", [])
679        prev_ids = [str(p) for p in prev_ids]  # They may come as ints
680
681        if get_domain_from_id(user_id) != origin:
682            # TODO: Raise?
683            logger.warning(
684                "Got device list update edu for %r/%r from %r",
685                user_id,
686                device_id,
687                origin,
688            )
689
690            set_tag("error", True)
691            log_kv(
692                {
693                    "message": "Got a device list update edu from a user and "
694                    "device which does not match the origin of the request.",
695                    "user_id": user_id,
696                    "device_id": device_id,
697                }
698            )
699            return
700
701        room_ids = await self.store.get_rooms_for_user(user_id)
702        if not room_ids:
703            # We don't share any rooms with this user. Ignore update, as we
704            # probably won't get any further updates.
705            set_tag("error", True)
706            log_kv(
707                {
708                    "message": "Got an update from a user for which "
709                    "we don't share any rooms",
710                    "other user_id": user_id,
711                }
712            )
713            logger.warning(
714                "Got device list update edu for %r/%r, but don't share a room",
715                user_id,
716                device_id,
717            )
718            return
719
720        logger.debug("Received device list update for %r/%r", user_id, device_id)
721
722        self._pending_updates.setdefault(user_id, []).append(
723            (device_id, stream_id, prev_ids, edu_content)
724        )
725
726        await self._handle_device_updates(user_id)
727
728    @measure_func("_incoming_device_list_update")
729    async def _handle_device_updates(self, user_id: str) -> None:
730        "Actually handle pending updates."
731
732        with (await self._remote_edu_linearizer.queue(user_id)):
733            pending_updates = self._pending_updates.pop(user_id, [])
734            if not pending_updates:
735                # This can happen since we batch updates
736                return
737
738            for device_id, stream_id, prev_ids, _ in pending_updates:
739                logger.debug(
740                    "Handling update %r/%r, ID: %r, prev: %r ",
741                    user_id,
742                    device_id,
743                    stream_id,
744                    prev_ids,
745                )
746
747            # Given a list of updates we check if we need to resync. This
748            # happens if we've missed updates.
749            resync = await self._need_to_do_resync(user_id, pending_updates)
750
751            if logger.isEnabledFor(logging.INFO):
752                logger.info(
753                    "Received device list update for %s, requiring resync: %s. Devices: %s",
754                    user_id,
755                    resync,
756                    ", ".join(u[0] for u in pending_updates),
757                )
758
759            if resync:
760                await self.user_device_resync(user_id)
761            else:
762                # Simply update the single device, since we know that is the only
763                # change (because of the single prev_id matching the current cache)
764                for device_id, stream_id, _, content in pending_updates:
765                    await self.store.update_remote_device_list_cache_entry(
766                        user_id, device_id, content, stream_id
767                    )
768
769                await self.device_handler.notify_device_update(
770                    user_id, [device_id for device_id, _, _, _ in pending_updates]
771                )
772
773                self._seen_updates.setdefault(user_id, set()).update(
774                    stream_id for _, stream_id, _, _ in pending_updates
775                )
776
777    async def _need_to_do_resync(
778        self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
779    ) -> bool:
780        """Given a list of updates for a user figure out if we need to do a full
781        resync, or whether we have enough data that we can just apply the delta.
782        """
783        seen_updates: Set[str] = self._seen_updates.get(user_id, set())
784
785        extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
786
787        logger.debug("Current extremity for %r: %r", user_id, extremity)
788
789        stream_id_in_updates = set()  # stream_ids in updates list
790        for _, stream_id, prev_ids, _ in updates:
791            if not prev_ids:
792                # We always do a resync if there are no previous IDs
793                return True
794
795            for prev_id in prev_ids:
796                if prev_id == extremity:
797                    continue
798                elif prev_id in seen_updates:
799                    continue
800                elif prev_id in stream_id_in_updates:
801                    continue
802                else:
803                    return True
804
805            stream_id_in_updates.add(stream_id)
806
807        return False
808
809    @trace
810    async def _maybe_retry_device_resync(self) -> None:
811        """Retry to resync device lists that are out of sync, except if another retry is
812        in progress.
813        """
814        if self._resync_retry_in_progress:
815            return
816
817        try:
818            # Prevent another call of this function to retry resyncing device lists so
819            # we don't send too many requests.
820            self._resync_retry_in_progress = True
821            # Get all of the users that need resyncing.
822            need_resync = await self.store.get_user_ids_requiring_device_list_resync()
823            # Iterate over the set of user IDs.
824            for user_id in need_resync:
825                try:
826                    # Try to resync the current user's devices list.
827                    result = await self.user_device_resync(
828                        user_id=user_id,
829                        mark_failed_as_stale=False,
830                    )
831
832                    # user_device_resync only returns a result if it managed to
833                    # successfully resync and update the database. Updating the table
834                    # of users requiring resync isn't necessary here as
835                    # user_device_resync already does it (through
836                    # self.store.update_remote_device_list_cache).
837                    if result:
838                        logger.debug(
839                            "Successfully resynced the device list for %s",
840                            user_id,
841                        )
842                except Exception as e:
843                    # If there was an issue resyncing this user, e.g. if the remote
844                    # server sent a malformed result, just log the error instead of
845                    # aborting all the subsequent resyncs.
846                    logger.debug(
847                        "Could not resync the device list for %s: %s",
848                        user_id,
849                        e,
850                    )
851        finally:
852            # Allow future calls to retry resyncinc out of sync device lists.
853            self._resync_retry_in_progress = False
854
855    async def user_device_resync(
856        self, user_id: str, mark_failed_as_stale: bool = True
857    ) -> Optional[JsonDict]:
858        """Fetches all devices for a user and updates the device cache with them.
859
860        Args:
861            user_id: The user's id whose device_list will be updated.
862            mark_failed_as_stale: Whether to mark the user's device list as stale
863                if the attempt to resync failed.
864        Returns:
865            A dict with device info as under the "devices" in the result of this
866            request:
867            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
868        """
869        logger.debug("Attempting to resync the device list for %s", user_id)
870        log_kv({"message": "Doing resync to update device list."})
871        # Fetch all devices for the user.
872        origin = get_domain_from_id(user_id)
873        try:
874            result = await self.federation.query_user_devices(origin, user_id)
875        except NotRetryingDestination:
876            if mark_failed_as_stale:
877                # Mark the remote user's device list as stale so we know we need to retry
878                # it later.
879                await self.store.mark_remote_user_device_cache_as_stale(user_id)
880
881            return None
882        except (RequestSendFailed, HttpResponseException) as e:
883            logger.warning(
884                "Failed to handle device list update for %s: %s",
885                user_id,
886                e,
887            )
888
889            if mark_failed_as_stale:
890                # Mark the remote user's device list as stale so we know we need to retry
891                # it later.
892                await self.store.mark_remote_user_device_cache_as_stale(user_id)
893
894            # We abort on exceptions rather than accepting the update
895            # as otherwise synapse will 'forget' that its device list
896            # is out of date. If we bail then we will retry the resync
897            # next time we get a device list update for this user_id.
898            # This makes it more likely that the device lists will
899            # eventually become consistent.
900            return None
901        except FederationDeniedError as e:
902            set_tag("error", True)
903            log_kv({"reason": "FederationDeniedError"})
904            logger.info(e)
905            return None
906        except Exception as e:
907            set_tag("error", True)
908            log_kv(
909                {"message": "Exception raised by federation request", "exception": e}
910            )
911            logger.exception("Failed to handle device list update for %s", user_id)
912
913            if mark_failed_as_stale:
914                # Mark the remote user's device list as stale so we know we need to retry
915                # it later.
916                await self.store.mark_remote_user_device_cache_as_stale(user_id)
917
918            return None
919        log_kv({"result": result})
920        stream_id = result["stream_id"]
921        devices = result["devices"]
922
923        # Get the master key and the self-signing key for this user if provided in the
924        # response (None if not in the response).
925        # The response will not contain the user signing key, as this key is only used by
926        # its owner, thus it doesn't make sense to send it over federation.
927        master_key = result.get("master_key")
928        self_signing_key = result.get("self_signing_key")
929
930        ignore_devices = False
931        # If the remote server has more than ~1000 devices for this user
932        # we assume that something is going horribly wrong (e.g. a bot
933        # that logs in and creates a new device every time it tries to
934        # send a message).  Maintaining lots of devices per user in the
935        # cache can cause serious performance issues as if this request
936        # takes more than 60s to complete, internal replication from the
937        # inbound federation worker to the synapse master may time out
938        # causing the inbound federation to fail and causing the remote
939        # server to retry, causing a DoS.  So in this scenario we give
940        # up on storing the total list of devices and only handle the
941        # delta instead.
942        if len(devices) > 1000:
943            logger.warning(
944                "Ignoring device list snapshot for %s as it has >1K devs (%d)",
945                user_id,
946                len(devices),
947            )
948            devices = []
949            ignore_devices = True
950        else:
951            cached_devices = await self.store.get_cached_devices_for_user(user_id)
952            if cached_devices == {d["device_id"]: d for d in devices}:
953                logging.info(
954                    "Skipping device list resync for %s, as our cache matches already",
955                    user_id,
956                )
957                devices = []
958                ignore_devices = True
959
960        for device in devices:
961            logger.debug(
962                "Handling resync update %r/%r, ID: %r",
963                user_id,
964                device["device_id"],
965                stream_id,
966            )
967
968        if not ignore_devices:
969            await self.store.update_remote_device_list_cache(
970                user_id, devices, stream_id
971            )
972        # mark the cache as valid, whether or not we actually processed any device
973        # list updates.
974        await self.store.mark_remote_user_device_cache_as_valid(user_id)
975        device_ids = [device["device_id"] for device in devices]
976
977        # Handle cross-signing keys.
978        cross_signing_device_ids = await self.process_cross_signing_key_update(
979            user_id,
980            master_key,
981            self_signing_key,
982        )
983        device_ids = device_ids + cross_signing_device_ids
984
985        if device_ids:
986            await self.device_handler.notify_device_update(user_id, device_ids)
987
988        # We clobber the seen updates since we've re-synced from a given
989        # point.
990        self._seen_updates[user_id] = {stream_id}
991
992        return result
993
994    async def process_cross_signing_key_update(
995        self,
996        user_id: str,
997        master_key: Optional[JsonDict],
998        self_signing_key: Optional[JsonDict],
999    ) -> List[str]:
1000        """Process the given new master and self-signing key for the given remote user.
1001
1002        Args:
1003            user_id: The ID of the user these keys are for.
1004            master_key: The dict of the cross-signing master key as returned by the
1005                remote server.
1006            self_signing_key: The dict of the cross-signing self-signing key as returned
1007                by the remote server.
1008
1009        Return:
1010            The device IDs for the given keys.
1011        """
1012        device_ids = []
1013
1014        current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id])
1015        current_keys = current_keys_map.get(user_id) or {}
1016
1017        if master_key and master_key != current_keys.get("master"):
1018            await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
1019            _, verify_key = get_verify_key_from_cross_signing_key(master_key)
1020            # verify_key is a VerifyKey from signedjson, which uses
1021            # .version to denote the portion of the key ID after the
1022            # algorithm and colon, which is the device ID
1023            device_ids.append(verify_key.version)
1024        if self_signing_key and self_signing_key != current_keys.get("self_signing"):
1025            await self.store.set_e2e_cross_signing_key(
1026                user_id, "self_signing", self_signing_key
1027            )
1028            _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
1029            device_ids.append(verify_key.version)
1030
1031        return device_ids
1032