From c46d452c7cdf7289ae7e8677b44a88c747761dce Mon Sep 17 00:00:00 2001
From: Erik Johnston <erikj@element.io>
Date: Mon, 3 Feb 2025 20:04:19 +0100
Subject: [PATCH] Fix bug where purging history could lead to increase in disk
 space usage (#18131)

When purging history, we try and delete any state groups that become
unreferenced (i.e. there are no longer any events that directly
reference them). When we delete a state group that is referenced by
another state group, we "de-delta" that state group so that it no longer
refers to the state group that is deleted.

There are two bugs with this approach that we fix here:
1. There is a common pattern where we end up storing two state groups
when persisting a state event: the state before and after the new state
event, where the latter is stored as a delta to the former. When
deleting state groups we only deleted the "new" state and left (and
potentially de-deltaed) the old state. This was due to a bug/typo when
trying to find referenced state groups.
2. There are times where we store unreferenced state groups in the DB,
during the purging of history these would not get rechecked and instead
always de-deltaed. Instead, we should check for this case and delete any
unreferenced state groups rather than de-deltaing them.

The effect of the above bugs is that when purging history we'd end up
with lots of unreferenced state groups that had been de-deltaed (i.e.
stored as the full state). This can lead to dramatic increases in
storage space used.
---
 changelog.d/18131.bugfix                    |  1 +
 synapse/storage/controllers/purge_events.py | 10 +++
 synapse/storage/databases/state/store.py    | 31 ++++++++-
 tests/storage/test_purge.py                 | 75 +++++++++++++++++++++
 4 files changed, 116 insertions(+), 1 deletion(-)
 create mode 100644 changelog.d/18131.bugfix

diff --git a/changelog.d/18131.bugfix b/changelog.d/18131.bugfix
new file mode 100644
index 0000000000..4d0c19fab9
--- /dev/null
+++ b/changelog.d/18131.bugfix
@@ -0,0 +1 @@
+Fix rare edge case where state groups could be deleted while we are persisting new events that reference them.
diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 2d6f80f770..47cec8c469 100644
--- a/synapse/storage/controllers/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -128,6 +128,16 @@ class PurgeEventsStorageController:
             next_to_search |= prevs
             state_groups_seen |= prevs
 
+            # We also check to see if anything referencing the state groups are
+            # also unreferenced. This helps ensure that we delete unreferenced
+            # state groups, if we don't then we will de-delta them when we
+            # delete the other state groups leading to increased DB usage.
+            next_edges = await self.stores.state.get_next_state_groups(current_search)
+            nexts = set(next_edges.keys())
+            nexts -= state_groups_seen
+            next_to_search |= nexts
+            state_groups_seen |= nexts
+
         to_delete = state_groups_seen - referenced_groups
 
         return to_delete
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0f47642ae5..8c7980e719 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -853,7 +853,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             List[Tuple[int, int]],
             await self.db_pool.simple_select_many_batch(
                 table="state_group_edges",
-                column="prev_state_group",
+                column="state_group",
                 iterable=state_groups,
                 keyvalues={},
                 retcols=("state_group", "prev_state_group"),
@@ -863,6 +863,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return dict(rows)
 
+    @trace
+    @tag_args
+    async def get_next_state_groups(
+        self, state_groups: Iterable[int]
+    ) -> Dict[int, int]:
+        """Fetch the groups that have the given state groups as their previous
+        state groups.
+
+        Args:
+            state_groups
+
+        Returns:
+            A mapping from state group to previous state group.
+        """
+
+        rows = cast(
+            List[Tuple[int, int]],
+            await self.db_pool.simple_select_many_batch(
+                table="state_group_edges",
+                column="prev_state_group",
+                iterable=state_groups,
+                keyvalues={},
+                retcols=("state_group", "prev_state_group"),
+                desc="get_next_state_groups",
+            ),
+        )
+
+        return dict(rows)
+
     async def purge_room_state(self, room_id: str) -> None:
         return await self.db_pool.runInteraction(
             "purge_room_state",
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index efd8d25bd1..5d6a8518c0 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -195,3 +195,78 @@ class PurgeTests(HomeserverTestCase):
         self.assertEqual(second_state, {})
         self.assertEqual(third_state, {})
         self.assertNotEqual(last_state, {})
+
+    def test_purge_unreferenced_state_group(self) -> None:
+        """Test that purging a room also gets rid of unreferenced state groups
+        it encounters during the purge.
+
+        This is important, as otherwise these unreferenced state groups get
+        "de-deltaed" during the purge process, consuming lots of disk space.
+        """
+
+        self.helper.send(self.room_id, body="test1")
+        state1 = self.helper.send_state(
+            self.room_id, "org.matrix.test", body={"number": 2}
+        )
+        state2 = self.helper.send_state(
+            self.room_id, "org.matrix.test", body={"number": 3}
+        )
+        self.helper.send(self.room_id, body="test4")
+        last = self.helper.send(self.room_id, body="test5")
+
+        # Create an unreferenced state group that has a prev group of one of the
+        # to-be-purged events.
+        prev_group = self.get_success(
+            self.store._get_state_group_for_event(state1["event_id"])
+        )
+        unreferenced_state_group = self.get_success(
+            self.state_store.store_state_group(
+                event_id=last["event_id"],
+                room_id=self.room_id,
+                prev_group=prev_group,
+                delta_ids={("org.matrix.test", ""): state2["event_id"]},
+                current_state_ids=None,
+            )
+        )
+
+        # Get the topological token
+        token = self.get_success(
+            self.store.get_topological_token_for_event(last["event_id"])
+        )
+        token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+
+        # Purge everything before this topological token
+        self.get_success(
+            self._storage_controllers.purge_events.purge_history(
+                self.room_id, token_str, True
+            )
+        )
+
+        # Advance so that the background jobs to delete the state groups runs
+        self.reactor.advance(
+            1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000
+        )
+
+        # We expect that the unreferenced state group has been deleted.
+        row = self.get_success(
+            self.state_store.db_pool.simple_select_one_onecol(
+                table="state_groups",
+                keyvalues={"id": unreferenced_state_group},
+                retcol="id",
+                allow_none=True,
+                desc="test_purge_unreferenced_state_group",
+            )
+        )
+        self.assertIsNone(row)
+
+        # We expect there to now only be one state group for the room, which is
+        # the state group of the last event (as the only outlier).
+        state_groups = self.get_success(
+            self.state_store.db_pool.simple_select_onecol(
+                table="state_groups",
+                keyvalues={"room_id": self.room_id},
+                retcol="id",
+                desc="test_purge_unreferenced_state_group",
+            )
+        )
+        self.assertEqual(len(state_groups), 1)
-- 
GitLab