1# Copyright 2014-2016 OpenMarket Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import List, Optional 15from unittest.mock import Mock 16 17from twisted.internet import defer 18 19from synapse.api.auth import Auth 20from synapse.api.constants import EventTypes, Membership 21from synapse.api.room_versions import RoomVersions 22from synapse.events import make_event_from_dict 23from synapse.events.snapshot import EventContext 24from synapse.state import StateHandler, StateResolutionHandler 25 26from tests import unittest 27 28from .utils import MockClock, default_config 29 30_next_event_id = 1000 31 32 33def create_event( 34 name=None, 35 type=None, 36 state_key=None, 37 depth=2, 38 event_id=None, 39 prev_events: Optional[List[str]] = None, 40 **kwargs, 41): 42 global _next_event_id 43 44 if not event_id: 45 _next_event_id += 1 46 event_id = "$%s:test" % (_next_event_id,) 47 48 if not name: 49 if state_key is not None: 50 name = "<%s-%s, %s>" % (type, state_key, event_id) 51 else: 52 name = "<%s, %s>" % (type, event_id) 53 54 d = { 55 "event_id": event_id, 56 "type": type, 57 "sender": "@user_id:example.com", 58 "room_id": "!room_id:example.com", 59 "depth": depth, 60 "prev_events": prev_events or [], 61 } 62 63 if state_key is not None: 64 d["state_key"] = state_key 65 66 d.update(kwargs) 67 68 event = make_event_from_dict(d) 69 70 return event 71 72 73class StateGroupStore: 74 def __init__(self): 75 self._event_to_state_group = {} 76 self._group_to_state = {} 77 78 self._event_id_to_event = {} 79 80 self._next_group = 1 81 82 async def get_state_groups_ids(self, room_id, event_ids): 83 groups = {} 84 for event_id in event_ids: 85 group = self._event_to_state_group.get(event_id) 86 if group: 87 groups[group] = self._group_to_state[group] 88 89 return groups 90 91 async def store_state_group( 92 self, event_id, room_id, prev_group, delta_ids, current_state_ids 93 ): 94 state_group = self._next_group 95 self._next_group += 1 96 97 self._group_to_state[state_group] = dict(current_state_ids) 98 99 return state_group 100 101 async def get_events(self, event_ids, **kwargs): 102 return { 103 e_id: self._event_id_to_event[e_id] 104 for e_id in event_ids 105 if e_id in self._event_id_to_event 106 } 107 108 async def get_state_group_delta(self, name): 109 return None, None 110 111 def register_events(self, events): 112 for e in events: 113 self._event_id_to_event[e.event_id] = e 114 115 def register_event_context(self, event, context): 116 self._event_to_state_group[event.event_id] = context.state_group 117 118 def register_event_id_state_group(self, event_id, state_group): 119 self._event_to_state_group[event_id] = state_group 120 121 async def get_room_version_id(self, room_id): 122 return RoomVersions.V1.identifier 123 124 125class DictObj(dict): 126 def __init__(self, **kwargs): 127 super().__init__(kwargs) 128 self.__dict__ = self 129 130 131class Graph: 132 def __init__(self, nodes, edges): 133 events = {} 134 clobbered = set(events.keys()) 135 136 for event_id, fields in nodes.items(): 137 refs = edges.get(event_id) 138 if refs: 139 clobbered.difference_update(refs) 140 prev_events = [(r, {}) for r in refs] 141 else: 142 prev_events = [] 143 144 events[event_id] = create_event( 145 event_id=event_id, prev_events=prev_events, **fields 146 ) 147 148 self._leaves = clobbered 149 self._events = sorted(events.values(), key=lambda e: e.depth) 150 151 def walk(self): 152 return iter(self._events) 153 154 def get_leaves(self): 155 return (self._events[i] for i in self._leaves) 156 157 158class StateTestCase(unittest.TestCase): 159 def setUp(self): 160 self.store = StateGroupStore() 161 storage = Mock(main=self.store, state=self.store) 162 hs = Mock( 163 spec_set=[ 164 "config", 165 "get_datastore", 166 "get_storage", 167 "get_auth", 168 "get_state_handler", 169 "get_clock", 170 "get_state_resolution_handler", 171 "get_account_validity_handler", 172 "hostname", 173 ] 174 ) 175 hs.config = default_config("tesths", True) 176 hs.get_datastore.return_value = self.store 177 hs.get_state_handler.return_value = None 178 hs.get_clock.return_value = MockClock() 179 hs.get_auth.return_value = Auth(hs) 180 hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) 181 hs.get_storage.return_value = storage 182 183 self.state = StateHandler(hs) 184 self.event_id = 0 185 186 @defer.inlineCallbacks 187 def test_branch_no_conflict(self): 188 graph = Graph( 189 nodes={ 190 "START": DictObj( 191 type=EventTypes.Create, state_key="", content={}, depth=1 192 ), 193 "A": DictObj(type=EventTypes.Message, depth=2), 194 "B": DictObj(type=EventTypes.Message, depth=3), 195 "C": DictObj(type=EventTypes.Name, state_key="", depth=3), 196 "D": DictObj(type=EventTypes.Message, depth=4), 197 }, 198 edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, 199 ) 200 201 self.store.register_events(graph.walk()) 202 203 context_store: dict[str, EventContext] = {} 204 205 for event in graph.walk(): 206 context = yield defer.ensureDeferred( 207 self.state.compute_event_context(event) 208 ) 209 self.store.register_event_context(event, context) 210 context_store[event.event_id] = context 211 212 ctx_c = context_store["C"] 213 ctx_d = context_store["D"] 214 215 prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) 216 self.assertEqual(2, len(prev_state_ids)) 217 218 self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) 219 self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) 220 221 @defer.inlineCallbacks 222 def test_branch_basic_conflict(self): 223 graph = Graph( 224 nodes={ 225 "START": DictObj( 226 type=EventTypes.Create, 227 state_key="", 228 content={"creator": "@user_id:example.com"}, 229 depth=1, 230 ), 231 "A": DictObj( 232 type=EventTypes.Member, 233 state_key="@user_id:example.com", 234 content={"membership": Membership.JOIN}, 235 membership=Membership.JOIN, 236 depth=2, 237 ), 238 "B": DictObj(type=EventTypes.Name, state_key="", depth=3), 239 "C": DictObj(type=EventTypes.Name, state_key="", depth=4), 240 "D": DictObj(type=EventTypes.Message, depth=5), 241 }, 242 edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]}, 243 ) 244 245 self.store.register_events(graph.walk()) 246 247 context_store = {} 248 249 for event in graph.walk(): 250 context = yield defer.ensureDeferred( 251 self.state.compute_event_context(event) 252 ) 253 self.store.register_event_context(event, context) 254 context_store[event.event_id] = context 255 256 # C ends up winning the resolution between B and C 257 258 ctx_c = context_store["C"] 259 ctx_d = context_store["D"] 260 261 prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) 262 self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) 263 264 self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) 265 self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) 266 267 @defer.inlineCallbacks 268 def test_branch_have_banned_conflict(self): 269 graph = Graph( 270 nodes={ 271 "START": DictObj( 272 type=EventTypes.Create, 273 state_key="", 274 content={"creator": "@user_id:example.com"}, 275 depth=1, 276 ), 277 "A": DictObj( 278 type=EventTypes.Member, 279 state_key="@user_id:example.com", 280 content={"membership": Membership.JOIN}, 281 membership=Membership.JOIN, 282 depth=2, 283 ), 284 "B": DictObj(type=EventTypes.Name, state_key="", depth=3), 285 "C": DictObj( 286 type=EventTypes.Member, 287 state_key="@user_id_2:example.com", 288 content={"membership": Membership.BAN}, 289 membership=Membership.BAN, 290 depth=4, 291 ), 292 "D": DictObj( 293 type=EventTypes.Name, 294 state_key="", 295 depth=4, 296 sender="@user_id_2:example.com", 297 ), 298 "E": DictObj(type=EventTypes.Message, depth=5), 299 }, 300 edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]}, 301 ) 302 303 self.store.register_events(graph.walk()) 304 305 context_store = {} 306 307 for event in graph.walk(): 308 context = yield defer.ensureDeferred( 309 self.state.compute_event_context(event) 310 ) 311 self.store.register_event_context(event, context) 312 context_store[event.event_id] = context 313 314 # C ends up winning the resolution between C and D because bans win over other 315 # changes 316 317 ctx_c = context_store["C"] 318 ctx_e = context_store["E"] 319 320 prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) 321 self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) 322 self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) 323 self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) 324 325 @defer.inlineCallbacks 326 def test_branch_have_perms_conflict(self): 327 userid1 = "@user_id:example.com" 328 userid2 = "@user_id2:example.com" 329 330 nodes = { 331 "A1": DictObj( 332 type=EventTypes.Create, 333 state_key="", 334 content={"creator": userid1}, 335 depth=1, 336 ), 337 "A2": DictObj( 338 type=EventTypes.Member, 339 state_key=userid1, 340 content={"membership": Membership.JOIN}, 341 membership=Membership.JOIN, 342 ), 343 "A3": DictObj( 344 type=EventTypes.Member, 345 state_key=userid2, 346 content={"membership": Membership.JOIN}, 347 membership=Membership.JOIN, 348 ), 349 "A4": DictObj( 350 type=EventTypes.PowerLevels, 351 state_key="", 352 content={ 353 "events": {"m.room.name": 50}, 354 "users": {userid1: 100, userid2: 60}, 355 }, 356 ), 357 "A5": DictObj(type=EventTypes.Name, state_key=""), 358 "B": DictObj( 359 type=EventTypes.PowerLevels, 360 state_key="", 361 content={"events": {"m.room.name": 50}, "users": {userid2: 30}}, 362 ), 363 "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2), 364 "D": DictObj(type=EventTypes.Message), 365 } 366 edges = { 367 "A2": ["A1"], 368 "A3": ["A2"], 369 "A4": ["A3"], 370 "A5": ["A4"], 371 "B": ["A5"], 372 "C": ["A5"], 373 "D": ["B", "C"], 374 } 375 self._add_depths(nodes, edges) 376 graph = Graph(nodes, edges) 377 378 self.store.register_events(graph.walk()) 379 380 context_store = {} 381 382 for event in graph.walk(): 383 context = yield defer.ensureDeferred( 384 self.state.compute_event_context(event) 385 ) 386 self.store.register_event_context(event, context) 387 context_store[event.event_id] = context 388 389 # B ends up winning the resolution between B and C because power levels 390 # win over other changes. 391 392 ctx_b = context_store["B"] 393 ctx_d = context_store["D"] 394 395 prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) 396 self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) 397 398 self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) 399 self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group) 400 401 def _add_depths(self, nodes, edges): 402 def _get_depth(ev): 403 node = nodes[ev] 404 if "depth" not in node: 405 prevs = edges[ev] 406 depth = max(_get_depth(prev) for prev in prevs) + 1 407 node["depth"] = depth 408 return node["depth"] 409 410 for n in nodes: 411 _get_depth(n) 412 413 @defer.inlineCallbacks 414 def test_annotate_with_old_message(self): 415 event = create_event(type="test_message", name="event") 416 417 old_state = [ 418 create_event(type="test1", state_key="1"), 419 create_event(type="test1", state_key="2"), 420 create_event(type="test2", state_key=""), 421 ] 422 423 context = yield defer.ensureDeferred( 424 self.state.compute_event_context(event, old_state=old_state) 425 ) 426 427 prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) 428 self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) 429 430 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 431 self.assertCountEqual( 432 (e.event_id for e in old_state), current_state_ids.values() 433 ) 434 435 self.assertIsNotNone(context.state_group_before_event) 436 self.assertEqual(context.state_group_before_event, context.state_group) 437 438 @defer.inlineCallbacks 439 def test_annotate_with_old_state(self): 440 event = create_event(type="state", state_key="", name="event") 441 442 old_state = [ 443 create_event(type="test1", state_key="1"), 444 create_event(type="test1", state_key="2"), 445 create_event(type="test2", state_key=""), 446 ] 447 448 context = yield defer.ensureDeferred( 449 self.state.compute_event_context(event, old_state=old_state) 450 ) 451 452 prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) 453 self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) 454 455 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 456 self.assertCountEqual( 457 (e.event_id for e in old_state + [event]), current_state_ids.values() 458 ) 459 460 self.assertIsNotNone(context.state_group_before_event) 461 self.assertNotEqual(context.state_group_before_event, context.state_group) 462 self.assertEqual(context.state_group_before_event, context.prev_group) 463 self.assertEqual({("state", ""): event.event_id}, context.delta_ids) 464 465 @defer.inlineCallbacks 466 def test_trivial_annotate_message(self): 467 prev_event_id = "prev_event_id" 468 event = create_event( 469 type="test_message", name="event2", prev_events=[(prev_event_id, {})] 470 ) 471 472 old_state = [ 473 create_event(type="test1", state_key="1"), 474 create_event(type="test1", state_key="2"), 475 create_event(type="test2", state_key=""), 476 ] 477 478 group_name = yield defer.ensureDeferred( 479 self.store.store_state_group( 480 prev_event_id, 481 event.room_id, 482 None, 483 None, 484 {(e.type, e.state_key): e.event_id for e in old_state}, 485 ) 486 ) 487 self.store.register_event_id_state_group(prev_event_id, group_name) 488 489 context = yield defer.ensureDeferred(self.state.compute_event_context(event)) 490 491 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 492 493 self.assertEqual( 494 {e.event_id for e in old_state}, set(current_state_ids.values()) 495 ) 496 497 self.assertEqual(group_name, context.state_group) 498 499 @defer.inlineCallbacks 500 def test_trivial_annotate_state(self): 501 prev_event_id = "prev_event_id" 502 event = create_event( 503 type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})] 504 ) 505 506 old_state = [ 507 create_event(type="test1", state_key="1"), 508 create_event(type="test1", state_key="2"), 509 create_event(type="test2", state_key=""), 510 ] 511 512 group_name = yield defer.ensureDeferred( 513 self.store.store_state_group( 514 prev_event_id, 515 event.room_id, 516 None, 517 None, 518 {(e.type, e.state_key): e.event_id for e in old_state}, 519 ) 520 ) 521 self.store.register_event_id_state_group(prev_event_id, group_name) 522 523 context = yield defer.ensureDeferred(self.state.compute_event_context(event)) 524 525 prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) 526 527 self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) 528 529 self.assertIsNotNone(context.state_group) 530 531 @defer.inlineCallbacks 532 def test_resolve_message_conflict(self): 533 prev_event_id1 = "event_id1" 534 prev_event_id2 = "event_id2" 535 event = create_event( 536 type="test_message", 537 name="event3", 538 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], 539 ) 540 541 creation = create_event(type=EventTypes.Create, state_key="") 542 543 old_state_1 = [ 544 creation, 545 create_event(type="test1", state_key="1"), 546 create_event(type="test1", state_key="2"), 547 create_event(type="test2", state_key=""), 548 ] 549 550 old_state_2 = [ 551 creation, 552 create_event(type="test1", state_key="1"), 553 create_event(type="test3", state_key="2"), 554 create_event(type="test4", state_key=""), 555 ] 556 557 self.store.register_events(old_state_1) 558 self.store.register_events(old_state_2) 559 560 context = yield self._get_context( 561 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 562 ) 563 564 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 565 566 self.assertEqual(len(current_state_ids), 6) 567 568 self.assertIsNotNone(context.state_group) 569 570 @defer.inlineCallbacks 571 def test_resolve_state_conflict(self): 572 prev_event_id1 = "event_id1" 573 prev_event_id2 = "event_id2" 574 event = create_event( 575 type="test4", 576 state_key="", 577 name="event", 578 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], 579 ) 580 581 creation = create_event(type=EventTypes.Create, state_key="") 582 583 old_state_1 = [ 584 creation, 585 create_event(type="test1", state_key="1"), 586 create_event(type="test1", state_key="2"), 587 create_event(type="test2", state_key=""), 588 ] 589 590 old_state_2 = [ 591 creation, 592 create_event(type="test1", state_key="1"), 593 create_event(type="test3", state_key="2"), 594 create_event(type="test4", state_key=""), 595 ] 596 597 store = StateGroupStore() 598 store.register_events(old_state_1) 599 store.register_events(old_state_2) 600 self.store.get_events = store.get_events 601 602 context = yield self._get_context( 603 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 604 ) 605 606 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 607 608 self.assertEqual(len(current_state_ids), 6) 609 610 self.assertIsNotNone(context.state_group) 611 612 @defer.inlineCallbacks 613 def test_standard_depth_conflict(self): 614 prev_event_id1 = "event_id1" 615 prev_event_id2 = "event_id2" 616 event = create_event( 617 type="test4", 618 name="event", 619 prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], 620 ) 621 622 member_event = create_event( 623 type=EventTypes.Member, 624 state_key="@user_id:example.com", 625 content={"membership": Membership.JOIN}, 626 ) 627 628 power_levels = create_event( 629 type=EventTypes.PowerLevels, 630 state_key="", 631 content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}}, 632 ) 633 634 creation = create_event( 635 type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} 636 ) 637 638 old_state_1 = [ 639 creation, 640 power_levels, 641 member_event, 642 create_event(type="test1", state_key="1", depth=1), 643 ] 644 645 old_state_2 = [ 646 creation, 647 power_levels, 648 member_event, 649 create_event(type="test1", state_key="1", depth=2), 650 ] 651 652 store = StateGroupStore() 653 store.register_events(old_state_1) 654 store.register_events(old_state_2) 655 self.store.get_events = store.get_events 656 657 context = yield self._get_context( 658 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 659 ) 660 661 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 662 663 self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) 664 665 # Reverse the depth to make sure we are actually using the depths 666 # during state resolution. 667 668 old_state_1 = [ 669 creation, 670 power_levels, 671 member_event, 672 create_event(type="test1", state_key="1", depth=2), 673 ] 674 675 old_state_2 = [ 676 creation, 677 power_levels, 678 member_event, 679 create_event(type="test1", state_key="1", depth=1), 680 ] 681 682 store.register_events(old_state_1) 683 store.register_events(old_state_2) 684 685 context = yield self._get_context( 686 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 687 ) 688 689 current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) 690 691 self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) 692 693 @defer.inlineCallbacks 694 def _get_context( 695 self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 696 ): 697 sg1 = yield defer.ensureDeferred( 698 self.store.store_state_group( 699 prev_event_id_1, 700 event.room_id, 701 None, 702 None, 703 {(e.type, e.state_key): e.event_id for e in old_state_1}, 704 ) 705 ) 706 self.store.register_event_id_state_group(prev_event_id_1, sg1) 707 708 sg2 = yield defer.ensureDeferred( 709 self.store.store_state_group( 710 prev_event_id_2, 711 event.room_id, 712 None, 713 None, 714 {(e.type, e.state_key): e.event_id for e in old_state_2}, 715 ) 716 ) 717 self.store.register_event_id_state_group(prev_event_id_2, sg2) 718 719 result = yield defer.ensureDeferred(self.state.compute_event_context(event)) 720 return result 721