Skip to content
Snippets Groups Projects
test_event_chain.py 24.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • # Copyright 2020 The Matrix.org Foundation C.I.C.
    #
    # Licensed under the Apache License, Version 2.0 (the 'License');
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an 'AS IS' BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    
    from typing import Dict, List, Set, Tuple
    
    
    from twisted.trial import unittest
    
    from synapse.api.constants import EventTypes
    from synapse.api.room_versions import RoomVersions
    from synapse.events import EventBase
    
    from synapse.events.snapshot import EventContext
    
    from synapse.rest import admin
    
    from synapse.rest.client import login, room
    
    from synapse.storage.databases.main.events import _LinkMap
    
    from synapse.types import create_requester
    
    
    from tests.unittest import HomeserverTestCase
    
    
    class EventChainStoreTestCase(HomeserverTestCase):
        def prepare(self, reactor, clock, hs):
    
            self.store = hs.get_datastores().main
    
            self._next_stream_ordering = 1
    
        def test_simple(self):
            """Test that the example in `docs/auth_chain_difference_algorithm.md`
            works.
            """
    
            event_factory = self.hs.get_event_builder_factory()
            bob = "@creator:test"
            alice = "@alice:test"
            room_id = "!room:test"
    
            # Ensure that we have a rooms entry so that we generate the chain index.
            self.get_success(
                self.store.store_room(
                    room_id=room_id,
                    room_creator_user_id="",
                    is_public=True,
                    room_version=RoomVersions.V6,
                )
            )
    
            create = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Create,
                        "state_key": "",
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "create"},
                    },
                ).build(prev_event_ids=[], auth_event_ids=[])
            )
    
            bob_join = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": bob,
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "bob_join"},
                    },
                ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
            )
    
            power = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.PowerLevels,
                        "state_key": "",
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "power"},
                    },
                ).build(
    
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id],
    
                )
            )
    
            alice_invite = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "alice_invite"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
                )
            )
    
            alice_join = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": alice,
                        "room_id": room_id,
                        "content": {"tag": "alice_join"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
                )
            )
    
            power_2 = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.PowerLevels,
                        "state_key": "",
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "power_2"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
                )
            )
    
            bob_join_2 = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": bob,
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "bob_join_2"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
                )
            )
    
            alice_join2 = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": alice,
                        "room_id": room_id,
                        "content": {"tag": "alice_join2"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[
                        create.event_id,
                        alice_join.event_id,
                        power_2.event_id,
                    ],
                )
            )
    
            events = [
                create,
                bob_join,
                power,
                alice_invite,
                alice_join,
                bob_join_2,
                power_2,
                alice_join2,
            ]
    
            expected_links = [
                (bob_join, create),
                (power, create),
                (power, bob_join),
                (alice_invite, create),
                (alice_invite, power),
                (alice_invite, bob_join),
                (bob_join_2, power),
                (alice_join2, power_2),
            ]
    
            self.persist(events)
            chain_map, link_map = self.fetch_chains(events)
    
            # Check that the expected links and only the expected links have been
            # added.
            self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
    
            for start, end in expected_links:
                start_id, start_seq = chain_map[start.event_id]
                end_id, end_seq = chain_map[end.event_id]
    
                self.assertIn(
                    (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
                )
    
            # Test that everything can reach the create event, but the create event
            # can't reach anything.
            for event in events[1:]:
                self.assertTrue(
                    link_map.exists_path_from(
                        chain_map[event.event_id], chain_map[create.event_id]
                    ),
                )
    
                self.assertFalse(
                    link_map.exists_path_from(
    
                        chain_map[create.event_id],
                        chain_map[event.event_id],
    
                    ),
                )
    
        def test_out_of_order_events(self):
            """Test that we handle persisting events that we don't have the full
            auth chain for yet (which should only happen for out of band memberships).
            """
            event_factory = self.hs.get_event_builder_factory()
            bob = "@creator:test"
            alice = "@alice:test"
            room_id = "!room:test"
    
            # Ensure that we have a rooms entry so that we generate the chain index.
            self.get_success(
                self.store.store_room(
                    room_id=room_id,
                    room_creator_user_id="",
                    is_public=True,
                    room_version=RoomVersions.V6,
                )
            )
    
            # First persist the base room.
            create = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Create,
                        "state_key": "",
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "create"},
                    },
                ).build(prev_event_ids=[], auth_event_ids=[])
            )
    
            bob_join = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": bob,
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "bob_join"},
                    },
                ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
            )
    
            power = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.PowerLevels,
                        "state_key": "",
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "power"},
                    },
                ).build(
    
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id],
    
                )
            )
    
            self.persist([create, bob_join, power])
    
            # Now persist an invite and a couple of memberships out of order.
            alice_invite = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": bob,
                        "room_id": room_id,
                        "content": {"tag": "alice_invite"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
                )
            )
    
            alice_join = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": alice,
                        "room_id": room_id,
                        "content": {"tag": "alice_join"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
                )
            )
    
            alice_join2 = self.get_success(
                event_factory.for_room_version(
                    RoomVersions.V6,
                    {
                        "type": EventTypes.Member,
                        "state_key": alice,
                        "sender": alice,
                        "room_id": room_id,
                        "content": {"tag": "alice_join2"},
                    },
                ).build(
                    prev_event_ids=[],
                    auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
                )
            )
    
            self.persist([alice_join])
            self.persist([alice_join2])
            self.persist([alice_invite])
    
            # The end result should be sane.
            events = [create, bob_join, power, alice_invite, alice_join]
    
            chain_map, link_map = self.fetch_chains(events)
    
            expected_links = [
                (bob_join, create),
                (power, create),
                (power, bob_join),
                (alice_invite, create),
                (alice_invite, power),
                (alice_invite, bob_join),
            ]
    
            # Check that the expected links and only the expected links have been
            # added.
            self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
    
            for start, end in expected_links:
                start_id, start_seq = chain_map[start.event_id]
                end_id, end_seq = chain_map[end.event_id]
    
                self.assertIn(
                    (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
                )
    
        def persist(
    
        ):
            """Persist the given events and check that the links generated match
            those given.
            """
    
            persist_events_store = self.hs.get_datastores().persist_events
    
            for e in events:
                e.internal_metadata.stream_ordering = self._next_stream_ordering
                self._next_stream_ordering += 1
    
            def _persist(txn):
                # We need to persist the events to the events and state_events
                # tables.
    
                persist_events_store._store_event_txn(
    
                    txn,
                    [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
    
    
                # Actually call the function that calculates the auth chain stuff.
                persist_events_store._persist_event_auth_chain_txn(txn, events)
    
            self.get_success(
    
                persist_events_store.db_pool.runInteraction(
                    "_persist",
                    _persist,
                )
    
            )
    
        def fetch_chains(
            self, events: List[EventBase]
        ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
    
            # Fetch the map from event ID -> (chain ID, sequence number)
            rows = self.get_success(
                self.store.db_pool.simple_select_many_batch(
                    table="event_auth_chains",
                    column="event_id",
                    iterable=[e.event_id for e in events],
                    retcols=("event_id", "chain_id", "sequence_number"),
                    keyvalues={},
                )
            )
    
            chain_map = {
                row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
            }
    
            # Fetch all the links and pass them to the _LinkMap.
            rows = self.get_success(
                self.store.db_pool.simple_select_many_batch(
                    table="event_auth_chain_links",
                    column="origin_chain_id",
                    iterable=[chain_id for chain_id, _ in chain_map.values()],
                    retcols=(
                        "origin_chain_id",
                        "origin_sequence_number",
                        "target_chain_id",
                        "target_sequence_number",
                    ),
                    keyvalues={},
                )
            )
    
            link_map = _LinkMap()
            for row in rows:
                added = link_map.add_link(
                    (row["origin_chain_id"], row["origin_sequence_number"]),
                    (row["target_chain_id"], row["target_sequence_number"]),
                )
    
                # We shouldn't have persisted any redundant links
                self.assertTrue(added)
    
            return chain_map, link_map
    
    
    class LinkMapTestCase(unittest.TestCase):
        def test_simple(self):
    
            """Basic tests for the LinkMap."""
    
            link_map = _LinkMap()
    
            link_map.add_link((1, 1), (2, 1), new=False)
            self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
            self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
            self.assertCountEqual(link_map.get_additions(), [])
            self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
            self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
            self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
            self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
    
            # Attempting to add a redundant link is ignored.
            self.assertFalse(link_map.add_link((1, 4), (2, 1)))
            self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
    
            # Adding new non-redundant links works
            self.assertTrue(link_map.add_link((1, 3), (2, 3)))
            self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
    
            self.assertTrue(link_map.add_link((2, 5), (1, 3)))
            self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
            self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
    
            self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
    
    
    
    class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
    
        servlets = [
            admin.register_servlets,
            room.register_servlets,
            login.register_servlets,
        ]
    
    
        def prepare(self, reactor, clock, hs):
    
            self.store = hs.get_datastores().main
    
            self.user_id = self.register_user("foo", "pass")
            self.token = self.login("foo", "pass")
            self.requester = create_requester(self.user_id)
    
        def _generate_room(self) -> Tuple[str, List[Set[str]]]:
    
            """Insert a room without a chain cover index."""
    
            room_id = self.helper.create_room_as(self.user_id, tok=self.token)
    
    
            # Mark the room as not having a chain cover index
            self.get_success(
    
                self.store.db_pool.simple_update(
    
                    table="rooms",
                    keyvalues={"room_id": room_id},
                    updatevalues={"has_auth_chain_index": False},
                    desc="test",
                )
            )
    
            # Create a fork in the DAG with different events.
            event_handler = self.hs.get_event_creation_handler()
    
            latest_event_ids = self.get_success(
                self.store.get_prev_events_for_room(room_id)
            )
    
            event, context = self.get_success(
                event_handler.create_event(
    
                    {
                        "type": "some_state_type",
                        "state_key": "",
                        "content": {},
                        "room_id": room_id,
    
                        "sender": self.user_id,
    
                    },
                    prev_event_ids=latest_event_ids,
                )
            )
            self.get_success(
    
                event_handler.handle_new_client_event(
                    self.requester, events_and_context=[(event, context)]
                )
    
            state1 = set(self.get_success(context.get_current_state_ids()).values())
    
    
            event, context = self.get_success(
                event_handler.create_event(
    
                    {
                        "type": "some_state_type",
                        "state_key": "",
                        "content": {},
                        "room_id": room_id,
    
                        "sender": self.user_id,
    
                    },
                    prev_event_ids=latest_event_ids,
                )
            )
            self.get_success(
    
                event_handler.handle_new_client_event(
                    self.requester, events_and_context=[(event, context)]
                )
    
            state2 = set(self.get_success(context.get_current_state_ids()).values())
    
    
            # Delete the chain cover info.
    
            def _delete_tables(txn):
                txn.execute("DELETE FROM event_auth_chains")
                txn.execute("DELETE FROM event_auth_chain_links")
    
    
            self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
    
            return room_id, [state1, state2]
    
        def test_background_update_single_room(self):
            """Test that the background update to calculate auth chains for historic
            rooms works correctly.
            """
    
            # Create a room
            room_id, states = self._generate_room()
    
    
            # Insert and run the background update.
            self.get_success(
    
                self.store.db_pool.simple_insert(
    
                    "background_updates",
                    {"update_name": "chain_cover", "progress_json": "{}"},
                )
            )
    
            # Ugh, have to reset this flag
    
            self.store.db_pool.updates._all_done = False
    
            self.wait_for_background_updates()
    
    
            # Test that the `has_auth_chain_index` has been set
    
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
    
    
            # Test that calculating the auth chain difference using the newly
            # calculated chain cover works.
            self.get_success(
    
                self.store.db_pool.runInteraction(
    
                    self.store._get_auth_chain_difference_using_cover_index_txn,
    
                    states,
                )
            )
    
        def test_background_update_multiple_rooms(self):
            """Test that the background update to calculate auth chains for historic
            rooms works correctly.
            """
            # Create a room
            room_id1, states1 = self._generate_room()
            room_id2, states2 = self._generate_room()
            room_id3, states2 = self._generate_room()
    
            # Insert and run the background update.
            self.get_success(
                self.store.db_pool.simple_insert(
                    "background_updates",
                    {"update_name": "chain_cover", "progress_json": "{}"},
                )
            )
    
            # Ugh, have to reset this flag
            self.store.db_pool.updates._all_done = False
    
    
            self.wait_for_background_updates()
    
    
            # Test that the `has_auth_chain_index` has been set
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
    
            # Test that calculating the auth chain difference using the newly
            # calculated chain cover works.
            self.get_success(
                self.store.db_pool.runInteraction(
                    "test",
                    self.store._get_auth_chain_difference_using_cover_index_txn,
                    room_id1,
                    states1,
    
    
        def test_background_update_single_large_room(self):
            """Test that the background update to calculate auth chains for historic
            rooms works correctly.
            """
    
            # Create a room
            room_id, states = self._generate_room()
    
            # Add a bunch of state so that it takes multiple iterations of the
            # background update to process the room.
            for i in range(0, 150):
                self.helper.send_state(
                    room_id, event_type="m.test", body={"index": i}, tok=self.token
                )
    
            # Insert and run the background update.
            self.get_success(
                self.store.db_pool.simple_insert(
                    "background_updates",
                    {"update_name": "chain_cover", "progress_json": "{}"},
                )
            )
    
            # Ugh, have to reset this flag
            self.store.db_pool.updates._all_done = False
    
            iterations = 0
            while not self.get_success(
                self.store.db_pool.updates.has_completed_background_updates()
            ):
                iterations += 1
                self.get_success(
    
                    self.store.db_pool.updates.do_next_background_update(False), by=0.1
    
                )
    
            # Ensure that we did actually take multiple iterations to process the
            # room.
            self.assertGreater(iterations, 1)
    
            # Test that the `has_auth_chain_index` has been set
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
    
            # Test that calculating the auth chain difference using the newly
            # calculated chain cover works.
            self.get_success(
                self.store.db_pool.runInteraction(
                    "test",
                    self.store._get_auth_chain_difference_using_cover_index_txn,
                    room_id,
                    states,
                )
            )
    
        def test_background_update_multiple_large_room(self):
            """Test that the background update to calculate auth chains for historic
            rooms works correctly.
            """
    
            # Create the rooms
            room_id1, _ = self._generate_room()
            room_id2, _ = self._generate_room()
    
            # Add a bunch of state so that it takes multiple iterations of the
            # background update to process the room.
            for i in range(0, 150):
                self.helper.send_state(
                    room_id1, event_type="m.test", body={"index": i}, tok=self.token
                )
    
            for i in range(0, 150):
                self.helper.send_state(
                    room_id2, event_type="m.test", body={"index": i}, tok=self.token
                )
    
            # Insert and run the background update.
            self.get_success(
                self.store.db_pool.simple_insert(
                    "background_updates",
                    {"update_name": "chain_cover", "progress_json": "{}"},
                )
            )
    
            # Ugh, have to reset this flag
            self.store.db_pool.updates._all_done = False
    
            iterations = 0
            while not self.get_success(
                self.store.db_pool.updates.has_completed_background_updates()
            ):
                iterations += 1
                self.get_success(
    
                    self.store.db_pool.updates.do_next_background_update(False), by=0.1
    
                )
    
            # Ensure that we did actually take multiple iterations to process the
            # room.
            self.assertGreater(iterations, 1)
    
            # Test that the `has_auth_chain_index` has been set
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
            self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))