Skip to content
Snippets Groups Projects
Commit c3eae8a8 authored by Mark Haines's avatar Mark Haines
Browse files

Construct the EventContext in the state handler rather than constructing one...

Construct the EventContext in the state handler rather than constructing one and then immediately calling state_handler.annotate_context_with_state
parent 3c7857e4
No related branches found
No related tags found
No related merge requests found
...@@ -20,8 +20,6 @@ from synapse.util.async import run_on_reactor ...@@ -20,8 +20,6 @@ from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.events.snapshot import EventContext
import logging import logging
...@@ -77,15 +75,10 @@ class BaseHandler(object): ...@@ -77,15 +75,10 @@ class BaseHandler(object):
state_handler = self.state_handler state_handler = self.state_handler
context = EventContext() context = yield state_handler.compute_event_context(builder)
ret = yield state_handler.annotate_context_with_state(
builder,
context,
)
prev_state = ret
if builder.is_state(): if builder.is_state():
builder.prev_state = prev_state builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context) yield self.auth.add_auth_events(builder, context)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, SynapseError, StoreError,
...@@ -260,8 +259,7 @@ class FederationHandler(BaseHandler): ...@@ -260,8 +259,7 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# FIXME (erikj): Not sure this actually works :/ # FIXME (erikj): Not sure this actually works :/
context = EventContext() context = yield self.state_handler.compute_event_context(event)
yield self.state_handler.annotate_context_with_state(event, context)
events.append((event, context)) events.append((event, context))
...@@ -555,8 +553,7 @@ class FederationHandler(BaseHandler): ...@@ -555,8 +553,7 @@ class FederationHandler(BaseHandler):
) )
) )
context = EventContext() context = yield self.state_handler.compute_event_context(event)
yield self.state_handler.annotate_context_with_state(event, context)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
...@@ -688,11 +685,8 @@ class FederationHandler(BaseHandler): ...@@ -688,11 +685,8 @@ class FederationHandler(BaseHandler):
event.event_id, event.signatures, event.event_id, event.signatures,
) )
context = EventContext() context = yield self.state_handler.compute_event_context(
yield self.state_handler.annotate_context_with_state( event, old_state=state
event,
context,
old_state=state
) )
logger.debug( logger.debug(
......
...@@ -19,6 +19,7 @@ from twisted.internet import defer ...@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
...@@ -70,7 +71,7 @@ class StateHandler(object): ...@@ -70,7 +71,7 @@ class StateHandler(object):
defer.returnValue(res[1].values()) defer.returnValue(res[1].values())
@defer.inlineCallbacks @defer.inlineCallbacks
def annotate_context_with_state(self, event, context, old_state=None): def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
...@@ -80,8 +81,11 @@ class StateHandler(object): ...@@ -80,8 +81,11 @@ class StateHandler(object):
Args: Args:
event (EventBase) event (EventBase)
context (EventContext) Returns:
an EventContext
""" """
context = EventContext()
yield run_on_reactor() yield run_on_reactor()
if old_state: if old_state:
...@@ -107,7 +111,8 @@ class StateHandler(object): ...@@ -107,7 +111,8 @@ class StateHandler(object):
if replaces.event_id != event.event_id: # Paranoia check if replaces.event_id != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
defer.returnValue([]) context.prev_state_events = []
defer.returnValue(context)
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
...@@ -145,7 +150,8 @@ class StateHandler(object): ...@@ -145,7 +150,8 @@ class StateHandler(object):
else: else:
context.auth_events = {} context.auth_events = {}
defer.returnValue(prev_state) context.prev_state_events = prev_state
defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
......
...@@ -34,7 +34,7 @@ class FederationTestCase(unittest.TestCase): ...@@ -34,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.state_handler = NonCallableMock(spec_set=[ self.state_handler = NonCallableMock(spec_set=[
"annotate_context_with_state", "compute_event_context",
]) ])
self.auth = NonCallableMock(spec_set=[ self.auth = NonCallableMock(spec_set=[
...@@ -91,11 +91,12 @@ class FederationTestCase(unittest.TestCase): ...@@ -91,11 +91,12 @@ class FederationTestCase(unittest.TestCase):
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
def annotate(ev, context, old_state=None): def annotate(ev, old_state=None):
context = Mock()
context.current_state = {} context.current_state = {}
context.auth_events = {} context.auth_events = {}
return defer.succeed(False) return defer.succeed(context)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu( yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False "fo", pdu, False
...@@ -109,15 +110,12 @@ class FederationTestCase(unittest.TestCase): ...@@ -109,15 +110,12 @@ class FederationTestCase(unittest.TestCase):
context=ANY, context=ANY,
) )
self.state_handler.annotate_context_with_state.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
ANY, ANY, old_state=None,
ANY,
old_state=None,
) )
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, ANY, extra_users=[]
extra_users=[]
) )
...@@ -60,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -60,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"check_host_in_room", "check_host_in_room",
]), ]),
state_handler=NonCallableMock(spec_set=[ state_handler=NonCallableMock(spec_set=[
"annotate_context_with_state", "compute_event_context",
"get_current_state", "get_current_state",
]), ]),
config=self.mock_config, config=self.mock_config,
...@@ -110,7 +110,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -110,7 +110,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member( (EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green", user_id="@alice:green",
...@@ -121,10 +122,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -121,10 +122,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
room_id=room_id, room_id=room_id,
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[
...@@ -146,8 +148,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -146,8 +148,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
yield room_handler.change_membership(event, context) yield room_handler.change_membership(event, context)
self.state_handler.annotate_context_with_state.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
builder, context builder
) )
self.auth.add_auth_events.assert_called_once_with( self.auth.add_auth_events.assert_called_once_with(
...@@ -189,7 +191,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -189,7 +191,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member( (EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red", user_id="@bob:red",
...@@ -197,10 +200,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -197,10 +200,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
membership=Membership.INVITE membership=Membership.INVITE
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[
...@@ -262,7 +266,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -262,7 +266,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
defer.succeed([]) defer.succeed([])
) )
def annotate(_, ctx): def annotate(_):
ctx = Mock()
ctx.current_state = { ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member( (EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red", user_id="@bob:red",
...@@ -270,10 +275,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase): ...@@ -270,10 +275,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
membership=Membership.JOIN membership=Membership.JOIN
), ),
} }
ctx.prev_state_events = []
return defer.succeed(True) return defer.succeed(ctx)
self.state_handler.annotate_context_with_state.side_effect = annotate self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx): def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[ ctx.auth_events = ctx.current_state[
......
...@@ -38,7 +38,6 @@ class StateTestCase(unittest.TestCase): ...@@ -38,7 +38,6 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event") event = self.create_event(type="test_message", name="event")
context = Mock()
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -46,8 +45,8 @@ class StateTestCase(unittest.TestCase): ...@@ -46,8 +45,8 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_context_with_state( context = yield self.state.compute_event_context(
event, context, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items(): for k, v in context.current_state.items():
...@@ -64,7 +63,6 @@ class StateTestCase(unittest.TestCase): ...@@ -64,7 +63,6 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event") event = self.create_event(type="state", state_key="", name="event")
context = Mock()
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -72,8 +70,8 @@ class StateTestCase(unittest.TestCase): ...@@ -72,8 +70,8 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_context_with_state( context = yield self.state.compute_event_context(
event, context, old_state=old_state event, old_state=old_state
) )
for k, v in context.current_state.items(): for k, v in context.current_state.items():
...@@ -92,7 +90,6 @@ class StateTestCase(unittest.TestCase): ...@@ -92,7 +90,6 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event") event = self.create_event(type="test_message", name="event")
event.prev_events = [] event.prev_events = []
context = Mock()
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -106,7 +103,7 @@ class StateTestCase(unittest.TestCase): ...@@ -106,7 +103,7 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_context_with_state(event, context) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
...@@ -124,7 +121,6 @@ class StateTestCase(unittest.TestCase): ...@@ -124,7 +121,6 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event") event = self.create_event(type="state", state_key="", name="event")
event.prev_events = [] event.prev_events = []
context = Mock()
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -138,7 +134,7 @@ class StateTestCase(unittest.TestCase): ...@@ -138,7 +134,7 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_context_with_state(event, context) context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
...@@ -156,7 +152,6 @@ class StateTestCase(unittest.TestCase): ...@@ -156,7 +152,6 @@ class StateTestCase(unittest.TestCase):
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event") event = self.create_event(type="test_message", name="event")
event.prev_events = [] event.prev_events = []
context = Mock()
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -178,7 +173,7 @@ class StateTestCase(unittest.TestCase): ...@@ -178,7 +173,7 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_context_with_state(event, context) context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 5)
...@@ -188,7 +183,6 @@ class StateTestCase(unittest.TestCase): ...@@ -188,7 +183,6 @@ class StateTestCase(unittest.TestCase):
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event") event = self.create_event(type="test4", state_key="", name="event")
event.prev_events = [] event.prev_events = []
context = Mock()
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="1"),
...@@ -210,7 +204,7 @@ class StateTestCase(unittest.TestCase): ...@@ -210,7 +204,7 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_context_with_state(event, context) context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment