Skip to content
Snippets Groups Projects
Commit 7b079a26 authored by Mark Haines's avatar Mark Haines
Browse files

Remove get_state_for_room function from federation handler

parent bddc1d9f
No related branches found
No related tags found
No related merge requests found
...@@ -84,12 +84,6 @@ class FederationHandler(BaseHandler): ...@@ -84,12 +84,6 @@ class FederationHandler(BaseHandler):
yield self.replication_layer.send_pdu(pdu) yield self.replication_layer.send_pdu(pdu)
@log_function
def get_state_for_room(self, destination, room_id):
return self.replication_layer.get_state_for_context(
destination, room_id
)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, pdu, backfilled): def on_receive_pdu(self, pdu, backfilled):
...@@ -139,7 +133,7 @@ class FederationHandler(BaseHandler): ...@@ -139,7 +133,7 @@ class FederationHandler(BaseHandler):
yield self.hs.get_handlers().room_member_handler.change_membership( yield self.hs.get_handlers().room_member_handler.change_membership(
new_event, new_event,
True do_auth=True
) )
else: else:
...@@ -151,8 +145,8 @@ class FederationHandler(BaseHandler): ...@@ -151,8 +145,8 @@ class FederationHandler(BaseHandler):
if not room: if not room:
# Huh, let's try and get the current state # Huh, let's try and get the current state
try: try:
yield self.get_state_for_room( yield self.replication_layer.get_state_for_context(
event.origin, event.room_id origin, event.room_id
) )
hosts = yield self.store.get_joined_hosts_for_room( hosts = yield self.store.get_joined_hosts_for_room(
...@@ -161,9 +155,9 @@ class FederationHandler(BaseHandler): ...@@ -161,9 +155,9 @@ class FederationHandler(BaseHandler):
if self.hs.hostname in hosts: if self.hs.hostname in hosts:
try: try:
yield self.store.store_room( yield self.store.store_room(
event.room_id, room_id=event.room_id,
"", room_creator_user_id="",
is_public=False is_public=False,
) )
except: except:
pass pass
...@@ -209,7 +203,9 @@ class FederationHandler(BaseHandler): ...@@ -209,7 +203,9 @@ class FederationHandler(BaseHandler):
# First get current state to see if we are already joined. # First get current state to see if we are already joined.
try: try:
yield self.get_state_for_room(target_host, room_id) yield self.replication_layer.get_state_for_context(
target_host, room_id
)
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts: if self.hs.hostname in hosts:
...@@ -239,8 +235,8 @@ class FederationHandler(BaseHandler): ...@@ -239,8 +235,8 @@ class FederationHandler(BaseHandler):
try: try:
yield self.store.store_room( yield self.store.store_room(
room_id, room_id=room_id,
"", room_creator_user_id="",
is_public=False is_public=False
) )
except: except:
......
...@@ -28,6 +28,8 @@ from mock import NonCallableMock, ANY ...@@ -28,6 +28,8 @@ from mock import NonCallableMock, ANY
import logging import logging
from ..utils import get_mock_call_args
logging.getLogger().addHandler(logging.NullHandler()) logging.getLogger().addHandler(logging.NullHandler())
...@@ -99,9 +101,13 @@ class FederationTestCase(unittest.TestCase): ...@@ -99,9 +101,13 @@ class FederationTestCase(unittest.TestCase):
mem_handler = self.handlers.room_member_handler mem_handler = self.handlers.room_member_handler
self.assertEquals(1, mem_handler.change_membership.call_count) self.assertEquals(1, mem_handler.change_membership.call_count)
self.assertEquals(True, mem_handler.change_membership.call_args[0][1]) call_args = get_mock_call_args(
lambda event, do_auth: None,
mem_handler.change_membership
)
self.assertEquals(True, call_args["do_auth"])
new_event = mem_handler.change_membership.call_args[0][0] new_event = call_args["event"]
self.assertEquals(RoomMemberEvent.TYPE, new_event.type) self.assertEquals(RoomMemberEvent.TYPE, new_event.type)
self.assertEquals(room_id, new_event.room_id) self.assertEquals(room_id, new_event.room_id)
self.assertEquals(user_id, new_event.state_key) self.assertEquals(user_id, new_event.state_key)
......
...@@ -28,6 +28,16 @@ from mock import patch, Mock ...@@ -28,6 +28,16 @@ from mock import patch, Mock
import json import json
import urlparse import urlparse
from inspect import getcallargs
def get_mock_call_args(pattern_func, mock_func):
""" Return the arguments the mock function was called with interpreted
by the pattern functions argument list.
"""
invoked_args, invoked_kargs = mock_func.call_args
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment