From e4a1f271b9365e2dddbb9205b10ebe36eed12e62 Mon Sep 17 00:00:00 2001
From: Eric Eastwood <eric.eastwood@beta.gouv.fr>
Date: Wed, 11 Sep 2024 12:13:54 -0500
Subject: [PATCH] Sliding Sync: Make sure we get up-to-date information from
 `get_sliding_sync_rooms_for_user(...)` (#17692)

We need to bust the `get_sliding_sync_rooms_for_user`
cache when the room encryption is updated and any
other field that is used in the query.

Follow-up to https://github.com/element-hq/synapse/pull/17630

- Bust cache for membership change (cross-reference
`get_rooms_for_user`)
- Bust cache for room `encryption` (cross-reference
`get_room_encryption`)
- Bust cache for `forgotten` (cross-reference
`did_forget`/`get_forgotten_rooms_for_user`)
---
 changelog.d/17692.bugfix                      |   1 +
 synapse/storage/_base.py                      |   1 +
 synapse/storage/databases/main/cache.py       |  14 ++
 synapse/storage/databases/main/roommember.py  |   9 +-
 .../client/sliding_sync/test_sliding_sync.py  | 191 +++++++++++++-----
 5 files changed, 160 insertions(+), 56 deletions(-)
 create mode 100644 changelog.d/17692.bugfix

diff --git a/changelog.d/17692.bugfix b/changelog.d/17692.bugfix
new file mode 100644
index 0000000000..84e0754a99
--- /dev/null
+++ b/changelog.d/17692.bugfix
@@ -0,0 +1 @@
+Make sure we get up-to-date state information when using the new Sliding Sync tables to derive room membership.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d6deb077c8..e14d711c76 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -136,6 +136,7 @@ class SQLBaseStore(metaclass=ABCMeta):
         self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
         self._attempt_to_invalidate_cache("get_room_type", (room_id,))
         self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
+        self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
 
     def _invalidate_state_caches_all(self, room_id: str) -> None:
         """Invalidates caches that are based on the current state, but does
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index b0e30daee5..37c865a8e7 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -41,6 +41,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
+from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import CachedFunction
@@ -271,12 +272,20 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 self._attempt_to_invalidate_cache(
                     "get_rooms_for_user", (data.state_key,)
                 )
+                self._attempt_to_invalidate_cache(
+                    "get_sliding_sync_rooms_for_user", None
+                )
             elif data.type == EventTypes.RoomEncryption:
                 self._attempt_to_invalidate_cache(
                     "get_room_encryption", (data.room_id,)
                 )
             elif data.type == EventTypes.Create:
                 self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
+
+            if (data.type, data.state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
+                self._attempt_to_invalidate_cache(
+                    "get_sliding_sync_rooms_for_user", None
+                )
         elif row.type == EventsStreamAllStateRow.TypeId:
             assert isinstance(data, EventsStreamAllStateRow)
             # Similar to the above, but the entire caches are invalidated. This is
@@ -285,6 +294,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._attempt_to_invalidate_cache("get_rooms_for_user", None)
             self._attempt_to_invalidate_cache("get_room_type", (data.room_id,))
             self._attempt_to_invalidate_cache("get_room_encryption", (data.room_id,))
+            self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
         else:
             raise Exception("Unknown events stream row type %s" % (row.type,))
 
@@ -365,6 +375,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         elif etype == EventTypes.RoomEncryption:
             self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
 
+        if (etype, state_key) in SLIDING_SYNC_RELEVANT_STATE_SET:
+            self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
+
         if relates_to:
             self._attempt_to_invalidate_cache(
                 "get_relations_for_event",
@@ -477,6 +490,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         self._attempt_to_invalidate_cache(
             "get_current_hosts_in_room_ordered", (room_id,)
         )
+        self._attempt_to_invalidate_cache("get_sliding_sync_rooms_for_user", None)
         self._attempt_to_invalidate_cache("did_forget", None)
         self._attempt_to_invalidate_cache("get_forgotten_rooms_for_user", None)
         self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index db03729cfe..1fc2d7ba1e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -1365,6 +1365,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             self._invalidate_cache_and_stream(
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
+            self._invalidate_cache_and_stream(
+                txn, self.get_sliding_sync_rooms_for_user, (user_id,)
+            )
 
         await self.db_pool.runInteraction("forget_membership", f)
 
@@ -1410,6 +1413,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
         def get_sliding_sync_rooms_for_user_txn(
             txn: LoggingTransaction,
         ) -> Dict[str, RoomsForUserSlidingSync]:
+            # XXX: If you use any new columns that can change (like from
+            # `sliding_sync_joined_rooms` or `forgotten`), make sure to bust the
+            # `get_sliding_sync_rooms_for_user` cache in the appropriate places (and add
+            # tests).
             sql = """
                 SELECT m.room_id, m.sender, m.membership, m.membership_event_id,
                     r.room_version,
@@ -1432,7 +1439,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
                     room_version_id=row[4],
                     event_pos=PersistedEventPosition(row[5], row[6]),
                     room_type=row[7],
-                    is_encrypted=row[8],
+                    is_encrypted=bool(row[8]),
                 )
                 for row in txn
             }
diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py
index 930cb5ef45..9e23dbe522 100644
--- a/tests/rest/client/sliding_sync/test_sliding_sync.py
+++ b/tests/rest/client/sliding_sync/test_sliding_sync.py
@@ -722,43 +722,37 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.helper.join(space_room_id, user1_id, tok=user1_tok)
 
         # Make an initial Sliding Sync request
-        channel = self.make_request(
-            "POST",
-            self.sync_endpoint,
-            {
-                "lists": {
-                    "all-list": {
-                        "ranges": [[0, 99]],
-                        "required_state": [],
-                        "timeline_limit": 0,
-                        "filters": {},
-                    },
-                    "foo-list": {
-                        "ranges": [[0, 99]],
-                        "required_state": [],
-                        "timeline_limit": 1,
-                        "filters": {
-                            "is_encrypted": True,
-                            "room_types": [RoomTypes.SPACE],
-                        },
+        sync_body = {
+            "lists": {
+                "all-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 0,
+                    "filters": {},
+                },
+                "foo-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 1,
+                    "filters": {
+                        "is_encrypted": True,
+                        "room_types": [RoomTypes.SPACE],
                     },
-                }
-            },
-            access_token=user1_tok,
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
-        from_token = channel.json_body["pos"]
+                },
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
 
         # Make sure the response has the lists we requested
         self.assertListEqual(
-            list(channel.json_body["lists"].keys()),
+            list(response_body["lists"].keys()),
             ["all-list", "foo-list"],
-            channel.json_body["lists"].keys(),
+            response_body["lists"].keys(),
         )
 
         # Make sure the lists have the correct rooms
         self.assertListEqual(
-            list(channel.json_body["lists"]["all-list"]["ops"]),
+            list(response_body["lists"]["all-list"]["ops"]),
             [
                 {
                     "op": "SYNC",
@@ -768,7 +762,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
             ],
         )
         self.assertListEqual(
-            list(channel.json_body["lists"]["foo-list"]["ops"]),
+            list(response_body["lists"]["foo-list"]["ops"]),
             [
                 {
                     "op": "SYNC",
@@ -783,35 +777,30 @@ class SlidingSyncTestCase(SlidingSyncBase):
         self.helper.leave(space_room_id, user2_id, tok=user2_tok)
 
         # Make an incremental Sliding Sync request
-        channel = self.make_request(
-            "POST",
-            self.sync_endpoint + f"?pos={from_token}",
-            {
-                "lists": {
-                    "all-list": {
-                        "ranges": [[0, 99]],
-                        "required_state": [],
-                        "timeline_limit": 0,
-                        "filters": {},
-                    },
-                    "foo-list": {
-                        "ranges": [[0, 99]],
-                        "required_state": [],
-                        "timeline_limit": 1,
-                        "filters": {
-                            "is_encrypted": True,
-                            "room_types": [RoomTypes.SPACE],
-                        },
+        sync_body = {
+            "lists": {
+                "all-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 0,
+                    "filters": {},
+                },
+                "foo-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 1,
+                    "filters": {
+                        "is_encrypted": True,
+                        "room_types": [RoomTypes.SPACE],
                     },
-                }
-            },
-            access_token=user1_tok,
-        )
-        self.assertEqual(channel.code, 200, channel.json_body)
+                },
+            }
+        }
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
 
         # Make sure the lists have the correct rooms even though we `newly_left`
         self.assertListEqual(
-            list(channel.json_body["lists"]["all-list"]["ops"]),
+            list(response_body["lists"]["all-list"]["ops"]),
             [
                 {
                     "op": "SYNC",
@@ -821,7 +810,7 @@ class SlidingSyncTestCase(SlidingSyncBase):
             ],
         )
         self.assertListEqual(
-            list(channel.json_body["lists"]["foo-list"]["ops"]),
+            list(response_body["lists"]["foo-list"]["ops"]),
             [
                 {
                     "op": "SYNC",
@@ -831,6 +820,98 @@ class SlidingSyncTestCase(SlidingSyncBase):
             ],
         )
 
+    def test_filter_is_encrypted_up_to_date(self) -> None:
+        """
+        Make sure we get up-to-date `is_encrypted` status for a joined room
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 0,
+                    "filters": {
+                        "is_encrypted": True,
+                    },
+                },
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+        self.assertIncludes(
+            set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+            set(),
+            exact=True,
+        )
+
+        # Update the encryption status
+        self.helper.send_state(
+            room_id,
+            EventTypes.RoomEncryption,
+            {EventContentFields.ENCRYPTION_ALGORITHM: "m.megolm.v1.aes-sha2"},
+            tok=user1_tok,
+        )
+
+        # We should see the room now because it's encrypted
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+        self.assertIncludes(
+            set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+            {room_id},
+            exact=True,
+        )
+
+    def test_forgotten_up_to_date(self) -> None:
+        """
+        Make sure we get up-to-date `forgotten` status for rooms
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+
+        # User1 is banned from the room (was never in the room)
+        self.helper.ban(room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+        sync_body = {
+            "lists": {
+                "foo-list": {
+                    "ranges": [[0, 99]],
+                    "required_state": [],
+                    "timeline_limit": 0,
+                    "filters": {},
+                },
+            }
+        }
+        response_body, from_token = self.do_sync(sync_body, tok=user1_tok)
+        self.assertIncludes(
+            set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+            {room_id},
+            exact=True,
+        )
+
+        # User1 forgets the room
+        channel = self.make_request(
+            "POST",
+            f"/_matrix/client/r0/rooms/{room_id}/forget",
+            content={},
+            access_token=user1_tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # We should no longer see the forgotten room
+        response_body, _ = self.do_sync(sync_body, since=from_token, tok=user1_tok)
+        self.assertIncludes(
+            set(response_body["lists"]["foo-list"]["ops"][0]["room_ids"]),
+            set(),
+            exact=True,
+        )
+
     def test_sort_list(self) -> None:
         """
         Test that the `lists` are sorted by `stream_ordering`
-- 
GitLab