Skip to content
Snippets Groups Projects
Unverified Commit c486fa5f authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Add some missing type hints to cache datastore. (#12216)

parent 86965605
Branches
Tags
No related merge requests found
Add missing type hints for cache storage.
...@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( ...@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream, EventsStream,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow,
) )
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
...@@ -31,6 +32,7 @@ from synapse.storage.database import ( ...@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_caches_txn(txn): def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to # We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache # send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine. # invalidations are idempotent, so duplicates are fine.
...@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn
) )
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == EventsStream.NAME: if stream_name == EventsStream.NAME:
for row in rows: for row in rows:
self._process_event_stream_row(token, row) self._process_event_stream_row(token, row)
...@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def _process_event_stream_row(self, token, row): def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data
if row.type == EventsStreamEventRow.TypeId: if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event( self._invalidate_caches_for_event(
token, token,
data.event_id, data.event_id,
...@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False, backfilled=False,
) )
elif row.type == EventsStreamCurrentStateRow.TypeId: elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed( assert isinstance(data, EventsStreamCurrentStateRow)
row.data.room_id, token self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
)
if data.type == EventTypes.Member: if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate( self.get_rooms_for_user_with_stream_ordering.invalidate(
...@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event( def _invalidate_caches_for_event(
self, self,
stream_ordering, stream_ordering: int,
event_id, event_id: str,
room_id, room_id: str,
etype, etype: str,
state_key, state_key: Optional[str],
redacts, redacts: Optional[str],
relates_to, relates_to: Optional[str],
backfilled, backfilled: bool,
): ) -> None:
self._invalidate_get_event_cache(event_id) self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id)) self.have_seen_event.invalidate((room_id, event_id))
...@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,)) self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,))
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
...@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys, keys,
) )
def _invalidate_cache_and_stream(self, txn, cache_func, keys): def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
...@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys) txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(self, txn, cache_func): def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves """Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
""" """
...@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
) )
def _send_invalidation_to_replication( def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]] self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
): ) -> None:
"""Notifies replication that given cache has been invalidated. """Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally. Note that this does *not* invalidate the cache locally.
...@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ...@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name, "instance_name": self._instance_name,
"cache_func": cache_name, "cache_func": cache_name,
"keys": keys, "keys": keys,
"invalidation_ts": self.clock.time_msec(), "invalidation_ts": self._clock.time_msec(),
}, },
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment