1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2019 The Matrix.org Foundation C.I.C. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import logging 17from abc import abstractmethod 18from enum import Enum 19from typing import ( 20 TYPE_CHECKING, 21 Any, 22 Awaitable, 23 Dict, 24 List, 25 Optional, 26 Tuple, 27 Union, 28 cast, 29) 30 31import attr 32 33from synapse.api.constants import EventContentFields, EventTypes, JoinRules 34from synapse.api.errors import StoreError 35from synapse.api.room_versions import RoomVersion, RoomVersions 36from synapse.events import EventBase 37from synapse.storage._base import SQLBaseStore, db_to_json 38from synapse.storage.database import ( 39 DatabasePool, 40 LoggingDatabaseConnection, 41 LoggingTransaction, 42) 43from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore 44from synapse.storage.types import Cursor 45from synapse.storage.util.id_generators import IdGenerator 46from synapse.types import JsonDict, ThirdPartyInstanceID 47from synapse.util import json_encoder 48from synapse.util.caches.descriptors import cached 49from synapse.util.stringutils import MXC_REGEX 50 51if TYPE_CHECKING: 52 from synapse.server import HomeServer 53 54logger = logging.getLogger(__name__) 55 56 57@attr.s(slots=True, frozen=True, auto_attribs=True) 58class RatelimitOverride: 59 messages_per_second: int 60 burst_count: int 61 62 63class RoomSortOrder(Enum): 64 """ 65 Enum to define the sorting method used when returning rooms with get_rooms_paginate 66 67 NAME = sort rooms alphabetically by name 68 JOINED_MEMBERS = sort rooms by membership size, highest to lowest 69 """ 70 71 # ALPHABETICAL and SIZE are deprecated. 72 # ALPHABETICAL is the same as NAME. 73 ALPHABETICAL = "alphabetical" 74 # SIZE is the same as JOINED_MEMBERS. 75 SIZE = "size" 76 NAME = "name" 77 CANONICAL_ALIAS = "canonical_alias" 78 JOINED_MEMBERS = "joined_members" 79 JOINED_LOCAL_MEMBERS = "joined_local_members" 80 VERSION = "version" 81 CREATOR = "creator" 82 ENCRYPTION = "encryption" 83 FEDERATABLE = "federatable" 84 PUBLIC = "public" 85 JOIN_RULES = "join_rules" 86 GUEST_ACCESS = "guest_access" 87 HISTORY_VISIBILITY = "history_visibility" 88 STATE_EVENTS = "state_events" 89 90 91class RoomWorkerStore(CacheInvalidationWorkerStore): 92 def __init__( 93 self, 94 database: DatabasePool, 95 db_conn: LoggingDatabaseConnection, 96 hs: "HomeServer", 97 ): 98 super().__init__(database, db_conn, hs) 99 100 self.config = hs.config 101 102 async def store_room( 103 self, 104 room_id: str, 105 room_creator_user_id: str, 106 is_public: bool, 107 room_version: RoomVersion, 108 ) -> None: 109 """Stores a room. 110 111 Args: 112 room_id: The desired room ID, can be None. 113 room_creator_user_id: The user ID of the room creator. 114 is_public: True to indicate that this room should appear in 115 public room lists. 116 room_version: The version of the room 117 Raises: 118 StoreError if the room could not be stored. 119 """ 120 try: 121 await self.db_pool.simple_insert( 122 "rooms", 123 { 124 "room_id": room_id, 125 "creator": room_creator_user_id, 126 "is_public": is_public, 127 "room_version": room_version.identifier, 128 "has_auth_chain_index": True, 129 }, 130 desc="store_room", 131 ) 132 except Exception as e: 133 logger.error("store_room with room_id=%s failed: %s", room_id, e) 134 raise StoreError(500, "Problem creating room.") 135 136 async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]: 137 """Retrieve a room. 138 139 Args: 140 room_id: The ID of the room to retrieve. 141 Returns: 142 A dict containing the room information, or None if the room is unknown. 143 """ 144 return await self.db_pool.simple_select_one( 145 table="rooms", 146 keyvalues={"room_id": room_id}, 147 retcols=("room_id", "is_public", "creator", "has_auth_chain_index"), 148 desc="get_room", 149 allow_none=True, 150 ) 151 152 async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: 153 """Retrieve room with statistics. 154 155 Args: 156 room_id: The ID of the room to retrieve. 157 Returns: 158 A dict containing the room information, or None if the room is unknown. 159 """ 160 161 def get_room_with_stats_txn( 162 txn: LoggingTransaction, room_id: str 163 ) -> Optional[Dict[str, Any]]: 164 sql = """ 165 SELECT room_id, state.name, state.canonical_alias, curr.joined_members, 166 curr.local_users_in_room AS joined_local_members, rooms.room_version AS version, 167 rooms.creator, state.encryption, state.is_federatable AS federatable, 168 rooms.is_public AS public, state.join_rules, state.guest_access, 169 state.history_visibility, curr.current_state_events AS state_events, 170 state.avatar, state.topic 171 FROM rooms 172 LEFT JOIN room_stats_state state USING (room_id) 173 LEFT JOIN room_stats_current curr USING (room_id) 174 WHERE room_id = ? 175 """ 176 txn.execute(sql, [room_id]) 177 # Catch error if sql returns empty result to return "None" instead of an error 178 try: 179 res = self.db_pool.cursor_to_dict(txn)[0] 180 except IndexError: 181 return None 182 183 res["federatable"] = bool(res["federatable"]) 184 res["public"] = bool(res["public"]) 185 return res 186 187 return await self.db_pool.runInteraction( 188 "get_room_with_stats", get_room_with_stats_txn, room_id 189 ) 190 191 async def get_public_room_ids(self) -> List[str]: 192 return await self.db_pool.simple_select_onecol( 193 table="rooms", 194 keyvalues={"is_public": True}, 195 retcol="room_id", 196 desc="get_public_room_ids", 197 ) 198 199 async def count_public_rooms( 200 self, 201 network_tuple: Optional[ThirdPartyInstanceID], 202 ignore_non_federatable: bool, 203 ) -> int: 204 """Counts the number of public rooms as tracked in the room_stats_current 205 and room_stats_state table. 206 207 Args: 208 network_tuple 209 ignore_non_federatable: If true filters out non-federatable rooms 210 """ 211 212 def _count_public_rooms_txn(txn: LoggingTransaction) -> int: 213 query_args = [] 214 215 if network_tuple: 216 if network_tuple.appservice_id: 217 published_sql = """ 218 SELECT room_id from appservice_room_list 219 WHERE appservice_id = ? AND network_id = ? 220 """ 221 query_args.append(network_tuple.appservice_id) 222 assert network_tuple.network_id is not None 223 query_args.append(network_tuple.network_id) 224 else: 225 published_sql = """ 226 SELECT room_id FROM rooms WHERE is_public 227 """ 228 else: 229 published_sql = """ 230 SELECT room_id FROM rooms WHERE is_public 231 UNION SELECT room_id from appservice_room_list 232 """ 233 234 sql = """ 235 SELECT 236 COUNT(*) 237 FROM ( 238 %(published_sql)s 239 ) published 240 INNER JOIN room_stats_state USING (room_id) 241 INNER JOIN room_stats_current USING (room_id) 242 WHERE 243 ( 244 join_rules = 'public' OR join_rules = '%(knock_join_rule)s' 245 OR history_visibility = 'world_readable' 246 ) 247 AND joined_members > 0 248 """ % { 249 "published_sql": published_sql, 250 "knock_join_rule": JoinRules.KNOCK, 251 } 252 253 txn.execute(sql, query_args) 254 return cast(Tuple[int], txn.fetchone())[0] 255 256 return await self.db_pool.runInteraction( 257 "count_public_rooms", _count_public_rooms_txn 258 ) 259 260 async def get_room_count(self) -> int: 261 """Retrieve the total number of rooms.""" 262 263 def f(txn: LoggingTransaction) -> int: 264 sql = "SELECT count(*) FROM rooms" 265 txn.execute(sql) 266 row = cast(Tuple[int], txn.fetchone()) 267 return row[0] 268 269 return await self.db_pool.runInteraction("get_rooms", f) 270 271 async def get_largest_public_rooms( 272 self, 273 network_tuple: Optional[ThirdPartyInstanceID], 274 search_filter: Optional[dict], 275 limit: Optional[int], 276 bounds: Optional[Tuple[int, str]], 277 forwards: bool, 278 ignore_non_federatable: bool = False, 279 ) -> List[Dict[str, Any]]: 280 """Gets the largest public rooms (where largest is in terms of joined 281 members, as tracked in the statistics table). 282 283 Args: 284 network_tuple 285 search_filter 286 limit: Maxmimum number of rows to return, unlimited otherwise. 287 bounds: An uppoer or lower bound to apply to result set if given, 288 consists of a joined member count and room_id (these are 289 excluded from result set). 290 forwards: true iff going forwards, going backwards otherwise 291 ignore_non_federatable: If true filters out non-federatable rooms. 292 293 Returns: 294 Rooms in order: biggest number of joined users first. 295 We then arbitrarily use the room_id as a tie breaker. 296 297 """ 298 299 where_clauses = [] 300 query_args: List[Union[str, int]] = [] 301 302 if network_tuple: 303 if network_tuple.appservice_id: 304 published_sql = """ 305 SELECT room_id from appservice_room_list 306 WHERE appservice_id = ? AND network_id = ? 307 """ 308 query_args.append(network_tuple.appservice_id) 309 assert network_tuple.network_id is not None 310 query_args.append(network_tuple.network_id) 311 else: 312 published_sql = """ 313 SELECT room_id FROM rooms WHERE is_public 314 """ 315 else: 316 published_sql = """ 317 SELECT room_id FROM rooms WHERE is_public 318 UNION SELECT room_id from appservice_room_list 319 """ 320 321 # Work out the bounds if we're given them, these bounds look slightly 322 # odd, but are designed to help query planner use indices by pulling 323 # out a common bound. 324 if bounds: 325 last_joined_members, last_room_id = bounds 326 if forwards: 327 where_clauses.append( 328 """ 329 joined_members <= ? AND ( 330 joined_members < ? OR room_id < ? 331 ) 332 """ 333 ) 334 else: 335 where_clauses.append( 336 """ 337 joined_members >= ? AND ( 338 joined_members > ? OR room_id > ? 339 ) 340 """ 341 ) 342 343 query_args += [last_joined_members, last_joined_members, last_room_id] 344 345 if ignore_non_federatable: 346 where_clauses.append("is_federatable") 347 348 if search_filter and search_filter.get("generic_search_term", None): 349 search_term = "%" + search_filter["generic_search_term"] + "%" 350 351 where_clauses.append( 352 """ 353 ( 354 LOWER(name) LIKE ? 355 OR LOWER(topic) LIKE ? 356 OR LOWER(canonical_alias) LIKE ? 357 ) 358 """ 359 ) 360 query_args += [ 361 search_term.lower(), 362 search_term.lower(), 363 search_term.lower(), 364 ] 365 366 where_clause = "" 367 if where_clauses: 368 where_clause = " AND " + " AND ".join(where_clauses) 369 370 sql = """ 371 SELECT 372 room_id, name, topic, canonical_alias, joined_members, 373 avatar, history_visibility, guest_access, join_rules 374 FROM ( 375 %(published_sql)s 376 ) published 377 INNER JOIN room_stats_state USING (room_id) 378 INNER JOIN room_stats_current USING (room_id) 379 WHERE 380 ( 381 join_rules = 'public' OR join_rules = '%(knock_join_rule)s' 382 OR history_visibility = 'world_readable' 383 ) 384 AND joined_members > 0 385 %(where_clause)s 386 ORDER BY joined_members %(dir)s, room_id %(dir)s 387 """ % { 388 "published_sql": published_sql, 389 "where_clause": where_clause, 390 "dir": "DESC" if forwards else "ASC", 391 "knock_join_rule": JoinRules.KNOCK, 392 } 393 394 if limit is not None: 395 query_args.append(limit) 396 397 sql += """ 398 LIMIT ? 399 """ 400 401 def _get_largest_public_rooms_txn( 402 txn: LoggingTransaction, 403 ) -> List[Dict[str, Any]]: 404 txn.execute(sql, query_args) 405 406 results = self.db_pool.cursor_to_dict(txn) 407 408 if not forwards: 409 results.reverse() 410 411 return results 412 413 ret_val = await self.db_pool.runInteraction( 414 "get_largest_public_rooms", _get_largest_public_rooms_txn 415 ) 416 return ret_val 417 418 @cached(max_entries=10000) 419 async def is_room_blocked(self, room_id: str) -> Optional[bool]: 420 return await self.db_pool.simple_select_one_onecol( 421 table="blocked_rooms", 422 keyvalues={"room_id": room_id}, 423 retcol="1", 424 allow_none=True, 425 desc="is_room_blocked", 426 ) 427 428 async def room_is_blocked_by(self, room_id: str) -> Optional[str]: 429 """ 430 Function to retrieve user who has blocked the room. 431 user_id is non-nullable 432 It returns None if the room is not blocked. 433 """ 434 return await self.db_pool.simple_select_one_onecol( 435 table="blocked_rooms", 436 keyvalues={"room_id": room_id}, 437 retcol="user_id", 438 allow_none=True, 439 desc="room_is_blocked_by", 440 ) 441 442 async def get_rooms_paginate( 443 self, 444 start: int, 445 limit: int, 446 order_by: str, 447 reverse_order: bool, 448 search_term: Optional[str], 449 ) -> Tuple[List[Dict[str, Any]], int]: 450 """Function to retrieve a paginated list of rooms as json. 451 452 Args: 453 start: offset in the list 454 limit: maximum amount of rooms to retrieve 455 order_by: the sort order of the returned list 456 reverse_order: whether to reverse the room list 457 search_term: a string to filter room names, 458 canonical alias and room ids by. 459 Room ID must match exactly. Canonical alias must match a substring of the local part. 460 Returns: 461 A list of room dicts and an integer representing the total number of 462 rooms that exist given this query 463 """ 464 # Filter room names by a string 465 where_statement = "" 466 search_pattern: List[object] = [] 467 if search_term: 468 where_statement = """ 469 WHERE LOWER(state.name) LIKE ? 470 OR LOWER(state.canonical_alias) LIKE ? 471 OR state.room_id = ? 472 """ 473 474 # Our postgres db driver converts ? -> %s in SQL strings as that's the 475 # placeholder for postgres. 476 # HOWEVER, if you put a % into your SQL then everything goes wibbly. 477 # To get around this, we're going to surround search_term with %'s 478 # before giving it to the database in python instead 479 search_pattern = [ 480 "%" + search_term.lower() + "%", 481 "#%" + search_term.lower() + "%:%", 482 search_term, 483 ] 484 485 # Set ordering 486 if RoomSortOrder(order_by) == RoomSortOrder.SIZE: 487 # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS 488 order_by_column = "curr.joined_members" 489 order_by_asc = False 490 elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL: 491 # Deprecated in favour of RoomSortOrder.NAME 492 order_by_column = "state.name" 493 order_by_asc = True 494 elif RoomSortOrder(order_by) == RoomSortOrder.NAME: 495 order_by_column = "state.name" 496 order_by_asc = True 497 elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS: 498 order_by_column = "state.canonical_alias" 499 order_by_asc = True 500 elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS: 501 order_by_column = "curr.joined_members" 502 order_by_asc = False 503 elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS: 504 order_by_column = "curr.local_users_in_room" 505 order_by_asc = False 506 elif RoomSortOrder(order_by) == RoomSortOrder.VERSION: 507 order_by_column = "rooms.room_version" 508 order_by_asc = False 509 elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR: 510 order_by_column = "rooms.creator" 511 order_by_asc = True 512 elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION: 513 order_by_column = "state.encryption" 514 order_by_asc = True 515 elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE: 516 order_by_column = "state.is_federatable" 517 order_by_asc = True 518 elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC: 519 order_by_column = "rooms.is_public" 520 order_by_asc = True 521 elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES: 522 order_by_column = "state.join_rules" 523 order_by_asc = True 524 elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS: 525 order_by_column = "state.guest_access" 526 order_by_asc = True 527 elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY: 528 order_by_column = "state.history_visibility" 529 order_by_asc = True 530 elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS: 531 order_by_column = "curr.current_state_events" 532 order_by_asc = False 533 else: 534 raise StoreError( 535 500, "Incorrect value for order_by provided: %s" % order_by 536 ) 537 538 # Whether to return the list in reverse order 539 if reverse_order: 540 # Flip the boolean 541 order_by_asc = not order_by_asc 542 543 # Create one query for getting the limited number of events that the user asked 544 # for, and another query for getting the total number of events that could be 545 # returned. Thus allowing us to see if there are more events to paginate through 546 info_sql = """ 547 SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members, 548 curr.local_users_in_room, rooms.room_version, rooms.creator, 549 state.encryption, state.is_federatable, rooms.is_public, state.join_rules, 550 state.guest_access, state.history_visibility, curr.current_state_events 551 FROM room_stats_state state 552 INNER JOIN room_stats_current curr USING (room_id) 553 INNER JOIN rooms USING (room_id) 554 %s 555 ORDER BY %s %s 556 LIMIT ? 557 OFFSET ? 558 """ % ( 559 where_statement, 560 order_by_column, 561 "ASC" if order_by_asc else "DESC", 562 ) 563 564 # Use a nested SELECT statement as SQL can't count(*) with an OFFSET 565 count_sql = """ 566 SELECT count(*) FROM ( 567 SELECT room_id FROM room_stats_state state 568 %s 569 ) AS get_room_ids 570 """ % ( 571 where_statement, 572 ) 573 574 def _get_rooms_paginate_txn( 575 txn: LoggingTransaction, 576 ) -> Tuple[List[Dict[str, Any]], int]: 577 # Add the search term into the WHERE clause 578 # and execute the data query 579 txn.execute(info_sql, search_pattern + [limit, start]) 580 581 # Refactor room query data into a structured dictionary 582 rooms = [] 583 for room in txn: 584 rooms.append( 585 { 586 "room_id": room[0], 587 "name": room[1], 588 "canonical_alias": room[2], 589 "joined_members": room[3], 590 "joined_local_members": room[4], 591 "version": room[5], 592 "creator": room[6], 593 "encryption": room[7], 594 "federatable": room[8], 595 "public": room[9], 596 "join_rules": room[10], 597 "guest_access": room[11], 598 "history_visibility": room[12], 599 "state_events": room[13], 600 } 601 ) 602 603 # Execute the count query 604 605 # Add the search term into the WHERE clause if present 606 txn.execute(count_sql, search_pattern) 607 608 room_count = cast(Tuple[int], txn.fetchone()) 609 return rooms, room_count[0] 610 611 return await self.db_pool.runInteraction( 612 "get_rooms_paginate", 613 _get_rooms_paginate_txn, 614 ) 615 616 @cached(max_entries=10000) 617 async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]: 618 """Check if there are any overrides for ratelimiting for the given user 619 620 Args: 621 user_id: user ID of the user 622 Returns: 623 RatelimitOverride if there is an override, else None. If the contents 624 of RatelimitOverride are None or 0 then ratelimitng has been 625 disabled for that user entirely. 626 """ 627 row = await self.db_pool.simple_select_one( 628 table="ratelimit_override", 629 keyvalues={"user_id": user_id}, 630 retcols=("messages_per_second", "burst_count"), 631 allow_none=True, 632 desc="get_ratelimit_for_user", 633 ) 634 635 if row: 636 return RatelimitOverride( 637 messages_per_second=row["messages_per_second"], 638 burst_count=row["burst_count"], 639 ) 640 else: 641 return None 642 643 async def set_ratelimit_for_user( 644 self, user_id: str, messages_per_second: int, burst_count: int 645 ) -> None: 646 """Sets whether a user is set an overridden ratelimit. 647 Args: 648 user_id: user ID of the user 649 messages_per_second: The number of actions that can be performed in a second. 650 burst_count: How many actions that can be performed before being limited. 651 """ 652 653 def set_ratelimit_txn(txn: LoggingTransaction) -> None: 654 self.db_pool.simple_upsert_txn( 655 txn, 656 table="ratelimit_override", 657 keyvalues={"user_id": user_id}, 658 values={ 659 "messages_per_second": messages_per_second, 660 "burst_count": burst_count, 661 }, 662 ) 663 664 self._invalidate_cache_and_stream( 665 txn, self.get_ratelimit_for_user, (user_id,) 666 ) 667 668 await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn) 669 670 async def delete_ratelimit_for_user(self, user_id: str) -> None: 671 """Delete an overridden ratelimit for a user. 672 Args: 673 user_id: user ID of the user 674 """ 675 676 def delete_ratelimit_txn(txn: LoggingTransaction) -> None: 677 row = self.db_pool.simple_select_one_txn( 678 txn, 679 table="ratelimit_override", 680 keyvalues={"user_id": user_id}, 681 retcols=["user_id"], 682 allow_none=True, 683 ) 684 685 if not row: 686 return 687 688 # They are there, delete them. 689 self.db_pool.simple_delete_one_txn( 690 txn, "ratelimit_override", keyvalues={"user_id": user_id} 691 ) 692 693 self._invalidate_cache_and_stream( 694 txn, self.get_ratelimit_for_user, (user_id,) 695 ) 696 697 await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn) 698 699 @cached() 700 async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]: 701 """Get the retention policy for a given room. 702 703 If no retention policy has been found for this room, returns a policy defined 704 by the configured default policy (which has None as both the 'min_lifetime' and 705 the 'max_lifetime' if no default policy has been defined in the server's 706 configuration). 707 708 Args: 709 room_id: The ID of the room to get the retention policy of. 710 711 Returns: 712 A dict containing "min_lifetime" and "max_lifetime" for this room. 713 """ 714 715 def get_retention_policy_for_room_txn( 716 txn: LoggingTransaction, 717 ) -> List[Dict[str, Optional[int]]]: 718 txn.execute( 719 """ 720 SELECT min_lifetime, max_lifetime FROM room_retention 721 INNER JOIN current_state_events USING (event_id, room_id) 722 WHERE room_id = ?; 723 """, 724 (room_id,), 725 ) 726 727 return self.db_pool.cursor_to_dict(txn) 728 729 ret = await self.db_pool.runInteraction( 730 "get_retention_policy_for_room", 731 get_retention_policy_for_room_txn, 732 ) 733 734 # If we don't know this room ID, ret will be None, in this case return the default 735 # policy. 736 if not ret: 737 return { 738 "min_lifetime": self.config.retention.retention_default_min_lifetime, 739 "max_lifetime": self.config.retention.retention_default_max_lifetime, 740 } 741 742 min_lifetime = ret[0]["min_lifetime"] 743 max_lifetime = ret[0]["max_lifetime"] 744 745 # If one of the room's policy's attributes isn't defined, use the matching 746 # attribute from the default policy. 747 # The default values will be None if no default policy has been defined, or if one 748 # of the attributes is missing from the default policy. 749 if min_lifetime is None: 750 min_lifetime = self.config.retention.retention_default_min_lifetime 751 752 if max_lifetime is None: 753 max_lifetime = self.config.retention.retention_default_max_lifetime 754 755 return { 756 "min_lifetime": min_lifetime, 757 "max_lifetime": max_lifetime, 758 } 759 760 async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: 761 """Retrieves all the local and remote media MXC URIs in a given room 762 763 Args: 764 room_id 765 766 Returns: 767 The local and remote media as a lists of the media IDs. 768 """ 769 770 def _get_media_mxcs_in_room_txn( 771 txn: LoggingTransaction, 772 ) -> Tuple[List[str], List[str]]: 773 local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) 774 local_media_mxcs = [] 775 remote_media_mxcs = [] 776 777 # Convert the IDs to MXC URIs 778 for media_id in local_mxcs: 779 local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) 780 for hostname, media_id in remote_mxcs: 781 remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) 782 783 return local_media_mxcs, remote_media_mxcs 784 785 return await self.db_pool.runInteraction( 786 "get_media_ids_in_room", _get_media_mxcs_in_room_txn 787 ) 788 789 async def quarantine_media_ids_in_room( 790 self, room_id: str, quarantined_by: str 791 ) -> int: 792 """For a room loops through all events with media and quarantines 793 the associated media 794 """ 795 796 logger.info("Quarantining media in room: %s", room_id) 797 798 def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int: 799 local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) 800 return self._quarantine_media_txn( 801 txn, local_mxcs, remote_mxcs, quarantined_by 802 ) 803 804 return await self.db_pool.runInteraction( 805 "quarantine_media_in_room", _quarantine_media_in_room_txn 806 ) 807 808 def _get_media_mxcs_in_room_txn( 809 self, txn: LoggingTransaction, room_id: str 810 ) -> Tuple[List[str], List[Tuple[str, str]]]: 811 """Retrieves all the local and remote media MXC URIs in a given room 812 813 Returns: 814 The local and remote media as a lists of tuples where the key is 815 the hostname and the value is the media ID. 816 """ 817 sql = """ 818 SELECT stream_ordering, json FROM events 819 JOIN event_json USING (room_id, event_id) 820 WHERE room_id = ? 821 %(where_clause)s 822 AND contains_url = ? AND outlier = ? 823 ORDER BY stream_ordering DESC 824 LIMIT ? 825 """ 826 txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100)) 827 828 local_media_mxcs = [] 829 remote_media_mxcs = [] 830 831 while True: 832 next_token = None 833 for stream_ordering, content_json in txn: 834 next_token = stream_ordering 835 event_json = db_to_json(content_json) 836 content = event_json["content"] 837 content_url = content.get("url") 838 thumbnail_url = content.get("info", {}).get("thumbnail_url") 839 840 for url in (content_url, thumbnail_url): 841 if not url: 842 continue 843 matches = MXC_REGEX.match(url) 844 if matches: 845 hostname = matches.group(1) 846 media_id = matches.group(2) 847 if hostname == self.hs.hostname: 848 local_media_mxcs.append(media_id) 849 else: 850 remote_media_mxcs.append((hostname, media_id)) 851 852 if next_token is None: 853 # We've gone through the whole room, so we're finished. 854 break 855 856 txn.execute( 857 sql % {"where_clause": "AND stream_ordering < ?"}, 858 (room_id, next_token, True, False, 100), 859 ) 860 861 return local_media_mxcs, remote_media_mxcs 862 863 async def quarantine_media_by_id( 864 self, 865 server_name: str, 866 media_id: str, 867 quarantined_by: Optional[str], 868 ) -> int: 869 """quarantines or unquarantines a single local or remote media id 870 871 Args: 872 server_name: The name of the server that holds this media 873 media_id: The ID of the media to be quarantined 874 quarantined_by: The user ID that initiated the quarantine request 875 If it is `None` media will be removed from quarantine 876 """ 877 logger.info("Quarantining media: %s/%s", server_name, media_id) 878 is_local = server_name == self.config.server.server_name 879 880 def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int: 881 local_mxcs = [media_id] if is_local else [] 882 remote_mxcs = [(server_name, media_id)] if not is_local else [] 883 884 return self._quarantine_media_txn( 885 txn, local_mxcs, remote_mxcs, quarantined_by 886 ) 887 888 return await self.db_pool.runInteraction( 889 "quarantine_media_by_user", _quarantine_media_by_id_txn 890 ) 891 892 async def quarantine_media_ids_by_user( 893 self, user_id: str, quarantined_by: str 894 ) -> int: 895 """quarantines all local media associated with a single user 896 897 Args: 898 user_id: The ID of the user to quarantine media of 899 quarantined_by: The ID of the user who made the quarantine request 900 """ 901 902 def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int: 903 local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) 904 return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) 905 906 return await self.db_pool.runInteraction( 907 "quarantine_media_by_user", _quarantine_media_by_user_txn 908 ) 909 910 def _get_media_ids_by_user_txn( 911 self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True 912 ) -> List[str]: 913 """Retrieves local media IDs by a given user 914 915 Args: 916 txn (cursor) 917 user_id: The ID of the user to retrieve media IDs of 918 919 Returns: 920 The local and remote media as a lists of tuples where the key is 921 the hostname and the value is the media ID. 922 """ 923 # Local media 924 sql = """ 925 SELECT media_id 926 FROM local_media_repository 927 WHERE user_id = ? 928 """ 929 if filter_quarantined: 930 sql += "AND quarantined_by IS NULL" 931 txn.execute(sql, (user_id,)) 932 933 local_media_ids = [row[0] for row in txn] 934 935 # TODO: Figure out all remote media a user has referenced in a message 936 937 return local_media_ids 938 939 def _quarantine_media_txn( 940 self, 941 txn: LoggingTransaction, 942 local_mxcs: List[str], 943 remote_mxcs: List[Tuple[str, str]], 944 quarantined_by: Optional[str], 945 ) -> int: 946 """Quarantine and unquarantine local and remote media items 947 948 Args: 949 txn (cursor) 950 local_mxcs: A list of local mxc URLs 951 remote_mxcs: A list of (remote server, media id) tuples representing 952 remote mxc URLs 953 quarantined_by: The ID of the user who initiated the quarantine request 954 If it is `None` media will be removed from quarantine 955 Returns: 956 The total number of media items quarantined 957 """ 958 959 # Update all the tables to set the quarantined_by flag 960 sql = """ 961 UPDATE local_media_repository 962 SET quarantined_by = ? 963 WHERE media_id = ? 964 """ 965 966 # set quarantine 967 if quarantined_by is not None: 968 sql += "AND safe_from_quarantine = ?" 969 txn.executemany( 970 sql, [(quarantined_by, media_id, False) for media_id in local_mxcs] 971 ) 972 # remove from quarantine 973 else: 974 txn.executemany( 975 sql, [(quarantined_by, media_id) for media_id in local_mxcs] 976 ) 977 978 # Note that a rowcount of -1 can be used to indicate no rows were affected. 979 total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 980 981 txn.executemany( 982 """ 983 UPDATE remote_media_cache 984 SET quarantined_by = ? 985 WHERE media_origin = ? AND media_id = ? 986 """, 987 ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs), 988 ) 989 total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0 990 991 return total_media_quarantined 992 993 async def get_rooms_for_retention_period_in_range( 994 self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False 995 ) -> Dict[str, Dict[str, Optional[int]]]: 996 """Retrieves all of the rooms within the given retention range. 997 998 Optionally includes the rooms which don't have a retention policy. 999 1000 Args: 1001 min_ms: Duration in milliseconds that define the lower limit of 1002 the range to handle (exclusive). If None, doesn't set a lower limit. 1003 max_ms: Duration in milliseconds that define the upper limit of 1004 the range to handle (inclusive). If None, doesn't set an upper limit. 1005 include_null: Whether to include rooms which retention policy is NULL 1006 in the returned set. 1007 1008 Returns: 1009 The rooms within this range, along with their retention 1010 policy. The key is "room_id", and maps to a dict describing the retention 1011 policy associated with this room ID. The keys for this nested dict are 1012 "min_lifetime" (int|None), and "max_lifetime" (int|None). 1013 """ 1014 1015 def get_rooms_for_retention_period_in_range_txn( 1016 txn: LoggingTransaction, 1017 ) -> Dict[str, Dict[str, Optional[int]]]: 1018 range_conditions = [] 1019 args = [] 1020 1021 if min_ms is not None: 1022 range_conditions.append("max_lifetime > ?") 1023 args.append(min_ms) 1024 1025 if max_ms is not None: 1026 range_conditions.append("max_lifetime <= ?") 1027 args.append(max_ms) 1028 1029 # Do a first query which will retrieve the rooms that have a retention policy 1030 # in their current state. 1031 sql = """ 1032 SELECT room_id, min_lifetime, max_lifetime FROM room_retention 1033 INNER JOIN current_state_events USING (event_id, room_id) 1034 """ 1035 1036 if len(range_conditions): 1037 sql += " WHERE (" + " AND ".join(range_conditions) + ")" 1038 1039 if include_null: 1040 sql += " OR max_lifetime IS NULL" 1041 1042 txn.execute(sql, args) 1043 1044 rows = self.db_pool.cursor_to_dict(txn) 1045 rooms_dict = {} 1046 1047 for row in rows: 1048 rooms_dict[row["room_id"]] = { 1049 "min_lifetime": row["min_lifetime"], 1050 "max_lifetime": row["max_lifetime"], 1051 } 1052 1053 if include_null: 1054 # If required, do a second query that retrieves all of the rooms we know 1055 # of so we can handle rooms with no retention policy. 1056 sql = "SELECT DISTINCT room_id FROM current_state_events" 1057 1058 txn.execute(sql) 1059 1060 rows = self.db_pool.cursor_to_dict(txn) 1061 1062 # If a room isn't already in the dict (i.e. it doesn't have a retention 1063 # policy in its state), add it with a null policy. 1064 for row in rows: 1065 if row["room_id"] not in rooms_dict: 1066 rooms_dict[row["room_id"]] = { 1067 "min_lifetime": None, 1068 "max_lifetime": None, 1069 } 1070 1071 return rooms_dict 1072 1073 return await self.db_pool.runInteraction( 1074 "get_rooms_for_retention_period_in_range", 1075 get_rooms_for_retention_period_in_range_txn, 1076 ) 1077 1078 1079class _BackgroundUpdates: 1080 REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" 1081 ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" 1082 POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" 1083 REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" 1084 POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" 1085 1086 1087_REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( 1088 "DROP TRIGGER populate_min_depth2_trigger ON room_depth", 1089 "DROP FUNCTION populate_min_depth2()", 1090 "ALTER TABLE room_depth DROP COLUMN min_depth", 1091 "ALTER TABLE room_depth RENAME COLUMN min_depth2 TO min_depth", 1092) 1093 1094 1095class RoomBackgroundUpdateStore(SQLBaseStore): 1096 def __init__( 1097 self, 1098 database: DatabasePool, 1099 db_conn: LoggingDatabaseConnection, 1100 hs: "HomeServer", 1101 ): 1102 super().__init__(database, db_conn, hs) 1103 1104 self.db_pool.updates.register_background_update_handler( 1105 "insert_room_retention", 1106 self._background_insert_retention, 1107 ) 1108 1109 self.db_pool.updates.register_background_update_handler( 1110 _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, 1111 self._remove_tombstoned_rooms_from_directory, 1112 ) 1113 1114 self.db_pool.updates.register_background_update_handler( 1115 _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN, 1116 self._background_add_rooms_room_version_column, 1117 ) 1118 1119 # BG updates to change the type of room_depth.min_depth 1120 self.db_pool.updates.register_background_update_handler( 1121 _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, 1122 self._background_populate_room_depth_min_depth2, 1123 ) 1124 self.db_pool.updates.register_background_update_handler( 1125 _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH, 1126 self._background_replace_room_depth_min_depth, 1127 ) 1128 1129 self.db_pool.updates.register_background_update_handler( 1130 _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, 1131 self._background_populate_rooms_creator_column, 1132 ) 1133 1134 async def _background_insert_retention( 1135 self, progress: JsonDict, batch_size: int 1136 ) -> int: 1137 """Retrieves a list of all rooms within a range and inserts an entry for each of 1138 them into the room_retention table. 1139 NULLs the property's columns if missing from the retention event in the room's 1140 state (or NULLs all of them if there's no retention event in the room's state), 1141 so that we fall back to the server's retention policy. 1142 """ 1143 1144 last_room = progress.get("room_id", "") 1145 1146 def _background_insert_retention_txn(txn: LoggingTransaction) -> bool: 1147 txn.execute( 1148 """ 1149 SELECT state.room_id, state.event_id, events.json 1150 FROM current_state_events as state 1151 LEFT JOIN event_json AS events ON (state.event_id = events.event_id) 1152 WHERE state.room_id > ? AND state.type = '%s' 1153 ORDER BY state.room_id ASC 1154 LIMIT ?; 1155 """ 1156 % EventTypes.Retention, 1157 (last_room, batch_size), 1158 ) 1159 1160 rows = self.db_pool.cursor_to_dict(txn) 1161 1162 if not rows: 1163 return True 1164 1165 for row in rows: 1166 if not row["json"]: 1167 retention_policy = {} 1168 else: 1169 ev = db_to_json(row["json"]) 1170 retention_policy = ev["content"] 1171 1172 self.db_pool.simple_insert_txn( 1173 txn=txn, 1174 table="room_retention", 1175 values={ 1176 "room_id": row["room_id"], 1177 "event_id": row["event_id"], 1178 "min_lifetime": retention_policy.get("min_lifetime"), 1179 "max_lifetime": retention_policy.get("max_lifetime"), 1180 }, 1181 ) 1182 1183 logger.info("Inserted %d rows into room_retention", len(rows)) 1184 1185 self.db_pool.updates._background_update_progress_txn( 1186 txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} 1187 ) 1188 1189 if batch_size > len(rows): 1190 return True 1191 else: 1192 return False 1193 1194 end = await self.db_pool.runInteraction( 1195 "insert_room_retention", 1196 _background_insert_retention_txn, 1197 ) 1198 1199 if end: 1200 await self.db_pool.updates._end_background_update("insert_room_retention") 1201 1202 return batch_size 1203 1204 async def _background_add_rooms_room_version_column( 1205 self, progress: JsonDict, batch_size: int 1206 ) -> int: 1207 """Background update to go and add room version information to `rooms` 1208 table from `current_state_events` table. 1209 """ 1210 1211 last_room_id = progress.get("room_id", "") 1212 1213 def _background_add_rooms_room_version_column_txn( 1214 txn: LoggingTransaction, 1215 ) -> bool: 1216 sql = """ 1217 SELECT room_id, json FROM current_state_events 1218 INNER JOIN event_json USING (room_id, event_id) 1219 WHERE room_id > ? AND type = 'm.room.create' AND state_key = '' 1220 ORDER BY room_id 1221 LIMIT ? 1222 """ 1223 1224 txn.execute(sql, (last_room_id, batch_size)) 1225 1226 updates = [] 1227 for room_id, event_json in txn: 1228 event_dict = db_to_json(event_json) 1229 room_version_id = event_dict.get("content", {}).get( 1230 "room_version", RoomVersions.V1.identifier 1231 ) 1232 1233 creator = event_dict.get("content").get("creator") 1234 1235 updates.append((room_id, creator, room_version_id)) 1236 1237 if not updates: 1238 return True 1239 1240 new_last_room_id = "" 1241 for room_id, creator, room_version_id in updates: 1242 # We upsert here just in case we don't already have a row, 1243 # mainly for paranoia as much badness would happen if we don't 1244 # insert the row and then try and get the room version for the 1245 # room. 1246 self.db_pool.simple_upsert_txn( 1247 txn, 1248 table="rooms", 1249 keyvalues={"room_id": room_id}, 1250 values={"room_version": room_version_id}, 1251 insertion_values={"is_public": False, "creator": creator}, 1252 ) 1253 new_last_room_id = room_id 1254 1255 self.db_pool.updates._background_update_progress_txn( 1256 txn, 1257 _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN, 1258 {"room_id": new_last_room_id}, 1259 ) 1260 1261 return False 1262 1263 end = await self.db_pool.runInteraction( 1264 "_background_add_rooms_room_version_column", 1265 _background_add_rooms_room_version_column_txn, 1266 ) 1267 1268 if end: 1269 await self.db_pool.updates._end_background_update( 1270 _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN 1271 ) 1272 1273 return batch_size 1274 1275 async def _remove_tombstoned_rooms_from_directory( 1276 self, progress: JsonDict, batch_size: int 1277 ) -> int: 1278 """Removes any rooms with tombstone events from the room directory 1279 1280 Nowadays this is handled by the room upgrade handler, but we may have some 1281 that got left behind 1282 """ 1283 1284 last_room = progress.get("room_id", "") 1285 1286 def _get_rooms(txn: LoggingTransaction) -> List[str]: 1287 txn.execute( 1288 """ 1289 SELECT room_id 1290 FROM rooms r 1291 INNER JOIN current_state_events cse USING (room_id) 1292 WHERE room_id > ? AND r.is_public 1293 AND cse.type = '%s' AND cse.state_key = '' 1294 ORDER BY room_id ASC 1295 LIMIT ?; 1296 """ 1297 % EventTypes.Tombstone, 1298 (last_room, batch_size), 1299 ) 1300 1301 return [row[0] for row in txn] 1302 1303 rooms = await self.db_pool.runInteraction( 1304 "get_tombstoned_directory_rooms", _get_rooms 1305 ) 1306 1307 if not rooms: 1308 await self.db_pool.updates._end_background_update( 1309 _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE 1310 ) 1311 return 0 1312 1313 for room_id in rooms: 1314 logger.info("Removing tombstoned room %s from the directory", room_id) 1315 await self.set_room_is_public(room_id, False) 1316 1317 await self.db_pool.updates._background_update_progress( 1318 _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]} 1319 ) 1320 1321 return len(rooms) 1322 1323 @abstractmethod 1324 def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]: 1325 # this will need to be implemented if a background update is performed with 1326 # existing (tombstoned, public) rooms in the database. 1327 # 1328 # It's overridden by RoomStore for the synapse master. 1329 raise NotImplementedError() 1330 1331 async def has_auth_chain_index(self, room_id: str) -> bool: 1332 """Check if the room has (or can have) a chain cover index. 1333 1334 Defaults to True if we don't have an entry in `rooms` table nor any 1335 events for the room. 1336 """ 1337 1338 has_auth_chain_index = await self.db_pool.simple_select_one_onecol( 1339 table="rooms", 1340 keyvalues={"room_id": room_id}, 1341 retcol="has_auth_chain_index", 1342 desc="has_auth_chain_index", 1343 allow_none=True, 1344 ) 1345 1346 if has_auth_chain_index: 1347 return True 1348 1349 # It's possible that we already have events for the room in our DB 1350 # without a corresponding room entry. If we do then we don't want to 1351 # mark the room as having an auth chain cover index. 1352 max_ordering = await self.db_pool.simple_select_one_onecol( 1353 table="events", 1354 keyvalues={"room_id": room_id}, 1355 retcol="MAX(stream_ordering)", 1356 allow_none=True, 1357 desc="has_auth_chain_index_fallback", 1358 ) 1359 1360 return max_ordering is None 1361 1362 async def _background_populate_room_depth_min_depth2( 1363 self, progress: JsonDict, batch_size: int 1364 ) -> int: 1365 """Populate room_depth.min_depth2 1366 1367 This is to deal with the fact that min_depth was initially created as a 1368 32-bit integer field. 1369 """ 1370 1371 def process(txn: LoggingTransaction) -> int: 1372 last_room = progress.get("last_room", "") 1373 txn.execute( 1374 """ 1375 UPDATE room_depth SET min_depth2=min_depth 1376 WHERE room_id IN ( 1377 SELECT room_id FROM room_depth WHERE room_id > ? 1378 ORDER BY room_id LIMIT ? 1379 ) 1380 RETURNING room_id; 1381 """, 1382 (last_room, batch_size), 1383 ) 1384 row_count = txn.rowcount 1385 if row_count == 0: 1386 return 0 1387 last_room = max(row[0] for row in txn) 1388 logger.info("populated room_depth up to %s", last_room) 1389 1390 self.db_pool.updates._background_update_progress_txn( 1391 txn, 1392 _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2, 1393 {"last_room": last_room}, 1394 ) 1395 return row_count 1396 1397 result = await self.db_pool.runInteraction( 1398 "_background_populate_min_depth2", process 1399 ) 1400 1401 if result != 0: 1402 return result 1403 1404 await self.db_pool.updates._end_background_update( 1405 _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2 1406 ) 1407 return 0 1408 1409 async def _background_replace_room_depth_min_depth( 1410 self, progress: JsonDict, batch_size: int 1411 ) -> int: 1412 """Drop the old 'min_depth' column and rename 'min_depth2' into its place.""" 1413 1414 def process(txn: Cursor) -> None: 1415 for sql in _REPLACE_ROOM_DEPTH_SQL_COMMANDS: 1416 logger.info("completing room_depth migration: %s", sql) 1417 txn.execute(sql) 1418 1419 await self.db_pool.runInteraction("_background_replace_room_depth", process) 1420 1421 await self.db_pool.updates._end_background_update( 1422 _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH, 1423 ) 1424 1425 return 0 1426 1427 async def _background_populate_rooms_creator_column( 1428 self, progress: JsonDict, batch_size: int 1429 ) -> int: 1430 """Background update to go and add creator information to `rooms` 1431 table from `current_state_events` table. 1432 """ 1433 1434 last_room_id = progress.get("room_id", "") 1435 1436 def _background_populate_rooms_creator_column_txn( 1437 txn: LoggingTransaction, 1438 ) -> bool: 1439 sql = """ 1440 SELECT room_id, json FROM event_json 1441 INNER JOIN rooms AS room USING (room_id) 1442 INNER JOIN current_state_events AS state_event USING (room_id, event_id) 1443 WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = '' 1444 ORDER BY room_id 1445 LIMIT ? 1446 """ 1447 1448 txn.execute(sql, (last_room_id, batch_size)) 1449 room_id_to_create_event_results = txn.fetchall() 1450 1451 new_last_room_id = "" 1452 for room_id, event_json in room_id_to_create_event_results: 1453 event_dict = db_to_json(event_json) 1454 1455 creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR) 1456 1457 self.db_pool.simple_update_txn( 1458 txn, 1459 table="rooms", 1460 keyvalues={"room_id": room_id}, 1461 updatevalues={"creator": creator}, 1462 ) 1463 new_last_room_id = room_id 1464 1465 if new_last_room_id == "": 1466 return True 1467 1468 self.db_pool.updates._background_update_progress_txn( 1469 txn, 1470 _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, 1471 {"room_id": new_last_room_id}, 1472 ) 1473 1474 return False 1475 1476 end = await self.db_pool.runInteraction( 1477 "_background_populate_rooms_creator_column", 1478 _background_populate_rooms_creator_column_txn, 1479 ) 1480 1481 if end: 1482 await self.db_pool.updates._end_background_update( 1483 _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN 1484 ) 1485 1486 return batch_size 1487 1488 1489class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): 1490 def __init__( 1491 self, 1492 database: DatabasePool, 1493 db_conn: LoggingDatabaseConnection, 1494 hs: "HomeServer", 1495 ): 1496 super().__init__(database, db_conn, hs) 1497 1498 self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") 1499 1500 async def upsert_room_on_join( 1501 self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] 1502 ) -> None: 1503 """Ensure that the room is stored in the table 1504 1505 Called when we join a room over federation, and overwrites any room version 1506 currently in the table. 1507 """ 1508 # It's possible that we already have events for the room in our DB 1509 # without a corresponding room entry. If we do then we don't want to 1510 # mark the room as having an auth chain cover index. 1511 has_auth_chain_index = await self.has_auth_chain_index(room_id) 1512 1513 create_event = None 1514 for e in auth_events: 1515 if (e.type, e.state_key) == (EventTypes.Create, ""): 1516 create_event = e 1517 break 1518 1519 if create_event is None: 1520 # If the state doesn't have a create event then the room is 1521 # invalid, and it would fail auth checks anyway. 1522 raise StoreError(400, "No create event in state") 1523 1524 room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) 1525 1526 if not isinstance(room_creator, str): 1527 # If the create event does not have a creator then the room is 1528 # invalid, and it would fail auth checks anyway. 1529 raise StoreError(400, "No creator defined on the create event") 1530 1531 await self.db_pool.simple_upsert( 1532 desc="upsert_room_on_join", 1533 table="rooms", 1534 keyvalues={"room_id": room_id}, 1535 values={"room_version": room_version.identifier}, 1536 insertion_values={ 1537 "is_public": False, 1538 "creator": room_creator, 1539 "has_auth_chain_index": has_auth_chain_index, 1540 }, 1541 # rooms has a unique constraint on room_id, so no need to lock when doing an 1542 # emulated upsert. 1543 lock=False, 1544 ) 1545 1546 async def maybe_store_room_on_outlier_membership( 1547 self, room_id: str, room_version: RoomVersion 1548 ) -> None: 1549 """ 1550 When we receive an invite or any other event over federation that may relate to a room 1551 we are not in, store the version of the room if we don't already know the room version. 1552 """ 1553 # It's possible that we already have events for the room in our DB 1554 # without a corresponding room entry. If we do then we don't want to 1555 # mark the room as having an auth chain cover index. 1556 has_auth_chain_index = await self.has_auth_chain_index(room_id) 1557 1558 await self.db_pool.simple_upsert( 1559 desc="maybe_store_room_on_outlier_membership", 1560 table="rooms", 1561 keyvalues={"room_id": room_id}, 1562 values={}, 1563 insertion_values={ 1564 "room_version": room_version.identifier, 1565 "is_public": False, 1566 # We don't worry about setting the `creator` here because 1567 # we don't process any messages in a room while a user is 1568 # invited (only after the join). 1569 "creator": "", 1570 "has_auth_chain_index": has_auth_chain_index, 1571 }, 1572 # rooms has a unique constraint on room_id, so no need to lock when doing an 1573 # emulated upsert. 1574 lock=False, 1575 ) 1576 1577 async def set_room_is_public(self, room_id: str, is_public: bool) -> None: 1578 await self.db_pool.simple_update_one( 1579 table="rooms", 1580 keyvalues={"room_id": room_id}, 1581 updatevalues={"is_public": is_public}, 1582 desc="set_room_is_public", 1583 ) 1584 1585 self.hs.get_notifier().on_new_replication_data() 1586 1587 async def set_room_is_public_appservice( 1588 self, room_id: str, appservice_id: str, network_id: str, is_public: bool 1589 ) -> None: 1590 """Edit the appservice/network specific public room list. 1591 1592 Each appservice can have a number of published room lists associated 1593 with them, keyed off of an appservice defined `network_id`, which 1594 basically represents a single instance of a bridge to a third party 1595 network. 1596 1597 Args: 1598 room_id 1599 appservice_id 1600 network_id 1601 is_public: Whether to publish or unpublish the room from the list. 1602 """ 1603 1604 if is_public: 1605 await self.db_pool.simple_upsert( 1606 table="appservice_room_list", 1607 keyvalues={ 1608 "appservice_id": appservice_id, 1609 "network_id": network_id, 1610 "room_id": room_id, 1611 }, 1612 values={}, 1613 insertion_values={ 1614 "appservice_id": appservice_id, 1615 "network_id": network_id, 1616 "room_id": room_id, 1617 }, 1618 desc="set_room_is_public_appservice_true", 1619 ) 1620 else: 1621 await self.db_pool.simple_delete( 1622 table="appservice_room_list", 1623 keyvalues={ 1624 "appservice_id": appservice_id, 1625 "network_id": network_id, 1626 "room_id": room_id, 1627 }, 1628 desc="set_room_is_public_appservice_false", 1629 ) 1630 1631 self.hs.get_notifier().on_new_replication_data() 1632 1633 async def add_event_report( 1634 self, 1635 room_id: str, 1636 event_id: str, 1637 user_id: str, 1638 reason: Optional[str], 1639 content: JsonDict, 1640 received_ts: int, 1641 ) -> None: 1642 next_id = self._event_reports_id_gen.get_next() 1643 await self.db_pool.simple_insert( 1644 table="event_reports", 1645 values={ 1646 "id": next_id, 1647 "received_ts": received_ts, 1648 "room_id": room_id, 1649 "event_id": event_id, 1650 "user_id": user_id, 1651 "reason": reason, 1652 "content": json_encoder.encode(content), 1653 }, 1654 desc="add_event_report", 1655 ) 1656 1657 async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: 1658 """Retrieve an event report 1659 1660 Args: 1661 report_id: ID of reported event in database 1662 Returns: 1663 event_report: json list of information from event report 1664 """ 1665 1666 def _get_event_report_txn( 1667 txn: LoggingTransaction, report_id: int 1668 ) -> Optional[Dict[str, Any]]: 1669 1670 sql = """ 1671 SELECT 1672 er.id, 1673 er.received_ts, 1674 er.room_id, 1675 er.event_id, 1676 er.user_id, 1677 er.content, 1678 events.sender, 1679 room_stats_state.canonical_alias, 1680 room_stats_state.name, 1681 event_json.json AS event_json 1682 FROM event_reports AS er 1683 LEFT JOIN events 1684 ON events.event_id = er.event_id 1685 JOIN event_json 1686 ON event_json.event_id = er.event_id 1687 JOIN room_stats_state 1688 ON room_stats_state.room_id = er.room_id 1689 WHERE er.id = ? 1690 """ 1691 1692 txn.execute(sql, [report_id]) 1693 row = txn.fetchone() 1694 1695 if not row: 1696 return None 1697 1698 event_report = { 1699 "id": row[0], 1700 "received_ts": row[1], 1701 "room_id": row[2], 1702 "event_id": row[3], 1703 "user_id": row[4], 1704 "score": db_to_json(row[5]).get("score"), 1705 "reason": db_to_json(row[5]).get("reason"), 1706 "sender": row[6], 1707 "canonical_alias": row[7], 1708 "name": row[8], 1709 "event_json": db_to_json(row[9]), 1710 } 1711 1712 return event_report 1713 1714 return await self.db_pool.runInteraction( 1715 "get_event_report", _get_event_report_txn, report_id 1716 ) 1717 1718 async def get_event_reports_paginate( 1719 self, 1720 start: int, 1721 limit: int, 1722 direction: str = "b", 1723 user_id: Optional[str] = None, 1724 room_id: Optional[str] = None, 1725 ) -> Tuple[List[Dict[str, Any]], int]: 1726 """Retrieve a paginated list of event reports 1727 1728 Args: 1729 start: event offset to begin the query from 1730 limit: number of rows to retrieve 1731 direction: Whether to fetch the most recent first (`"b"`) or the 1732 oldest first (`"f"`) 1733 user_id: search for user_id. Ignored if user_id is None 1734 room_id: search for room_id. Ignored if room_id is None 1735 Returns: 1736 event_reports: json list of event reports 1737 count: total number of event reports matching the filter criteria 1738 """ 1739 1740 def _get_event_reports_paginate_txn( 1741 txn: LoggingTransaction, 1742 ) -> Tuple[List[Dict[str, Any]], int]: 1743 filters = [] 1744 args: List[object] = [] 1745 1746 if user_id: 1747 filters.append("er.user_id LIKE ?") 1748 args.extend(["%" + user_id + "%"]) 1749 if room_id: 1750 filters.append("er.room_id LIKE ?") 1751 args.extend(["%" + room_id + "%"]) 1752 1753 if direction == "b": 1754 order = "DESC" 1755 else: 1756 order = "ASC" 1757 1758 where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" 1759 1760 sql = """ 1761 SELECT COUNT(*) as total_event_reports 1762 FROM event_reports AS er 1763 {} 1764 """.format( 1765 where_clause 1766 ) 1767 txn.execute(sql, args) 1768 count = cast(Tuple[int], txn.fetchone())[0] 1769 1770 sql = """ 1771 SELECT 1772 er.id, 1773 er.received_ts, 1774 er.room_id, 1775 er.event_id, 1776 er.user_id, 1777 er.content, 1778 events.sender, 1779 room_stats_state.canonical_alias, 1780 room_stats_state.name 1781 FROM event_reports AS er 1782 LEFT JOIN events 1783 ON events.event_id = er.event_id 1784 JOIN room_stats_state 1785 ON room_stats_state.room_id = er.room_id 1786 {where_clause} 1787 ORDER BY er.received_ts {order} 1788 LIMIT ? 1789 OFFSET ? 1790 """.format( 1791 where_clause=where_clause, 1792 order=order, 1793 ) 1794 1795 args += [limit, start] 1796 txn.execute(sql, args) 1797 1798 event_reports = [] 1799 for row in txn: 1800 try: 1801 s = db_to_json(row[5]).get("score") 1802 r = db_to_json(row[5]).get("reason") 1803 except Exception: 1804 logger.error("Unable to parse json from event_reports: %s", row[0]) 1805 continue 1806 event_reports.append( 1807 { 1808 "id": row[0], 1809 "received_ts": row[1], 1810 "room_id": row[2], 1811 "event_id": row[3], 1812 "user_id": row[4], 1813 "score": s, 1814 "reason": r, 1815 "sender": row[6], 1816 "canonical_alias": row[7], 1817 "name": row[8], 1818 } 1819 ) 1820 1821 return event_reports, count 1822 1823 return await self.db_pool.runInteraction( 1824 "get_event_reports_paginate", _get_event_reports_paginate_txn 1825 ) 1826 1827 async def block_room(self, room_id: str, user_id: str) -> None: 1828 """Marks the room as blocked. 1829 1830 Can be called multiple times (though we'll only track the last user to 1831 block this room). 1832 1833 Can be called on a room unknown to this homeserver. 1834 1835 Args: 1836 room_id: Room to block 1837 user_id: Who blocked it 1838 """ 1839 await self.db_pool.simple_upsert( 1840 table="blocked_rooms", 1841 keyvalues={"room_id": room_id}, 1842 values={}, 1843 insertion_values={"user_id": user_id}, 1844 desc="block_room", 1845 ) 1846 await self.db_pool.runInteraction( 1847 "block_room_invalidation", 1848 self._invalidate_cache_and_stream, 1849 self.is_room_blocked, 1850 (room_id,), 1851 ) 1852 1853 async def unblock_room(self, room_id: str) -> None: 1854 """Remove the room from blocking list. 1855 1856 Args: 1857 room_id: Room to unblock 1858 """ 1859 await self.db_pool.simple_delete( 1860 table="blocked_rooms", 1861 keyvalues={"room_id": room_id}, 1862 desc="unblock_room", 1863 ) 1864 await self.db_pool.runInteraction( 1865 "block_room_invalidation", 1866 self._invalidate_cache_and_stream, 1867 self.is_room_blocked, 1868 (room_id,), 1869 ) 1870