1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2017-2018 New Vector Ltd
3# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import logging
17import random
18import re
19from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
20
21import attr
22
23from synapse.api.constants import UserTypes
24from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
25from synapse.metrics.background_process_metrics import wrap_as_background_process
26from synapse.storage.database import (
27    DatabasePool,
28    LoggingDatabaseConnection,
29    LoggingTransaction,
30)
31from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
32from synapse.storage.databases.main.stats import StatsStore
33from synapse.storage.types import Cursor
34from synapse.storage.util.id_generators import IdGenerator
35from synapse.storage.util.sequence import build_sequence_generator
36from synapse.types import UserID, UserInfo
37from synapse.util.caches.descriptors import cached
38
39if TYPE_CHECKING:
40    from synapse.server import HomeServer
41
42THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
43
44logger = logging.getLogger(__name__)
45
46
47class ExternalIDReuseException(Exception):
48    """Exception if writing an external id for a user fails,
49    because this external id is given to an other user."""
50
51    pass
52
53
54@attr.s(frozen=True, slots=True)
55class TokenLookupResult:
56    """Result of looking up an access token.
57
58    Attributes:
59        user_id: The user that this token authenticates as
60        is_guest
61        shadow_banned
62        token_id: The ID of the access token looked up
63        device_id: The device associated with the token, if any.
64        valid_until_ms: The timestamp the token expires, if any.
65        token_owner: The "owner" of the token. This is either the same as the
66            user, or a server admin who is logged in as the user.
67        token_used: True if this token was used at least once in a request.
68            This field can be out of date since `get_user_by_access_token` is
69            cached.
70    """
71
72    user_id = attr.ib(type=str)
73    is_guest = attr.ib(type=bool, default=False)
74    shadow_banned = attr.ib(type=bool, default=False)
75    token_id = attr.ib(type=Optional[int], default=None)
76    device_id = attr.ib(type=Optional[str], default=None)
77    valid_until_ms = attr.ib(type=Optional[int], default=None)
78    token_owner = attr.ib(type=str)
79    token_used = attr.ib(type=bool, default=False)
80
81    # Make the token owner default to the user ID, which is the common case.
82    @token_owner.default
83    def _default_token_owner(self):
84        return self.user_id
85
86
87@attr.s(auto_attribs=True, frozen=True, slots=True)
88class RefreshTokenLookupResult:
89    """Result of looking up a refresh token."""
90
91    user_id: str
92    """The user this token belongs to."""
93
94    device_id: str
95    """The device associated with this refresh token."""
96
97    token_id: int
98    """The ID of this refresh token."""
99
100    next_token_id: Optional[int]
101    """The ID of the refresh token which replaced this one."""
102
103    has_next_refresh_token_been_refreshed: bool
104    """True if the next refresh token was used for another refresh."""
105
106    has_next_access_token_been_used: bool
107    """True if the next access token was already used at least once."""
108
109    expiry_ts: Optional[int]
110    """The time at which the refresh token expires and can not be used.
111    If None, the refresh token doesn't expire."""
112
113    ultimate_session_expiry_ts: Optional[int]
114    """The time at which the session comes to an end and can no longer be
115    refreshed.
116    If None, the session can be refreshed indefinitely."""
117
118
119class RegistrationWorkerStore(CacheInvalidationWorkerStore):
120    def __init__(
121        self,
122        database: DatabasePool,
123        db_conn: LoggingDatabaseConnection,
124        hs: "HomeServer",
125    ):
126        super().__init__(database, db_conn, hs)
127
128        self.config = hs.config
129
130        # Note: we don't check this sequence for consistency as we'd have to
131        # call `find_max_generated_user_id_localpart` each time, which is
132        # expensive if there are many entries.
133        self._user_id_seq = build_sequence_generator(
134            db_conn,
135            database.engine,
136            find_max_generated_user_id_localpart,
137            "user_id_seq",
138            table=None,
139            id_column=None,
140        )
141
142        self._account_validity_enabled = (
143            hs.config.account_validity.account_validity_enabled
144        )
145        self._account_validity_period = None
146        self._account_validity_startup_job_max_delta = None
147        if self._account_validity_enabled:
148            self._account_validity_period = (
149                hs.config.account_validity.account_validity_period
150            )
151            self._account_validity_startup_job_max_delta = (
152                hs.config.account_validity.account_validity_startup_job_max_delta
153            )
154
155            if hs.config.worker.run_background_tasks:
156                self._clock.call_later(
157                    0.0,
158                    self._set_expiration_date_when_missing,
159                )
160
161        # Create a background job for culling expired 3PID validity tokens
162        if hs.config.worker.run_background_tasks:
163            self._clock.looping_call(
164                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
165            )
166
167    @cached()
168    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
169        """Deprecated: use get_userinfo_by_id instead"""
170        return await self.db_pool.simple_select_one(
171            table="users",
172            keyvalues={"name": user_id},
173            retcols=[
174                "name",
175                "password_hash",
176                "is_guest",
177                "admin",
178                "consent_version",
179                "consent_server_notice_sent",
180                "appservice_id",
181                "creation_ts",
182                "user_type",
183                "deactivated",
184                "shadow_banned",
185            ],
186            allow_none=True,
187            desc="get_user_by_id",
188        )
189
190    async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
191        """Get a UserInfo object for a user by user ID.
192
193        Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
194        this method should be cached.
195
196        Args:
197             user_id: The user to fetch user info for.
198        Returns:
199            `UserInfo` object if user found, otherwise `None`.
200        """
201        user_data = await self.get_user_by_id(user_id)
202        if not user_data:
203            return None
204        return UserInfo(
205            appservice_id=user_data["appservice_id"],
206            consent_server_notice_sent=user_data["consent_server_notice_sent"],
207            consent_version=user_data["consent_version"],
208            creation_ts=user_data["creation_ts"],
209            is_admin=bool(user_data["admin"]),
210            is_deactivated=bool(user_data["deactivated"]),
211            is_guest=bool(user_data["is_guest"]),
212            is_shadow_banned=bool(user_data["shadow_banned"]),
213            user_id=UserID.from_string(user_data["name"]),
214            user_type=user_data["user_type"],
215        )
216
217    async def is_trial_user(self, user_id: str) -> bool:
218        """Checks if user is in the "trial" period, i.e. within the first
219        N days of registration defined by `mau_trial_days` config
220
221        Args:
222            user_id: The user to check for trial status.
223        """
224
225        info = await self.get_user_by_id(user_id)
226        if not info:
227            return False
228
229        now = self._clock.time_msec()
230        trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000
231        is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
232        return is_trial
233
234    @cached()
235    async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
236        """Get a user from the given access token.
237
238        Args:
239            token: The access token of a user.
240        Returns:
241            None, if the token did not match, otherwise a `TokenLookupResult`
242        """
243        return await self.db_pool.runInteraction(
244            "get_user_by_access_token", self._query_for_auth, token
245        )
246
247    @cached()
248    async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
249        """Get the expiration timestamp for the account bearing a given user ID.
250
251        Args:
252            user_id: The ID of the user.
253        Returns:
254            None, if the account has no expiration timestamp, otherwise int
255            representation of the timestamp (as a number of milliseconds since epoch).
256        """
257        return await self.db_pool.simple_select_one_onecol(
258            table="account_validity",
259            keyvalues={"user_id": user_id},
260            retcol="expiration_ts_ms",
261            allow_none=True,
262            desc="get_expiration_ts_for_user",
263        )
264
265    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
266        """
267        Returns whether an user account is expired.
268
269        Args:
270            user_id: The user's ID
271            current_ts: The current timestamp
272
273        Returns:
274            Whether the user account has expired
275        """
276        expiration_ts = await self.get_expiration_ts_for_user(user_id)
277        return expiration_ts is not None and current_ts >= expiration_ts
278
279    async def set_account_validity_for_user(
280        self,
281        user_id: str,
282        expiration_ts: int,
283        email_sent: bool,
284        renewal_token: Optional[str] = None,
285        token_used_ts: Optional[int] = None,
286    ) -> None:
287        """Updates the account validity properties of the given account, with the
288        given values.
289
290        Args:
291            user_id: ID of the account to update properties for.
292            expiration_ts: New expiration date, as a timestamp in milliseconds
293                since epoch.
294            email_sent: True means a renewal email has been sent for this account
295                and there's no need to send another one for the current validity
296                period.
297            renewal_token: Renewal token the user can use to extend the validity
298                of their account. Defaults to no token.
299            token_used_ts: A timestamp of when the current token was used to renew
300                the account.
301        """
302
303        def set_account_validity_for_user_txn(txn):
304            self.db_pool.simple_update_txn(
305                txn=txn,
306                table="account_validity",
307                keyvalues={"user_id": user_id},
308                updatevalues={
309                    "expiration_ts_ms": expiration_ts,
310                    "email_sent": email_sent,
311                    "renewal_token": renewal_token,
312                    "token_used_ts_ms": token_used_ts,
313                },
314            )
315            self._invalidate_cache_and_stream(
316                txn, self.get_expiration_ts_for_user, (user_id,)
317            )
318
319        await self.db_pool.runInteraction(
320            "set_account_validity_for_user", set_account_validity_for_user_txn
321        )
322
323    async def set_renewal_token_for_user(
324        self, user_id: str, renewal_token: str
325    ) -> None:
326        """Defines a renewal token for a given user, and clears the token_used timestamp.
327
328        Args:
329            user_id: ID of the user to set the renewal token for.
330            renewal_token: Random unique string that will be used to renew the
331                user's account.
332
333        Raises:
334            StoreError: The provided token is already set for another user.
335        """
336        await self.db_pool.simple_update_one(
337            table="account_validity",
338            keyvalues={"user_id": user_id},
339            updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None},
340            desc="set_renewal_token_for_user",
341        )
342
343    async def get_user_from_renewal_token(
344        self, renewal_token: str
345    ) -> Tuple[str, int, Optional[int]]:
346        """Get a user ID and renewal status from a renewal token.
347
348        Args:
349            renewal_token: The renewal token to perform the lookup with.
350
351        Returns:
352            A tuple of containing the following values:
353                * The ID of a user to which the token belongs.
354                * An int representing the user's expiry timestamp as milliseconds since the
355                    epoch, or 0 if the token was invalid.
356                * An optional int representing the timestamp of when the user renewed their
357                    account timestamp as milliseconds since the epoch. None if the account
358                    has not been renewed using the current token yet.
359        """
360        ret_dict = await self.db_pool.simple_select_one(
361            table="account_validity",
362            keyvalues={"renewal_token": renewal_token},
363            retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
364            desc="get_user_from_renewal_token",
365        )
366
367        return (
368            ret_dict["user_id"],
369            ret_dict["expiration_ts_ms"],
370            ret_dict["token_used_ts_ms"],
371        )
372
373    async def get_renewal_token_for_user(self, user_id: str) -> str:
374        """Get the renewal token associated with a given user ID.
375
376        Args:
377            user_id: The user ID to lookup a token for.
378
379        Returns:
380            The renewal token associated with this user ID.
381        """
382        return await self.db_pool.simple_select_one_onecol(
383            table="account_validity",
384            keyvalues={"user_id": user_id},
385            retcol="renewal_token",
386            desc="get_renewal_token_for_user",
387        )
388
389    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
390        """Selects users whose account will expire in the [now, now + renew_at] time
391        window (see configuration for account_validity for information on what renew_at
392        refers to).
393
394        Returns:
395            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
396        """
397
398        def select_users_txn(txn, now_ms, renew_at):
399            sql = (
400                "SELECT user_id, expiration_ts_ms FROM account_validity"
401                " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
402            )
403            values = [False, now_ms, renew_at]
404            txn.execute(sql, values)
405            return self.db_pool.cursor_to_dict(txn)
406
407        return await self.db_pool.runInteraction(
408            "get_users_expiring_soon",
409            select_users_txn,
410            self._clock.time_msec(),
411            self.config.account_validity.account_validity_renew_at,
412        )
413
414    async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
415        """Sets or unsets the flag that indicates whether a renewal email has been sent
416        to the user (and the user hasn't renewed their account yet).
417
418        Args:
419            user_id: ID of the user to set/unset the flag for.
420            email_sent: Flag which indicates whether a renewal email has been sent
421                to this user.
422        """
423        await self.db_pool.simple_update_one(
424            table="account_validity",
425            keyvalues={"user_id": user_id},
426            updatevalues={"email_sent": email_sent},
427            desc="set_renewal_mail_status",
428        )
429
430    async def delete_account_validity_for_user(self, user_id: str) -> None:
431        """Deletes the entry for the given user in the account validity table, removing
432        their expiration date and renewal token.
433
434        Args:
435            user_id: ID of the user to remove from the account validity table.
436        """
437        await self.db_pool.simple_delete_one(
438            table="account_validity",
439            keyvalues={"user_id": user_id},
440            desc="delete_account_validity_for_user",
441        )
442
443    async def is_server_admin(self, user: UserID) -> bool:
444        """Determines if a user is an admin of this homeserver.
445
446        Args:
447            user: user ID of the user to test
448
449        Returns:
450            true iff the user is a server admin, false otherwise.
451        """
452        res = await self.db_pool.simple_select_one_onecol(
453            table="users",
454            keyvalues={"name": user.to_string()},
455            retcol="admin",
456            allow_none=True,
457            desc="is_server_admin",
458        )
459
460        return bool(res) if res else False
461
462    async def set_server_admin(self, user: UserID, admin: bool) -> None:
463        """Sets whether a user is an admin of this homeserver.
464
465        Args:
466            user: user ID of the user to test
467            admin: true iff the user is to be a server admin, false otherwise.
468        """
469
470        def set_server_admin_txn(txn):
471            self.db_pool.simple_update_one_txn(
472                txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
473            )
474            self._invalidate_cache_and_stream(
475                txn, self.get_user_by_id, (user.to_string(),)
476            )
477
478        await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
479
480    async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
481        """Sets whether a user shadow-banned.
482
483        Args:
484            user: user ID of the user to test
485            shadow_banned: true iff the user is to be shadow-banned, false otherwise.
486        """
487
488        def set_shadow_banned_txn(txn: LoggingTransaction) -> None:
489            user_id = user.to_string()
490            self.db_pool.simple_update_one_txn(
491                txn,
492                table="users",
493                keyvalues={"name": user_id},
494                updatevalues={"shadow_banned": shadow_banned},
495            )
496            # In order for this to apply immediately, clear the cache for this user.
497            tokens = self.db_pool.simple_select_onecol_txn(
498                txn,
499                table="access_tokens",
500                keyvalues={"user_id": user_id},
501                retcol="token",
502            )
503            for token in tokens:
504                self._invalidate_cache_and_stream(
505                    txn, self.get_user_by_access_token, (token,)
506                )
507            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
508
509        await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
510
511    async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
512        """Sets the user type.
513
514        Args:
515            user: user ID of the user.
516            user_type: type of the user or None for a user without a type.
517        """
518
519        def set_user_type_txn(txn):
520            self.db_pool.simple_update_one_txn(
521                txn, "users", {"name": user.to_string()}, {"user_type": user_type}
522            )
523            self._invalidate_cache_and_stream(
524                txn, self.get_user_by_id, (user.to_string(),)
525            )
526
527        await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
528
529    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
530        sql = """
531            SELECT users.name as user_id,
532                users.is_guest,
533                users.shadow_banned,
534                access_tokens.id as token_id,
535                access_tokens.device_id,
536                access_tokens.valid_until_ms,
537                access_tokens.user_id as token_owner,
538                access_tokens.used as token_used
539            FROM users
540            INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
541            WHERE token = ?
542        """
543
544        txn.execute(sql, (token,))
545        rows = self.db_pool.cursor_to_dict(txn)
546
547        if rows:
548            row = rows[0]
549
550            # This field is nullable, ensure it comes out as a boolean
551            if row["token_used"] is None:
552                row["token_used"] = False
553
554            return TokenLookupResult(**row)
555
556        return None
557
558    @cached()
559    async def is_real_user(self, user_id: str) -> bool:
560        """Determines if the user is a real user, ie does not have a 'user_type'.
561
562        Args:
563            user_id: user id to test
564
565        Returns:
566            True if user 'user_type' is null or empty string
567        """
568        return await self.db_pool.runInteraction(
569            "is_real_user", self.is_real_user_txn, user_id
570        )
571
572    @cached()
573    async def is_support_user(self, user_id: str) -> bool:
574        """Determines if the user is of type UserTypes.SUPPORT
575
576        Args:
577            user_id: user id to test
578
579        Returns:
580            True if user is of type UserTypes.SUPPORT
581        """
582        return await self.db_pool.runInteraction(
583            "is_support_user", self.is_support_user_txn, user_id
584        )
585
586    def is_real_user_txn(self, txn, user_id):
587        res = self.db_pool.simple_select_one_onecol_txn(
588            txn=txn,
589            table="users",
590            keyvalues={"name": user_id},
591            retcol="user_type",
592            allow_none=True,
593        )
594        return res is None
595
596    def is_support_user_txn(self, txn, user_id):
597        res = self.db_pool.simple_select_one_onecol_txn(
598            txn=txn,
599            table="users",
600            keyvalues={"name": user_id},
601            retcol="user_type",
602            allow_none=True,
603        )
604        return True if res == UserTypes.SUPPORT else False
605
606    async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
607        """Gets users that match user_id case insensitively.
608
609        Returns:
610             A mapping of user_id -> password_hash.
611        """
612
613        def f(txn):
614            sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
615            txn.execute(sql, (user_id,))
616            return dict(txn)
617
618        return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
619
620    async def record_user_external_id(
621        self, auth_provider: str, external_id: str, user_id: str
622    ) -> None:
623        """Record a mapping from an external user id to a mxid
624
625        Args:
626            auth_provider: identifier for the remote auth provider
627            external_id: id on that system
628            user_id: complete mxid that it is mapped to
629        Raises:
630            ExternalIDReuseException if the new external_id could not be mapped.
631        """
632
633        try:
634            await self.db_pool.runInteraction(
635                "record_user_external_id",
636                self._record_user_external_id_txn,
637                auth_provider,
638                external_id,
639                user_id,
640            )
641        except self.database_engine.module.IntegrityError:
642            raise ExternalIDReuseException()
643
644    def _record_user_external_id_txn(
645        self,
646        txn: LoggingTransaction,
647        auth_provider: str,
648        external_id: str,
649        user_id: str,
650    ) -> None:
651
652        self.db_pool.simple_insert_txn(
653            txn,
654            table="user_external_ids",
655            values={
656                "auth_provider": auth_provider,
657                "external_id": external_id,
658                "user_id": user_id,
659            },
660        )
661
662    async def remove_user_external_id(
663        self, auth_provider: str, external_id: str, user_id: str
664    ) -> None:
665        """Remove a mapping from an external user id to a mxid
666        If the mapping is not found, this method does nothing.
667        Args:
668            auth_provider: identifier for the remote auth provider
669            external_id: id on that system
670            user_id: complete mxid that it is mapped to
671        """
672        await self.db_pool.simple_delete(
673            table="user_external_ids",
674            keyvalues={
675                "auth_provider": auth_provider,
676                "external_id": external_id,
677                "user_id": user_id,
678            },
679            desc="remove_user_external_id",
680        )
681
682    async def replace_user_external_id(
683        self,
684        record_external_ids: List[Tuple[str, str]],
685        user_id: str,
686    ) -> None:
687        """Replace mappings from external user ids to a mxid in a single transaction.
688        All mappings are deleted and the new ones are created.
689
690        Args:
691            record_external_ids:
692                List with tuple of auth_provider and external_id to record
693            user_id: complete mxid that it is mapped to
694        Raises:
695            ExternalIDReuseException if the new external_id could not be mapped.
696        """
697
698        def _remove_user_external_ids_txn(
699            txn: LoggingTransaction,
700            user_id: str,
701        ) -> None:
702            """Remove all mappings from external user ids to a mxid
703            If these mappings are not found, this method does nothing.
704
705            Args:
706                user_id: complete mxid that it is mapped to
707            """
708
709            self.db_pool.simple_delete_txn(
710                txn,
711                table="user_external_ids",
712                keyvalues={"user_id": user_id},
713            )
714
715        def _replace_user_external_id_txn(
716            txn: LoggingTransaction,
717        ):
718            _remove_user_external_ids_txn(txn, user_id)
719
720            for auth_provider, external_id in record_external_ids:
721                self._record_user_external_id_txn(
722                    txn,
723                    auth_provider,
724                    external_id,
725                    user_id,
726                )
727
728        try:
729            await self.db_pool.runInteraction(
730                "replace_user_external_id",
731                _replace_user_external_id_txn,
732            )
733        except self.database_engine.module.IntegrityError:
734            raise ExternalIDReuseException()
735
736    async def get_user_by_external_id(
737        self, auth_provider: str, external_id: str
738    ) -> Optional[str]:
739        """Look up a user by their external auth id
740
741        Args:
742            auth_provider: identifier for the remote auth provider
743            external_id: id on that system
744
745        Returns:
746            the mxid of the user, or None if they are not known
747        """
748        return await self.db_pool.simple_select_one_onecol(
749            table="user_external_ids",
750            keyvalues={"auth_provider": auth_provider, "external_id": external_id},
751            retcol="user_id",
752            allow_none=True,
753            desc="get_user_by_external_id",
754        )
755
756    async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
757        """Look up external ids for the given user
758
759        Args:
760            mxid: the MXID to be looked up
761
762        Returns:
763            Tuples of (auth_provider, external_id)
764        """
765        res = await self.db_pool.simple_select_list(
766            table="user_external_ids",
767            keyvalues={"user_id": mxid},
768            retcols=("auth_provider", "external_id"),
769            desc="get_external_ids_by_user",
770        )
771        return [(r["auth_provider"], r["external_id"]) for r in res]
772
773    async def count_all_users(self):
774        """Counts all users registered on the homeserver."""
775
776        def _count_users(txn):
777            txn.execute("SELECT COUNT(*) AS users FROM users")
778            rows = self.db_pool.cursor_to_dict(txn)
779            if rows:
780                return rows[0]["users"]
781            return 0
782
783        return await self.db_pool.runInteraction("count_users", _count_users)
784
785    async def count_daily_user_type(self) -> Dict[str, int]:
786        """
787        Counts 1) native non guest users
788               2) native guests users
789               3) bridged users
790        who registered on the homeserver in the past 24 hours
791        """
792
793        def _count_daily_user_type(txn):
794            yesterday = int(self._clock.time()) - (60 * 60 * 24)
795
796            sql = """
797                SELECT user_type, COUNT(*) AS count FROM (
798                    SELECT
799                    CASE
800                        WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
801                        WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
802                        WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
803                    END AS user_type
804                    FROM users
805                    WHERE creation_ts > ?
806                ) AS t GROUP BY user_type
807            """
808            results = {"native": 0, "guest": 0, "bridged": 0}
809            txn.execute(sql, (yesterday,))
810            for row in txn:
811                results[row[0]] = row[1]
812            return results
813
814        return await self.db_pool.runInteraction(
815            "count_daily_user_type", _count_daily_user_type
816        )
817
818    async def count_nonbridged_users(self):
819        def _count_users(txn):
820            txn.execute(
821                """
822                SELECT COUNT(*) FROM users
823                WHERE appservice_id IS NULL
824            """
825            )
826            (count,) = txn.fetchone()
827            return count
828
829        return await self.db_pool.runInteraction("count_users", _count_users)
830
831    async def count_real_users(self):
832        """Counts all users without a special user_type registered on the homeserver."""
833
834        def _count_users(txn):
835            txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
836            rows = self.db_pool.cursor_to_dict(txn)
837            if rows:
838                return rows[0]["users"]
839            return 0
840
841        return await self.db_pool.runInteraction("count_real_users", _count_users)
842
843    async def generate_user_id(self) -> str:
844        """Generate a suitable localpart for a guest user
845
846        Returns: a (hopefully) free localpart
847        """
848        next_id = await self.db_pool.runInteraction(
849            "generate_user_id", self._user_id_seq.get_next_id_txn
850        )
851
852        return str(next_id)
853
854    async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
855        """Returns user id from threepid
856
857        Args:
858            medium: threepid medium e.g. email
859            address: threepid address e.g. me@example.com. This must already be
860                in canonical form.
861
862        Returns:
863            The user ID or None if no user id/threepid mapping exists
864        """
865        user_id = await self.db_pool.runInteraction(
866            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
867        )
868        return user_id
869
870    def get_user_id_by_threepid_txn(
871        self, txn, medium: str, address: str
872    ) -> Optional[str]:
873        """Returns user id from threepid
874
875        Args:
876            txn (cursor):
877            medium: threepid medium e.g. email
878            address: threepid address e.g. me@example.com
879
880        Returns:
881            user id, or None if no user id/threepid mapping exists
882        """
883        ret = self.db_pool.simple_select_one_txn(
884            txn,
885            "user_threepids",
886            {"medium": medium, "address": address},
887            ["user_id"],
888            True,
889        )
890        if ret:
891            return ret["user_id"]
892        return None
893
894    async def user_add_threepid(
895        self,
896        user_id: str,
897        medium: str,
898        address: str,
899        validated_at: int,
900        added_at: int,
901    ) -> None:
902        await self.db_pool.simple_upsert(
903            "user_threepids",
904            {"medium": medium, "address": address},
905            {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
906        )
907
908    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
909        return await self.db_pool.simple_select_list(
910            "user_threepids",
911            {"user_id": user_id},
912            ["medium", "address", "validated_at", "added_at"],
913            "user_get_threepids",
914        )
915
916    async def user_delete_threepid(
917        self, user_id: str, medium: str, address: str
918    ) -> None:
919        await self.db_pool.simple_delete(
920            "user_threepids",
921            keyvalues={"user_id": user_id, "medium": medium, "address": address},
922            desc="user_delete_threepid",
923        )
924
925    async def user_delete_threepids(self, user_id: str) -> None:
926        """Delete all threepid this user has bound
927
928        Args:
929             user_id: The user id to delete all threepids of
930
931        """
932        await self.db_pool.simple_delete(
933            "user_threepids",
934            keyvalues={"user_id": user_id},
935            desc="user_delete_threepids",
936        )
937
938    async def add_user_bound_threepid(
939        self, user_id: str, medium: str, address: str, id_server: str
940    ):
941        """The server proxied a bind request to the given identity server on
942        behalf of the given user. We need to remember this in case the user
943        asks us to unbind the threepid.
944
945        Args:
946            user_id
947            medium
948            address
949            id_server
950        """
951        # We need to use an upsert, in case they user had already bound the
952        # threepid
953        await self.db_pool.simple_upsert(
954            table="user_threepid_id_server",
955            keyvalues={
956                "user_id": user_id,
957                "medium": medium,
958                "address": address,
959                "id_server": id_server,
960            },
961            values={},
962            insertion_values={},
963            desc="add_user_bound_threepid",
964        )
965
966    async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
967        """Get the threepids that a user has bound to an identity server through the homeserver
968        The homeserver remembers where binds to an identity server occurred. Using this
969        method can retrieve those threepids.
970
971        Args:
972            user_id: The ID of the user to retrieve threepids for
973
974        Returns:
975            List of dictionaries containing the following keys:
976                medium (str): The medium of the threepid (e.g "email")
977                address (str): The address of the threepid (e.g "bob@example.com")
978        """
979        return await self.db_pool.simple_select_list(
980            table="user_threepid_id_server",
981            keyvalues={"user_id": user_id},
982            retcols=["medium", "address"],
983            desc="user_get_bound_threepids",
984        )
985
986    async def remove_user_bound_threepid(
987        self, user_id: str, medium: str, address: str, id_server: str
988    ) -> None:
989        """The server proxied an unbind request to the given identity server on
990        behalf of the given user, so we remove the mapping of threepid to
991        identity server.
992
993        Args:
994            user_id
995            medium
996            address
997            id_server
998        """
999        await self.db_pool.simple_delete(
1000            table="user_threepid_id_server",
1001            keyvalues={
1002                "user_id": user_id,
1003                "medium": medium,
1004                "address": address,
1005                "id_server": id_server,
1006            },
1007            desc="remove_user_bound_threepid",
1008        )
1009
1010    async def get_id_servers_user_bound(
1011        self, user_id: str, medium: str, address: str
1012    ) -> List[str]:
1013        """Get the list of identity servers that the server proxied bind
1014        requests to for given user and threepid
1015
1016        Args:
1017            user_id: The user to query for identity servers.
1018            medium: The medium to query for identity servers.
1019            address: The address to query for identity servers.
1020
1021        Returns:
1022            A list of identity servers
1023        """
1024        return await self.db_pool.simple_select_onecol(
1025            table="user_threepid_id_server",
1026            keyvalues={"user_id": user_id, "medium": medium, "address": address},
1027            retcol="id_server",
1028            desc="get_id_servers_user_bound",
1029        )
1030
1031    @cached()
1032    async def get_user_deactivated_status(self, user_id: str) -> bool:
1033        """Retrieve the value for the `deactivated` property for the provided user.
1034
1035        Args:
1036            user_id: The ID of the user to retrieve the status for.
1037
1038        Returns:
1039            True if the user was deactivated, false if the user is still active.
1040        """
1041
1042        res = await self.db_pool.simple_select_one_onecol(
1043            table="users",
1044            keyvalues={"name": user_id},
1045            retcol="deactivated",
1046            desc="get_user_deactivated_status",
1047        )
1048
1049        # Convert the integer into a boolean.
1050        return res == 1
1051
1052    async def get_threepid_validation_session(
1053        self,
1054        medium: Optional[str],
1055        client_secret: str,
1056        address: Optional[str] = None,
1057        sid: Optional[str] = None,
1058        validated: Optional[bool] = True,
1059    ) -> Optional[Dict[str, Any]]:
1060        """Gets a session_id and last_send_attempt (if available) for a
1061        combination of validation metadata
1062
1063        Args:
1064            medium: The medium of the 3PID
1065            client_secret: A unique string provided by the client to help identify this
1066                validation attempt
1067            address: The address of the 3PID
1068            sid: The ID of the validation session
1069            validated: Whether sessions should be filtered by
1070                whether they have been validated already or not. None to
1071                perform no filtering
1072
1073        Returns:
1074            A dict containing the following:
1075                * address - address of the 3pid
1076                * medium - medium of the 3pid
1077                * client_secret - a secret provided by the client for this validation session
1078                * session_id - ID of the validation session
1079                * send_attempt - a number serving to dedupe send attempts for this session
1080                * validated_at - timestamp of when this session was validated if so
1081
1082                Otherwise None if a validation session is not found
1083        """
1084        if not client_secret:
1085            raise SynapseError(
1086                400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
1087            )
1088
1089        keyvalues = {"client_secret": client_secret}
1090        if medium:
1091            keyvalues["medium"] = medium
1092        if address:
1093            keyvalues["address"] = address
1094        if sid:
1095            keyvalues["session_id"] = sid
1096
1097        assert address or sid
1098
1099        def get_threepid_validation_session_txn(txn):
1100            sql = """
1101                SELECT address, session_id, medium, client_secret,
1102                last_send_attempt, validated_at
1103                FROM threepid_validation_session WHERE %s
1104                """ % (
1105                " AND ".join("%s = ?" % k for k in keyvalues.keys()),
1106            )
1107
1108            if validated is not None:
1109                sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
1110
1111            sql += " LIMIT 1"
1112
1113            txn.execute(sql, list(keyvalues.values()))
1114            rows = self.db_pool.cursor_to_dict(txn)
1115            if not rows:
1116                return None
1117
1118            return rows[0]
1119
1120        return await self.db_pool.runInteraction(
1121            "get_threepid_validation_session", get_threepid_validation_session_txn
1122        )
1123
1124    async def delete_threepid_session(self, session_id: str) -> None:
1125        """Removes a threepid validation session from the database. This can
1126        be done after validation has been performed and whatever action was
1127        waiting on it has been carried out
1128
1129        Args:
1130            session_id: The ID of the session to delete
1131        """
1132
1133        def delete_threepid_session_txn(txn):
1134            self.db_pool.simple_delete_txn(
1135                txn,
1136                table="threepid_validation_token",
1137                keyvalues={"session_id": session_id},
1138            )
1139            self.db_pool.simple_delete_txn(
1140                txn,
1141                table="threepid_validation_session",
1142                keyvalues={"session_id": session_id},
1143            )
1144
1145        await self.db_pool.runInteraction(
1146            "delete_threepid_session", delete_threepid_session_txn
1147        )
1148
1149    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
1150    async def cull_expired_threepid_validation_tokens(self) -> None:
1151        """Remove threepid validation tokens with expiry dates that have passed"""
1152
1153        def cull_expired_threepid_validation_tokens_txn(txn, ts):
1154            sql = """
1155            DELETE FROM threepid_validation_token WHERE
1156            expires < ?
1157            """
1158            txn.execute(sql, (ts,))
1159
1160        await self.db_pool.runInteraction(
1161            "cull_expired_threepid_validation_tokens",
1162            cull_expired_threepid_validation_tokens_txn,
1163            self._clock.time_msec(),
1164        )
1165
1166    @wrap_as_background_process("account_validity_set_expiration_dates")
1167    async def _set_expiration_date_when_missing(self):
1168        """
1169        Retrieves the list of registered users that don't have an expiration date, and
1170        adds an expiration date for each of them.
1171        """
1172
1173        def select_users_with_no_expiration_date_txn(txn):
1174            """Retrieves the list of registered users with no expiration date from the
1175            database, filtering out deactivated users.
1176            """
1177            sql = (
1178                "SELECT users.name FROM users"
1179                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
1180                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
1181            )
1182            txn.execute(sql, [])
1183
1184            res = self.db_pool.cursor_to_dict(txn)
1185            if res:
1186                for user in res:
1187                    self.set_expiration_date_for_user_txn(
1188                        txn, user["name"], use_delta=True
1189                    )
1190
1191        await self.db_pool.runInteraction(
1192            "get_users_with_no_expiration_date",
1193            select_users_with_no_expiration_date_txn,
1194        )
1195
1196    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
1197        """Sets an expiration date to the account with the given user ID.
1198
1199        Args:
1200             user_id (str): User ID to set an expiration date for.
1201             use_delta (bool): If set to False, the expiration date for the user will be
1202                now + validity period. If set to True, this expiration date will be a
1203                random value in the [now + period - d ; now + period] range, d being a
1204                delta equal to 10% of the validity period.
1205        """
1206        now_ms = self._clock.time_msec()
1207        assert self._account_validity_period is not None
1208        expiration_ts = now_ms + self._account_validity_period
1209
1210        if use_delta:
1211            assert self._account_validity_startup_job_max_delta is not None
1212            expiration_ts = random.randrange(
1213                int(expiration_ts - self._account_validity_startup_job_max_delta),
1214                expiration_ts,
1215            )
1216
1217        self.db_pool.simple_upsert_txn(
1218            txn,
1219            "account_validity",
1220            keyvalues={"user_id": user_id},
1221            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
1222        )
1223
1224    async def get_user_pending_deactivation(self) -> Optional[str]:
1225        """
1226        Gets one user from the table of users waiting to be parted from all the rooms
1227        they're in.
1228        """
1229        return await self.db_pool.simple_select_one_onecol(
1230            "users_pending_deactivation",
1231            keyvalues={},
1232            retcol="user_id",
1233            allow_none=True,
1234            desc="get_users_pending_deactivation",
1235        )
1236
1237    async def del_user_pending_deactivation(self, user_id: str) -> None:
1238        """
1239        Removes the given user to the table of users who need to be parted from all the
1240        rooms they're in, effectively marking that user as fully deactivated.
1241        """
1242        # XXX: This should be simple_delete_one but we failed to put a unique index on
1243        # the table, so somehow duplicate entries have ended up in it.
1244        await self.db_pool.simple_delete(
1245            "users_pending_deactivation",
1246            keyvalues={"user_id": user_id},
1247            desc="del_user_pending_deactivation",
1248        )
1249
1250    async def get_access_token_last_validated(self, token_id: int) -> int:
1251        """Retrieves the time (in milliseconds) of the last validation of an access token.
1252
1253        Args:
1254            token_id: The ID of the access token to update.
1255        Raises:
1256            StoreError if the access token was not found.
1257
1258        Returns:
1259            The last validation time.
1260        """
1261        result = await self.db_pool.simple_select_one_onecol(
1262            "access_tokens", {"id": token_id}, "last_validated"
1263        )
1264
1265        # If this token has not been validated (since starting to track this),
1266        # return 0 instead of None.
1267        return result or 0
1268
1269    async def update_access_token_last_validated(self, token_id: int) -> None:
1270        """Updates the last time an access token was validated.
1271
1272        Args:
1273            token_id: The ID of the access token to update.
1274        Raises:
1275            StoreError if there was a problem updating this.
1276        """
1277        now = self._clock.time_msec()
1278
1279        await self.db_pool.simple_update_one(
1280            "access_tokens",
1281            {"id": token_id},
1282            {"last_validated": now},
1283            desc="update_access_token_last_validated",
1284        )
1285
1286    async def registration_token_is_valid(self, token: str) -> bool:
1287        """Checks if a token can be used to authenticate a registration.
1288
1289        Args:
1290            token: The registration token to be checked
1291        Returns:
1292            True if the token is valid, False otherwise.
1293        """
1294        res = await self.db_pool.simple_select_one(
1295            "registration_tokens",
1296            keyvalues={"token": token},
1297            retcols=["uses_allowed", "pending", "completed", "expiry_time"],
1298            allow_none=True,
1299        )
1300
1301        # Check if the token exists
1302        if res is None:
1303            return False
1304
1305        # Check if the token has expired
1306        now = self._clock.time_msec()
1307        if res["expiry_time"] and res["expiry_time"] < now:
1308            return False
1309
1310        # Check if the token has been used up
1311        if (
1312            res["uses_allowed"]
1313            and res["pending"] + res["completed"] >= res["uses_allowed"]
1314        ):
1315            return False
1316
1317        # Otherwise, the token is valid
1318        return True
1319
1320    async def set_registration_token_pending(self, token: str) -> None:
1321        """Increment the pending registrations counter for a token.
1322
1323        Args:
1324            token: The registration token pending use
1325        """
1326
1327        def _set_registration_token_pending_txn(txn):
1328            pending = self.db_pool.simple_select_one_onecol_txn(
1329                txn,
1330                "registration_tokens",
1331                keyvalues={"token": token},
1332                retcol="pending",
1333            )
1334            self.db_pool.simple_update_one_txn(
1335                txn,
1336                "registration_tokens",
1337                keyvalues={"token": token},
1338                updatevalues={"pending": pending + 1},
1339            )
1340
1341        return await self.db_pool.runInteraction(
1342            "set_registration_token_pending", _set_registration_token_pending_txn
1343        )
1344
1345    async def use_registration_token(self, token: str) -> None:
1346        """Complete a use of the given registration token.
1347
1348        The `pending` counter will be decremented, and the `completed`
1349        counter will be incremented.
1350
1351        Args:
1352            token: The registration token to be 'used'
1353        """
1354
1355        def _use_registration_token_txn(txn):
1356            # Normally, res is Optional[Dict[str, Any]].
1357            # Override type because the return type is only optional if
1358            # allow_none is True, and we don't want mypy throwing errors
1359            # about None not being indexable.
1360            res = cast(
1361                Dict[str, Any],
1362                self.db_pool.simple_select_one_txn(
1363                    txn,
1364                    "registration_tokens",
1365                    keyvalues={"token": token},
1366                    retcols=["pending", "completed"],
1367                ),
1368            )
1369
1370            # Decrement pending and increment completed
1371            self.db_pool.simple_update_one_txn(
1372                txn,
1373                "registration_tokens",
1374                keyvalues={"token": token},
1375                updatevalues={
1376                    "completed": res["completed"] + 1,
1377                    "pending": res["pending"] - 1,
1378                },
1379            )
1380
1381        return await self.db_pool.runInteraction(
1382            "use_registration_token", _use_registration_token_txn
1383        )
1384
1385    async def get_registration_tokens(
1386        self, valid: Optional[bool] = None
1387    ) -> List[Dict[str, Any]]:
1388        """List all registration tokens. Used by the admin API.
1389
1390        Args:
1391            valid: If True, only valid tokens are returned.
1392              If False, only invalid tokens are returned.
1393              Default is None: return all tokens regardless of validity.
1394
1395        Returns:
1396            A list of dicts, each containing details of a token.
1397        """
1398
1399        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
1400            if valid is None:
1401                # Return all tokens regardless of validity
1402                txn.execute("SELECT * FROM registration_tokens")
1403
1404            elif valid:
1405                # Select valid tokens only
1406                sql = (
1407                    "SELECT * FROM registration_tokens WHERE "
1408                    "(uses_allowed > pending + completed OR uses_allowed IS NULL) "
1409                    "AND (expiry_time > ? OR expiry_time IS NULL)"
1410                )
1411                txn.execute(sql, [now])
1412
1413            else:
1414                # Select invalid tokens only
1415                sql = (
1416                    "SELECT * FROM registration_tokens WHERE "
1417                    "uses_allowed <= pending + completed OR expiry_time <= ?"
1418                )
1419                txn.execute(sql, [now])
1420
1421            return self.db_pool.cursor_to_dict(txn)
1422
1423        return await self.db_pool.runInteraction(
1424            "select_registration_tokens",
1425            select_registration_tokens_txn,
1426            self._clock.time_msec(),
1427            valid,
1428        )
1429
1430    async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]:
1431        """Get info about the given registration token. Used by the admin API.
1432
1433        Args:
1434            token: The token to retrieve information about.
1435
1436        Returns:
1437            A dict, or None if token doesn't exist.
1438        """
1439        return await self.db_pool.simple_select_one(
1440            "registration_tokens",
1441            keyvalues={"token": token},
1442            retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"],
1443            allow_none=True,
1444            desc="get_one_registration_token",
1445        )
1446
1447    async def generate_registration_token(
1448        self, length: int, chars: str
1449    ) -> Optional[str]:
1450        """Generate a random registration token. Used by the admin API.
1451
1452        Args:
1453            length: The length of the token to generate.
1454            chars: A string of the characters allowed in the generated token.
1455
1456        Returns:
1457            The generated token.
1458
1459        Raises:
1460            SynapseError if a unique registration token could still not be
1461            generated after a few tries.
1462        """
1463        # Make a few attempts at generating a unique token of the required
1464        # length before failing.
1465        for _i in range(3):
1466            # Generate token
1467            token = "".join(random.choices(chars, k=length))
1468
1469            # Check if the token already exists
1470            existing_token = await self.db_pool.simple_select_one_onecol(
1471                "registration_tokens",
1472                keyvalues={"token": token},
1473                retcol="token",
1474                allow_none=True,
1475                desc="check_if_registration_token_exists",
1476            )
1477
1478            if existing_token is None:
1479                # The generated token doesn't exist yet, return it
1480                return token
1481
1482        raise SynapseError(
1483            500,
1484            "Unable to generate a unique registration token. Try again with a greater length",
1485            Codes.UNKNOWN,
1486        )
1487
1488    async def create_registration_token(
1489        self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int]
1490    ) -> bool:
1491        """Create a new registration token. Used by the admin API.
1492
1493        Args:
1494            token: The token to create.
1495            uses_allowed: The number of times the token can be used to complete
1496              a registration before it becomes invalid. A value of None indicates
1497              unlimited uses.
1498            expiry_time: The latest time the token is valid. Given as the
1499              number of milliseconds since 1970-01-01 00:00:00 UTC. A value of
1500              None indicates that the token does not expire.
1501
1502        Returns:
1503            Whether the row was inserted or not.
1504        """
1505
1506        def _create_registration_token_txn(txn):
1507            row = self.db_pool.simple_select_one_txn(
1508                txn,
1509                "registration_tokens",
1510                keyvalues={"token": token},
1511                retcols=["token"],
1512                allow_none=True,
1513            )
1514
1515            if row is not None:
1516                # Token already exists
1517                return False
1518
1519            self.db_pool.simple_insert_txn(
1520                txn,
1521                "registration_tokens",
1522                values={
1523                    "token": token,
1524                    "uses_allowed": uses_allowed,
1525                    "pending": 0,
1526                    "completed": 0,
1527                    "expiry_time": expiry_time,
1528                },
1529            )
1530
1531            return True
1532
1533        return await self.db_pool.runInteraction(
1534            "create_registration_token", _create_registration_token_txn
1535        )
1536
1537    async def update_registration_token(
1538        self, token: str, updatevalues: Dict[str, Optional[int]]
1539    ) -> Optional[Dict[str, Any]]:
1540        """Update a registration token. Used by the admin API.
1541
1542        Args:
1543            token: The token to update.
1544            updatevalues: A dict with the fields to update. E.g.:
1545              `{"uses_allowed": 3}` to update just uses_allowed, or
1546              `{"uses_allowed": 3, "expiry_time": None}` to update both.
1547              This is passed straight to simple_update_one.
1548
1549        Returns:
1550            A dict with all info about the token, or None if token doesn't exist.
1551        """
1552
1553        def _update_registration_token_txn(txn):
1554            try:
1555                self.db_pool.simple_update_one_txn(
1556                    txn,
1557                    "registration_tokens",
1558                    keyvalues={"token": token},
1559                    updatevalues=updatevalues,
1560                )
1561            except StoreError:
1562                # Update failed because token does not exist
1563                return None
1564
1565            # Get all info about the token so it can be sent in the response
1566            return self.db_pool.simple_select_one_txn(
1567                txn,
1568                "registration_tokens",
1569                keyvalues={"token": token},
1570                retcols=[
1571                    "token",
1572                    "uses_allowed",
1573                    "pending",
1574                    "completed",
1575                    "expiry_time",
1576                ],
1577                allow_none=True,
1578            )
1579
1580        return await self.db_pool.runInteraction(
1581            "update_registration_token", _update_registration_token_txn
1582        )
1583
1584    async def delete_registration_token(self, token: str) -> bool:
1585        """Delete a registration token. Used by the admin API.
1586
1587        Args:
1588            token: The token to delete.
1589
1590        Returns:
1591            Whether the token was successfully deleted or not.
1592        """
1593        try:
1594            await self.db_pool.simple_delete_one(
1595                "registration_tokens",
1596                keyvalues={"token": token},
1597                desc="delete_registration_token",
1598            )
1599        except StoreError:
1600            # Deletion failed because token does not exist
1601            return False
1602
1603        return True
1604
1605    @cached()
1606    async def mark_access_token_as_used(self, token_id: int) -> None:
1607        """
1608        Mark the access token as used, which invalidates the refresh token used
1609        to obtain it.
1610
1611        Because get_user_by_access_token is cached, this function might be
1612        called multiple times for the same token, effectively doing unnecessary
1613        SQL updates. Because updating the `used` field only goes one way (from
1614        False to True) it is safe to cache this function as well to avoid this
1615        issue.
1616
1617        Args:
1618            token_id: The ID of the access token to update.
1619        Raises:
1620            StoreError if there was a problem updating this.
1621        """
1622        await self.db_pool.simple_update_one(
1623            "access_tokens",
1624            {"id": token_id},
1625            {"used": True},
1626            desc="mark_access_token_as_used",
1627        )
1628
1629    async def lookup_refresh_token(
1630        self, token: str
1631    ) -> Optional[RefreshTokenLookupResult]:
1632        """Lookup a refresh token with hints about its validity."""
1633
1634        def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
1635            txn.execute(
1636                """
1637                SELECT
1638                    rt.id token_id,
1639                    rt.user_id,
1640                    rt.device_id,
1641                    rt.next_token_id,
1642                    (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
1643                    at.used AS has_next_access_token_been_used,
1644                    rt.expiry_ts,
1645                    rt.ultimate_session_expiry_ts
1646                FROM refresh_tokens rt
1647                LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
1648                LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
1649                WHERE rt.token = ?
1650            """,
1651                (token,),
1652            )
1653            row = txn.fetchone()
1654
1655            if row is None:
1656                return None
1657
1658            return RefreshTokenLookupResult(
1659                token_id=row[0],
1660                user_id=row[1],
1661                device_id=row[2],
1662                next_token_id=row[3],
1663                has_next_refresh_token_been_refreshed=row[4],
1664                # This column is nullable, ensure it's a boolean
1665                has_next_access_token_been_used=(row[5] or False),
1666                expiry_ts=row[6],
1667                ultimate_session_expiry_ts=row[7],
1668            )
1669
1670        return await self.db_pool.runInteraction(
1671            "lookup_refresh_token", _lookup_refresh_token_txn
1672        )
1673
1674    async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
1675        """
1676        Set the successor of a refresh token, removing the existing successor
1677        if any.
1678
1679        Args:
1680            token_id: ID of the refresh token to update.
1681            next_token_id: ID of its successor.
1682        """
1683
1684        def _replace_refresh_token_txn(txn) -> None:
1685            # First check if there was an existing refresh token
1686            old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
1687                txn,
1688                "refresh_tokens",
1689                {"id": token_id},
1690                "next_token_id",
1691                allow_none=True,
1692            )
1693
1694            self.db_pool.simple_update_one_txn(
1695                txn,
1696                "refresh_tokens",
1697                {"id": token_id},
1698                {"next_token_id": next_token_id},
1699            )
1700
1701            # Delete the old "next" token if it exists. This should cascade and
1702            # delete the associated access_token
1703            if old_next_token_id is not None:
1704                self.db_pool.simple_delete_one_txn(
1705                    txn,
1706                    "refresh_tokens",
1707                    {"id": old_next_token_id},
1708                )
1709
1710        await self.db_pool.runInteraction(
1711            "replace_refresh_token", _replace_refresh_token_txn
1712        )
1713
1714
1715class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
1716    def __init__(
1717        self,
1718        database: DatabasePool,
1719        db_conn: LoggingDatabaseConnection,
1720        hs: "HomeServer",
1721    ):
1722        super().__init__(database, db_conn, hs)
1723
1724        self._clock = hs.get_clock()
1725        self.config = hs.config
1726
1727        self.db_pool.updates.register_background_index_update(
1728            "access_tokens_device_index",
1729            index_name="access_tokens_device_id",
1730            table="access_tokens",
1731            columns=["user_id", "device_id"],
1732        )
1733
1734        self.db_pool.updates.register_background_index_update(
1735            "users_creation_ts",
1736            index_name="users_creation_ts",
1737            table="users",
1738            columns=["creation_ts"],
1739        )
1740
1741        # we no longer use refresh tokens, but it's possible that some people
1742        # might have a background update queued to build this index. Just
1743        # clear the background update.
1744        self.db_pool.updates.register_noop_background_update(
1745            "refresh_tokens_device_index"
1746        )
1747
1748        self.db_pool.updates.register_background_update_handler(
1749            "users_set_deactivated_flag", self._background_update_set_deactivated_flag
1750        )
1751
1752        self.db_pool.updates.register_noop_background_update(
1753            "user_threepids_grandfather"
1754        )
1755
1756        self.db_pool.updates.register_background_index_update(
1757            "user_external_ids_user_id_idx",
1758            index_name="user_external_ids_user_id_idx",
1759            table="user_external_ids",
1760            columns=["user_id"],
1761            unique=False,
1762        )
1763
1764    async def _background_update_set_deactivated_flag(self, progress, batch_size):
1765        """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
1766        for each of them.
1767        """
1768
1769        last_user = progress.get("user_id", "")
1770
1771        def _background_update_set_deactivated_flag_txn(txn):
1772            txn.execute(
1773                """
1774                SELECT
1775                    users.name,
1776                    COUNT(access_tokens.token) AS count_tokens,
1777                    COUNT(user_threepids.address) AS count_threepids
1778                FROM users
1779                    LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
1780                    LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
1781                WHERE (users.password_hash IS NULL OR users.password_hash = '')
1782                AND (users.appservice_id IS NULL OR users.appservice_id = '')
1783                AND users.is_guest = 0
1784                AND users.name > ?
1785                GROUP BY users.name
1786                ORDER BY users.name ASC
1787                LIMIT ?;
1788                """,
1789                (last_user, batch_size),
1790            )
1791
1792            rows = self.db_pool.cursor_to_dict(txn)
1793
1794            if not rows:
1795                return True, 0
1796
1797            rows_processed_nb = 0
1798
1799            for user in rows:
1800                if not user["count_tokens"] and not user["count_threepids"]:
1801                    self.set_user_deactivated_status_txn(txn, user["name"], True)
1802                    rows_processed_nb += 1
1803
1804            logger.info("Marked %d rows as deactivated", rows_processed_nb)
1805
1806            self.db_pool.updates._background_update_progress_txn(
1807                txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
1808            )
1809
1810            if batch_size > len(rows):
1811                return True, len(rows)
1812            else:
1813                return False, len(rows)
1814
1815        end, nb_processed = await self.db_pool.runInteraction(
1816            "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
1817        )
1818
1819        if end:
1820            await self.db_pool.updates._end_background_update(
1821                "users_set_deactivated_flag"
1822            )
1823
1824        return nb_processed
1825
1826    async def set_user_deactivated_status(
1827        self, user_id: str, deactivated: bool
1828    ) -> None:
1829        """Set the `deactivated` property for the provided user to the provided value.
1830
1831        Args:
1832            user_id: The ID of the user to set the status for.
1833            deactivated: The value to set for `deactivated`.
1834        """
1835
1836        await self.db_pool.runInteraction(
1837            "set_user_deactivated_status",
1838            self.set_user_deactivated_status_txn,
1839            user_id,
1840            deactivated,
1841        )
1842
1843    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
1844        self.db_pool.simple_update_one_txn(
1845            txn=txn,
1846            table="users",
1847            keyvalues={"name": user_id},
1848            updatevalues={"deactivated": 1 if deactivated else 0},
1849        )
1850        self._invalidate_cache_and_stream(
1851            txn, self.get_user_deactivated_status, (user_id,)
1852        )
1853        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
1854        txn.call_after(self.is_guest.invalidate, (user_id,))
1855
1856    @cached()
1857    async def is_guest(self, user_id: str) -> bool:
1858        res = await self.db_pool.simple_select_one_onecol(
1859            table="users",
1860            keyvalues={"name": user_id},
1861            retcol="is_guest",
1862            allow_none=True,
1863            desc="is_guest",
1864        )
1865
1866        return res if res else False
1867
1868
1869class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
1870    def __init__(
1871        self,
1872        database: DatabasePool,
1873        db_conn: LoggingDatabaseConnection,
1874        hs: "HomeServer",
1875    ):
1876        super().__init__(database, db_conn, hs)
1877
1878        self._ignore_unknown_session_error = (
1879            hs.config.server.request_token_inhibit_3pid_errors
1880        )
1881
1882        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
1883        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
1884
1885    async def add_access_token_to_user(
1886        self,
1887        user_id: str,
1888        token: str,
1889        device_id: Optional[str],
1890        valid_until_ms: Optional[int],
1891        puppets_user_id: Optional[str] = None,
1892        refresh_token_id: Optional[int] = None,
1893    ) -> int:
1894        """Adds an access token for the given user.
1895
1896        Args:
1897            user_id: The user ID.
1898            token: The new access token to add.
1899            device_id: ID of the device to associate with the access token.
1900            valid_until_ms: when the token is valid until. None for no expiry.
1901            puppets_user_id
1902            refresh_token_id: ID of the refresh token generated alongside this
1903                access token.
1904        Raises:
1905            StoreError if there was a problem adding this.
1906        Returns:
1907            The token ID
1908        """
1909        next_id = self._access_tokens_id_gen.get_next()
1910        now = self._clock.time_msec()
1911
1912        await self.db_pool.simple_insert(
1913            "access_tokens",
1914            {
1915                "id": next_id,
1916                "user_id": user_id,
1917                "token": token,
1918                "device_id": device_id,
1919                "valid_until_ms": valid_until_ms,
1920                "puppets_user_id": puppets_user_id,
1921                "last_validated": now,
1922                "refresh_token_id": refresh_token_id,
1923                "used": False,
1924            },
1925            desc="add_access_token_to_user",
1926        )
1927
1928        return next_id
1929
1930    async def add_refresh_token_to_user(
1931        self,
1932        user_id: str,
1933        token: str,
1934        device_id: Optional[str],
1935        expiry_ts: Optional[int],
1936        ultimate_session_expiry_ts: Optional[int],
1937    ) -> int:
1938        """Adds a refresh token for the given user.
1939
1940        Args:
1941            user_id: The user ID.
1942            token: The new access token to add.
1943            device_id: ID of the device to associate with the refresh token.
1944            expiry_ts (milliseconds since the epoch): Time after which the
1945                refresh token cannot be used.
1946                If None, the refresh token never expires until it has been used.
1947            ultimate_session_expiry_ts (milliseconds since the epoch):
1948                Time at which the session will end and can not be extended any
1949                further.
1950                If None, the session can be refreshed indefinitely.
1951        Raises:
1952            StoreError if there was a problem adding this.
1953        Returns:
1954            The token ID
1955        """
1956        next_id = self._refresh_tokens_id_gen.get_next()
1957
1958        await self.db_pool.simple_insert(
1959            "refresh_tokens",
1960            {
1961                "id": next_id,
1962                "user_id": user_id,
1963                "device_id": device_id,
1964                "token": token,
1965                "next_token_id": None,
1966                "expiry_ts": expiry_ts,
1967                "ultimate_session_expiry_ts": ultimate_session_expiry_ts,
1968            },
1969            desc="add_refresh_token_to_user",
1970        )
1971
1972        return next_id
1973
1974    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
1975        old_device_id = self.db_pool.simple_select_one_onecol_txn(
1976            txn, "access_tokens", {"token": token}, "device_id"
1977        )
1978
1979        self.db_pool.simple_update_txn(
1980            txn, "access_tokens", {"token": token}, {"device_id": device_id}
1981        )
1982
1983        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
1984
1985        return old_device_id
1986
1987    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
1988        """Sets the device ID associated with an access token.
1989
1990        Args:
1991            token: The access token to modify.
1992            device_id: The new device ID.
1993        Returns:
1994            The old device ID associated with the access token.
1995        """
1996
1997        return await self.db_pool.runInteraction(
1998            "set_device_for_access_token",
1999            self._set_device_for_access_token_txn,
2000            token,
2001            device_id,
2002        )
2003
2004    async def register_user(
2005        self,
2006        user_id: str,
2007        password_hash: Optional[str] = None,
2008        was_guest: bool = False,
2009        make_guest: bool = False,
2010        appservice_id: Optional[str] = None,
2011        create_profile_with_displayname: Optional[str] = None,
2012        admin: bool = False,
2013        user_type: Optional[str] = None,
2014        shadow_banned: bool = False,
2015    ) -> None:
2016        """Attempts to register an account.
2017
2018        Args:
2019            user_id: The desired user ID to register.
2020            password_hash: Optional. The password hash for this user.
2021            was_guest: Whether this is a guest account being upgraded to a
2022                non-guest account.
2023            make_guest: True if the the new user should be guest, false to add a
2024                regular user account.
2025            appservice_id: The ID of the appservice registering the user.
2026            create_profile_with_displayname: Optionally create a profile for
2027                the user, setting their displayname to the given value
2028            admin: is an admin user?
2029            user_type: type of user. One of the values from api.constants.UserTypes,
2030                or None for a normal user.
2031            shadow_banned: Whether the user is shadow-banned, i.e. they may be
2032                told their requests succeeded but we ignore them.
2033
2034        Raises:
2035            StoreError if the user_id could not be registered.
2036        """
2037        await self.db_pool.runInteraction(
2038            "register_user",
2039            self._register_user,
2040            user_id,
2041            password_hash,
2042            was_guest,
2043            make_guest,
2044            appservice_id,
2045            create_profile_with_displayname,
2046            admin,
2047            user_type,
2048            shadow_banned,
2049        )
2050
2051    def _register_user(
2052        self,
2053        txn,
2054        user_id: str,
2055        password_hash: Optional[str],
2056        was_guest: bool,
2057        make_guest: bool,
2058        appservice_id: Optional[str],
2059        create_profile_with_displayname: Optional[str],
2060        admin: bool,
2061        user_type: Optional[str],
2062        shadow_banned: bool,
2063    ):
2064        user_id_obj = UserID.from_string(user_id)
2065
2066        now = int(self._clock.time())
2067
2068        try:
2069            if was_guest:
2070                # Ensure that the guest user actually exists
2071                # ``allow_none=False`` makes this raise an exception
2072                # if the row isn't in the database.
2073                self.db_pool.simple_select_one_txn(
2074                    txn,
2075                    "users",
2076                    keyvalues={"name": user_id, "is_guest": 1},
2077                    retcols=("name",),
2078                    allow_none=False,
2079                )
2080
2081                self.db_pool.simple_update_one_txn(
2082                    txn,
2083                    "users",
2084                    keyvalues={"name": user_id, "is_guest": 1},
2085                    updatevalues={
2086                        "password_hash": password_hash,
2087                        "upgrade_ts": now,
2088                        "is_guest": 1 if make_guest else 0,
2089                        "appservice_id": appservice_id,
2090                        "admin": 1 if admin else 0,
2091                        "user_type": user_type,
2092                        "shadow_banned": shadow_banned,
2093                    },
2094                )
2095            else:
2096                self.db_pool.simple_insert_txn(
2097                    txn,
2098                    "users",
2099                    values={
2100                        "name": user_id,
2101                        "password_hash": password_hash,
2102                        "creation_ts": now,
2103                        "is_guest": 1 if make_guest else 0,
2104                        "appservice_id": appservice_id,
2105                        "admin": 1 if admin else 0,
2106                        "user_type": user_type,
2107                        "shadow_banned": shadow_banned,
2108                    },
2109                )
2110
2111        except self.database_engine.module.IntegrityError:
2112            raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
2113
2114        if self._account_validity_enabled:
2115            self.set_expiration_date_for_user_txn(txn, user_id)
2116
2117        if create_profile_with_displayname:
2118            # set a default displayname serverside to avoid ugly race
2119            # between auto-joins and clients trying to set displaynames
2120            #
2121            # *obviously* the 'profiles' table uses localpart for user_id
2122            # while everything else uses the full mxid.
2123            txn.execute(
2124                "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
2125                (user_id_obj.localpart, create_profile_with_displayname),
2126            )
2127
2128        if self.hs.config.stats.stats_enabled:
2129            # we create a new completed user statistics row
2130
2131            # we don't strictly need current_token since this user really can't
2132            # have any state deltas before now (as it is a new user), but still,
2133            # we include it for completeness.
2134            current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
2135            self._update_stats_delta_txn(
2136                txn, now, "user", user_id, {}, complete_with_stream_id=current_token
2137            )
2138
2139        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
2140
2141    async def user_set_password_hash(
2142        self, user_id: str, password_hash: Optional[str]
2143    ) -> None:
2144        """
2145        NB. This does *not* evict any cache because the one use for this
2146            removes most of the entries subsequently anyway so it would be
2147            pointless. Use flush_user separately.
2148        """
2149
2150        def user_set_password_hash_txn(txn):
2151            self.db_pool.simple_update_one_txn(
2152                txn, "users", {"name": user_id}, {"password_hash": password_hash}
2153            )
2154            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
2155
2156        await self.db_pool.runInteraction(
2157            "user_set_password_hash", user_set_password_hash_txn
2158        )
2159
2160    async def user_set_consent_version(
2161        self, user_id: str, consent_version: str
2162    ) -> None:
2163        """Updates the user table to record privacy policy consent
2164
2165        Args:
2166            user_id: full mxid of the user to update
2167            consent_version: version of the policy the user has consented to
2168
2169        Raises:
2170            StoreError(404) if user not found
2171        """
2172
2173        def f(txn):
2174            self.db_pool.simple_update_one_txn(
2175                txn,
2176                table="users",
2177                keyvalues={"name": user_id},
2178                updatevalues={"consent_version": consent_version},
2179            )
2180            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
2181
2182        await self.db_pool.runInteraction("user_set_consent_version", f)
2183
2184    async def user_set_consent_server_notice_sent(
2185        self, user_id: str, consent_version: str
2186    ) -> None:
2187        """Updates the user table to record that we have sent the user a server
2188        notice about privacy policy consent
2189
2190        Args:
2191            user_id: full mxid of the user to update
2192            consent_version: version of the policy we have notified the user about
2193
2194        Raises:
2195            StoreError(404) if user not found
2196        """
2197
2198        def f(txn):
2199            self.db_pool.simple_update_one_txn(
2200                txn,
2201                table="users",
2202                keyvalues={"name": user_id},
2203                updatevalues={"consent_server_notice_sent": consent_version},
2204            )
2205            self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
2206
2207        await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
2208
2209    async def user_delete_access_tokens(
2210        self,
2211        user_id: str,
2212        except_token_id: Optional[int] = None,
2213        device_id: Optional[str] = None,
2214    ) -> List[Tuple[str, int, Optional[str]]]:
2215        """
2216        Invalidate access and refresh tokens belonging to a user
2217
2218        Args:
2219            user_id: ID of user the tokens belong to
2220            except_token_id: access_tokens ID which should *not* be deleted
2221            device_id: ID of device the tokens are associated with.
2222                If None, tokens associated with any device (or no device) will
2223                be deleted
2224        Returns:
2225            A tuple of (token, token id, device id) for each of the deleted tokens
2226        """
2227
2228        def f(txn):
2229            keyvalues = {"user_id": user_id}
2230            if device_id is not None:
2231                keyvalues["device_id"] = device_id
2232
2233            items = keyvalues.items()
2234            where_clause = " AND ".join(k + " = ?" for k, _ in items)
2235            values: List[Union[str, int]] = [v for _, v in items]
2236            # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
2237            # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
2238            # clause and values before we handle that. This seems to be only used in the "set password" handler.
2239            refresh_where_clause = where_clause
2240            refresh_values = values.copy()
2241            if except_token_id:
2242                # TODO: support that for refresh tokens
2243                where_clause += " AND id != ?"
2244                values.append(except_token_id)
2245
2246            txn.execute(
2247                "SELECT token, id, device_id FROM access_tokens WHERE %s"
2248                % where_clause,
2249                values,
2250            )
2251            tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
2252
2253            for token, _, _ in tokens_and_devices:
2254                self._invalidate_cache_and_stream(
2255                    txn, self.get_user_by_access_token, (token,)
2256                )
2257
2258            txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
2259
2260            txn.execute(
2261                "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
2262                refresh_values,
2263            )
2264
2265            return tokens_and_devices
2266
2267        return await self.db_pool.runInteraction("user_delete_access_tokens", f)
2268
2269    async def delete_access_token(self, access_token: str) -> None:
2270        def f(txn):
2271            self.db_pool.simple_delete_one_txn(
2272                txn, table="access_tokens", keyvalues={"token": access_token}
2273            )
2274
2275            self._invalidate_cache_and_stream(
2276                txn, self.get_user_by_access_token, (access_token,)
2277            )
2278
2279        await self.db_pool.runInteraction("delete_access_token", f)
2280
2281    async def delete_refresh_token(self, refresh_token: str) -> None:
2282        def f(txn):
2283            self.db_pool.simple_delete_one_txn(
2284                txn, table="refresh_tokens", keyvalues={"token": refresh_token}
2285            )
2286
2287        await self.db_pool.runInteraction("delete_refresh_token", f)
2288
2289    async def add_user_pending_deactivation(self, user_id: str) -> None:
2290        """
2291        Adds a user to the table of users who need to be parted from all the rooms they're
2292        in
2293        """
2294        await self.db_pool.simple_insert(
2295            "users_pending_deactivation",
2296            values={"user_id": user_id},
2297            desc="add_user_pending_deactivation",
2298        )
2299
2300    async def validate_threepid_session(
2301        self, session_id: str, client_secret: str, token: str, current_ts: int
2302    ) -> Optional[str]:
2303        """Attempt to validate a threepid session using a token
2304
2305        Args:
2306            session_id: The id of a validation session
2307            client_secret: A unique string provided by the client to help identify
2308                this validation attempt
2309            token: A validation token
2310            current_ts: The current unix time in milliseconds. Used for checking
2311                token expiry status
2312
2313        Raises:
2314            ThreepidValidationError: if a matching validation token was not found or has
2315                expired
2316
2317        Returns:
2318            A str representing a link to redirect the user to if there is one.
2319        """
2320
2321        # Insert everything into a transaction in order to run atomically
2322        def validate_threepid_session_txn(txn):
2323            row = self.db_pool.simple_select_one_txn(
2324                txn,
2325                table="threepid_validation_session",
2326                keyvalues={"session_id": session_id},
2327                retcols=["client_secret", "validated_at"],
2328                allow_none=True,
2329            )
2330
2331            if not row:
2332                if self._ignore_unknown_session_error:
2333                    # If we need to inhibit the error caused by an incorrect session ID,
2334                    # use None as placeholder values for the client secret and the
2335                    # validation timestamp.
2336                    # It shouldn't be an issue because they're both only checked after
2337                    # the token check, which should fail. And if it doesn't for some
2338                    # reason, the next check is on the client secret, which is NOT NULL,
2339                    # so we don't have to worry about the client secret matching by
2340                    # accident.
2341                    row = {"client_secret": None, "validated_at": None}
2342                else:
2343                    raise ThreepidValidationError("Unknown session_id")
2344
2345            retrieved_client_secret = row["client_secret"]
2346            validated_at = row["validated_at"]
2347
2348            row = self.db_pool.simple_select_one_txn(
2349                txn,
2350                table="threepid_validation_token",
2351                keyvalues={"session_id": session_id, "token": token},
2352                retcols=["expires", "next_link"],
2353                allow_none=True,
2354            )
2355
2356            if not row:
2357                raise ThreepidValidationError(
2358                    "Validation token not found or has expired"
2359                )
2360            expires = row["expires"]
2361            next_link = row["next_link"]
2362
2363            if retrieved_client_secret != client_secret:
2364                raise ThreepidValidationError(
2365                    "This client_secret does not match the provided session_id"
2366                )
2367
2368            # If the session is already validated, no need to revalidate
2369            if validated_at:
2370                return next_link
2371
2372            if expires <= current_ts:
2373                raise ThreepidValidationError(
2374                    "This token has expired. Please request a new one"
2375                )
2376
2377            # Looks good. Validate the session
2378            self.db_pool.simple_update_txn(
2379                txn,
2380                table="threepid_validation_session",
2381                keyvalues={"session_id": session_id},
2382                updatevalues={"validated_at": self._clock.time_msec()},
2383            )
2384
2385            return next_link
2386
2387        # Return next_link if it exists
2388        return await self.db_pool.runInteraction(
2389            "validate_threepid_session_txn", validate_threepid_session_txn
2390        )
2391
2392    async def start_or_continue_validation_session(
2393        self,
2394        medium: str,
2395        address: str,
2396        session_id: str,
2397        client_secret: str,
2398        send_attempt: int,
2399        next_link: Optional[str],
2400        token: str,
2401        token_expires: int,
2402    ) -> None:
2403        """Creates a new threepid validation session if it does not already
2404        exist and associates a new validation token with it
2405
2406        Args:
2407            medium: The medium of the 3PID
2408            address: The address of the 3PID
2409            session_id: The id of this validation session
2410            client_secret: A unique string provided by the client to help
2411                identify this validation attempt
2412            send_attempt: The latest send_attempt on this session
2413            next_link: The link to redirect the user to upon successful validation
2414            token: The validation token
2415            token_expires: The timestamp for which after the token will no
2416                longer be valid
2417        """
2418
2419        def start_or_continue_validation_session_txn(txn):
2420            # Create or update a validation session
2421            self.db_pool.simple_upsert_txn(
2422                txn,
2423                table="threepid_validation_session",
2424                keyvalues={"session_id": session_id},
2425                values={"last_send_attempt": send_attempt},
2426                insertion_values={
2427                    "medium": medium,
2428                    "address": address,
2429                    "client_secret": client_secret,
2430                },
2431            )
2432
2433            # Create a new validation token with this session ID
2434            self.db_pool.simple_insert_txn(
2435                txn,
2436                table="threepid_validation_token",
2437                values={
2438                    "session_id": session_id,
2439                    "token": token,
2440                    "next_link": next_link,
2441                    "expires": token_expires,
2442                },
2443            )
2444
2445        await self.db_pool.runInteraction(
2446            "start_or_continue_validation_session",
2447            start_or_continue_validation_session_txn,
2448        )
2449
2450
2451def find_max_generated_user_id_localpart(cur: Cursor) -> int:
2452    """
2453    Gets the localpart of the max current generated user ID.
2454
2455    Generated user IDs are integers, so we find the largest integer user ID
2456    already taken and return that.
2457    """
2458
2459    # We bound between '@0' and '@a' to avoid pulling the entire table
2460    # out.
2461    cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
2462
2463    regex = re.compile(r"^@(\d+):")
2464
2465    max_found = 0
2466
2467    for (user_id,) in cur:
2468        match = regex.search(user_id)
2469        if match:
2470            max_found = max(int(match.group(1)), max_found)
2471    return max_found
2472