1# Copyright 2014-2016 OpenMarket Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import List, Optional
15from unittest.mock import Mock
16
17from twisted.internet import defer
18
19from synapse.api.auth import Auth
20from synapse.api.constants import EventTypes, Membership
21from synapse.api.room_versions import RoomVersions
22from synapse.events import make_event_from_dict
23from synapse.events.snapshot import EventContext
24from synapse.state import StateHandler, StateResolutionHandler
25
26from tests import unittest
27
28from .utils import MockClock, default_config
29
30_next_event_id = 1000
31
32
33def create_event(
34    name=None,
35    type=None,
36    state_key=None,
37    depth=2,
38    event_id=None,
39    prev_events: Optional[List[str]] = None,
40    **kwargs,
41):
42    global _next_event_id
43
44    if not event_id:
45        _next_event_id += 1
46        event_id = "$%s:test" % (_next_event_id,)
47
48    if not name:
49        if state_key is not None:
50            name = "<%s-%s, %s>" % (type, state_key, event_id)
51        else:
52            name = "<%s, %s>" % (type, event_id)
53
54    d = {
55        "event_id": event_id,
56        "type": type,
57        "sender": "@user_id:example.com",
58        "room_id": "!room_id:example.com",
59        "depth": depth,
60        "prev_events": prev_events or [],
61    }
62
63    if state_key is not None:
64        d["state_key"] = state_key
65
66    d.update(kwargs)
67
68    event = make_event_from_dict(d)
69
70    return event
71
72
73class StateGroupStore:
74    def __init__(self):
75        self._event_to_state_group = {}
76        self._group_to_state = {}
77
78        self._event_id_to_event = {}
79
80        self._next_group = 1
81
82    async def get_state_groups_ids(self, room_id, event_ids):
83        groups = {}
84        for event_id in event_ids:
85            group = self._event_to_state_group.get(event_id)
86            if group:
87                groups[group] = self._group_to_state[group]
88
89        return groups
90
91    async def store_state_group(
92        self, event_id, room_id, prev_group, delta_ids, current_state_ids
93    ):
94        state_group = self._next_group
95        self._next_group += 1
96
97        self._group_to_state[state_group] = dict(current_state_ids)
98
99        return state_group
100
101    async def get_events(self, event_ids, **kwargs):
102        return {
103            e_id: self._event_id_to_event[e_id]
104            for e_id in event_ids
105            if e_id in self._event_id_to_event
106        }
107
108    async def get_state_group_delta(self, name):
109        return None, None
110
111    def register_events(self, events):
112        for e in events:
113            self._event_id_to_event[e.event_id] = e
114
115    def register_event_context(self, event, context):
116        self._event_to_state_group[event.event_id] = context.state_group
117
118    def register_event_id_state_group(self, event_id, state_group):
119        self._event_to_state_group[event_id] = state_group
120
121    async def get_room_version_id(self, room_id):
122        return RoomVersions.V1.identifier
123
124
125class DictObj(dict):
126    def __init__(self, **kwargs):
127        super().__init__(kwargs)
128        self.__dict__ = self
129
130
131class Graph:
132    def __init__(self, nodes, edges):
133        events = {}
134        clobbered = set(events.keys())
135
136        for event_id, fields in nodes.items():
137            refs = edges.get(event_id)
138            if refs:
139                clobbered.difference_update(refs)
140                prev_events = [(r, {}) for r in refs]
141            else:
142                prev_events = []
143
144            events[event_id] = create_event(
145                event_id=event_id, prev_events=prev_events, **fields
146            )
147
148        self._leaves = clobbered
149        self._events = sorted(events.values(), key=lambda e: e.depth)
150
151    def walk(self):
152        return iter(self._events)
153
154    def get_leaves(self):
155        return (self._events[i] for i in self._leaves)
156
157
158class StateTestCase(unittest.TestCase):
159    def setUp(self):
160        self.store = StateGroupStore()
161        storage = Mock(main=self.store, state=self.store)
162        hs = Mock(
163            spec_set=[
164                "config",
165                "get_datastore",
166                "get_storage",
167                "get_auth",
168                "get_state_handler",
169                "get_clock",
170                "get_state_resolution_handler",
171                "get_account_validity_handler",
172                "hostname",
173            ]
174        )
175        hs.config = default_config("tesths", True)
176        hs.get_datastore.return_value = self.store
177        hs.get_state_handler.return_value = None
178        hs.get_clock.return_value = MockClock()
179        hs.get_auth.return_value = Auth(hs)
180        hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
181        hs.get_storage.return_value = storage
182
183        self.state = StateHandler(hs)
184        self.event_id = 0
185
186    @defer.inlineCallbacks
187    def test_branch_no_conflict(self):
188        graph = Graph(
189            nodes={
190                "START": DictObj(
191                    type=EventTypes.Create, state_key="", content={}, depth=1
192                ),
193                "A": DictObj(type=EventTypes.Message, depth=2),
194                "B": DictObj(type=EventTypes.Message, depth=3),
195                "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
196                "D": DictObj(type=EventTypes.Message, depth=4),
197            },
198            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
199        )
200
201        self.store.register_events(graph.walk())
202
203        context_store: dict[str, EventContext] = {}
204
205        for event in graph.walk():
206            context = yield defer.ensureDeferred(
207                self.state.compute_event_context(event)
208            )
209            self.store.register_event_context(event, context)
210            context_store[event.event_id] = context
211
212        ctx_c = context_store["C"]
213        ctx_d = context_store["D"]
214
215        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
216        self.assertEqual(2, len(prev_state_ids))
217
218        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
219        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
220
221    @defer.inlineCallbacks
222    def test_branch_basic_conflict(self):
223        graph = Graph(
224            nodes={
225                "START": DictObj(
226                    type=EventTypes.Create,
227                    state_key="",
228                    content={"creator": "@user_id:example.com"},
229                    depth=1,
230                ),
231                "A": DictObj(
232                    type=EventTypes.Member,
233                    state_key="@user_id:example.com",
234                    content={"membership": Membership.JOIN},
235                    membership=Membership.JOIN,
236                    depth=2,
237                ),
238                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
239                "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
240                "D": DictObj(type=EventTypes.Message, depth=5),
241            },
242            edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
243        )
244
245        self.store.register_events(graph.walk())
246
247        context_store = {}
248
249        for event in graph.walk():
250            context = yield defer.ensureDeferred(
251                self.state.compute_event_context(event)
252            )
253            self.store.register_event_context(event, context)
254            context_store[event.event_id] = context
255
256        # C ends up winning the resolution between B and C
257
258        ctx_c = context_store["C"]
259        ctx_d = context_store["D"]
260
261        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
262        self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
263
264        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
265        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
266
267    @defer.inlineCallbacks
268    def test_branch_have_banned_conflict(self):
269        graph = Graph(
270            nodes={
271                "START": DictObj(
272                    type=EventTypes.Create,
273                    state_key="",
274                    content={"creator": "@user_id:example.com"},
275                    depth=1,
276                ),
277                "A": DictObj(
278                    type=EventTypes.Member,
279                    state_key="@user_id:example.com",
280                    content={"membership": Membership.JOIN},
281                    membership=Membership.JOIN,
282                    depth=2,
283                ),
284                "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
285                "C": DictObj(
286                    type=EventTypes.Member,
287                    state_key="@user_id_2:example.com",
288                    content={"membership": Membership.BAN},
289                    membership=Membership.BAN,
290                    depth=4,
291                ),
292                "D": DictObj(
293                    type=EventTypes.Name,
294                    state_key="",
295                    depth=4,
296                    sender="@user_id_2:example.com",
297                ),
298                "E": DictObj(type=EventTypes.Message, depth=5),
299            },
300            edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
301        )
302
303        self.store.register_events(graph.walk())
304
305        context_store = {}
306
307        for event in graph.walk():
308            context = yield defer.ensureDeferred(
309                self.state.compute_event_context(event)
310            )
311            self.store.register_event_context(event, context)
312            context_store[event.event_id] = context
313
314        # C ends up winning the resolution between C and D because bans win over other
315        # changes
316
317        ctx_c = context_store["C"]
318        ctx_e = context_store["E"]
319
320        prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
321        self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
322        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
323        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
324
325    @defer.inlineCallbacks
326    def test_branch_have_perms_conflict(self):
327        userid1 = "@user_id:example.com"
328        userid2 = "@user_id2:example.com"
329
330        nodes = {
331            "A1": DictObj(
332                type=EventTypes.Create,
333                state_key="",
334                content={"creator": userid1},
335                depth=1,
336            ),
337            "A2": DictObj(
338                type=EventTypes.Member,
339                state_key=userid1,
340                content={"membership": Membership.JOIN},
341                membership=Membership.JOIN,
342            ),
343            "A3": DictObj(
344                type=EventTypes.Member,
345                state_key=userid2,
346                content={"membership": Membership.JOIN},
347                membership=Membership.JOIN,
348            ),
349            "A4": DictObj(
350                type=EventTypes.PowerLevels,
351                state_key="",
352                content={
353                    "events": {"m.room.name": 50},
354                    "users": {userid1: 100, userid2: 60},
355                },
356            ),
357            "A5": DictObj(type=EventTypes.Name, state_key=""),
358            "B": DictObj(
359                type=EventTypes.PowerLevels,
360                state_key="",
361                content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
362            ),
363            "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
364            "D": DictObj(type=EventTypes.Message),
365        }
366        edges = {
367            "A2": ["A1"],
368            "A3": ["A2"],
369            "A4": ["A3"],
370            "A5": ["A4"],
371            "B": ["A5"],
372            "C": ["A5"],
373            "D": ["B", "C"],
374        }
375        self._add_depths(nodes, edges)
376        graph = Graph(nodes, edges)
377
378        self.store.register_events(graph.walk())
379
380        context_store = {}
381
382        for event in graph.walk():
383            context = yield defer.ensureDeferred(
384                self.state.compute_event_context(event)
385            )
386            self.store.register_event_context(event, context)
387            context_store[event.event_id] = context
388
389        # B ends up winning the resolution between B and C because power levels
390        # win over other changes.
391
392        ctx_b = context_store["B"]
393        ctx_d = context_store["D"]
394
395        prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
396        self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
397
398        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
399        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
400
401    def _add_depths(self, nodes, edges):
402        def _get_depth(ev):
403            node = nodes[ev]
404            if "depth" not in node:
405                prevs = edges[ev]
406                depth = max(_get_depth(prev) for prev in prevs) + 1
407                node["depth"] = depth
408            return node["depth"]
409
410        for n in nodes:
411            _get_depth(n)
412
413    @defer.inlineCallbacks
414    def test_annotate_with_old_message(self):
415        event = create_event(type="test_message", name="event")
416
417        old_state = [
418            create_event(type="test1", state_key="1"),
419            create_event(type="test1", state_key="2"),
420            create_event(type="test2", state_key=""),
421        ]
422
423        context = yield defer.ensureDeferred(
424            self.state.compute_event_context(event, old_state=old_state)
425        )
426
427        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
428        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
429
430        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
431        self.assertCountEqual(
432            (e.event_id for e in old_state), current_state_ids.values()
433        )
434
435        self.assertIsNotNone(context.state_group_before_event)
436        self.assertEqual(context.state_group_before_event, context.state_group)
437
438    @defer.inlineCallbacks
439    def test_annotate_with_old_state(self):
440        event = create_event(type="state", state_key="", name="event")
441
442        old_state = [
443            create_event(type="test1", state_key="1"),
444            create_event(type="test1", state_key="2"),
445            create_event(type="test2", state_key=""),
446        ]
447
448        context = yield defer.ensureDeferred(
449            self.state.compute_event_context(event, old_state=old_state)
450        )
451
452        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
453        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
454
455        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
456        self.assertCountEqual(
457            (e.event_id for e in old_state + [event]), current_state_ids.values()
458        )
459
460        self.assertIsNotNone(context.state_group_before_event)
461        self.assertNotEqual(context.state_group_before_event, context.state_group)
462        self.assertEqual(context.state_group_before_event, context.prev_group)
463        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
464
465    @defer.inlineCallbacks
466    def test_trivial_annotate_message(self):
467        prev_event_id = "prev_event_id"
468        event = create_event(
469            type="test_message", name="event2", prev_events=[(prev_event_id, {})]
470        )
471
472        old_state = [
473            create_event(type="test1", state_key="1"),
474            create_event(type="test1", state_key="2"),
475            create_event(type="test2", state_key=""),
476        ]
477
478        group_name = yield defer.ensureDeferred(
479            self.store.store_state_group(
480                prev_event_id,
481                event.room_id,
482                None,
483                None,
484                {(e.type, e.state_key): e.event_id for e in old_state},
485            )
486        )
487        self.store.register_event_id_state_group(prev_event_id, group_name)
488
489        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
490
491        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
492
493        self.assertEqual(
494            {e.event_id for e in old_state}, set(current_state_ids.values())
495        )
496
497        self.assertEqual(group_name, context.state_group)
498
499    @defer.inlineCallbacks
500    def test_trivial_annotate_state(self):
501        prev_event_id = "prev_event_id"
502        event = create_event(
503            type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
504        )
505
506        old_state = [
507            create_event(type="test1", state_key="1"),
508            create_event(type="test1", state_key="2"),
509            create_event(type="test2", state_key=""),
510        ]
511
512        group_name = yield defer.ensureDeferred(
513            self.store.store_state_group(
514                prev_event_id,
515                event.room_id,
516                None,
517                None,
518                {(e.type, e.state_key): e.event_id for e in old_state},
519            )
520        )
521        self.store.register_event_id_state_group(prev_event_id, group_name)
522
523        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
524
525        prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
526
527        self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
528
529        self.assertIsNotNone(context.state_group)
530
531    @defer.inlineCallbacks
532    def test_resolve_message_conflict(self):
533        prev_event_id1 = "event_id1"
534        prev_event_id2 = "event_id2"
535        event = create_event(
536            type="test_message",
537            name="event3",
538            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
539        )
540
541        creation = create_event(type=EventTypes.Create, state_key="")
542
543        old_state_1 = [
544            creation,
545            create_event(type="test1", state_key="1"),
546            create_event(type="test1", state_key="2"),
547            create_event(type="test2", state_key=""),
548        ]
549
550        old_state_2 = [
551            creation,
552            create_event(type="test1", state_key="1"),
553            create_event(type="test3", state_key="2"),
554            create_event(type="test4", state_key=""),
555        ]
556
557        self.store.register_events(old_state_1)
558        self.store.register_events(old_state_2)
559
560        context = yield self._get_context(
561            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
562        )
563
564        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
565
566        self.assertEqual(len(current_state_ids), 6)
567
568        self.assertIsNotNone(context.state_group)
569
570    @defer.inlineCallbacks
571    def test_resolve_state_conflict(self):
572        prev_event_id1 = "event_id1"
573        prev_event_id2 = "event_id2"
574        event = create_event(
575            type="test4",
576            state_key="",
577            name="event",
578            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
579        )
580
581        creation = create_event(type=EventTypes.Create, state_key="")
582
583        old_state_1 = [
584            creation,
585            create_event(type="test1", state_key="1"),
586            create_event(type="test1", state_key="2"),
587            create_event(type="test2", state_key=""),
588        ]
589
590        old_state_2 = [
591            creation,
592            create_event(type="test1", state_key="1"),
593            create_event(type="test3", state_key="2"),
594            create_event(type="test4", state_key=""),
595        ]
596
597        store = StateGroupStore()
598        store.register_events(old_state_1)
599        store.register_events(old_state_2)
600        self.store.get_events = store.get_events
601
602        context = yield self._get_context(
603            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
604        )
605
606        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
607
608        self.assertEqual(len(current_state_ids), 6)
609
610        self.assertIsNotNone(context.state_group)
611
612    @defer.inlineCallbacks
613    def test_standard_depth_conflict(self):
614        prev_event_id1 = "event_id1"
615        prev_event_id2 = "event_id2"
616        event = create_event(
617            type="test4",
618            name="event",
619            prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
620        )
621
622        member_event = create_event(
623            type=EventTypes.Member,
624            state_key="@user_id:example.com",
625            content={"membership": Membership.JOIN},
626        )
627
628        power_levels = create_event(
629            type=EventTypes.PowerLevels,
630            state_key="",
631            content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
632        )
633
634        creation = create_event(
635            type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
636        )
637
638        old_state_1 = [
639            creation,
640            power_levels,
641            member_event,
642            create_event(type="test1", state_key="1", depth=1),
643        ]
644
645        old_state_2 = [
646            creation,
647            power_levels,
648            member_event,
649            create_event(type="test1", state_key="1", depth=2),
650        ]
651
652        store = StateGroupStore()
653        store.register_events(old_state_1)
654        store.register_events(old_state_2)
655        self.store.get_events = store.get_events
656
657        context = yield self._get_context(
658            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
659        )
660
661        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
662
663        self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
664
665        # Reverse the depth to make sure we are actually using the depths
666        # during state resolution.
667
668        old_state_1 = [
669            creation,
670            power_levels,
671            member_event,
672            create_event(type="test1", state_key="1", depth=2),
673        ]
674
675        old_state_2 = [
676            creation,
677            power_levels,
678            member_event,
679            create_event(type="test1", state_key="1", depth=1),
680        ]
681
682        store.register_events(old_state_1)
683        store.register_events(old_state_2)
684
685        context = yield self._get_context(
686            event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
687        )
688
689        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
690
691        self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
692
693    @defer.inlineCallbacks
694    def _get_context(
695        self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
696    ):
697        sg1 = yield defer.ensureDeferred(
698            self.store.store_state_group(
699                prev_event_id_1,
700                event.room_id,
701                None,
702                None,
703                {(e.type, e.state_key): e.event_id for e in old_state_1},
704            )
705        )
706        self.store.register_event_id_state_group(prev_event_id_1, sg1)
707
708        sg2 = yield defer.ensureDeferred(
709            self.store.store_state_group(
710                prev_event_id_2,
711                event.room_id,
712                None,
713                None,
714                {(e.type, e.state_key): e.event_id for e in old_state_2},
715            )
716        )
717        self.store.register_event_id_state_group(prev_event_id_2, sg2)
718
719        result = yield defer.ensureDeferred(self.state.compute_event_context(event))
720        return result
721