From c2e5e9e67c24264f5a12bf3ceaa9c4e195547d26 Mon Sep 17 00:00:00 2001
From: Eric Eastwood <eric.eastwood@beta.gouv.fr>
Date: Thu, 19 Sep 2024 10:07:18 -0500
Subject: [PATCH] Sliding Sync: Avoid fetching left rooms and add back
 `newly_left` rooms (#17725)

Performance optimization: We can avoid fetching rooms that the user has
left themselves (which could be a significant amount), then only add
back rooms that the user has `newly_left` (left in the token range of an
incremental sync). It's a lot faster to fetch less rooms than fetch them
all and throw them away in most cases. Since the user only leaves a room
(or is state reset out) once in a blue moon, we can avoid a lot of work.

Based on @erikjohnston's branch, erikj/ss_perf


---------

Co-authored-by: Erik Johnston <erik@matrix.org>
---
 changelog.d/17725.misc                        |   1 +
 synapse/handlers/sliding_sync/__init__.py     |  26 +-
 synapse/handlers/sliding_sync/room_lists.py   | 279 ++++++----
 synapse/storage/databases/main/roommember.py  |  46 +-
 synapse/storage/databases/main/state.py       |  20 +-
 synapse/storage/databases/main/stream.py      |   7 +
 .../client/sliding_sync/test_sliding_sync.py  | 479 +++++++++++++++++-
 tests/storage/test_stream.py                  |  85 +++-
 8 files changed, 833 insertions(+), 110 deletions(-)
 create mode 100644 changelog.d/17725.misc

diff --git a/changelog.d/17725.misc b/changelog.d/17725.misc
new file mode 100644
index 0000000000..2a53bb1491
--- /dev/null
+++ b/changelog.d/17725.misc
@@ -0,0 +1 @@
+More efficiently fetch rooms for Sliding Sync.
diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py
index 4010f28607..5206af22ec 100644
--- a/synapse/handlers/sliding_sync/__init__.py
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -495,6 +495,24 @@ class SlidingSyncHandler:
             room_sync_config.timeline_limit,
         )
 
+        # Handle state resets. For example, if we see
+        # `room_membership_for_user_at_to_token.event_id=None and
+        # room_membership_for_user_at_to_token.membership is not None`, we should
+        # indicate to the client that a state reset happened. Perhaps we should indicate
+        # this by setting `initial: True` and empty `required_state: []`.
+        state_reset_out_of_room = False
+        if (
+            room_membership_for_user_at_to_token.event_id is None
+            and room_membership_for_user_at_to_token.membership is not None
+        ):
+            # We only expect the `event_id` to be `None` if you've been state reset out
+            # of the room (meaning you're no longer in the room). We could put this as
+            # part of the if-statement above but we want to handle every case where
+            # `event_id` is `None`.
+            assert room_membership_for_user_at_to_token.membership is Membership.LEAVE
+
+            state_reset_out_of_room = True
+
         # Determine whether we should limit the timeline to the token range.
         #
         # We should return historical messages (before token range) in the
@@ -527,7 +545,7 @@ class SlidingSyncHandler:
         from_bound = None
         initial = True
         ignore_timeline_bound = False
-        if from_token and not newly_joined:
+        if from_token and not newly_joined and not state_reset_out_of_room:
             room_status = previous_connection_state.rooms.have_sent_room(room_id)
             if room_status.status == HaveSentRoomFlag.LIVE:
                 from_bound = from_token.stream_token.room_key
@@ -732,12 +750,6 @@ class SlidingSyncHandler:
 
             stripped_state.append(strip_event(invite_or_knock_event))
 
-        # TODO: Handle state resets. For example, if we see
-        # `room_membership_for_user_at_to_token.event_id=None and
-        # room_membership_for_user_at_to_token.membership is not None`, we should
-        # indicate to the client that a state reset happened. Perhaps we should indicate
-        # this by setting `initial: True` and empty `required_state`.
-
         # Get the changes to current state in the token range from the
         # `current_state_delta_stream` table.
         #
diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py
index 353c491f72..bf19eb735b 100644
--- a/synapse/handlers/sliding_sync/room_lists.py
+++ b/synapse/handlers/sliding_sync/room_lists.py
@@ -56,7 +56,6 @@ from synapse.storage.roommember import (
 )
 from synapse.types import (
     MutableStateMap,
-    PersistedEventPosition,
     RoomStreamToken,
     StateMap,
     StrCollection,
@@ -81,6 +80,12 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class Sentinel(enum.Enum):
+    # defining a sentinel in this way allows mypy to correctly handle the
+    # type of a dictionary lookup and subsequent type narrowing.
+    UNSET_SENTINEL = object()
+
+
 # Helper definition for the types that we might return. We do this to avoid
 # copying data between types (which can be expensive for many rooms).
 RoomsForUserType = Union[RoomsForUserStateReset, RoomsForUser, RoomsForUserSlidingSync]
@@ -119,12 +124,6 @@ class SlidingSyncInterestedRooms:
     dm_room_ids: AbstractSet[str]
 
 
-class Sentinel(enum.Enum):
-    # defining a sentinel in this way allows mypy to correctly handle the
-    # type of a dictionary lookup and subsequent type narrowing.
-    UNSET_SENTINEL = object()
-
-
 def filter_membership_for_sync(
     *,
     user_id: str,
@@ -221,6 +220,9 @@ class SlidingSyncRoomLists:
         # include rooms that are outside the list ranges.
         all_rooms: Set[str] = set()
 
+        # Note: this won't include rooms the user has left themselves. We add back
+        # `newly_left` rooms below. This is more efficient than fetching all rooms and
+        # then filtering out the old left rooms.
         room_membership_for_user_map = await self.store.get_sliding_sync_rooms_for_user(
             user_id
         )
@@ -262,36 +264,11 @@ class SlidingSyncRoomLists:
                         event_id=change.event_id,
                         event_pos=change.event_pos,
                         room_version_id=change.room_version_id,
-                        # We keep the current state of the room though
+                        # We keep the state of the room though
                         has_known_state=existing_room.has_known_state,
                         room_type=existing_room.room_type,
                         is_encrypted=existing_room.is_encrypted,
                     )
-                else:
-                    # This can happen if we get "state reset" out of the room
-                    # after the `to_token`. In other words, there is no membership
-                    # for the room after the `to_token` but we see membership in
-                    # the token range.
-
-                    # Get the state at the time. Note that room type never changes,
-                    # so we can just get current room type
-                    room_type = await self.store.get_room_type(room_id)
-                    is_encrypted = await self.get_is_encrypted_for_room_at_token(
-                        room_id, to_token.room_key
-                    )
-
-                    # Add back rooms that the user was state-reset out of after `to_token`
-                    room_membership_for_user_map[room_id] = RoomsForUserSlidingSync(
-                        room_id=room_id,
-                        sender=change.sender,
-                        membership=change.membership,
-                        event_id=change.event_id,
-                        event_pos=change.event_pos,
-                        room_version_id=change.room_version_id,
-                        has_known_state=True,
-                        room_type=room_type,
-                        is_encrypted=is_encrypted,
-                    )
 
         (
             newly_joined_room_ids,
@@ -301,44 +278,88 @@ class SlidingSyncRoomLists:
         )
         dm_room_ids = await self._get_dm_rooms_for_user(user_id)
 
-        # Handle state resets in the from -> to token range.
-        state_reset_rooms = (
+        # Add back `newly_left` rooms (rooms left in the from -> to token range).
+        #
+        # We do this because `get_sliding_sync_rooms_for_user(...)` doesn't include
+        # rooms that the user left themselves as it's more efficient to add them back
+        # here than to fetch all rooms and then filter out the old left rooms. The user
+        # only leaves a room once in a blue moon so this barely needs to run.
+        #
+        missing_newly_left_rooms = (
             newly_left_room_map.keys() - room_membership_for_user_map.keys()
         )
-        if state_reset_rooms:
+        if missing_newly_left_rooms:
+            # TODO: It would be nice to avoid these copies
             room_membership_for_user_map = dict(room_membership_for_user_map)
-            for room_id in (
-                newly_left_room_map.keys() - room_membership_for_user_map.keys()
-            ):
-                # Get the state at the time. Note that room type never changes,
-                # so we can just get current room type
-                room_type = await self.store.get_room_type(room_id)
-                is_encrypted = await self.get_is_encrypted_for_room_at_token(
-                    room_id, newly_left_room_map[room_id].to_room_stream_token()
-                )
+            for room_id in missing_newly_left_rooms:
+                newly_left_room_for_user = newly_left_room_map[room_id]
+                # This should be a given
+                assert newly_left_room_for_user.membership == Membership.LEAVE
 
-                room_membership_for_user_map[room_id] = RoomsForUserSlidingSync(
-                    room_id=room_id,
-                    sender=None,
-                    membership=Membership.LEAVE,
-                    event_id=None,
-                    event_pos=newly_left_room_map[room_id],
-                    room_version_id=await self.store.get_room_version_id(room_id),
-                    has_known_state=True,
-                    room_type=room_type,
-                    is_encrypted=is_encrypted,
+                # Add back `newly_left` rooms
+                #
+                # Check for membership and state in the Sliding Sync tables as it's just
+                # another membership
+                newly_left_room_for_user_sliding_sync = (
+                    await self.store.get_sliding_sync_room_for_user(user_id, room_id)
                 )
+                # If the membership exists, it's just a normal user left the room on
+                # their own
+                if newly_left_room_for_user_sliding_sync is not None:
+                    room_membership_for_user_map[room_id] = (
+                        newly_left_room_for_user_sliding_sync
+                    )
+
+                    change = changes.get(room_id)
+                    if change is not None:
+                        # Update room membership events to the point in time of the `to_token`
+                        room_membership_for_user_map[room_id] = RoomsForUserSlidingSync(
+                            room_id=room_id,
+                            sender=change.sender,
+                            membership=change.membership,
+                            event_id=change.event_id,
+                            event_pos=change.event_pos,
+                            room_version_id=change.room_version_id,
+                            # We keep the state of the room though
+                            has_known_state=newly_left_room_for_user_sliding_sync.has_known_state,
+                            room_type=newly_left_room_for_user_sliding_sync.room_type,
+                            is_encrypted=newly_left_room_for_user_sliding_sync.is_encrypted,
+                        )
+
+                # If we are `newly_left` from the room but can't find any membership,
+                # then we have been "state reset" out of the room
+                else:
+                    # Get the state at the time. We can't read from the Sliding Sync
+                    # tables because the user has no membership in the room according to
+                    # the state (thanks to the state reset).
+                    #
+                    # Note: `room_type` never changes, so we can just get current room
+                    # type
+                    room_type = await self.store.get_room_type(room_id)
+                    has_known_state = room_type is not ROOM_UNKNOWN_SENTINEL
+                    if isinstance(room_type, StateSentinel):
+                        room_type = None
+
+                    # Get the encryption status at the time of the token
+                    is_encrypted = await self.get_is_encrypted_for_room_at_token(
+                        room_id,
+                        newly_left_room_for_user.event_pos.to_room_stream_token(),
+                    )
+
+                    room_membership_for_user_map[room_id] = RoomsForUserSlidingSync(
+                        room_id=room_id,
+                        sender=newly_left_room_for_user.sender,
+                        membership=newly_left_room_for_user.membership,
+                        event_id=newly_left_room_for_user.event_id,
+                        event_pos=newly_left_room_for_user.event_pos,
+                        room_version_id=newly_left_room_for_user.room_version_id,
+                        has_known_state=has_known_state,
+                        room_type=room_type,
+                        is_encrypted=is_encrypted,
+                    )
 
         if sync_config.lists:
-            sync_room_map = {
-                room_id: room_membership_for_user
-                for room_id, room_membership_for_user in room_membership_for_user_map.items()
-                if filter_membership_for_sync(
-                    user_id=user_id,
-                    room_membership_for_user=room_membership_for_user,
-                    newly_left=room_id in newly_left_room_map,
-                )
-            }
+            sync_room_map = room_membership_for_user_map
             with start_active_span("assemble_sliding_window_lists"):
                 for list_key, list_config in sync_config.lists.items():
                     # Apply filters
@@ -347,6 +368,7 @@ class SlidingSyncRoomLists:
                         filtered_sync_room_map = await self.filter_rooms_using_tables(
                             user_id,
                             sync_room_map,
+                            previous_connection_state,
                             list_config.filters,
                             to_token,
                             dm_room_ids,
@@ -446,6 +468,9 @@ class SlidingSyncRoomLists:
 
         if sync_config.room_subscriptions:
             with start_active_span("assemble_room_subscriptions"):
+                # TODO: It would be nice to avoid these copies
+                room_membership_for_user_map = dict(room_membership_for_user_map)
+
                 # Find which rooms are partially stated and may need to be filtered out
                 # depending on the `required_state` requested (see below).
                 partial_state_rooms = await self.store.get_partial_rooms()
@@ -454,10 +479,20 @@ class SlidingSyncRoomLists:
                     room_id,
                     room_subscription,
                 ) in sync_config.room_subscriptions.items():
-                    if room_id not in room_membership_for_user_map:
+                    # Check if we have a membership for the room, but didn't pull it out
+                    # above. This could be e.g. a leave that we don't pull out by
+                    # default.
+                    current_room_entry = (
+                        await self.store.get_sliding_sync_room_for_user(
+                            user_id, room_id
+                        )
+                    )
+                    if not current_room_entry:
                         # TODO: Handle rooms the user isn't in.
                         continue
 
+                    room_membership_for_user_map[room_id] = current_room_entry
+
                     all_rooms.add(room_id)
 
                     # Take the superset of the `RoomSyncConfig` for each room.
@@ -471,8 +506,6 @@ class SlidingSyncRoomLists:
                         if room_id in partial_state_rooms:
                             continue
 
-                    all_rooms.add(room_id)
-
                     # Update our `relevant_room_map` with the room we're going to display
                     # and need to fetch more info about.
                     existing_room_sync_config = relevant_room_map.get(room_id)
@@ -487,7 +520,7 @@ class SlidingSyncRoomLists:
 
         # Filtered subset of `relevant_room_map` for rooms that may have updates
         # (in the event stream)
-        relevant_rooms_to_send_map = await self._filter_relevant_room_to_send(
+        relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send(
             previous_connection_state, from_token, relevant_room_map
         )
 
@@ -544,6 +577,7 @@ class SlidingSyncRoomLists:
                         filtered_sync_room_map = await self.filter_rooms(
                             sync_config.user,
                             sync_room_map,
+                            previous_connection_state,
                             list_config.filters,
                             to_token,
                             dm_room_ids,
@@ -674,7 +708,7 @@ class SlidingSyncRoomLists:
 
         # Filtered subset of `relevant_room_map` for rooms that may have updates
         # (in the event stream)
-        relevant_rooms_to_send_map = await self._filter_relevant_room_to_send(
+        relevant_rooms_to_send_map = await self._filter_relevant_rooms_to_send(
             previous_connection_state, from_token, relevant_room_map
         )
 
@@ -689,7 +723,7 @@ class SlidingSyncRoomLists:
             dm_room_ids=dm_room_ids,
         )
 
-    async def _filter_relevant_room_to_send(
+    async def _filter_relevant_rooms_to_send(
         self,
         previous_connection_state: PerConnectionState,
         from_token: Optional[StreamToken],
@@ -974,8 +1008,17 @@ class SlidingSyncRoomLists:
                 )
             ]
 
-        # If the user has never joined any rooms before, we can just return an empty list
-        if not room_for_user_list:
+        (
+            newly_joined_room_ids,
+            newly_left_room_map,
+        ) = await self._get_newly_joined_and_left_rooms(
+            user_id, to_token=to_token, from_token=from_token
+        )
+
+        # If the user has never joined any rooms before, we can just return an empty
+        # list. We also have to check the `newly_left_room_map` in case someone was
+        # state reset out of all of the rooms they were in.
+        if not room_for_user_list and not newly_left_room_map:
             return {}, set(), set()
 
         # Since we fetched the users room list at some point in time after the
@@ -993,30 +1036,22 @@ class SlidingSyncRoomLists:
             else:
                 rooms_for_user[room_id] = change_room_for_user
 
-        (
-            newly_joined_room_ids,
-            newly_left_room_ids,
-        ) = await self._get_newly_joined_and_left_rooms(
-            user_id, to_token=to_token, from_token=from_token
-        )
-
         # Ensure we have entries for rooms that the user has been "state reset"
         # out of. These are rooms appear in the `newly_left_rooms` map but
         # aren't in the `rooms_for_user` map.
-        for room_id, left_event_pos in newly_left_room_ids.items():
+        for room_id, newly_left_room_for_user in newly_left_room_map.items():
+            # If we already know about the room, it's not a state reset
             if room_id in rooms_for_user:
                 continue
 
-            rooms_for_user[room_id] = RoomsForUserStateReset(
-                room_id=room_id,
-                event_id=None,
-                event_pos=left_event_pos,
-                membership=Membership.LEAVE,
-                sender=None,
-                room_version_id=await self.store.get_room_version_id(room_id),
-            )
+            # This should be true if it's a state reset
+            assert newly_left_room_for_user.membership is Membership.LEAVE
+            assert newly_left_room_for_user.event_id is None
+            assert newly_left_room_for_user.sender is None
+
+            rooms_for_user[room_id] = newly_left_room_for_user
 
-        return rooms_for_user, newly_joined_room_ids, set(newly_left_room_ids)
+        return rooms_for_user, newly_joined_room_ids, set(newly_left_room_map)
 
     @trace
     async def _get_newly_joined_and_left_rooms(
@@ -1024,7 +1059,7 @@ class SlidingSyncRoomLists:
         user_id: str,
         to_token: StreamToken,
         from_token: Optional[StreamToken],
-    ) -> Tuple[AbstractSet[str], Mapping[str, PersistedEventPosition]]:
+    ) -> Tuple[AbstractSet[str], Mapping[str, RoomsForUserStateReset]]:
         """Fetch the sets of rooms that the user newly joined or left in the
         given token range.
 
@@ -1033,11 +1068,18 @@ class SlidingSyncRoomLists:
         "current memberships" of the user.
 
         Returns:
-            A 2-tuple of newly joined room IDs and a map of newly left room
-            IDs to the event position the leave happened at.
+            A 2-tuple of newly joined room IDs and a map of newly_left room
+            IDs to the `RoomsForUserStateReset` entry.
+
+            We're using `RoomsForUserStateReset` but that doesn't necessarily mean the
+            user was state reset of the rooms. It's just that the `event_id`/`sender`
+            are optional and we can't tell the difference between the server leaving the
+            room when the user was the last person participating in the room and left or
+            was state reset out of the room. To actually check for a state reset, you
+            need to check if a membership still exists in the room.
         """
         newly_joined_room_ids: Set[str] = set()
-        newly_left_room_map: Dict[str, PersistedEventPosition] = {}
+        newly_left_room_map: Dict[str, RoomsForUserStateReset] = {}
 
         # We need to figure out the
         #
@@ -1108,8 +1150,13 @@ class SlidingSyncRoomLists:
             # 1) Figure out newly_left rooms (> `from_token` and <= `to_token`).
             if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
                 # 1) Mark this room as `newly_left`
-                newly_left_room_map[room_id] = (
-                    last_membership_change_in_from_to_range.event_pos
+                newly_left_room_map[room_id] = RoomsForUserStateReset(
+                    room_id=room_id,
+                    sender=last_membership_change_in_from_to_range.sender,
+                    membership=Membership.LEAVE,
+                    event_id=last_membership_change_in_from_to_range.event_id,
+                    event_pos=last_membership_change_in_from_to_range.event_pos,
+                    room_version_id=await self.store.get_room_version_id(room_id),
                 )
 
         # 2) Figure out `newly_joined`
@@ -1553,6 +1600,7 @@ class SlidingSyncRoomLists:
         self,
         user: UserID,
         sync_room_map: Dict[str, RoomsForUserType],
+        previous_connection_state: PerConnectionState,
         filters: SlidingSyncConfig.SlidingSyncList.Filters,
         to_token: StreamToken,
         dm_room_ids: AbstractSet[str],
@@ -1738,14 +1786,33 @@ class SlidingSyncRoomLists:
                         )
                     }
 
+        # Keep rooms if the user has been state reset out of it but we previously sent
+        # down the connection before. We want to make sure that we send these down to
+        # the client regardless of filters so they find out about the state reset.
+        #
+        # We don't always have access to the state in a room after being state reset if
+        # no one else locally on the server is participating in the room so we patch
+        # these back in manually.
+        state_reset_out_of_room_id_set = {
+            room_id
+            for room_id in sync_room_map.keys()
+            if sync_room_map[room_id].event_id is None
+            and previous_connection_state.rooms.have_sent_room(room_id).status
+            != HaveSentRoomFlag.NEVER
+        }
+
         # Assemble a new sync room map but only with the `filtered_room_id_set`
-        return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
+        return {
+            room_id: sync_room_map[room_id]
+            for room_id in filtered_room_id_set | state_reset_out_of_room_id_set
+        }
 
     @trace
     async def filter_rooms_using_tables(
         self,
         user_id: str,
         sync_room_map: Mapping[str, RoomsForUserSlidingSync],
+        previous_connection_state: PerConnectionState,
         filters: SlidingSyncConfig.SlidingSyncList.Filters,
         to_token: StreamToken,
         dm_room_ids: AbstractSet[str],
@@ -1887,8 +1954,26 @@ class SlidingSyncRoomLists:
                         )
                     }
 
+        # Keep rooms if the user has been state reset out of it but we previously sent
+        # down the connection before. We want to make sure that we send these down to
+        # the client regardless of filters so they find out about the state reset.
+        #
+        # We don't always have access to the state in a room after being state reset if
+        # no one else locally on the server is participating in the room so we patch
+        # these back in manually.
+        state_reset_out_of_room_id_set = {
+            room_id
+            for room_id in sync_room_map.keys()
+            if sync_room_map[room_id].event_id is None
+            and previous_connection_state.rooms.have_sent_room(room_id).status
+            != HaveSentRoomFlag.NEVER
+        }
+
         # Assemble a new sync room map but only with the `filtered_room_id_set`
-        return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
+        return {
+            room_id: sync_room_map[room_id]
+            for room_id in filtered_room_id_set | state_reset_out_of_room_id_set
+        }
 
     @trace
     async def sort_rooms(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ded7948713..0a62613d34 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1404,7 +1404,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
     ) -> Mapping[str, RoomsForUserSlidingSync]:
         """Get all the rooms for a user to handle a sliding sync request.
 
-        Ignores forgotten rooms and rooms that the user has been kicked from.
+        Ignores forgotten rooms and rooms that the user has left themselves.
 
         Returns:
             Map from room ID to membership info
@@ -1429,6 +1429,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
                 LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join')
                 WHERE user_id = ?
                     AND m.forgotten = 0
+                    AND (m.membership != 'leave' OR m.user_id != m.sender)
             """
             txn.execute(sql, (user_id,))
             return {
@@ -1455,6 +1456,49 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             get_sliding_sync_rooms_for_user_txn,
         )
 
+    async def get_sliding_sync_room_for_user(
+        self, user_id: str, room_id: str
+    ) -> Optional[RoomsForUserSlidingSync]:
+        """Get the sliding sync room entry for the given user and room."""
+
+        def get_sliding_sync_room_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[RoomsForUserSlidingSync]:
+            sql = """
+                SELECT m.room_id, m.sender, m.membership, m.membership_event_id,
+                    r.room_version,
+                    m.event_instance_name, m.event_stream_ordering,
+                    m.has_known_state,
+                    COALESCE(j.room_type, m.room_type),
+                    COALESCE(j.is_encrypted, m.is_encrypted)
+                FROM sliding_sync_membership_snapshots AS m
+                INNER JOIN rooms AS r USING (room_id)
+                LEFT JOIN sliding_sync_joined_rooms AS j ON (j.room_id = m.room_id AND m.membership = 'join')
+                WHERE user_id = ?
+                    AND m.forgotten = 0
+                    AND m.room_id = ?
+            """
+            txn.execute(sql, (user_id, room_id))
+            row = txn.fetchone()
+            if not row:
+                return None
+
+            return RoomsForUserSlidingSync(
+                room_id=row[0],
+                sender=row[1],
+                membership=row[2],
+                event_id=row[3],
+                room_version_id=row[4],
+                event_pos=PersistedEventPosition(row[5], row[6]),
+                has_known_state=bool(row[7]),
+                room_type=row[8],
+                is_encrypted=row[9],
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_sliding_sync_room_for_user", get_sliding_sync_room_for_user_txn
+        )
+
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
     def __init__(
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ca31122ad3..60312d770d 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -308,8 +308,24 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return create_event
 
     @cached(max_entries=10000)
-    async def get_room_type(self, room_id: str) -> Optional[str]:
-        raise NotImplementedError()
+    async def get_room_type(self, room_id: str) -> Union[Optional[str], Sentinel]:
+        """Fetch room type for given room.
+
+        Since this function is cached, any missing values would be cached as
+        `None`. In order to distinguish between an unencrypted room that has
+        `None` encryption and a room that is unknown to the server where we
+        might want to omit the value (which would make it cached as `None`),
+        instead we use the sentinel value `ROOM_UNKNOWN_SENTINEL`.
+        """
+
+        try:
+            create_event = await self.get_create_event_for_room(room_id)
+            return create_event.content.get(EventContentFields.ROOM_TYPE)
+        except NotFoundError:
+            # We use the sentinel value to distinguish between `None` which is a
+            # valid room type and a room that is unknown to the server so the value
+            # is just unset.
+            return ROOM_UNKNOWN_SENTINEL
 
     @cachedList(cached_method_name="get_room_type", list_name="room_ids")
     async def bulk_get_room_type(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 0ab7cb8dbd..964f41ca57 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -941,6 +941,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             All membership changes to the current state in the token range. Events are
             sorted by `stream_ordering` ascending.
+
+            `event_id`/`sender` can be `None` when the server leaves a room (meaning
+            everyone locally left) or a state reset which removed the person from the
+            room. We can't tell the difference between the two cases with what's
+            available in the `current_state_delta_stream` table. To actually check for a
+            state reset, you need to check if a membership still exists in the room.
         """
         # Start by ruling out cases where a DB query is not necessary.
         if from_key == to_key:
@@ -1052,6 +1058,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                         membership=(
                             membership if membership is not None else Membership.LEAVE
                         ),
+                        # This will also be null for the same reasons if `s.event_id = null`
                         sender=sender,
                         # Prev event
                         prev_event_id=prev_event_id,
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
index c2cfb29866..ea3ca57957 100644
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ b/tests/rest/client/sliding_sync/test_sliding_sync.py
@@ -15,7 +15,7 @@ import logging
 from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
 from unittest.mock import AsyncMock
 
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
 from typing_extensions import assert_never
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -23,12 +23,16 @@ from twisted.test.proto_helpers import MemoryReactor
 import synapse.rest.admin
 from synapse.api.constants import (
     AccountDataTypes,
+    EventContentFields,
     EventTypes,
+    JoinRules,
     Membership,
+    RoomTypes,
 )
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, StrippedStateEvent, make_event_from_dict
 from synapse.events.snapshot import EventContext
+from synapse.handlers.sliding_sync import StateValues
 from synapse.rest.client import account_data, devices, login, receipts, room, sync
 from synapse.server import HomeServer
 from synapse.types import (
@@ -43,6 +47,7 @@ from synapse.util.stringutils import random_string
 
 from tests import unittest
 from tests.server import TimedOutException
+from tests.test_utils.event_injection import create_event
 
 logger = logging.getLogger(__name__)
 
@@ -421,6 +426,9 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.event_sources = hs.get_event_sources()
         self.storage_controllers = hs.get_storage_controllers()
         self.account_data_handler = hs.get_account_data_handler()
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self.persistence = persistence
 
         super().prepare(reactor, clock, hs)
 
@@ -988,3 +996,472 @@ class SlidingSyncTestCase(SlidingSyncBase):
         # Make the Sliding Sync request
         response_body, _ = self.do_sync(sync_body, tok=user1_tok)
         self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+    def test_state_reset_room_comes_down_incremental_sync(self) -> None:
+        """Test that a room that we were state reset out of comes down
+        incremental sync"""
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        room_id1 = self.helper.create_room_as(
+            user2_id,
+            is_public=True,
+            tok=user2_tok,
+            extra_content={
+                "name": "my super room",
+            },
+        )
+
+        # Create an event for us to point back to for the state reset
+        event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+        event_id = event_response["event_id"]
+
+        self.helper.join(room_id1, user1_id, tok=user1_tok)
+
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        # Request all state just to see what we get back when we are
+                        # state reset out of the room
+                        [StateValues.WILDCARD, StateValues.WILDCARD]
+                    ],
+                    "timeline_limit": 1,
+                }
+            }
+        }
+
+        # Make the Sliding Sync request
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+        # Make sure we see room1
+        self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+        self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+
+        # Trigger a state reset
+        join_rule_event, join_rule_context = self.get_success(
+            create_event(
+                self.hs,
+                prev_event_ids=[event_id],
+                type=EventTypes.JoinRules,
+                state_key="",
+                content={"join_rule": JoinRules.INVITE},
+                sender=user2_id,
+                room_id=room_id1,
+                room_version=self.get_success(self.store.get_room_version_id(room_id1)),
+            )
+        )
+        _, join_rule_event_pos, _ = self.get_success(
+            self.persistence.persist_event(join_rule_event, join_rule_context)
+        )
+
+        # FIXME: We're manually busting the cache since
+        # https://github.com/element-hq/synapse/issues/17368 is not solved yet
+        self.store._membership_stream_cache.entity_has_changed(
+            user1_id, join_rule_event_pos.stream
+        )
+
+        # Ensure that the state reset worked and only user2 is in the room now
+        users_in_room = self.get_success(self.store.get_users_in_room(room_id1))
+        self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+        state_map_at_reset = self.get_success(
+            self.storage_controllers.state.get_current_state(room_id1)
+        )
+
+        # Update the state after user1 was state reset out of the room
+        self.helper.send_state(
+            room_id1,
+            EventTypes.Name,
+            {EventContentFields.ROOM_NAME: "my super duper room"},
+            tok=user2_tok,
+        )
+
+        # Make another Sliding Sync request (incremental)
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        # Expect to see room1 because it is `newly_left` thanks to being state reset out
+        # of it since the last time we synced. We need to let the client know that
+        # something happened and that they are no longer in the room.
+        self.assertIncludes(set(response_body["rooms"].keys()), {room_id1}, exact=True)
+        # We set `initial=True` to indicate that the client should reset the state they
+        # have about the room
+        self.assertEqual(response_body["rooms"][room_id1]["initial"], True)
+        # They shouldn't see anything past the state reset
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1]["required_state"],
+            # We should see all the state events in the room
+            state_map_at_reset.values(),
+            exact=True,
+        )
+        # The position where the state reset happened
+        self.assertEqual(
+            response_body["rooms"][room_id1]["bump_stamp"],
+            join_rule_event_pos.stream,
+            response_body["rooms"][room_id1],
+        )
+
+        # Other non-important things. We just want to check what these are so we know
+        # what happens in a state reset scenario.
+        #
+        # Room name was set at the time of the state reset so we should still be able to
+        # see it.
+        self.assertEqual(response_body["rooms"][room_id1]["name"], "my super room")
+        # Could be set but there is no avatar for this room
+        self.assertIsNone(
+            response_body["rooms"][room_id1].get("avatar"),
+            response_body["rooms"][room_id1],
+        )
+        # Could be set but this room isn't marked as a DM
+        self.assertIsNone(
+            response_body["rooms"][room_id1].get("is_dm"),
+            response_body["rooms"][room_id1],
+        )
+        # Empty timeline because we are not in the room at all (they are all being
+        # filtered out)
+        self.assertIsNone(
+            response_body["rooms"][room_id1].get("timeline"),
+            response_body["rooms"][room_id1],
+        )
+        # `limited` since we're not providing any timeline events but there are some in
+        # the room.
+        self.assertEqual(response_body["rooms"][room_id1]["limited"], True)
+        # User is no longer in the room so they can't see this info
+        self.assertIsNone(
+            response_body["rooms"][room_id1].get("joined_count"),
+            response_body["rooms"][room_id1],
+        )
+        self.assertIsNone(
+            response_body["rooms"][room_id1].get("invited_count"),
+            response_body["rooms"][room_id1],
+        )
+
+    def test_state_reset_previously_room_comes_down_incremental_sync_with_filters(
+        self,
+    ) -> None:
+        """
+        Test that a room that we were state reset out of should always be sent down
+        regardless of the filters if it has been sent down the connection before.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        # Create a space room
+        space_room_id = self.helper.create_room_as(
+            user2_id,
+            tok=user2_tok,
+            extra_content={
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+                "name": "my super space",
+            },
+        )
+
+        # Create an event for us to point back to for the state reset
+        event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+        event_id = event_response["event_id"]
+
+        self.helper.join(space_room_id, user1_id, tok=user1_tok)
+
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        # Request all state just to see what we get back when we are
+                        # state reset out of the room
+                        [StateValues.WILDCARD, StateValues.WILDCARD]
+                    ],
+                    "timeline_limit": 1,
+                    "filters": {
+                        "room_types": [RoomTypes.SPACE],
+                    },
+                }
+            }
+        }
+
+        # Make the Sliding Sync request
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+        # Make sure we see room1
+        self.assertIncludes(
+            set(response_body["rooms"].keys()), {space_room_id}, exact=True
+        )
+        self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+
+        # Trigger a state reset
+        join_rule_event, join_rule_context = self.get_success(
+            create_event(
+                self.hs,
+                prev_event_ids=[event_id],
+                type=EventTypes.JoinRules,
+                state_key="",
+                content={"join_rule": JoinRules.INVITE},
+                sender=user2_id,
+                room_id=space_room_id,
+                room_version=self.get_success(
+                    self.store.get_room_version_id(space_room_id)
+                ),
+            )
+        )
+        _, join_rule_event_pos, _ = self.get_success(
+            self.persistence.persist_event(join_rule_event, join_rule_context)
+        )
+
+        # FIXME: We're manually busting the cache since
+        # https://github.com/element-hq/synapse/issues/17368 is not solved yet
+        self.store._membership_stream_cache.entity_has_changed(
+            user1_id, join_rule_event_pos.stream
+        )
+
+        # Ensure that the state reset worked and only user2 is in the room now
+        users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+        self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+        state_map_at_reset = self.get_success(
+            self.storage_controllers.state.get_current_state(space_room_id)
+        )
+
+        # Update the state after user1 was state reset out of the room
+        self.helper.send_state(
+            space_room_id,
+            EventTypes.Name,
+            {EventContentFields.ROOM_NAME: "my super duper space"},
+            tok=user2_tok,
+        )
+
+        # User2 also leaves the room so the server is no longer participating in the room
+        # and we don't have access to current state
+        self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+        # Make another Sliding Sync request (incremental)
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        # Expect to see room1 because it is `newly_left` thanks to being state reset out
+        # of it since the last time we synced. We need to let the client know that
+        # something happened and that they are no longer in the room.
+        self.assertIncludes(
+            set(response_body["rooms"].keys()), {space_room_id}, exact=True
+        )
+        # We set `initial=True` to indicate that the client should reset the state they
+        # have about the room
+        self.assertEqual(response_body["rooms"][space_room_id]["initial"], True)
+        # They shouldn't see anything past the state reset
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][space_room_id]["required_state"],
+            # We should see all the state events in the room
+            state_map_at_reset.values(),
+            exact=True,
+        )
+        # The position where the state reset happened
+        self.assertEqual(
+            response_body["rooms"][space_room_id]["bump_stamp"],
+            join_rule_event_pos.stream,
+            response_body["rooms"][space_room_id],
+        )
+
+        # Other non-important things. We just want to check what these are so we know
+        # what happens in a state reset scenario.
+        #
+        # Room name was set at the time of the state reset so we should still be able to
+        # see it.
+        self.assertEqual(
+            response_body["rooms"][space_room_id]["name"], "my super space"
+        )
+        # Could be set but there is no avatar for this room
+        self.assertIsNone(
+            response_body["rooms"][space_room_id].get("avatar"),
+            response_body["rooms"][space_room_id],
+        )
+        # Could be set but this room isn't marked as a DM
+        self.assertIsNone(
+            response_body["rooms"][space_room_id].get("is_dm"),
+            response_body["rooms"][space_room_id],
+        )
+        # Empty timeline because we are not in the room at all (they are all being
+        # filtered out)
+        self.assertIsNone(
+            response_body["rooms"][space_room_id].get("timeline"),
+            response_body["rooms"][space_room_id],
+        )
+        # `limited` since we're not providing any timeline events but there are some in
+        # the room.
+        self.assertEqual(response_body["rooms"][space_room_id]["limited"], True)
+        # User is no longer in the room so they can't see this info
+        self.assertIsNone(
+            response_body["rooms"][space_room_id].get("joined_count"),
+            response_body["rooms"][space_room_id],
+        )
+        self.assertIsNone(
+            response_body["rooms"][space_room_id].get("invited_count"),
+            response_body["rooms"][space_room_id],
+        )
+
+    @parameterized.expand(
+        [
+            ("server_leaves_room", True),
+            ("server_participating_in_room", False),
+        ]
+    )
+    def test_state_reset_never_room_incremental_sync_with_filters(
+        self, test_description: str, server_leaves_room: bool
+    ) -> None:
+        """
+        Test that a room that we were state reset out of should be sent down if we can
+        figure out the state or if it was sent down the connection before.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        # Create a space room
+        space_room_id = self.helper.create_room_as(
+            user2_id,
+            tok=user2_tok,
+            extra_content={
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+                "name": "my super space",
+            },
+        )
+
+        # Create another space room
+        space_room_id2 = self.helper.create_room_as(
+            user2_id,
+            tok=user2_tok,
+            extra_content={
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+            },
+        )
+
+        # Create an event for us to point back to for the state reset
+        event_response = self.helper.send(space_room_id, "test", tok=user2_tok)
+        event_id = event_response["event_id"]
+
+        # User1 joins the rooms
+        #
+        self.helper.join(space_room_id, user1_id, tok=user1_tok)
+        # Join space_room_id2 so that it is at the top of the list
+        self.helper.join(space_room_id2, user1_id, tok=user1_tok)
+
+        # Make a SS request for only the top room.
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 0]],
+                    "required_state": [
+                        # Request all state just to see what we get back when we are
+                        # state reset out of the room
+                        [StateValues.WILDCARD, StateValues.WILDCARD]
+                    ],
+                    "timeline_limit": 1,
+                    "filters": {
+                        "room_types": [RoomTypes.SPACE],
+                    },
+                }
+            }
+        }
+
+        # Make the Sliding Sync request
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+        # Make sure we only see space_room_id2
+        self.assertIncludes(
+            set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+        )
+        self.assertEqual(response_body["rooms"][space_room_id2]["initial"], True)
+
+        # Just create some activity in space_room_id2 so it appears when we incremental sync again
+        self.helper.send(space_room_id2, "test", tok=user2_tok)
+
+        # Trigger a state reset
+        join_rule_event, join_rule_context = self.get_success(
+            create_event(
+                self.hs,
+                prev_event_ids=[event_id],
+                type=EventTypes.JoinRules,
+                state_key="",
+                content={"join_rule": JoinRules.INVITE},
+                sender=user2_id,
+                room_id=space_room_id,
+                room_version=self.get_success(
+                    self.store.get_room_version_id(space_room_id)
+                ),
+            )
+        )
+        _, join_rule_event_pos, _ = self.get_success(
+            self.persistence.persist_event(join_rule_event, join_rule_context)
+        )
+
+        # FIXME: We're manually busting the cache since
+        # https://github.com/element-hq/synapse/issues/17368 is not solved yet
+        self.store._membership_stream_cache.entity_has_changed(
+            user1_id, join_rule_event_pos.stream
+        )
+
+        # Ensure that the state reset worked and only user2 is in the room now
+        users_in_room = self.get_success(self.store.get_users_in_room(space_room_id))
+        self.assertIncludes(set(users_in_room), {user2_id}, exact=True)
+
+        # Update the state after user1 was state reset out of the room.
+        # This will also bump it to the top of the list.
+        self.helper.send_state(
+            space_room_id,
+            EventTypes.Name,
+            {EventContentFields.ROOM_NAME: "my super duper space"},
+            tok=user2_tok,
+        )
+
+        if server_leaves_room:
+            # User2 also leaves the room so the server is no longer participating in the room
+            # and we don't have access to current state
+            self.helper.leave(space_room_id, user2_id, tok=user2_tok)
+
+        # Make another Sliding Sync request (incremental)
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    # Expand the range to include all rooms
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        # Request all state just to see what we get back when we are
+                        # state reset out of the room
+                        [StateValues.WILDCARD, StateValues.WILDCARD]
+                    ],
+                    "timeline_limit": 1,
+                    "filters": {
+                        "room_types": [RoomTypes.SPACE],
+                    },
+                }
+            }
+        }
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        if self.use_new_tables:
+            if server_leaves_room:
+                # We still only expect to see space_room_id2 because even though we were state
+                # reset out of space_room_id, it was never sent down the connection before so we
+                # don't need to bother the client with it.
+                self.assertIncludes(
+                    set(response_body["rooms"].keys()), {space_room_id2}, exact=True
+                )
+            else:
+                # Both rooms show up because we can figure out the state for the
+                # `filters.room_types` if someone is still in the room (we look at the
+                # current state because `room_type` never changes).
+                self.assertIncludes(
+                    set(response_body["rooms"].keys()),
+                    {space_room_id, space_room_id2},
+                    exact=True,
+                )
+        else:
+            # Both rooms show up because we can actually take the time to figure out the
+            # state for the `filters.room_types` in the fallback path (we look at
+            # historical state for `LEAVE` membership).
+            self.assertIncludes(
+                set(response_body["rooms"].keys()),
+                {space_room_id, space_room_id2},
+                exact=True,
+            )
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 837eb434aa..ed5f286243 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -27,7 +27,13 @@ from immutabledict import immutabledict
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes
+from synapse.api.constants import (
+    Direction,
+    EventTypes,
+    JoinRules,
+    Membership,
+    RelationTypes,
+)
 from synapse.api.filtering import Filter
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events import FrozenEventV3
@@ -1154,7 +1160,7 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
                     room_id=room_id1,
                     event_id=None,
                     event_pos=dummy_state_pos,
-                    membership="leave",
+                    membership=Membership.LEAVE,
                     sender=None,  # user1_id,
                     prev_event_id=join_response1["event_id"],
                     prev_event_pos=join_pos1,
@@ -1164,6 +1170,81 @@ class GetCurrentStateDeltaMembershipChangesForUserTestCase(HomeserverTestCase):
             ],
         )
 
+    def test_state_reset2(self) -> None:
+        """
+        Test a state reset scenario where the user gets removed from the room (when
+        there is no corresponding leave event)
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        room_id1 = self.helper.create_room_as(user2_id, is_public=True, tok=user2_tok)
+
+        event_response = self.helper.send(room_id1, "test", tok=user2_tok)
+        event_id = event_response["event_id"]
+
+        user1_join_response = self.helper.join(room_id1, user1_id, tok=user1_tok)
+        user1_join_pos = self.get_success(
+            self.store.get_position_for_event(user1_join_response["event_id"])
+        )
+
+        before_reset_token = self.event_sources.get_current_token()
+
+        # Trigger a state reset
+        join_rule_event, join_rule_context = self.get_success(
+            create_event(
+                self.hs,
+                prev_event_ids=[event_id],
+                type=EventTypes.JoinRules,
+                state_key="",
+                content={"join_rule": JoinRules.INVITE},
+                sender=user2_id,
+                room_id=room_id1,
+                room_version=self.get_success(self.store.get_room_version_id(room_id1)),
+            )
+        )
+        _, join_rule_event_pos, _ = self.get_success(
+            self.persistence.persist_event(join_rule_event, join_rule_context)
+        )
+
+        # FIXME: We're manually busting the cache since
+        # https://github.com/element-hq/synapse/issues/17368 is not solved yet
+        self.store._membership_stream_cache.entity_has_changed(
+            user1_id, join_rule_event_pos.stream
+        )
+
+        after_reset_token = self.event_sources.get_current_token()
+
+        membership_changes = self.get_success(
+            self.store.get_current_state_delta_membership_changes_for_user(
+                user1_id,
+                from_key=before_reset_token.room_key,
+                to_key=after_reset_token.room_key,
+            )
+        )
+
+        # Let the whole diff show on failure
+        self.maxDiff = None
+        self.assertEqual(
+            membership_changes,
+            [
+                CurrentStateDeltaMembership(
+                    room_id=room_id1,
+                    event_id=None,
+                    # The position where the state reset happened
+                    event_pos=join_rule_event_pos,
+                    membership=Membership.LEAVE,
+                    sender=None,
+                    prev_event_id=user1_join_response["event_id"],
+                    prev_event_pos=user1_join_pos,
+                    prev_membership="join",
+                    prev_sender=user1_id,
+                ),
+            ],
+        )
+
     def test_excluded_room_ids(self) -> None:
         """
         Test that the `excluded_room_ids` option excludes changes from the specified rooms.
-- 
GitLab