Skip to content
Snippets Groups Projects
Unverified Commit f01e2ca0 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Use symbolic names for replication stream names (#7768)

This makes it much easier to find where streams are referenced.
parent a6eae69f
Branches
Tags
No related merge requests found
Use symbolic names for replication stream names.
......@@ -16,6 +16,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
......@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == "account_data":
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
......
......@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
......@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
......
......@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
......@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
......@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
......
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
......@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
......
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.database import Database
......@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
......@@ -14,20 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
......@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore
from synapse.storage.database import Database
......@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
......@@ -19,7 +19,9 @@ import logging
from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
......@@ -71,10 +73,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
elif stream_name == BackfillStream.NAME:
for row in rows:
self._invalidate_caches_for_event(
-token,
......@@ -86,7 +88,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
row.relates_to,
backfilled=True,
)
elif stream_name == "caches":
elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
......
......@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
......@@ -113,9 +115,9 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
elif stream_name == "backfill":
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment