1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2017-2018 New Vector Ltd 3# Copyright 2019,2020 The Matrix.org Foundation C.I.C. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16import logging 17import random 18import re 19from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast 20 21import attr 22 23from synapse.api.constants import UserTypes 24from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError 25from synapse.metrics.background_process_metrics import wrap_as_background_process 26from synapse.storage.database import ( 27 DatabasePool, 28 LoggingDatabaseConnection, 29 LoggingTransaction, 30) 31from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore 32from synapse.storage.databases.main.stats import StatsStore 33from synapse.storage.types import Cursor 34from synapse.storage.util.id_generators import IdGenerator 35from synapse.storage.util.sequence import build_sequence_generator 36from synapse.types import UserID, UserInfo 37from synapse.util.caches.descriptors import cached 38 39if TYPE_CHECKING: 40 from synapse.server import HomeServer 41 42THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 43 44logger = logging.getLogger(__name__) 45 46 47class ExternalIDReuseException(Exception): 48 """Exception if writing an external id for a user fails, 49 because this external id is given to an other user.""" 50 51 pass 52 53 54@attr.s(frozen=True, slots=True) 55class TokenLookupResult: 56 """Result of looking up an access token. 57 58 Attributes: 59 user_id: The user that this token authenticates as 60 is_guest 61 shadow_banned 62 token_id: The ID of the access token looked up 63 device_id: The device associated with the token, if any. 64 valid_until_ms: The timestamp the token expires, if any. 65 token_owner: The "owner" of the token. This is either the same as the 66 user, or a server admin who is logged in as the user. 67 token_used: True if this token was used at least once in a request. 68 This field can be out of date since `get_user_by_access_token` is 69 cached. 70 """ 71 72 user_id = attr.ib(type=str) 73 is_guest = attr.ib(type=bool, default=False) 74 shadow_banned = attr.ib(type=bool, default=False) 75 token_id = attr.ib(type=Optional[int], default=None) 76 device_id = attr.ib(type=Optional[str], default=None) 77 valid_until_ms = attr.ib(type=Optional[int], default=None) 78 token_owner = attr.ib(type=str) 79 token_used = attr.ib(type=bool, default=False) 80 81 # Make the token owner default to the user ID, which is the common case. 82 @token_owner.default 83 def _default_token_owner(self): 84 return self.user_id 85 86 87@attr.s(auto_attribs=True, frozen=True, slots=True) 88class RefreshTokenLookupResult: 89 """Result of looking up a refresh token.""" 90 91 user_id: str 92 """The user this token belongs to.""" 93 94 device_id: str 95 """The device associated with this refresh token.""" 96 97 token_id: int 98 """The ID of this refresh token.""" 99 100 next_token_id: Optional[int] 101 """The ID of the refresh token which replaced this one.""" 102 103 has_next_refresh_token_been_refreshed: bool 104 """True if the next refresh token was used for another refresh.""" 105 106 has_next_access_token_been_used: bool 107 """True if the next access token was already used at least once.""" 108 109 expiry_ts: Optional[int] 110 """The time at which the refresh token expires and can not be used. 111 If None, the refresh token doesn't expire.""" 112 113 ultimate_session_expiry_ts: Optional[int] 114 """The time at which the session comes to an end and can no longer be 115 refreshed. 116 If None, the session can be refreshed indefinitely.""" 117 118 119class RegistrationWorkerStore(CacheInvalidationWorkerStore): 120 def __init__( 121 self, 122 database: DatabasePool, 123 db_conn: LoggingDatabaseConnection, 124 hs: "HomeServer", 125 ): 126 super().__init__(database, db_conn, hs) 127 128 self.config = hs.config 129 130 # Note: we don't check this sequence for consistency as we'd have to 131 # call `find_max_generated_user_id_localpart` each time, which is 132 # expensive if there are many entries. 133 self._user_id_seq = build_sequence_generator( 134 db_conn, 135 database.engine, 136 find_max_generated_user_id_localpart, 137 "user_id_seq", 138 table=None, 139 id_column=None, 140 ) 141 142 self._account_validity_enabled = ( 143 hs.config.account_validity.account_validity_enabled 144 ) 145 self._account_validity_period = None 146 self._account_validity_startup_job_max_delta = None 147 if self._account_validity_enabled: 148 self._account_validity_period = ( 149 hs.config.account_validity.account_validity_period 150 ) 151 self._account_validity_startup_job_max_delta = ( 152 hs.config.account_validity.account_validity_startup_job_max_delta 153 ) 154 155 if hs.config.worker.run_background_tasks: 156 self._clock.call_later( 157 0.0, 158 self._set_expiration_date_when_missing, 159 ) 160 161 # Create a background job for culling expired 3PID validity tokens 162 if hs.config.worker.run_background_tasks: 163 self._clock.looping_call( 164 self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS 165 ) 166 167 @cached() 168 async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: 169 """Deprecated: use get_userinfo_by_id instead""" 170 return await self.db_pool.simple_select_one( 171 table="users", 172 keyvalues={"name": user_id}, 173 retcols=[ 174 "name", 175 "password_hash", 176 "is_guest", 177 "admin", 178 "consent_version", 179 "consent_server_notice_sent", 180 "appservice_id", 181 "creation_ts", 182 "user_type", 183 "deactivated", 184 "shadow_banned", 185 ], 186 allow_none=True, 187 desc="get_user_by_id", 188 ) 189 190 async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: 191 """Get a UserInfo object for a user by user ID. 192 193 Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, 194 this method should be cached. 195 196 Args: 197 user_id: The user to fetch user info for. 198 Returns: 199 `UserInfo` object if user found, otherwise `None`. 200 """ 201 user_data = await self.get_user_by_id(user_id) 202 if not user_data: 203 return None 204 return UserInfo( 205 appservice_id=user_data["appservice_id"], 206 consent_server_notice_sent=user_data["consent_server_notice_sent"], 207 consent_version=user_data["consent_version"], 208 creation_ts=user_data["creation_ts"], 209 is_admin=bool(user_data["admin"]), 210 is_deactivated=bool(user_data["deactivated"]), 211 is_guest=bool(user_data["is_guest"]), 212 is_shadow_banned=bool(user_data["shadow_banned"]), 213 user_id=UserID.from_string(user_data["name"]), 214 user_type=user_data["user_type"], 215 ) 216 217 async def is_trial_user(self, user_id: str) -> bool: 218 """Checks if user is in the "trial" period, i.e. within the first 219 N days of registration defined by `mau_trial_days` config 220 221 Args: 222 user_id: The user to check for trial status. 223 """ 224 225 info = await self.get_user_by_id(user_id) 226 if not info: 227 return False 228 229 now = self._clock.time_msec() 230 trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 231 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms 232 return is_trial 233 234 @cached() 235 async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: 236 """Get a user from the given access token. 237 238 Args: 239 token: The access token of a user. 240 Returns: 241 None, if the token did not match, otherwise a `TokenLookupResult` 242 """ 243 return await self.db_pool.runInteraction( 244 "get_user_by_access_token", self._query_for_auth, token 245 ) 246 247 @cached() 248 async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]: 249 """Get the expiration timestamp for the account bearing a given user ID. 250 251 Args: 252 user_id: The ID of the user. 253 Returns: 254 None, if the account has no expiration timestamp, otherwise int 255 representation of the timestamp (as a number of milliseconds since epoch). 256 """ 257 return await self.db_pool.simple_select_one_onecol( 258 table="account_validity", 259 keyvalues={"user_id": user_id}, 260 retcol="expiration_ts_ms", 261 allow_none=True, 262 desc="get_expiration_ts_for_user", 263 ) 264 265 async def is_account_expired(self, user_id: str, current_ts: int) -> bool: 266 """ 267 Returns whether an user account is expired. 268 269 Args: 270 user_id: The user's ID 271 current_ts: The current timestamp 272 273 Returns: 274 Whether the user account has expired 275 """ 276 expiration_ts = await self.get_expiration_ts_for_user(user_id) 277 return expiration_ts is not None and current_ts >= expiration_ts 278 279 async def set_account_validity_for_user( 280 self, 281 user_id: str, 282 expiration_ts: int, 283 email_sent: bool, 284 renewal_token: Optional[str] = None, 285 token_used_ts: Optional[int] = None, 286 ) -> None: 287 """Updates the account validity properties of the given account, with the 288 given values. 289 290 Args: 291 user_id: ID of the account to update properties for. 292 expiration_ts: New expiration date, as a timestamp in milliseconds 293 since epoch. 294 email_sent: True means a renewal email has been sent for this account 295 and there's no need to send another one for the current validity 296 period. 297 renewal_token: Renewal token the user can use to extend the validity 298 of their account. Defaults to no token. 299 token_used_ts: A timestamp of when the current token was used to renew 300 the account. 301 """ 302 303 def set_account_validity_for_user_txn(txn): 304 self.db_pool.simple_update_txn( 305 txn=txn, 306 table="account_validity", 307 keyvalues={"user_id": user_id}, 308 updatevalues={ 309 "expiration_ts_ms": expiration_ts, 310 "email_sent": email_sent, 311 "renewal_token": renewal_token, 312 "token_used_ts_ms": token_used_ts, 313 }, 314 ) 315 self._invalidate_cache_and_stream( 316 txn, self.get_expiration_ts_for_user, (user_id,) 317 ) 318 319 await self.db_pool.runInteraction( 320 "set_account_validity_for_user", set_account_validity_for_user_txn 321 ) 322 323 async def set_renewal_token_for_user( 324 self, user_id: str, renewal_token: str 325 ) -> None: 326 """Defines a renewal token for a given user, and clears the token_used timestamp. 327 328 Args: 329 user_id: ID of the user to set the renewal token for. 330 renewal_token: Random unique string that will be used to renew the 331 user's account. 332 333 Raises: 334 StoreError: The provided token is already set for another user. 335 """ 336 await self.db_pool.simple_update_one( 337 table="account_validity", 338 keyvalues={"user_id": user_id}, 339 updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, 340 desc="set_renewal_token_for_user", 341 ) 342 343 async def get_user_from_renewal_token( 344 self, renewal_token: str 345 ) -> Tuple[str, int, Optional[int]]: 346 """Get a user ID and renewal status from a renewal token. 347 348 Args: 349 renewal_token: The renewal token to perform the lookup with. 350 351 Returns: 352 A tuple of containing the following values: 353 * The ID of a user to which the token belongs. 354 * An int representing the user's expiry timestamp as milliseconds since the 355 epoch, or 0 if the token was invalid. 356 * An optional int representing the timestamp of when the user renewed their 357 account timestamp as milliseconds since the epoch. None if the account 358 has not been renewed using the current token yet. 359 """ 360 ret_dict = await self.db_pool.simple_select_one( 361 table="account_validity", 362 keyvalues={"renewal_token": renewal_token}, 363 retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], 364 desc="get_user_from_renewal_token", 365 ) 366 367 return ( 368 ret_dict["user_id"], 369 ret_dict["expiration_ts_ms"], 370 ret_dict["token_used_ts_ms"], 371 ) 372 373 async def get_renewal_token_for_user(self, user_id: str) -> str: 374 """Get the renewal token associated with a given user ID. 375 376 Args: 377 user_id: The user ID to lookup a token for. 378 379 Returns: 380 The renewal token associated with this user ID. 381 """ 382 return await self.db_pool.simple_select_one_onecol( 383 table="account_validity", 384 keyvalues={"user_id": user_id}, 385 retcol="renewal_token", 386 desc="get_renewal_token_for_user", 387 ) 388 389 async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: 390 """Selects users whose account will expire in the [now, now + renew_at] time 391 window (see configuration for account_validity for information on what renew_at 392 refers to). 393 394 Returns: 395 A list of dictionaries, each with a user ID and expiration time (in milliseconds). 396 """ 397 398 def select_users_txn(txn, now_ms, renew_at): 399 sql = ( 400 "SELECT user_id, expiration_ts_ms FROM account_validity" 401 " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?" 402 ) 403 values = [False, now_ms, renew_at] 404 txn.execute(sql, values) 405 return self.db_pool.cursor_to_dict(txn) 406 407 return await self.db_pool.runInteraction( 408 "get_users_expiring_soon", 409 select_users_txn, 410 self._clock.time_msec(), 411 self.config.account_validity.account_validity_renew_at, 412 ) 413 414 async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: 415 """Sets or unsets the flag that indicates whether a renewal email has been sent 416 to the user (and the user hasn't renewed their account yet). 417 418 Args: 419 user_id: ID of the user to set/unset the flag for. 420 email_sent: Flag which indicates whether a renewal email has been sent 421 to this user. 422 """ 423 await self.db_pool.simple_update_one( 424 table="account_validity", 425 keyvalues={"user_id": user_id}, 426 updatevalues={"email_sent": email_sent}, 427 desc="set_renewal_mail_status", 428 ) 429 430 async def delete_account_validity_for_user(self, user_id: str) -> None: 431 """Deletes the entry for the given user in the account validity table, removing 432 their expiration date and renewal token. 433 434 Args: 435 user_id: ID of the user to remove from the account validity table. 436 """ 437 await self.db_pool.simple_delete_one( 438 table="account_validity", 439 keyvalues={"user_id": user_id}, 440 desc="delete_account_validity_for_user", 441 ) 442 443 async def is_server_admin(self, user: UserID) -> bool: 444 """Determines if a user is an admin of this homeserver. 445 446 Args: 447 user: user ID of the user to test 448 449 Returns: 450 true iff the user is a server admin, false otherwise. 451 """ 452 res = await self.db_pool.simple_select_one_onecol( 453 table="users", 454 keyvalues={"name": user.to_string()}, 455 retcol="admin", 456 allow_none=True, 457 desc="is_server_admin", 458 ) 459 460 return bool(res) if res else False 461 462 async def set_server_admin(self, user: UserID, admin: bool) -> None: 463 """Sets whether a user is an admin of this homeserver. 464 465 Args: 466 user: user ID of the user to test 467 admin: true iff the user is to be a server admin, false otherwise. 468 """ 469 470 def set_server_admin_txn(txn): 471 self.db_pool.simple_update_one_txn( 472 txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0} 473 ) 474 self._invalidate_cache_and_stream( 475 txn, self.get_user_by_id, (user.to_string(),) 476 ) 477 478 await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) 479 480 async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: 481 """Sets whether a user shadow-banned. 482 483 Args: 484 user: user ID of the user to test 485 shadow_banned: true iff the user is to be shadow-banned, false otherwise. 486 """ 487 488 def set_shadow_banned_txn(txn: LoggingTransaction) -> None: 489 user_id = user.to_string() 490 self.db_pool.simple_update_one_txn( 491 txn, 492 table="users", 493 keyvalues={"name": user_id}, 494 updatevalues={"shadow_banned": shadow_banned}, 495 ) 496 # In order for this to apply immediately, clear the cache for this user. 497 tokens = self.db_pool.simple_select_onecol_txn( 498 txn, 499 table="access_tokens", 500 keyvalues={"user_id": user_id}, 501 retcol="token", 502 ) 503 for token in tokens: 504 self._invalidate_cache_and_stream( 505 txn, self.get_user_by_access_token, (token,) 506 ) 507 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 508 509 await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) 510 511 async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: 512 """Sets the user type. 513 514 Args: 515 user: user ID of the user. 516 user_type: type of the user or None for a user without a type. 517 """ 518 519 def set_user_type_txn(txn): 520 self.db_pool.simple_update_one_txn( 521 txn, "users", {"name": user.to_string()}, {"user_type": user_type} 522 ) 523 self._invalidate_cache_and_stream( 524 txn, self.get_user_by_id, (user.to_string(),) 525 ) 526 527 await self.db_pool.runInteraction("set_user_type", set_user_type_txn) 528 529 def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: 530 sql = """ 531 SELECT users.name as user_id, 532 users.is_guest, 533 users.shadow_banned, 534 access_tokens.id as token_id, 535 access_tokens.device_id, 536 access_tokens.valid_until_ms, 537 access_tokens.user_id as token_owner, 538 access_tokens.used as token_used 539 FROM users 540 INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) 541 WHERE token = ? 542 """ 543 544 txn.execute(sql, (token,)) 545 rows = self.db_pool.cursor_to_dict(txn) 546 547 if rows: 548 row = rows[0] 549 550 # This field is nullable, ensure it comes out as a boolean 551 if row["token_used"] is None: 552 row["token_used"] = False 553 554 return TokenLookupResult(**row) 555 556 return None 557 558 @cached() 559 async def is_real_user(self, user_id: str) -> bool: 560 """Determines if the user is a real user, ie does not have a 'user_type'. 561 562 Args: 563 user_id: user id to test 564 565 Returns: 566 True if user 'user_type' is null or empty string 567 """ 568 return await self.db_pool.runInteraction( 569 "is_real_user", self.is_real_user_txn, user_id 570 ) 571 572 @cached() 573 async def is_support_user(self, user_id: str) -> bool: 574 """Determines if the user is of type UserTypes.SUPPORT 575 576 Args: 577 user_id: user id to test 578 579 Returns: 580 True if user is of type UserTypes.SUPPORT 581 """ 582 return await self.db_pool.runInteraction( 583 "is_support_user", self.is_support_user_txn, user_id 584 ) 585 586 def is_real_user_txn(self, txn, user_id): 587 res = self.db_pool.simple_select_one_onecol_txn( 588 txn=txn, 589 table="users", 590 keyvalues={"name": user_id}, 591 retcol="user_type", 592 allow_none=True, 593 ) 594 return res is None 595 596 def is_support_user_txn(self, txn, user_id): 597 res = self.db_pool.simple_select_one_onecol_txn( 598 txn=txn, 599 table="users", 600 keyvalues={"name": user_id}, 601 retcol="user_type", 602 allow_none=True, 603 ) 604 return True if res == UserTypes.SUPPORT else False 605 606 async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]: 607 """Gets users that match user_id case insensitively. 608 609 Returns: 610 A mapping of user_id -> password_hash. 611 """ 612 613 def f(txn): 614 sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" 615 txn.execute(sql, (user_id,)) 616 return dict(txn) 617 618 return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) 619 620 async def record_user_external_id( 621 self, auth_provider: str, external_id: str, user_id: str 622 ) -> None: 623 """Record a mapping from an external user id to a mxid 624 625 Args: 626 auth_provider: identifier for the remote auth provider 627 external_id: id on that system 628 user_id: complete mxid that it is mapped to 629 Raises: 630 ExternalIDReuseException if the new external_id could not be mapped. 631 """ 632 633 try: 634 await self.db_pool.runInteraction( 635 "record_user_external_id", 636 self._record_user_external_id_txn, 637 auth_provider, 638 external_id, 639 user_id, 640 ) 641 except self.database_engine.module.IntegrityError: 642 raise ExternalIDReuseException() 643 644 def _record_user_external_id_txn( 645 self, 646 txn: LoggingTransaction, 647 auth_provider: str, 648 external_id: str, 649 user_id: str, 650 ) -> None: 651 652 self.db_pool.simple_insert_txn( 653 txn, 654 table="user_external_ids", 655 values={ 656 "auth_provider": auth_provider, 657 "external_id": external_id, 658 "user_id": user_id, 659 }, 660 ) 661 662 async def remove_user_external_id( 663 self, auth_provider: str, external_id: str, user_id: str 664 ) -> None: 665 """Remove a mapping from an external user id to a mxid 666 If the mapping is not found, this method does nothing. 667 Args: 668 auth_provider: identifier for the remote auth provider 669 external_id: id on that system 670 user_id: complete mxid that it is mapped to 671 """ 672 await self.db_pool.simple_delete( 673 table="user_external_ids", 674 keyvalues={ 675 "auth_provider": auth_provider, 676 "external_id": external_id, 677 "user_id": user_id, 678 }, 679 desc="remove_user_external_id", 680 ) 681 682 async def replace_user_external_id( 683 self, 684 record_external_ids: List[Tuple[str, str]], 685 user_id: str, 686 ) -> None: 687 """Replace mappings from external user ids to a mxid in a single transaction. 688 All mappings are deleted and the new ones are created. 689 690 Args: 691 record_external_ids: 692 List with tuple of auth_provider and external_id to record 693 user_id: complete mxid that it is mapped to 694 Raises: 695 ExternalIDReuseException if the new external_id could not be mapped. 696 """ 697 698 def _remove_user_external_ids_txn( 699 txn: LoggingTransaction, 700 user_id: str, 701 ) -> None: 702 """Remove all mappings from external user ids to a mxid 703 If these mappings are not found, this method does nothing. 704 705 Args: 706 user_id: complete mxid that it is mapped to 707 """ 708 709 self.db_pool.simple_delete_txn( 710 txn, 711 table="user_external_ids", 712 keyvalues={"user_id": user_id}, 713 ) 714 715 def _replace_user_external_id_txn( 716 txn: LoggingTransaction, 717 ): 718 _remove_user_external_ids_txn(txn, user_id) 719 720 for auth_provider, external_id in record_external_ids: 721 self._record_user_external_id_txn( 722 txn, 723 auth_provider, 724 external_id, 725 user_id, 726 ) 727 728 try: 729 await self.db_pool.runInteraction( 730 "replace_user_external_id", 731 _replace_user_external_id_txn, 732 ) 733 except self.database_engine.module.IntegrityError: 734 raise ExternalIDReuseException() 735 736 async def get_user_by_external_id( 737 self, auth_provider: str, external_id: str 738 ) -> Optional[str]: 739 """Look up a user by their external auth id 740 741 Args: 742 auth_provider: identifier for the remote auth provider 743 external_id: id on that system 744 745 Returns: 746 the mxid of the user, or None if they are not known 747 """ 748 return await self.db_pool.simple_select_one_onecol( 749 table="user_external_ids", 750 keyvalues={"auth_provider": auth_provider, "external_id": external_id}, 751 retcol="user_id", 752 allow_none=True, 753 desc="get_user_by_external_id", 754 ) 755 756 async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]: 757 """Look up external ids for the given user 758 759 Args: 760 mxid: the MXID to be looked up 761 762 Returns: 763 Tuples of (auth_provider, external_id) 764 """ 765 res = await self.db_pool.simple_select_list( 766 table="user_external_ids", 767 keyvalues={"user_id": mxid}, 768 retcols=("auth_provider", "external_id"), 769 desc="get_external_ids_by_user", 770 ) 771 return [(r["auth_provider"], r["external_id"]) for r in res] 772 773 async def count_all_users(self): 774 """Counts all users registered on the homeserver.""" 775 776 def _count_users(txn): 777 txn.execute("SELECT COUNT(*) AS users FROM users") 778 rows = self.db_pool.cursor_to_dict(txn) 779 if rows: 780 return rows[0]["users"] 781 return 0 782 783 return await self.db_pool.runInteraction("count_users", _count_users) 784 785 async def count_daily_user_type(self) -> Dict[str, int]: 786 """ 787 Counts 1) native non guest users 788 2) native guests users 789 3) bridged users 790 who registered on the homeserver in the past 24 hours 791 """ 792 793 def _count_daily_user_type(txn): 794 yesterday = int(self._clock.time()) - (60 * 60 * 24) 795 796 sql = """ 797 SELECT user_type, COUNT(*) AS count FROM ( 798 SELECT 799 CASE 800 WHEN is_guest=0 AND appservice_id IS NULL THEN 'native' 801 WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest' 802 WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged' 803 END AS user_type 804 FROM users 805 WHERE creation_ts > ? 806 ) AS t GROUP BY user_type 807 """ 808 results = {"native": 0, "guest": 0, "bridged": 0} 809 txn.execute(sql, (yesterday,)) 810 for row in txn: 811 results[row[0]] = row[1] 812 return results 813 814 return await self.db_pool.runInteraction( 815 "count_daily_user_type", _count_daily_user_type 816 ) 817 818 async def count_nonbridged_users(self): 819 def _count_users(txn): 820 txn.execute( 821 """ 822 SELECT COUNT(*) FROM users 823 WHERE appservice_id IS NULL 824 """ 825 ) 826 (count,) = txn.fetchone() 827 return count 828 829 return await self.db_pool.runInteraction("count_users", _count_users) 830 831 async def count_real_users(self): 832 """Counts all users without a special user_type registered on the homeserver.""" 833 834 def _count_users(txn): 835 txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") 836 rows = self.db_pool.cursor_to_dict(txn) 837 if rows: 838 return rows[0]["users"] 839 return 0 840 841 return await self.db_pool.runInteraction("count_real_users", _count_users) 842 843 async def generate_user_id(self) -> str: 844 """Generate a suitable localpart for a guest user 845 846 Returns: a (hopefully) free localpart 847 """ 848 next_id = await self.db_pool.runInteraction( 849 "generate_user_id", self._user_id_seq.get_next_id_txn 850 ) 851 852 return str(next_id) 853 854 async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]: 855 """Returns user id from threepid 856 857 Args: 858 medium: threepid medium e.g. email 859 address: threepid address e.g. me@example.com. This must already be 860 in canonical form. 861 862 Returns: 863 The user ID or None if no user id/threepid mapping exists 864 """ 865 user_id = await self.db_pool.runInteraction( 866 "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address 867 ) 868 return user_id 869 870 def get_user_id_by_threepid_txn( 871 self, txn, medium: str, address: str 872 ) -> Optional[str]: 873 """Returns user id from threepid 874 875 Args: 876 txn (cursor): 877 medium: threepid medium e.g. email 878 address: threepid address e.g. me@example.com 879 880 Returns: 881 user id, or None if no user id/threepid mapping exists 882 """ 883 ret = self.db_pool.simple_select_one_txn( 884 txn, 885 "user_threepids", 886 {"medium": medium, "address": address}, 887 ["user_id"], 888 True, 889 ) 890 if ret: 891 return ret["user_id"] 892 return None 893 894 async def user_add_threepid( 895 self, 896 user_id: str, 897 medium: str, 898 address: str, 899 validated_at: int, 900 added_at: int, 901 ) -> None: 902 await self.db_pool.simple_upsert( 903 "user_threepids", 904 {"medium": medium, "address": address}, 905 {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, 906 ) 907 908 async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: 909 return await self.db_pool.simple_select_list( 910 "user_threepids", 911 {"user_id": user_id}, 912 ["medium", "address", "validated_at", "added_at"], 913 "user_get_threepids", 914 ) 915 916 async def user_delete_threepid( 917 self, user_id: str, medium: str, address: str 918 ) -> None: 919 await self.db_pool.simple_delete( 920 "user_threepids", 921 keyvalues={"user_id": user_id, "medium": medium, "address": address}, 922 desc="user_delete_threepid", 923 ) 924 925 async def user_delete_threepids(self, user_id: str) -> None: 926 """Delete all threepid this user has bound 927 928 Args: 929 user_id: The user id to delete all threepids of 930 931 """ 932 await self.db_pool.simple_delete( 933 "user_threepids", 934 keyvalues={"user_id": user_id}, 935 desc="user_delete_threepids", 936 ) 937 938 async def add_user_bound_threepid( 939 self, user_id: str, medium: str, address: str, id_server: str 940 ): 941 """The server proxied a bind request to the given identity server on 942 behalf of the given user. We need to remember this in case the user 943 asks us to unbind the threepid. 944 945 Args: 946 user_id 947 medium 948 address 949 id_server 950 """ 951 # We need to use an upsert, in case they user had already bound the 952 # threepid 953 await self.db_pool.simple_upsert( 954 table="user_threepid_id_server", 955 keyvalues={ 956 "user_id": user_id, 957 "medium": medium, 958 "address": address, 959 "id_server": id_server, 960 }, 961 values={}, 962 insertion_values={}, 963 desc="add_user_bound_threepid", 964 ) 965 966 async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]: 967 """Get the threepids that a user has bound to an identity server through the homeserver 968 The homeserver remembers where binds to an identity server occurred. Using this 969 method can retrieve those threepids. 970 971 Args: 972 user_id: The ID of the user to retrieve threepids for 973 974 Returns: 975 List of dictionaries containing the following keys: 976 medium (str): The medium of the threepid (e.g "email") 977 address (str): The address of the threepid (e.g "bob@example.com") 978 """ 979 return await self.db_pool.simple_select_list( 980 table="user_threepid_id_server", 981 keyvalues={"user_id": user_id}, 982 retcols=["medium", "address"], 983 desc="user_get_bound_threepids", 984 ) 985 986 async def remove_user_bound_threepid( 987 self, user_id: str, medium: str, address: str, id_server: str 988 ) -> None: 989 """The server proxied an unbind request to the given identity server on 990 behalf of the given user, so we remove the mapping of threepid to 991 identity server. 992 993 Args: 994 user_id 995 medium 996 address 997 id_server 998 """ 999 await self.db_pool.simple_delete( 1000 table="user_threepid_id_server", 1001 keyvalues={ 1002 "user_id": user_id, 1003 "medium": medium, 1004 "address": address, 1005 "id_server": id_server, 1006 }, 1007 desc="remove_user_bound_threepid", 1008 ) 1009 1010 async def get_id_servers_user_bound( 1011 self, user_id: str, medium: str, address: str 1012 ) -> List[str]: 1013 """Get the list of identity servers that the server proxied bind 1014 requests to for given user and threepid 1015 1016 Args: 1017 user_id: The user to query for identity servers. 1018 medium: The medium to query for identity servers. 1019 address: The address to query for identity servers. 1020 1021 Returns: 1022 A list of identity servers 1023 """ 1024 return await self.db_pool.simple_select_onecol( 1025 table="user_threepid_id_server", 1026 keyvalues={"user_id": user_id, "medium": medium, "address": address}, 1027 retcol="id_server", 1028 desc="get_id_servers_user_bound", 1029 ) 1030 1031 @cached() 1032 async def get_user_deactivated_status(self, user_id: str) -> bool: 1033 """Retrieve the value for the `deactivated` property for the provided user. 1034 1035 Args: 1036 user_id: The ID of the user to retrieve the status for. 1037 1038 Returns: 1039 True if the user was deactivated, false if the user is still active. 1040 """ 1041 1042 res = await self.db_pool.simple_select_one_onecol( 1043 table="users", 1044 keyvalues={"name": user_id}, 1045 retcol="deactivated", 1046 desc="get_user_deactivated_status", 1047 ) 1048 1049 # Convert the integer into a boolean. 1050 return res == 1 1051 1052 async def get_threepid_validation_session( 1053 self, 1054 medium: Optional[str], 1055 client_secret: str, 1056 address: Optional[str] = None, 1057 sid: Optional[str] = None, 1058 validated: Optional[bool] = True, 1059 ) -> Optional[Dict[str, Any]]: 1060 """Gets a session_id and last_send_attempt (if available) for a 1061 combination of validation metadata 1062 1063 Args: 1064 medium: The medium of the 3PID 1065 client_secret: A unique string provided by the client to help identify this 1066 validation attempt 1067 address: The address of the 3PID 1068 sid: The ID of the validation session 1069 validated: Whether sessions should be filtered by 1070 whether they have been validated already or not. None to 1071 perform no filtering 1072 1073 Returns: 1074 A dict containing the following: 1075 * address - address of the 3pid 1076 * medium - medium of the 3pid 1077 * client_secret - a secret provided by the client for this validation session 1078 * session_id - ID of the validation session 1079 * send_attempt - a number serving to dedupe send attempts for this session 1080 * validated_at - timestamp of when this session was validated if so 1081 1082 Otherwise None if a validation session is not found 1083 """ 1084 if not client_secret: 1085 raise SynapseError( 1086 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM 1087 ) 1088 1089 keyvalues = {"client_secret": client_secret} 1090 if medium: 1091 keyvalues["medium"] = medium 1092 if address: 1093 keyvalues["address"] = address 1094 if sid: 1095 keyvalues["session_id"] = sid 1096 1097 assert address or sid 1098 1099 def get_threepid_validation_session_txn(txn): 1100 sql = """ 1101 SELECT address, session_id, medium, client_secret, 1102 last_send_attempt, validated_at 1103 FROM threepid_validation_session WHERE %s 1104 """ % ( 1105 " AND ".join("%s = ?" % k for k in keyvalues.keys()), 1106 ) 1107 1108 if validated is not None: 1109 sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") 1110 1111 sql += " LIMIT 1" 1112 1113 txn.execute(sql, list(keyvalues.values())) 1114 rows = self.db_pool.cursor_to_dict(txn) 1115 if not rows: 1116 return None 1117 1118 return rows[0] 1119 1120 return await self.db_pool.runInteraction( 1121 "get_threepid_validation_session", get_threepid_validation_session_txn 1122 ) 1123 1124 async def delete_threepid_session(self, session_id: str) -> None: 1125 """Removes a threepid validation session from the database. This can 1126 be done after validation has been performed and whatever action was 1127 waiting on it has been carried out 1128 1129 Args: 1130 session_id: The ID of the session to delete 1131 """ 1132 1133 def delete_threepid_session_txn(txn): 1134 self.db_pool.simple_delete_txn( 1135 txn, 1136 table="threepid_validation_token", 1137 keyvalues={"session_id": session_id}, 1138 ) 1139 self.db_pool.simple_delete_txn( 1140 txn, 1141 table="threepid_validation_session", 1142 keyvalues={"session_id": session_id}, 1143 ) 1144 1145 await self.db_pool.runInteraction( 1146 "delete_threepid_session", delete_threepid_session_txn 1147 ) 1148 1149 @wrap_as_background_process("cull_expired_threepid_validation_tokens") 1150 async def cull_expired_threepid_validation_tokens(self) -> None: 1151 """Remove threepid validation tokens with expiry dates that have passed""" 1152 1153 def cull_expired_threepid_validation_tokens_txn(txn, ts): 1154 sql = """ 1155 DELETE FROM threepid_validation_token WHERE 1156 expires < ? 1157 """ 1158 txn.execute(sql, (ts,)) 1159 1160 await self.db_pool.runInteraction( 1161 "cull_expired_threepid_validation_tokens", 1162 cull_expired_threepid_validation_tokens_txn, 1163 self._clock.time_msec(), 1164 ) 1165 1166 @wrap_as_background_process("account_validity_set_expiration_dates") 1167 async def _set_expiration_date_when_missing(self): 1168 """ 1169 Retrieves the list of registered users that don't have an expiration date, and 1170 adds an expiration date for each of them. 1171 """ 1172 1173 def select_users_with_no_expiration_date_txn(txn): 1174 """Retrieves the list of registered users with no expiration date from the 1175 database, filtering out deactivated users. 1176 """ 1177 sql = ( 1178 "SELECT users.name FROM users" 1179 " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" 1180 " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" 1181 ) 1182 txn.execute(sql, []) 1183 1184 res = self.db_pool.cursor_to_dict(txn) 1185 if res: 1186 for user in res: 1187 self.set_expiration_date_for_user_txn( 1188 txn, user["name"], use_delta=True 1189 ) 1190 1191 await self.db_pool.runInteraction( 1192 "get_users_with_no_expiration_date", 1193 select_users_with_no_expiration_date_txn, 1194 ) 1195 1196 def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): 1197 """Sets an expiration date to the account with the given user ID. 1198 1199 Args: 1200 user_id (str): User ID to set an expiration date for. 1201 use_delta (bool): If set to False, the expiration date for the user will be 1202 now + validity period. If set to True, this expiration date will be a 1203 random value in the [now + period - d ; now + period] range, d being a 1204 delta equal to 10% of the validity period. 1205 """ 1206 now_ms = self._clock.time_msec() 1207 assert self._account_validity_period is not None 1208 expiration_ts = now_ms + self._account_validity_period 1209 1210 if use_delta: 1211 assert self._account_validity_startup_job_max_delta is not None 1212 expiration_ts = random.randrange( 1213 int(expiration_ts - self._account_validity_startup_job_max_delta), 1214 expiration_ts, 1215 ) 1216 1217 self.db_pool.simple_upsert_txn( 1218 txn, 1219 "account_validity", 1220 keyvalues={"user_id": user_id}, 1221 values={"expiration_ts_ms": expiration_ts, "email_sent": False}, 1222 ) 1223 1224 async def get_user_pending_deactivation(self) -> Optional[str]: 1225 """ 1226 Gets one user from the table of users waiting to be parted from all the rooms 1227 they're in. 1228 """ 1229 return await self.db_pool.simple_select_one_onecol( 1230 "users_pending_deactivation", 1231 keyvalues={}, 1232 retcol="user_id", 1233 allow_none=True, 1234 desc="get_users_pending_deactivation", 1235 ) 1236 1237 async def del_user_pending_deactivation(self, user_id: str) -> None: 1238 """ 1239 Removes the given user to the table of users who need to be parted from all the 1240 rooms they're in, effectively marking that user as fully deactivated. 1241 """ 1242 # XXX: This should be simple_delete_one but we failed to put a unique index on 1243 # the table, so somehow duplicate entries have ended up in it. 1244 await self.db_pool.simple_delete( 1245 "users_pending_deactivation", 1246 keyvalues={"user_id": user_id}, 1247 desc="del_user_pending_deactivation", 1248 ) 1249 1250 async def get_access_token_last_validated(self, token_id: int) -> int: 1251 """Retrieves the time (in milliseconds) of the last validation of an access token. 1252 1253 Args: 1254 token_id: The ID of the access token to update. 1255 Raises: 1256 StoreError if the access token was not found. 1257 1258 Returns: 1259 The last validation time. 1260 """ 1261 result = await self.db_pool.simple_select_one_onecol( 1262 "access_tokens", {"id": token_id}, "last_validated" 1263 ) 1264 1265 # If this token has not been validated (since starting to track this), 1266 # return 0 instead of None. 1267 return result or 0 1268 1269 async def update_access_token_last_validated(self, token_id: int) -> None: 1270 """Updates the last time an access token was validated. 1271 1272 Args: 1273 token_id: The ID of the access token to update. 1274 Raises: 1275 StoreError if there was a problem updating this. 1276 """ 1277 now = self._clock.time_msec() 1278 1279 await self.db_pool.simple_update_one( 1280 "access_tokens", 1281 {"id": token_id}, 1282 {"last_validated": now}, 1283 desc="update_access_token_last_validated", 1284 ) 1285 1286 async def registration_token_is_valid(self, token: str) -> bool: 1287 """Checks if a token can be used to authenticate a registration. 1288 1289 Args: 1290 token: The registration token to be checked 1291 Returns: 1292 True if the token is valid, False otherwise. 1293 """ 1294 res = await self.db_pool.simple_select_one( 1295 "registration_tokens", 1296 keyvalues={"token": token}, 1297 retcols=["uses_allowed", "pending", "completed", "expiry_time"], 1298 allow_none=True, 1299 ) 1300 1301 # Check if the token exists 1302 if res is None: 1303 return False 1304 1305 # Check if the token has expired 1306 now = self._clock.time_msec() 1307 if res["expiry_time"] and res["expiry_time"] < now: 1308 return False 1309 1310 # Check if the token has been used up 1311 if ( 1312 res["uses_allowed"] 1313 and res["pending"] + res["completed"] >= res["uses_allowed"] 1314 ): 1315 return False 1316 1317 # Otherwise, the token is valid 1318 return True 1319 1320 async def set_registration_token_pending(self, token: str) -> None: 1321 """Increment the pending registrations counter for a token. 1322 1323 Args: 1324 token: The registration token pending use 1325 """ 1326 1327 def _set_registration_token_pending_txn(txn): 1328 pending = self.db_pool.simple_select_one_onecol_txn( 1329 txn, 1330 "registration_tokens", 1331 keyvalues={"token": token}, 1332 retcol="pending", 1333 ) 1334 self.db_pool.simple_update_one_txn( 1335 txn, 1336 "registration_tokens", 1337 keyvalues={"token": token}, 1338 updatevalues={"pending": pending + 1}, 1339 ) 1340 1341 return await self.db_pool.runInteraction( 1342 "set_registration_token_pending", _set_registration_token_pending_txn 1343 ) 1344 1345 async def use_registration_token(self, token: str) -> None: 1346 """Complete a use of the given registration token. 1347 1348 The `pending` counter will be decremented, and the `completed` 1349 counter will be incremented. 1350 1351 Args: 1352 token: The registration token to be 'used' 1353 """ 1354 1355 def _use_registration_token_txn(txn): 1356 # Normally, res is Optional[Dict[str, Any]]. 1357 # Override type because the return type is only optional if 1358 # allow_none is True, and we don't want mypy throwing errors 1359 # about None not being indexable. 1360 res = cast( 1361 Dict[str, Any], 1362 self.db_pool.simple_select_one_txn( 1363 txn, 1364 "registration_tokens", 1365 keyvalues={"token": token}, 1366 retcols=["pending", "completed"], 1367 ), 1368 ) 1369 1370 # Decrement pending and increment completed 1371 self.db_pool.simple_update_one_txn( 1372 txn, 1373 "registration_tokens", 1374 keyvalues={"token": token}, 1375 updatevalues={ 1376 "completed": res["completed"] + 1, 1377 "pending": res["pending"] - 1, 1378 }, 1379 ) 1380 1381 return await self.db_pool.runInteraction( 1382 "use_registration_token", _use_registration_token_txn 1383 ) 1384 1385 async def get_registration_tokens( 1386 self, valid: Optional[bool] = None 1387 ) -> List[Dict[str, Any]]: 1388 """List all registration tokens. Used by the admin API. 1389 1390 Args: 1391 valid: If True, only valid tokens are returned. 1392 If False, only invalid tokens are returned. 1393 Default is None: return all tokens regardless of validity. 1394 1395 Returns: 1396 A list of dicts, each containing details of a token. 1397 """ 1398 1399 def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): 1400 if valid is None: 1401 # Return all tokens regardless of validity 1402 txn.execute("SELECT * FROM registration_tokens") 1403 1404 elif valid: 1405 # Select valid tokens only 1406 sql = ( 1407 "SELECT * FROM registration_tokens WHERE " 1408 "(uses_allowed > pending + completed OR uses_allowed IS NULL) " 1409 "AND (expiry_time > ? OR expiry_time IS NULL)" 1410 ) 1411 txn.execute(sql, [now]) 1412 1413 else: 1414 # Select invalid tokens only 1415 sql = ( 1416 "SELECT * FROM registration_tokens WHERE " 1417 "uses_allowed <= pending + completed OR expiry_time <= ?" 1418 ) 1419 txn.execute(sql, [now]) 1420 1421 return self.db_pool.cursor_to_dict(txn) 1422 1423 return await self.db_pool.runInteraction( 1424 "select_registration_tokens", 1425 select_registration_tokens_txn, 1426 self._clock.time_msec(), 1427 valid, 1428 ) 1429 1430 async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: 1431 """Get info about the given registration token. Used by the admin API. 1432 1433 Args: 1434 token: The token to retrieve information about. 1435 1436 Returns: 1437 A dict, or None if token doesn't exist. 1438 """ 1439 return await self.db_pool.simple_select_one( 1440 "registration_tokens", 1441 keyvalues={"token": token}, 1442 retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], 1443 allow_none=True, 1444 desc="get_one_registration_token", 1445 ) 1446 1447 async def generate_registration_token( 1448 self, length: int, chars: str 1449 ) -> Optional[str]: 1450 """Generate a random registration token. Used by the admin API. 1451 1452 Args: 1453 length: The length of the token to generate. 1454 chars: A string of the characters allowed in the generated token. 1455 1456 Returns: 1457 The generated token. 1458 1459 Raises: 1460 SynapseError if a unique registration token could still not be 1461 generated after a few tries. 1462 """ 1463 # Make a few attempts at generating a unique token of the required 1464 # length before failing. 1465 for _i in range(3): 1466 # Generate token 1467 token = "".join(random.choices(chars, k=length)) 1468 1469 # Check if the token already exists 1470 existing_token = await self.db_pool.simple_select_one_onecol( 1471 "registration_tokens", 1472 keyvalues={"token": token}, 1473 retcol="token", 1474 allow_none=True, 1475 desc="check_if_registration_token_exists", 1476 ) 1477 1478 if existing_token is None: 1479 # The generated token doesn't exist yet, return it 1480 return token 1481 1482 raise SynapseError( 1483 500, 1484 "Unable to generate a unique registration token. Try again with a greater length", 1485 Codes.UNKNOWN, 1486 ) 1487 1488 async def create_registration_token( 1489 self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] 1490 ) -> bool: 1491 """Create a new registration token. Used by the admin API. 1492 1493 Args: 1494 token: The token to create. 1495 uses_allowed: The number of times the token can be used to complete 1496 a registration before it becomes invalid. A value of None indicates 1497 unlimited uses. 1498 expiry_time: The latest time the token is valid. Given as the 1499 number of milliseconds since 1970-01-01 00:00:00 UTC. A value of 1500 None indicates that the token does not expire. 1501 1502 Returns: 1503 Whether the row was inserted or not. 1504 """ 1505 1506 def _create_registration_token_txn(txn): 1507 row = self.db_pool.simple_select_one_txn( 1508 txn, 1509 "registration_tokens", 1510 keyvalues={"token": token}, 1511 retcols=["token"], 1512 allow_none=True, 1513 ) 1514 1515 if row is not None: 1516 # Token already exists 1517 return False 1518 1519 self.db_pool.simple_insert_txn( 1520 txn, 1521 "registration_tokens", 1522 values={ 1523 "token": token, 1524 "uses_allowed": uses_allowed, 1525 "pending": 0, 1526 "completed": 0, 1527 "expiry_time": expiry_time, 1528 }, 1529 ) 1530 1531 return True 1532 1533 return await self.db_pool.runInteraction( 1534 "create_registration_token", _create_registration_token_txn 1535 ) 1536 1537 async def update_registration_token( 1538 self, token: str, updatevalues: Dict[str, Optional[int]] 1539 ) -> Optional[Dict[str, Any]]: 1540 """Update a registration token. Used by the admin API. 1541 1542 Args: 1543 token: The token to update. 1544 updatevalues: A dict with the fields to update. E.g.: 1545 `{"uses_allowed": 3}` to update just uses_allowed, or 1546 `{"uses_allowed": 3, "expiry_time": None}` to update both. 1547 This is passed straight to simple_update_one. 1548 1549 Returns: 1550 A dict with all info about the token, or None if token doesn't exist. 1551 """ 1552 1553 def _update_registration_token_txn(txn): 1554 try: 1555 self.db_pool.simple_update_one_txn( 1556 txn, 1557 "registration_tokens", 1558 keyvalues={"token": token}, 1559 updatevalues=updatevalues, 1560 ) 1561 except StoreError: 1562 # Update failed because token does not exist 1563 return None 1564 1565 # Get all info about the token so it can be sent in the response 1566 return self.db_pool.simple_select_one_txn( 1567 txn, 1568 "registration_tokens", 1569 keyvalues={"token": token}, 1570 retcols=[ 1571 "token", 1572 "uses_allowed", 1573 "pending", 1574 "completed", 1575 "expiry_time", 1576 ], 1577 allow_none=True, 1578 ) 1579 1580 return await self.db_pool.runInteraction( 1581 "update_registration_token", _update_registration_token_txn 1582 ) 1583 1584 async def delete_registration_token(self, token: str) -> bool: 1585 """Delete a registration token. Used by the admin API. 1586 1587 Args: 1588 token: The token to delete. 1589 1590 Returns: 1591 Whether the token was successfully deleted or not. 1592 """ 1593 try: 1594 await self.db_pool.simple_delete_one( 1595 "registration_tokens", 1596 keyvalues={"token": token}, 1597 desc="delete_registration_token", 1598 ) 1599 except StoreError: 1600 # Deletion failed because token does not exist 1601 return False 1602 1603 return True 1604 1605 @cached() 1606 async def mark_access_token_as_used(self, token_id: int) -> None: 1607 """ 1608 Mark the access token as used, which invalidates the refresh token used 1609 to obtain it. 1610 1611 Because get_user_by_access_token is cached, this function might be 1612 called multiple times for the same token, effectively doing unnecessary 1613 SQL updates. Because updating the `used` field only goes one way (from 1614 False to True) it is safe to cache this function as well to avoid this 1615 issue. 1616 1617 Args: 1618 token_id: The ID of the access token to update. 1619 Raises: 1620 StoreError if there was a problem updating this. 1621 """ 1622 await self.db_pool.simple_update_one( 1623 "access_tokens", 1624 {"id": token_id}, 1625 {"used": True}, 1626 desc="mark_access_token_as_used", 1627 ) 1628 1629 async def lookup_refresh_token( 1630 self, token: str 1631 ) -> Optional[RefreshTokenLookupResult]: 1632 """Lookup a refresh token with hints about its validity.""" 1633 1634 def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: 1635 txn.execute( 1636 """ 1637 SELECT 1638 rt.id token_id, 1639 rt.user_id, 1640 rt.device_id, 1641 rt.next_token_id, 1642 (nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed, 1643 at.used AS has_next_access_token_been_used, 1644 rt.expiry_ts, 1645 rt.ultimate_session_expiry_ts 1646 FROM refresh_tokens rt 1647 LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id 1648 LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id 1649 WHERE rt.token = ? 1650 """, 1651 (token,), 1652 ) 1653 row = txn.fetchone() 1654 1655 if row is None: 1656 return None 1657 1658 return RefreshTokenLookupResult( 1659 token_id=row[0], 1660 user_id=row[1], 1661 device_id=row[2], 1662 next_token_id=row[3], 1663 has_next_refresh_token_been_refreshed=row[4], 1664 # This column is nullable, ensure it's a boolean 1665 has_next_access_token_been_used=(row[5] or False), 1666 expiry_ts=row[6], 1667 ultimate_session_expiry_ts=row[7], 1668 ) 1669 1670 return await self.db_pool.runInteraction( 1671 "lookup_refresh_token", _lookup_refresh_token_txn 1672 ) 1673 1674 async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: 1675 """ 1676 Set the successor of a refresh token, removing the existing successor 1677 if any. 1678 1679 Args: 1680 token_id: ID of the refresh token to update. 1681 next_token_id: ID of its successor. 1682 """ 1683 1684 def _replace_refresh_token_txn(txn) -> None: 1685 # First check if there was an existing refresh token 1686 old_next_token_id = self.db_pool.simple_select_one_onecol_txn( 1687 txn, 1688 "refresh_tokens", 1689 {"id": token_id}, 1690 "next_token_id", 1691 allow_none=True, 1692 ) 1693 1694 self.db_pool.simple_update_one_txn( 1695 txn, 1696 "refresh_tokens", 1697 {"id": token_id}, 1698 {"next_token_id": next_token_id}, 1699 ) 1700 1701 # Delete the old "next" token if it exists. This should cascade and 1702 # delete the associated access_token 1703 if old_next_token_id is not None: 1704 self.db_pool.simple_delete_one_txn( 1705 txn, 1706 "refresh_tokens", 1707 {"id": old_next_token_id}, 1708 ) 1709 1710 await self.db_pool.runInteraction( 1711 "replace_refresh_token", _replace_refresh_token_txn 1712 ) 1713 1714 1715class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): 1716 def __init__( 1717 self, 1718 database: DatabasePool, 1719 db_conn: LoggingDatabaseConnection, 1720 hs: "HomeServer", 1721 ): 1722 super().__init__(database, db_conn, hs) 1723 1724 self._clock = hs.get_clock() 1725 self.config = hs.config 1726 1727 self.db_pool.updates.register_background_index_update( 1728 "access_tokens_device_index", 1729 index_name="access_tokens_device_id", 1730 table="access_tokens", 1731 columns=["user_id", "device_id"], 1732 ) 1733 1734 self.db_pool.updates.register_background_index_update( 1735 "users_creation_ts", 1736 index_name="users_creation_ts", 1737 table="users", 1738 columns=["creation_ts"], 1739 ) 1740 1741 # we no longer use refresh tokens, but it's possible that some people 1742 # might have a background update queued to build this index. Just 1743 # clear the background update. 1744 self.db_pool.updates.register_noop_background_update( 1745 "refresh_tokens_device_index" 1746 ) 1747 1748 self.db_pool.updates.register_background_update_handler( 1749 "users_set_deactivated_flag", self._background_update_set_deactivated_flag 1750 ) 1751 1752 self.db_pool.updates.register_noop_background_update( 1753 "user_threepids_grandfather" 1754 ) 1755 1756 self.db_pool.updates.register_background_index_update( 1757 "user_external_ids_user_id_idx", 1758 index_name="user_external_ids_user_id_idx", 1759 table="user_external_ids", 1760 columns=["user_id"], 1761 unique=False, 1762 ) 1763 1764 async def _background_update_set_deactivated_flag(self, progress, batch_size): 1765 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 1766 for each of them. 1767 """ 1768 1769 last_user = progress.get("user_id", "") 1770 1771 def _background_update_set_deactivated_flag_txn(txn): 1772 txn.execute( 1773 """ 1774 SELECT 1775 users.name, 1776 COUNT(access_tokens.token) AS count_tokens, 1777 COUNT(user_threepids.address) AS count_threepids 1778 FROM users 1779 LEFT JOIN access_tokens ON (access_tokens.user_id = users.name) 1780 LEFT JOIN user_threepids ON (user_threepids.user_id = users.name) 1781 WHERE (users.password_hash IS NULL OR users.password_hash = '') 1782 AND (users.appservice_id IS NULL OR users.appservice_id = '') 1783 AND users.is_guest = 0 1784 AND users.name > ? 1785 GROUP BY users.name 1786 ORDER BY users.name ASC 1787 LIMIT ?; 1788 """, 1789 (last_user, batch_size), 1790 ) 1791 1792 rows = self.db_pool.cursor_to_dict(txn) 1793 1794 if not rows: 1795 return True, 0 1796 1797 rows_processed_nb = 0 1798 1799 for user in rows: 1800 if not user["count_tokens"] and not user["count_threepids"]: 1801 self.set_user_deactivated_status_txn(txn, user["name"], True) 1802 rows_processed_nb += 1 1803 1804 logger.info("Marked %d rows as deactivated", rows_processed_nb) 1805 1806 self.db_pool.updates._background_update_progress_txn( 1807 txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} 1808 ) 1809 1810 if batch_size > len(rows): 1811 return True, len(rows) 1812 else: 1813 return False, len(rows) 1814 1815 end, nb_processed = await self.db_pool.runInteraction( 1816 "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn 1817 ) 1818 1819 if end: 1820 await self.db_pool.updates._end_background_update( 1821 "users_set_deactivated_flag" 1822 ) 1823 1824 return nb_processed 1825 1826 async def set_user_deactivated_status( 1827 self, user_id: str, deactivated: bool 1828 ) -> None: 1829 """Set the `deactivated` property for the provided user to the provided value. 1830 1831 Args: 1832 user_id: The ID of the user to set the status for. 1833 deactivated: The value to set for `deactivated`. 1834 """ 1835 1836 await self.db_pool.runInteraction( 1837 "set_user_deactivated_status", 1838 self.set_user_deactivated_status_txn, 1839 user_id, 1840 deactivated, 1841 ) 1842 1843 def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): 1844 self.db_pool.simple_update_one_txn( 1845 txn=txn, 1846 table="users", 1847 keyvalues={"name": user_id}, 1848 updatevalues={"deactivated": 1 if deactivated else 0}, 1849 ) 1850 self._invalidate_cache_and_stream( 1851 txn, self.get_user_deactivated_status, (user_id,) 1852 ) 1853 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 1854 txn.call_after(self.is_guest.invalidate, (user_id,)) 1855 1856 @cached() 1857 async def is_guest(self, user_id: str) -> bool: 1858 res = await self.db_pool.simple_select_one_onecol( 1859 table="users", 1860 keyvalues={"name": user_id}, 1861 retcol="is_guest", 1862 allow_none=True, 1863 desc="is_guest", 1864 ) 1865 1866 return res if res else False 1867 1868 1869class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): 1870 def __init__( 1871 self, 1872 database: DatabasePool, 1873 db_conn: LoggingDatabaseConnection, 1874 hs: "HomeServer", 1875 ): 1876 super().__init__(database, db_conn, hs) 1877 1878 self._ignore_unknown_session_error = ( 1879 hs.config.server.request_token_inhibit_3pid_errors 1880 ) 1881 1882 self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") 1883 self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") 1884 1885 async def add_access_token_to_user( 1886 self, 1887 user_id: str, 1888 token: str, 1889 device_id: Optional[str], 1890 valid_until_ms: Optional[int], 1891 puppets_user_id: Optional[str] = None, 1892 refresh_token_id: Optional[int] = None, 1893 ) -> int: 1894 """Adds an access token for the given user. 1895 1896 Args: 1897 user_id: The user ID. 1898 token: The new access token to add. 1899 device_id: ID of the device to associate with the access token. 1900 valid_until_ms: when the token is valid until. None for no expiry. 1901 puppets_user_id 1902 refresh_token_id: ID of the refresh token generated alongside this 1903 access token. 1904 Raises: 1905 StoreError if there was a problem adding this. 1906 Returns: 1907 The token ID 1908 """ 1909 next_id = self._access_tokens_id_gen.get_next() 1910 now = self._clock.time_msec() 1911 1912 await self.db_pool.simple_insert( 1913 "access_tokens", 1914 { 1915 "id": next_id, 1916 "user_id": user_id, 1917 "token": token, 1918 "device_id": device_id, 1919 "valid_until_ms": valid_until_ms, 1920 "puppets_user_id": puppets_user_id, 1921 "last_validated": now, 1922 "refresh_token_id": refresh_token_id, 1923 "used": False, 1924 }, 1925 desc="add_access_token_to_user", 1926 ) 1927 1928 return next_id 1929 1930 async def add_refresh_token_to_user( 1931 self, 1932 user_id: str, 1933 token: str, 1934 device_id: Optional[str], 1935 expiry_ts: Optional[int], 1936 ultimate_session_expiry_ts: Optional[int], 1937 ) -> int: 1938 """Adds a refresh token for the given user. 1939 1940 Args: 1941 user_id: The user ID. 1942 token: The new access token to add. 1943 device_id: ID of the device to associate with the refresh token. 1944 expiry_ts (milliseconds since the epoch): Time after which the 1945 refresh token cannot be used. 1946 If None, the refresh token never expires until it has been used. 1947 ultimate_session_expiry_ts (milliseconds since the epoch): 1948 Time at which the session will end and can not be extended any 1949 further. 1950 If None, the session can be refreshed indefinitely. 1951 Raises: 1952 StoreError if there was a problem adding this. 1953 Returns: 1954 The token ID 1955 """ 1956 next_id = self._refresh_tokens_id_gen.get_next() 1957 1958 await self.db_pool.simple_insert( 1959 "refresh_tokens", 1960 { 1961 "id": next_id, 1962 "user_id": user_id, 1963 "device_id": device_id, 1964 "token": token, 1965 "next_token_id": None, 1966 "expiry_ts": expiry_ts, 1967 "ultimate_session_expiry_ts": ultimate_session_expiry_ts, 1968 }, 1969 desc="add_refresh_token_to_user", 1970 ) 1971 1972 return next_id 1973 1974 def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: 1975 old_device_id = self.db_pool.simple_select_one_onecol_txn( 1976 txn, "access_tokens", {"token": token}, "device_id" 1977 ) 1978 1979 self.db_pool.simple_update_txn( 1980 txn, "access_tokens", {"token": token}, {"device_id": device_id} 1981 ) 1982 1983 self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,)) 1984 1985 return old_device_id 1986 1987 async def set_device_for_access_token(self, token: str, device_id: str) -> str: 1988 """Sets the device ID associated with an access token. 1989 1990 Args: 1991 token: The access token to modify. 1992 device_id: The new device ID. 1993 Returns: 1994 The old device ID associated with the access token. 1995 """ 1996 1997 return await self.db_pool.runInteraction( 1998 "set_device_for_access_token", 1999 self._set_device_for_access_token_txn, 2000 token, 2001 device_id, 2002 ) 2003 2004 async def register_user( 2005 self, 2006 user_id: str, 2007 password_hash: Optional[str] = None, 2008 was_guest: bool = False, 2009 make_guest: bool = False, 2010 appservice_id: Optional[str] = None, 2011 create_profile_with_displayname: Optional[str] = None, 2012 admin: bool = False, 2013 user_type: Optional[str] = None, 2014 shadow_banned: bool = False, 2015 ) -> None: 2016 """Attempts to register an account. 2017 2018 Args: 2019 user_id: The desired user ID to register. 2020 password_hash: Optional. The password hash for this user. 2021 was_guest: Whether this is a guest account being upgraded to a 2022 non-guest account. 2023 make_guest: True if the the new user should be guest, false to add a 2024 regular user account. 2025 appservice_id: The ID of the appservice registering the user. 2026 create_profile_with_displayname: Optionally create a profile for 2027 the user, setting their displayname to the given value 2028 admin: is an admin user? 2029 user_type: type of user. One of the values from api.constants.UserTypes, 2030 or None for a normal user. 2031 shadow_banned: Whether the user is shadow-banned, i.e. they may be 2032 told their requests succeeded but we ignore them. 2033 2034 Raises: 2035 StoreError if the user_id could not be registered. 2036 """ 2037 await self.db_pool.runInteraction( 2038 "register_user", 2039 self._register_user, 2040 user_id, 2041 password_hash, 2042 was_guest, 2043 make_guest, 2044 appservice_id, 2045 create_profile_with_displayname, 2046 admin, 2047 user_type, 2048 shadow_banned, 2049 ) 2050 2051 def _register_user( 2052 self, 2053 txn, 2054 user_id: str, 2055 password_hash: Optional[str], 2056 was_guest: bool, 2057 make_guest: bool, 2058 appservice_id: Optional[str], 2059 create_profile_with_displayname: Optional[str], 2060 admin: bool, 2061 user_type: Optional[str], 2062 shadow_banned: bool, 2063 ): 2064 user_id_obj = UserID.from_string(user_id) 2065 2066 now = int(self._clock.time()) 2067 2068 try: 2069 if was_guest: 2070 # Ensure that the guest user actually exists 2071 # ``allow_none=False`` makes this raise an exception 2072 # if the row isn't in the database. 2073 self.db_pool.simple_select_one_txn( 2074 txn, 2075 "users", 2076 keyvalues={"name": user_id, "is_guest": 1}, 2077 retcols=("name",), 2078 allow_none=False, 2079 ) 2080 2081 self.db_pool.simple_update_one_txn( 2082 txn, 2083 "users", 2084 keyvalues={"name": user_id, "is_guest": 1}, 2085 updatevalues={ 2086 "password_hash": password_hash, 2087 "upgrade_ts": now, 2088 "is_guest": 1 if make_guest else 0, 2089 "appservice_id": appservice_id, 2090 "admin": 1 if admin else 0, 2091 "user_type": user_type, 2092 "shadow_banned": shadow_banned, 2093 }, 2094 ) 2095 else: 2096 self.db_pool.simple_insert_txn( 2097 txn, 2098 "users", 2099 values={ 2100 "name": user_id, 2101 "password_hash": password_hash, 2102 "creation_ts": now, 2103 "is_guest": 1 if make_guest else 0, 2104 "appservice_id": appservice_id, 2105 "admin": 1 if admin else 0, 2106 "user_type": user_type, 2107 "shadow_banned": shadow_banned, 2108 }, 2109 ) 2110 2111 except self.database_engine.module.IntegrityError: 2112 raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) 2113 2114 if self._account_validity_enabled: 2115 self.set_expiration_date_for_user_txn(txn, user_id) 2116 2117 if create_profile_with_displayname: 2118 # set a default displayname serverside to avoid ugly race 2119 # between auto-joins and clients trying to set displaynames 2120 # 2121 # *obviously* the 'profiles' table uses localpart for user_id 2122 # while everything else uses the full mxid. 2123 txn.execute( 2124 "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", 2125 (user_id_obj.localpart, create_profile_with_displayname), 2126 ) 2127 2128 if self.hs.config.stats.stats_enabled: 2129 # we create a new completed user statistics row 2130 2131 # we don't strictly need current_token since this user really can't 2132 # have any state deltas before now (as it is a new user), but still, 2133 # we include it for completeness. 2134 current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) 2135 self._update_stats_delta_txn( 2136 txn, now, "user", user_id, {}, complete_with_stream_id=current_token 2137 ) 2138 2139 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 2140 2141 async def user_set_password_hash( 2142 self, user_id: str, password_hash: Optional[str] 2143 ) -> None: 2144 """ 2145 NB. This does *not* evict any cache because the one use for this 2146 removes most of the entries subsequently anyway so it would be 2147 pointless. Use flush_user separately. 2148 """ 2149 2150 def user_set_password_hash_txn(txn): 2151 self.db_pool.simple_update_one_txn( 2152 txn, "users", {"name": user_id}, {"password_hash": password_hash} 2153 ) 2154 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 2155 2156 await self.db_pool.runInteraction( 2157 "user_set_password_hash", user_set_password_hash_txn 2158 ) 2159 2160 async def user_set_consent_version( 2161 self, user_id: str, consent_version: str 2162 ) -> None: 2163 """Updates the user table to record privacy policy consent 2164 2165 Args: 2166 user_id: full mxid of the user to update 2167 consent_version: version of the policy the user has consented to 2168 2169 Raises: 2170 StoreError(404) if user not found 2171 """ 2172 2173 def f(txn): 2174 self.db_pool.simple_update_one_txn( 2175 txn, 2176 table="users", 2177 keyvalues={"name": user_id}, 2178 updatevalues={"consent_version": consent_version}, 2179 ) 2180 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 2181 2182 await self.db_pool.runInteraction("user_set_consent_version", f) 2183 2184 async def user_set_consent_server_notice_sent( 2185 self, user_id: str, consent_version: str 2186 ) -> None: 2187 """Updates the user table to record that we have sent the user a server 2188 notice about privacy policy consent 2189 2190 Args: 2191 user_id: full mxid of the user to update 2192 consent_version: version of the policy we have notified the user about 2193 2194 Raises: 2195 StoreError(404) if user not found 2196 """ 2197 2198 def f(txn): 2199 self.db_pool.simple_update_one_txn( 2200 txn, 2201 table="users", 2202 keyvalues={"name": user_id}, 2203 updatevalues={"consent_server_notice_sent": consent_version}, 2204 ) 2205 self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) 2206 2207 await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) 2208 2209 async def user_delete_access_tokens( 2210 self, 2211 user_id: str, 2212 except_token_id: Optional[int] = None, 2213 device_id: Optional[str] = None, 2214 ) -> List[Tuple[str, int, Optional[str]]]: 2215 """ 2216 Invalidate access and refresh tokens belonging to a user 2217 2218 Args: 2219 user_id: ID of user the tokens belong to 2220 except_token_id: access_tokens ID which should *not* be deleted 2221 device_id: ID of device the tokens are associated with. 2222 If None, tokens associated with any device (or no device) will 2223 be deleted 2224 Returns: 2225 A tuple of (token, token id, device id) for each of the deleted tokens 2226 """ 2227 2228 def f(txn): 2229 keyvalues = {"user_id": user_id} 2230 if device_id is not None: 2231 keyvalues["device_id"] = device_id 2232 2233 items = keyvalues.items() 2234 where_clause = " AND ".join(k + " = ?" for k, _ in items) 2235 values: List[Union[str, int]] = [v for _, v in items] 2236 # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat 2237 # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where 2238 # clause and values before we handle that. This seems to be only used in the "set password" handler. 2239 refresh_where_clause = where_clause 2240 refresh_values = values.copy() 2241 if except_token_id: 2242 # TODO: support that for refresh tokens 2243 where_clause += " AND id != ?" 2244 values.append(except_token_id) 2245 2246 txn.execute( 2247 "SELECT token, id, device_id FROM access_tokens WHERE %s" 2248 % where_clause, 2249 values, 2250 ) 2251 tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] 2252 2253 for token, _, _ in tokens_and_devices: 2254 self._invalidate_cache_and_stream( 2255 txn, self.get_user_by_access_token, (token,) 2256 ) 2257 2258 txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) 2259 2260 txn.execute( 2261 "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, 2262 refresh_values, 2263 ) 2264 2265 return tokens_and_devices 2266 2267 return await self.db_pool.runInteraction("user_delete_access_tokens", f) 2268 2269 async def delete_access_token(self, access_token: str) -> None: 2270 def f(txn): 2271 self.db_pool.simple_delete_one_txn( 2272 txn, table="access_tokens", keyvalues={"token": access_token} 2273 ) 2274 2275 self._invalidate_cache_and_stream( 2276 txn, self.get_user_by_access_token, (access_token,) 2277 ) 2278 2279 await self.db_pool.runInteraction("delete_access_token", f) 2280 2281 async def delete_refresh_token(self, refresh_token: str) -> None: 2282 def f(txn): 2283 self.db_pool.simple_delete_one_txn( 2284 txn, table="refresh_tokens", keyvalues={"token": refresh_token} 2285 ) 2286 2287 await self.db_pool.runInteraction("delete_refresh_token", f) 2288 2289 async def add_user_pending_deactivation(self, user_id: str) -> None: 2290 """ 2291 Adds a user to the table of users who need to be parted from all the rooms they're 2292 in 2293 """ 2294 await self.db_pool.simple_insert( 2295 "users_pending_deactivation", 2296 values={"user_id": user_id}, 2297 desc="add_user_pending_deactivation", 2298 ) 2299 2300 async def validate_threepid_session( 2301 self, session_id: str, client_secret: str, token: str, current_ts: int 2302 ) -> Optional[str]: 2303 """Attempt to validate a threepid session using a token 2304 2305 Args: 2306 session_id: The id of a validation session 2307 client_secret: A unique string provided by the client to help identify 2308 this validation attempt 2309 token: A validation token 2310 current_ts: The current unix time in milliseconds. Used for checking 2311 token expiry status 2312 2313 Raises: 2314 ThreepidValidationError: if a matching validation token was not found or has 2315 expired 2316 2317 Returns: 2318 A str representing a link to redirect the user to if there is one. 2319 """ 2320 2321 # Insert everything into a transaction in order to run atomically 2322 def validate_threepid_session_txn(txn): 2323 row = self.db_pool.simple_select_one_txn( 2324 txn, 2325 table="threepid_validation_session", 2326 keyvalues={"session_id": session_id}, 2327 retcols=["client_secret", "validated_at"], 2328 allow_none=True, 2329 ) 2330 2331 if not row: 2332 if self._ignore_unknown_session_error: 2333 # If we need to inhibit the error caused by an incorrect session ID, 2334 # use None as placeholder values for the client secret and the 2335 # validation timestamp. 2336 # It shouldn't be an issue because they're both only checked after 2337 # the token check, which should fail. And if it doesn't for some 2338 # reason, the next check is on the client secret, which is NOT NULL, 2339 # so we don't have to worry about the client secret matching by 2340 # accident. 2341 row = {"client_secret": None, "validated_at": None} 2342 else: 2343 raise ThreepidValidationError("Unknown session_id") 2344 2345 retrieved_client_secret = row["client_secret"] 2346 validated_at = row["validated_at"] 2347 2348 row = self.db_pool.simple_select_one_txn( 2349 txn, 2350 table="threepid_validation_token", 2351 keyvalues={"session_id": session_id, "token": token}, 2352 retcols=["expires", "next_link"], 2353 allow_none=True, 2354 ) 2355 2356 if not row: 2357 raise ThreepidValidationError( 2358 "Validation token not found or has expired" 2359 ) 2360 expires = row["expires"] 2361 next_link = row["next_link"] 2362 2363 if retrieved_client_secret != client_secret: 2364 raise ThreepidValidationError( 2365 "This client_secret does not match the provided session_id" 2366 ) 2367 2368 # If the session is already validated, no need to revalidate 2369 if validated_at: 2370 return next_link 2371 2372 if expires <= current_ts: 2373 raise ThreepidValidationError( 2374 "This token has expired. Please request a new one" 2375 ) 2376 2377 # Looks good. Validate the session 2378 self.db_pool.simple_update_txn( 2379 txn, 2380 table="threepid_validation_session", 2381 keyvalues={"session_id": session_id}, 2382 updatevalues={"validated_at": self._clock.time_msec()}, 2383 ) 2384 2385 return next_link 2386 2387 # Return next_link if it exists 2388 return await self.db_pool.runInteraction( 2389 "validate_threepid_session_txn", validate_threepid_session_txn 2390 ) 2391 2392 async def start_or_continue_validation_session( 2393 self, 2394 medium: str, 2395 address: str, 2396 session_id: str, 2397 client_secret: str, 2398 send_attempt: int, 2399 next_link: Optional[str], 2400 token: str, 2401 token_expires: int, 2402 ) -> None: 2403 """Creates a new threepid validation session if it does not already 2404 exist and associates a new validation token with it 2405 2406 Args: 2407 medium: The medium of the 3PID 2408 address: The address of the 3PID 2409 session_id: The id of this validation session 2410 client_secret: A unique string provided by the client to help 2411 identify this validation attempt 2412 send_attempt: The latest send_attempt on this session 2413 next_link: The link to redirect the user to upon successful validation 2414 token: The validation token 2415 token_expires: The timestamp for which after the token will no 2416 longer be valid 2417 """ 2418 2419 def start_or_continue_validation_session_txn(txn): 2420 # Create or update a validation session 2421 self.db_pool.simple_upsert_txn( 2422 txn, 2423 table="threepid_validation_session", 2424 keyvalues={"session_id": session_id}, 2425 values={"last_send_attempt": send_attempt}, 2426 insertion_values={ 2427 "medium": medium, 2428 "address": address, 2429 "client_secret": client_secret, 2430 }, 2431 ) 2432 2433 # Create a new validation token with this session ID 2434 self.db_pool.simple_insert_txn( 2435 txn, 2436 table="threepid_validation_token", 2437 values={ 2438 "session_id": session_id, 2439 "token": token, 2440 "next_link": next_link, 2441 "expires": token_expires, 2442 }, 2443 ) 2444 2445 await self.db_pool.runInteraction( 2446 "start_or_continue_validation_session", 2447 start_or_continue_validation_session_txn, 2448 ) 2449 2450 2451def find_max_generated_user_id_localpart(cur: Cursor) -> int: 2452 """ 2453 Gets the localpart of the max current generated user ID. 2454 2455 Generated user IDs are integers, so we find the largest integer user ID 2456 already taken and return that. 2457 """ 2458 2459 # We bound between '@0' and '@a' to avoid pulling the entire table 2460 # out. 2461 cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'") 2462 2463 regex = re.compile(r"^@(\d+):") 2464 2465 max_found = 0 2466 2467 for (user_id,) in cur: 2468 match = regex.search(user_id) 2469 if match: 2470 max_found = max(int(match.group(1)), max_found) 2471 return max_found 2472