From 27dbb1b4290b9de64e24a11f892777378810b595 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erikj@element.io>
Date: Mon, 3 Feb 2025 18:58:55 +0100
Subject: [PATCH] Add locking to more safely delete state groups: Part 2
 (#18130)

This actually makes it so that deleting state groups goes via the new
mechanism.

c.f. #18107
---
 changelog.d/18130.bugfix                    |  1 +
 synapse/storage/controllers/purge_events.py | 80 +++++++++++++++++++--
 synapse/storage/databases/state/deletion.py | 65 +++++++++++++++++
 synapse/storage/databases/state/store.py    | 28 ++++++--
 tests/rest/client/utils.py                  |  2 +-
 tests/storage/test_purge.py                 | 67 +++++++++++++++++
 tests/storage/test_state_deletion.py        | 68 +++++++++++++++++-
 7 files changed, 297 insertions(+), 14 deletions(-)
 create mode 100644 changelog.d/18130.bugfix

diff --git a/changelog.d/18130.bugfix b/changelog.d/18130.bugfix
new file mode 100644
index 0000000000..4d0c19fab9
--- /dev/null
+++ b/changelog.d/18130.bugfix
@@ -0,0 +1 @@
+Fix rare edge case where state groups could be deleted while we are persisting new events that reference them.
diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 15c04ffef8..2d6f80f770 100644
--- a/synapse/storage/controllers/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -21,9 +21,10 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Set
+from typing import TYPE_CHECKING, Collection, Mapping, Set
 
 from synapse.logging.context import nested_logging_context
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.databases import Databases
 
 if TYPE_CHECKING:
@@ -38,6 +39,11 @@ class PurgeEventsStorageController:
     def __init__(self, hs: "HomeServer", stores: Databases):
         self.stores = stores
 
+        if hs.config.worker.run_background_tasks:
+            self._delete_state_loop_call = hs.get_clock().looping_call(
+                self._delete_state_groups_loop, 60 * 1000
+            )
+
     async def purge_room(self, room_id: str) -> None:
         """Deletes all record of a room"""
 
@@ -68,11 +74,15 @@ class PurgeEventsStorageController:
             logger.info("[purge] finding state groups that can be deleted")
             sg_to_delete = await self._find_unreferenced_groups(state_groups)
 
-            await self.stores.state.purge_unreferenced_state_groups(
-                room_id, sg_to_delete
+            # Mark these state groups as pending deletion, they will actually
+            # get deleted automatically later.
+            await self.stores.state_deletion.mark_state_groups_as_pending_deletion(
+                sg_to_delete
             )
 
-    async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
+    async def _find_unreferenced_groups(
+        self, state_groups: Collection[int]
+    ) -> Set[int]:
         """Used when purging history to figure out which state groups can be
         deleted.
 
@@ -121,3 +131,65 @@ class PurgeEventsStorageController:
         to_delete = state_groups_seen - referenced_groups
 
         return to_delete
+
+    @wrap_as_background_process("_delete_state_groups_loop")
+    async def _delete_state_groups_loop(self) -> None:
+        """Background task that deletes any state groups that may be pending
+        deletion."""
+
+        while True:
+            next_to_delete = await self.stores.state_deletion.get_next_state_group_collection_to_delete()
+            if next_to_delete is None:
+                break
+
+            (room_id, groups_to_sequences) = next_to_delete
+            made_progress = await self._delete_state_groups(
+                room_id, groups_to_sequences
+            )
+
+            # If no progress was made in deleting the state groups, then we
+            # break to allow a pause before trying again next time we get
+            # called.
+            if not made_progress:
+                break
+
+    async def _delete_state_groups(
+        self, room_id: str, groups_to_sequences: Mapping[int, int]
+    ) -> bool:
+        """Tries to delete the given state groups.
+
+        Returns:
+            Whether we made progress in deleting the state groups (or marking
+            them as referenced).
+        """
+
+        # We double check if any of the state groups have become referenced.
+        # This shouldn't happen, as any usages should cause the state group to
+        # be removed as pending deletion.
+        referenced_state_groups = await self.stores.main.get_referenced_state_groups(
+            groups_to_sequences
+        )
+
+        if referenced_state_groups:
+            # We mark any state groups that have become referenced as being
+            # used.
+            await self.stores.state_deletion.mark_state_groups_as_used(
+                referenced_state_groups
+            )
+
+            # Update list of state groups to remove referenced ones
+            groups_to_sequences = {
+                state_group: sequence_number
+                for state_group, sequence_number in groups_to_sequences.items()
+                if state_group not in referenced_state_groups
+            }
+
+        if not groups_to_sequences:
+            # We made progress here as long as we marked some state groups as
+            # now referenced.
+            return len(referenced_state_groups) > 0
+
+        return await self.stores.state.purge_unreferenced_state_groups(
+            room_id,
+            groups_to_sequences,
+        )
diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py
index 07dbbc8e75..4853e5aa2f 100644
--- a/synapse/storage/databases/state/deletion.py
+++ b/synapse/storage/databases/state/deletion.py
@@ -20,6 +20,7 @@ from typing import (
     AsyncIterator,
     Collection,
     Mapping,
+    Optional,
     Set,
     Tuple,
 )
@@ -307,6 +308,17 @@ class StateDeletionDataStore:
             desc="mark_state_groups_as_pending_deletion",
         )
 
+    async def mark_state_groups_as_used(self, state_groups: Collection[int]) -> None:
+        """Mark the given state groups as now being referenced"""
+
+        await self.db_pool.simple_delete_many(
+            table="state_groups_pending_deletion",
+            column="state_group",
+            iterable=state_groups,
+            keyvalues={},
+            desc="mark_state_groups_as_used",
+        )
+
     async def get_pending_deletions(
         self, state_groups: Collection[int]
     ) -> Mapping[int, int]:
@@ -444,3 +456,56 @@ class StateDeletionDataStore:
             can_be_deleted.difference_update(state_group for (state_group,) in txn)
 
         return can_be_deleted
+
+    async def get_next_state_group_collection_to_delete(
+        self,
+    ) -> Optional[Tuple[str, Mapping[int, int]]]:
+        """Get the next set of state groups to try and delete
+
+        Returns:
+            2-tuple of room_id and mapping of state groups to sequence number.
+        """
+        return await self.db_pool.runInteraction(
+            "get_next_state_group_collection_to_delete",
+            self._get_next_state_group_collection_to_delete_txn,
+        )
+
+    def _get_next_state_group_collection_to_delete_txn(
+        self,
+        txn: LoggingTransaction,
+    ) -> Optional[Tuple[str, Mapping[int, int]]]:
+        """Implementation of `get_next_state_group_collection_to_delete`"""
+
+        # We want to return chunks of state groups that were marked for deletion
+        # at the same time (this isn't necessary, just more efficient). We do
+        # this by looking for the oldest insertion_ts, and then pulling out all
+        # rows that have the same insertion_ts (and room ID).
+        now = self._clock.time_msec()
+
+        sql = """
+            SELECT room_id, insertion_ts
+            FROM state_groups_pending_deletion AS sd
+            INNER JOIN state_groups AS sg ON (id = sd.state_group)
+            LEFT JOIN state_groups_persisting AS sp USING (state_group)
+            WHERE insertion_ts < ? AND sp.state_group IS NULL
+            ORDER BY insertion_ts
+            LIMIT 1
+        """
+        txn.execute(sql, (now - self.DELAY_BEFORE_DELETION_MS,))
+        row = txn.fetchone()
+        if not row:
+            return None
+
+        (room_id, insertion_ts) = row
+
+        sql = """
+            SELECT state_group, sequence_number
+            FROM state_groups_pending_deletion AS sd
+            INNER JOIN state_groups AS sg ON (id = sd.state_group)
+            LEFT JOIN state_groups_persisting AS sp USING (state_group)
+            WHERE room_id = ? AND insertion_ts = ? AND sp.state_group IS NULL
+            ORDER BY insertion_ts
+        """
+        txn.execute(sql, (room_id, insertion_ts))
+
+        return room_id, dict(txn)
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7e986e0601..0f47642ae5 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -22,10 +22,10 @@
 import logging
 from typing import (
     TYPE_CHECKING,
-    Collection,
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -735,8 +735,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def purge_unreferenced_state_groups(
-        self, room_id: str, state_groups_to_delete: Collection[int]
-    ) -> None:
+        self,
+        room_id: str,
+        state_groups_to_sequence_numbers: Mapping[int, int],
+    ) -> bool:
         """Deletes no longer referenced state groups and de-deltas any state
         groups that reference them.
 
@@ -744,21 +746,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             room_id: The room the state groups belong to (must all be in the
                 same room).
             state_groups_to_delete: Set of all state groups to delete.
+
+        Returns:
+            Whether any state groups were actually deleted.
         """
 
-        await self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "purge_unreferenced_state_groups",
             self._purge_unreferenced_state_groups,
             room_id,
-            state_groups_to_delete,
+            state_groups_to_sequence_numbers,
         )
 
     def _purge_unreferenced_state_groups(
         self,
         txn: LoggingTransaction,
         room_id: str,
-        state_groups_to_delete: Collection[int],
-    ) -> None:
+        state_groups_to_sequence_numbers: Mapping[int, int],
+    ) -> bool:
+        state_groups_to_delete = self._state_deletion_store.get_state_groups_ready_for_potential_deletion_txn(
+            txn, state_groups_to_sequence_numbers
+        )
+
+        if not state_groups_to_delete:
+            return False
+
         logger.info(
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
@@ -821,6 +833,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             [(sg,) for sg in state_groups_to_delete],
         )
 
+        return True
+
     @trace
     @tag_args
     async def get_previous_state_groups(
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index dbd6049f9f..e766630afb 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -548,7 +548,7 @@ class RestHelper:
         room_id: str,
         event_type: str,
         body: Dict[str, Any],
-        tok: Optional[str],
+        tok: Optional[str] = None,
         expect_code: int = HTTPStatus.OK,
         state_key: str = "",
     ) -> JsonDict:
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 080d5640a5..efd8d25bd1 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.rest.client import room
 from synapse.server import HomeServer
+from synapse.types.state import StateFilter
 from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
@@ -40,6 +41,8 @@ class PurgeTests(HomeserverTestCase):
         self.room_id = self.helper.create_room_as(self.user_id)
 
         self.store = hs.get_datastores().main
+        self.state_store = hs.get_datastores().state
+        self.state_deletion_store = hs.get_datastores().state_deletion
         self._storage_controllers = self.hs.get_storage_controllers()
 
     def test_purge_history(self) -> None:
@@ -128,3 +131,67 @@ class PurgeTests(HomeserverTestCase):
         self.store._invalidate_local_get_event_cache(create_event.event_id)
         self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
         self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+
+    def test_purge_history_deletes_state_groups(self) -> None:
+        """Test that unreferenced state groups get cleaned up after purge"""
+
+        # Send four state changes to the room.
+        first = self.helper.send_state(
+            self.room_id, event_type="m.foo", body={"test": 1}
+        )
+        second = self.helper.send_state(
+            self.room_id, event_type="m.foo", body={"test": 2}
+        )
+        third = self.helper.send_state(
+            self.room_id, event_type="m.foo", body={"test": 3}
+        )
+        last = self.helper.send_state(
+            self.room_id, event_type="m.foo", body={"test": 4}
+        )
+
+        # Get references to the state groups
+        event_to_groups = self.get_success(
+            self.store._get_state_group_for_events(
+                [
+                    first["event_id"],
+                    second["event_id"],
+                    third["event_id"],
+                    last["event_id"],
+                ]
+            )
+        )
+
+        # Get the topological token
+        token = self.get_success(
+            self.store.get_topological_token_for_event(last["event_id"])
+        )
+        token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+
+        # Purge everything before this topological token
+        self.get_success(
+            self._storage_controllers.purge_events.purge_history(
+                self.room_id, token_str, True
+            )
+        )
+
+        # Advance so that the background jobs to delete the state groups runs
+        self.reactor.advance(
+            1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+        )
+
+        # We expect all the state groups associated with events above, except
+        # the last one, should return no state.
+        state_groups = self.get_success(
+            self.state_store._get_state_groups_from_groups(
+                list(event_to_groups.values()), StateFilter.all()
+            )
+        )
+        first_state = state_groups[event_to_groups[first["event_id"]]]
+        second_state = state_groups[event_to_groups[second["event_id"]]]
+        third_state = state_groups[event_to_groups[third["event_id"]]]
+        last_state = state_groups[event_to_groups[last["event_id"]]]
+
+        self.assertEqual(first_state, {})
+        self.assertEqual(second_state, {})
+        self.assertEqual(third_state, {})
+        self.assertNotEqual(last_state, {})
diff --git a/tests/storage/test_state_deletion.py b/tests/storage/test_state_deletion.py
index 19b290b554..a4d318ae20 100644
--- a/tests/storage/test_state_deletion.py
+++ b/tests/storage/test_state_deletion.py
@@ -41,6 +41,11 @@ class StateDeletionStoreTestCase(HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.state_store = hs.get_datastores().state
         self.state_deletion_store = hs.get_datastores().state_deletion
+        self.purge_events = hs.get_storage_controllers().purge_events
+
+        # We want to disable the automatic deletion of state groups in the
+        # background, so we can do controlled tests.
+        self.purge_events._delete_state_loop_call.stop()
 
         self.user_id = self.register_user("test", "password")
         tok = self.login("test", "password")
@@ -341,7 +346,7 @@ class StateDeletionStoreTestCase(HomeserverTestCase):
 
     def test_remove_ancestors_from_can_delete(self) -> None:
         """Test that if a state group is not ready to be deleted, we also don't
-        delete anything that is refernced by it"""
+        delete anything that is referenced by it"""
 
         event, context = self.get_success(
             create_event(
@@ -354,7 +359,7 @@ class StateDeletionStoreTestCase(HomeserverTestCase):
         )
         assert context.state_group is not None
 
-        # Create a new state group that refernces the one from the event
+        # Create a new state group that references the one from the event
         new_state_group = self.get_success(
             self.state_store.store_state_group(
                 event.event_id,
@@ -409,3 +414,62 @@ class StateDeletionStoreTestCase(HomeserverTestCase):
             )
         )
         self.assertCountEqual(can_be_deleted, [])
+
+    def test_newly_referenced_state_group_gets_removed_from_pending(self) -> None:
+        """Check that if a state group marked for deletion becomes referenced
+        (without being removed from pending deletion table), it gets removed
+        from pending deletion table."""
+
+        event, context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.test",
+                state_key="",
+                sender=self.user_id,
+            )
+        )
+        assert context.state_group is not None
+
+        # Mark a state group that we're referencing as pending deletion.
+        self.get_success(
+            self.state_deletion_store.mark_state_groups_as_pending_deletion(
+                [context.state_group]
+            )
+        )
+
+        # Advance time enough so we can delete the state group so they're both
+        # ready for deletion.
+        self.reactor.advance(
+            1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+        )
+
+        # Manually insert into the table to mimic the state group getting used.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                table="event_to_state_groups",
+                values={"state_group": context.state_group, "event_id": event.event_id},
+                desc="test_newly_referenced_state_group_gets_removed_from_pending",
+            )
+        )
+
+        # Manually run the background task to delete pending state groups.
+        self.get_success(self.purge_events._delete_state_groups_loop())
+
+        # The pending deletion flag should be cleared...
+        pending_deletion = self.get_success(
+            self.state_deletion_store.db_pool.simple_select_one_onecol(
+                table="state_groups_pending_deletion",
+                keyvalues={"state_group": context.state_group},
+                retcol="1",
+                allow_none=True,
+                desc="test_newly_referenced_state_group_gets_removed_from_pending",
+            )
+        )
+        self.assertIsNone(pending_deletion)
+
+        # .. but the state should not have been deleted.
+        state = self.get_success(
+            self.state_store._get_state_for_groups([context.state_group])
+        )
+        self.assertGreater(len(state[context.state_group]), 0)
-- 
GitLab