1# Copyright 2014-2016 OpenMarket Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import logging 15from typing import ( 16 TYPE_CHECKING, 17 Awaitable, 18 Collection, 19 Dict, 20 Iterable, 21 List, 22 Mapping, 23 Optional, 24 Set, 25 Tuple, 26 TypeVar, 27) 28 29import attr 30from frozendict import frozendict 31 32from synapse.api.constants import EventTypes 33from synapse.events import EventBase 34from synapse.types import MutableStateMap, StateKey, StateMap 35 36if TYPE_CHECKING: 37 from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad 38 39 from synapse.server import HomeServer 40 from synapse.storage.databases import Databases 41 42logger = logging.getLogger(__name__) 43 44# Used for generic functions below 45T = TypeVar("T") 46 47 48@attr.s(slots=True, frozen=True) 49class StateFilter: 50 """A filter used when querying for state. 51 52 Attributes: 53 types: Map from type to set of state keys (or None). This specifies 54 which state_keys for the given type to fetch from the DB. If None 55 then all events with that type are fetched. If the set is empty 56 then no events with that type are fetched. 57 include_others: Whether to fetch events with types that do not 58 appear in `types`. 59 """ 60 61 types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]") 62 include_others = attr.ib(default=False, type=bool) 63 64 def __attrs_post_init__(self): 65 # If `include_others` is set we canonicalise the filter by removing 66 # wildcards from the types dictionary 67 if self.include_others: 68 # this is needed to work around the fact that StateFilter is frozen 69 object.__setattr__( 70 self, 71 "types", 72 frozendict({k: v for k, v in self.types.items() if v is not None}), 73 ) 74 75 @staticmethod 76 def all() -> "StateFilter": 77 """Creates a filter that fetches everything. 78 79 Returns: 80 The new state filter. 81 """ 82 return StateFilter(types=frozendict(), include_others=True) 83 84 @staticmethod 85 def none() -> "StateFilter": 86 """Creates a filter that fetches nothing. 87 88 Returns: 89 The new state filter. 90 """ 91 return StateFilter(types=frozendict(), include_others=False) 92 93 @staticmethod 94 def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": 95 """Creates a filter that only fetches the given types 96 97 Args: 98 types: A list of type and state keys to fetch. A state_key of None 99 fetches everything for that type 100 101 Returns: 102 The new state filter. 103 """ 104 type_dict: Dict[str, Optional[Set[str]]] = {} 105 for typ, s in types: 106 if typ in type_dict: 107 if type_dict[typ] is None: 108 continue 109 110 if s is None: 111 type_dict[typ] = None 112 continue 113 114 type_dict.setdefault(typ, set()).add(s) # type: ignore 115 116 return StateFilter( 117 types=frozendict( 118 (k, frozenset(v) if v is not None else None) 119 for k, v in type_dict.items() 120 ) 121 ) 122 123 @staticmethod 124 def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": 125 """Creates a filter that returns all non-member events, plus the member 126 events for the given users 127 128 Args: 129 members: Set of user IDs 130 131 Returns: 132 The new state filter 133 """ 134 return StateFilter( 135 types=frozendict({EventTypes.Member: frozenset(members)}), 136 include_others=True, 137 ) 138 139 @staticmethod 140 def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): 141 """ 142 Returns a (frozen) StateFilter with the same contents as the parameters 143 specified here, which can be made of mutable types. 144 """ 145 types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} 146 for state_types, state_keys in types.items(): 147 if state_keys is not None: 148 types_with_frozen_values[state_types] = frozenset(state_keys) 149 else: 150 types_with_frozen_values[state_types] = None 151 152 return StateFilter( 153 frozendict(types_with_frozen_values), include_others=include_others 154 ) 155 156 def return_expanded(self) -> "StateFilter": 157 """Creates a new StateFilter where type wild cards have been removed 158 (except for memberships). The returned filter is a superset of the 159 current one, i.e. anything that passes the current filter will pass 160 the returned filter. 161 162 This helps the caching as the DictionaryCache knows if it has *all* the 163 state, but does not know if it has all of the keys of a particular type, 164 which makes wildcard lookups expensive unless we have a complete cache. 165 Hence, if we are doing a wildcard lookup, populate the cache fully so 166 that we can do an efficient lookup next time. 167 168 Note that since we have two caches, one for membership events and one for 169 other events, we can be a bit more clever than simply returning 170 `StateFilter.all()` if `has_wildcards()` is True. 171 172 We return a StateFilter where: 173 1. the list of membership events to return is the same 174 2. if there is a wildcard that matches non-member events we 175 return all non-member events 176 177 Returns: 178 The new state filter. 179 """ 180 181 if self.is_full(): 182 # If we're going to return everything then there's nothing to do 183 return self 184 185 if not self.has_wildcards(): 186 # If there are no wild cards, there's nothing to do 187 return self 188 189 if EventTypes.Member in self.types: 190 get_all_members = self.types[EventTypes.Member] is None 191 else: 192 get_all_members = self.include_others 193 194 has_non_member_wildcard = self.include_others or any( 195 state_keys is None 196 for t, state_keys in self.types.items() 197 if t != EventTypes.Member 198 ) 199 200 if not has_non_member_wildcard: 201 # If there are no non-member wild cards we can just return ourselves 202 return self 203 204 if get_all_members: 205 # We want to return everything. 206 return StateFilter.all() 207 else: 208 # We want to return all non-members, but only particular 209 # memberships 210 return StateFilter( 211 types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), 212 include_others=True, 213 ) 214 215 def make_sql_filter_clause(self) -> Tuple[str, List[str]]: 216 """Converts the filter to an SQL clause. 217 218 For example: 219 220 f = StateFilter.from_types([("m.room.create", "")]) 221 clause, args = f.make_sql_filter_clause() 222 clause == "(type = ? AND state_key = ?)" 223 args == ['m.room.create', ''] 224 225 226 Returns: 227 The SQL string (may be empty) and arguments. An empty SQL string is 228 returned when the filter matches everything (i.e. is "full"). 229 """ 230 231 where_clause = "" 232 where_args: List[str] = [] 233 234 if self.is_full(): 235 return where_clause, where_args 236 237 if not self.include_others and not self.types: 238 # i.e. this is an empty filter, so we need to return a clause that 239 # will match nothing 240 return "1 = 2", [] 241 242 # First we build up a lost of clauses for each type/state_key combo 243 clauses = [] 244 for etype, state_keys in self.types.items(): 245 if state_keys is None: 246 clauses.append("(type = ?)") 247 where_args.append(etype) 248 continue 249 250 for state_key in state_keys: 251 clauses.append("(type = ? AND state_key = ?)") 252 where_args.extend((etype, state_key)) 253 254 # This will match anything that appears in `self.types` 255 where_clause = " OR ".join(clauses) 256 257 # If we want to include stuff that's not in the types dict then we add 258 # a `OR type NOT IN (...)` clause to the end. 259 if self.include_others: 260 if where_clause: 261 where_clause += " OR " 262 263 where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) 264 where_args.extend(self.types) 265 266 return where_clause, where_args 267 268 def max_entries_returned(self) -> Optional[int]: 269 """Returns the maximum number of entries this filter will return if 270 known, otherwise returns None. 271 272 For example a simple state filter asking for `("m.room.create", "")` 273 will return 1, whereas the default state filter will return None. 274 275 This is used to bail out early if the right number of entries have been 276 fetched. 277 """ 278 if self.has_wildcards(): 279 return None 280 281 return len(self.concrete_types()) 282 283 def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: 284 """Returns the state filtered with by this StateFilter. 285 286 Args: 287 state: The state map to filter 288 289 Returns: 290 The filtered state map. 291 This is a copy, so it's safe to mutate. 292 """ 293 if self.is_full(): 294 return dict(state_dict) 295 296 filtered_state = {} 297 for k, v in state_dict.items(): 298 typ, state_key = k 299 if typ in self.types: 300 state_keys = self.types[typ] 301 if state_keys is None or state_key in state_keys: 302 filtered_state[k] = v 303 elif self.include_others: 304 filtered_state[k] = v 305 306 return filtered_state 307 308 def is_full(self) -> bool: 309 """Whether this filter fetches everything or not 310 311 Returns: 312 True if the filter fetches everything. 313 """ 314 return self.include_others and not self.types 315 316 def has_wildcards(self) -> bool: 317 """Whether the filter includes wildcards or is attempting to fetch 318 specific state. 319 320 Returns: 321 True if the filter includes wildcards. 322 """ 323 324 return self.include_others or any( 325 state_keys is None for state_keys in self.types.values() 326 ) 327 328 def concrete_types(self) -> List[Tuple[str, str]]: 329 """Returns a list of concrete type/state_keys (i.e. not None) that 330 will be fetched. This will be a complete list if `has_wildcards` 331 returns False, but otherwise will be a subset (or even empty). 332 333 Returns: 334 A list of type/state_keys tuples. 335 """ 336 return [ 337 (t, s) 338 for t, state_keys in self.types.items() 339 if state_keys is not None 340 for s in state_keys 341 ] 342 343 def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: 344 """Return the filter split into two: one which assumes it's exclusively 345 matching against member state, and one which assumes it's matching 346 against non member state. 347 348 This is useful due to the returned filters giving correct results for 349 `is_full()`, `has_wildcards()`, etc, when operating against maps that 350 either exclusively contain member events or only contain non-member 351 events. (Which is the case when dealing with the member vs non-member 352 state caches). 353 354 Returns: 355 The member and non member filters 356 """ 357 358 if EventTypes.Member in self.types: 359 state_keys = self.types[EventTypes.Member] 360 if state_keys is None: 361 member_filter = StateFilter.all() 362 else: 363 member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) 364 elif self.include_others: 365 member_filter = StateFilter.all() 366 else: 367 member_filter = StateFilter.none() 368 369 non_member_filter = StateFilter( 370 types=frozendict( 371 {k: v for k, v in self.types.items() if k != EventTypes.Member} 372 ), 373 include_others=self.include_others, 374 ) 375 376 return member_filter, non_member_filter 377 378 def _decompose_into_four_parts( 379 self, 380 ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: 381 """ 382 Decomposes this state filter into 4 constituent parts, which can be 383 thought of as this: 384 all? - minus_wildcards + plus_wildcards + plus_state_keys 385 386 where 387 * all represents ALL state 388 * minus_wildcards represents entire state types to remove 389 * plus_wildcards represents entire state types to add 390 * plus_state_keys represents individual state keys to add 391 392 See `recompose_from_four_parts` for the other direction of this 393 correspondence. 394 """ 395 is_all = self.include_others 396 excluded_types: Set[str] = {t for t in self.types if is_all} 397 wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} 398 concrete_keys: Set[StateKey] = set(self.concrete_types()) 399 400 return (is_all, excluded_types), (wildcard_types, concrete_keys) 401 402 @staticmethod 403 def _recompose_from_four_parts( 404 all_part: bool, 405 minus_wildcards: Set[str], 406 plus_wildcards: Set[str], 407 plus_state_keys: Set[StateKey], 408 ) -> "StateFilter": 409 """ 410 Recomposes a state filter from 4 parts. 411 412 See `decompose_into_four_parts` (the other direction of this 413 correspondence) for descriptions on each of the parts. 414 """ 415 416 # {state type -> set of state keys OR None for wildcard} 417 # (The same structure as that of a StateFilter.) 418 new_types: Dict[str, Optional[Set[str]]] = {} 419 420 # if we start with all, insert the excluded statetypes as empty sets 421 # to prevent them from being included 422 if all_part: 423 new_types.update({state_type: set() for state_type in minus_wildcards}) 424 425 # insert the plus wildcards 426 new_types.update({state_type: None for state_type in plus_wildcards}) 427 428 # insert the specific state keys 429 for state_type, state_key in plus_state_keys: 430 if state_type in new_types: 431 entry = new_types[state_type] 432 if entry is not None: 433 entry.add(state_key) 434 elif not all_part: 435 # don't insert if the entire type is already included by 436 # include_others as this would actually shrink the state allowed 437 # by this filter. 438 new_types[state_type] = {state_key} 439 440 return StateFilter.freeze(new_types, include_others=all_part) 441 442 def approx_difference(self, other: "StateFilter") -> "StateFilter": 443 """ 444 Returns a state filter which represents `self - other`. 445 446 This is useful for determining what state remains to be pulled out of the 447 database if we want the state included by `self` but already have the state 448 included by `other`. 449 450 The returned state filter 451 - MUST include all state events that are included by this filter (`self`) 452 unless they are included by `other`; 453 - MUST NOT include state events not included by this filter (`self`); and 454 - MAY be an over-approximation: the returned state filter 455 MAY additionally include some state events from `other`. 456 457 This implementation attempts to return the narrowest such state filter. 458 In the case that `self` contains wildcards for state types where 459 `other` contains specific state keys, an approximation must be made: 460 the returned state filter keeps the wildcard, as state filters are not 461 able to express 'all state keys except some given examples'. 462 e.g. 463 StateFilter(m.room.member -> None (wildcard)) 464 minus 465 StateFilter(m.room.member -> {'@wombat:example.org'}) 466 is approximated as 467 StateFilter(m.room.member -> None (wildcard)) 468 """ 469 470 # We first transform self and other into an alternative representation: 471 # - whether or not they include all events to begin with ('all') 472 # - if so, which event types are excluded? ('excludes') 473 # - which entire event types to include ('wildcards') 474 # - which concrete state keys to include ('concrete state keys') 475 (self_all, self_excludes), ( 476 self_wildcards, 477 self_concrete_keys, 478 ) = self._decompose_into_four_parts() 479 (other_all, other_excludes), ( 480 other_wildcards, 481 other_concrete_keys, 482 ) = other._decompose_into_four_parts() 483 484 # Start with an estimate of the difference based on self 485 new_all = self_all 486 # Wildcards from the other can be added to the exclusion filter 487 new_excludes = self_excludes | other_wildcards 488 # We remove wildcards that appeared as wildcards in the other 489 new_wildcards = self_wildcards - other_wildcards 490 # We filter out the concrete state keys that appear in the other 491 # as wildcards or concrete state keys. 492 new_concrete_keys = { 493 (state_type, state_key) 494 for (state_type, state_key) in self_concrete_keys 495 if state_type not in other_wildcards 496 } - other_concrete_keys 497 498 if other_all: 499 if self_all: 500 # If self starts with all, then we add as wildcards any 501 # types which appear in the other's exclusion filter (but 502 # aren't in the self exclusion filter). This is as the other 503 # filter will return everything BUT the types in its exclusion, so 504 # we need to add those excluded types that also match the self 505 # filter as wildcard types in the new filter. 506 new_wildcards |= other_excludes.difference(self_excludes) 507 508 # If other is an `include_others` then the difference isn't. 509 new_all = False 510 # (We have no need for excludes when we don't start with all, as there 511 # is nothing to exclude.) 512 new_excludes = set() 513 514 # We also filter out all state types that aren't in the exclusion 515 # list of the other. 516 new_wildcards &= other_excludes 517 new_concrete_keys = { 518 (state_type, state_key) 519 for (state_type, state_key) in new_concrete_keys 520 if state_type in other_excludes 521 } 522 523 # Transform our newly-constructed state filter from the alternative 524 # representation back into the normal StateFilter representation. 525 return StateFilter._recompose_from_four_parts( 526 new_all, new_excludes, new_wildcards, new_concrete_keys 527 ) 528 529 530class StateGroupStorage: 531 """High level interface to fetching state for event.""" 532 533 def __init__(self, hs: "HomeServer", stores: "Databases"): 534 self.stores = stores 535 536 async def get_state_group_delta( 537 self, state_group: int 538 ) -> Tuple[Optional[int], Optional[StateMap[str]]]: 539 """Given a state group try to return a previous group and a delta between 540 the old and the new. 541 542 Args: 543 state_group: The state group used to retrieve state deltas. 544 545 Returns: 546 A tuple of the previous group and a state map of the event IDs which 547 make up the delta between the old and new state groups. 548 """ 549 550 state_group_delta = await self.stores.state.get_state_group_delta(state_group) 551 return state_group_delta.prev_group, state_group_delta.delta_ids 552 553 async def get_state_groups_ids( 554 self, _room_id: str, event_ids: Iterable[str] 555 ) -> Dict[int, MutableStateMap[str]]: 556 """Get the event IDs of all the state for the state groups for the given events 557 558 Args: 559 _room_id: id of the room for these events 560 event_ids: ids of the events 561 562 Returns: 563 dict of state_group_id -> (dict of (type, state_key) -> event id) 564 """ 565 if not event_ids: 566 return {} 567 568 event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) 569 570 groups = set(event_to_groups.values()) 571 group_to_state = await self.stores.state._get_state_for_groups(groups) 572 573 return group_to_state 574 575 async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: 576 """Get the event IDs of all the state in the given state group 577 578 Args: 579 state_group: A state group for which we want to get the state IDs. 580 581 Returns: 582 Resolves to a map of (type, state_key) -> event_id 583 """ 584 group_to_state = await self._get_state_for_groups((state_group,)) 585 586 return group_to_state[state_group] 587 588 async def get_state_groups( 589 self, room_id: str, event_ids: Iterable[str] 590 ) -> Dict[int, List[EventBase]]: 591 """Get the state groups for the given list of event_ids 592 593 Args: 594 room_id: ID of the room for these events. 595 event_ids: The event IDs to retrieve state for. 596 597 Returns: 598 dict of state_group_id -> list of state events. 599 """ 600 if not event_ids: 601 return {} 602 603 group_to_ids = await self.get_state_groups_ids(room_id, event_ids) 604 605 state_event_map = await self.stores.main.get_events( 606 [ 607 ev_id 608 for group_ids in group_to_ids.values() 609 for ev_id in group_ids.values() 610 ], 611 get_prev_content=False, 612 ) 613 614 return { 615 group: [ 616 state_event_map[v] 617 for v in event_id_map.values() 618 if v in state_event_map 619 ] 620 for group, event_id_map in group_to_ids.items() 621 } 622 623 def _get_state_groups_from_groups( 624 self, groups: List[int], state_filter: StateFilter 625 ) -> Awaitable[Dict[int, StateMap[str]]]: 626 """Returns the state groups for a given set of groups, filtering on 627 types of state events. 628 629 Args: 630 groups: list of state group IDs to query 631 state_filter: The state filter used to fetch state 632 from the database. 633 634 Returns: 635 Dict of state group to state map. 636 """ 637 638 return self.stores.state._get_state_groups_from_groups(groups, state_filter) 639 640 async def get_state_for_events( 641 self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None 642 ) -> Dict[str, StateMap[EventBase]]: 643 """Given a list of event_ids and type tuples, return a list of state 644 dicts for each event. 645 646 Args: 647 event_ids: The events to fetch the state of. 648 state_filter: The state filter used to fetch state. 649 650 Returns: 651 A dict of (event_id) -> (type, state_key) -> [state_events] 652 """ 653 event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) 654 655 groups = set(event_to_groups.values()) 656 group_to_state = await self.stores.state._get_state_for_groups( 657 groups, state_filter or StateFilter.all() 658 ) 659 660 state_event_map = await self.stores.main.get_events( 661 [ev_id for sd in group_to_state.values() for ev_id in sd.values()], 662 get_prev_content=False, 663 ) 664 665 event_to_state = { 666 event_id: { 667 k: state_event_map[v] 668 for k, v in group_to_state[group].items() 669 if v in state_event_map 670 } 671 for event_id, group in event_to_groups.items() 672 } 673 674 return {event: event_to_state[event] for event in event_ids} 675 676 async def get_state_ids_for_events( 677 self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None 678 ) -> Dict[str, StateMap[str]]: 679 """ 680 Get the state dicts corresponding to a list of events, containing the event_ids 681 of the state events (as opposed to the events themselves) 682 683 Args: 684 event_ids: events whose state should be returned 685 state_filter: The state filter used to fetch state from the database. 686 687 Returns: 688 A dict from event_id -> (type, state_key) -> event_id 689 """ 690 event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) 691 692 groups = set(event_to_groups.values()) 693 group_to_state = await self.stores.state._get_state_for_groups( 694 groups, state_filter or StateFilter.all() 695 ) 696 697 event_to_state = { 698 event_id: group_to_state[group] 699 for event_id, group in event_to_groups.items() 700 } 701 702 return {event: event_to_state[event] for event in event_ids} 703 704 async def get_state_for_event( 705 self, event_id: str, state_filter: Optional[StateFilter] = None 706 ) -> StateMap[EventBase]: 707 """ 708 Get the state dict corresponding to a particular event 709 710 Args: 711 event_id: event whose state should be returned 712 state_filter: The state filter used to fetch state from the database. 713 714 Returns: 715 A dict from (type, state_key) -> state_event 716 """ 717 state_map = await self.get_state_for_events( 718 [event_id], state_filter or StateFilter.all() 719 ) 720 return state_map[event_id] 721 722 async def get_state_ids_for_event( 723 self, event_id: str, state_filter: Optional[StateFilter] = None 724 ) -> StateMap[str]: 725 """ 726 Get the state dict corresponding to a particular event 727 728 Args: 729 event_id: event whose state should be returned 730 state_filter: The state filter used to fetch state from the database. 731 732 Returns: 733 A dict from (type, state_key) -> state_event_id 734 """ 735 state_map = await self.get_state_ids_for_events( 736 [event_id], state_filter or StateFilter.all() 737 ) 738 return state_map[event_id] 739 740 def _get_state_for_groups( 741 self, groups: Iterable[int], state_filter: Optional[StateFilter] = None 742 ) -> Awaitable[Dict[int, MutableStateMap[str]]]: 743 """Gets the state at each of a list of state groups, optionally 744 filtering by type/state_key 745 746 Args: 747 groups: list of state groups for which we want to get the state. 748 state_filter: The state filter used to fetch state. 749 from the database. 750 751 Returns: 752 Dict of state group to state map. 753 """ 754 return self.stores.state._get_state_for_groups( 755 groups, state_filter or StateFilter.all() 756 ) 757 758 async def store_state_group( 759 self, 760 event_id: str, 761 room_id: str, 762 prev_group: Optional[int], 763 delta_ids: Optional[StateMap[str]], 764 current_state_ids: StateMap[str], 765 ) -> int: 766 """Store a new set of state, returning a newly assigned state group. 767 768 Args: 769 event_id: The event ID for which the state was calculated. 770 room_id: ID of the room for which the state was calculated. 771 prev_group: A previous state group for the room, optional. 772 delta_ids: The delta between state at `prev_group` and 773 `current_state_ids`, if `prev_group` was given. Same format as 774 `current_state_ids`. 775 current_state_ids: The state to store. Map of (type, state_key) 776 to event_id. 777 778 Returns: 779 The state group ID 780 """ 781 return await self.stores.state.store_state_group( 782 event_id, room_id, prev_group, delta_ids, current_state_ids 783 ) 784