1# Copyright 2019 New Vector Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging
16from typing import List, Optional, Tuple, Union, cast
17
18import attr
19
20from synapse.api.constants import RelationTypes
21from synapse.events import EventBase
22from synapse.storage._base import SQLBaseStore
23from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
24from synapse.storage.databases.main.stream import generate_pagination_where_clause
25from synapse.storage.relations import (
26    AggregationPaginationToken,
27    PaginationChunk,
28    RelationPaginationToken,
29)
30from synapse.util.caches.descriptors import cached
31
32logger = logging.getLogger(__name__)
33
34
35class RelationsWorkerStore(SQLBaseStore):
36    @cached(tree=True)
37    async def get_relations_for_event(
38        self,
39        event_id: str,
40        room_id: str,
41        relation_type: Optional[str] = None,
42        event_type: Optional[str] = None,
43        aggregation_key: Optional[str] = None,
44        limit: int = 5,
45        direction: str = "b",
46        from_token: Optional[RelationPaginationToken] = None,
47        to_token: Optional[RelationPaginationToken] = None,
48    ) -> PaginationChunk:
49        """Get a list of relations for an event, ordered by topological ordering.
50
51        Args:
52            event_id: Fetch events that relate to this event ID.
53            room_id: The room the event belongs to.
54            relation_type: Only fetch events with this relation type, if given.
55            event_type: Only fetch events with this event type, if given.
56            aggregation_key: Only fetch events with this aggregation key, if given.
57            limit: Only fetch the most recent `limit` events.
58            direction: Whether to fetch the most recent first (`"b"`) or the
59                oldest first (`"f"`).
60            from_token: Fetch rows from the given token, or from the start if None.
61            to_token: Fetch rows up to the given token, or up to the end if None.
62
63        Returns:
64            List of event IDs that match relations requested. The rows are of
65            the form `{"event_id": "..."}`.
66        """
67
68        where_clause = ["relates_to_id = ?", "room_id = ?"]
69        where_args: List[Union[str, int]] = [event_id, room_id]
70
71        if relation_type is not None:
72            where_clause.append("relation_type = ?")
73            where_args.append(relation_type)
74
75        if event_type is not None:
76            where_clause.append("type = ?")
77            where_args.append(event_type)
78
79        if aggregation_key:
80            where_clause.append("aggregation_key = ?")
81            where_args.append(aggregation_key)
82
83        pagination_clause = generate_pagination_where_clause(
84            direction=direction,
85            column_names=("topological_ordering", "stream_ordering"),
86            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
87            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
88            engine=self.database_engine,
89        )
90
91        if pagination_clause:
92            where_clause.append(pagination_clause)
93
94        if direction == "b":
95            order = "DESC"
96        else:
97            order = "ASC"
98
99        sql = """
100            SELECT event_id, topological_ordering, stream_ordering
101            FROM event_relations
102            INNER JOIN events USING (event_id)
103            WHERE %s
104            ORDER BY topological_ordering %s, stream_ordering %s
105            LIMIT ?
106        """ % (
107            " AND ".join(where_clause),
108            order,
109            order,
110        )
111
112        def _get_recent_references_for_event_txn(
113            txn: LoggingTransaction,
114        ) -> PaginationChunk:
115            txn.execute(sql, where_args + [limit + 1])
116
117            last_topo_id = None
118            last_stream_id = None
119            events = []
120            for row in txn:
121                events.append({"event_id": row[0]})
122                last_topo_id = row[1]
123                last_stream_id = row[2]
124
125            next_batch = None
126            if len(events) > limit and last_topo_id and last_stream_id:
127                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
128
129            return PaginationChunk(
130                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
131            )
132
133        return await self.db_pool.runInteraction(
134            "get_recent_references_for_event", _get_recent_references_for_event_txn
135        )
136
137    async def event_includes_relation(self, event_id: str) -> bool:
138        """Check if the given event relates to another event.
139
140        An event has a relation if it has a valid m.relates_to with a rel_type
141        and event_id in the content:
142
143        {
144            "content": {
145                "m.relates_to": {
146                    "rel_type": "m.replace",
147                    "event_id": "$other_event_id"
148                }
149            }
150        }
151
152        Args:
153            event_id: The event to check.
154
155        Returns:
156            True if the event includes a valid relation.
157        """
158
159        result = await self.db_pool.simple_select_one_onecol(
160            table="event_relations",
161            keyvalues={"event_id": event_id},
162            retcol="event_id",
163            allow_none=True,
164            desc="event_includes_relation",
165        )
166        return result is not None
167
168    async def event_is_target_of_relation(self, parent_id: str) -> bool:
169        """Check if the given event is the target of another event's relation.
170
171        An event is the target of an event relation if it has a valid
172        m.relates_to with a rel_type and event_id pointing to parent_id in the
173        content:
174
175        {
176            "content": {
177                "m.relates_to": {
178                    "rel_type": "m.replace",
179                    "event_id": "$parent_id"
180                }
181            }
182        }
183
184        Args:
185            parent_id: The event to check.
186
187        Returns:
188            True if the event is the target of another event's relation.
189        """
190
191        result = await self.db_pool.simple_select_one_onecol(
192            table="event_relations",
193            keyvalues={"relates_to_id": parent_id},
194            retcol="event_id",
195            allow_none=True,
196            desc="event_is_target_of_relation",
197        )
198        return result is not None
199
200    @cached(tree=True)
201    async def get_aggregation_groups_for_event(
202        self,
203        event_id: str,
204        room_id: str,
205        event_type: Optional[str] = None,
206        limit: int = 5,
207        direction: str = "b",
208        from_token: Optional[AggregationPaginationToken] = None,
209        to_token: Optional[AggregationPaginationToken] = None,
210    ) -> PaginationChunk:
211        """Get a list of annotations on the event, grouped by event type and
212        aggregation key, sorted by count.
213
214        This is used e.g. to get the what and how many reactions have happend
215        on an event.
216
217        Args:
218            event_id: Fetch events that relate to this event ID.
219            room_id: The room the event belongs to.
220            event_type: Only fetch events with this event type, if given.
221            limit: Only fetch the `limit` groups.
222            direction: Whether to fetch the highest count first (`"b"`) or
223                the lowest count first (`"f"`).
224            from_token: Fetch rows from the given token, or from the start if None.
225            to_token: Fetch rows up to the given token, or up to the end if None.
226
227        Returns:
228            List of groups of annotations that match. Each row is a dict with
229            `type`, `key` and `count` fields.
230        """
231
232        where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
233        where_args: List[Union[str, int]] = [
234            event_id,
235            room_id,
236            RelationTypes.ANNOTATION,
237        ]
238
239        if event_type:
240            where_clause.append("type = ?")
241            where_args.append(event_type)
242
243        having_clause = generate_pagination_where_clause(
244            direction=direction,
245            column_names=("COUNT(*)", "MAX(stream_ordering)"),
246            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
247            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
248            engine=self.database_engine,
249        )
250
251        if direction == "b":
252            order = "DESC"
253        else:
254            order = "ASC"
255
256        if having_clause:
257            having_clause = "HAVING " + having_clause
258        else:
259            having_clause = ""
260
261        sql = """
262            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
263            FROM event_relations
264            INNER JOIN events USING (event_id)
265            WHERE {where_clause}
266            GROUP BY relation_type, type, aggregation_key
267            {having_clause}
268            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
269            LIMIT ?
270        """.format(
271            where_clause=" AND ".join(where_clause),
272            order=order,
273            having_clause=having_clause,
274        )
275
276        def _get_aggregation_groups_for_event_txn(
277            txn: LoggingTransaction,
278        ) -> PaginationChunk:
279            txn.execute(sql, where_args + [limit + 1])
280
281            next_batch = None
282            events = []
283            for row in txn:
284                events.append({"type": row[0], "key": row[1], "count": row[2]})
285                next_batch = AggregationPaginationToken(row[2], row[3])
286
287            if len(events) <= limit:
288                next_batch = None
289
290            return PaginationChunk(
291                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
292            )
293
294        return await self.db_pool.runInteraction(
295            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
296        )
297
298    @cached()
299    async def get_applicable_edit(
300        self, event_id: str, room_id: str
301    ) -> Optional[EventBase]:
302        """Get the most recent edit (if any) that has happened for the given
303        event.
304
305        Correctly handles checking whether edits were allowed to happen.
306
307        Args:
308            event_id: The original event ID
309            room_id: The original event's room ID
310
311        Returns:
312            The most recent edit, if any.
313        """
314
315        # We only allow edits for `m.room.message` events that have the same sender
316        # and event type. We can't assert these things during regular event auth so
317        # we have to do the checks post hoc.
318
319        # Fetches latest edit that has the same type and sender as the
320        # original, and is an `m.room.message`.
321        sql = """
322            SELECT edit.event_id FROM events AS edit
323            INNER JOIN event_relations USING (event_id)
324            INNER JOIN events AS original ON
325                original.event_id = relates_to_id
326                AND edit.type = original.type
327                AND edit.sender = original.sender
328            WHERE
329                relates_to_id = ?
330                AND relation_type = ?
331                AND edit.room_id = ?
332                AND edit.type = 'm.room.message'
333            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
334            LIMIT 1
335        """
336
337        def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
338            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
339            row = txn.fetchone()
340            if row:
341                return row[0]
342            return None
343
344        edit_id = await self.db_pool.runInteraction(
345            "get_applicable_edit", _get_applicable_edit_txn
346        )
347
348        if not edit_id:
349            return None
350
351        return await self.get_event(edit_id, allow_none=True)  # type: ignore[attr-defined]
352
353    @cached()
354    async def get_thread_summary(
355        self, event_id: str, room_id: str
356    ) -> Tuple[int, Optional[EventBase]]:
357        """Get the number of threaded replies, the senders of those replies, and
358        the latest reply (if any) for the given event.
359
360        Args:
361            event_id: Summarize the thread related to this event ID.
362            room_id: The room the event belongs to.
363
364        Returns:
365            The number of items in the thread and the most recent response, if any.
366        """
367
368        def _get_thread_summary_txn(
369            txn: LoggingTransaction,
370        ) -> Tuple[int, Optional[str]]:
371            # Fetch the count of threaded events and the latest event ID.
372            # TODO Should this only allow m.room.message events.
373            sql = """
374                SELECT event_id
375                FROM event_relations
376                INNER JOIN events USING (event_id)
377                WHERE
378                    relates_to_id = ?
379                    AND room_id = ?
380                    AND relation_type = ?
381                ORDER BY topological_ordering DESC, stream_ordering DESC
382                LIMIT 1
383            """
384
385            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
386            row = txn.fetchone()
387            if row is None:
388                return 0, None
389
390            latest_event_id = row[0]
391
392            sql = """
393                SELECT COUNT(event_id)
394                FROM event_relations
395                INNER JOIN events USING (event_id)
396                WHERE
397                    relates_to_id = ?
398                    AND room_id = ?
399                    AND relation_type = ?
400            """
401            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
402            count = cast(Tuple[int], txn.fetchone())[0]
403
404            return count, latest_event_id
405
406        count, latest_event_id = await self.db_pool.runInteraction(
407            "get_thread_summary", _get_thread_summary_txn
408        )
409
410        latest_event = None
411        if latest_event_id:
412            latest_event = await self.get_event(latest_event_id, allow_none=True)  # type: ignore[attr-defined]
413
414        return count, latest_event
415
416    async def events_have_relations(
417        self,
418        parent_ids: List[str],
419        relation_senders: Optional[List[str]],
420        relation_types: Optional[List[str]],
421    ) -> List[str]:
422        """Check which events have a relationship from the given senders of the
423        given types.
424
425        Args:
426            parent_ids: The events being annotated
427            relation_senders: The relation senders to check.
428            relation_types: The relation types to check.
429
430        Returns:
431            True if the event has at least one relationship from one of the given senders of the given type.
432        """
433        # If no restrictions are given then the event has the required relations.
434        if not relation_senders and not relation_types:
435            return parent_ids
436
437        sql = """
438            SELECT relates_to_id FROM event_relations
439            INNER JOIN events USING (event_id)
440            WHERE
441                %s;
442        """
443
444        def _get_if_events_have_relations(txn) -> List[str]:
445            clauses: List[str] = []
446            clause, args = make_in_list_sql_clause(
447                txn.database_engine, "relates_to_id", parent_ids
448            )
449            clauses.append(clause)
450
451            if relation_senders:
452                clause, temp_args = make_in_list_sql_clause(
453                    txn.database_engine, "sender", relation_senders
454                )
455                clauses.append(clause)
456                args.extend(temp_args)
457            if relation_types:
458                clause, temp_args = make_in_list_sql_clause(
459                    txn.database_engine, "relation_type", relation_types
460                )
461                clauses.append(clause)
462                args.extend(temp_args)
463
464            txn.execute(sql % " AND ".join(clauses), args)
465
466            return [row[0] for row in txn]
467
468        return await self.db_pool.runInteraction(
469            "get_if_events_have_relations", _get_if_events_have_relations
470        )
471
472    async def has_user_annotated_event(
473        self, parent_id: str, event_type: str, aggregation_key: str, sender: str
474    ) -> bool:
475        """Check if a user has already annotated an event with the same key
476        (e.g. already liked an event).
477
478        Args:
479            parent_id: The event being annotated
480            event_type: The event type of the annotation
481            aggregation_key: The aggregation key of the annotation
482            sender: The sender of the annotation
483
484        Returns:
485            True if the event is already annotated.
486        """
487
488        sql = """
489            SELECT 1 FROM event_relations
490            INNER JOIN events USING (event_id)
491            WHERE
492                relates_to_id = ?
493                AND relation_type = ?
494                AND type = ?
495                AND sender = ?
496                AND aggregation_key = ?
497            LIMIT 1;
498        """
499
500        def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
501            txn.execute(
502                sql,
503                (
504                    parent_id,
505                    RelationTypes.ANNOTATION,
506                    event_type,
507                    sender,
508                    aggregation_key,
509                ),
510            )
511
512            return bool(txn.fetchone())
513
514        return await self.db_pool.runInteraction(
515            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
516        )
517
518
519class RelationsStore(RelationsWorkerStore):
520    pass
521