From adda2a4613cb67d4329681e5c5eb7867a17e021d Mon Sep 17 00:00:00 2001
From: Eric Eastwood <eric.eastwood@beta.gouv.fr>
Date: Mon, 14 Oct 2024 07:47:35 -0500
Subject: [PATCH] Sliding Sync: Slight optimization when fetching state for the
 room (`get_events_as_list(...)`) (#17718)

Spawning from @kegsay [pointing
out](https://matrix.to/#/!cnVVNLKqgUzNTOFQkz:matrix.org/$ExOO7J8uPUQSyH-9Uxc_QCa8jlXX9uK4VRtkSC0EI3o?via=element.io&via=matrix.org&via=jki.re)
that the Sliding Sync endpoint doesn't handle a large room with a lot of
state well on initial sync (requesting all state via `required_state: [
["*","*"] ]`) (it just takes forever).

After investigating further, the slow part is just
`get_events_as_list(...)` fetching all of the current state ID's out for
the room (which can be 100k+ events for rooms with a lot of membership).
This is just a slow thing in Synapse in general and the same thing
happens in Sync v2 or the `/state` endpoint.


---

The only idea I had to improve things was to use `batch_iter` to only
try fetching a fixed amount at a time instead of working with large
maps, lists, and sets. This doesn't seem to have much effect though.

There is already a `batch_iter(event_ids, 200)` in
`_fetch_event_rows(...)` for when we actually have to touch the database
and that's inside a queue to deduplicate work.

I did notice one slight optimization to use `get_events_as_list(...)`
directly instead of `get_events(...)`. `get_events(...)` just turns the
result from `get_events_as_list(...)` into a dict and since we're just
iterating over the events, we don't need the dict/map.
---
 changelog.d/17718.misc                        |  1 +
 synapse/handlers/sliding_sync/__init__.py     |  8 ++-
 .../storage/databases/main/events_worker.py   | 41 ++++++++++++++--
 .../databases/main/test_events_worker.py      | 49 ++++++++++++++++++-
 4 files changed, 89 insertions(+), 10 deletions(-)
 create mode 100644 changelog.d/17718.misc

diff --git a/changelog.d/17718.misc b/changelog.d/17718.misc
new file mode 100644
index 0000000000..ea73a03f53
--- /dev/null
+++ b/changelog.d/17718.misc
@@ -0,0 +1 @@
+Slight optimization when fetching state/events for Sliding Sync.
diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py
index 39dba4ff98..a1a6728fb9 100644
--- a/synapse/handlers/sliding_sync/__init__.py
+++ b/synapse/handlers/sliding_sync/__init__.py
@@ -452,13 +452,11 @@ class SlidingSyncHandler:
             to_token=to_token,
         )
 
-        event_map = await self.store.get_events(list(state_ids.values()))
+        events = await self.store.get_events_as_list(list(state_ids.values()))
 
         state_map = {}
-        for key, event_id in state_ids.items():
-            event = event_map.get(event_id)
-            if event:
-                state_map[key] = event
+        for event in events:
+            state_map[(event.type, event.state_key)] = event
 
         return state_map
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c029228422..403407068c 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -61,7 +61,13 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
-from synapse.logging.opentracing import start_active_span, tag_args, trace
+from synapse.logging.opentracing import (
+    SynapseTags,
+    set_tag,
+    start_active_span,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -525,6 +531,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
+    @trace
     async def get_events(
         self,
         event_ids: Collection[str],
@@ -556,6 +563,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             A mapping from event_id to event.
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
@@ -603,6 +615,10 @@ class EventsWorkerStore(SQLBaseStore):
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
 
         if not event_ids:
             return []
@@ -723,10 +739,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
+    @trace
     @cancellable
     async def get_unredacted_events_from_cache_or_db(
         self,
-        event_ids: Iterable[str],
+        event_ids: Collection[str],
         allow_rejected: bool = False,
     ) -> Dict[str, EventCacheEntry]:
         """Fetch a bunch of events from the cache or the database.
@@ -748,6 +765,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         # Shortcut: check if we have any events in the *in memory* cache - this function
         # may be called repeatedly for the same event so at this point we cannot reach
         # out to any external cache for performance reasons. The external cache is
@@ -936,7 +958,7 @@ class EventsWorkerStore(SQLBaseStore):
             events, update_metrics=update_metrics
         )
 
-        missing_event_ids = (e for e in events if e not in event_map)
+        missing_event_ids = [e for e in events if e not in event_map]
         event_map.update(
             await self._get_events_from_external_cache(
                 events=missing_event_ids,
@@ -946,8 +968,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_map
 
+    @trace
     async def _get_events_from_external_cache(
-        self, events: Iterable[str], update_metrics: bool = True
+        self, events: Collection[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
         """Fetch events from any configured external cache.
 
@@ -957,6 +980,10 @@ class EventsWorkerStore(SQLBaseStore):
             events: list of event_ids to fetch
             update_metrics: Whether to update the cache hit ratio metrics
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "events.length",
+            str(len(events)),
+        )
         event_map = {}
 
         for event_id in events:
@@ -1222,6 +1249,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire_errback, e)
 
+    @trace
     async def _get_events_from_db(
         self, event_ids: Collection[str]
     ) -> Dict[str, EventCacheEntry]:
@@ -1240,6 +1268,11 @@ class EventsWorkerStore(SQLBaseStore):
             map from event id to result. May return extra events which
             weren't asked for.
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
 
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index fd1f5e7fd5..104d141a72 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -20,7 +20,7 @@
 #
 import json
 from contextlib import contextmanager
-from typing import Generator, List, Tuple
+from typing import Generator, List, Set, Tuple
 from unittest import mock
 
 from twisted.enterprise.adbapi import ConnectionPool
@@ -295,6 +295,53 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
 
 
+class GetEventsTestCase(unittest.HomeserverTestCase):
+    """Test `get_events(...)`/`get_events_as_list(...)`"""
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store: EventsWorkerStore = hs.get_datastores().main
+
+    def test_get_lots_of_messages(self) -> None:
+        """Sanity check that `get_events(...)`/`get_events_as_list(...)` works"""
+        num_events = 100
+
+        user_id = self.register_user("user", "pass")
+        user_tok = self.login(user_id, "pass")
+
+        room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+        event_ids: Set[str] = set()
+        for i in range(num_events):
+            event = self.get_success(
+                inject_event(
+                    self.hs,
+                    room_id=room_id,
+                    type="m.room.message",
+                    sender=user_id,
+                    content={
+                        "body": f"foo{i}",
+                        "msgtype": "m.text",
+                    },
+                )
+            )
+            event_ids.add(event.event_id)
+
+        # Sanity check that we actually created the events
+        self.assertEqual(len(event_ids), num_events)
+
+        # This is the function under test
+        fetched_event_map = self.get_success(self.store.get_events(event_ids))
+
+        # Sanity check that we got the events back
+        self.assertIncludes(fetched_event_map.keys(), event_ids, exact=True)
+
+
 class DatabaseOutageTestCase(unittest.HomeserverTestCase):
     """Test event fetching during a database outage."""
 
-- 
GitLab