From e4074749d296ca703c857dedf00c6d0543a0679b Mon Sep 17 00:00:00 2001
From: Andrew Ferrazzutti <andrewf@element.io>
Date: Mon, 10 Feb 2025 10:37:05 -0500
Subject: [PATCH] Overload "allow_none" on DB pool static method (#17616)

### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct
(run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))

---------

Co-authored-by: Quentin Gliech <quenting@element.io>
---
 changelog.d/17616.misc                        |  1 +
 synapse/storage/database.py                   | 24 +++++++++++++++----
 .../storage/databases/main/e2e_room_keys.py   | 23 ++++++++----------
 .../storage/databases/main/registration.py    | 15 ++++++------
 synapse/storage/databases/main/stream.py      | 15 ++++++------
 5 files changed, 45 insertions(+), 33 deletions(-)
 create mode 100644 changelog.d/17616.misc

diff --git a/changelog.d/17616.misc b/changelog.d/17616.misc
new file mode 100644
index 0000000000..8250832dcd
--- /dev/null
+++ b/changelog.d/17616.misc
@@ -0,0 +1 @@
+Overload DatabasePool.simple_select_one_txn to return non-None when the allow_none parameter is False.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index cb4a5857be..8272e39340 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2159,10 +2159,26 @@ class DatabasePool:
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    # Ideally we could use the overload decorator here to specify that the
-    # return type is only optional if allow_none is True, but this does not work
-    # when you call a static method from an instance.
-    # See https://github.com/python/mypy/issues/7781
+    @overload
+    @staticmethod
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Collection[str],
+        allow_none: Literal[False] = False,
+    ) -> Tuple[Any, ...]: ...
+
+    @overload
+    @staticmethod
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Collection[str],
+        allow_none: Literal[True] = True,
+    ) -> Optional[Tuple[Any, ...]]: ...
+
     @staticmethod
     def simple_select_one_txn(
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index c2c93e12d9..a618a2de69 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -510,19 +510,16 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
                     # it isn't there.
                     raise StoreError(404, "No backup with that version exists")
 
-            row = cast(
-                Tuple[int, str, str, Optional[int]],
-                self.db_pool.simple_select_one_txn(
-                    txn,
-                    table="e2e_room_keys_versions",
-                    keyvalues={
-                        "user_id": user_id,
-                        "version": this_version,
-                        "deleted": 0,
-                    },
-                    retcols=("version", "algorithm", "auth_data", "etag"),
-                    allow_none=False,
-                ),
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                table="e2e_room_keys_versions",
+                keyvalues={
+                    "user_id": user_id,
+                    "version": this_version,
+                    "deleted": 0,
+                },
+                retcols=("version", "algorithm", "auth_data", "etag"),
+                allow_none=False,
             )
             return {
                 "auth_data": db_to_json(row[2]),
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d7cbe33411..8380930c70 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1510,15 +1510,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
             # about None not being indexable.
-            pending, completed = cast(
-                Tuple[int, int],
-                self.db_pool.simple_select_one_txn(
-                    txn,
-                    "registration_tokens",
-                    keyvalues={"token": token},
-                    retcols=["pending", "completed"],
-                ),
+            row = self.db_pool.simple_select_one_txn(
+                txn,
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=("pending", "completed"),
             )
+            pending = int(row[0])
+            completed = int(row[1])
 
             # Decrement pending and increment completed
             self.db_pool.simple_update_one_txn(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index b4258a4436..40b0bff164 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1837,15 +1837,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        stream_ordering, topological_ordering = cast(
-            Tuple[int, int],
-            self.db_pool.simple_select_one_txn(
-                txn,
-                "events",
-                keyvalues={"event_id": event_id, "room_id": room_id},
-                retcols=["stream_ordering", "topological_ordering"],
-            ),
+        row = self.db_pool.simple_select_one_txn(
+            txn,
+            "events",
+            keyvalues={"event_id": event_id, "room_id": room_id},
+            retcols=("stream_ordering", "topological_ordering"),
         )
+        stream_ordering = int(row[0])
+        topological_ordering = int(row[1])
 
         # Paginating backwards includes the event at the token, but paginating
         # forward doesn't.
-- 
GitLab