1# Copyright 2014-2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from typing import (
16    TYPE_CHECKING,
17    Awaitable,
18    Collection,
19    Dict,
20    Iterable,
21    List,
22    Mapping,
23    Optional,
24    Set,
25    Tuple,
26    TypeVar,
27)
28
29import attr
30from frozendict import frozendict
31
32from synapse.api.constants import EventTypes
33from synapse.events import EventBase
34from synapse.types import MutableStateMap, StateKey, StateMap
35
36if TYPE_CHECKING:
37    from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad
38
39    from synapse.server import HomeServer
40    from synapse.storage.databases import Databases
41
42logger = logging.getLogger(__name__)
43
44# Used for generic functions below
45T = TypeVar("T")
46
47
48@attr.s(slots=True, frozen=True)
49class StateFilter:
50    """A filter used when querying for state.
51
52    Attributes:
53        types: Map from type to set of state keys (or None). This specifies
54            which state_keys for the given type to fetch from the DB. If None
55            then all events with that type are fetched. If the set is empty
56            then no events with that type are fetched.
57        include_others: Whether to fetch events with types that do not
58            appear in `types`.
59    """
60
61    types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
62    include_others = attr.ib(default=False, type=bool)
63
64    def __attrs_post_init__(self):
65        # If `include_others` is set we canonicalise the filter by removing
66        # wildcards from the types dictionary
67        if self.include_others:
68            # this is needed to work around the fact that StateFilter is frozen
69            object.__setattr__(
70                self,
71                "types",
72                frozendict({k: v for k, v in self.types.items() if v is not None}),
73            )
74
75    @staticmethod
76    def all() -> "StateFilter":
77        """Creates a filter that fetches everything.
78
79        Returns:
80            The new state filter.
81        """
82        return StateFilter(types=frozendict(), include_others=True)
83
84    @staticmethod
85    def none() -> "StateFilter":
86        """Creates a filter that fetches nothing.
87
88        Returns:
89            The new state filter.
90        """
91        return StateFilter(types=frozendict(), include_others=False)
92
93    @staticmethod
94    def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
95        """Creates a filter that only fetches the given types
96
97        Args:
98            types: A list of type and state keys to fetch. A state_key of None
99                fetches everything for that type
100
101        Returns:
102            The new state filter.
103        """
104        type_dict: Dict[str, Optional[Set[str]]] = {}
105        for typ, s in types:
106            if typ in type_dict:
107                if type_dict[typ] is None:
108                    continue
109
110            if s is None:
111                type_dict[typ] = None
112                continue
113
114            type_dict.setdefault(typ, set()).add(s)  # type: ignore
115
116        return StateFilter(
117            types=frozendict(
118                (k, frozenset(v) if v is not None else None)
119                for k, v in type_dict.items()
120            )
121        )
122
123    @staticmethod
124    def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
125        """Creates a filter that returns all non-member events, plus the member
126        events for the given users
127
128        Args:
129            members: Set of user IDs
130
131        Returns:
132            The new state filter
133        """
134        return StateFilter(
135            types=frozendict({EventTypes.Member: frozenset(members)}),
136            include_others=True,
137        )
138
139    @staticmethod
140    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
141        """
142        Returns a (frozen) StateFilter with the same contents as the parameters
143        specified here, which can be made of mutable types.
144        """
145        types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
146        for state_types, state_keys in types.items():
147            if state_keys is not None:
148                types_with_frozen_values[state_types] = frozenset(state_keys)
149            else:
150                types_with_frozen_values[state_types] = None
151
152        return StateFilter(
153            frozendict(types_with_frozen_values), include_others=include_others
154        )
155
156    def return_expanded(self) -> "StateFilter":
157        """Creates a new StateFilter where type wild cards have been removed
158        (except for memberships). The returned filter is a superset of the
159        current one, i.e. anything that passes the current filter will pass
160        the returned filter.
161
162        This helps the caching as the DictionaryCache knows if it has *all* the
163        state, but does not know if it has all of the keys of a particular type,
164        which makes wildcard lookups expensive unless we have a complete cache.
165        Hence, if we are doing a wildcard lookup, populate the cache fully so
166        that we can do an efficient lookup next time.
167
168        Note that since we have two caches, one for membership events and one for
169        other events, we can be a bit more clever than simply returning
170        `StateFilter.all()` if `has_wildcards()` is True.
171
172        We return a StateFilter where:
173            1. the list of membership events to return is the same
174            2. if there is a wildcard that matches non-member events we
175               return all non-member events
176
177        Returns:
178            The new state filter.
179        """
180
181        if self.is_full():
182            # If we're going to return everything then there's nothing to do
183            return self
184
185        if not self.has_wildcards():
186            # If there are no wild cards, there's nothing to do
187            return self
188
189        if EventTypes.Member in self.types:
190            get_all_members = self.types[EventTypes.Member] is None
191        else:
192            get_all_members = self.include_others
193
194        has_non_member_wildcard = self.include_others or any(
195            state_keys is None
196            for t, state_keys in self.types.items()
197            if t != EventTypes.Member
198        )
199
200        if not has_non_member_wildcard:
201            # If there are no non-member wild cards we can just return ourselves
202            return self
203
204        if get_all_members:
205            # We want to return everything.
206            return StateFilter.all()
207        else:
208            # We want to return all non-members, but only particular
209            # memberships
210            return StateFilter(
211                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
212                include_others=True,
213            )
214
215    def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
216        """Converts the filter to an SQL clause.
217
218        For example:
219
220            f = StateFilter.from_types([("m.room.create", "")])
221            clause, args = f.make_sql_filter_clause()
222            clause == "(type = ? AND state_key = ?)"
223            args == ['m.room.create', '']
224
225
226        Returns:
227            The SQL string (may be empty) and arguments. An empty SQL string is
228            returned when the filter matches everything (i.e. is "full").
229        """
230
231        where_clause = ""
232        where_args: List[str] = []
233
234        if self.is_full():
235            return where_clause, where_args
236
237        if not self.include_others and not self.types:
238            # i.e. this is an empty filter, so we need to return a clause that
239            # will match nothing
240            return "1 = 2", []
241
242        # First we build up a lost of clauses for each type/state_key combo
243        clauses = []
244        for etype, state_keys in self.types.items():
245            if state_keys is None:
246                clauses.append("(type = ?)")
247                where_args.append(etype)
248                continue
249
250            for state_key in state_keys:
251                clauses.append("(type = ? AND state_key = ?)")
252                where_args.extend((etype, state_key))
253
254        # This will match anything that appears in `self.types`
255        where_clause = " OR ".join(clauses)
256
257        # If we want to include stuff that's not in the types dict then we add
258        # a `OR type NOT IN (...)` clause to the end.
259        if self.include_others:
260            if where_clause:
261                where_clause += " OR "
262
263            where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
264            where_args.extend(self.types)
265
266        return where_clause, where_args
267
268    def max_entries_returned(self) -> Optional[int]:
269        """Returns the maximum number of entries this filter will return if
270        known, otherwise returns None.
271
272        For example a simple state filter asking for `("m.room.create", "")`
273        will return 1, whereas the default state filter will return None.
274
275        This is used to bail out early if the right number of entries have been
276        fetched.
277        """
278        if self.has_wildcards():
279            return None
280
281        return len(self.concrete_types())
282
283    def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
284        """Returns the state filtered with by this StateFilter.
285
286        Args:
287            state: The state map to filter
288
289        Returns:
290            The filtered state map.
291            This is a copy, so it's safe to mutate.
292        """
293        if self.is_full():
294            return dict(state_dict)
295
296        filtered_state = {}
297        for k, v in state_dict.items():
298            typ, state_key = k
299            if typ in self.types:
300                state_keys = self.types[typ]
301                if state_keys is None or state_key in state_keys:
302                    filtered_state[k] = v
303            elif self.include_others:
304                filtered_state[k] = v
305
306        return filtered_state
307
308    def is_full(self) -> bool:
309        """Whether this filter fetches everything or not
310
311        Returns:
312            True if the filter fetches everything.
313        """
314        return self.include_others and not self.types
315
316    def has_wildcards(self) -> bool:
317        """Whether the filter includes wildcards or is attempting to fetch
318        specific state.
319
320        Returns:
321            True if the filter includes wildcards.
322        """
323
324        return self.include_others or any(
325            state_keys is None for state_keys in self.types.values()
326        )
327
328    def concrete_types(self) -> List[Tuple[str, str]]:
329        """Returns a list of concrete type/state_keys (i.e. not None) that
330        will be fetched. This will be a complete list if `has_wildcards`
331        returns False, but otherwise will be a subset (or even empty).
332
333        Returns:
334            A list of type/state_keys tuples.
335        """
336        return [
337            (t, s)
338            for t, state_keys in self.types.items()
339            if state_keys is not None
340            for s in state_keys
341        ]
342
343    def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
344        """Return the filter split into two: one which assumes it's exclusively
345        matching against member state, and one which assumes it's matching
346        against non member state.
347
348        This is useful due to the returned filters giving correct results for
349        `is_full()`, `has_wildcards()`, etc, when operating against maps that
350        either exclusively contain member events or only contain non-member
351        events. (Which is the case when dealing with the member vs non-member
352        state caches).
353
354        Returns:
355            The member and non member filters
356        """
357
358        if EventTypes.Member in self.types:
359            state_keys = self.types[EventTypes.Member]
360            if state_keys is None:
361                member_filter = StateFilter.all()
362            else:
363                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
364        elif self.include_others:
365            member_filter = StateFilter.all()
366        else:
367            member_filter = StateFilter.none()
368
369        non_member_filter = StateFilter(
370            types=frozendict(
371                {k: v for k, v in self.types.items() if k != EventTypes.Member}
372            ),
373            include_others=self.include_others,
374        )
375
376        return member_filter, non_member_filter
377
378    def _decompose_into_four_parts(
379        self,
380    ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
381        """
382        Decomposes this state filter into 4 constituent parts, which can be
383        thought of as this:
384            all? - minus_wildcards + plus_wildcards + plus_state_keys
385
386        where
387        * all represents ALL state
388        * minus_wildcards represents entire state types to remove
389        * plus_wildcards represents entire state types to add
390        * plus_state_keys represents individual state keys to add
391
392        See `recompose_from_four_parts` for the other direction of this
393        correspondence.
394        """
395        is_all = self.include_others
396        excluded_types: Set[str] = {t for t in self.types if is_all}
397        wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
398        concrete_keys: Set[StateKey] = set(self.concrete_types())
399
400        return (is_all, excluded_types), (wildcard_types, concrete_keys)
401
402    @staticmethod
403    def _recompose_from_four_parts(
404        all_part: bool,
405        minus_wildcards: Set[str],
406        plus_wildcards: Set[str],
407        plus_state_keys: Set[StateKey],
408    ) -> "StateFilter":
409        """
410        Recomposes a state filter from 4 parts.
411
412        See `decompose_into_four_parts` (the other direction of this
413        correspondence) for descriptions on each of the parts.
414        """
415
416        # {state type -> set of state keys OR None for wildcard}
417        # (The same structure as that of a StateFilter.)
418        new_types: Dict[str, Optional[Set[str]]] = {}
419
420        # if we start with all, insert the excluded statetypes as empty sets
421        # to prevent them from being included
422        if all_part:
423            new_types.update({state_type: set() for state_type in minus_wildcards})
424
425        # insert the plus wildcards
426        new_types.update({state_type: None for state_type in plus_wildcards})
427
428        # insert the specific state keys
429        for state_type, state_key in plus_state_keys:
430            if state_type in new_types:
431                entry = new_types[state_type]
432                if entry is not None:
433                    entry.add(state_key)
434            elif not all_part:
435                # don't insert if the entire type is already included by
436                # include_others as this would actually shrink the state allowed
437                # by this filter.
438                new_types[state_type] = {state_key}
439
440        return StateFilter.freeze(new_types, include_others=all_part)
441
442    def approx_difference(self, other: "StateFilter") -> "StateFilter":
443        """
444        Returns a state filter which represents `self - other`.
445
446        This is useful for determining what state remains to be pulled out of the
447        database if we want the state included by `self` but already have the state
448        included by `other`.
449
450        The returned state filter
451        - MUST include all state events that are included by this filter (`self`)
452          unless they are included by `other`;
453        - MUST NOT include state events not included by this filter (`self`); and
454        - MAY be an over-approximation: the returned state filter
455          MAY additionally include some state events from `other`.
456
457        This implementation attempts to return the narrowest such state filter.
458        In the case that `self` contains wildcards for state types where
459        `other` contains specific state keys, an approximation must be made:
460        the returned state filter keeps the wildcard, as state filters are not
461        able to express 'all state keys except some given examples'.
462        e.g.
463            StateFilter(m.room.member -> None (wildcard))
464                minus
465            StateFilter(m.room.member -> {'@wombat:example.org'})
466                is approximated as
467            StateFilter(m.room.member -> None (wildcard))
468        """
469
470        # We first transform self and other into an alternative representation:
471        #   - whether or not they include all events to begin with ('all')
472        #   - if so, which event types are excluded? ('excludes')
473        #   - which entire event types to include ('wildcards')
474        #   - which concrete state keys to include ('concrete state keys')
475        (self_all, self_excludes), (
476            self_wildcards,
477            self_concrete_keys,
478        ) = self._decompose_into_four_parts()
479        (other_all, other_excludes), (
480            other_wildcards,
481            other_concrete_keys,
482        ) = other._decompose_into_four_parts()
483
484        # Start with an estimate of the difference based on self
485        new_all = self_all
486        # Wildcards from the other can be added to the exclusion filter
487        new_excludes = self_excludes | other_wildcards
488        # We remove wildcards that appeared as wildcards in the other
489        new_wildcards = self_wildcards - other_wildcards
490        # We filter out the concrete state keys that appear in the other
491        # as wildcards or concrete state keys.
492        new_concrete_keys = {
493            (state_type, state_key)
494            for (state_type, state_key) in self_concrete_keys
495            if state_type not in other_wildcards
496        } - other_concrete_keys
497
498        if other_all:
499            if self_all:
500                # If self starts with all, then we add as wildcards any
501                # types which appear in the other's exclusion filter (but
502                # aren't in the self exclusion filter). This is as the other
503                # filter will return everything BUT the types in its exclusion, so
504                # we need to add those excluded types that also match the self
505                # filter as wildcard types in the new filter.
506                new_wildcards |= other_excludes.difference(self_excludes)
507
508            # If other is an `include_others` then the difference isn't.
509            new_all = False
510            # (We have no need for excludes when we don't start with all, as there
511            #  is nothing to exclude.)
512            new_excludes = set()
513
514            # We also filter out all state types that aren't in the exclusion
515            # list of the other.
516            new_wildcards &= other_excludes
517            new_concrete_keys = {
518                (state_type, state_key)
519                for (state_type, state_key) in new_concrete_keys
520                if state_type in other_excludes
521            }
522
523        # Transform our newly-constructed state filter from the alternative
524        # representation back into the normal StateFilter representation.
525        return StateFilter._recompose_from_four_parts(
526            new_all, new_excludes, new_wildcards, new_concrete_keys
527        )
528
529
530class StateGroupStorage:
531    """High level interface to fetching state for event."""
532
533    def __init__(self, hs: "HomeServer", stores: "Databases"):
534        self.stores = stores
535
536    async def get_state_group_delta(
537        self, state_group: int
538    ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
539        """Given a state group try to return a previous group and a delta between
540        the old and the new.
541
542        Args:
543            state_group: The state group used to retrieve state deltas.
544
545        Returns:
546            A tuple of the previous group and a state map of the event IDs which
547            make up the delta between the old and new state groups.
548        """
549
550        state_group_delta = await self.stores.state.get_state_group_delta(state_group)
551        return state_group_delta.prev_group, state_group_delta.delta_ids
552
553    async def get_state_groups_ids(
554        self, _room_id: str, event_ids: Iterable[str]
555    ) -> Dict[int, MutableStateMap[str]]:
556        """Get the event IDs of all the state for the state groups for the given events
557
558        Args:
559            _room_id: id of the room for these events
560            event_ids: ids of the events
561
562        Returns:
563            dict of state_group_id -> (dict of (type, state_key) -> event id)
564        """
565        if not event_ids:
566            return {}
567
568        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
569
570        groups = set(event_to_groups.values())
571        group_to_state = await self.stores.state._get_state_for_groups(groups)
572
573        return group_to_state
574
575    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
576        """Get the event IDs of all the state in the given state group
577
578        Args:
579            state_group: A state group for which we want to get the state IDs.
580
581        Returns:
582            Resolves to a map of (type, state_key) -> event_id
583        """
584        group_to_state = await self._get_state_for_groups((state_group,))
585
586        return group_to_state[state_group]
587
588    async def get_state_groups(
589        self, room_id: str, event_ids: Iterable[str]
590    ) -> Dict[int, List[EventBase]]:
591        """Get the state groups for the given list of event_ids
592
593        Args:
594            room_id: ID of the room for these events.
595            event_ids: The event IDs to retrieve state for.
596
597        Returns:
598            dict of state_group_id -> list of state events.
599        """
600        if not event_ids:
601            return {}
602
603        group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
604
605        state_event_map = await self.stores.main.get_events(
606            [
607                ev_id
608                for group_ids in group_to_ids.values()
609                for ev_id in group_ids.values()
610            ],
611            get_prev_content=False,
612        )
613
614        return {
615            group: [
616                state_event_map[v]
617                for v in event_id_map.values()
618                if v in state_event_map
619            ]
620            for group, event_id_map in group_to_ids.items()
621        }
622
623    def _get_state_groups_from_groups(
624        self, groups: List[int], state_filter: StateFilter
625    ) -> Awaitable[Dict[int, StateMap[str]]]:
626        """Returns the state groups for a given set of groups, filtering on
627        types of state events.
628
629        Args:
630            groups: list of state group IDs to query
631            state_filter: The state filter used to fetch state
632                from the database.
633
634        Returns:
635            Dict of state group to state map.
636        """
637
638        return self.stores.state._get_state_groups_from_groups(groups, state_filter)
639
640    async def get_state_for_events(
641        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
642    ) -> Dict[str, StateMap[EventBase]]:
643        """Given a list of event_ids and type tuples, return a list of state
644        dicts for each event.
645
646        Args:
647            event_ids: The events to fetch the state of.
648            state_filter: The state filter used to fetch state.
649
650        Returns:
651            A dict of (event_id) -> (type, state_key) -> [state_events]
652        """
653        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
654
655        groups = set(event_to_groups.values())
656        group_to_state = await self.stores.state._get_state_for_groups(
657            groups, state_filter or StateFilter.all()
658        )
659
660        state_event_map = await self.stores.main.get_events(
661            [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
662            get_prev_content=False,
663        )
664
665        event_to_state = {
666            event_id: {
667                k: state_event_map[v]
668                for k, v in group_to_state[group].items()
669                if v in state_event_map
670            }
671            for event_id, group in event_to_groups.items()
672        }
673
674        return {event: event_to_state[event] for event in event_ids}
675
676    async def get_state_ids_for_events(
677        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
678    ) -> Dict[str, StateMap[str]]:
679        """
680        Get the state dicts corresponding to a list of events, containing the event_ids
681        of the state events (as opposed to the events themselves)
682
683        Args:
684            event_ids: events whose state should be returned
685            state_filter: The state filter used to fetch state from the database.
686
687        Returns:
688            A dict from event_id -> (type, state_key) -> event_id
689        """
690        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
691
692        groups = set(event_to_groups.values())
693        group_to_state = await self.stores.state._get_state_for_groups(
694            groups, state_filter or StateFilter.all()
695        )
696
697        event_to_state = {
698            event_id: group_to_state[group]
699            for event_id, group in event_to_groups.items()
700        }
701
702        return {event: event_to_state[event] for event in event_ids}
703
704    async def get_state_for_event(
705        self, event_id: str, state_filter: Optional[StateFilter] = None
706    ) -> StateMap[EventBase]:
707        """
708        Get the state dict corresponding to a particular event
709
710        Args:
711            event_id: event whose state should be returned
712            state_filter: The state filter used to fetch state from the database.
713
714        Returns:
715            A dict from (type, state_key) -> state_event
716        """
717        state_map = await self.get_state_for_events(
718            [event_id], state_filter or StateFilter.all()
719        )
720        return state_map[event_id]
721
722    async def get_state_ids_for_event(
723        self, event_id: str, state_filter: Optional[StateFilter] = None
724    ) -> StateMap[str]:
725        """
726        Get the state dict corresponding to a particular event
727
728        Args:
729            event_id: event whose state should be returned
730            state_filter: The state filter used to fetch state from the database.
731
732        Returns:
733            A dict from (type, state_key) -> state_event_id
734        """
735        state_map = await self.get_state_ids_for_events(
736            [event_id], state_filter or StateFilter.all()
737        )
738        return state_map[event_id]
739
740    def _get_state_for_groups(
741        self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
742    ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
743        """Gets the state at each of a list of state groups, optionally
744        filtering by type/state_key
745
746        Args:
747            groups: list of state groups for which we want to get the state.
748            state_filter: The state filter used to fetch state.
749                from the database.
750
751        Returns:
752            Dict of state group to state map.
753        """
754        return self.stores.state._get_state_for_groups(
755            groups, state_filter or StateFilter.all()
756        )
757
758    async def store_state_group(
759        self,
760        event_id: str,
761        room_id: str,
762        prev_group: Optional[int],
763        delta_ids: Optional[StateMap[str]],
764        current_state_ids: StateMap[str],
765    ) -> int:
766        """Store a new set of state, returning a newly assigned state group.
767
768        Args:
769            event_id: The event ID for which the state was calculated.
770            room_id: ID of the room for which the state was calculated.
771            prev_group: A previous state group for the room, optional.
772            delta_ids: The delta between state at `prev_group` and
773                `current_state_ids`, if `prev_group` was given. Same format as
774                `current_state_ids`.
775            current_state_ids: The state to store. Map of (type, state_key)
776                to event_id.
777
778        Returns:
779            The state group ID
780        """
781        return await self.stores.state.store_state_group(
782            event_id, room_id, prev_group, delta_ids, current_state_ids
783        )
784