Skip to content
Snippets Groups Projects
Commit 9a4fb457 authored by Erik Johnston's avatar Erik Johnston
Browse files

Change DataStores to accept 'database' param.

parent f3ea2f5a
No related branches found
No related tags found
No related merge requests found
Showing
with 62 additions and 43 deletions
......@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
......@@ -59,8 +60,8 @@ class FederationSenderSlaveStore(
SlavedDeviceStore,
SlavedPresenceStore,
):
def __init__(self, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs)
# We pull out the current federation stream position now so that we
# always have a known value for the federation position in memory so
......
......@@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
......@@ -60,8 +61,8 @@ class UserDirectorySlaveStore(
UserDirectoryStore,
BaseSlavedStore,
):
def __init__(self, db_conn, hs):
super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
......
......@@ -20,6 +20,7 @@ import six
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
......@@ -35,8 +36,8 @@ def __func__(inp):
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
......
......@@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id"
)
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
......@@ -21,8 +22,8 @@ from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedClientIpStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
......
......@@ -16,13 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_max_stream_id", "stream_id"
)
......
......@@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
......
......@@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.storage.data_stores.main.stream import StreamWorkerStore
from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
......@@ -59,13 +60,13 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
super(SlavedEventStore, self).__init__(db_conn, hs)
super(SlavedEventStore, self).__init__(database, db_conn, hs)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
......
......@@ -14,13 +14,14 @@
# limitations under the License.
from synapse.storage.data_stores.main.filtering import FilteringStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]
......@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.storage import DataStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
......@@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
......
......@@ -15,6 +15,7 @@
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
......@@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn)
......
......@@ -15,17 +15,18 @@
# limitations under the License.
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from synapse.storage.database import Database
from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
)
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
def get_push_rules_stream_token(self):
return (
......
......@@ -15,14 +15,15 @@
# limitations under the License.
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
......
......@@ -15,6 +15,7 @@
# limitations under the License.
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
......@@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
......
......@@ -14,14 +14,15 @@
# limitations under the License.
from synapse.storage.data_stores.main.room import RoomWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(RoomStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
......
......@@ -37,7 +37,7 @@ class SQLBaseStore(object):
per data store (and not one per physical database).
"""
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
......
......@@ -20,6 +20,7 @@ import logging
import time
from synapse.api.constants import PresenceState
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
ChainedIdGenerator,
......@@ -111,7 +112,7 @@ class DataStore(
RelationsStore,
CacheInvalidationStore,
):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
......@@ -169,7 +170,7 @@ class DataStore(
else:
self._cache_id_gen = None
super(DataStore, self).__init__(db_conn, hs)
super(DataStore, self).__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
......
......@@ -22,6 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
......@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
......@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
super(AccountDataStore, self).__init__(db_conn, hs)
super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
......
......@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
......@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
......
......@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
......@@ -33,8 +34,8 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
def __init__(self, database: Database, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.db.updates.register_background_index_update(
"user_ips_device_index",
......@@ -363,13 +364,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, db_conn, hs):
def __init__(self, database: Database, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
super(ClientIpStore, self).__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment