Newer
Older
Matthew Hodgson
committed
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
Any,
Collection,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
cast,
)
from synapse.api.auth.internal import InternalAuth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
Richard van der Hoff
committed
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.types import MutableStateMap, StateMap
from synapse.types.state import StateFilter
Quentin Gliech
committed
from synapse.util import Clock
from synapse.util.macaroons import MacaroonGenerator
from .utils import MockClock, default_config
Erik Johnston
committed
name: Optional[str] = None,
type: Optional[str] = None,
state_key: Optional[str] = None,
depth: int = 2,
event_id: Optional[str] = None,
prev_events: Optional[List[Tuple[str, dict]]] = None,
**kwargs: Any,
) -> EventBase:
global _next_event_id
if not event_id:
_next_event_id += 1
event_id = "$%s:test" % (_next_event_id,)
if not name:
if state_key is not None:
d = {
"event_id": event_id,
"type": type,
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events or [],
}
if state_key is not None:
d["state_key"] = state_key
d.update(kwargs)
def __init__(self) -> None:
self._event_to_state_group: Dict[str, int] = {}
self._group_to_state: Dict[int, MutableStateMap[str]] = {}
self._event_id_to_event: Dict[str, EventBase] = {}
async def get_state_groups_ids(
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> MutableStateMap[str]:
return self._group_to_state[state_group]
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[StateMap[str]],
current_state_ids: Optional[StateMap[str]],
) -> int:
state_group = self._next_group
self._next_group += 1
if current_state_ids is None:
assert prev_group is not None
assert delta_ids is not None
current_state_ids = dict(self._group_to_state[prev_group])
current_state_ids.update(delta_ids)
self._group_to_state[state_group] = dict(current_state_ids)
async def get_events(
self, event_ids: Collection[str], **kwargs: Any
) -> Dict[str, EventBase]:
return {
e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event
}
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
return {e: False for e in event_ids}
async def get_state_group_delta(
self, name: str
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
return None, None
def register_events(self, events: Iterable[EventBase]) -> None:
for e in events:
self._event_id_to_event[e.event_id] = e
def register_event_context(self, event: EventBase, context: EventContext) -> None:
assert context.state_group is not None
self._event_to_state_group[event.event_id] = context.state_group
def register_event_id_state_group(self, event_id: str, state_group: int) -> None:
self._event_to_state_group[event_id] = state_group
async def get_room_version_id(self, room_id: str) -> str:
return RoomVersions.V1.identifier
async def get_state_group_for_events(
self, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[str, int]:
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
return res
async def get_state_for_groups(
self, groups: Collection[int]
) -> Dict[int, MutableStateMap[str]]:
res = {}
for group in groups:
state = self._group_to_state[group]
res[group] = state
return res
def __init__(self, **kwargs: Any) -> None:
super().__init__(kwargs)
def __init__(self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]):
events: Dict[str, EventBase] = {}
clobbered: Set[str] = set()
for event_id, fields in nodes.items():
refs = edges.get(event_id)
if refs:
clobbered.difference_update(refs)
prev_events: List[Tuple[str, dict]] = [(r, {}) for r in refs]
else:
prev_events = []
events[event_id] = create_event(
)
self._leaves = clobbered
self._events = sorted(events.values(), key=lambda e: e.depth)
def walk(self) -> Iterator[EventBase]:
return iter(self._events)
self.dummy_store = _DummyStore()
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
"get_auth",
"get_state_handler",
"get_clock",
"get_state_resolution_handler",
"get_account_validity_handler",
Quentin Gliech
committed
"get_macaroon_generator",
"get_instance_name",
"get_simple_http_client",
Quentin Gliech
committed
clock = cast(Clock, MockClock())
Amber Brown
committed
hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(main=self.dummy_store)
hs.get_state_handler.return_value = None
Quentin Gliech
committed
hs.get_clock.return_value = clock
hs.get_macaroon_generator.return_value = MacaroonGenerator(
clock, "tesths", b"verysecret"
)
hs.get_auth.return_value = InternalAuth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
hs.get_storage_controllers.return_value = storage_controllers
def test_branch_no_conflict(self) -> Generator[defer.Deferred, Any, None]:
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1
"A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=3),
"D": DictObj(type=EventTypes.Message, depth=4),
self.dummy_store.register_events(graph.walk())
context_store: Dict[str, EventContext] = {}
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context
Richard van der Hoff
committed
ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
Richard van der Hoff
committed
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def test_branch_basic_conflict(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="",
content={"creator": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(type=EventTypes.Name, state_key="", depth=3),
"C": DictObj(type=EventTypes.Name, state_key="", depth=4),
"D": DictObj(type=EventTypes.Message, depth=5),
self.dummy_store.register_events(graph.walk())
context_store: Dict[str, EventContext] = {}
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context
Richard van der Hoff
committed
# C ends up winning the resolution between B and C
ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
Richard van der Hoff
committed
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def test_branch_have_banned_conflict(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="",
content={"creator": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"C": DictObj(
type=EventTypes.Member,
state_key="@user_id_2:example.com",
content={"membership": Membership.BAN},
membership=Membership.BAN,
depth=4,
),
"D": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
sender="@user_id_2:example.com",
),
edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
self.dummy_store.register_events(graph.walk())
context_store: Dict[str, EventContext] = {}
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context
Richard van der Hoff
committed
# C ends up winning the resolution between C and D because bans win over other
# changes
ctx_c = context_store["C"]
ctx_e = context_store["E"]
prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
Richard van der Hoff
committed
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
userid1 = "@user_id:example.com"
userid2 = "@user_id2:example.com"
nodes = {
"A1": DictObj(
type=EventTypes.Create,
state_key="",
content={"creator": userid1},
depth=1,
),
"A2": DictObj(
type=EventTypes.Member,
state_key=userid1,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A3": DictObj(
type=EventTypes.Member,
state_key=userid2,
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
),
"A4": DictObj(
type=EventTypes.PowerLevels,
state_key="",
content={
"events": {"m.room.name": 50},
},
),
"B": DictObj(
type=EventTypes.PowerLevels,
state_key="",
),
"C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
"D": DictObj(type=EventTypes.Message),
}
edges = {
"A2": ["A1"],
"A3": ["A2"],
"A4": ["A3"],
"A5": ["A4"],
"B": ["A5"],
"C": ["A5"],
}
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
self.dummy_store.register_events(graph.walk())
context_store: Dict[str, EventContext] = {}
for event in graph.walk():
context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.dummy_store.register_event_context(event, context)
context_store[event.event_id] = context
Richard van der Hoff
committed
# B ends up winning the resolution between B and C because power levels
# win over other changes.
ctx_b = context_store["B"]
ctx_d = context_store["D"]
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
Richard van der Hoff
committed
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def _add_depths(
self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]
) -> None:
def _get_depth(ev: str) -> int:
node = nodes[ev]
prevs = edges[ev]
depth = max(_get_depth(prev) for prev in prevs) + 1
for n in nodes:
_get_depth(n)
Erik Johnston
committed
@defer.inlineCallbacks
def test_annotate_with_old_message(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
event = create_event(type="test_message", name="event")
Erik Johnston
committed
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
Erik Johnston
committed
context = yield defer.ensureDeferred(
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
Sean Quah
committed
partial_state=False,
Erik Johnston
committed
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
Richard van der Hoff
committed
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
Richard van der Hoff
committed
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
Mark Haines
committed
)
Erik Johnston
committed
Richard van der Hoff
committed
self.assertIsNotNone(context.state_group_before_event)
self.assertEqual(context.state_group_before_event, context.state_group)
Erik Johnston
committed
def test_annotate_with_old_state(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
event = create_event(type="state", state_key="", name="event")
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
context = yield defer.ensureDeferred(
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
Sean Quah
committed
partial_state=False,
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
Richard van der Hoff
committed
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
Richard van der Hoff
committed
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
Shay
committed
assert context.state_group_before_event is not None
assert context.state_group is not None
self.assertEqual(
context.state_group_deltas.get(
(context.state_group_before_event, context.state_group)
),
{(event.type, event.state_key): event.event_id},
)
Richard van der Hoff
committed
self.assertNotEqual(context.state_group_before_event, context.state_group)
def test_trivial_annotate_message(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
prev_event_id = "prev_event_id"
event = create_event(
type="test_message", name="event2", prev_events=[(prev_event_id, {})]
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
group_name = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
{e.event_id for e in old_state}, set(current_state_ids.values())
Mark Haines
committed
self.assertEqual(group_name, context.state_group)
def test_trivial_annotate_state(
self,
) -> Generator["defer.Deferred[object]", Any, None]:
prev_event_id = "prev_event_id"
event = create_event(
type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
group_name = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
Erik Johnston
committed
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(
self,
) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
]
create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""),
self.dummy_store.register_events(old_state_1)
self.dummy_store.register_events(old_state_2)
context = yield self._get_context(
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(
self,
) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
]
create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""),
store.register_events(old_state_1)
store.register_events(old_state_2)
self.dummy_store.get_events = store.get_events # type: ignore[method-assign]
context = yield self._get_context(
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertIsNotNone(context.state_group)
def test_standard_depth_conflict(
self,
) -> Generator["defer.Deferred[Any]", Any, None]:
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
member_event = create_event(
type=EventTypes.Member,
state_key="@user_id:example.com",
type=EventTypes.PowerLevels,
state_key="",
content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
store.register_events(old_state_1)
store.register_events(old_state_2)
self.dummy_store.get_events = store.get_events # type: ignore[method-assign]
context = yield self._get_context(
current_state_ids: StateMap[str]
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
store.register_events(old_state_1)
store.register_events(old_state_2)
context = yield self._get_context(
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
self,
event: EventBase,
prev_event_id_1: str,
old_state_1: Collection[EventBase],
prev_event_id_2: str,
old_state_2: Collection[EventBase],
) -> Generator["defer.Deferred[object]", Any, EventContext]:
sg1: int
sg1 = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id_1,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state_1},
)
self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
Mark Haines
committed
sg2 = yield defer.ensureDeferred(
self.dummy_store.store_state_group(
prev_event_id_2,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state_2},
)
self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result
def test_make_state_cache_entry(self) -> None:
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
"Test that calculating a prev_group and delta is correct"
new_state = {
("a", ""): "E",
("b", ""): "E",
("c", ""): "E",
("d", ""): "E",
}
# old_state_1 has fewer differences to new_state than old_state_2, but
# the delta involves deleting a key, which isn't allowed in the deltas,
# so we should pick old_state_2 as the prev_group.
# `old_state_1` has two differences: `a` and `e`
old_state_1 = {
("a", ""): "F",
("b", ""): "E",
("c", ""): "E",
("d", ""): "E",
("e", ""): "E",
}
# `old_state_2` has three differences: `a`, `c` and `d`
old_state_2 = {
("a", ""): "F",
("b", ""): "E",
("c", ""): "F",
("d", ""): "F",
}
entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
self.assertEqual(entry.prev_group, 2)
# There are three changes from `old_state_2` to `new_state`
self.assertEqual(
entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
)