1# Copyright 2014-2016 OpenMarket Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple 16 17from synapse.api.presence import PresenceState, UserPresenceState 18from synapse.replication.tcp.streams import PresenceStream 19from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause 20from synapse.storage.database import DatabasePool, LoggingDatabaseConnection 21from synapse.storage.engines import PostgresEngine 22from synapse.storage.types import Connection 23from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator 24from synapse.util.caches.descriptors import cached, cachedList 25from synapse.util.caches.stream_change_cache import StreamChangeCache 26from synapse.util.iterutils import batch_iter 27 28if TYPE_CHECKING: 29 from synapse.server import HomeServer 30 31 32class PresenceBackgroundUpdateStore(SQLBaseStore): 33 def __init__( 34 self, 35 database: DatabasePool, 36 db_conn: LoggingDatabaseConnection, 37 hs: "HomeServer", 38 ): 39 super().__init__(database, db_conn, hs) 40 41 # Used by `PresenceStore._get_active_presence()` 42 self.db_pool.updates.register_background_index_update( 43 "presence_stream_not_offline_index", 44 index_name="presence_stream_state_not_offline_idx", 45 table="presence_stream", 46 columns=["state"], 47 where_clause="state != 'offline'", 48 ) 49 50 51class PresenceStore(PresenceBackgroundUpdateStore): 52 def __init__( 53 self, 54 database: DatabasePool, 55 db_conn: LoggingDatabaseConnection, 56 hs: "HomeServer", 57 ): 58 super().__init__(database, db_conn, hs) 59 60 self._can_persist_presence = ( 61 hs.get_instance_name() in hs.config.worker.writers.presence 62 ) 63 64 if isinstance(database.engine, PostgresEngine): 65 self._presence_id_gen = MultiWriterIdGenerator( 66 db_conn=db_conn, 67 db=database, 68 stream_name="presence_stream", 69 instance_name=self._instance_name, 70 tables=[("presence_stream", "instance_name", "stream_id")], 71 sequence_name="presence_stream_sequence", 72 writers=hs.config.worker.writers.presence, 73 ) 74 else: 75 self._presence_id_gen = StreamIdGenerator( 76 db_conn, "presence_stream", "stream_id" 77 ) 78 79 self.hs = hs 80 self._presence_on_startup = self._get_active_presence(db_conn) 81 82 presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict( 83 db_conn, 84 "presence_stream", 85 entity_column="user_id", 86 stream_column="stream_id", 87 max_value=self._presence_id_gen.get_current_token(), 88 ) 89 self.presence_stream_cache = StreamChangeCache( 90 "PresenceStreamChangeCache", 91 min_presence_val, 92 prefilled_cache=presence_cache_prefill, 93 ) 94 95 async def update_presence(self, presence_states) -> Tuple[int, int]: 96 assert self._can_persist_presence 97 98 stream_ordering_manager = self._presence_id_gen.get_next_mult( 99 len(presence_states) 100 ) 101 102 async with stream_ordering_manager as stream_orderings: 103 await self.db_pool.runInteraction( 104 "update_presence", 105 self._update_presence_txn, 106 stream_orderings, 107 presence_states, 108 ) 109 110 return stream_orderings[-1], self._presence_id_gen.get_current_token() 111 112 def _update_presence_txn(self, txn, stream_orderings, presence_states): 113 for stream_id, state in zip(stream_orderings, presence_states): 114 txn.call_after( 115 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id 116 ) 117 txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) 118 119 # Delete old rows to stop database from getting really big 120 sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " 121 122 for states in batch_iter(presence_states, 50): 123 clause, args = make_in_list_sql_clause( 124 self.database_engine, "user_id", [s.user_id for s in states] 125 ) 126 txn.execute(sql + clause, [stream_id] + list(args)) 127 128 # Actually insert new rows 129 self.db_pool.simple_insert_many_txn( 130 txn, 131 table="presence_stream", 132 values=[ 133 { 134 "stream_id": stream_id, 135 "user_id": state.user_id, 136 "state": state.state, 137 "last_active_ts": state.last_active_ts, 138 "last_federation_update_ts": state.last_federation_update_ts, 139 "last_user_sync_ts": state.last_user_sync_ts, 140 "status_msg": state.status_msg, 141 "currently_active": state.currently_active, 142 "instance_name": self._instance_name, 143 } 144 for stream_id, state in zip(stream_orderings, presence_states) 145 ], 146 ) 147 148 async def get_all_presence_updates( 149 self, instance_name: str, last_id: int, current_id: int, limit: int 150 ) -> Tuple[List[Tuple[int, list]], int, bool]: 151 """Get updates for presence replication stream. 152 153 Args: 154 instance_name: The writer we want to fetch updates from. Unused 155 here since there is only ever one writer. 156 last_id: The token to fetch updates from. Exclusive. 157 current_id: The token to fetch updates up to. Inclusive. 158 limit: The requested limit for the number of rows to return. The 159 function may return more or fewer rows. 160 161 Returns: 162 A tuple consisting of: the updates, a token to use to fetch 163 subsequent updates, and whether we returned fewer rows than exists 164 between the requested tokens due to the limit. 165 166 The token returned can be used in a subsequent call to this 167 function to get further updatees. 168 169 The updates are a list of 2-tuples of stream ID and the row data 170 """ 171 172 if last_id == current_id: 173 return [], current_id, False 174 175 def get_all_presence_updates_txn(txn): 176 sql = """ 177 SELECT stream_id, user_id, state, last_active_ts, 178 last_federation_update_ts, last_user_sync_ts, 179 status_msg, 180 currently_active 181 FROM presence_stream 182 WHERE ? < stream_id AND stream_id <= ? 183 ORDER BY stream_id ASC 184 LIMIT ? 185 """ 186 txn.execute(sql, (last_id, current_id, limit)) 187 updates = [(row[0], row[1:]) for row in txn] 188 189 upper_bound = current_id 190 limited = False 191 if len(updates) >= limit: 192 upper_bound = updates[-1][0] 193 limited = True 194 195 return updates, upper_bound, limited 196 197 return await self.db_pool.runInteraction( 198 "get_all_presence_updates", get_all_presence_updates_txn 199 ) 200 201 @cached() 202 def _get_presence_for_user(self, user_id): 203 raise NotImplementedError() 204 205 @cachedList( 206 cached_method_name="_get_presence_for_user", 207 list_name="user_ids", 208 num_args=1, 209 ) 210 async def get_presence_for_users(self, user_ids): 211 rows = await self.db_pool.simple_select_many_batch( 212 table="presence_stream", 213 column="user_id", 214 iterable=user_ids, 215 keyvalues={}, 216 retcols=( 217 "user_id", 218 "state", 219 "last_active_ts", 220 "last_federation_update_ts", 221 "last_user_sync_ts", 222 "status_msg", 223 "currently_active", 224 ), 225 desc="get_presence_for_users", 226 ) 227 228 for row in rows: 229 row["currently_active"] = bool(row["currently_active"]) 230 231 return {row["user_id"]: UserPresenceState(**row) for row in rows} 232 233 async def should_user_receive_full_presence_with_token( 234 self, 235 user_id: str, 236 from_token: int, 237 ) -> bool: 238 """Check whether the given user should receive full presence using the stream token 239 they're updating from. 240 241 Args: 242 user_id: The ID of the user to check. 243 from_token: The stream token included in their /sync token. 244 245 Returns: 246 True if the user should have full presence sent to them, False otherwise. 247 """ 248 249 def _should_user_receive_full_presence_with_token_txn(txn): 250 sql = """ 251 SELECT 1 FROM users_to_send_full_presence_to 252 WHERE user_id = ? 253 AND presence_stream_id >= ? 254 """ 255 txn.execute(sql, (user_id, from_token)) 256 return bool(txn.fetchone()) 257 258 return await self.db_pool.runInteraction( 259 "should_user_receive_full_presence_with_token", 260 _should_user_receive_full_presence_with_token_txn, 261 ) 262 263 async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): 264 """Adds to the list of users who should receive a full snapshot of presence 265 upon their next sync. 266 267 Args: 268 user_ids: An iterable of user IDs. 269 """ 270 # Add user entries to the table, updating the presence_stream_id column if the user already 271 # exists in the table. 272 presence_stream_id = self._presence_id_gen.get_current_token() 273 await self.db_pool.simple_upsert_many( 274 table="users_to_send_full_presence_to", 275 key_names=("user_id",), 276 key_values=[(user_id,) for user_id in user_ids], 277 value_names=("presence_stream_id",), 278 # We save the current presence stream ID token along with the user ID entry so 279 # that when a user /sync's, even if they syncing multiple times across separate 280 # devices at different times, each device will receive full presence once - when 281 # the presence stream ID in their sync token is less than the one in the table 282 # for their user ID. 283 value_values=[(presence_stream_id,) for _ in user_ids], 284 desc="add_users_to_send_full_presence_to", 285 ) 286 287 async def get_presence_for_all_users( 288 self, 289 include_offline: bool = True, 290 ) -> Dict[str, UserPresenceState]: 291 """Retrieve the current presence state for all users. 292 293 Note that the presence_stream table is culled frequently, so it should only 294 contain the latest presence state for each user. 295 296 Args: 297 include_offline: Whether to include offline presence states 298 299 Returns: 300 A dict of user IDs to their current UserPresenceState. 301 """ 302 users_to_state = {} 303 304 exclude_keyvalues = None 305 if not include_offline: 306 # Exclude offline presence state 307 exclude_keyvalues = {"state": "offline"} 308 309 # This may be a very heavy database query. 310 # We paginate in order to not block a database connection. 311 limit = 100 312 offset = 0 313 while True: 314 rows = await self.db_pool.runInteraction( 315 "get_presence_for_all_users", 316 self.db_pool.simple_select_list_paginate_txn, 317 "presence_stream", 318 orderby="stream_id", 319 start=offset, 320 limit=limit, 321 exclude_keyvalues=exclude_keyvalues, 322 retcols=( 323 "user_id", 324 "state", 325 "last_active_ts", 326 "last_federation_update_ts", 327 "last_user_sync_ts", 328 "status_msg", 329 "currently_active", 330 ), 331 order_direction="ASC", 332 ) 333 334 for row in rows: 335 users_to_state[row["user_id"]] = UserPresenceState(**row) 336 337 # We've run out of updates to query 338 if len(rows) < limit: 339 break 340 341 offset += limit 342 343 return users_to_state 344 345 def get_current_presence_token(self): 346 return self._presence_id_gen.get_current_token() 347 348 def _get_active_presence(self, db_conn: Connection): 349 """Fetch non-offline presence from the database so that we can register 350 the appropriate time outs. 351 """ 352 353 # The `presence_stream_state_not_offline_idx` index should be used for this 354 # query. 355 sql = ( 356 "SELECT user_id, state, last_active_ts, last_federation_update_ts," 357 " last_user_sync_ts, status_msg, currently_active FROM presence_stream" 358 " WHERE state != ?" 359 ) 360 361 txn = db_conn.cursor() 362 txn.execute(sql, (PresenceState.OFFLINE,)) 363 rows = self.db_pool.cursor_to_dict(txn) 364 txn.close() 365 366 for row in rows: 367 row["currently_active"] = bool(row["currently_active"]) 368 369 return [UserPresenceState(**row) for row in rows] 370 371 def take_presence_startup_info(self): 372 active_on_startup = self._presence_on_startup 373 self._presence_on_startup = None 374 return active_on_startup 375 376 def process_replication_rows(self, stream_name, instance_name, token, rows): 377 if stream_name == PresenceStream.NAME: 378 self._presence_id_gen.advance(instance_name, token) 379 for row in rows: 380 self.presence_stream_cache.entity_has_changed(row.user_id, token) 381 self._get_presence_for_user.invalidate((row.user_id,)) 382 return super().process_replication_rows(stream_name, instance_name, token, rows) 383