Newer
Older
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.handlers.account import AccountHandler
Azrenbeth
committed
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import override_config
# Login flows we expect to appear in the list after the normal ones.
ADDITIONAL_LOGIN_FLOWS = [
{"type": "m.login.application_service"},
]
# a mock instance which the dummy auth providers delegate to, so we can see what's going
# on
mock_password_provider = Mock()
Azrenbeth
committed
class LegacyPasswordOnlyAuthProvider:
"""A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(config: JsonDict) -> None:
def __init__(self, config: None, account_handler: AccountHandler):
def check_password(self, *args: str) -> Mock:
return mock_password_provider.check_password(*args)
Azrenbeth
committed
class LegacyCustomAuthProvider:
"""A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(config: JsonDict) -> None:
def __init__(self, config: None, account_handler: AccountHandler):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"test.login_type": ["test_field"]}
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
Azrenbeth
committed
class CustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type."""
@staticmethod
def parse_config(config: JsonDict) -> None:
Azrenbeth
committed
pass
def __init__(self, config: None, api: ModuleApi):
Azrenbeth
committed
api.register_password_auth_provider_callbacks(
auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
Azrenbeth
committed
)
def check_auth(self, *args: Any) -> Mock:
Azrenbeth
committed
return mock_password_provider.check_auth(*args)
class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@staticmethod
def parse_config(config: JsonDict) -> None:
def __init__(self, config: None, account_handler: AccountHandler):
def get_supported_login_types(self) -> Dict[str, List[str]]:
return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
def check_auth(self, *args: str) -> Mock:
return mock_password_provider.check_auth(*args)
Azrenbeth
committed
class PasswordCustomAuthProvider:
"""A module which registers password_auth_provider callbacks for a custom login type.
as well as a password login"""
@staticmethod
def parse_config(config: JsonDict) -> None:
Azrenbeth
committed
pass
def __init__(self, config: None, api: ModuleApi):
Azrenbeth
committed
api.register_password_auth_provider_callbacks(
auth_checkers={
("test.login_type", ("test_field",)): self.check_auth,
("m.login.password", ("password",)): self.check_auth,
Azrenbeth
committed
)
def check_auth(self, *args: Any) -> Mock:
Azrenbeth
committed
return mock_password_provider.check_auth(*args)
def check_pass(self, *args: str) -> Mock:
Azrenbeth
committed
return mock_password_provider.check_password(*args)
def legacy_providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
Azrenbeth
committed
def providers_config(*providers: Type[Any]) -> dict:
"""Returns a config dict that will enable the given modules"""
return {
"modules": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
for provider in providers
]
}
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
logout.register_servlets,
register.register_servlets,
account.register_servlets,
CALLBACK_USERNAME = "get_username_for_registration"
CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
# we use a global mock device, so make sure we are starting with a clean slate
mock_password_provider.reset_mock()
# The mock password provider doesn't register the users, so ensure they
# are registered first.
self.register_user("u", "not-the-tested-password")
self.register_user("user", "not-the-tested-password")
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self) -> None:
Azrenbeth
committed
self.password_only_auth_provider_login_test_body()
def password_only_auth_provider_login_test_body(self) -> None:
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:test", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_provider_ui_auth_legacy(self) -> None:
Azrenbeth
committed
self.password_only_auth_provider_ui_auth_test_body()
def password_only_auth_provider_ui_auth_test_body(self) -> None:
"""UI Auth should delegate correctly to the password provider"""
# log in twice, to get two devices
mock_password_provider.check_password = AsyncMock(return_value=True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password = AsyncMock(return_value=False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Make another request providing the UI auth flow.
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_login_legacy(self) -> None:
Azrenbeth
committed
self.local_user_fallback_login_test_body()
def local_user_fallback_login_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_local_user_fallback_ui_auth_legacy(self) -> None:
Azrenbeth
committed
self.local_user_fallback_ui_auth_test_body()
def local_user_fallback_ui_auth_test_body(self) -> None:
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password = AsyncMock(return_value=False)
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
session = self._start_delete_device_session(tok1, "dev2")
mock_password_provider.check_password.assert_not_called()
# Wrong password
channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "xxx"
)
mock_password_provider.reset_mock()
# Right password
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
Azrenbeth
committed
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_login_legacy(self) -> None:
Azrenbeth
committed
self.no_local_user_fallback_login_test_body()
def no_local_user_fallback_login_test_body(self) -> None:
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
Azrenbeth
committed
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_no_local_user_fallback_ui_auth_legacy(self) -> None:
Azrenbeth
committed
self.no_local_user_fallback_ui_auth_test_body()
def no_local_user_fallback_ui_auth_test_body(self) -> None:
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password = AsyncMock(return_value=True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
self.login("localuser", "p", device_id="dev2")
mock_password_provider.check_password.reset_mock()
# first delete should give a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# m.login.password UIA is permitted because the auth provider allows it,
# even though the localdb does not.
self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
session = channel.json_body["session"]
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
self.assertEqual(channel.code, 401) # XXX why not a 403?
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_password.assert_called_once_with(
"@localuser:test", "localpass"
)
@override_config(
{
Azrenbeth
committed
**legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_auth_disabled_legacy(self) -> None:
Azrenbeth
committed
self.password_auth_disabled_test_body()
def password_auth_disabled_test_body(self) -> None:
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_login_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_login_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_login_test_body()
def custom_auth_provider_login_test_body(self) -> None:
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
flows = self._get_login_flows()
self.assertEqual(
flows,
[{"type": "m.login.password"}, {"type": "test.login_type"}]
+ ADDITIONAL_LOGIN_FLOWS,
)
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
mock_password_provider.reset_mock()
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_ui_auth_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_ui_auth_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_ui_auth_test_body()
def custom_auth_provider_ui_auth_test_body(self) -> None:
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
self.login("localuser", "localpass", device_id="dev2")
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
# missing param
body = {
"auth": {
"type": "test.login_type",
"identifier": {"type": "m.id.user", "user": "localuser"},
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
# there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
# use it...
self.assertIn("Missing parameters", channel.json_body["error"])
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth = AsyncMock(
return_value=("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "foo"}
)
Azrenbeth
committed
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
def test_custom_auth_provider_callback_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_callback_test_body()
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self) -> None:
Azrenbeth
committed
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self) -> None:
callback = AsyncMock(return_value=None)
mock_password_provider.check_auth = AsyncMock(
return_value=("@user:test", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:test", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
)
# check the args to the callback
callback.assert_called_once()
call_args, call_kwargs = callback.call_args
# should be one positional arg
self.assertEqual(len(call_args), 1)
self.assertEqual(call_args[0]["user_id"], "@user:test")
for p in ["user_id", "access_token", "device_id", "home_server"]:
self.assertIn(p, call_args[0])
Azrenbeth
committed
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_custom_auth_password_disabled_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_password_disabled_test_body()
@override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self) -> None:
Azrenbeth
committed
self.custom_auth_password_disabled_test_body()
def custom_auth_password_disabled_test_body(self) -> None:
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
Azrenbeth
committed
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_password_disabled_localdb_enabled_test_body()
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self) -> None:
Azrenbeth
committed
self.custom_auth_password_disabled_localdb_enabled_test_body()
def custom_auth_password_disabled_localdb_enabled_test_body(self) -> None:
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
cause an exception.
"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
Azrenbeth
committed
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login_legacy(self) -> None:
Azrenbeth
committed
self.password_custom_auth_password_disabled_login_test_body()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self) -> None:
Azrenbeth
committed
self.password_custom_auth_password_disabled_login_test_body()
def password_custom_auth_password_disabled_login_test_body(self) -> None:
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
Azrenbeth
committed
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyPasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth_legacy(self) -> None:
Azrenbeth
committed
self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_ui_auth(self) -> None:
Azrenbeth
committed
self.password_custom_auth_password_disabled_ui_auth_test_body()
def password_custom_auth_password_disabled_ui_auth_test_body(self) -> None:
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth = AsyncMock(
return_value=("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected. In particular, "password" should *not*
# be present.
self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
session = channel.json_body["session"]
mock_password_provider.reset_mock()
# check that auth with password is rejected
body = {
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "localuser"},
"password": "localpass",
"session": session,
},
}
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 400)
self.assertEqual(
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
Azrenbeth
committed
mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
body["auth"]["type"] = "test.login_type"
body["auth"]["test_field"] = "x"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
Azrenbeth
committed
mock_password_provider.check_password.assert_not_called()
@override_config(
{
**legacy_providers_config(LegacyCustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback_legacy(self) -> None:
Azrenbeth
committed
self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
**providers_config(CustomAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
def test_custom_auth_no_local_user_fallback(self) -> None:
Azrenbeth
committed
self.custom_auth_no_local_user_fallback_test_body()
def custom_auth_no_local_user_fallback_test_body(self) -> None:
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self) -> None:
"""Tests that the on_logged_out callback is called when the user logs out."""
self.register_user("rin", "password")
tok = self.login("rin", "password")
self.called = False
async def on_logged_out(
user_id: str, device_id: Optional[str], access_token: str
) -> None:
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
on_logged_out
)
channel = self.make_request(
"POST",
"/_matrix/client/v3/logout",
{},
access_token=tok,
)
self.assertEqual(channel.code, 200)
on_logged_out.assert_called_once()
self.assertTrue(self.called)
def test_username(self) -> None:
"""Tests that the get_username_for_registration callback can define the username
of a user when registering.
"""
self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)
username = "rin"
channel = self.make_request(
"POST",
"/register",
{
"username": username,
"password": "bar",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(channel.code, 200)
# Our callback takes the username and appends "-foo" to it, check that's what we
# have.
mxid = channel.json_body["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
def test_username_uia(self) -> None:
"""Tests that the get_username_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_USERNAME,
)
username = "rin"
res = self._do_uia_assert_mock_not_called(username, m)
mxid = res["user_id"]
self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
# Check that the callback has been called.
m.assert_called_once()
# Set some email configuration so the test doesn't fail because of its absence.
@override_config({"email": {"notif_from": "noreply@test"}})
def test_3pid_allowed(self) -> None:
"""Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
the 3PID. Also checks that the module is passed a boolean indicating whether the
user to bind this 3PID to is currently registering.
"""
self._test_3pid_allowed("rin", False)
self._test_3pid_allowed("kitay", True)
def test_displayname(self) -> None:
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
"""Tests that the get_displayname_for_registration callback can define the
display name of a user when registering.
"""
self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)
username = "rin"
channel = self.make_request(
"POST",
"/register",
{
"username": username,
"password": "bar",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(channel.code, 200)
# Our callback takes the username and appends "-foo" to it, check that's what we
# have.
user_id = UserID.from_string(channel.json_body["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)
self.assertEqual(display_name, username + "-foo")
def test_displayname_uia(self) -> None:
"""Tests that the get_displayname_for_registration callback is only called at the
end of the UIA flow.
"""
m = self._setup_get_name_for_registration(
callback_name=self.CALLBACK_DISPLAYNAME,
)
username = "rin"
res = self._do_uia_assert_mock_not_called(username, m)
user_id = UserID.from_string(res["user_id"])
display_name = self.get_success(
self.hs.get_profile_handler().get_displayname(user_id)
)
self.assertEqual(display_name, username + "-foo")
# Check that the callback has been called.
m.assert_called_once()
def _test_3pid_allowed(self, username: str, registration: bool) -> None:
"""Tests that the "is_3pid_allowed" module callback is called correctly, using
either /register or /account URLs depending on the arguments.
Args:
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign]
m = AsyncMock(return_value=False)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.register_user(username, "password")
tok = self.login(username, "password")
if registration:
url = "/register/email/requestToken"
else:
url = "/account/3pid/email/requestToken"
channel = self.make_request(
"POST",
url,
{
"client_secret": "foo",
"email": "foo@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
channel.json_body,
)
m.assert_called_once_with("email", "foo@test.com", registration)
m = AsyncMock(return_value=True)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
channel = self.make_request(
"POST",
url,
{
"client_secret": "foo",
"email": "bar@test.com",
"send_attempt": 0,
},
access_token=tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
"""Registers either a get_username_for_registration callback or a
get_displayname_for_registration callback that appends "-foo" to the username the
client is trying to register.
"""
async def callback(uia_results: JsonDict, params: JsonDict) -> str:
self.assertIn(LoginType.DUMMY, uia_results)
username = params["username"]
return username + "-foo"
m = Mock(side_effect=callback)
password_auth_provider = self.hs.get_password_auth_provider()
getattr(password_auth_provider, callback_name + "_callbacks").append(m)
return m
def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
# Initiate the UIA flow.
channel = self.make_request(
"POST",
"register",
{"username": username, "type": "m.login.password", "password": "bar"},
)
self.assertEqual(channel.code, 401)
self.assertIn("session", channel.json_body)
# Check that the callback hasn't been called yet.
m.assert_not_called()
# Finish the UIA flow.
session = channel.json_body["session"]
channel = self.make_request(
"POST",
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
return self._send_login(type="m.login.password", user=user, password=password)
def _send_login(self, type: str, user: str, **extra_params: str) -> FakeChannel:
params = {"identifier": {"type": "m.id.user", "user": user}, "type": type}
params.update(extra_params)
channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token: str, device_id: str) -> str:
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
"""Make an initial delete device request, and return the UI Auth session ID"""
channel = self._delete_device(access_token, device_id)
self.assertEqual(channel.code, 401)
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
return channel.json_body["session"]
def _authed_delete_device(
self,
access_token: str,
device_id: str,
session: str,
user_id: str,
password: str,
) -> FakeChannel:
"""Make a delete device request, authenticating with the given uid/password"""
return self._delete_device(
access_token,
device_id,
{
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
"password": password,
"session": session,
},
},
)
def _delete_device(
self,
access_token: str,
device: str,
body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel