Skip to content
Snippets Groups Projects
Unverified Commit d38c73e9 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Skip waiting for full state if a StateFilter does not require it (#12498)

If `StateFilter` specifies a state set which we will have regardless of
state-syncing, then we may as well return it immediately.
parent 0fce474a
No related branches found
No related tags found
No related merge requests found
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,6 +16,7 @@ import logging ...@@ -15,6 +16,7 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable, Awaitable,
Callable,
Collection, Collection,
Dict, Dict,
Iterable, Iterable,
...@@ -532,6 +534,44 @@ class StateFilter: ...@@ -532,6 +534,44 @@ class StateFilter:
new_all, new_excludes, new_wildcards, new_concrete_keys new_all, new_excludes, new_wildcards, new_concrete_keys
) )
def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
"""Check if we need to wait for full state to complete to calculate this state
If we have a state filter which is completely satisfied even with partial
state, then we don't need to await_full_state before we can return it.
Args:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""
# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
# there may be circumstances in which we return a piece of state that, once we
# resync the state, we discover is invalid. For example: if it turns out that
# the sender of a piece of state wasn't actually in the room, then clearly that
# state shouldn't have been returned.
# We should at least add some tests around this to see what happens.
# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
return self.include_others
# if we're looking for *all* membership events, then we have to wait
member_state_keys = self.types[EventTypes.Member]
if member_state_keys is None:
return True
# otherwise, consider whose membership we are looking for. If it's entirely
# local users, then we don't need to wait.
for state_key in member_state_keys:
if not is_mine_id(state_key):
# remote user
return True
# local users only
return False
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter( _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
...@@ -544,6 +584,7 @@ class StateGroupStorage: ...@@ -544,6 +584,7 @@ class StateGroupStorage:
"""High level interface to fetching state for event.""" """High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"): def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
...@@ -675,7 +716,13 @@ class StateGroupStorage: ...@@ -675,7 +716,13 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
event_to_groups = await self.get_state_group_for_events(event_ids) await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
...@@ -699,7 +746,9 @@ class StateGroupStorage: ...@@ -699,7 +746,9 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
...@@ -716,7 +765,13 @@ class StateGroupStorage: ...@@ -716,7 +765,13 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
event_to_groups = await self.get_state_group_for_events(event_ids) await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
...@@ -802,7 +857,7 @@ class StateGroupStorage: ...@@ -802,7 +857,7 @@ class StateGroupStorage:
Args: Args:
event_ids: events to get state groups for event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete await_full_state: if true, will block if we do not yet have complete
state at this event. state at these events.
""" """
if await_full_state: if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids) await self._partial_state_events_tracker.await_full_state(event_ids)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment