1# Copyright 2014-2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import TYPE_CHECKING, List, Optional, Tuple, Union
15
16import attr
17from frozendict import frozendict
18
19from twisted.internet.defer import Deferred
20
21from synapse.appservice import ApplicationService
22from synapse.events import EventBase
23from synapse.logging.context import make_deferred_yieldable, run_in_background
24from synapse.types import JsonDict, StateMap
25
26if TYPE_CHECKING:
27    from synapse.storage import Storage
28    from synapse.storage.databases.main import DataStore
29
30
31@attr.s(slots=True)
32class EventContext:
33    """
34    Holds information relevant to persisting an event
35
36    Attributes:
37        rejected: A rejection reason if the event was rejected, else False
38
39        _state_group: The ID of the state group for this event. Note that state events
40            are persisted with a state group which includes the new event, so this is
41            effectively the state *after* the event in question.
42
43            For a *rejected* state event, where the state of the rejected event is
44            ignored, this state_group should never make it into the
45            event_to_state_groups table. Indeed, inspecting this value for a rejected
46            state event is almost certainly incorrect.
47
48            For an outlier, where we don't have the state at the event, this will be
49            None.
50
51            Note that this is a private attribute: it should be accessed via
52            the ``state_group`` property.
53
54        state_group_before_event: The ID of the state group representing the state
55            of the room before this event.
56
57            If this is a non-state event, this will be the same as ``state_group``. If
58            it's a state event, it will be the same as ``prev_group``.
59
60            If ``state_group`` is None (ie, the event is an outlier),
61            ``state_group_before_event`` will always also be ``None``.
62
63        prev_group: If it is known, ``state_group``'s prev_group. Note that this being
64            None does not necessarily mean that ``state_group`` does not have
65            a prev_group!
66
67            If the event is a state event, this is normally the same as ``prev_group``.
68
69            If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
70            will always also be ``None``.
71
72            Note that this *not* (necessarily) the state group associated with
73            ``_prev_state_ids``.
74
75        delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
76            and ``state_group``.
77
78        app_service: If this event is being sent by a (local) application service, that
79            app service.
80
81        _current_state_ids: The room state map, including this event - ie, the state
82            in ``state_group``.
83
84            (type, state_key) -> event_id
85
86            For an outlier, this is {}
87
88            Note that this is a private attribute: it should be accessed via
89            ``get_current_state_ids``. _AsyncEventContext impl calculates this
90            on-demand: it will be None until that happens.
91
92        _prev_state_ids: The room state map, excluding this event - ie, the state
93            in ``state_group_before_event``. For a non-state
94            event, this will be the same as _current_state_events.
95
96            Note that it is a completely different thing to prev_group!
97
98            (type, state_key) -> event_id
99
100            For an outlier, this is {}
101
102            As with _current_state_ids, this is a private attribute. It should be
103            accessed via get_prev_state_ids.
104    """
105
106    rejected = attr.ib(default=False, type=Union[bool, str])
107    _state_group = attr.ib(default=None, type=Optional[int])
108    state_group_before_event = attr.ib(default=None, type=Optional[int])
109    prev_group = attr.ib(default=None, type=Optional[int])
110    delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
111    app_service = attr.ib(default=None, type=Optional[ApplicationService])
112
113    _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
114    _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
115
116    @staticmethod
117    def with_state(
118        state_group: Optional[int],
119        state_group_before_event: Optional[int],
120        current_state_ids: Optional[StateMap[str]],
121        prev_state_ids: Optional[StateMap[str]],
122        prev_group: Optional[int] = None,
123        delta_ids: Optional[StateMap[str]] = None,
124    ) -> "EventContext":
125        return EventContext(
126            current_state_ids=current_state_ids,
127            prev_state_ids=prev_state_ids,
128            state_group=state_group,
129            state_group_before_event=state_group_before_event,
130            prev_group=prev_group,
131            delta_ids=delta_ids,
132        )
133
134    @staticmethod
135    def for_outlier() -> "EventContext":
136        """Return an EventContext instance suitable for persisting an outlier event"""
137        return EventContext(
138            current_state_ids={},
139            prev_state_ids={},
140        )
141
142    async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
143        """Converts self to a type that can be serialized as JSON, and then
144        deserialized by `deserialize`
145
146        Args:
147            event: The event that this context relates to
148
149        Returns:
150            The serialized event.
151        """
152
153        # We don't serialize the full state dicts, instead they get pulled out
154        # of the DB on the other side. However, the other side can't figure out
155        # the prev_state_ids, so if we're a state event we include the event
156        # id that we replaced in the state.
157        if event.is_state():
158            prev_state_ids = await self.get_prev_state_ids()
159            prev_state_id = prev_state_ids.get((event.type, event.state_key))
160        else:
161            prev_state_id = None
162
163        return {
164            "prev_state_id": prev_state_id,
165            "event_type": event.type,
166            "event_state_key": event.state_key if event.is_state() else None,
167            "state_group": self._state_group,
168            "state_group_before_event": self.state_group_before_event,
169            "rejected": self.rejected,
170            "prev_group": self.prev_group,
171            "delta_ids": _encode_state_dict(self.delta_ids),
172            "app_service_id": self.app_service.id if self.app_service else None,
173        }
174
175    @staticmethod
176    def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
177        """Converts a dict that was produced by `serialize` back into a
178        EventContext.
179
180        Args:
181            storage: Used to convert AS ID to AS object and fetch state.
182            input: A dict produced by `serialize`
183
184        Returns:
185            The event context.
186        """
187        context = _AsyncEventContextImpl(
188            # We use the state_group and prev_state_id stuff to pull the
189            # current_state_ids out of the DB and construct prev_state_ids.
190            storage=storage,
191            prev_state_id=input["prev_state_id"],
192            event_type=input["event_type"],
193            event_state_key=input["event_state_key"],
194            state_group=input["state_group"],
195            state_group_before_event=input["state_group_before_event"],
196            prev_group=input["prev_group"],
197            delta_ids=_decode_state_dict(input["delta_ids"]),
198            rejected=input["rejected"],
199        )
200
201        app_service_id = input["app_service_id"]
202        if app_service_id:
203            context.app_service = storage.main.get_app_service_by_id(app_service_id)
204
205        return context
206
207    @property
208    def state_group(self) -> Optional[int]:
209        """The ID of the state group for this event.
210
211        Note that state events are persisted with a state group which includes the new
212        event, so this is effectively the state *after* the event in question.
213
214        For an outlier, where we don't have the state at the event, this will be None.
215
216        It is an error to access this for a rejected event, since rejected state should
217        not make it into the room state. Accessing this property will raise an exception
218        if ``rejected`` is set.
219        """
220        if self.rejected:
221            raise RuntimeError("Attempt to access state_group of rejected event")
222
223        return self._state_group
224
225    async def get_current_state_ids(self) -> Optional[StateMap[str]]:
226        """
227        Gets the room state map, including this event - ie, the state in ``state_group``
228
229        It is an error to access this for a rejected event, since rejected state should
230        not make it into the room state. This method will raise an exception if
231        ``rejected`` is set.
232
233        Returns:
234            Returns None if state_group is None, which happens when the associated
235            event is an outlier.
236
237            Maps a (type, state_key) to the event ID of the state event matching
238            this tuple.
239        """
240        if self.rejected:
241            raise RuntimeError("Attempt to access state_ids of rejected event")
242
243        await self._ensure_fetched()
244        return self._current_state_ids
245
246    async def get_prev_state_ids(self) -> StateMap[str]:
247        """
248        Gets the room state map, excluding this event.
249
250        For a non-state event, this will be the same as get_current_state_ids().
251
252        Returns:
253            Returns {} if state_group is None, which happens when the associated
254            event is an outlier.
255
256            Maps a (type, state_key) to the event ID of the state event matching
257            this tuple.
258        """
259        await self._ensure_fetched()
260        # There *should* be previous state IDs now.
261        assert self._prev_state_ids is not None
262        return self._prev_state_ids
263
264    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
265        """Gets the current state IDs if we have them already cached.
266
267        It is an error to access this for a rejected event, since rejected state should
268        not make it into the room state. This method will raise an exception if
269        ``rejected`` is set.
270
271        Returns:
272            Returns None if we haven't cached the state or if state_group is None
273            (which happens when the associated event is an outlier).
274
275            Otherwise, returns the the current state IDs.
276        """
277        if self.rejected:
278            raise RuntimeError("Attempt to access state_ids of rejected event")
279
280        return self._current_state_ids
281
282    async def _ensure_fetched(self) -> None:
283        return None
284
285
286@attr.s(slots=True)
287class _AsyncEventContextImpl(EventContext):
288    """
289    An implementation of EventContext which fetches _current_state_ids and
290    _prev_state_ids from the database on demand.
291
292    Attributes:
293
294        _storage
295
296        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
297            None if we haven't started calculating yet
298
299        _event_type: The type of the event the context is associated with.
300
301        _event_state_key: The state_key of the event the context is associated with.
302
303        _prev_state_id: If the event associated with the context is a state event,
304            then `_prev_state_id` is the event_id of the state that was replaced.
305    """
306
307    # This needs to have a default as we're inheriting
308    _storage: "Storage" = attr.ib(default=None)
309    _prev_state_id: Optional[str] = attr.ib(default=None)
310    _event_type: str = attr.ib(default=None)
311    _event_state_key: Optional[str] = attr.ib(default=None)
312    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
313
314    async def _ensure_fetched(self) -> None:
315        if not self._fetching_state_deferred:
316            self._fetching_state_deferred = run_in_background(self._fill_out_state)
317
318        await make_deferred_yieldable(self._fetching_state_deferred)
319
320    async def _fill_out_state(self) -> None:
321        """Called to populate the _current_state_ids and _prev_state_ids
322        attributes by loading from the database.
323        """
324        if self.state_group is None:
325            # No state group means the event is an outlier. Usually the state_ids dicts are also
326            # pre-set to empty dicts, but they get reset when the context is serialized, so set
327            # them to empty dicts again here.
328            self._current_state_ids = {}
329            self._prev_state_ids = {}
330            return
331
332        current_state_ids = await self._storage.state.get_state_ids_for_group(
333            self.state_group
334        )
335        # Set this separately so mypy knows current_state_ids is not None.
336        self._current_state_ids = current_state_ids
337        if self._event_state_key is not None:
338            self._prev_state_ids = dict(current_state_ids)
339
340            key = (self._event_type, self._event_state_key)
341            if self._prev_state_id:
342                self._prev_state_ids[key] = self._prev_state_id
343            else:
344                self._prev_state_ids.pop(key, None)
345        else:
346            self._prev_state_ids = current_state_ids
347
348
349def _encode_state_dict(
350    state_dict: Optional[StateMap[str]],
351) -> Optional[List[Tuple[str, str, str]]]:
352    """Since dicts of (type, state_key) -> event_id cannot be serialized in
353    JSON we need to convert them to a form that can.
354    """
355    if state_dict is None:
356        return None
357
358    return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
359
360
361def _decode_state_dict(
362    input: Optional[List[Tuple[str, str, str]]]
363) -> Optional[StateMap[str]]:
364    """Decodes a state dict encoded using `_encode_state_dict` above"""
365    if input is None:
366        return None
367
368    return frozendict({(etype, state_key): v for etype, state_key, v in input})
369