Skip to content
Snippets Groups Projects
Unverified Commit eb609c65 authored by reivilibre's avatar reivilibre Committed by GitHub
Browse files

Fix bug in `StateFilter.return_expanded()` and add some tests. (#12016)

parent 31a298fe
No related branches found
No related tags found
No related merge requests found
Fix bug in `StateFilter.return_expanded()` and add some tests.
\ No newline at end of file
...@@ -204,13 +204,16 @@ class StateFilter: ...@@ -204,13 +204,16 @@ class StateFilter:
if get_all_members: if get_all_members:
# We want to return everything. # We want to return everything.
return StateFilter.all() return StateFilter.all()
else: elif EventTypes.Member in self.types:
# We want to return all non-members, but only particular # We want to return all non-members, but only particular
# memberships # memberships
return StateFilter( return StateFilter(
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True, include_others=True,
) )
else:
# We want to return all non-members
return _ALL_NON_MEMBER_STATE_FILTER
def make_sql_filter_clause(self) -> Tuple[str, List[str]]: def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause. """Converts the filter to an SQL clause.
...@@ -528,6 +531,9 @@ class StateFilter: ...@@ -528,6 +531,9 @@ class StateFilter:
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
)
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
......
...@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): ...@@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase):
StateFilter.none(), StateFilter.none(),
StateFilter.all(), StateFilter.all(),
) )
class StateFilterTestCase(TestCase):
def test_return_expanded(self):
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)
# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)
# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)
# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment