Skip to content
Snippets Groups Projects
Unverified Commit 52ec6e9d authored by Amber Brown's avatar Amber Brown Committed by GitHub
Browse files

Port tests/ to Python 3 (#3808)

parent c5440b2c
No related branches found
No related tags found
No related merge requests found
tests/ is now ported to Python 3.
...@@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -471,6 +471,7 @@ class AuthTestCase(unittest.TestCase):
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.hs.config.max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {'medium': 'email', 'address': 'reserved@server.com'} threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'} unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.hs.config.mau_limits_reserved_threepids = [threepid]
......
...@@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase): ...@@ -47,7 +47,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")
...@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase): ...@@ -77,7 +77,7 @@ class FrontendProxyTests(HomeserverTestCase):
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = ( self.resource = (
site.resource.children["_matrix"].children["client"].children["r0"] site.resource.children[b"_matrix"].children[b"client"].children[b"r0"]
) )
request, channel = self.make_request("PUT", "presence/a/status") request, channel = self.make_request("PUT", "presence/a/status")
......
...@@ -43,9 +43,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"): ...@@ -43,9 +43,7 @@ def _expect_edu_transaction(edu_type, content, origin="test"):
def _make_edu_transaction_json(edu_type, content): def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode( return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
'utf8'
)
class TypingNotificationsTestCase(unittest.TestCase): class TypingNotificationsTestCase(unittest.TestCase):
......
This diff is collapsed.
...@@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase): ...@@ -62,12 +62,6 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertTrue( self.assertTrue(
set( set(
[ ["next_batch", "rooms", "account_data", "to_device", "device_lists"]
"next_batch",
"rooms",
"account_data",
"to_device",
"device_lists",
]
).issubset(set(channel.json_body.keys())) ).issubset(set(channel.json_body.keys()))
) )
...@@ -65,7 +65,7 @@ class FakeChannel(object): ...@@ -65,7 +65,7 @@ class FakeChannel(object):
def getPeer(self): def getPeer(self):
# We give an address so that getClientIP returns a non null entry, # We give an address so that getClientIP returns a non null entry,
# causing us to record the MAU # causing us to record the MAU
return address.IPv4Address(b"TCP", "127.0.0.1", 3423) return address.IPv4Address("TCP", "127.0.0.1", 3423)
def getHost(self): def getHost(self):
return None return None
......
...@@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): ...@@ -80,12 +80,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock()
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
...@@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase): ...@@ -99,12 +98,11 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
) )
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
content={"msgtype": ServerNoticeMsgType}, )
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
{"123": mock_event}
))
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
...@@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): ...@@ -177,13 +175,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(return_value=1000)
return_value=1000,
)
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(return_value=1000)
return_value=1000,
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id) yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
...@@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase): ...@@ -193,12 +187,12 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
# Now lets get the last load of messages in the service notice room and # Now lets get the last load of messages in the service notice room and
# check that there is only one server notice # check that there is only one server notice
room_id = yield self.server_notices_manager.get_notice_room_for_user( room_id = yield self.server_notices_manager.get_notice_room_for_user(
self.user_id, self.user_id
) )
token = yield self.event_source.get_current_token() token = yield self.event_source.get_current_token()
events, _ = yield self.store.get_recent_events_for_room( events, _ = yield self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key, room_id, limit=100, end_token=token.room_key
) )
count = 0 count = 0
......
...@@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -185,8 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -200,19 +199,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({}, state_dict)
{},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with wildcard types # test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -226,7 +226,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -264,18 +266,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -305,9 +304,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
key=group, key=group,
value=state_dict_ids, value=state_dict_ids,
# list fetched keys so it knows it's partial # list fetched keys so it knows it's partial
fetched_keys=( fetched_keys=((e1.type, e1.state_key),),
(e1.type, e1.state_key),
),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get( (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
...@@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -315,20 +312,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual( self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
known_absent, self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
set(
[
(e1.type, e1.state_key),
]
),
)
self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id,
},
)
############################################ ############################################
# test that things work with a partial cache # test that things work with a partial cache
...@@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -336,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters out members with types=[] # test _get_some_state_from_cache correctly filters out members with types=[]
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
group, [], filtered_types=[EventTypes.Member]
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
...@@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -346,7 +330,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
room_id = self.room.to_string() room_id = self.room.to_string()
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [], filtered_types=[EventTypes.Member] group,
[],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -355,20 +341,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
# test _get_some_state_from_cache correctly filters in members wildcard types # test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, None)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
...@@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -389,12 +374,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual( self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
{
(e1.type, e1.state_key): e1.event_id,
},
state_dict,
)
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
...@@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -404,18 +384,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_cache, self.store._state_group_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
...@@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase): ...@@ -423,13 +400,10 @@ class StateStoreTestCase(tests.unittest.TestCase):
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
self.store._state_group_members_cache, self.store._state_group_members_cache,
group, [(EventTypes.Member, e5.state_key)], filtered_types=None group,
[(EventTypes.Member, e5.state_key)],
filtered_types=None,
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual( self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
{
(e5.type, e5.state_key): e5.event_id,
},
state_dict,
)
...@@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase): ...@@ -185,20 +185,20 @@ class TestMauLimit(unittest.TestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def create_user(self, localpart): def create_user(self, localpart):
request_data = json.dumps({ request_data = json.dumps(
"username": localpart, {
"password": "monkey", "username": localpart,
"auth": {"type": LoginType.DUMMY}, "password": "monkey",
}) "auth": {"type": LoginType.DUMMY},
}
)
request, channel = make_request(b"POST", b"/register", request_data) request, channel = make_request("POST", "/register", request_data)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
...@@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase): ...@@ -206,12 +206,12 @@ class TestMauLimit(unittest.TestCase):
return access_token return access_token
def do_sync_for_user(self, token): def do_sync_for_user(self, token):
request, channel = make_request(b"GET", b"/sync", access_token=token) request, channel = make_request(
"GET", "/sync", access_token=token.encode('ascii')
)
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)
if channel.result["code"] != b"200": if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
int(channel.result["code"]), channel.code, channel.result["reason"], channel.result["body"]
channel.result["reason"],
channel.result["body"],
).to_synapse_error() ).to_synapse_error()
...@@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase): ...@@ -180,7 +180,7 @@ class StateTestCase(unittest.TestCase):
graph = Graph( graph = Graph(
nodes={ nodes={
"START": DictObj( "START": DictObj(
type=EventTypes.Create, state_key="", content={}, depth=1, type=EventTypes.Create, state_key="", content={}, depth=1
), ),
"A": DictObj(type=EventTypes.Message, depth=2), "A": DictObj(type=EventTypes.Message, depth=2),
"B": DictObj(type=EventTypes.Message, depth=3), "B": DictObj(type=EventTypes.Message, depth=3),
......
...@@ -100,8 +100,13 @@ class TestHomeServer(HomeServer): ...@@ -100,8 +100,13 @@ class TestHomeServer(HomeServer):
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None, cleanup_func,
homeserverToUse=TestHomeServer, **kargs name="test",
datastore=None,
config=None,
reactor=None,
homeserverToUse=TestHomeServer,
**kargs
): ):
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
...@@ -323,8 +328,7 @@ class MockHttpResource(HttpServer): ...@@ -323,8 +328,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger( def trigger(
self, http_method, path, content, mock_request, self, http_method, path, content, mock_request, federation_auth_origin=None
federation_auth_origin=None,
): ):
""" Fire an HTTP event. """ Fire an HTTP event.
...@@ -357,7 +361,7 @@ class MockHttpResource(HttpServer): ...@@ -357,7 +361,7 @@ class MockHttpResource(HttpServer):
headers = {} headers = {}
if federation_auth_origin is not None: if federation_auth_origin is not None:
headers[b"Authorization"] = [ headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin, ) b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
] ]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers) mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
...@@ -577,16 +581,16 @@ def create_room(hs, room_id, creator_id): ...@@ -577,16 +581,16 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new({ builder = event_builder_factory.new(
"type": EventTypes.Create, {
"state_key": "", "type": EventTypes.Create,
"sender": creator_id, "state_key": "",
"room_id": room_id, "sender": creator_id,
"content": {}, "room_id": room_id,
}) "content": {},
}
event, context = yield event_creation_handler.create_new_client_event(
builder
) )
event, context = yield event_creation_handler.create_new_client_event(builder)
yield store.persist_event(event, context) yield store.persist_event(event, context)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment