From 8a380d0fe24edd746256d652836ec27003a05e7e Mon Sep 17 00:00:00 2001
From: Erik Johnston <erik@matrix.org>
Date: Tue, 18 Feb 2020 15:39:09 +0000
Subject: [PATCH] Increase perf of `get_auth_chain_ids` used in state res v2.
 (#6937)

We do this by moving the recursive query to be fully in the DB.
---
 changelog.d/6937.misc                         |  1 +
 .../data_stores/main/event_federation.py      | 23 +++++++++++++++++++
 2 files changed, 24 insertions(+)
 create mode 100644 changelog.d/6937.misc

diff --git a/changelog.d/6937.misc b/changelog.d/6937.misc
new file mode 100644
index 0000000000..6d00e58654
--- /dev/null
+++ b/changelog.d/6937.misc
@@ -0,0 +1 @@
+Increase perf of `get_auth_chain_ids` used in state res v2.
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..1746f40adf 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -61,6 +62,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+        if isinstance(self.database_engine, PostgresEngine):
+            # For efficiency we make the database do this if we can.
+            sql = """
+                WITH RECURSIVE auth_chain(event_id) AS (
+                    SELECT auth_id FROM event_auth WHERE event_id = ANY(?)
+                    UNION
+                    SELECT auth_id FROM event_auth
+                    INNER JOIN auth_chain USING (event_id)
+                )
+                SELECT event_id FROM auth_chain
+            """
+            txn.execute(sql, (list(event_ids),))
+
+            results = set(event_id for event_id, in txn)
+
+            if include_given:
+                results.update(event_ids)
+
+            return list(results)
+
+        # Database doesn't necessarily support recursive CTE, so we fall
+        # back to do doing it manually.
         if include_given:
             results = set(event_ids)
         else:
-- 
GitLab