1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2019 The Matrix.org Foundation C.I.C.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17from abc import abstractmethod
18from enum import Enum
19from typing import (
20    TYPE_CHECKING,
21    Any,
22    Awaitable,
23    Dict,
24    List,
25    Optional,
26    Tuple,
27    Union,
28    cast,
29)
30
31import attr
32
33from synapse.api.constants import EventContentFields, EventTypes, JoinRules
34from synapse.api.errors import StoreError
35from synapse.api.room_versions import RoomVersion, RoomVersions
36from synapse.events import EventBase
37from synapse.storage._base import SQLBaseStore, db_to_json
38from synapse.storage.database import (
39    DatabasePool,
40    LoggingDatabaseConnection,
41    LoggingTransaction,
42)
43from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
44from synapse.storage.types import Cursor
45from synapse.storage.util.id_generators import IdGenerator
46from synapse.types import JsonDict, ThirdPartyInstanceID
47from synapse.util import json_encoder
48from synapse.util.caches.descriptors import cached
49from synapse.util.stringutils import MXC_REGEX
50
51if TYPE_CHECKING:
52    from synapse.server import HomeServer
53
54logger = logging.getLogger(__name__)
55
56
57@attr.s(slots=True, frozen=True, auto_attribs=True)
58class RatelimitOverride:
59    messages_per_second: int
60    burst_count: int
61
62
63class RoomSortOrder(Enum):
64    """
65    Enum to define the sorting method used when returning rooms with get_rooms_paginate
66
67    NAME = sort rooms alphabetically by name
68    JOINED_MEMBERS = sort rooms by membership size, highest to lowest
69    """
70
71    # ALPHABETICAL and SIZE are deprecated.
72    # ALPHABETICAL is the same as NAME.
73    ALPHABETICAL = "alphabetical"
74    # SIZE is the same as JOINED_MEMBERS.
75    SIZE = "size"
76    NAME = "name"
77    CANONICAL_ALIAS = "canonical_alias"
78    JOINED_MEMBERS = "joined_members"
79    JOINED_LOCAL_MEMBERS = "joined_local_members"
80    VERSION = "version"
81    CREATOR = "creator"
82    ENCRYPTION = "encryption"
83    FEDERATABLE = "federatable"
84    PUBLIC = "public"
85    JOIN_RULES = "join_rules"
86    GUEST_ACCESS = "guest_access"
87    HISTORY_VISIBILITY = "history_visibility"
88    STATE_EVENTS = "state_events"
89
90
91class RoomWorkerStore(CacheInvalidationWorkerStore):
92    def __init__(
93        self,
94        database: DatabasePool,
95        db_conn: LoggingDatabaseConnection,
96        hs: "HomeServer",
97    ):
98        super().__init__(database, db_conn, hs)
99
100        self.config = hs.config
101
102    async def store_room(
103        self,
104        room_id: str,
105        room_creator_user_id: str,
106        is_public: bool,
107        room_version: RoomVersion,
108    ) -> None:
109        """Stores a room.
110
111        Args:
112            room_id: The desired room ID, can be None.
113            room_creator_user_id: The user ID of the room creator.
114            is_public: True to indicate that this room should appear in
115                public room lists.
116            room_version: The version of the room
117        Raises:
118            StoreError if the room could not be stored.
119        """
120        try:
121            await self.db_pool.simple_insert(
122                "rooms",
123                {
124                    "room_id": room_id,
125                    "creator": room_creator_user_id,
126                    "is_public": is_public,
127                    "room_version": room_version.identifier,
128                    "has_auth_chain_index": True,
129                },
130                desc="store_room",
131            )
132        except Exception as e:
133            logger.error("store_room with room_id=%s failed: %s", room_id, e)
134            raise StoreError(500, "Problem creating room.")
135
136    async def get_room(self, room_id: str) -> Optional[Dict[str, Any]]:
137        """Retrieve a room.
138
139        Args:
140            room_id: The ID of the room to retrieve.
141        Returns:
142            A dict containing the room information, or None if the room is unknown.
143        """
144        return await self.db_pool.simple_select_one(
145            table="rooms",
146            keyvalues={"room_id": room_id},
147            retcols=("room_id", "is_public", "creator", "has_auth_chain_index"),
148            desc="get_room",
149            allow_none=True,
150        )
151
152    async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
153        """Retrieve room with statistics.
154
155        Args:
156            room_id: The ID of the room to retrieve.
157        Returns:
158            A dict containing the room information, or None if the room is unknown.
159        """
160
161        def get_room_with_stats_txn(
162            txn: LoggingTransaction, room_id: str
163        ) -> Optional[Dict[str, Any]]:
164            sql = """
165                SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
166                  curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
167                  rooms.creator, state.encryption, state.is_federatable AS federatable,
168                  rooms.is_public AS public, state.join_rules, state.guest_access,
169                  state.history_visibility, curr.current_state_events AS state_events,
170                  state.avatar, state.topic
171                FROM rooms
172                LEFT JOIN room_stats_state state USING (room_id)
173                LEFT JOIN room_stats_current curr USING (room_id)
174                WHERE room_id = ?
175                """
176            txn.execute(sql, [room_id])
177            # Catch error if sql returns empty result to return "None" instead of an error
178            try:
179                res = self.db_pool.cursor_to_dict(txn)[0]
180            except IndexError:
181                return None
182
183            res["federatable"] = bool(res["federatable"])
184            res["public"] = bool(res["public"])
185            return res
186
187        return await self.db_pool.runInteraction(
188            "get_room_with_stats", get_room_with_stats_txn, room_id
189        )
190
191    async def get_public_room_ids(self) -> List[str]:
192        return await self.db_pool.simple_select_onecol(
193            table="rooms",
194            keyvalues={"is_public": True},
195            retcol="room_id",
196            desc="get_public_room_ids",
197        )
198
199    async def count_public_rooms(
200        self,
201        network_tuple: Optional[ThirdPartyInstanceID],
202        ignore_non_federatable: bool,
203    ) -> int:
204        """Counts the number of public rooms as tracked in the room_stats_current
205        and room_stats_state table.
206
207        Args:
208            network_tuple
209            ignore_non_federatable: If true filters out non-federatable rooms
210        """
211
212        def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
213            query_args = []
214
215            if network_tuple:
216                if network_tuple.appservice_id:
217                    published_sql = """
218                        SELECT room_id from appservice_room_list
219                        WHERE appservice_id = ? AND network_id = ?
220                    """
221                    query_args.append(network_tuple.appservice_id)
222                    assert network_tuple.network_id is not None
223                    query_args.append(network_tuple.network_id)
224                else:
225                    published_sql = """
226                        SELECT room_id FROM rooms WHERE is_public
227                    """
228            else:
229                published_sql = """
230                    SELECT room_id FROM rooms WHERE is_public
231                    UNION SELECT room_id from appservice_room_list
232            """
233
234            sql = """
235                SELECT
236                    COUNT(*)
237                FROM (
238                    %(published_sql)s
239                ) published
240                INNER JOIN room_stats_state USING (room_id)
241                INNER JOIN room_stats_current USING (room_id)
242                WHERE
243                    (
244                        join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
245                        OR history_visibility = 'world_readable'
246                    )
247                    AND joined_members > 0
248            """ % {
249                "published_sql": published_sql,
250                "knock_join_rule": JoinRules.KNOCK,
251            }
252
253            txn.execute(sql, query_args)
254            return cast(Tuple[int], txn.fetchone())[0]
255
256        return await self.db_pool.runInteraction(
257            "count_public_rooms", _count_public_rooms_txn
258        )
259
260    async def get_room_count(self) -> int:
261        """Retrieve the total number of rooms."""
262
263        def f(txn: LoggingTransaction) -> int:
264            sql = "SELECT count(*)  FROM rooms"
265            txn.execute(sql)
266            row = cast(Tuple[int], txn.fetchone())
267            return row[0]
268
269        return await self.db_pool.runInteraction("get_rooms", f)
270
271    async def get_largest_public_rooms(
272        self,
273        network_tuple: Optional[ThirdPartyInstanceID],
274        search_filter: Optional[dict],
275        limit: Optional[int],
276        bounds: Optional[Tuple[int, str]],
277        forwards: bool,
278        ignore_non_federatable: bool = False,
279    ) -> List[Dict[str, Any]]:
280        """Gets the largest public rooms (where largest is in terms of joined
281        members, as tracked in the statistics table).
282
283        Args:
284            network_tuple
285            search_filter
286            limit: Maxmimum number of rows to return, unlimited otherwise.
287            bounds: An uppoer or lower bound to apply to result set if given,
288                consists of a joined member count and room_id (these are
289                excluded from result set).
290            forwards: true iff going forwards, going backwards otherwise
291            ignore_non_federatable: If true filters out non-federatable rooms.
292
293        Returns:
294            Rooms in order: biggest number of joined users first.
295            We then arbitrarily use the room_id as a tie breaker.
296
297        """
298
299        where_clauses = []
300        query_args: List[Union[str, int]] = []
301
302        if network_tuple:
303            if network_tuple.appservice_id:
304                published_sql = """
305                    SELECT room_id from appservice_room_list
306                    WHERE appservice_id = ? AND network_id = ?
307                """
308                query_args.append(network_tuple.appservice_id)
309                assert network_tuple.network_id is not None
310                query_args.append(network_tuple.network_id)
311            else:
312                published_sql = """
313                    SELECT room_id FROM rooms WHERE is_public
314                """
315        else:
316            published_sql = """
317                SELECT room_id FROM rooms WHERE is_public
318                UNION SELECT room_id from appservice_room_list
319            """
320
321        # Work out the bounds if we're given them, these bounds look slightly
322        # odd, but are designed to help query planner use indices by pulling
323        # out a common bound.
324        if bounds:
325            last_joined_members, last_room_id = bounds
326            if forwards:
327                where_clauses.append(
328                    """
329                        joined_members <= ? AND (
330                            joined_members < ? OR room_id < ?
331                        )
332                    """
333                )
334            else:
335                where_clauses.append(
336                    """
337                        joined_members >= ? AND (
338                            joined_members > ? OR room_id > ?
339                        )
340                    """
341                )
342
343            query_args += [last_joined_members, last_joined_members, last_room_id]
344
345        if ignore_non_federatable:
346            where_clauses.append("is_federatable")
347
348        if search_filter and search_filter.get("generic_search_term", None):
349            search_term = "%" + search_filter["generic_search_term"] + "%"
350
351            where_clauses.append(
352                """
353                    (
354                        LOWER(name) LIKE ?
355                        OR LOWER(topic) LIKE ?
356                        OR LOWER(canonical_alias) LIKE ?
357                    )
358                """
359            )
360            query_args += [
361                search_term.lower(),
362                search_term.lower(),
363                search_term.lower(),
364            ]
365
366        where_clause = ""
367        if where_clauses:
368            where_clause = " AND " + " AND ".join(where_clauses)
369
370        sql = """
371            SELECT
372                room_id, name, topic, canonical_alias, joined_members,
373                avatar, history_visibility, guest_access, join_rules
374            FROM (
375                %(published_sql)s
376            ) published
377            INNER JOIN room_stats_state USING (room_id)
378            INNER JOIN room_stats_current USING (room_id)
379            WHERE
380                (
381                    join_rules = 'public' OR join_rules = '%(knock_join_rule)s'
382                    OR history_visibility = 'world_readable'
383                )
384                AND joined_members > 0
385                %(where_clause)s
386            ORDER BY joined_members %(dir)s, room_id %(dir)s
387        """ % {
388            "published_sql": published_sql,
389            "where_clause": where_clause,
390            "dir": "DESC" if forwards else "ASC",
391            "knock_join_rule": JoinRules.KNOCK,
392        }
393
394        if limit is not None:
395            query_args.append(limit)
396
397            sql += """
398                LIMIT ?
399            """
400
401        def _get_largest_public_rooms_txn(
402            txn: LoggingTransaction,
403        ) -> List[Dict[str, Any]]:
404            txn.execute(sql, query_args)
405
406            results = self.db_pool.cursor_to_dict(txn)
407
408            if not forwards:
409                results.reverse()
410
411            return results
412
413        ret_val = await self.db_pool.runInteraction(
414            "get_largest_public_rooms", _get_largest_public_rooms_txn
415        )
416        return ret_val
417
418    @cached(max_entries=10000)
419    async def is_room_blocked(self, room_id: str) -> Optional[bool]:
420        return await self.db_pool.simple_select_one_onecol(
421            table="blocked_rooms",
422            keyvalues={"room_id": room_id},
423            retcol="1",
424            allow_none=True,
425            desc="is_room_blocked",
426        )
427
428    async def room_is_blocked_by(self, room_id: str) -> Optional[str]:
429        """
430        Function to retrieve user who has blocked the room.
431        user_id is non-nullable
432        It returns None if the room is not blocked.
433        """
434        return await self.db_pool.simple_select_one_onecol(
435            table="blocked_rooms",
436            keyvalues={"room_id": room_id},
437            retcol="user_id",
438            allow_none=True,
439            desc="room_is_blocked_by",
440        )
441
442    async def get_rooms_paginate(
443        self,
444        start: int,
445        limit: int,
446        order_by: str,
447        reverse_order: bool,
448        search_term: Optional[str],
449    ) -> Tuple[List[Dict[str, Any]], int]:
450        """Function to retrieve a paginated list of rooms as json.
451
452        Args:
453            start: offset in the list
454            limit: maximum amount of rooms to retrieve
455            order_by: the sort order of the returned list
456            reverse_order: whether to reverse the room list
457            search_term: a string to filter room names,
458                canonical alias and room ids by.
459                Room ID must match exactly. Canonical alias must match a substring of the local part.
460        Returns:
461            A list of room dicts and an integer representing the total number of
462            rooms that exist given this query
463        """
464        # Filter room names by a string
465        where_statement = ""
466        search_pattern: List[object] = []
467        if search_term:
468            where_statement = """
469                WHERE LOWER(state.name) LIKE ?
470                OR LOWER(state.canonical_alias) LIKE ?
471                OR state.room_id = ?
472            """
473
474            # Our postgres db driver converts ? -> %s in SQL strings as that's the
475            # placeholder for postgres.
476            # HOWEVER, if you put a % into your SQL then everything goes wibbly.
477            # To get around this, we're going to surround search_term with %'s
478            # before giving it to the database in python instead
479            search_pattern = [
480                "%" + search_term.lower() + "%",
481                "#%" + search_term.lower() + "%:%",
482                search_term,
483            ]
484
485        # Set ordering
486        if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
487            # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS
488            order_by_column = "curr.joined_members"
489            order_by_asc = False
490        elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
491            # Deprecated in favour of RoomSortOrder.NAME
492            order_by_column = "state.name"
493            order_by_asc = True
494        elif RoomSortOrder(order_by) == RoomSortOrder.NAME:
495            order_by_column = "state.name"
496            order_by_asc = True
497        elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS:
498            order_by_column = "state.canonical_alias"
499            order_by_asc = True
500        elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS:
501            order_by_column = "curr.joined_members"
502            order_by_asc = False
503        elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS:
504            order_by_column = "curr.local_users_in_room"
505            order_by_asc = False
506        elif RoomSortOrder(order_by) == RoomSortOrder.VERSION:
507            order_by_column = "rooms.room_version"
508            order_by_asc = False
509        elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR:
510            order_by_column = "rooms.creator"
511            order_by_asc = True
512        elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION:
513            order_by_column = "state.encryption"
514            order_by_asc = True
515        elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE:
516            order_by_column = "state.is_federatable"
517            order_by_asc = True
518        elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC:
519            order_by_column = "rooms.is_public"
520            order_by_asc = True
521        elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES:
522            order_by_column = "state.join_rules"
523            order_by_asc = True
524        elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS:
525            order_by_column = "state.guest_access"
526            order_by_asc = True
527        elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY:
528            order_by_column = "state.history_visibility"
529            order_by_asc = True
530        elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS:
531            order_by_column = "curr.current_state_events"
532            order_by_asc = False
533        else:
534            raise StoreError(
535                500, "Incorrect value for order_by provided: %s" % order_by
536            )
537
538        # Whether to return the list in reverse order
539        if reverse_order:
540            # Flip the boolean
541            order_by_asc = not order_by_asc
542
543        # Create one query for getting the limited number of events that the user asked
544        # for, and another query for getting the total number of events that could be
545        # returned. Thus allowing us to see if there are more events to paginate through
546        info_sql = """
547            SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
548              curr.local_users_in_room, rooms.room_version, rooms.creator,
549              state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
550              state.guest_access, state.history_visibility, curr.current_state_events
551            FROM room_stats_state state
552            INNER JOIN room_stats_current curr USING (room_id)
553            INNER JOIN rooms USING (room_id)
554            %s
555            ORDER BY %s %s
556            LIMIT ?
557            OFFSET ?
558        """ % (
559            where_statement,
560            order_by_column,
561            "ASC" if order_by_asc else "DESC",
562        )
563
564        # Use a nested SELECT statement as SQL can't count(*) with an OFFSET
565        count_sql = """
566            SELECT count(*) FROM (
567              SELECT room_id FROM room_stats_state state
568              %s
569            ) AS get_room_ids
570        """ % (
571            where_statement,
572        )
573
574        def _get_rooms_paginate_txn(
575            txn: LoggingTransaction,
576        ) -> Tuple[List[Dict[str, Any]], int]:
577            # Add the search term into the WHERE clause
578            # and execute the data query
579            txn.execute(info_sql, search_pattern + [limit, start])
580
581            # Refactor room query data into a structured dictionary
582            rooms = []
583            for room in txn:
584                rooms.append(
585                    {
586                        "room_id": room[0],
587                        "name": room[1],
588                        "canonical_alias": room[2],
589                        "joined_members": room[3],
590                        "joined_local_members": room[4],
591                        "version": room[5],
592                        "creator": room[6],
593                        "encryption": room[7],
594                        "federatable": room[8],
595                        "public": room[9],
596                        "join_rules": room[10],
597                        "guest_access": room[11],
598                        "history_visibility": room[12],
599                        "state_events": room[13],
600                    }
601                )
602
603            # Execute the count query
604
605            # Add the search term into the WHERE clause if present
606            txn.execute(count_sql, search_pattern)
607
608            room_count = cast(Tuple[int], txn.fetchone())
609            return rooms, room_count[0]
610
611        return await self.db_pool.runInteraction(
612            "get_rooms_paginate",
613            _get_rooms_paginate_txn,
614        )
615
616    @cached(max_entries=10000)
617    async def get_ratelimit_for_user(self, user_id: str) -> Optional[RatelimitOverride]:
618        """Check if there are any overrides for ratelimiting for the given user
619
620        Args:
621            user_id: user ID of the user
622        Returns:
623            RatelimitOverride if there is an override, else None. If the contents
624            of RatelimitOverride are None or 0 then ratelimitng has been
625            disabled for that user entirely.
626        """
627        row = await self.db_pool.simple_select_one(
628            table="ratelimit_override",
629            keyvalues={"user_id": user_id},
630            retcols=("messages_per_second", "burst_count"),
631            allow_none=True,
632            desc="get_ratelimit_for_user",
633        )
634
635        if row:
636            return RatelimitOverride(
637                messages_per_second=row["messages_per_second"],
638                burst_count=row["burst_count"],
639            )
640        else:
641            return None
642
643    async def set_ratelimit_for_user(
644        self, user_id: str, messages_per_second: int, burst_count: int
645    ) -> None:
646        """Sets whether a user is set an overridden ratelimit.
647        Args:
648            user_id: user ID of the user
649            messages_per_second: The number of actions that can be performed in a second.
650            burst_count: How many actions that can be performed before being limited.
651        """
652
653        def set_ratelimit_txn(txn: LoggingTransaction) -> None:
654            self.db_pool.simple_upsert_txn(
655                txn,
656                table="ratelimit_override",
657                keyvalues={"user_id": user_id},
658                values={
659                    "messages_per_second": messages_per_second,
660                    "burst_count": burst_count,
661                },
662            )
663
664            self._invalidate_cache_and_stream(
665                txn, self.get_ratelimit_for_user, (user_id,)
666            )
667
668        await self.db_pool.runInteraction("set_ratelimit", set_ratelimit_txn)
669
670    async def delete_ratelimit_for_user(self, user_id: str) -> None:
671        """Delete an overridden ratelimit for a user.
672        Args:
673            user_id: user ID of the user
674        """
675
676        def delete_ratelimit_txn(txn: LoggingTransaction) -> None:
677            row = self.db_pool.simple_select_one_txn(
678                txn,
679                table="ratelimit_override",
680                keyvalues={"user_id": user_id},
681                retcols=["user_id"],
682                allow_none=True,
683            )
684
685            if not row:
686                return
687
688            # They are there, delete them.
689            self.db_pool.simple_delete_one_txn(
690                txn, "ratelimit_override", keyvalues={"user_id": user_id}
691            )
692
693            self._invalidate_cache_and_stream(
694                txn, self.get_ratelimit_for_user, (user_id,)
695            )
696
697        await self.db_pool.runInteraction("delete_ratelimit", delete_ratelimit_txn)
698
699    @cached()
700    async def get_retention_policy_for_room(self, room_id: str) -> Dict[str, int]:
701        """Get the retention policy for a given room.
702
703        If no retention policy has been found for this room, returns a policy defined
704        by the configured default policy (which has None as both the 'min_lifetime' and
705        the 'max_lifetime' if no default policy has been defined in the server's
706        configuration).
707
708        Args:
709            room_id: The ID of the room to get the retention policy of.
710
711        Returns:
712            A dict containing "min_lifetime" and "max_lifetime" for this room.
713        """
714
715        def get_retention_policy_for_room_txn(
716            txn: LoggingTransaction,
717        ) -> List[Dict[str, Optional[int]]]:
718            txn.execute(
719                """
720                SELECT min_lifetime, max_lifetime FROM room_retention
721                INNER JOIN current_state_events USING (event_id, room_id)
722                WHERE room_id = ?;
723                """,
724                (room_id,),
725            )
726
727            return self.db_pool.cursor_to_dict(txn)
728
729        ret = await self.db_pool.runInteraction(
730            "get_retention_policy_for_room",
731            get_retention_policy_for_room_txn,
732        )
733
734        # If we don't know this room ID, ret will be None, in this case return the default
735        # policy.
736        if not ret:
737            return {
738                "min_lifetime": self.config.retention.retention_default_min_lifetime,
739                "max_lifetime": self.config.retention.retention_default_max_lifetime,
740            }
741
742        min_lifetime = ret[0]["min_lifetime"]
743        max_lifetime = ret[0]["max_lifetime"]
744
745        # If one of the room's policy's attributes isn't defined, use the matching
746        # attribute from the default policy.
747        # The default values will be None if no default policy has been defined, or if one
748        # of the attributes is missing from the default policy.
749        if min_lifetime is None:
750            min_lifetime = self.config.retention.retention_default_min_lifetime
751
752        if max_lifetime is None:
753            max_lifetime = self.config.retention.retention_default_max_lifetime
754
755        return {
756            "min_lifetime": min_lifetime,
757            "max_lifetime": max_lifetime,
758        }
759
760    async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
761        """Retrieves all the local and remote media MXC URIs in a given room
762
763        Args:
764            room_id
765
766        Returns:
767            The local and remote media as a lists of the media IDs.
768        """
769
770        def _get_media_mxcs_in_room_txn(
771            txn: LoggingTransaction,
772        ) -> Tuple[List[str], List[str]]:
773            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
774            local_media_mxcs = []
775            remote_media_mxcs = []
776
777            # Convert the IDs to MXC URIs
778            for media_id in local_mxcs:
779                local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
780            for hostname, media_id in remote_mxcs:
781                remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
782
783            return local_media_mxcs, remote_media_mxcs
784
785        return await self.db_pool.runInteraction(
786            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
787        )
788
789    async def quarantine_media_ids_in_room(
790        self, room_id: str, quarantined_by: str
791    ) -> int:
792        """For a room loops through all events with media and quarantines
793        the associated media
794        """
795
796        logger.info("Quarantining media in room: %s", room_id)
797
798        def _quarantine_media_in_room_txn(txn: LoggingTransaction) -> int:
799            local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
800            return self._quarantine_media_txn(
801                txn, local_mxcs, remote_mxcs, quarantined_by
802            )
803
804        return await self.db_pool.runInteraction(
805            "quarantine_media_in_room", _quarantine_media_in_room_txn
806        )
807
808    def _get_media_mxcs_in_room_txn(
809        self, txn: LoggingTransaction, room_id: str
810    ) -> Tuple[List[str], List[Tuple[str, str]]]:
811        """Retrieves all the local and remote media MXC URIs in a given room
812
813        Returns:
814            The local and remote media as a lists of tuples where the key is
815            the hostname and the value is the media ID.
816        """
817        sql = """
818            SELECT stream_ordering, json FROM events
819            JOIN event_json USING (room_id, event_id)
820            WHERE room_id = ?
821                %(where_clause)s
822                AND contains_url = ? AND outlier = ?
823            ORDER BY stream_ordering DESC
824            LIMIT ?
825        """
826        txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
827
828        local_media_mxcs = []
829        remote_media_mxcs = []
830
831        while True:
832            next_token = None
833            for stream_ordering, content_json in txn:
834                next_token = stream_ordering
835                event_json = db_to_json(content_json)
836                content = event_json["content"]
837                content_url = content.get("url")
838                thumbnail_url = content.get("info", {}).get("thumbnail_url")
839
840                for url in (content_url, thumbnail_url):
841                    if not url:
842                        continue
843                    matches = MXC_REGEX.match(url)
844                    if matches:
845                        hostname = matches.group(1)
846                        media_id = matches.group(2)
847                        if hostname == self.hs.hostname:
848                            local_media_mxcs.append(media_id)
849                        else:
850                            remote_media_mxcs.append((hostname, media_id))
851
852            if next_token is None:
853                # We've gone through the whole room, so we're finished.
854                break
855
856            txn.execute(
857                sql % {"where_clause": "AND stream_ordering < ?"},
858                (room_id, next_token, True, False, 100),
859            )
860
861        return local_media_mxcs, remote_media_mxcs
862
863    async def quarantine_media_by_id(
864        self,
865        server_name: str,
866        media_id: str,
867        quarantined_by: Optional[str],
868    ) -> int:
869        """quarantines or unquarantines a single local or remote media id
870
871        Args:
872            server_name: The name of the server that holds this media
873            media_id: The ID of the media to be quarantined
874            quarantined_by: The user ID that initiated the quarantine request
875                If it is `None` media will be removed from quarantine
876        """
877        logger.info("Quarantining media: %s/%s", server_name, media_id)
878        is_local = server_name == self.config.server.server_name
879
880        def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
881            local_mxcs = [media_id] if is_local else []
882            remote_mxcs = [(server_name, media_id)] if not is_local else []
883
884            return self._quarantine_media_txn(
885                txn, local_mxcs, remote_mxcs, quarantined_by
886            )
887
888        return await self.db_pool.runInteraction(
889            "quarantine_media_by_user", _quarantine_media_by_id_txn
890        )
891
892    async def quarantine_media_ids_by_user(
893        self, user_id: str, quarantined_by: str
894    ) -> int:
895        """quarantines all local media associated with a single user
896
897        Args:
898            user_id: The ID of the user to quarantine media of
899            quarantined_by: The ID of the user who made the quarantine request
900        """
901
902        def _quarantine_media_by_user_txn(txn: LoggingTransaction) -> int:
903            local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
904            return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
905
906        return await self.db_pool.runInteraction(
907            "quarantine_media_by_user", _quarantine_media_by_user_txn
908        )
909
910    def _get_media_ids_by_user_txn(
911        self, txn: LoggingTransaction, user_id: str, filter_quarantined: bool = True
912    ) -> List[str]:
913        """Retrieves local media IDs by a given user
914
915        Args:
916            txn (cursor)
917            user_id: The ID of the user to retrieve media IDs of
918
919        Returns:
920            The local and remote media as a lists of tuples where the key is
921            the hostname and the value is the media ID.
922        """
923        # Local media
924        sql = """
925            SELECT media_id
926            FROM local_media_repository
927            WHERE user_id = ?
928            """
929        if filter_quarantined:
930            sql += "AND quarantined_by IS NULL"
931        txn.execute(sql, (user_id,))
932
933        local_media_ids = [row[0] for row in txn]
934
935        # TODO: Figure out all remote media a user has referenced in a message
936
937        return local_media_ids
938
939    def _quarantine_media_txn(
940        self,
941        txn: LoggingTransaction,
942        local_mxcs: List[str],
943        remote_mxcs: List[Tuple[str, str]],
944        quarantined_by: Optional[str],
945    ) -> int:
946        """Quarantine and unquarantine local and remote media items
947
948        Args:
949            txn (cursor)
950            local_mxcs: A list of local mxc URLs
951            remote_mxcs: A list of (remote server, media id) tuples representing
952                remote mxc URLs
953            quarantined_by: The ID of the user who initiated the quarantine request
954                If it is `None` media will be removed from quarantine
955        Returns:
956            The total number of media items quarantined
957        """
958
959        # Update all the tables to set the quarantined_by flag
960        sql = """
961            UPDATE local_media_repository
962            SET quarantined_by = ?
963            WHERE media_id = ?
964        """
965
966        # set quarantine
967        if quarantined_by is not None:
968            sql += "AND safe_from_quarantine = ?"
969            txn.executemany(
970                sql, [(quarantined_by, media_id, False) for media_id in local_mxcs]
971            )
972        # remove from quarantine
973        else:
974            txn.executemany(
975                sql, [(quarantined_by, media_id) for media_id in local_mxcs]
976            )
977
978        # Note that a rowcount of -1 can be used to indicate no rows were affected.
979        total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
980
981        txn.executemany(
982            """
983                UPDATE remote_media_cache
984                SET quarantined_by = ?
985                WHERE media_origin = ? AND media_id = ?
986            """,
987            ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
988        )
989        total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
990
991        return total_media_quarantined
992
993    async def get_rooms_for_retention_period_in_range(
994        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
995    ) -> Dict[str, Dict[str, Optional[int]]]:
996        """Retrieves all of the rooms within the given retention range.
997
998        Optionally includes the rooms which don't have a retention policy.
999
1000        Args:
1001            min_ms: Duration in milliseconds that define the lower limit of
1002                the range to handle (exclusive). If None, doesn't set a lower limit.
1003            max_ms: Duration in milliseconds that define the upper limit of
1004                the range to handle (inclusive). If None, doesn't set an upper limit.
1005            include_null: Whether to include rooms which retention policy is NULL
1006                in the returned set.
1007
1008        Returns:
1009            The rooms within this range, along with their retention
1010            policy. The key is "room_id", and maps to a dict describing the retention
1011            policy associated with this room ID. The keys for this nested dict are
1012            "min_lifetime" (int|None), and "max_lifetime" (int|None).
1013        """
1014
1015        def get_rooms_for_retention_period_in_range_txn(
1016            txn: LoggingTransaction,
1017        ) -> Dict[str, Dict[str, Optional[int]]]:
1018            range_conditions = []
1019            args = []
1020
1021            if min_ms is not None:
1022                range_conditions.append("max_lifetime > ?")
1023                args.append(min_ms)
1024
1025            if max_ms is not None:
1026                range_conditions.append("max_lifetime <= ?")
1027                args.append(max_ms)
1028
1029            # Do a first query which will retrieve the rooms that have a retention policy
1030            # in their current state.
1031            sql = """
1032                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
1033                INNER JOIN current_state_events USING (event_id, room_id)
1034                """
1035
1036            if len(range_conditions):
1037                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
1038
1039                if include_null:
1040                    sql += " OR max_lifetime IS NULL"
1041
1042            txn.execute(sql, args)
1043
1044            rows = self.db_pool.cursor_to_dict(txn)
1045            rooms_dict = {}
1046
1047            for row in rows:
1048                rooms_dict[row["room_id"]] = {
1049                    "min_lifetime": row["min_lifetime"],
1050                    "max_lifetime": row["max_lifetime"],
1051                }
1052
1053            if include_null:
1054                # If required, do a second query that retrieves all of the rooms we know
1055                # of so we can handle rooms with no retention policy.
1056                sql = "SELECT DISTINCT room_id FROM current_state_events"
1057
1058                txn.execute(sql)
1059
1060                rows = self.db_pool.cursor_to_dict(txn)
1061
1062                # If a room isn't already in the dict (i.e. it doesn't have a retention
1063                # policy in its state), add it with a null policy.
1064                for row in rows:
1065                    if row["room_id"] not in rooms_dict:
1066                        rooms_dict[row["room_id"]] = {
1067                            "min_lifetime": None,
1068                            "max_lifetime": None,
1069                        }
1070
1071            return rooms_dict
1072
1073        return await self.db_pool.runInteraction(
1074            "get_rooms_for_retention_period_in_range",
1075            get_rooms_for_retention_period_in_range_txn,
1076        )
1077
1078
1079class _BackgroundUpdates:
1080    REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
1081    ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
1082    POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
1083    REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
1084    POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
1085
1086
1087_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
1088    "DROP TRIGGER populate_min_depth2_trigger ON room_depth",
1089    "DROP FUNCTION populate_min_depth2()",
1090    "ALTER TABLE room_depth DROP COLUMN min_depth",
1091    "ALTER TABLE room_depth RENAME COLUMN min_depth2 TO min_depth",
1092)
1093
1094
1095class RoomBackgroundUpdateStore(SQLBaseStore):
1096    def __init__(
1097        self,
1098        database: DatabasePool,
1099        db_conn: LoggingDatabaseConnection,
1100        hs: "HomeServer",
1101    ):
1102        super().__init__(database, db_conn, hs)
1103
1104        self.db_pool.updates.register_background_update_handler(
1105            "insert_room_retention",
1106            self._background_insert_retention,
1107        )
1108
1109        self.db_pool.updates.register_background_update_handler(
1110            _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
1111            self._remove_tombstoned_rooms_from_directory,
1112        )
1113
1114        self.db_pool.updates.register_background_update_handler(
1115            _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
1116            self._background_add_rooms_room_version_column,
1117        )
1118
1119        # BG updates to change the type of room_depth.min_depth
1120        self.db_pool.updates.register_background_update_handler(
1121            _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
1122            self._background_populate_room_depth_min_depth2,
1123        )
1124        self.db_pool.updates.register_background_update_handler(
1125            _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
1126            self._background_replace_room_depth_min_depth,
1127        )
1128
1129        self.db_pool.updates.register_background_update_handler(
1130            _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
1131            self._background_populate_rooms_creator_column,
1132        )
1133
1134    async def _background_insert_retention(
1135        self, progress: JsonDict, batch_size: int
1136    ) -> int:
1137        """Retrieves a list of all rooms within a range and inserts an entry for each of
1138        them into the room_retention table.
1139        NULLs the property's columns if missing from the retention event in the room's
1140        state (or NULLs all of them if there's no retention event in the room's state),
1141        so that we fall back to the server's retention policy.
1142        """
1143
1144        last_room = progress.get("room_id", "")
1145
1146        def _background_insert_retention_txn(txn: LoggingTransaction) -> bool:
1147            txn.execute(
1148                """
1149                SELECT state.room_id, state.event_id, events.json
1150                FROM current_state_events as state
1151                LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
1152                WHERE state.room_id > ? AND state.type = '%s'
1153                ORDER BY state.room_id ASC
1154                LIMIT ?;
1155                """
1156                % EventTypes.Retention,
1157                (last_room, batch_size),
1158            )
1159
1160            rows = self.db_pool.cursor_to_dict(txn)
1161
1162            if not rows:
1163                return True
1164
1165            for row in rows:
1166                if not row["json"]:
1167                    retention_policy = {}
1168                else:
1169                    ev = db_to_json(row["json"])
1170                    retention_policy = ev["content"]
1171
1172                self.db_pool.simple_insert_txn(
1173                    txn=txn,
1174                    table="room_retention",
1175                    values={
1176                        "room_id": row["room_id"],
1177                        "event_id": row["event_id"],
1178                        "min_lifetime": retention_policy.get("min_lifetime"),
1179                        "max_lifetime": retention_policy.get("max_lifetime"),
1180                    },
1181                )
1182
1183            logger.info("Inserted %d rows into room_retention", len(rows))
1184
1185            self.db_pool.updates._background_update_progress_txn(
1186                txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
1187            )
1188
1189            if batch_size > len(rows):
1190                return True
1191            else:
1192                return False
1193
1194        end = await self.db_pool.runInteraction(
1195            "insert_room_retention",
1196            _background_insert_retention_txn,
1197        )
1198
1199        if end:
1200            await self.db_pool.updates._end_background_update("insert_room_retention")
1201
1202        return batch_size
1203
1204    async def _background_add_rooms_room_version_column(
1205        self, progress: JsonDict, batch_size: int
1206    ) -> int:
1207        """Background update to go and add room version information to `rooms`
1208        table from `current_state_events` table.
1209        """
1210
1211        last_room_id = progress.get("room_id", "")
1212
1213        def _background_add_rooms_room_version_column_txn(
1214            txn: LoggingTransaction,
1215        ) -> bool:
1216            sql = """
1217                SELECT room_id, json FROM current_state_events
1218                INNER JOIN event_json USING (room_id, event_id)
1219                WHERE room_id > ? AND type = 'm.room.create' AND state_key = ''
1220                ORDER BY room_id
1221                LIMIT ?
1222            """
1223
1224            txn.execute(sql, (last_room_id, batch_size))
1225
1226            updates = []
1227            for room_id, event_json in txn:
1228                event_dict = db_to_json(event_json)
1229                room_version_id = event_dict.get("content", {}).get(
1230                    "room_version", RoomVersions.V1.identifier
1231                )
1232
1233                creator = event_dict.get("content").get("creator")
1234
1235                updates.append((room_id, creator, room_version_id))
1236
1237            if not updates:
1238                return True
1239
1240            new_last_room_id = ""
1241            for room_id, creator, room_version_id in updates:
1242                # We upsert here just in case we don't already have a row,
1243                # mainly for paranoia as much badness would happen if we don't
1244                # insert the row and then try and get the room version for the
1245                # room.
1246                self.db_pool.simple_upsert_txn(
1247                    txn,
1248                    table="rooms",
1249                    keyvalues={"room_id": room_id},
1250                    values={"room_version": room_version_id},
1251                    insertion_values={"is_public": False, "creator": creator},
1252                )
1253                new_last_room_id = room_id
1254
1255            self.db_pool.updates._background_update_progress_txn(
1256                txn,
1257                _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN,
1258                {"room_id": new_last_room_id},
1259            )
1260
1261            return False
1262
1263        end = await self.db_pool.runInteraction(
1264            "_background_add_rooms_room_version_column",
1265            _background_add_rooms_room_version_column_txn,
1266        )
1267
1268        if end:
1269            await self.db_pool.updates._end_background_update(
1270                _BackgroundUpdates.ADD_ROOMS_ROOM_VERSION_COLUMN
1271            )
1272
1273        return batch_size
1274
1275    async def _remove_tombstoned_rooms_from_directory(
1276        self, progress: JsonDict, batch_size: int
1277    ) -> int:
1278        """Removes any rooms with tombstone events from the room directory
1279
1280        Nowadays this is handled by the room upgrade handler, but we may have some
1281        that got left behind
1282        """
1283
1284        last_room = progress.get("room_id", "")
1285
1286        def _get_rooms(txn: LoggingTransaction) -> List[str]:
1287            txn.execute(
1288                """
1289                SELECT room_id
1290                FROM rooms r
1291                INNER JOIN current_state_events cse USING (room_id)
1292                WHERE room_id > ? AND r.is_public
1293                AND cse.type = '%s' AND cse.state_key = ''
1294                ORDER BY room_id ASC
1295                LIMIT ?;
1296                """
1297                % EventTypes.Tombstone,
1298                (last_room, batch_size),
1299            )
1300
1301            return [row[0] for row in txn]
1302
1303        rooms = await self.db_pool.runInteraction(
1304            "get_tombstoned_directory_rooms", _get_rooms
1305        )
1306
1307        if not rooms:
1308            await self.db_pool.updates._end_background_update(
1309                _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
1310            )
1311            return 0
1312
1313        for room_id in rooms:
1314            logger.info("Removing tombstoned room %s from the directory", room_id)
1315            await self.set_room_is_public(room_id, False)
1316
1317        await self.db_pool.updates._background_update_progress(
1318            _BackgroundUpdates.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
1319        )
1320
1321        return len(rooms)
1322
1323    @abstractmethod
1324    def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
1325        # this will need to be implemented if a background update is performed with
1326        # existing (tombstoned, public) rooms in the database.
1327        #
1328        # It's overridden by RoomStore for the synapse master.
1329        raise NotImplementedError()
1330
1331    async def has_auth_chain_index(self, room_id: str) -> bool:
1332        """Check if the room has (or can have) a chain cover index.
1333
1334        Defaults to True if we don't have an entry in `rooms` table nor any
1335        events for the room.
1336        """
1337
1338        has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
1339            table="rooms",
1340            keyvalues={"room_id": room_id},
1341            retcol="has_auth_chain_index",
1342            desc="has_auth_chain_index",
1343            allow_none=True,
1344        )
1345
1346        if has_auth_chain_index:
1347            return True
1348
1349        # It's possible that we already have events for the room in our DB
1350        # without a corresponding room entry. If we do then we don't want to
1351        # mark the room as having an auth chain cover index.
1352        max_ordering = await self.db_pool.simple_select_one_onecol(
1353            table="events",
1354            keyvalues={"room_id": room_id},
1355            retcol="MAX(stream_ordering)",
1356            allow_none=True,
1357            desc="has_auth_chain_index_fallback",
1358        )
1359
1360        return max_ordering is None
1361
1362    async def _background_populate_room_depth_min_depth2(
1363        self, progress: JsonDict, batch_size: int
1364    ) -> int:
1365        """Populate room_depth.min_depth2
1366
1367        This is to deal with the fact that min_depth was initially created as a
1368        32-bit integer field.
1369        """
1370
1371        def process(txn: LoggingTransaction) -> int:
1372            last_room = progress.get("last_room", "")
1373            txn.execute(
1374                """
1375                UPDATE room_depth SET min_depth2=min_depth
1376                WHERE room_id IN (
1377                   SELECT room_id FROM room_depth WHERE room_id > ?
1378                   ORDER BY room_id LIMIT ?
1379                )
1380                RETURNING room_id;
1381                """,
1382                (last_room, batch_size),
1383            )
1384            row_count = txn.rowcount
1385            if row_count == 0:
1386                return 0
1387            last_room = max(row[0] for row in txn)
1388            logger.info("populated room_depth up to %s", last_room)
1389
1390            self.db_pool.updates._background_update_progress_txn(
1391                txn,
1392                _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
1393                {"last_room": last_room},
1394            )
1395            return row_count
1396
1397        result = await self.db_pool.runInteraction(
1398            "_background_populate_min_depth2", process
1399        )
1400
1401        if result != 0:
1402            return result
1403
1404        await self.db_pool.updates._end_background_update(
1405            _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2
1406        )
1407        return 0
1408
1409    async def _background_replace_room_depth_min_depth(
1410        self, progress: JsonDict, batch_size: int
1411    ) -> int:
1412        """Drop the old 'min_depth' column and rename 'min_depth2' into its place."""
1413
1414        def process(txn: Cursor) -> None:
1415            for sql in _REPLACE_ROOM_DEPTH_SQL_COMMANDS:
1416                logger.info("completing room_depth migration: %s", sql)
1417                txn.execute(sql)
1418
1419        await self.db_pool.runInteraction("_background_replace_room_depth", process)
1420
1421        await self.db_pool.updates._end_background_update(
1422            _BackgroundUpdates.REPLACE_ROOM_DEPTH_MIN_DEPTH,
1423        )
1424
1425        return 0
1426
1427    async def _background_populate_rooms_creator_column(
1428        self, progress: JsonDict, batch_size: int
1429    ) -> int:
1430        """Background update to go and add creator information to `rooms`
1431        table from `current_state_events` table.
1432        """
1433
1434        last_room_id = progress.get("room_id", "")
1435
1436        def _background_populate_rooms_creator_column_txn(
1437            txn: LoggingTransaction,
1438        ) -> bool:
1439            sql = """
1440                SELECT room_id, json FROM event_json
1441                INNER JOIN rooms AS room USING (room_id)
1442                INNER JOIN current_state_events AS state_event USING (room_id, event_id)
1443                WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = ''
1444                ORDER BY room_id
1445                LIMIT ?
1446            """
1447
1448            txn.execute(sql, (last_room_id, batch_size))
1449            room_id_to_create_event_results = txn.fetchall()
1450
1451            new_last_room_id = ""
1452            for room_id, event_json in room_id_to_create_event_results:
1453                event_dict = db_to_json(event_json)
1454
1455                creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR)
1456
1457                self.db_pool.simple_update_txn(
1458                    txn,
1459                    table="rooms",
1460                    keyvalues={"room_id": room_id},
1461                    updatevalues={"creator": creator},
1462                )
1463                new_last_room_id = room_id
1464
1465            if new_last_room_id == "":
1466                return True
1467
1468            self.db_pool.updates._background_update_progress_txn(
1469                txn,
1470                _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN,
1471                {"room_id": new_last_room_id},
1472            )
1473
1474            return False
1475
1476        end = await self.db_pool.runInteraction(
1477            "_background_populate_rooms_creator_column",
1478            _background_populate_rooms_creator_column_txn,
1479        )
1480
1481        if end:
1482            await self.db_pool.updates._end_background_update(
1483                _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN
1484            )
1485
1486        return batch_size
1487
1488
1489class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
1490    def __init__(
1491        self,
1492        database: DatabasePool,
1493        db_conn: LoggingDatabaseConnection,
1494        hs: "HomeServer",
1495    ):
1496        super().__init__(database, db_conn, hs)
1497
1498        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
1499
1500    async def upsert_room_on_join(
1501        self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
1502    ) -> None:
1503        """Ensure that the room is stored in the table
1504
1505        Called when we join a room over federation, and overwrites any room version
1506        currently in the table.
1507        """
1508        # It's possible that we already have events for the room in our DB
1509        # without a corresponding room entry. If we do then we don't want to
1510        # mark the room as having an auth chain cover index.
1511        has_auth_chain_index = await self.has_auth_chain_index(room_id)
1512
1513        create_event = None
1514        for e in auth_events:
1515            if (e.type, e.state_key) == (EventTypes.Create, ""):
1516                create_event = e
1517                break
1518
1519        if create_event is None:
1520            # If the state doesn't have a create event then the room is
1521            # invalid, and it would fail auth checks anyway.
1522            raise StoreError(400, "No create event in state")
1523
1524        room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
1525
1526        if not isinstance(room_creator, str):
1527            # If the create event does not have a creator then the room is
1528            # invalid, and it would fail auth checks anyway.
1529            raise StoreError(400, "No creator defined on the create event")
1530
1531        await self.db_pool.simple_upsert(
1532            desc="upsert_room_on_join",
1533            table="rooms",
1534            keyvalues={"room_id": room_id},
1535            values={"room_version": room_version.identifier},
1536            insertion_values={
1537                "is_public": False,
1538                "creator": room_creator,
1539                "has_auth_chain_index": has_auth_chain_index,
1540            },
1541            # rooms has a unique constraint on room_id, so no need to lock when doing an
1542            # emulated upsert.
1543            lock=False,
1544        )
1545
1546    async def maybe_store_room_on_outlier_membership(
1547        self, room_id: str, room_version: RoomVersion
1548    ) -> None:
1549        """
1550        When we receive an invite or any other event over federation that may relate to a room
1551        we are not in, store the version of the room if we don't already know the room version.
1552        """
1553        # It's possible that we already have events for the room in our DB
1554        # without a corresponding room entry. If we do then we don't want to
1555        # mark the room as having an auth chain cover index.
1556        has_auth_chain_index = await self.has_auth_chain_index(room_id)
1557
1558        await self.db_pool.simple_upsert(
1559            desc="maybe_store_room_on_outlier_membership",
1560            table="rooms",
1561            keyvalues={"room_id": room_id},
1562            values={},
1563            insertion_values={
1564                "room_version": room_version.identifier,
1565                "is_public": False,
1566                # We don't worry about setting the `creator` here because
1567                # we don't process any messages in a room while a user is
1568                # invited (only after the join).
1569                "creator": "",
1570                "has_auth_chain_index": has_auth_chain_index,
1571            },
1572            # rooms has a unique constraint on room_id, so no need to lock when doing an
1573            # emulated upsert.
1574            lock=False,
1575        )
1576
1577    async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
1578        await self.db_pool.simple_update_one(
1579            table="rooms",
1580            keyvalues={"room_id": room_id},
1581            updatevalues={"is_public": is_public},
1582            desc="set_room_is_public",
1583        )
1584
1585        self.hs.get_notifier().on_new_replication_data()
1586
1587    async def set_room_is_public_appservice(
1588        self, room_id: str, appservice_id: str, network_id: str, is_public: bool
1589    ) -> None:
1590        """Edit the appservice/network specific public room list.
1591
1592        Each appservice can have a number of published room lists associated
1593        with them, keyed off of an appservice defined `network_id`, which
1594        basically represents a single instance of a bridge to a third party
1595        network.
1596
1597        Args:
1598            room_id
1599            appservice_id
1600            network_id
1601            is_public: Whether to publish or unpublish the room from the list.
1602        """
1603
1604        if is_public:
1605            await self.db_pool.simple_upsert(
1606                table="appservice_room_list",
1607                keyvalues={
1608                    "appservice_id": appservice_id,
1609                    "network_id": network_id,
1610                    "room_id": room_id,
1611                },
1612                values={},
1613                insertion_values={
1614                    "appservice_id": appservice_id,
1615                    "network_id": network_id,
1616                    "room_id": room_id,
1617                },
1618                desc="set_room_is_public_appservice_true",
1619            )
1620        else:
1621            await self.db_pool.simple_delete(
1622                table="appservice_room_list",
1623                keyvalues={
1624                    "appservice_id": appservice_id,
1625                    "network_id": network_id,
1626                    "room_id": room_id,
1627                },
1628                desc="set_room_is_public_appservice_false",
1629            )
1630
1631        self.hs.get_notifier().on_new_replication_data()
1632
1633    async def add_event_report(
1634        self,
1635        room_id: str,
1636        event_id: str,
1637        user_id: str,
1638        reason: Optional[str],
1639        content: JsonDict,
1640        received_ts: int,
1641    ) -> None:
1642        next_id = self._event_reports_id_gen.get_next()
1643        await self.db_pool.simple_insert(
1644            table="event_reports",
1645            values={
1646                "id": next_id,
1647                "received_ts": received_ts,
1648                "room_id": room_id,
1649                "event_id": event_id,
1650                "user_id": user_id,
1651                "reason": reason,
1652                "content": json_encoder.encode(content),
1653            },
1654            desc="add_event_report",
1655        )
1656
1657    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
1658        """Retrieve an event report
1659
1660        Args:
1661            report_id: ID of reported event in database
1662        Returns:
1663            event_report: json list of information from event report
1664        """
1665
1666        def _get_event_report_txn(
1667            txn: LoggingTransaction, report_id: int
1668        ) -> Optional[Dict[str, Any]]:
1669
1670            sql = """
1671                SELECT
1672                    er.id,
1673                    er.received_ts,
1674                    er.room_id,
1675                    er.event_id,
1676                    er.user_id,
1677                    er.content,
1678                    events.sender,
1679                    room_stats_state.canonical_alias,
1680                    room_stats_state.name,
1681                    event_json.json AS event_json
1682                FROM event_reports AS er
1683                LEFT JOIN events
1684                    ON events.event_id = er.event_id
1685                JOIN event_json
1686                    ON event_json.event_id = er.event_id
1687                JOIN room_stats_state
1688                    ON room_stats_state.room_id = er.room_id
1689                WHERE er.id = ?
1690            """
1691
1692            txn.execute(sql, [report_id])
1693            row = txn.fetchone()
1694
1695            if not row:
1696                return None
1697
1698            event_report = {
1699                "id": row[0],
1700                "received_ts": row[1],
1701                "room_id": row[2],
1702                "event_id": row[3],
1703                "user_id": row[4],
1704                "score": db_to_json(row[5]).get("score"),
1705                "reason": db_to_json(row[5]).get("reason"),
1706                "sender": row[6],
1707                "canonical_alias": row[7],
1708                "name": row[8],
1709                "event_json": db_to_json(row[9]),
1710            }
1711
1712            return event_report
1713
1714        return await self.db_pool.runInteraction(
1715            "get_event_report", _get_event_report_txn, report_id
1716        )
1717
1718    async def get_event_reports_paginate(
1719        self,
1720        start: int,
1721        limit: int,
1722        direction: str = "b",
1723        user_id: Optional[str] = None,
1724        room_id: Optional[str] = None,
1725    ) -> Tuple[List[Dict[str, Any]], int]:
1726        """Retrieve a paginated list of event reports
1727
1728        Args:
1729            start: event offset to begin the query from
1730            limit: number of rows to retrieve
1731            direction: Whether to fetch the most recent first (`"b"`) or the
1732                oldest first (`"f"`)
1733            user_id: search for user_id. Ignored if user_id is None
1734            room_id: search for room_id. Ignored if room_id is None
1735        Returns:
1736            event_reports: json list of event reports
1737            count: total number of event reports matching the filter criteria
1738        """
1739
1740        def _get_event_reports_paginate_txn(
1741            txn: LoggingTransaction,
1742        ) -> Tuple[List[Dict[str, Any]], int]:
1743            filters = []
1744            args: List[object] = []
1745
1746            if user_id:
1747                filters.append("er.user_id LIKE ?")
1748                args.extend(["%" + user_id + "%"])
1749            if room_id:
1750                filters.append("er.room_id LIKE ?")
1751                args.extend(["%" + room_id + "%"])
1752
1753            if direction == "b":
1754                order = "DESC"
1755            else:
1756                order = "ASC"
1757
1758            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
1759
1760            sql = """
1761                SELECT COUNT(*) as total_event_reports
1762                FROM event_reports AS er
1763                {}
1764                """.format(
1765                where_clause
1766            )
1767            txn.execute(sql, args)
1768            count = cast(Tuple[int], txn.fetchone())[0]
1769
1770            sql = """
1771                SELECT
1772                    er.id,
1773                    er.received_ts,
1774                    er.room_id,
1775                    er.event_id,
1776                    er.user_id,
1777                    er.content,
1778                    events.sender,
1779                    room_stats_state.canonical_alias,
1780                    room_stats_state.name
1781                FROM event_reports AS er
1782                LEFT JOIN events
1783                    ON events.event_id = er.event_id
1784                JOIN room_stats_state
1785                    ON room_stats_state.room_id = er.room_id
1786                {where_clause}
1787                ORDER BY er.received_ts {order}
1788                LIMIT ?
1789                OFFSET ?
1790            """.format(
1791                where_clause=where_clause,
1792                order=order,
1793            )
1794
1795            args += [limit, start]
1796            txn.execute(sql, args)
1797
1798            event_reports = []
1799            for row in txn:
1800                try:
1801                    s = db_to_json(row[5]).get("score")
1802                    r = db_to_json(row[5]).get("reason")
1803                except Exception:
1804                    logger.error("Unable to parse json from event_reports: %s", row[0])
1805                    continue
1806                event_reports.append(
1807                    {
1808                        "id": row[0],
1809                        "received_ts": row[1],
1810                        "room_id": row[2],
1811                        "event_id": row[3],
1812                        "user_id": row[4],
1813                        "score": s,
1814                        "reason": r,
1815                        "sender": row[6],
1816                        "canonical_alias": row[7],
1817                        "name": row[8],
1818                    }
1819                )
1820
1821            return event_reports, count
1822
1823        return await self.db_pool.runInteraction(
1824            "get_event_reports_paginate", _get_event_reports_paginate_txn
1825        )
1826
1827    async def block_room(self, room_id: str, user_id: str) -> None:
1828        """Marks the room as blocked.
1829
1830        Can be called multiple times (though we'll only track the last user to
1831        block this room).
1832
1833        Can be called on a room unknown to this homeserver.
1834
1835        Args:
1836            room_id: Room to block
1837            user_id: Who blocked it
1838        """
1839        await self.db_pool.simple_upsert(
1840            table="blocked_rooms",
1841            keyvalues={"room_id": room_id},
1842            values={},
1843            insertion_values={"user_id": user_id},
1844            desc="block_room",
1845        )
1846        await self.db_pool.runInteraction(
1847            "block_room_invalidation",
1848            self._invalidate_cache_and_stream,
1849            self.is_room_blocked,
1850            (room_id,),
1851        )
1852
1853    async def unblock_room(self, room_id: str) -> None:
1854        """Remove the room from blocking list.
1855
1856        Args:
1857            room_id: Room to unblock
1858        """
1859        await self.db_pool.simple_delete(
1860            table="blocked_rooms",
1861            keyvalues={"room_id": room_id},
1862            desc="unblock_room",
1863        )
1864        await self.db_pool.runInteraction(
1865            "block_room_invalidation",
1866            self._invalidate_cache_and_stream,
1867            self.is_room_blocked,
1868            (room_id,),
1869        )
1870