Skip to content
Snippets Groups Projects
test_password_providers.py 33.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • # Copyright 2020 The Matrix.org Foundation C.I.C.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """Tests for the password_auth_provider interface"""
    
    from typing import Any, Type, Union
    
    from unittest.mock import Mock
    
    
    from twisted.internet import defer
    
    import synapse
    
    from synapse.api.constants import LoginType
    
    from synapse.handlers.auth import load_legacy_password_auth_providers
    from synapse.module_api import ModuleApi
    
    from synapse.rest.client import devices, login, logout, register
    from synapse.types import JsonDict, UserID
    
    
    from tests import unittest
    from tests.server import FakeChannel
    from tests.unittest import override_config
    
    # (possibly experimental) login flows we expect to appear in the list after the normal
    # ones
    
    ADDITIONAL_LOGIN_FLOWS = [
        {"type": "m.login.application_service"},
        {"type": "uk.half-shot.msc2778.login.application_service"},
    ]
    
    
    # a mock instance which the dummy auth providers delegate to, so we can see what's going
    # on
    mock_password_provider = Mock()
    
    
    
    class LegacyPasswordOnlyAuthProvider:
        """A legacy password_provider which only implements `check_password`."""
    
    
        @staticmethod
        def parse_config(self):
            pass
    
        def __init__(self, config, account_handler):
            pass
    
        def check_password(self, *args):
            return mock_password_provider.check_password(*args)
    
    
    
    class LegacyCustomAuthProvider:
        """A legacy password_provider which implements a custom login type."""
    
    
        @staticmethod
        def parse_config(self):
            pass
    
        def __init__(self, config, account_handler):
            pass
    
        def get_supported_login_types(self):
            return {"test.login_type": ["test_field"]}
    
        def check_auth(self, *args):
            return mock_password_provider.check_auth(*args)
    
    
    
    class CustomAuthProvider:
        """A module which registers password_auth_provider callbacks for a custom login type."""
    
        @staticmethod
        def parse_config(self):
            pass
    
        def __init__(self, config, api: ModuleApi):
            api.register_password_auth_provider_callbacks(
                auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
            )
    
        def check_auth(self, *args):
            return mock_password_provider.check_auth(*args)
    
    
    class LegacyPasswordCustomAuthProvider:
    
        """A password_provider which implements password login via `check_auth`, as well
        as a custom type."""
    
        @staticmethod
        def parse_config(self):
            pass
    
        def __init__(self, config, account_handler):
            pass
    
        def get_supported_login_types(self):
            return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
    
        def check_auth(self, *args):
            return mock_password_provider.check_auth(*args)
    
    
    
    class PasswordCustomAuthProvider:
        """A module which registers password_auth_provider callbacks for a custom login type.
        as well as a password login"""
    
        @staticmethod
        def parse_config(self):
            pass
    
        def __init__(self, config, api: ModuleApi):
            api.register_password_auth_provider_callbacks(
                auth_checkers={
                    ("test.login_type", ("test_field",)): self.check_auth,
                    ("m.login.password", ("password",)): self.check_auth,
                },
            )
            pass
    
        def check_auth(self, *args):
            return mock_password_provider.check_auth(*args)
    
        def check_pass(self, *args):
            return mock_password_provider.check_password(*args)
    
    
    def legacy_providers_config(*providers: Type[Any]) -> dict:
        """Returns a config dict that will enable the given legacy password auth providers"""
    
        return {
            "password_providers": [
                {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
                for provider in providers
            ]
        }
    
    
    
    def providers_config(*providers: Type[Any]) -> dict:
        """Returns a config dict that will enable the given modules"""
        return {
            "modules": [
                {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
                for provider in providers
            ]
        }
    
    
    
    class PasswordAuthProviderTests(unittest.HomeserverTestCase):
        servlets = [
            synapse.rest.admin.register_servlets,
            login.register_servlets,
            devices.register_servlets,
    
            logout.register_servlets,
    
        ]
    
        def setUp(self):
            # we use a global mock device, so make sure we are starting with a clean slate
            mock_password_provider.reset_mock()
            super().setUp()
    
    
        def make_homeserver(self, reactor, clock):
            hs = self.setup_test_homeserver()
            # Load the modules into the homeserver
            module_api = hs.get_module_api()
            for module, config in hs.config.modules.loaded_modules:
                module(config=config, api=module_api)
            load_legacy_password_auth_providers(hs)
    
            return hs
    
        @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
        def test_password_only_auth_progiver_login_legacy(self):
            self.password_only_auth_provider_login_test_body()
    
        def password_only_auth_provider_login_test_body(self):
    
            # login flows should only have m.login.password
            flows = self._get_login_flows()
            self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
    
            # check_password must return an awaitable
            mock_password_provider.check_password.return_value = defer.succeed(True)
            channel = self._send_password_login("u", "p")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@u:test", channel.json_body["user_id"])
            mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
            mock_password_provider.reset_mock()
    
            # login with mxid should work too
            channel = self._send_password_login("@u:bz", "p")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@u:bz", channel.json_body["user_id"])
            mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
            mock_password_provider.reset_mock()
    
            # try a weird username / pass. Honestly it's unclear what we *expect* to happen
            # in these cases, but at least we can guard against the API changing
            # unexpectedly
            channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
            mock_password_provider.check_password.assert_called_once_with(
                "@ USER🙂NAME :test", " pASS😢word "
            )
    
    
        @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
        def test_password_only_auth_provider_ui_auth_legacy(self):
            self.password_only_auth_provider_ui_auth_test_body()
    
        def password_only_auth_provider_ui_auth_test_body(self):
    
            """UI Auth should delegate correctly to the password provider"""
    
            # create the user, otherwise access doesn't work
            module_api = self.hs.get_module_api()
            self.get_success(module_api.register_user("u"))
    
            # log in twice, to get two devices
            mock_password_provider.check_password.return_value = defer.succeed(True)
            tok1 = self.login("u", "p")
            self.login("u", "p", device_id="dev2")
            mock_password_provider.reset_mock()
    
            # have the auth provider deny the request to start with
            mock_password_provider.check_password.return_value = defer.succeed(False)
    
            # make the initial request which returns a 401
            session = self._start_delete_device_session(tok1, "dev2")
            mock_password_provider.check_password.assert_not_called()
    
            # Make another request providing the UI auth flow.
            channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
            self.assertEqual(channel.code, 401)  # XXX why not a 403?
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
            mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
            mock_password_provider.reset_mock()
    
            # Finally, check the request goes through when we allow it
            mock_password_provider.check_password.return_value = defer.succeed(True)
            channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
            self.assertEqual(channel.code, 200)
            mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
    
    
        @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
        def test_local_user_fallback_login_legacy(self):
            self.local_user_fallback_login_test_body()
    
        def local_user_fallback_login_test_body(self):
    
            """rejected login should fall back to local db"""
            self.register_user("localuser", "localpass")
    
            # check_password must return an awaitable
            mock_password_provider.check_password.return_value = defer.succeed(False)
            channel = self._send_password_login("u", "p")
            self.assertEqual(channel.code, 403, channel.result)
    
            channel = self._send_password_login("localuser", "localpass")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@localuser:test", channel.json_body["user_id"])
    
    
        @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
        def test_local_user_fallback_ui_auth_legacy(self):
            self.local_user_fallback_ui_auth_test_body()
    
        def local_user_fallback_ui_auth_test_body(self):
    
            """rejected login should fall back to local db"""
            self.register_user("localuser", "localpass")
    
            # have the auth provider deny the request
            mock_password_provider.check_password.return_value = defer.succeed(False)
    
            # log in twice, to get two devices
            tok1 = self.login("localuser", "localpass")
            self.login("localuser", "localpass", device_id="dev2")
            mock_password_provider.check_password.reset_mock()
    
            # first delete should give a 401
            session = self._start_delete_device_session(tok1, "dev2")
            mock_password_provider.check_password.assert_not_called()
    
            # Wrong password
            channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
            self.assertEqual(channel.code, 401)  # XXX why not a 403?
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
            mock_password_provider.check_password.assert_called_once_with(
                "@localuser:test", "xxx"
            )
            mock_password_provider.reset_mock()
    
            # Right password
            channel = self._authed_delete_device(
                tok1, "dev2", session, "localuser", "localpass"
            )
            self.assertEqual(channel.code, 200)
            mock_password_provider.check_password.assert_called_once_with(
                "@localuser:test", "localpass"
            )
    
        @override_config(
            {
    
                **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
    
                "password_config": {"localdb_enabled": False},
            }
        )
    
        def test_no_local_user_fallback_login_legacy(self):
            self.no_local_user_fallback_login_test_body()
    
        def no_local_user_fallback_login_test_body(self):
    
            """localdb_enabled can block login with the local password"""
    
            self.register_user("localuser", "localpass")
    
            # check_password must return an awaitable
            mock_password_provider.check_password.return_value = defer.succeed(False)
            channel = self._send_password_login("localuser", "localpass")
            self.assertEqual(channel.code, 403)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
            mock_password_provider.check_password.assert_called_once_with(
                "@localuser:test", "localpass"
            )
    
        @override_config(
            {
    
                **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
    
                "password_config": {"localdb_enabled": False},
            }
        )
    
        def test_no_local_user_fallback_ui_auth_legacy(self):
            self.no_local_user_fallback_ui_auth_test_body()
    
        def no_local_user_fallback_ui_auth_test_body(self):
    
            """localdb_enabled can block ui auth with the local password"""
    
            self.register_user("localuser", "localpass")
    
            # allow login via the auth provider
            mock_password_provider.check_password.return_value = defer.succeed(True)
    
            # log in twice, to get two devices
            tok1 = self.login("localuser", "p")
            self.login("localuser", "p", device_id="dev2")
            mock_password_provider.check_password.reset_mock()
    
            # first delete should give a 401
    
            channel = self._delete_device(tok1, "dev2")
            self.assertEqual(channel.code, 401)
    
            # m.login.password UIA is permitted because the auth provider allows it,
            # even though the localdb does not.
            self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
    
            session = channel.json_body["session"]
    
            mock_password_provider.check_password.assert_not_called()
    
            # now try deleting with the local password
            mock_password_provider.check_password.return_value = defer.succeed(False)
            channel = self._authed_delete_device(
                tok1, "dev2", session, "localuser", "localpass"
            )
            self.assertEqual(channel.code, 401)  # XXX why not a 403?
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
            mock_password_provider.check_password.assert_called_once_with(
                "@localuser:test", "localpass"
            )
    
        @override_config(
            {
    
                **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
    
                "password_config": {"enabled": False},
            }
        )
    
        def test_password_auth_disabled_legacy(self):
            self.password_auth_disabled_test_body()
    
        def password_auth_disabled_test_body(self):
    
            """password auth doesn't work if it's disabled across the board"""
            # login flows should be empty
            flows = self._get_login_flows()
            self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
    
            # login shouldn't work and should be rejected with a 400 ("unknown login type")
            channel = self._send_password_login("u", "p")
            self.assertEqual(channel.code, 400, channel.result)
            mock_password_provider.check_password.assert_not_called()
    
    
        @override_config(legacy_providers_config(LegacyCustomAuthProvider))
        def test_custom_auth_provider_login_legacy(self):
            self.custom_auth_provider_login_test_body()
    
    
        @override_config(providers_config(CustomAuthProvider))
        def test_custom_auth_provider_login(self):
    
            self.custom_auth_provider_login_test_body()
    
        def custom_auth_provider_login_test_body(self):
    
            # login flows should have the custom flow and m.login.password, since we
            # haven't disabled local password lookup.
            # (password must come first, because reasons)
            flows = self._get_login_flows()
            self.assertEqual(
                flows,
                [{"type": "m.login.password"}, {"type": "test.login_type"}]
                + ADDITIONAL_LOGIN_FLOWS,
            )
    
            # login with missing param should be rejected
            channel = self._send_login("test.login_type", "u")
            self.assertEqual(channel.code, 400, channel.result)
            mock_password_provider.check_auth.assert_not_called()
    
    
            mock_password_provider.check_auth.return_value = defer.succeed(
                ("@user:bz", None)
            )
    
            channel = self._send_login("test.login_type", "u", test_field="y")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@user:bz", channel.json_body["user_id"])
            mock_password_provider.check_auth.assert_called_once_with(
                "u", "test.login_type", {"test_field": "y"}
            )
            mock_password_provider.reset_mock()
    
            # try a weird username. Again, it's unclear what we *expect* to happen
            # in these cases, but at least we can guard against the API changing
            # unexpectedly
            mock_password_provider.check_auth.return_value = defer.succeed(
    
            )
            channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
            mock_password_provider.check_auth.assert_called_once_with(
                " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
            )
    
    
        @override_config(legacy_providers_config(LegacyCustomAuthProvider))
        def test_custom_auth_provider_ui_auth_legacy(self):
            self.custom_auth_provider_ui_auth_test_body()
    
    
        @override_config(providers_config(CustomAuthProvider))
        def test_custom_auth_provider_ui_auth(self):
    
            self.custom_auth_provider_ui_auth_test_body()
    
        def custom_auth_provider_ui_auth_test_body(self):
    
            # register the user and log in twice, to get two devices
            self.register_user("localuser", "localpass")
            tok1 = self.login("localuser", "localpass")
            self.login("localuser", "localpass", device_id="dev2")
    
            # make the initial request which returns a 401
            channel = self._delete_device(tok1, "dev2")
            self.assertEqual(channel.code, 401)
            # Ensure that flows are what is expected.
            self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
            self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
            session = channel.json_body["session"]
    
            # missing param
            body = {
                "auth": {
                    "type": "test.login_type",
                    "identifier": {"type": "m.id.user", "user": "localuser"},
                    "session": session,
                },
            }
    
            channel = self._delete_device(tok1, "dev2", body)
            self.assertEqual(channel.code, 400)
            # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
            # use it...
            self.assertIn("Missing parameters", channel.json_body["error"])
            mock_password_provider.check_auth.assert_not_called()
            mock_password_provider.reset_mock()
    
            # right params, but authing as the wrong user
    
            mock_password_provider.check_auth.return_value = defer.succeed(
                ("@user:bz", None)
            )
    
            body["auth"]["test_field"] = "foo"
            channel = self._delete_device(tok1, "dev2", body)
            self.assertEqual(channel.code, 403)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
            mock_password_provider.check_auth.assert_called_once_with(
                "localuser", "test.login_type", {"test_field": "foo"}
            )
            mock_password_provider.reset_mock()
    
            # and finally, succeed
            mock_password_provider.check_auth.return_value = defer.succeed(
    
            )
            channel = self._delete_device(tok1, "dev2", body)
            self.assertEqual(channel.code, 200)
            mock_password_provider.check_auth.assert_called_once_with(
                "localuser", "test.login_type", {"test_field": "foo"}
            )
    
    
        @override_config(legacy_providers_config(LegacyCustomAuthProvider))
        def test_custom_auth_provider_callback_legacy(self):
            self.custom_auth_provider_callback_test_body()
    
    
        @override_config(providers_config(CustomAuthProvider))
        def test_custom_auth_provider_callback(self):
    
            self.custom_auth_provider_callback_test_body()
    
        def custom_auth_provider_callback_test_body(self):
    
            callback = Mock(return_value=defer.succeed(None))
    
            mock_password_provider.check_auth.return_value = defer.succeed(
                ("@user:bz", callback)
            )
            channel = self._send_login("test.login_type", "u", test_field="y")
            self.assertEqual(channel.code, 200, channel.result)
            self.assertEqual("@user:bz", channel.json_body["user_id"])
            mock_password_provider.check_auth.assert_called_once_with(
                "u", "test.login_type", {"test_field": "y"}
            )
    
            # check the args to the callback
            callback.assert_called_once()
            call_args, call_kwargs = callback.call_args
            # should be one positional arg
            self.assertEqual(len(call_args), 1)
            self.assertEqual(call_args[0]["user_id"], "@user:bz")
            for p in ["user_id", "access_token", "device_id", "home_server"]:
                self.assertIn(p, call_args[0])
    
    
        @override_config(
            {
                **legacy_providers_config(LegacyCustomAuthProvider),
                "password_config": {"enabled": False},
            }
        )
        def test_custom_auth_password_disabled_legacy(self):
            self.custom_auth_password_disabled_test_body()
    
    
        @override_config(
            {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
        )
        def test_custom_auth_password_disabled(self):
    
            self.custom_auth_password_disabled_test_body()
    
        def custom_auth_password_disabled_test_body(self):
    
            """Test login with a custom auth provider where password login is disabled"""
            self.register_user("localuser", "localpass")
    
            flows = self._get_login_flows()
            self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
    
            # login shouldn't work and should be rejected with a 400 ("unknown login type")
            channel = self._send_password_login("localuser", "localpass")
    
            self.assertEqual(channel.code, 400, channel.result)
            mock_password_provider.check_auth.assert_not_called()
    
    
        @override_config(
            {
                **legacy_providers_config(LegacyCustomAuthProvider),
                "password_config": {"enabled": False, "localdb_enabled": False},
            }
        )
        def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
            self.custom_auth_password_disabled_localdb_enabled_test_body()
    
    
        @override_config(
            {
                **providers_config(CustomAuthProvider),
                "password_config": {"enabled": False, "localdb_enabled": False},
            }
        )
        def test_custom_auth_password_disabled_localdb_enabled(self):
    
            self.custom_auth_password_disabled_localdb_enabled_test_body()
    
        def custom_auth_password_disabled_localdb_enabled_test_body(self):
    
            """Check the localdb_enabled == enabled == False
    
            Regression test for https://github.com/matrix-org/synapse/issues/8914: check
            that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
            cause an exception.
            """
            self.register_user("localuser", "localpass")
    
            flows = self._get_login_flows()
            self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
    
            # login shouldn't work and should be rejected with a 400 ("unknown login type")
            channel = self._send_password_login("localuser", "localpass")
    
            self.assertEqual(channel.code, 400, channel.result)
            mock_password_provider.check_auth.assert_not_called()
    
    
        @override_config(
            {
                **legacy_providers_config(LegacyPasswordCustomAuthProvider),
                "password_config": {"enabled": False},
            }
        )
        def test_password_custom_auth_password_disabled_login_legacy(self):
            self.password_custom_auth_password_disabled_login_test_body()
    
    
        @override_config(
            {
                **providers_config(PasswordCustomAuthProvider),
                "password_config": {"enabled": False},
            }
        )
        def test_password_custom_auth_password_disabled_login(self):
    
            self.password_custom_auth_password_disabled_login_test_body()
    
        def password_custom_auth_password_disabled_login_test_body(self):
    
            """log in with a custom auth provider which implements password, but password
            login is disabled"""
            self.register_user("localuser", "localpass")
    
            flows = self._get_login_flows()
            self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
    
            # login shouldn't work and should be rejected with a 400 ("unknown login type")
            channel = self._send_password_login("localuser", "localpass")
            self.assertEqual(channel.code, 400, channel.result)
            mock_password_provider.check_auth.assert_not_called()
    
            mock_password_provider.check_password.assert_not_called()
    
        @override_config(
            {
                **legacy_providers_config(LegacyPasswordCustomAuthProvider),
                "password_config": {"enabled": False},
            }
        )
        def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
            self.password_custom_auth_password_disabled_ui_auth_test_body()
    
    
        @override_config(
            {
                **providers_config(PasswordCustomAuthProvider),
                "password_config": {"enabled": False},
            }
        )
        def test_password_custom_auth_password_disabled_ui_auth(self):
    
            self.password_custom_auth_password_disabled_ui_auth_test_body()
    
        def password_custom_auth_password_disabled_ui_auth_test_body(self):
    
            """UI Auth with a custom auth provider which implements password, but password
            login is disabled"""
            # register the user and log in twice via the test login type to get two devices,
            self.register_user("localuser", "localpass")
            mock_password_provider.check_auth.return_value = defer.succeed(
    
            )
            channel = self._send_login("test.login_type", "localuser", test_field="")
            self.assertEqual(channel.code, 200, channel.result)
            tok1 = channel.json_body["access_token"]
    
            channel = self._send_login(
                "test.login_type", "localuser", test_field="", device_id="dev2"
            )
            self.assertEqual(channel.code, 200, channel.result)
    
            # make the initial request which returns a 401
            channel = self._delete_device(tok1, "dev2")
            self.assertEqual(channel.code, 401)
            # Ensure that flows are what is expected. In particular, "password" should *not*
            # be present.
            self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
            session = channel.json_body["session"]
    
            mock_password_provider.reset_mock()
    
            # check that auth with password is rejected
            body = {
                "auth": {
                    "type": "m.login.password",
                    "identifier": {"type": "m.id.user", "user": "localuser"},
                    "password": "localpass",
                    "session": session,
                },
            }
    
            channel = self._delete_device(tok1, "dev2", body)
            self.assertEqual(channel.code, 400)
            self.assertEqual(
                "Password login has been disabled.", channel.json_body["error"]
            )
            mock_password_provider.check_auth.assert_not_called()
    
            mock_password_provider.check_password.assert_not_called()
    
            mock_password_provider.reset_mock()
    
            # successful auth
            body["auth"]["type"] = "test.login_type"
            body["auth"]["test_field"] = "x"
            channel = self._delete_device(tok1, "dev2", body)
            self.assertEqual(channel.code, 200)
            mock_password_provider.check_auth.assert_called_once_with(
                "localuser", "test.login_type", {"test_field": "x"}
            )
    
            mock_password_provider.check_password.assert_not_called()
    
        @override_config(
            {
                **legacy_providers_config(LegacyCustomAuthProvider),
                "password_config": {"localdb_enabled": False},
            }
        )
        def test_custom_auth_no_local_user_fallback_legacy(self):
            self.custom_auth_no_local_user_fallback_test_body()
    
        @override_config(
            {
                **providers_config(CustomAuthProvider),
                "password_config": {"localdb_enabled": False},
            }
        )
        def test_custom_auth_no_local_user_fallback(self):
    
            self.custom_auth_no_local_user_fallback_test_body()
    
        def custom_auth_no_local_user_fallback_test_body(self):
    
            """Test login with a custom auth provider where the local db is disabled"""
            self.register_user("localuser", "localpass")
    
            flows = self._get_login_flows()
            self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
    
            # password login shouldn't work and should be rejected with a 400
            # ("unknown login type")
            channel = self._send_password_login("localuser", "localpass")
            self.assertEqual(channel.code, 400, channel.result)
    
    
        def test_on_logged_out(self):
            """Tests that the on_logged_out callback is called when the user logs out."""
            self.register_user("rin", "password")
            tok = self.login("rin", "password")
    
            self.called = False
    
            async def on_logged_out(user_id, device_id, access_token):
                self.called = True
    
            on_logged_out = Mock(side_effect=on_logged_out)
            self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
                on_logged_out
            )
    
            channel = self.make_request(
                "POST",
                "/_matrix/client/v3/logout",
                {},
                access_token=tok,
            )
            self.assertEqual(channel.code, 200)
            on_logged_out.assert_called_once()
            self.assertTrue(self.called)
    
    
        def test_username(self):
            """Tests that the get_username_for_registration callback can define the username
            of a user when registering.
            """
            self._setup_get_username_for_registration()
    
            username = "rin"
            channel = self.make_request(
                "POST",
                "/register",
                {
                    "username": username,
                    "password": "bar",
                    "auth": {"type": LoginType.DUMMY},
                },
            )
            self.assertEqual(channel.code, 200)
    
            # Our callback takes the username and appends "-foo" to it, check that's what we
            # have.
            mxid = channel.json_body["user_id"]
            self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
    
        def test_username_uia(self):
            """Tests that the get_username_for_registration callback is only called at the
            end of the UIA flow.
            """
            m = self._setup_get_username_for_registration()
    
            # Initiate the UIA flow.
            username = "rin"
            channel = self.make_request(
                "POST",
                "register",
                {"username": username, "type": "m.login.password", "password": "bar"},
            )
            self.assertEqual(channel.code, 401)
            self.assertIn("session", channel.json_body)
    
            # Check that the callback hasn't been called yet.
            m.assert_not_called()
    
            # Finish the UIA flow.
            session = channel.json_body["session"]
            channel = self.make_request(
                "POST",
                "register",
                {"auth": {"session": session, "type": LoginType.DUMMY}},
            )
            self.assertEqual(channel.code, 200, channel.json_body)
            mxid = channel.json_body["user_id"]
            self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
    
            # Check that the callback has been called.
            m.assert_called_once()
    
        def _setup_get_username_for_registration(self) -> Mock:
            """Registers a get_username_for_registration callback that appends "-foo" to the
            username the client is trying to register.
            """
    
            async def get_username_for_registration(uia_results, params):
                self.assertIn(LoginType.DUMMY, uia_results)
                username = params["username"]
                return username + "-foo"
    
            m = Mock(side_effect=get_username_for_registration)
    
            password_auth_provider = self.hs.get_password_auth_provider()
            password_auth_provider.get_username_for_registration_callbacks.append(m)
    
            return m
    
    
        def _get_login_flows(self) -> JsonDict:
    
            channel = self.make_request("GET", "/_matrix/client/r0/login")
    
            self.assertEqual(channel.code, 200, channel.result)
            return channel.json_body["flows"]
    
        def _send_password_login(self, user: str, password: str) -> FakeChannel:
            return self._send_login(type="m.login.password", user=user, password=password)
    
        def _send_login(self, type, user, **params) -> FakeChannel:
    
            params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
    
            channel = self.make_request("POST", "/_matrix/client/r0/login", params)
    
            return channel
    
        def _start_delete_device_session(self, access_token, device_id) -> str:
            """Make an initial delete device request, and return the UI Auth session ID"""
            channel = self._delete_device(access_token, device_id)
            self.assertEqual(channel.code, 401)
            # Ensure that flows are what is expected.
            self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
            return channel.json_body["session"]
    
        def _authed_delete_device(
            self,
            access_token: str,
            device_id: str,
            session: str,
            user_id: str,
            password: str,
        ) -> FakeChannel:
            """Make a delete device request, authenticating with the given uid/password"""
            return self._delete_device(
                access_token,
                device_id,
                {
                    "auth": {
                        "type": "m.login.password",
                        "identifier": {"type": "m.id.user", "user": user_id},
                        "password": password,
                        "session": session,
                    },
                },
            )
    
        def _delete_device(
    
            self,
            access_token: str,
            device: str,
            body: Union[JsonDict, bytes] = b"",
    
        ) -> FakeChannel:
            """Delete an individual device."""
    
                "DELETE", "devices/" + device, body, access_token=access_token
            )
            return channel