1# Copyright 2019 New Vector Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import logging 16from typing import List, Optional, Tuple, Union, cast 17 18import attr 19 20from synapse.api.constants import RelationTypes 21from synapse.events import EventBase 22from synapse.storage._base import SQLBaseStore 23from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause 24from synapse.storage.databases.main.stream import generate_pagination_where_clause 25from synapse.storage.relations import ( 26 AggregationPaginationToken, 27 PaginationChunk, 28 RelationPaginationToken, 29) 30from synapse.util.caches.descriptors import cached 31 32logger = logging.getLogger(__name__) 33 34 35class RelationsWorkerStore(SQLBaseStore): 36 @cached(tree=True) 37 async def get_relations_for_event( 38 self, 39 event_id: str, 40 room_id: str, 41 relation_type: Optional[str] = None, 42 event_type: Optional[str] = None, 43 aggregation_key: Optional[str] = None, 44 limit: int = 5, 45 direction: str = "b", 46 from_token: Optional[RelationPaginationToken] = None, 47 to_token: Optional[RelationPaginationToken] = None, 48 ) -> PaginationChunk: 49 """Get a list of relations for an event, ordered by topological ordering. 50 51 Args: 52 event_id: Fetch events that relate to this event ID. 53 room_id: The room the event belongs to. 54 relation_type: Only fetch events with this relation type, if given. 55 event_type: Only fetch events with this event type, if given. 56 aggregation_key: Only fetch events with this aggregation key, if given. 57 limit: Only fetch the most recent `limit` events. 58 direction: Whether to fetch the most recent first (`"b"`) or the 59 oldest first (`"f"`). 60 from_token: Fetch rows from the given token, or from the start if None. 61 to_token: Fetch rows up to the given token, or up to the end if None. 62 63 Returns: 64 List of event IDs that match relations requested. The rows are of 65 the form `{"event_id": "..."}`. 66 """ 67 68 where_clause = ["relates_to_id = ?", "room_id = ?"] 69 where_args: List[Union[str, int]] = [event_id, room_id] 70 71 if relation_type is not None: 72 where_clause.append("relation_type = ?") 73 where_args.append(relation_type) 74 75 if event_type is not None: 76 where_clause.append("type = ?") 77 where_args.append(event_type) 78 79 if aggregation_key: 80 where_clause.append("aggregation_key = ?") 81 where_args.append(aggregation_key) 82 83 pagination_clause = generate_pagination_where_clause( 84 direction=direction, 85 column_names=("topological_ordering", "stream_ordering"), 86 from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] 87 to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] 88 engine=self.database_engine, 89 ) 90 91 if pagination_clause: 92 where_clause.append(pagination_clause) 93 94 if direction == "b": 95 order = "DESC" 96 else: 97 order = "ASC" 98 99 sql = """ 100 SELECT event_id, topological_ordering, stream_ordering 101 FROM event_relations 102 INNER JOIN events USING (event_id) 103 WHERE %s 104 ORDER BY topological_ordering %s, stream_ordering %s 105 LIMIT ? 106 """ % ( 107 " AND ".join(where_clause), 108 order, 109 order, 110 ) 111 112 def _get_recent_references_for_event_txn( 113 txn: LoggingTransaction, 114 ) -> PaginationChunk: 115 txn.execute(sql, where_args + [limit + 1]) 116 117 last_topo_id = None 118 last_stream_id = None 119 events = [] 120 for row in txn: 121 events.append({"event_id": row[0]}) 122 last_topo_id = row[1] 123 last_stream_id = row[2] 124 125 next_batch = None 126 if len(events) > limit and last_topo_id and last_stream_id: 127 next_batch = RelationPaginationToken(last_topo_id, last_stream_id) 128 129 return PaginationChunk( 130 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token 131 ) 132 133 return await self.db_pool.runInteraction( 134 "get_recent_references_for_event", _get_recent_references_for_event_txn 135 ) 136 137 async def event_includes_relation(self, event_id: str) -> bool: 138 """Check if the given event relates to another event. 139 140 An event has a relation if it has a valid m.relates_to with a rel_type 141 and event_id in the content: 142 143 { 144 "content": { 145 "m.relates_to": { 146 "rel_type": "m.replace", 147 "event_id": "$other_event_id" 148 } 149 } 150 } 151 152 Args: 153 event_id: The event to check. 154 155 Returns: 156 True if the event includes a valid relation. 157 """ 158 159 result = await self.db_pool.simple_select_one_onecol( 160 table="event_relations", 161 keyvalues={"event_id": event_id}, 162 retcol="event_id", 163 allow_none=True, 164 desc="event_includes_relation", 165 ) 166 return result is not None 167 168 async def event_is_target_of_relation(self, parent_id: str) -> bool: 169 """Check if the given event is the target of another event's relation. 170 171 An event is the target of an event relation if it has a valid 172 m.relates_to with a rel_type and event_id pointing to parent_id in the 173 content: 174 175 { 176 "content": { 177 "m.relates_to": { 178 "rel_type": "m.replace", 179 "event_id": "$parent_id" 180 } 181 } 182 } 183 184 Args: 185 parent_id: The event to check. 186 187 Returns: 188 True if the event is the target of another event's relation. 189 """ 190 191 result = await self.db_pool.simple_select_one_onecol( 192 table="event_relations", 193 keyvalues={"relates_to_id": parent_id}, 194 retcol="event_id", 195 allow_none=True, 196 desc="event_is_target_of_relation", 197 ) 198 return result is not None 199 200 @cached(tree=True) 201 async def get_aggregation_groups_for_event( 202 self, 203 event_id: str, 204 room_id: str, 205 event_type: Optional[str] = None, 206 limit: int = 5, 207 direction: str = "b", 208 from_token: Optional[AggregationPaginationToken] = None, 209 to_token: Optional[AggregationPaginationToken] = None, 210 ) -> PaginationChunk: 211 """Get a list of annotations on the event, grouped by event type and 212 aggregation key, sorted by count. 213 214 This is used e.g. to get the what and how many reactions have happend 215 on an event. 216 217 Args: 218 event_id: Fetch events that relate to this event ID. 219 room_id: The room the event belongs to. 220 event_type: Only fetch events with this event type, if given. 221 limit: Only fetch the `limit` groups. 222 direction: Whether to fetch the highest count first (`"b"`) or 223 the lowest count first (`"f"`). 224 from_token: Fetch rows from the given token, or from the start if None. 225 to_token: Fetch rows up to the given token, or up to the end if None. 226 227 Returns: 228 List of groups of annotations that match. Each row is a dict with 229 `type`, `key` and `count` fields. 230 """ 231 232 where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"] 233 where_args: List[Union[str, int]] = [ 234 event_id, 235 room_id, 236 RelationTypes.ANNOTATION, 237 ] 238 239 if event_type: 240 where_clause.append("type = ?") 241 where_args.append(event_type) 242 243 having_clause = generate_pagination_where_clause( 244 direction=direction, 245 column_names=("COUNT(*)", "MAX(stream_ordering)"), 246 from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] 247 to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] 248 engine=self.database_engine, 249 ) 250 251 if direction == "b": 252 order = "DESC" 253 else: 254 order = "ASC" 255 256 if having_clause: 257 having_clause = "HAVING " + having_clause 258 else: 259 having_clause = "" 260 261 sql = """ 262 SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering) 263 FROM event_relations 264 INNER JOIN events USING (event_id) 265 WHERE {where_clause} 266 GROUP BY relation_type, type, aggregation_key 267 {having_clause} 268 ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order} 269 LIMIT ? 270 """.format( 271 where_clause=" AND ".join(where_clause), 272 order=order, 273 having_clause=having_clause, 274 ) 275 276 def _get_aggregation_groups_for_event_txn( 277 txn: LoggingTransaction, 278 ) -> PaginationChunk: 279 txn.execute(sql, where_args + [limit + 1]) 280 281 next_batch = None 282 events = [] 283 for row in txn: 284 events.append({"type": row[0], "key": row[1], "count": row[2]}) 285 next_batch = AggregationPaginationToken(row[2], row[3]) 286 287 if len(events) <= limit: 288 next_batch = None 289 290 return PaginationChunk( 291 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token 292 ) 293 294 return await self.db_pool.runInteraction( 295 "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn 296 ) 297 298 @cached() 299 async def get_applicable_edit( 300 self, event_id: str, room_id: str 301 ) -> Optional[EventBase]: 302 """Get the most recent edit (if any) that has happened for the given 303 event. 304 305 Correctly handles checking whether edits were allowed to happen. 306 307 Args: 308 event_id: The original event ID 309 room_id: The original event's room ID 310 311 Returns: 312 The most recent edit, if any. 313 """ 314 315 # We only allow edits for `m.room.message` events that have the same sender 316 # and event type. We can't assert these things during regular event auth so 317 # we have to do the checks post hoc. 318 319 # Fetches latest edit that has the same type and sender as the 320 # original, and is an `m.room.message`. 321 sql = """ 322 SELECT edit.event_id FROM events AS edit 323 INNER JOIN event_relations USING (event_id) 324 INNER JOIN events AS original ON 325 original.event_id = relates_to_id 326 AND edit.type = original.type 327 AND edit.sender = original.sender 328 WHERE 329 relates_to_id = ? 330 AND relation_type = ? 331 AND edit.room_id = ? 332 AND edit.type = 'm.room.message' 333 ORDER by edit.origin_server_ts DESC, edit.event_id DESC 334 LIMIT 1 335 """ 336 337 def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: 338 txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) 339 row = txn.fetchone() 340 if row: 341 return row[0] 342 return None 343 344 edit_id = await self.db_pool.runInteraction( 345 "get_applicable_edit", _get_applicable_edit_txn 346 ) 347 348 if not edit_id: 349 return None 350 351 return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] 352 353 @cached() 354 async def get_thread_summary( 355 self, event_id: str, room_id: str 356 ) -> Tuple[int, Optional[EventBase]]: 357 """Get the number of threaded replies, the senders of those replies, and 358 the latest reply (if any) for the given event. 359 360 Args: 361 event_id: Summarize the thread related to this event ID. 362 room_id: The room the event belongs to. 363 364 Returns: 365 The number of items in the thread and the most recent response, if any. 366 """ 367 368 def _get_thread_summary_txn( 369 txn: LoggingTransaction, 370 ) -> Tuple[int, Optional[str]]: 371 # Fetch the count of threaded events and the latest event ID. 372 # TODO Should this only allow m.room.message events. 373 sql = """ 374 SELECT event_id 375 FROM event_relations 376 INNER JOIN events USING (event_id) 377 WHERE 378 relates_to_id = ? 379 AND room_id = ? 380 AND relation_type = ? 381 ORDER BY topological_ordering DESC, stream_ordering DESC 382 LIMIT 1 383 """ 384 385 txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) 386 row = txn.fetchone() 387 if row is None: 388 return 0, None 389 390 latest_event_id = row[0] 391 392 sql = """ 393 SELECT COUNT(event_id) 394 FROM event_relations 395 INNER JOIN events USING (event_id) 396 WHERE 397 relates_to_id = ? 398 AND room_id = ? 399 AND relation_type = ? 400 """ 401 txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) 402 count = cast(Tuple[int], txn.fetchone())[0] 403 404 return count, latest_event_id 405 406 count, latest_event_id = await self.db_pool.runInteraction( 407 "get_thread_summary", _get_thread_summary_txn 408 ) 409 410 latest_event = None 411 if latest_event_id: 412 latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] 413 414 return count, latest_event 415 416 async def events_have_relations( 417 self, 418 parent_ids: List[str], 419 relation_senders: Optional[List[str]], 420 relation_types: Optional[List[str]], 421 ) -> List[str]: 422 """Check which events have a relationship from the given senders of the 423 given types. 424 425 Args: 426 parent_ids: The events being annotated 427 relation_senders: The relation senders to check. 428 relation_types: The relation types to check. 429 430 Returns: 431 True if the event has at least one relationship from one of the given senders of the given type. 432 """ 433 # If no restrictions are given then the event has the required relations. 434 if not relation_senders and not relation_types: 435 return parent_ids 436 437 sql = """ 438 SELECT relates_to_id FROM event_relations 439 INNER JOIN events USING (event_id) 440 WHERE 441 %s; 442 """ 443 444 def _get_if_events_have_relations(txn) -> List[str]: 445 clauses: List[str] = [] 446 clause, args = make_in_list_sql_clause( 447 txn.database_engine, "relates_to_id", parent_ids 448 ) 449 clauses.append(clause) 450 451 if relation_senders: 452 clause, temp_args = make_in_list_sql_clause( 453 txn.database_engine, "sender", relation_senders 454 ) 455 clauses.append(clause) 456 args.extend(temp_args) 457 if relation_types: 458 clause, temp_args = make_in_list_sql_clause( 459 txn.database_engine, "relation_type", relation_types 460 ) 461 clauses.append(clause) 462 args.extend(temp_args) 463 464 txn.execute(sql % " AND ".join(clauses), args) 465 466 return [row[0] for row in txn] 467 468 return await self.db_pool.runInteraction( 469 "get_if_events_have_relations", _get_if_events_have_relations 470 ) 471 472 async def has_user_annotated_event( 473 self, parent_id: str, event_type: str, aggregation_key: str, sender: str 474 ) -> bool: 475 """Check if a user has already annotated an event with the same key 476 (e.g. already liked an event). 477 478 Args: 479 parent_id: The event being annotated 480 event_type: The event type of the annotation 481 aggregation_key: The aggregation key of the annotation 482 sender: The sender of the annotation 483 484 Returns: 485 True if the event is already annotated. 486 """ 487 488 sql = """ 489 SELECT 1 FROM event_relations 490 INNER JOIN events USING (event_id) 491 WHERE 492 relates_to_id = ? 493 AND relation_type = ? 494 AND type = ? 495 AND sender = ? 496 AND aggregation_key = ? 497 LIMIT 1; 498 """ 499 500 def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: 501 txn.execute( 502 sql, 503 ( 504 parent_id, 505 RelationTypes.ANNOTATION, 506 event_type, 507 sender, 508 aggregation_key, 509 ), 510 ) 511 512 return bool(txn.fetchone()) 513 514 return await self.db_pool.runInteraction( 515 "get_if_user_has_annotated_event", _get_if_user_has_annotated_event 516 ) 517 518 519class RelationsStore(RelationsWorkerStore): 520 pass 521