1# Copyright 2014-2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
16
17from synapse.api.presence import PresenceState, UserPresenceState
18from synapse.replication.tcp.streams import PresenceStream
19from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
20from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
21from synapse.storage.engines import PostgresEngine
22from synapse.storage.types import Connection
23from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
24from synapse.util.caches.descriptors import cached, cachedList
25from synapse.util.caches.stream_change_cache import StreamChangeCache
26from synapse.util.iterutils import batch_iter
27
28if TYPE_CHECKING:
29    from synapse.server import HomeServer
30
31
32class PresenceBackgroundUpdateStore(SQLBaseStore):
33    def __init__(
34        self,
35        database: DatabasePool,
36        db_conn: LoggingDatabaseConnection,
37        hs: "HomeServer",
38    ):
39        super().__init__(database, db_conn, hs)
40
41        # Used by `PresenceStore._get_active_presence()`
42        self.db_pool.updates.register_background_index_update(
43            "presence_stream_not_offline_index",
44            index_name="presence_stream_state_not_offline_idx",
45            table="presence_stream",
46            columns=["state"],
47            where_clause="state != 'offline'",
48        )
49
50
51class PresenceStore(PresenceBackgroundUpdateStore):
52    def __init__(
53        self,
54        database: DatabasePool,
55        db_conn: LoggingDatabaseConnection,
56        hs: "HomeServer",
57    ):
58        super().__init__(database, db_conn, hs)
59
60        self._can_persist_presence = (
61            hs.get_instance_name() in hs.config.worker.writers.presence
62        )
63
64        if isinstance(database.engine, PostgresEngine):
65            self._presence_id_gen = MultiWriterIdGenerator(
66                db_conn=db_conn,
67                db=database,
68                stream_name="presence_stream",
69                instance_name=self._instance_name,
70                tables=[("presence_stream", "instance_name", "stream_id")],
71                sequence_name="presence_stream_sequence",
72                writers=hs.config.worker.writers.presence,
73            )
74        else:
75            self._presence_id_gen = StreamIdGenerator(
76                db_conn, "presence_stream", "stream_id"
77            )
78
79        self.hs = hs
80        self._presence_on_startup = self._get_active_presence(db_conn)
81
82        presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
83            db_conn,
84            "presence_stream",
85            entity_column="user_id",
86            stream_column="stream_id",
87            max_value=self._presence_id_gen.get_current_token(),
88        )
89        self.presence_stream_cache = StreamChangeCache(
90            "PresenceStreamChangeCache",
91            min_presence_val,
92            prefilled_cache=presence_cache_prefill,
93        )
94
95    async def update_presence(self, presence_states) -> Tuple[int, int]:
96        assert self._can_persist_presence
97
98        stream_ordering_manager = self._presence_id_gen.get_next_mult(
99            len(presence_states)
100        )
101
102        async with stream_ordering_manager as stream_orderings:
103            await self.db_pool.runInteraction(
104                "update_presence",
105                self._update_presence_txn,
106                stream_orderings,
107                presence_states,
108            )
109
110        return stream_orderings[-1], self._presence_id_gen.get_current_token()
111
112    def _update_presence_txn(self, txn, stream_orderings, presence_states):
113        for stream_id, state in zip(stream_orderings, presence_states):
114            txn.call_after(
115                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
116            )
117            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
118
119        # Delete old rows to stop database from getting really big
120        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
121
122        for states in batch_iter(presence_states, 50):
123            clause, args = make_in_list_sql_clause(
124                self.database_engine, "user_id", [s.user_id for s in states]
125            )
126            txn.execute(sql + clause, [stream_id] + list(args))
127
128        # Actually insert new rows
129        self.db_pool.simple_insert_many_txn(
130            txn,
131            table="presence_stream",
132            values=[
133                {
134                    "stream_id": stream_id,
135                    "user_id": state.user_id,
136                    "state": state.state,
137                    "last_active_ts": state.last_active_ts,
138                    "last_federation_update_ts": state.last_federation_update_ts,
139                    "last_user_sync_ts": state.last_user_sync_ts,
140                    "status_msg": state.status_msg,
141                    "currently_active": state.currently_active,
142                    "instance_name": self._instance_name,
143                }
144                for stream_id, state in zip(stream_orderings, presence_states)
145            ],
146        )
147
148    async def get_all_presence_updates(
149        self, instance_name: str, last_id: int, current_id: int, limit: int
150    ) -> Tuple[List[Tuple[int, list]], int, bool]:
151        """Get updates for presence replication stream.
152
153        Args:
154            instance_name: The writer we want to fetch updates from. Unused
155                here since there is only ever one writer.
156            last_id: The token to fetch updates from. Exclusive.
157            current_id: The token to fetch updates up to. Inclusive.
158            limit: The requested limit for the number of rows to return. The
159                function may return more or fewer rows.
160
161        Returns:
162            A tuple consisting of: the updates, a token to use to fetch
163            subsequent updates, and whether we returned fewer rows than exists
164            between the requested tokens due to the limit.
165
166            The token returned can be used in a subsequent call to this
167            function to get further updatees.
168
169            The updates are a list of 2-tuples of stream ID and the row data
170        """
171
172        if last_id == current_id:
173            return [], current_id, False
174
175        def get_all_presence_updates_txn(txn):
176            sql = """
177                SELECT stream_id, user_id, state, last_active_ts,
178                    last_federation_update_ts, last_user_sync_ts,
179                    status_msg,
180                currently_active
181                FROM presence_stream
182                WHERE ? < stream_id AND stream_id <= ?
183                ORDER BY stream_id ASC
184                LIMIT ?
185            """
186            txn.execute(sql, (last_id, current_id, limit))
187            updates = [(row[0], row[1:]) for row in txn]
188
189            upper_bound = current_id
190            limited = False
191            if len(updates) >= limit:
192                upper_bound = updates[-1][0]
193                limited = True
194
195            return updates, upper_bound, limited
196
197        return await self.db_pool.runInteraction(
198            "get_all_presence_updates", get_all_presence_updates_txn
199        )
200
201    @cached()
202    def _get_presence_for_user(self, user_id):
203        raise NotImplementedError()
204
205    @cachedList(
206        cached_method_name="_get_presence_for_user",
207        list_name="user_ids",
208        num_args=1,
209    )
210    async def get_presence_for_users(self, user_ids):
211        rows = await self.db_pool.simple_select_many_batch(
212            table="presence_stream",
213            column="user_id",
214            iterable=user_ids,
215            keyvalues={},
216            retcols=(
217                "user_id",
218                "state",
219                "last_active_ts",
220                "last_federation_update_ts",
221                "last_user_sync_ts",
222                "status_msg",
223                "currently_active",
224            ),
225            desc="get_presence_for_users",
226        )
227
228        for row in rows:
229            row["currently_active"] = bool(row["currently_active"])
230
231        return {row["user_id"]: UserPresenceState(**row) for row in rows}
232
233    async def should_user_receive_full_presence_with_token(
234        self,
235        user_id: str,
236        from_token: int,
237    ) -> bool:
238        """Check whether the given user should receive full presence using the stream token
239        they're updating from.
240
241        Args:
242            user_id: The ID of the user to check.
243            from_token: The stream token included in their /sync token.
244
245        Returns:
246            True if the user should have full presence sent to them, False otherwise.
247        """
248
249        def _should_user_receive_full_presence_with_token_txn(txn):
250            sql = """
251                SELECT 1 FROM users_to_send_full_presence_to
252                WHERE user_id = ?
253                AND presence_stream_id >= ?
254            """
255            txn.execute(sql, (user_id, from_token))
256            return bool(txn.fetchone())
257
258        return await self.db_pool.runInteraction(
259            "should_user_receive_full_presence_with_token",
260            _should_user_receive_full_presence_with_token_txn,
261        )
262
263    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
264        """Adds to the list of users who should receive a full snapshot of presence
265        upon their next sync.
266
267        Args:
268            user_ids: An iterable of user IDs.
269        """
270        # Add user entries to the table, updating the presence_stream_id column if the user already
271        # exists in the table.
272        presence_stream_id = self._presence_id_gen.get_current_token()
273        await self.db_pool.simple_upsert_many(
274            table="users_to_send_full_presence_to",
275            key_names=("user_id",),
276            key_values=[(user_id,) for user_id in user_ids],
277            value_names=("presence_stream_id",),
278            # We save the current presence stream ID token along with the user ID entry so
279            # that when a user /sync's, even if they syncing multiple times across separate
280            # devices at different times, each device will receive full presence once - when
281            # the presence stream ID in their sync token is less than the one in the table
282            # for their user ID.
283            value_values=[(presence_stream_id,) for _ in user_ids],
284            desc="add_users_to_send_full_presence_to",
285        )
286
287    async def get_presence_for_all_users(
288        self,
289        include_offline: bool = True,
290    ) -> Dict[str, UserPresenceState]:
291        """Retrieve the current presence state for all users.
292
293        Note that the presence_stream table is culled frequently, so it should only
294        contain the latest presence state for each user.
295
296        Args:
297            include_offline: Whether to include offline presence states
298
299        Returns:
300            A dict of user IDs to their current UserPresenceState.
301        """
302        users_to_state = {}
303
304        exclude_keyvalues = None
305        if not include_offline:
306            # Exclude offline presence state
307            exclude_keyvalues = {"state": "offline"}
308
309        # This may be a very heavy database query.
310        # We paginate in order to not block a database connection.
311        limit = 100
312        offset = 0
313        while True:
314            rows = await self.db_pool.runInteraction(
315                "get_presence_for_all_users",
316                self.db_pool.simple_select_list_paginate_txn,
317                "presence_stream",
318                orderby="stream_id",
319                start=offset,
320                limit=limit,
321                exclude_keyvalues=exclude_keyvalues,
322                retcols=(
323                    "user_id",
324                    "state",
325                    "last_active_ts",
326                    "last_federation_update_ts",
327                    "last_user_sync_ts",
328                    "status_msg",
329                    "currently_active",
330                ),
331                order_direction="ASC",
332            )
333
334            for row in rows:
335                users_to_state[row["user_id"]] = UserPresenceState(**row)
336
337            # We've run out of updates to query
338            if len(rows) < limit:
339                break
340
341            offset += limit
342
343        return users_to_state
344
345    def get_current_presence_token(self):
346        return self._presence_id_gen.get_current_token()
347
348    def _get_active_presence(self, db_conn: Connection):
349        """Fetch non-offline presence from the database so that we can register
350        the appropriate time outs.
351        """
352
353        # The `presence_stream_state_not_offline_idx` index should be used for this
354        # query.
355        sql = (
356            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
357            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
358            " WHERE state != ?"
359        )
360
361        txn = db_conn.cursor()
362        txn.execute(sql, (PresenceState.OFFLINE,))
363        rows = self.db_pool.cursor_to_dict(txn)
364        txn.close()
365
366        for row in rows:
367            row["currently_active"] = bool(row["currently_active"])
368
369        return [UserPresenceState(**row) for row in rows]
370
371    def take_presence_startup_info(self):
372        active_on_startup = self._presence_on_startup
373        self._presence_on_startup = None
374        return active_on_startup
375
376    def process_replication_rows(self, stream_name, instance_name, token, rows):
377        if stream_name == PresenceStream.NAME:
378            self._presence_id_gen.advance(instance_name, token)
379            for row in rows:
380                self.presence_stream_cache.entity_has_changed(row.user_id, token)
381                self._get_presence_for_user.invalidate((row.user_id,))
382        return super().process_replication_rows(stream_name, instance_name, token, rows)
383