From 83fc225030253fc222af62fc96c9f4db8d4738e8 Mon Sep 17 00:00:00 2001
From: Eric Eastwood <eric.eastwood@beta.gouv.fr>
Date: Thu, 19 Sep 2024 06:43:26 -0500
Subject: [PATCH] Sliding Sync: Add cache to `get_tags_for_room(...)` (#17730)

Add cache to `get_tags_for_room(...)`

This helps Sliding Sync because `get_tags_for_room(...)` is going to be
used in https://github.com/element-hq/synapse/pull/17695

Essentially, we're just trying to match `get_account_data_for_room(...)`
which already has a tree cache.
---
 changelog.d/17730.misc                        |  1 +
 synapse/handlers/account_data.py              |  4 ++--
 synapse/storage/databases/main/cache.py       |  1 +
 synapse/storage/databases/main/tags.py        | 19 ++++++++++++++++---
 .../test_resource_limits_server_notices.py    |  2 +-
 5 files changed, 21 insertions(+), 6 deletions(-)
 create mode 100644 changelog.d/17730.misc

diff --git a/changelog.d/17730.misc b/changelog.d/17730.misc
new file mode 100644
index 0000000000..56da7bfd1a
--- /dev/null
+++ b/changelog.d/17730.misc
@@ -0,0 +1 @@
+Add cache to `get_tags_for_room(...)`.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 97a463d8d0..228132db48 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -33,7 +33,7 @@ from synapse.replication.http.account_data import (
     ReplicationRemoveUserAccountDataRestServlet,
 )
 from synapse.streams import EventSource
-from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -253,7 +253,7 @@ class AccountDataHandler:
             return response["max_stream_id"]
 
     async def add_tag_to_room(
-        self, user_id: str, room_id: str, tag: str, content: JsonDict
+        self, user_id: str, room_id: str, tag: str, content: JsonMapping
     ) -> int:
         """Add a tag to a room for a user.
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 37c865a8e7..32c3472e58 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -471,6 +471,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self._attempt_to_invalidate_cache("get_account_data_for_room", None)
         self._attempt_to_invalidate_cache("get_account_data_for_room_and_type", None)
+        self._attempt_to_invalidate_cache("get_tags_for_room", None)
         self._attempt_to_invalidate_cache("get_aliases_for_room", (room_id,))
         self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,))
         self._attempt_to_invalidate_cache("_get_forward_extremeties_for_room", None)
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b5af294384..b498cb9625 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -158,9 +158,10 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
         return results
 
+    @cached(num_args=2, tree=True)
     async def get_tags_for_room(
         self, user_id: str, room_id: str
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonMapping]:
         """Get all the tags for the given room
 
         Args:
@@ -182,7 +183,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         return {tag: db_to_json(content) for tag, content in rows}
 
     async def add_tag_to_room(
-        self, user_id: str, room_id: str, tag: str, content: JsonDict
+        self, user_id: str, room_id: str, tag: str, content: JsonMapping
     ) -> int:
         """Add a tag to a room for a user.
 
@@ -213,6 +214,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
+        self.get_tags_for_room.invalidate((user_id, room_id))
 
         return self._account_data_id_gen.get_current_token()
 
@@ -237,6 +239,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
+        self.get_tags_for_room.invalidate((user_id, room_id))
 
         return self._account_data_id_gen.get_current_token()
 
@@ -290,9 +293,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
         rows: Iterable[Any],
     ) -> None:
         if stream_name == AccountDataStream.NAME:
-            for row in rows:
+            # Cast is safe because the `AccountDataStream` should only be giving us
+            # `AccountDataStreamRow`
+            account_data_stream_rows: List[AccountDataStream.AccountDataStreamRow] = (
+                cast(List[AccountDataStream.AccountDataStreamRow], rows)
+            )
+
+            for row in account_data_stream_rows:
                 if row.data_type == AccountDataTypes.TAG:
                     self.get_tags_for_user.invalidate((row.user_id,))
+                    if row.room_id:
+                        self.get_tags_for_room.invalidate((row.user_id, row.room_id))
+                    else:
+                        self.get_tags_for_room.invalidate((row.user_id,))
                     self._account_data_stream_cache.entity_has_changed(
                         row.user_id, token
                     )
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 0e3e4f7293..997ee7b91b 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -89,7 +89,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
             return_value="!something:localhost"
         )
         self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None)  # type: ignore[method-assign]
-        self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})  # type: ignore[method-assign]
+        self._rlsn._store.get_tags_for_room = AsyncMock(return_value={})
 
     @override_config({"hs_disabled": True})
     def test_maybe_send_server_notice_disabled_hs(self) -> None:
-- 
GitLab