From 0932c775399575bba509728dc6721d1e48a6f689 Mon Sep 17 00:00:00 2001
From: Eric Eastwood <erice@element.io>
Date: Mon, 4 Nov 2024 10:17:58 -0600
Subject: [PATCH] Sliding Sync: Lazy-loading room members on incremental sync
 (remember memberships) (#17809)

Lazy-loading room members on incremental sync and remember which
memberships we've sent down the connection before (up-to 100)

Fix https://github.com/element-hq/synapse/issues/17804
---
 changelog.d/17809.bugfix                      |   1 +
 synapse/handlers/sliding_sync/__init__.py     | 168 +++++--
 tests/handlers/test_sliding_sync.py           | 421 +++++++++++++++++-
 .../sliding_sync/test_rooms_required_state.py | 257 ++++++++++-
 4 files changed, 788 insertions(+), 59 deletions(-)
 create mode 100644 changelog.d/17809.bugfix

diff --git a/changelog.d/17809.bugfix b/changelog.d/17809.bugfix
new file mode 100644
index 0000000000..e244a36bd3
--- /dev/null
+++ b/changelog.d/17809.bugfix
@@ -0,0 +1 @@
+Fix bug with sliding sync where `$LAZY`-loading room members would not return `required_state` membership in incremental syncs.
diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py
index a1a6728fb9..85cfbc6dbf 100644
--- a/synapse/handlers/sliding_sync/__init__.py
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -12,6 +12,7 @@
 # <https://www.gnu.org/licenses/agpl-3.0.html>.
 #
 
+import itertools
 import logging
 from itertools import chain
 from typing import TYPE_CHECKING, AbstractSet, Dict, List, Mapping, Optional, Set, Tuple
@@ -79,6 +80,15 @@ sync_processing_time = Histogram(
     ["initial"],
 )
 
+# Limit the number of state_keys we should remember sending down the connection for each
+# (room_id, user_id). We don't want to store and pull out too much data in the database.
+#
+# 100 is an arbitrary but small-ish number. The idea is that we probably won't send down
+# too many redundant member state events (that the client already knows about) for a
+# given ongoing conversation if we keep 100 around. Most rooms don't have 100 members
+# anyway and it takes a while to cycle through 100 members.
+MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER = 100
+
 
 class SlidingSyncHandler:
     def __init__(self, hs: "HomeServer"):
@@ -873,6 +883,14 @@ class SlidingSyncHandler:
         #
         # Calculate the `StateFilter` based on the `required_state` for the room
         required_state_filter = StateFilter.none()
+        # The requested `required_state_map` with the lazy membership expanded and
+        # `$ME` replaced with the user's ID. This allows us to see what membership we've
+        # sent down to the client in the next request.
+        #
+        # Make a copy so we can modify it. Still need to be careful to make a copy of
+        # the state key sets if we want to add/remove from them. We could make a deep
+        # copy but this saves us some work.
+        expanded_required_state_map = dict(room_sync_config.required_state_map)
         if room_membership_for_user_at_to_token.membership not in (
             Membership.INVITE,
             Membership.KNOCK,
@@ -938,21 +956,48 @@ class SlidingSyncHandler:
                         ):
                             lazy_load_room_members = True
                             # Everyone in the timeline is relevant
+                            #
+                            # FIXME: We probably also care about invite, ban, kick, targets, etc
+                            # but the spec only mentions "senders".
                             timeline_membership: Set[str] = set()
                             if timeline_events is not None:
                                 for timeline_event in timeline_events:
                                     timeline_membership.add(timeline_event.sender)
 
+                            # Update the required state filter so we pick up the new
+                            # membership
                             for user_id in timeline_membership:
                                 required_state_types.append(
                                     (EventTypes.Member, user_id)
                                 )
 
-                            # FIXME: We probably also care about invite, ban, kick, targets, etc
-                            # but the spec only mentions "senders".
+                            # Add an explicit entry for each user in the timeline
+                            #
+                            # Make a new set or copy of the state key set so we can
+                            # modify it without affecting the original
+                            # `required_state_map`
+                            expanded_required_state_map[EventTypes.Member] = (
+                                expanded_required_state_map.get(
+                                    EventTypes.Member, set()
+                                )
+                                | timeline_membership
+                            )
                         elif state_key == StateValues.ME:
                             num_others += 1
                             required_state_types.append((state_type, user.to_string()))
+                            # Replace `$ME` with the user's ID so we can deduplicate
+                            # when someone requests the same state with `$ME` or with
+                            # their user ID.
+                            #
+                            # Make a new set or copy of the state key set so we can
+                            # modify it without affecting the original
+                            # `required_state_map`
+                            expanded_required_state_map[EventTypes.Member] = (
+                                expanded_required_state_map.get(
+                                    EventTypes.Member, set()
+                                )
+                                | {user.to_string()}
+                            )
                         else:
                             num_others += 1
                             required_state_types.append((state_type, state_key))
@@ -1016,8 +1061,8 @@ class SlidingSyncHandler:
                 changed_required_state_map, added_state_filter = (
                     _required_state_changes(
                         user.to_string(),
-                        previous_room_config=prev_room_sync_config,
-                        room_sync_config=room_sync_config,
+                        prev_required_state_map=prev_room_sync_config.required_state_map,
+                        request_required_state_map=expanded_required_state_map,
                         state_deltas=room_state_delta_id_map,
                     )
                 )
@@ -1131,7 +1176,9 @@ class SlidingSyncHandler:
             # sensible order again.
             bump_stamp = 0
 
-        room_sync_required_state_map_to_persist = room_sync_config.required_state_map
+        room_sync_required_state_map_to_persist: Mapping[str, AbstractSet[str]] = (
+            expanded_required_state_map
+        )
         if changed_required_state_map:
             room_sync_required_state_map_to_persist = changed_required_state_map
 
@@ -1185,7 +1232,10 @@ class SlidingSyncHandler:
                 )
 
         else:
-            new_connection_state.room_configs[room_id] = room_sync_config
+            new_connection_state.room_configs[room_id] = RoomSyncConfig(
+                timeline_limit=room_sync_config.timeline_limit,
+                required_state_map=room_sync_required_state_map_to_persist,
+            )
 
         set_tag(SynapseTags.RESULT_PREFIX + "initial", initial)
 
@@ -1320,8 +1370,8 @@ class SlidingSyncHandler:
 def _required_state_changes(
     user_id: str,
     *,
-    previous_room_config: "RoomSyncConfig",
-    room_sync_config: RoomSyncConfig,
+    prev_required_state_map: Mapping[str, AbstractSet[str]],
+    request_required_state_map: Mapping[str, AbstractSet[str]],
     state_deltas: StateMap[str],
 ) -> Tuple[Optional[Mapping[str, AbstractSet[str]]], StateFilter]:
     """Calculates the changes between the required state room config from the
@@ -1342,10 +1392,6 @@ def _required_state_changes(
         and the state filter to use to fetch extra current state that we need to
         return.
     """
-
-    prev_required_state_map = previous_room_config.required_state_map
-    request_required_state_map = room_sync_config.required_state_map
-
     if prev_required_state_map == request_required_state_map:
         # There has been no change. Return immediately.
         return None, StateFilter.none()
@@ -1378,12 +1424,19 @@ def _required_state_changes(
     # client. Passed to `StateFilter.from_types(...)`
     added: List[Tuple[str, Optional[str]]] = []
 
+    # Convert the list of state deltas to map from type to state_keys that have
+    # changed.
+    changed_types_to_state_keys: Dict[str, Set[str]] = {}
+    for event_type, state_key in state_deltas:
+        changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)
+
     # First we calculate what, if anything, has been *added*.
     for event_type in (
         prev_required_state_map.keys() | request_required_state_map.keys()
     ):
         old_state_keys = prev_required_state_map.get(event_type, set())
         request_state_keys = request_required_state_map.get(event_type, set())
+        changed_state_keys = changed_types_to_state_keys.get(event_type, set())
 
         if old_state_keys == request_state_keys:
             # No change to this type
@@ -1393,8 +1446,55 @@ def _required_state_changes(
             # Nothing *added*, so we skip. Removals happen below.
             continue
 
-        # Always update changes to include the newly added keys
-        changes[event_type] = request_state_keys
+        # We only remove state keys from the effective state if they've been
+        # removed from the request *and* the state has changed. This ensures
+        # that if a client removes and then re-adds a state key, we only send
+        # down the associated current state event if its changed (rather than
+        # sending down the same event twice).
+        invalidated_state_keys = (
+            old_state_keys - request_state_keys
+        ) & changed_state_keys
+
+        # Figure out which state keys we should remember sending down the connection
+        inheritable_previous_state_keys = (
+            # Retain the previous state_keys that we've sent down before.
+            # Wildcard and lazy state keys are not sticky from previous requests.
+            (old_state_keys - {StateValues.WILDCARD, StateValues.LAZY})
+            - invalidated_state_keys
+        )
+
+        # Always update changes to include the newly added keys (we've expanded the set
+        # of state keys), use the new requested set with whatever hasn't been
+        # invalidated from the previous set.
+        changes[event_type] = request_state_keys | inheritable_previous_state_keys
+        # Limit the number of state_keys we should remember sending down the connection
+        # for each (room_id, user_id). We don't want to store and pull out too much data
+        # in the database. This is a happy-medium between remembering nothing and
+        # everything. We can avoid sending redundant state down the connection most of
+        # the time given that most rooms don't have 100 members anyway and it takes a
+        # while to cycle through 100 members.
+        #
+        # Only remember up to (MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER)
+        if len(changes[event_type]) > MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER:
+            # Reset back to only the requested state keys
+            changes[event_type] = request_state_keys
+
+            # Skip if there isn't any room to fill in the rest with previous state keys
+            if len(request_state_keys) < MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER:
+                # Fill the rest with previous state_keys. Ideally, we could sort
+                # these by recency but it's just a set so just pick an arbitrary
+                # subset (good enough).
+                changes[event_type] = changes[event_type] | set(
+                    itertools.islice(
+                        inheritable_previous_state_keys,
+                        # Just taking the difference isn't perfect as there could be
+                        # overlap in the keys between the requested and previous but we
+                        # will decide to just take the easy route for now and avoid
+                        # additional set operations to figure it out.
+                        MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER
+                        - len(request_state_keys),
+                    )
+                )
 
         if StateValues.WILDCARD in old_state_keys:
             # We were previously fetching everything for this type, so we don't need to
@@ -1421,12 +1521,6 @@ def _required_state_changes(
 
     added_state_filter = StateFilter.from_types(added)
 
-    # Convert the list of state deltas to map from type to state_keys that have
-    # changed.
-    changed_types_to_state_keys: Dict[str, Set[str]] = {}
-    for event_type, state_key in state_deltas:
-        changed_types_to_state_keys.setdefault(event_type, set()).add(state_key)
-
     # Figure out what changes we need to apply to the effective required state
     # config.
     for event_type, changed_state_keys in changed_types_to_state_keys.items():
@@ -1437,15 +1531,23 @@ def _required_state_changes(
             # No change.
             continue
 
+        # If we see the `user_id` as a state_key, also add "$ME" to the list of state
+        # that has changed to account for people requesting `required_state` with `$ME`
+        # or their user ID.
+        if user_id in changed_state_keys:
+            changed_state_keys.add(StateValues.ME)
+
+        # We only remove state keys from the effective state if they've been
+        # removed from the request *and* the state has changed. This ensures
+        # that if a client removes and then re-adds a state key, we only send
+        # down the associated current state event if its changed (rather than
+        # sending down the same event twice).
+        invalidated_state_keys = (
+            old_state_keys - request_state_keys
+        ) & changed_state_keys
+
+        # We've expanded the set of state keys, ... (already handled above)
         if request_state_keys - old_state_keys:
-            # We've expanded the set of state keys, so we just clobber the
-            # current set with the new set.
-            #
-            # We could also ensure that we keep entries where the state hasn't
-            # changed, but are no longer in the requested required state, but
-            # that's a sufficient edge case that we can ignore (as its only a
-            # performance optimization).
-            changes[event_type] = request_state_keys
             continue
 
         old_state_key_wildcard = StateValues.WILDCARD in old_state_keys
@@ -1467,11 +1569,6 @@ def _required_state_changes(
                 changes[event_type] = request_state_keys
                 continue
 
-        # Handle "$ME" values by adding "$ME" if the state key matches the user
-        # ID.
-        if user_id in changed_state_keys:
-            changed_state_keys.add(StateValues.ME)
-
         # At this point there are no wildcards and no additions to the set of
         # state keys requested, only deletions.
         #
@@ -1480,9 +1577,8 @@ def _required_state_changes(
         # that if a client removes and then re-adds a state key, we only send
         # down the associated current state event if its changed (rather than
         # sending down the same event twice).
-        invalidated = (old_state_keys - request_state_keys) & changed_state_keys
-        if invalidated:
-            changes[event_type] = old_state_keys - invalidated
+        if invalidated_state_keys:
+            changes[event_type] = old_state_keys - invalidated_state_keys
 
     if changes:
         # Update the required state config based on the changes.
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 9a68d1dd95..5b7e2937f8 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -33,6 +33,7 @@ from synapse.api.constants import (
 )
 from synapse.api.room_versions import RoomVersions
 from synapse.handlers.sliding_sync import (
+    MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
     RoomsForUserType,
     RoomSyncConfig,
     StateValues,
@@ -3319,6 +3320,32 @@ class RequiredStateChangesTestCase(unittest.TestCase):
                     ),
                 ),
             ),
+            (
+                "simple_retain_previous_state_keys",
+                """Test adding a state key to the config and retaining a previously sent state_key""",
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={"type": {"state_key1"}},
+                    request_required_state_map={"type": {"state_key2", "state_key3"}},
+                    state_deltas={("type", "state_key2"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # We've added a key so we should persist the changed required state
+                        # config.
+                        #
+                        # Retain `state_key1` from the `previous_required_state_map`
+                        {"type": {"state_key1", "state_key2", "state_key3"}},
+                        # We should see the new state_keys added
+                        StateFilter.from_types(
+                            [("type", "state_key2"), ("type", "state_key3")]
+                        ),
+                    ),
+                    expected_without_state_deltas=(
+                        {"type": {"state_key1", "state_key2", "state_key3"}},
+                        StateFilter.from_types(
+                            [("type", "state_key2"), ("type", "state_key3")]
+                        ),
+                    ),
+                ),
+            ),
             (
                 "simple_remove_type",
                 """
@@ -3724,6 +3751,249 @@ class RequiredStateChangesTestCase(unittest.TestCase):
                     ),
                 ),
             ),
+            (
+                "state_key_lazy_keep_previous_memberships_and_no_new_memberships",
+                """
+                This test mimics a request with lazy-loading room members enabled where
+                we have previously sent down user2 and user3's membership events and now
+                we're sending down another response without any timeline events.
+                """,
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={
+                        EventTypes.Member: {
+                            StateValues.LAZY,
+                            "@user2:test",
+                            "@user3:test",
+                        }
+                    },
+                    request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+                    state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # Remove "@user2:test" since that state has changed and is no
+                        # longer being requested anymore. Since something was removed,
+                        # we should persist the changed to required state. That way next
+                        # time, they request "@user2:test", we see that we haven't sent
+                        # it before and send the new state. (we should still keep track
+                        # that we've sent specific `EventTypes.Member` before)
+                        {
+                            EventTypes.Member: {
+                                StateValues.LAZY,
+                                "@user3:test",
+                            }
+                        },
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                    expected_without_state_deltas=(
+                        # We're not requesting any specific `EventTypes.Member` now but
+                        # since that state hasn't changed, nothing should change (we
+                        # should still keep track that we've sent specific
+                        # `EventTypes.Member` before).
+                        None,
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                ),
+            ),
+            (
+                "state_key_lazy_keep_previous_memberships_with_new_memberships",
+                """
+                This test mimics a request with lazy-loading room members enabled where
+                we have previously sent down user2 and user3's membership events and now
+                we're sending down another response with a new event from user4.
+                """,
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={
+                        EventTypes.Member: {
+                            StateValues.LAZY,
+                            "@user2:test",
+                            "@user3:test",
+                        }
+                    },
+                    request_required_state_map={
+                        EventTypes.Member: {StateValues.LAZY, "@user4:test"}
+                    },
+                    state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # Since "@user4:test" was added, we should persist the changed
+                        # required state config.
+                        #
+                        # Also remove "@user2:test" since that state has changed and is no
+                        # longer being requested anymore. Since something was removed,
+                        # we also should persist the changed to required state. That way next
+                        # time, they request "@user2:test", we see that we haven't sent
+                        # it before and send the new state. (we should still keep track
+                        # that we've sent specific `EventTypes.Member` before)
+                        {
+                            EventTypes.Member: {
+                                StateValues.LAZY,
+                                "@user3:test",
+                                "@user4:test",
+                            }
+                        },
+                        # We should see the new state_keys added
+                        StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+                    ),
+                    expected_without_state_deltas=(
+                        # Since "@user4:test" was added, we should persist the changed
+                        # required state config.
+                        {
+                            EventTypes.Member: {
+                                StateValues.LAZY,
+                                "@user2:test",
+                                "@user3:test",
+                                "@user4:test",
+                            }
+                        },
+                        # We should see the new state_keys added
+                        StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+                    ),
+                ),
+            ),
+            (
+                "state_key_expand_lazy_keep_previous_memberships",
+                """
+                Test expanding the `required_state` to lazy-loading room members.
+                """,
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={
+                        EventTypes.Member: {"@user2:test", "@user3:test"}
+                    },
+                    request_required_state_map={EventTypes.Member: {StateValues.LAZY}},
+                    state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # Since `StateValues.LAZY` was added, we should persist the
+                        # changed required state config.
+                        #
+                        # Also remove "@user2:test" since that state has changed and is no
+                        # longer being requested anymore. Since something was removed,
+                        # we also should persist the changed to required state. That way next
+                        # time, they request "@user2:test", we see that we haven't sent
+                        # it before and send the new state. (we should still keep track
+                        # that we've sent specific `EventTypes.Member` before)
+                        {
+                            EventTypes.Member: {
+                                StateValues.LAZY,
+                                "@user3:test",
+                            }
+                        },
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                    expected_without_state_deltas=(
+                        # Since `StateValues.LAZY` was added, we should persist the
+                        # changed required state config.
+                        {
+                            EventTypes.Member: {
+                                StateValues.LAZY,
+                                "@user2:test",
+                                "@user3:test",
+                            }
+                        },
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                ),
+            ),
+            (
+                "state_key_retract_lazy_keep_previous_memberships_no_new_memberships",
+                """
+                Test retracting the `required_state` to no longer lazy-loading room members.
+                """,
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={
+                        EventTypes.Member: {
+                            StateValues.LAZY,
+                            "@user2:test",
+                            "@user3:test",
+                        }
+                    },
+                    request_required_state_map={},
+                    state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # Remove `EventTypes.Member` since there's been a change to that
+                        # state, (persist the change to required state). That way next
+                        # time, they request `EventTypes.Member`, we see that we haven't
+                        # sent it before and send the new state. (if we were tracking
+                        # that we sent any other state, we should still keep track
+                        # that).
+                        #
+                        # This acts the same as the `simple_remove_type` test. It's
+                        # possible that we could remember the specific `state_keys` that
+                        # we have sent down before but this currently just acts the same
+                        # as if a whole `type` was removed. Perhaps it's good that we
+                        # "garbage collect" and forget what we've sent before for a
+                        # given `type`  when the client stops caring about a certain
+                        # `type`.
+                        {},
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                    expected_without_state_deltas=(
+                        # `EventTypes.Member` is no longer requested but since that
+                        # state hasn't changed, nothing should change (we should still
+                        # keep track that we've sent `EventTypes.Member` before).
+                        None,
+                        # We don't need to request anything more if they are requesting
+                        # less state now
+                        StateFilter.none(),
+                    ),
+                ),
+            ),
+            (
+                "state_key_retract_lazy_keep_previous_memberships_with_new_memberships",
+                """
+                Test retracting the `required_state` to no longer lazy-loading room members.
+                """,
+                RequiredStateChangesTestParameters(
+                    previous_required_state_map={
+                        EventTypes.Member: {
+                            StateValues.LAZY,
+                            "@user2:test",
+                            "@user3:test",
+                        }
+                    },
+                    request_required_state_map={EventTypes.Member: {"@user4:test"}},
+                    state_deltas={(EventTypes.Member, "@user2:test"): "$event_id"},
+                    expected_with_state_deltas=(
+                        # Since "@user4:test" was added, we should persist the changed
+                        # required state config.
+                        #
+                        # Also remove "@user2:test" since that state has changed and is no
+                        # longer being requested anymore. Since something was removed,
+                        # we also should persist the changed to required state. That way next
+                        # time, they request "@user2:test", we see that we haven't sent
+                        # it before and send the new state. (we should still keep track
+                        # that we've sent specific `EventTypes.Member` before)
+                        {
+                            EventTypes.Member: {
+                                "@user3:test",
+                                "@user4:test",
+                            }
+                        },
+                        # We should see the new state_keys added
+                        StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+                    ),
+                    expected_without_state_deltas=(
+                        # Since "@user4:test" was added, we should persist the changed
+                        # required state config.
+                        {
+                            EventTypes.Member: {
+                                "@user2:test",
+                                "@user3:test",
+                                "@user4:test",
+                            }
+                        },
+                        # We should see the new state_keys added
+                        StateFilter.from_types([(EventTypes.Member, "@user4:test")]),
+                    ),
+                ),
+            ),
             (
                 "type_wildcard_with_state_key_wildcard_to_explicit_state_keys",
                 """
@@ -3824,7 +4094,7 @@ class RequiredStateChangesTestCase(unittest.TestCase):
                 ),
             ),
             (
-                "state_key_wildcard_to_explicit_state_keys",
+                "explicit_state_keys_to_wildcard_state_key",
                 """Test switching from a wildcard to explicit state keys with a concrete type""",
                 RequiredStateChangesTestParameters(
                     previous_required_state_map={
@@ -3837,11 +4107,18 @@ class RequiredStateChangesTestCase(unittest.TestCase):
                     # request. And we need to request all of the state for that type
                     # because we previously, only sent down a few keys.
                     expected_with_state_deltas=(
-                        {"type1": {StateValues.WILDCARD}},
+                        {"type1": {StateValues.WILDCARD, "state_key2", "state_key3"}},
                         StateFilter.from_types([("type1", None)]),
                     ),
                     expected_without_state_deltas=(
-                        {"type1": {StateValues.WILDCARD}},
+                        {
+                            "type1": {
+                                StateValues.WILDCARD,
+                                "state_key1",
+                                "state_key2",
+                                "state_key3",
+                            }
+                        },
                         StateFilter.from_types([("type1", None)]),
                     ),
                 ),
@@ -3857,14 +4134,8 @@ class RequiredStateChangesTestCase(unittest.TestCase):
         # Without `state_deltas`
         changed_required_state_map, added_state_filter = _required_state_changes(
             user_id="@user:test",
-            previous_room_config=RoomSyncConfig(
-                timeline_limit=0,
-                required_state_map=test_parameters.previous_required_state_map,
-            ),
-            room_sync_config=RoomSyncConfig(
-                timeline_limit=0,
-                required_state_map=test_parameters.request_required_state_map,
-            ),
+            prev_required_state_map=test_parameters.previous_required_state_map,
+            request_required_state_map=test_parameters.request_required_state_map,
             state_deltas={},
         )
 
@@ -3882,14 +4153,8 @@ class RequiredStateChangesTestCase(unittest.TestCase):
         # With `state_deltas`
         changed_required_state_map, added_state_filter = _required_state_changes(
             user_id="@user:test",
-            previous_room_config=RoomSyncConfig(
-                timeline_limit=0,
-                required_state_map=test_parameters.previous_required_state_map,
-            ),
-            room_sync_config=RoomSyncConfig(
-                timeline_limit=0,
-                required_state_map=test_parameters.request_required_state_map,
-            ),
+            prev_required_state_map=test_parameters.previous_required_state_map,
+            request_required_state_map=test_parameters.request_required_state_map,
             state_deltas=test_parameters.state_deltas,
         )
 
@@ -3903,3 +4168,121 @@ class RequiredStateChangesTestCase(unittest.TestCase):
             test_parameters.expected_with_state_deltas[1],
             "added_state_filter does not match (with state_deltas)",
         )
+
+    @parameterized.expand(
+        [
+            # Test with a normal arbitrary type (no special meaning)
+            ("arbitrary_type", "type", set()),
+            # Test with membership
+            ("membership", EventTypes.Member, set()),
+            # Test with lazy-loading room members
+            ("lazy_loading_membership", EventTypes.Member, {StateValues.LAZY}),
+        ]
+    )
+    def test_limit_retained_previous_state_keys(
+        self,
+        _test_label: str,
+        event_type: str,
+        extra_state_keys: Set[str],
+    ) -> None:
+        """
+        Test that we limit the number of state_keys that we remember but always include
+        the state_keys that we've just requested.
+        """
+        previous_required_state_map = {
+            event_type: {
+                # Prefix the state_keys we've "prev_"iously sent so they are easier to
+                # identify in our assertions.
+                f"prev_state_key{i}"
+                for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+            }
+            | extra_state_keys
+        }
+        request_required_state_map = {
+            event_type: {f"state_key{i}" for i in range(50)} | extra_state_keys
+        }
+
+        # (function under test)
+        changed_required_state_map, added_state_filter = _required_state_changes(
+            user_id="@user:test",
+            prev_required_state_map=previous_required_state_map,
+            request_required_state_map=request_required_state_map,
+            state_deltas={},
+        )
+        assert changed_required_state_map is not None
+
+        # We should only remember up to the maximum number of state keys
+        self.assertGreaterEqual(
+            len(changed_required_state_map[event_type]),
+            # Most of the time this will be `MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER` but
+            # because we are just naively selecting enough previous state_keys to fill
+            # the limit, there might be some overlap in what's added back which means we
+            # might have slightly less than the limit.
+            #
+            # `extra_state_keys` overlaps in the previous and requested
+            # `required_state_map` so we might see this this scenario.
+            MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - len(extra_state_keys),
+        )
+
+        # Should include all of the requested state
+        self.assertIncludes(
+            changed_required_state_map[event_type],
+            request_required_state_map[event_type],
+        )
+        # And the rest is filled with the previous state keys
+        #
+        # We can't assert the exact state_keys since we don't know the order so we just
+        # check that they all start with "prev_" and that we have the correct amount.
+        remaining_state_keys = (
+            changed_required_state_map[event_type]
+            - request_required_state_map[event_type]
+        )
+        self.assertGreater(
+            len(remaining_state_keys),
+            0,
+        )
+        assert all(
+            state_key.startswith("prev_") for state_key in remaining_state_keys
+        ), "Remaining state_keys should be the previous state_keys"
+
+    def test_request_more_state_keys_than_remember_limit(self) -> None:
+        """
+        Test requesting more state_keys than fit in our limit to remember from previous
+        requests.
+        """
+        previous_required_state_map = {
+            "type": {
+                # Prefix the state_keys we've "prev_"iously sent so they are easier to
+                # identify in our assertions.
+                f"prev_state_key{i}"
+                for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER - 30)
+            }
+        }
+        request_required_state_map = {
+            "type": {
+                f"state_key{i}"
+                # Requesting more than the MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER
+                for i in range(MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER + 20)
+            }
+        }
+        # Ensure that we are requesting more than the limit
+        self.assertGreater(
+            len(request_required_state_map["type"]),
+            MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER,
+        )
+
+        # (function under test)
+        changed_required_state_map, added_state_filter = _required_state_changes(
+            user_id="@user:test",
+            prev_required_state_map=previous_required_state_map,
+            request_required_state_map=request_required_state_map,
+            state_deltas={},
+        )
+        assert changed_required_state_map is not None
+
+        # Should include all of the requested state
+        self.assertIncludes(
+            changed_required_state_map["type"],
+            request_required_state_map["type"],
+            exact=True,
+        )
diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py
index 7da51d4954..ecea5f2d5b 100644
--- a/tests/rest/client/sliding_sync/test_rooms_required_state.py
+++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py
@@ -381,10 +381,10 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
         )
         self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
 
-    def test_rooms_required_state_lazy_loading_room_members(self) -> None:
+    def test_rooms_required_state_lazy_loading_room_members_initial_sync(self) -> None:
         """
-        Test `rooms.required_state` returns people relevant to the timeline when
-        lazy-loading room members, `["m.room.member","$LAZY"]`.
+        On initial sync, test `rooms.required_state` returns people relevant to the
+        timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
         """
         user1_id = self.register_user("user1", "pass")
         user1_tok = self.login(user1_id, "pass")
@@ -432,6 +432,255 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
         )
         self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
 
+    def test_rooms_required_state_lazy_loading_room_members_incremental_sync(
+        self,
+    ) -> None:
+        """
+        On incremental sync, test `rooms.required_state` returns people relevant to the
+        timeline when lazy-loading room members, `["m.room.member","$LAZY"]`.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+        user3_id = self.register_user("user3", "pass")
+        user3_tok = self.login(user3_id, "pass")
+        user4_id = self.register_user("user4", "pass")
+        user4_tok = self.login(user4_id, "pass")
+
+        room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id1, user1_id, tok=user1_tok)
+        self.helper.join(room_id1, user3_id, tok=user3_tok)
+        self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+        self.helper.send(room_id1, "1", tok=user2_tok)
+        self.helper.send(room_id1, "2", tok=user2_tok)
+        self.helper.send(room_id1, "3", tok=user2_tok)
+
+        # Make the Sliding Sync request with lazy loading for the room members
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        [EventTypes.Create, ""],
+                        [EventTypes.Member, StateValues.LAZY],
+                    ],
+                    "timeline_limit": 3,
+                }
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+        # Send more timeline events into the room
+        self.helper.send(room_id1, "4", tok=user2_tok)
+        self.helper.send(room_id1, "5", tok=user4_tok)
+        self.helper.send(room_id1, "6", tok=user4_tok)
+
+        # Make an incremental Sliding Sync request
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        state_map = self.get_success(
+            self.storage_controllers.state.get_current_state(room_id1)
+        )
+
+        # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+        # but since we've seen user2 in the last sync (and their membership hasn't
+        # changed), we should only see user4 here.
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1]["required_state"],
+            {
+                state_map[(EventTypes.Member, user4_id)],
+            },
+            exact=True,
+        )
+        self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+    def test_rooms_required_state_expand_lazy_loading_room_members_incremental_sync(
+        self,
+    ) -> None:
+        """
+        Test that when we expand the `required_state` to include lazy-loading room
+        members, it returns people relevant to the timeline.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+        user3_id = self.register_user("user3", "pass")
+        user3_tok = self.login(user3_id, "pass")
+        user4_id = self.register_user("user4", "pass")
+        user4_tok = self.login(user4_id, "pass")
+
+        room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id1, user1_id, tok=user1_tok)
+        self.helper.join(room_id1, user3_id, tok=user3_tok)
+        self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+        self.helper.send(room_id1, "1", tok=user2_tok)
+        self.helper.send(room_id1, "2", tok=user2_tok)
+        self.helper.send(room_id1, "3", tok=user2_tok)
+
+        # Make the Sliding Sync request *without* lazy loading for the room members
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        [EventTypes.Create, ""],
+                    ],
+                    "timeline_limit": 3,
+                }
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+        # Send more timeline events into the room
+        self.helper.send(room_id1, "4", tok=user2_tok)
+        self.helper.send(room_id1, "5", tok=user4_tok)
+        self.helper.send(room_id1, "6", tok=user4_tok)
+
+        # Expand `required_state` and make an incremental Sliding Sync request *with*
+        # lazy-loading room members
+        sync_body["lists"]["foo-list"]["required_state"] = [
+            [EventTypes.Create, ""],
+            [EventTypes.Member, StateValues.LAZY],
+        ]
+        response_body, from_token = self.do_sync(
+            sync_body, since=from_token, tok=user1_tok
+        )
+
+        state_map = self.get_success(
+            self.storage_controllers.state.get_current_state(room_id1)
+        )
+
+        # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+        # and we haven't seen any membership before this sync so we should see both
+        # users.
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1]["required_state"],
+            {
+                state_map[(EventTypes.Member, user2_id)],
+                state_map[(EventTypes.Member, user4_id)],
+            },
+            exact=True,
+        )
+        self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+        # Send a message so the room comes down sync.
+        self.helper.send(room_id1, "7", tok=user2_tok)
+        self.helper.send(room_id1, "8", tok=user4_tok)
+        self.helper.send(room_id1, "9", tok=user4_tok)
+
+        # Make another incremental Sliding Sync request
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+        # but since we've seen both memberships in the last sync, they shouldn't appear
+        # again.
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1].get("required_state", []),
+            set(),
+            exact=True,
+        )
+        self.assertIsNone(response_body["rooms"][room_id1].get("invite_state"))
+
+    def test_rooms_required_state_expand_retract_expand_lazy_loading_room_members_incremental_sync(
+        self,
+    ) -> None:
+        """
+        Test that when we expand the `required_state` to include lazy-loading room
+        members, it returns people relevant to the timeline.
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+        user3_id = self.register_user("user3", "pass")
+        user3_tok = self.login(user3_id, "pass")
+        user4_id = self.register_user("user4", "pass")
+        user4_tok = self.login(user4_id, "pass")
+
+        room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id1, user1_id, tok=user1_tok)
+        self.helper.join(room_id1, user3_id, tok=user3_tok)
+        self.helper.join(room_id1, user4_id, tok=user4_tok)
+
+        self.helper.send(room_id1, "1", tok=user2_tok)
+        self.helper.send(room_id1, "2", tok=user2_tok)
+        self.helper.send(room_id1, "3", tok=user2_tok)
+
+        # Make the Sliding Sync request *without* lazy loading for the room members
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 1]],
+                    "required_state": [
+                        [EventTypes.Create, ""],
+                    ],
+                    "timeline_limit": 3,
+                }
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+
+        # Send more timeline events into the room
+        self.helper.send(room_id1, "4", tok=user2_tok)
+        self.helper.send(room_id1, "5", tok=user4_tok)
+        self.helper.send(room_id1, "6", tok=user4_tok)
+
+        # Expand `required_state` and make an incremental Sliding Sync request *with*
+        # lazy-loading room members
+        sync_body["lists"]["foo-list"]["required_state"] = [
+            [EventTypes.Create, ""],
+            [EventTypes.Member, StateValues.LAZY],
+        ]
+        response_body, from_token = self.do_sync(
+            sync_body, since=from_token, tok=user1_tok
+        )
+
+        state_map = self.get_success(
+            self.storage_controllers.state.get_current_state(room_id1)
+        )
+
+        # Only user2 and user4 sent events in the last 3 events we see in the `timeline`
+        # and we haven't seen any membership before this sync so we should see both
+        # users because we're lazy-loading the room members.
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1]["required_state"],
+            {
+                state_map[(EventTypes.Member, user2_id)],
+                state_map[(EventTypes.Member, user4_id)],
+            },
+            exact=True,
+        )
+
+        # Send a message so the room comes down sync.
+        self.helper.send(room_id1, "msg", tok=user4_tok)
+
+        # Retract `required_state` and make an incremental Sliding Sync request
+        # requesting a few memberships
+        sync_body["lists"]["foo-list"]["required_state"] = [
+            [EventTypes.Create, ""],
+            [EventTypes.Member, StateValues.ME],
+            [EventTypes.Member, user2_id],
+        ]
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+
+        state_map = self.get_success(
+            self.storage_controllers.state.get_current_state(room_id1)
+        )
+
+        # We've seen user2's membership in the last sync so we shouldn't see it here
+        # even though it's requested. We should only see user1's membership.
+        self._assertRequiredStateIncludes(
+            response_body["rooms"][room_id1]["required_state"],
+            {
+                state_map[(EventTypes.Member, user1_id)],
+            },
+            exact=True,
+        )
+
     def test_rooms_required_state_me(self) -> None:
         """
         Test `rooms.required_state` correctly handles $ME.
@@ -561,7 +810,7 @@ class SlidingSyncRoomsRequiredStateTestCase(SlidingSyncBase):
         )
         self.helper.leave(room_id1, user3_id, tok=user3_tok)
 
-        # Make the Sliding Sync request with lazy loading for the room members
+        # Make an incremental Sliding Sync request
         response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
 
         # Only user2 and user3 sent events in the 3 events we see in the `timeline`
-- 
GitLab