Skip to content
Snippets Groups Projects
Unverified Commit 8eb7cb2e authored by reivilibre's avatar reivilibre Committed by GitHub
Browse files

Make StateFilter frozen so we can hash it (#10816)

Also enables Mypy for related tests.
parent 14b8c047
No related branches found
No related tags found
No related merge requests found
Make `StateFilter` frozen so it is hashable.
...@@ -86,6 +86,7 @@ files = ...@@ -86,6 +86,7 @@ files =
tests/handlers/test_sync.py, tests/handlers/test_sync.py,
tests/rest/client/test_login.py, tests/rest/client/test_login.py,
tests/rest/client/test_auth.py, tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
......
...@@ -25,12 +25,15 @@ from typing import ( ...@@ -25,12 +25,15 @@ from typing import (
) )
import attr import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
...@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__) ...@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T") T = TypeVar("T")
@attr.s(slots=True) @attr.s(slots=True, frozen=True)
class StateFilter: class StateFilter:
"""A filter used when querying for state. """A filter used when querying for state.
...@@ -53,14 +56,19 @@ class StateFilter: ...@@ -53,14 +56,19 @@ class StateFilter:
appear in `types`. appear in `types`.
""" """
types = attr.ib(type=Dict[str, Optional[Set[str]]]) types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool) include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None} # this is needed to work around the fact that StateFilter is frozen
object.__setattr__(
self,
"types",
frozendict({k: v for k, v in self.types.items() if v is not None}),
)
@staticmethod @staticmethod
def all() -> "StateFilter": def all() -> "StateFilter":
...@@ -69,7 +77,7 @@ class StateFilter: ...@@ -69,7 +77,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types=frozendict(), include_others=True)
@staticmethod @staticmethod
def none() -> "StateFilter": def none() -> "StateFilter":
...@@ -78,7 +86,7 @@ class StateFilter: ...@@ -78,7 +86,7 @@ class StateFilter:
Returns: Returns:
The new state filter. The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types=frozendict(), include_others=False)
@staticmethod @staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
...@@ -103,7 +111,12 @@ class StateFilter: ...@@ -103,7 +111,12 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict) return StateFilter(
types=frozendict(
(k, frozenset(v) if v is not None else None)
for k, v in type_dict.items()
)
)
@staticmethod @staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
...@@ -116,7 +129,10 @@ class StateFilter: ...@@ -116,7 +129,10 @@ class StateFilter:
Returns: Returns:
The new state filter The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(
types=frozendict({EventTypes.Member: frozenset(members)}),
include_others=True,
)
def return_expanded(self) -> "StateFilter": def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
...@@ -173,7 +189,7 @@ class StateFilter: ...@@ -173,7 +189,7 @@ class StateFilter:
# We want to return all non-members, but only particular # We want to return all non-members, but only particular
# memberships # memberships
return StateFilter( return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]}, types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True, include_others=True,
) )
...@@ -245,14 +261,15 @@ class StateFilter: ...@@ -245,14 +261,15 @@ class StateFilter:
return len(self.concrete_types()) return len(self.concrete_types())
def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
"""Returns the state filtered with by this StateFilter """Returns the state filtered with by this StateFilter.
Args: Args:
state: The state map to filter state: The state map to filter
Returns: Returns:
The filtered state map The filtered state map.
This is a copy, so it's safe to mutate.
""" """
if self.is_full(): if self.is_full():
return dict(state_dict) return dict(state_dict)
...@@ -324,14 +341,16 @@ class StateFilter: ...@@ -324,14 +341,16 @@ class StateFilter:
if state_keys is None: if state_keys is None:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter({EventTypes.Member: state_keys}) member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others: elif self.include_others:
member_filter = StateFilter.all() member_filter = StateFilter.all()
else: else:
member_filter = StateFilter.none() member_filter = StateFilter.none()
non_member_filter = StateFilter( non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member}, types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others, include_others=self.include_others,
) )
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import logging import logging
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
...@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types=frozendict(
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
),
include_others=True, include_others=True,
), ),
) )
...@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event( self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}),
include_others=True,
), ),
) )
) )
...@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
...@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
...@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
...@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
...@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
...@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
...@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
...@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
...@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types=frozendict({EventTypes.Member: frozenset()}), include_others=True
), ),
) )
...@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
...@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True types=frozendict({EventTypes.Member: None}), include_others=True
), ),
) )
...@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
...@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=True,
), ),
) )
...@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache, self.state_datastore._state_group_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
...@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase): ...@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache, self.state_datastore._state_group_members_cache,
group, group,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
include_others=False,
), ),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment