diff --git a/changelog.d/17616.misc b/changelog.d/17616.misc new file mode 100644 index 0000000000000000000000000000000000000000..8250832dcd9d4a7a7840a43030a7fa50316c40c1 --- /dev/null +++ b/changelog.d/17616.misc @@ -0,0 +1 @@ +Overload DatabasePool.simple_select_one_txn to return non-None when the allow_none parameter is False. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index cb4a5857bed41b44fdf94cd9309631f521c260fa..8272e393405e1fe3a86943f4c557c17392602f94 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -2159,10 +2159,26 @@ class DatabasePool: if rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - # Ideally we could use the overload decorator here to specify that the - # return type is only optional if allow_none is True, but this does not work - # when you call a static method from an instance. - # See https://github.com/python/mypy/issues/7781 + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[False] = False, + ) -> Tuple[Any, ...]: ... + + @overload + @staticmethod + def simple_select_one_txn( + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcols: Collection[str], + allow_none: Literal[True] = True, + ) -> Optional[Tuple[Any, ...]]: ... + @staticmethod def simple_select_one_txn( txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index c2c93e12d90e2d3219ba580300ba9bb8748a82ee..a618a2de69c7b101da497ab17f8c7b717cc1f4a6 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -510,19 +510,16 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): # it isn't there. raise StoreError(404, "No backup with that version exists") - row = cast( - Tuple[int, str, str, Optional[int]], - self.db_pool.simple_select_one_txn( - txn, - table="e2e_room_keys_versions", - keyvalues={ - "user_id": user_id, - "version": this_version, - "deleted": 0, - }, - retcols=("version", "algorithm", "auth_data", "etag"), - allow_none=False, - ), + row = self.db_pool.simple_select_one_txn( + txn, + table="e2e_room_keys_versions", + keyvalues={ + "user_id": user_id, + "version": this_version, + "deleted": 0, + }, + retcols=("version", "algorithm", "auth_data", "etag"), + allow_none=False, ) return { "auth_data": db_to_json(row[2]), diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d7cbe3341182121520a44481ddf9829838f7ffb2..8380930c70eefc255147eec5fa5290831cf974cb 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1510,15 +1510,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # Override type because the return type is only optional if # allow_none is True, and we don't want mypy throwing errors # about None not being indexable. - pending, completed = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "registration_tokens", - keyvalues={"token": token}, - retcols=["pending", "completed"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=("pending", "completed"), ) + pending = int(row[0]) + completed = int(row[1]) # Decrement pending and increment completed self.db_pool.simple_update_one_txn( diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index b4258a44362ce103ec930b5ebd2e67f852dc9bf0..40b0bff164bf3eab1111dd9eaf2aa2eeb0cd079f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1837,15 +1837,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - stream_ordering, topological_ordering = cast( - Tuple[int, int], - self.db_pool.simple_select_one_txn( - txn, - "events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcols=["stream_ordering", "topological_ordering"], - ), + row = self.db_pool.simple_select_one_txn( + txn, + "events", + keyvalues={"event_id": event_id, "room_id": room_id}, + retcols=("stream_ordering", "topological_ordering"), ) + stream_ordering = int(row[0]) + topological_ordering = int(row[1]) # Paginating backwards includes the event at the token, but paginating # forward doesn't.