Skip to content
Snippets Groups Projects
Commit 5db03535 authored by Erik Johnston's avatar Erik Johnston
Browse files

Add StateGroupStorage interface

parent b7fe62b7
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,7 @@ stored in `synapse.storage.schema`. ...@@ -30,6 +30,7 @@ stored in `synapse.storage.schema`.
from synapse.storage.data_stores import DataStores from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main import DataStore from synapse.storage.data_stores.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.state import StateGroupStorage
__all__ = ["DataStores", "DataStore"] __all__ = ["DataStores", "DataStore"]
...@@ -45,6 +46,7 @@ class Storage(object): ...@@ -45,6 +46,7 @@ class Storage(object):
self.main = stores.main self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores) self.persistence = EventsPersistenceStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
def are_all_users_on_domain(txn, database_engine, domain): def are_all_users_on_domain(txn, database_engine, domain):
......
...@@ -550,7 +550,7 @@ class EventsPersistenceStorage(object): ...@@ -550,7 +550,7 @@ class EventsPersistenceStorage(object):
if missing_event_ids: if missing_event_ids:
# Now pull out the state groups for any missing events from DB # Now pull out the state groups for any missing events from DB
event_to_groups = yield self.state_store._get_state_group_for_events( event_to_groups = yield self.main_store._get_state_group_for_events(
missing_event_ids missing_event_ids
) )
event_id_to_state_group.update(event_to_groups) event_id_to_state_group.update(event_to_groups)
......
...@@ -19,6 +19,8 @@ from six import iteritems, itervalues ...@@ -19,6 +19,8 @@ from six import iteritems, itervalues
import attr import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -322,3 +324,233 @@ class StateFilter(object): ...@@ -322,3 +324,233 @@ class StateFilter(object):
) )
return member_filter, non_member_filter return member_filter, non_member_filter
class StateGroupStorage(object):
"""High level interface to fetching state for event.
"""
def __init__(self, hs, stores):
self.stores = stores
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
(prev_group, delta_ids), where both may be None.
"""
return self.stores.main.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(groups)
return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
"""Get the event IDs of all the state in the given state group
Args:
state_group (int)
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events(
[
ev_id
for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
get_prev_content=False,
)
return {
group: [
state_event_map[v]
for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
}
def _get_state_groups_from_groups(self, groups, state_filter):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
return self.stores.main._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False,
)
event_to_state = {
event_id: {
k: state_event_map[v]
for k, v in iteritems(group_to_state[group])
if v in state_event_map
}
for event_id, group in iteritems(event_to_groups)
}
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self.stores.main._get_state_for_groups(
groups, state_filter
)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in iteritems(event_to_groups)
}
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
return self.stores.main._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
Returns:
Deferred[int]: The state group ID
"""
return self.stores.main.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment