From b732d13d4cf6e6b474980b16b4920eebebe4b197 Mon Sep 17 00:00:00 2001
From: Erik Johnston <erikj@element.io>
Date: Wed, 11 Sep 2024 15:38:46 +0100
Subject: [PATCH] Sliding sync: various fixups to the background update
 (#17652)

---
 changelog.d/17652.misc                        |   1 +
 synapse/storage/databases/main/events.py      |  45 ++++-
 .../databases/main/events_bg_updates.py       | 190 +++++++++++++-----
 tests/storage/test_sliding_sync_tables.py     | 130 ------------
 4 files changed, 186 insertions(+), 180 deletions(-)
 create mode 100644 changelog.d/17652.misc

diff --git a/changelog.d/17652.misc b/changelog.d/17652.misc
new file mode 100644
index 0000000000..756918e2b2
--- /dev/null
+++ b/changelog.d/17652.misc
@@ -0,0 +1 @@
+Pre-populate room data used in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint for quick filtering/sorting.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index e5f63019fd..c0b7d8107d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1980,7 +1980,12 @@ class PersistEventsStore:
             if state_key == (EventTypes.Create, ""):
                 room_type = event.content.get(EventContentFields.ROOM_TYPE)
                 # Scrutinize JSON values
-                if room_type is None or isinstance(room_type, str):
+                if room_type is None or (
+                    isinstance(room_type, str)
+                    # We ignore values with null bytes as Postgres doesn't allow them in
+                    # text columns.
+                    and "\0" not in room_type
+                ):
                     sliding_sync_insert_map["room_type"] = room_type
             elif state_key == (EventTypes.RoomEncryption, ""):
                 encryption_algorithm = event.content.get(
@@ -1990,15 +1995,26 @@ class PersistEventsStore:
                 sliding_sync_insert_map["is_encrypted"] = is_encrypted
             elif state_key == (EventTypes.Name, ""):
                 room_name = event.content.get(EventContentFields.ROOM_NAME)
-                # Scrutinize JSON values
-                if room_name is None or isinstance(room_name, str):
+                # Scrutinize JSON values. We ignore values with nulls as
+                # postgres doesn't allow null bytes in text columns.
+                if room_name is None or (
+                    isinstance(room_name, str)
+                    # We ignore values with null bytes as Postgres doesn't allow them in
+                    # text columns.
+                    and "\0" not in room_name
+                ):
                     sliding_sync_insert_map["room_name"] = room_name
             elif state_key == (EventTypes.Tombstone, ""):
                 successor_room_id = event.content.get(
                     EventContentFields.TOMBSTONE_SUCCESSOR_ROOM
                 )
                 # Scrutinize JSON values
-                if successor_room_id is None or isinstance(successor_room_id, str):
+                if successor_room_id is None or (
+                    isinstance(successor_room_id, str)
+                    # We ignore values with null bytes as Postgres doesn't allow them in
+                    # text columns.
+                    and "\0" not in successor_room_id
+                ):
                     sliding_sync_insert_map["tombstone_successor_room_id"] = (
                         successor_room_id
                     )
@@ -2081,6 +2097,21 @@ class PersistEventsStore:
                     else None
                 )
 
+                # Check for null bytes in the room name and type. We have to
+                # ignore values with null bytes as Postgres doesn't allow them
+                # in text columns.
+                if (
+                    sliding_sync_insert_map["room_name"] is not None
+                    and "\0" in sliding_sync_insert_map["room_name"]
+                ):
+                    sliding_sync_insert_map.pop("room_name")
+
+                if (
+                    sliding_sync_insert_map["room_type"] is not None
+                    and "\0" in sliding_sync_insert_map["room_type"]
+                ):
+                    sliding_sync_insert_map.pop("room_type")
+
                 # Find the tombstone_successor_room_id
                 # Note: This isn't one of the stripped state events according to the spec
                 # but seems like there is no reason not to support this kind of thing.
@@ -2095,6 +2126,12 @@ class PersistEventsStore:
                     else None
                 )
 
+                if (
+                    sliding_sync_insert_map["tombstone_successor_room_id"] is not None
+                    and "\0" in sliding_sync_insert_map["tombstone_successor_room_id"]
+                ):
+                    sliding_sync_insert_map.pop("tombstone_successor_room_id")
+
             else:
                 # No stripped state provided
                 sliding_sync_insert_map["has_known_state"] = False
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b3244f7457..743200471b 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_worker import (
 )
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection
 from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
@@ -1877,9 +1878,29 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
         def _find_memberships_to_update_txn(
             txn: LoggingTransaction,
         ) -> List[
-            Tuple[str, Optional[str], str, str, str, str, int, Optional[str], bool]
+            Tuple[
+                str,
+                Optional[str],
+                Optional[str],
+                str,
+                str,
+                str,
+                str,
+                int,
+                Optional[str],
+                bool,
+            ]
         ]:
             # Fetch the set of event IDs that we want to update
+            #
+            # We skip over rows which we've already handled, i.e. have a
+            # matching row in `sliding_sync_membership_snapshots` with the same
+            # room, user and event ID.
+            #
+            # We also ignore rooms that the user has left themselves (i.e. not
+            # kicked). This is to avoid having to port lots of old rooms that we
+            # will never send down sliding sync (as we exclude such rooms from
+            # initial syncs).
 
             if initial_phase:
                 # There are some old out-of-band memberships (before
@@ -1892,6 +1913,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                     SELECT
                         c.room_id,
                         r.room_id,
+                        r.room_version,
                         c.user_id,
                         e.sender,
                         c.event_id,
@@ -1900,9 +1922,11 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                         e.instance_name,
                         e.outlier
                     FROM local_current_membership AS c
+                    LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id)
                     INNER JOIN events AS e USING (event_id)
                     LEFT JOIN rooms AS r ON (c.room_id = r.room_id)
                     WHERE (c.room_id, c.user_id) > (?, ?)
+                        AND (m.user_id IS NULL OR c.event_id != m.membership_event_id)
                     ORDER BY c.room_id ASC, c.user_id ASC
                     LIMIT ?
                     """,
@@ -1922,7 +1946,8 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                     """
                     SELECT
                         c.room_id,
-                        c.room_id,
+                        r.room_id,
+                        r.room_version,
                         c.user_id,
                         e.sender,
                         c.event_id,
@@ -1931,9 +1956,12 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                         e.instance_name,
                         e.outlier
                     FROM local_current_membership AS c
+                    LEFT JOIN sliding_sync_membership_snapshots AS m USING (room_id, user_id)
                     INNER JOIN events AS e USING (event_id)
-                    WHERE event_stream_ordering > ?
-                    ORDER BY event_stream_ordering ASC
+                    LEFT JOIN rooms AS r ON (c.room_id = r.room_id)
+                    WHERE c.event_stream_ordering > ?
+                        AND (m.user_id IS NULL OR c.event_id != m.membership_event_id)
+                    ORDER BY c.event_stream_ordering ASC
                     LIMIT ?
                     """,
                     (last_event_stream_ordering, batch_size),
@@ -1944,7 +1972,16 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
             memberships_to_update_rows = cast(
                 List[
                     Tuple[
-                        str, Optional[str], str, str, str, str, int, Optional[str], bool
+                        str,
+                        Optional[str],
+                        Optional[str],
+                        str,
+                        str,
+                        str,
+                        str,
+                        int,
+                        Optional[str],
+                        bool,
                     ]
                 ],
                 txn.fetchall(),
@@ -1977,7 +2014,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
 
         def _find_previous_invite_or_knock_membership_txn(
             txn: LoggingTransaction, room_id: str, user_id: str, event_id: str
-        ) -> Tuple[str, str]:
+        ) -> Optional[Tuple[str, str]]:
             # Find the previous invite/knock event before the leave event
             #
             # Here are some notes on how we landed on this query:
@@ -2027,8 +2064,13 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
             )
             row = txn.fetchone()
 
-            # We should see a corresponding previous invite/knock event
-            assert row is not None
+            if row is None:
+                # Generally we should have an invite or knock event for leaves
+                # that are outliers, however this may not always be the case
+                # (e.g. a local user got kicked but the kick event got pulled in
+                # as an outlier).
+                return None
+
             event_id, membership = row
 
             return event_id, membership
@@ -2043,6 +2085,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
         for (
             room_id,
             room_id_from_rooms_table,
+            room_version_id,
             user_id,
             sender,
             membership_event_id,
@@ -2061,6 +2104,14 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                 Membership.BAN,
             )
 
+            if (
+                room_version_id is not None
+                and room_version_id not in KNOWN_ROOM_VERSIONS
+            ):
+                # Ignore rooms with unknown room versions (these were
+                # experimental rooms, that we no longer support).
+                continue
+
             # There are some old out-of-band memberships (before
             # https://github.com/matrix-org/synapse/issues/6983) where we don't have the
             # corresponding room stored in the `rooms` table`. We have a `FOREIGN KEY`
@@ -2148,14 +2199,17 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                     # in the events table though. We'll just say that we don't
                     # know the state for these rooms and continue on with our
                     # day.
-                    sliding_sync_membership_snapshots_insert_map["has_known_state"] = (
-                        False
-                    )
+                    sliding_sync_membership_snapshots_insert_map = {
+                        "has_known_state": False,
+                        "room_type": None,
+                        "room_name": None,
+                        "is_encrypted": False,
+                    }
             elif membership in (Membership.INVITE, Membership.KNOCK) or (
                 membership in (Membership.LEAVE, Membership.BAN) and is_outlier
             ):
-                invite_or_knock_event_id = membership_event_id
-                invite_or_knock_membership = membership
+                invite_or_knock_event_id = None
+                invite_or_knock_membership = None
 
                 # If the event is an `out_of_band_membership` (special case of
                 # `outlier`), we never had historical state so we have to pull from
@@ -2164,35 +2218,55 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                 # membership (i.e. the room shouldn't disappear if your using the
                 # `is_encrypted` filter and you leave).
                 if membership in (Membership.LEAVE, Membership.BAN) and is_outlier:
-                    (
-                        invite_or_knock_event_id,
-                        invite_or_knock_membership,
-                    ) = await self.db_pool.runInteraction(
+                    previous_membership = await self.db_pool.runInteraction(
                         "sliding_sync_membership_snapshots_bg_update._find_previous_invite_or_knock_membership_txn",
                         _find_previous_invite_or_knock_membership_txn,
                         room_id,
                         user_id,
                         membership_event_id,
                     )
+                    if previous_membership is not None:
+                        (
+                            invite_or_knock_event_id,
+                            invite_or_knock_membership,
+                        ) = previous_membership
+                else:
+                    invite_or_knock_event_id = membership_event_id
+                    invite_or_knock_membership = membership
 
-                # Pull from the stripped state on the invite/knock event
-                invite_or_knock_event = await self.get_event(invite_or_knock_event_id)
-
-                raw_stripped_state_events = None
-                if invite_or_knock_membership == Membership.INVITE:
-                    invite_room_state = invite_or_knock_event.unsigned.get(
-                        "invite_room_state"
-                    )
-                    raw_stripped_state_events = invite_room_state
-                elif invite_or_knock_membership == Membership.KNOCK:
-                    knock_room_state = invite_or_knock_event.unsigned.get(
-                        "knock_room_state"
+                if (
+                    invite_or_knock_event_id is not None
+                    and invite_or_knock_membership is not None
+                ):
+                    # Pull from the stripped state on the invite/knock event
+                    invite_or_knock_event = await self.get_event(
+                        invite_or_knock_event_id
                     )
-                    raw_stripped_state_events = knock_room_state
 
-                sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state(
-                    raw_stripped_state_events
-                )
+                    raw_stripped_state_events = None
+                    if invite_or_knock_membership == Membership.INVITE:
+                        invite_room_state = invite_or_knock_event.unsigned.get(
+                            "invite_room_state"
+                        )
+                        raw_stripped_state_events = invite_room_state
+                    elif invite_or_knock_membership == Membership.KNOCK:
+                        knock_room_state = invite_or_knock_event.unsigned.get(
+                            "knock_room_state"
+                        )
+                        raw_stripped_state_events = knock_room_state
+
+                    sliding_sync_membership_snapshots_insert_map = PersistEventsStore._get_sliding_sync_insert_values_from_stripped_state(
+                        raw_stripped_state_events
+                    )
+                else:
+                    # We couldn't find any state for the membership, so we just have to
+                    # leave it as empty.
+                    sliding_sync_membership_snapshots_insert_map = {
+                        "has_known_state": False,
+                        "room_type": None,
+                        "room_name": None,
+                        "is_encrypted": False,
+                    }
 
                 # We should have some insert values for each room, even if no
                 # stripped state is on the event because we still want to record
@@ -2311,19 +2385,42 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
                 )
                 # We need to find the `forgotten` value during the transaction because
                 # we can't risk inserting stale data.
-                txn.execute(
-                    """
-                    UPDATE sliding_sync_membership_snapshots
-                    SET
-                        forgotten = (SELECT forgotten FROM room_memberships WHERE event_id = ?)
-                    WHERE room_id = ? and user_id = ?
-                    """,
-                    (
-                        membership_event_id,
-                        room_id,
-                        user_id,
-                    ),
-                )
+                if isinstance(txn.database_engine, PostgresEngine):
+                    txn.execute(
+                        """
+                        UPDATE sliding_sync_membership_snapshots
+                        SET
+                            forgotten = m.forgotten
+                        FROM room_memberships AS m
+                        WHERE sliding_sync_membership_snapshots.room_id = ?
+                            AND sliding_sync_membership_snapshots.user_id = ?
+                            AND membership_event_id = ?
+                            AND membership_event_id = m.event_id
+                            AND m.event_id IS NOT NULL
+                        """,
+                        (
+                            room_id,
+                            user_id,
+                            membership_event_id,
+                        ),
+                    )
+                else:
+                    # SQLite doesn't support UPDATE FROM before 3.33.0, so we do
+                    # this via sub-selects.
+                    txn.execute(
+                        """
+                        UPDATE sliding_sync_membership_snapshots
+                        SET
+                            forgotten = (SELECT forgotten FROM room_memberships WHERE event_id = ?)
+                        WHERE room_id = ? and user_id = ? AND membership_event_id = ?
+                        """,
+                        (
+                            membership_event_id,
+                            room_id,
+                            user_id,
+                            membership_event_id,
+                        ),
+                    )
 
         await self.db_pool.runInteraction(
             "sliding_sync_membership_snapshots_bg_update", _fill_table_txn
@@ -2333,6 +2430,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
         (
             room_id,
             _room_id_from_rooms_table,
+            _room_version_id,
             user_id,
             _sender,
             _membership_event_id,
diff --git a/tests/storage/test_sliding_sync_tables.py b/tests/storage/test_sliding_sync_tables.py
index de80ad53cd..61dccc8077 100644
--- a/tests/storage/test_sliding_sync_tables.py
+++ b/tests/storage/test_sliding_sync_tables.py
@@ -4416,136 +4416,6 @@ class SlidingSyncTablesBackgroundUpdatesTestCase(SlidingSyncTablesTestCaseBase):
             ),
         )
 
-    def test_membership_snapshots_background_update_forgotten_partial(self) -> None:
-        """
-        Test an existing `sliding_sync_membership_snapshots` row is updated with the
-        latest `forgotten` status after the background update passes over it.
-        """
-        user1_id = self.register_user("user1", "pass")
-        user1_tok = self.login(user1_id, "pass")
-        user2_id = self.register_user("user2", "pass")
-        user2_tok = self.login(user2_id, "pass")
-
-        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
-
-        # User1 joins the room
-        self.helper.join(room_id, user1_id, tok=user1_tok)
-        # User1 leaves the room (we have to leave in order to forget the room)
-        self.helper.leave(room_id, user1_id, tok=user1_tok)
-
-        state_map = self.get_success(
-            self.storage_controllers.state.get_current_state(room_id)
-        )
-
-        # Forget the room
-        channel = self.make_request(
-            "POST",
-            f"/_matrix/client/r0/rooms/{room_id}/forget",
-            content={},
-            access_token=user1_tok,
-        )
-        self.assertEqual(channel.code, 200, channel.result)
-
-        # Clean-up the `sliding_sync_joined_rooms` table as if the forgotten status
-        # never made it into the table.
-        self.get_success(
-            self.store.db_pool.simple_update(
-                table="sliding_sync_membership_snapshots",
-                keyvalues={"room_id": room_id},
-                updatevalues={"forgotten": 0},
-                desc="sliding_sync_membership_snapshots.test_membership_snapshots_background_update_forgotten_partial",
-            )
-        )
-
-        # We should see the partial row that we made in preparation for the test.
-        sliding_sync_membership_snapshots_results = (
-            self._get_sliding_sync_membership_snapshots()
-        )
-        self.assertIncludes(
-            set(sliding_sync_membership_snapshots_results.keys()),
-            {
-                (room_id, user1_id),
-                (room_id, user2_id),
-            },
-            exact=True,
-        )
-        user1_snapshot = _SlidingSyncMembershipSnapshotResult(
-            room_id=room_id,
-            user_id=user1_id,
-            sender=user1_id,
-            membership_event_id=state_map[(EventTypes.Member, user1_id)].event_id,
-            membership=Membership.LEAVE,
-            event_stream_ordering=state_map[
-                (EventTypes.Member, user1_id)
-            ].internal_metadata.stream_ordering,
-            has_known_state=True,
-            room_type=None,
-            room_name=None,
-            is_encrypted=False,
-            tombstone_successor_room_id=None,
-            # Room is *not* forgotten because of our test preparation
-            forgotten=False,
-        )
-        self.assertEqual(
-            sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
-            user1_snapshot,
-        )
-        user2_snapshot = _SlidingSyncMembershipSnapshotResult(
-            room_id=room_id,
-            user_id=user2_id,
-            sender=user2_id,
-            membership_event_id=state_map[(EventTypes.Member, user2_id)].event_id,
-            membership=Membership.JOIN,
-            event_stream_ordering=state_map[
-                (EventTypes.Member, user2_id)
-            ].internal_metadata.stream_ordering,
-            has_known_state=True,
-            room_type=None,
-            room_name=None,
-            is_encrypted=False,
-            tombstone_successor_room_id=None,
-        )
-        self.assertEqual(
-            sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
-            user2_snapshot,
-        )
-
-        # Insert and run the background update.
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE,
-                    "progress_json": "{}",
-                },
-            )
-        )
-        self.store.db_pool.updates._all_done = False
-        self.wait_for_background_updates()
-
-        # Make sure the table is populated
-        sliding_sync_membership_snapshots_results = (
-            self._get_sliding_sync_membership_snapshots()
-        )
-        self.assertIncludes(
-            set(sliding_sync_membership_snapshots_results.keys()),
-            {
-                (room_id, user1_id),
-                (room_id, user2_id),
-            },
-            exact=True,
-        )
-        # Forgotten status is now updated
-        self.assertEqual(
-            sliding_sync_membership_snapshots_results.get((room_id, user1_id)),
-            attr.evolve(user1_snapshot, forgotten=True),
-        )
-        # Holds the info according to the current state when the user joined
-        self.assertEqual(
-            sliding_sync_membership_snapshots_results.get((room_id, user2_id)),
-            user2_snapshot,
-        )
-
 
 class SlidingSyncTablesCatchUpBackgroundUpdatesTestCase(SlidingSyncTablesTestCaseBase):
     """
-- 
GitLab