Skip to content
Snippets Groups Projects
test_login.py 51.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    
    from typing import Any, Dict, List, Union
    
    import pymacaroons
    
    from twisted.web.resource import Resource
    
    
    import synapse.rest.admin
    
    from synapse.appservice import ApplicationService
    
    from synapse.rest.client.v1 import login, logout
    
    from synapse.rest.client.v2_alpha import devices, register
    
    from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
    
    from synapse.rest.synapse.client import build_synapse_client_resource_tree
    
    from synapse.types import create_requester
    
    
    from tests import unittest
    
    from tests.handlers.test_oidc import HAS_OIDC
    from tests.handlers.test_saml import has_saml2
    from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
    
    from tests.test_utils.html_parsers import TestHtmlParser
    
    from tests.unittest import HomeserverTestCase, override_config, skip_unless
    
    
    try:
        import jwt
    
        HAS_JWT = True
    except ImportError:
        HAS_JWT = False
    
    
    # public_base_url used in some tests
    BASE_URL = "https://synapse/"
    
    # CAS server used in some tests
    CAS_SERVER = "https://fake.test"
    
    # just enough to tell pysaml2 where to redirect to
    SAML_SERVER = "https://test.saml.server/idp/sso"
    TEST_SAML_METADATA = """
    <md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata">
      <md:IDPSSODescriptor protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
          <md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="%(SAML_SERVER)s"/>
      </md:IDPSSODescriptor>
    </md:EntityDescriptor>
    """ % {
        "SAML_SERVER": SAML_SERVER,
    }
    
    
    LOGIN_URL = b"/_matrix/client/r0/login"
    
    TEST_URL = b"/_matrix/client/r0/account/whoami"
    
    # a (valid) url with some annoying characters in.  %3D is =, %26 is &, %2B is +
    TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
    
    # the query params in TEST_CLIENT_REDIRECT_URL
    EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
    
    
    # (possibly experimental) login flows we expect to appear in the list after the normal
    # ones
    ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
    
    
    
    class LoginRestServletTestCase(unittest.HomeserverTestCase):
    
        servlets = [
    
            synapse.rest.admin.register_servlets_for_client_rest_resource,
    
            login.register_servlets,
    
            logout.register_servlets,
    
            devices.register_servlets,
            lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
    
        ]
    
        def make_homeserver(self, reactor, clock):
            self.hs = self.setup_test_homeserver()
            self.hs.config.enable_registration = True
            self.hs.config.registrations_require_3pid = []
            self.hs.config.auto_join_rooms = []
            self.hs.config.enable_registration_captcha = False
    
            return self.hs
    
    
        @override_config(
            {
                "rc_login": {
                    "address": {"per_second": 0.17, "burst_count": 5},
                    # Prevent the account login ratelimiter from raising first
                    #
                    # This is normally covered by the default test homeserver config
                    # which sets these values to 10000, but as we're overriding the entire
                    # rc_login dict here, we need to set this manually as well
                    "account": {"per_second": 10000, "burst_count": 10000},
                }
            }
        )
    
        def test_POST_ratelimiting_per_address(self):
            # Create different users so we're sure not to be bothered by the per-user
            # ratelimiter.
            for i in range(0, 6):
                self.register_user("kermit" + str(i), "monkey")
    
            for i in range(0, 6):
                params = {
                    "type": "m.login.password",
    
                    "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
    
                    "password": "monkey",
                }
    
                channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
                if i == 5:
                    self.assertEquals(channel.result["code"], b"429", channel.result)
                    retry_after_ms = int(channel.json_body["retry_after_ms"])
                else:
                    self.assertEquals(channel.result["code"], b"200", channel.result)
    
            # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
            # than 1min.
            self.assertTrue(retry_after_ms < 6000)
    
    
            self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
    
    
            params = {
                "type": "m.login.password",
    
                "identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
    
                "password": "monkey",
            }
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
    
        @override_config(
            {
                "rc_login": {
                    "account": {"per_second": 0.17, "burst_count": 5},
                    # Prevent the address login ratelimiter from raising first
                    #
                    # This is normally covered by the default test homeserver config
                    # which sets these values to 10000, but as we're overriding the entire
                    # rc_login dict here, we need to set this manually as well
                    "address": {"per_second": 10000, "burst_count": 10000},
                }
            }
        )
    
        def test_POST_ratelimiting_per_account(self):
            self.register_user("kermit", "monkey")
    
            for i in range(0, 6):
                params = {
                    "type": "m.login.password",
    
                    "identifier": {"type": "m.id.user", "user": "kermit"},
    
                    "password": "monkey",
                }
    
                channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
                if i == 5:
                    self.assertEquals(channel.result["code"], b"429", channel.result)
                    retry_after_ms = int(channel.json_body["retry_after_ms"])
                else:
                    self.assertEquals(channel.result["code"], b"200", channel.result)
    
            # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
            # than 1min.
            self.assertTrue(retry_after_ms < 6000)
    
    
            self.reactor.advance(retry_after_ms / 1000.0)
    
    
            params = {
                "type": "m.login.password",
    
                "identifier": {"type": "m.id.user", "user": "kermit"},
    
                "password": "monkey",
            }
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
        @override_config(
            {
                "rc_login": {
                    # Prevent the address login ratelimiter from raising first
                    #
                    # This is normally covered by the default test homeserver config
                    # which sets these values to 10000, but as we're overriding the entire
                    # rc_login dict here, we need to set this manually as well
                    "address": {"per_second": 10000, "burst_count": 10000},
                    "failed_attempts": {"per_second": 0.17, "burst_count": 5},
                }
            }
        )
    
        def test_POST_ratelimiting_per_account_failed_attempts(self):
            self.register_user("kermit", "monkey")
    
            for i in range(0, 6):
                params = {
                    "type": "m.login.password",
    
                    "identifier": {"type": "m.id.user", "user": "kermit"},
    
                    "password": "notamonkey",
                }
    
                channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
                if i == 5:
                    self.assertEquals(channel.result["code"], b"429", channel.result)
                    retry_after_ms = int(channel.json_body["retry_after_ms"])
                else:
                    self.assertEquals(channel.result["code"], b"403", channel.result)
    
            # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
            # than 1min.
            self.assertTrue(retry_after_ms < 6000)
    
    
            self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
    
    
            params = {
                "type": "m.login.password",
    
                "identifier": {"type": "m.id.user", "user": "kermit"},
    
                "password": "notamonkey",
            }
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
            self.assertEquals(channel.result["code"], b"403", channel.result)
    
    
        @override_config({"session_lifetime": "24h"})
        def test_soft_logout(self):
            self.register_user("kermit", "monkey")
    
            # we shouldn't be able to make requests without an access token
    
            channel = self.make_request(b"GET", TEST_URL)
    
            self.assertEquals(channel.result["code"], b"401", channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN")
    
            # log in as normal
            params = {
                "type": "m.login.password",
                "identifier": {"type": "m.id.user", "user": "kermit"},
                "password": "monkey",
            }
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
            self.assertEquals(channel.code, 200, channel.result)
            access_token = channel.json_body["access_token"]
            device_id = channel.json_body["device_id"]
    
            # we should now be able to make requests with the access token
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 200, channel.result)
    
            # time passes
            self.reactor.advance(24 * 3600)
    
            # ... and we should be soft-logouted
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 401, channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
            self.assertEquals(channel.json_body["soft_logout"], True)
    
            #
            # test behaviour after deleting the expired device
            #
    
            # we now log in as a different device
            access_token_2 = self.login("kermit", "monkey")
    
            # more requests with the expired token should still return a soft-logout
            self.reactor.advance(3600)
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 401, channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
            self.assertEquals(channel.json_body["soft_logout"], True)
    
            # ... but if we delete that device, it will be a proper logout
            self._delete_device(access_token_2, "kermit", "monkey", device_id)
    
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 401, channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
            self.assertEquals(channel.json_body["soft_logout"], False)
    
        def _delete_device(self, access_token, user_id, password, device_id):
            """Perform the UI-Auth to delete a device"""
    
                b"DELETE", "devices/" + device_id, access_token=access_token
            )
            self.assertEquals(channel.code, 401, channel.result)
            # check it's a UI-Auth fail
            self.assertEqual(
                set(channel.json_body.keys()),
                {"flows", "params", "session"},
                channel.result,
            )
    
            auth = {
                "type": "m.login.password",
                # https://github.com/matrix-org/synapse/issues/5665
                # "identifier": {"type": "m.id.user", "user": user_id},
                "user": user_id,
                "password": password,
                "session": channel.json_body["session"],
            }
    
    
                b"DELETE",
                "devices/" + device_id,
                access_token=access_token,
                content={"auth": auth},
            )
            self.assertEquals(channel.code, 200, channel.result)
    
        @override_config({"session_lifetime": "24h"})
        def test_session_can_hard_logout_after_being_soft_logged_out(self):
            self.register_user("kermit", "monkey")
    
            # log in as normal
            access_token = self.login("kermit", "monkey")
    
            # we should now be able to make requests with the access token
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 200, channel.result)
    
            # time passes
            self.reactor.advance(24 * 3600)
    
            # ... and we should be soft-logouted
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 401, channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
            self.assertEquals(channel.json_body["soft_logout"], True)
    
            # Now try to hard logout this session
    
            channel = self.make_request(b"POST", "/logout", access_token=access_token)
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
        @override_config({"session_lifetime": "24h"})
        def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self):
            self.register_user("kermit", "monkey")
    
            # log in as normal
            access_token = self.login("kermit", "monkey")
    
            # we should now be able to make requests with the access token
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 200, channel.result)
    
            # time passes
            self.reactor.advance(24 * 3600)
    
            # ... and we should be soft-logouted
    
            channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
    
            self.assertEquals(channel.code, 401, channel.result)
            self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
            self.assertEquals(channel.json_body["soft_logout"], True)
    
            # Now try to hard log out all of the user's sessions
    
            channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
    
    @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
    class MultiSSOTestCase(unittest.HomeserverTestCase):
        """Tests for homeservers with multiple SSO providers enabled"""
    
        servlets = [
            login.register_servlets,
        ]
    
        def default_config(self) -> Dict[str, Any]:
            config = super().default_config()
    
            config["public_baseurl"] = BASE_URL
    
            config["cas_config"] = {
                "enabled": True,
                "server_url": CAS_SERVER,
                "service_url": "https://matrix.goodserver.com:8448",
            }
    
            config["saml2_config"] = {
                "sp_config": {
                    "metadata": {"inline": [TEST_SAML_METADATA]},
                    # use the XMLSecurity backend to avoid relying on xmlsec1
                    "crypto_backend": "XMLSecurity",
                },
            }
    
    
            config["oidc_config"] = TEST_OIDC_CONFIG
    
    
            # additional OIDC providers
            config["oidc_providers"] = [
                {
                    "idp_id": "idp1",
                    "idp_name": "IDP1",
                    "discover": False,
                    "issuer": "https://issuer1",
                    "client_id": "test-client-id",
                    "client_secret": "test-client-secret",
                    "scopes": ["profile"],
                    "authorization_endpoint": "https://issuer1/auth",
                    "token_endpoint": "https://issuer1/token",
                    "userinfo_endpoint": "https://issuer1/userinfo",
                    "user_mapping_provider": {
                        "config": {"localpart_template": "{{ user.sub }}"}
                    },
                }
            ]
    
            return config
    
        def create_resource_dict(self) -> Dict[str, Resource]:
            d = super().create_resource_dict()
    
            d.update(build_synapse_client_resource_tree(self.hs))
    
        def test_get_login_flows(self):
            """GET /login should return password and SSO flows"""
            channel = self.make_request("GET", "/_matrix/client/r0/login")
            self.assertEqual(channel.code, 200, channel.result)
    
            expected_flows = [
                {"type": "m.login.cas"},
                {"type": "m.login.sso"},
                {"type": "m.login.token"},
                {"type": "m.login.password"},
            ] + ADDITIONAL_LOGIN_FLOWS
    
            self.assertCountEqual(channel.json_body["flows"], expected_flows)
    
        @override_config({"experimental_features": {"msc2858_enabled": True}})
        def test_get_msc2858_login_flows(self):
            """The SSO flow should include IdP info if MSC2858 is enabled"""
            channel = self.make_request("GET", "/_matrix/client/r0/login")
            self.assertEqual(channel.code, 200, channel.result)
    
            # stick the flows results in a dict by type
            flow_results = {}  # type: Dict[str, Any]
            for f in channel.json_body["flows"]:
                flow_type = f["type"]
                self.assertNotIn(
                    flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
                )
                flow_results[flow_type] = f
    
            self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
            sso_flow = flow_results.pop("m.login.sso")
            # we should have a set of IdPs
            self.assertCountEqual(
                sso_flow["org.matrix.msc2858.identity_providers"],
                [
                    {"id": "cas", "name": "CAS"},
                    {"id": "saml", "name": "SAML"},
                    {"id": "oidc-idp1", "name": "IDP1"},
                    {"id": "oidc", "name": "OIDC"},
                ],
            )
    
            # the rest of the flows are simple
            expected_flows = [
                {"type": "m.login.cas"},
                {"type": "m.login.token"},
                {"type": "m.login.password"},
            ] + ADDITIONAL_LOGIN_FLOWS
    
            self.assertCountEqual(flow_results.values(), expected_flows)
    
    
        def test_multi_sso_redirect(self):
            """/login/sso/redirect should redirect to an identity picker"""
            # first hit the redirect url, which should redirect to our idp picker
            channel = self.make_request(
                "GET",
    
                "/_matrix/client/r0/login/sso/redirect?redirectUrl="
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
    
            )
            self.assertEqual(channel.code, 302, channel.result)
            uri = channel.headers.getRawHeaders("Location")[0]
    
            # hitting that picker should give us some HTML
            channel = self.make_request("GET", uri)
            self.assertEqual(channel.code, 200, channel.result)
    
            # parse the form to check it has fields assumed elsewhere in this class
    
            html = channel.result["body"].decode("utf-8")
    
            p.feed(html)
    
            # there should be a link for each href
            returned_idps = []  # type: List[str]
            for link in p.links:
                path, query = link.split("?", 1)
                self.assertEqual(path, "pick_idp")
                params = urllib.parse.parse_qs(query)
                self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL])
                returned_idps.append(params["idp"][0])
    
            self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"])
    
    
        def test_multi_sso_redirect_to_cas(self):
            """If CAS is chosen, should redirect to the CAS server"""
    
            channel = self.make_request(
                "GET",
    
                "/_synapse/client/pick_idp?redirectUrl="
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
                + "&idp=cas",
    
                shorthand=False,
            )
            self.assertEqual(channel.code, 302, channel.result)
            cas_uri = channel.headers.getRawHeaders("Location")[0]
            cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
    
            # it should redirect us to the login page of the cas server
            self.assertEqual(cas_uri_path, CAS_SERVER + "/login")
    
            # check that the redirectUrl is correctly encoded in the service param - ie, the
            # place that CAS will redirect to
            cas_uri_params = urllib.parse.parse_qs(cas_uri_query)
            service_uri = cas_uri_params["service"][0]
            _, service_uri_query = service_uri.split("?", 1)
            service_uri_params = urllib.parse.parse_qs(service_uri_query)
    
            self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
    
    
        def test_multi_sso_redirect_to_saml(self):
            """If SAML is chosen, should redirect to the SAML server"""
            channel = self.make_request(
                "GET",
                "/_synapse/client/pick_idp?redirectUrl="
    
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
    
                + "&idp=saml",
            )
            self.assertEqual(channel.code, 302, channel.result)
            saml_uri = channel.headers.getRawHeaders("Location")[0]
            saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
    
            # it should redirect us to the login page of the SAML server
            self.assertEqual(saml_uri_path, SAML_SERVER)
    
            # the RelayState is used to carry the client redirect url
            saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
            relay_state_param = saml_uri_params["RelayState"][0]
    
            self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
    
            """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
    
    
            channel = self.make_request(
                "GET",
                "/_synapse/client/pick_idp?redirectUrl="
    
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
    
                + "&idp=oidc",
            )
            self.assertEqual(channel.code, 302, channel.result)
            oidc_uri = channel.headers.getRawHeaders("Location")[0]
            oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
    
            # it should redirect us to the auth page of the OIDC server
            self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
    
            # ... and should have set a cookie including the redirect url
            cookies = dict(
                h.split(";")[0].split("=", maxsplit=1)
                for h in channel.headers.getRawHeaders("Set-Cookie")
            )
    
            oidc_session_cookie = cookies["oidc_session"]
            macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
            self.assertEqual(
                self._get_value_from_macaroon(macaroon, "client_redirect_url"),
    
            channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
    
            # that should serve a confirmation page
            self.assertEqual(channel.code, 200, channel.result)
            self.assertTrue(
                channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
            )
            p = TestHtmlParser()
            p.feed(channel.text_body)
            p.close()
    
            # ... which should contain our redirect link
            self.assertEqual(len(p.links), 1)
            path, query = p.links[0].split("?", 1)
            self.assertEqual(path, "https://x")
    
            # it will have url-encoded the params properly, so we'll have to parse them
            params = urllib.parse.parse_qsl(
                query, keep_blank_values=True, strict_parsing=True, errors="strict"
            )
            self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
            self.assertEqual(params[2][0], "loginToken")
    
            # finally, submit the matrix login token to the login API, which gives us our
            # matrix access token, mxid, and device id.
            login_token = params[2][1]
            chan = self.make_request(
    
                "POST",
                "/login",
                content={"type": "m.login.token", "token": login_token},
    
            )
            self.assertEqual(chan.code, 200, chan.result)
            self.assertEqual(chan.json_body["user_id"], "@user1:test")
    
    
        def test_multi_sso_redirect_to_unknown(self):
            """An unknown IdP should cause a 400"""
            channel = self.make_request(
    
                "GET",
                "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
    
            )
            self.assertEqual(channel.code, 400, channel.result)
    
    
        def test_client_idp_redirect_msc2858_disabled(self):
            """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
            channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
            )
            self.assertEqual(channel.code, 400, channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
    
        @override_config({"experimental_features": {"msc2858_enabled": True}})
        def test_client_idp_redirect_to_unknown(self):
            """If the client tries to pick an unknown IdP, return a 404"""
            channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
            )
            self.assertEqual(channel.code, 404, channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
    
        @override_config({"experimental_features": {"msc2858_enabled": True}})
        def test_client_idp_redirect_to_oidc(self):
            """If the client pick a known IdP, redirect to it"""
            channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
            )
    
            self.assertEqual(channel.code, 302, channel.result)
            oidc_uri = channel.headers.getRawHeaders("Location")[0]
            oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
    
            # it should redirect us to the auth page of the OIDC server
            self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
    
    
        @staticmethod
        def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
            prefix = key + " = "
            for caveat in macaroon.caveats:
                if caveat.caveat_id.startswith(prefix):
                    return caveat.caveat_id[len(prefix) :]
            raise ValueError("No %s caveat in macaroon" % (key,))
    
    
    
    class CASTestCase(unittest.HomeserverTestCase):
    
    
        servlets = [
            login.register_servlets,
        ]
    
        def make_homeserver(self, reactor, clock):
            self.base_url = "https://matrix.goodserver.com/"
            self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
    
            config = self.default_config()
    
            config["public_baseurl"] = (
                config.get("public_baseurl") or "https://matrix.goodserver.com:8448"
            )
    
            config["cas_config"] = {
                "enabled": True,
    
                "server_url": CAS_SERVER,
    
            cas_user_id = "username"
            self.user_id = "@%s:test" % cas_user_id
    
    
            async def get_raw(uri, args):
                """Return an example response payload from a call to the `/proxyValidate`
                endpoint of a CAS server, copied from
                https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
    
                This needs to be returned by an async function (as opposed to set as the
                mock's return value) because the corresponding Synapse code awaits on it.
                """
    
                    <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
                      <cas:authenticationSuccess>
    
                          <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
                          <cas:proxies>
                              <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
                              <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
                          </cas:proxies>
                      </cas:authenticationSuccess>
                    </cas:serviceResponse>
                """
    
    
            mocked_http_client = Mock(spec=["get_raw"])
            mocked_http_client.get_raw.side_effect = get_raw
    
            self.hs = self.setup_test_homeserver(
    
                config=config,
                proxied_http_client=mocked_http_client,
    
        def prepare(self, reactor, clock, hs):
            self.deactivate_account_handler = hs.get_deactivate_account_handler()
    
    
        def test_cas_redirect_confirm(self):
            """Tests that the SSO login flow serves a confirmation page before redirecting a
            user to the redirect URL.
            """
    
            base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
    
            redirect_url = "https://dodgy-site.com/"
    
            url_parts = list(urllib.parse.urlparse(base_url))
            query = dict(urllib.parse.parse_qsl(url_parts[4]))
            query.update({"redirectUrl": redirect_url})
            query.update({"ticket": "ticket"})
            url_parts[4] = urllib.parse.urlencode(query)
            cas_ticket_url = urllib.parse.urlunparse(url_parts)
    
            # Get Synapse to call the fake CAS and serve the template.
    
            channel = self.make_request("GET", cas_ticket_url)
    
    
            # Test that the response is HTML.
    
            self.assertEqual(channel.code, 200, channel.result)
    
            content_type_header_value = ""
            for header in channel.result.get("headers", []):
                if header[0] == b"Content-Type":
                    content_type_header_value = header[1].decode("utf8")
    
            self.assertTrue(content_type_header_value.startswith("text/html"))
    
            # Test that the body isn't empty.
            self.assertTrue(len(channel.result["body"]) > 0)
    
            # And that it contains our redirect link
            self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
    
    
        @override_config(
            {
                "sso": {
                    "client_whitelist": [
                        "https://legit-site.com/",
                        "https://other-site.com/",
                    ]
                }
            }
        )
        def test_cas_redirect_whitelisted(self):
    
            """Tests that the SSO login flow serves a redirect to a whitelisted url"""
    
            self._test_redirect("https://legit-site.com/")
    
        @override_config({"public_baseurl": "https://example.com"})
        def test_cas_redirect_login_fallback(self):
            self._test_redirect("https://example.com/_matrix/static/client/login")
    
        def _test_redirect(self, redirect_url):
            """Tests that the SSO login flow serves a redirect for the given redirect URL."""
    
            cas_ticket_url = (
                "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
                % (urllib.parse.quote(redirect_url))
            )
    
            # Get Synapse to call the fake CAS and serve the template.
    
            channel = self.make_request("GET", cas_ticket_url)
    
    
            self.assertEqual(channel.code, 302)
            location_headers = channel.headers.getRawHeaders("Location")
            self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
    
    
        @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
        def test_deactivated_user(self):
            """Logging in as a deactivated account should error."""
            redirect_url = "https://legit-site.com/"
    
            # First login (to create the user).
            self._test_redirect(redirect_url)
    
            # Deactivate the account.
            self.get_success(
    
                self.deactivate_account_handler.deactivate_account(
                    self.user_id, False, create_requester(self.user_id)
                )
    
            )
    
            # Request the CAS ticket.
            cas_ticket_url = (
                "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
                % (urllib.parse.quote(redirect_url))
            )
    
            # Get Synapse to call the fake CAS and serve the template.
    
            channel = self.make_request("GET", cas_ticket_url)
    
    
            # Because the user is deactivated they are served an error template.
            self.assertEqual(channel.code, 403)
            self.assertIn(b"SSO account deactivated", channel.result["body"])
    
    @skip_unless(HAS_JWT, "requires jwt")
    
    class JWTTestCase(unittest.HomeserverTestCase):
        servlets = [
            synapse.rest.admin.register_servlets_for_client_rest_resource,
            login.register_servlets,
        ]
    
        jwt_secret = "secret"
    
    
        def make_homeserver(self, reactor, clock):
            self.hs = self.setup_test_homeserver()
            self.hs.config.jwt_enabled = True
            self.hs.config.jwt_secret = self.jwt_secret
    
            self.hs.config.jwt_algorithm = self.jwt_algorithm
    
        def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str:
    
            # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
    
            result = jwt.encode(
                payload, secret, self.jwt_algorithm
            )  # type: Union[str, bytes]
    
            if isinstance(result, bytes):
                return result.decode("ascii")
            return result
    
            params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
            return channel
    
        def test_login_jwt_valid_registered(self):
            self.register_user("kermit", "monkey")
            channel = self.jwt_login({"sub": "kermit"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@kermit:test")
    
        def test_login_jwt_valid_unregistered(self):
            channel = self.jwt_login({"sub": "frog"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@frog:test")
    
        def test_login_jwt_invalid_signature(self):
            channel = self.jwt_login({"sub": "frog"}, "notsecret")
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"],
                "JWT validation failed: Signature verification failed",
            )
    
    
        def test_login_jwt_expired(self):
            channel = self.jwt_login({"sub": "frog", "exp": 864000})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"], "JWT validation failed: Signature has expired"
            )
    
    
        def test_login_jwt_not_before(self):
            now = int(time.time())
            channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"],
                "JWT validation failed: The token is not yet valid (nbf)",
            )
    
    
        def test_login_no_sub(self):
            channel = self.jwt_login({"username": "root"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(channel.json_body["error"], "Invalid JWT")
    
    
        @override_config(
            {
                "jwt_config": {
                    "jwt_enabled": True,
                    "secret": jwt_secret,
                    "algorithm": jwt_algorithm,
                    "issuer": "test-issuer",
                }
            }
        )
        def test_login_iss(self):
            """Test validating the issuer claim."""
            # A valid issuer.
            channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@kermit:test")
    
            # An invalid issuer.
            channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"], "JWT validation failed: Invalid issuer"
            )
    
            # Not providing an issuer.
            channel = self.jwt_login({"sub": "kermit"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"],
                'JWT validation failed: Token is missing the "iss" claim',
            )
    
        def test_login_iss_no_config(self):
            """Test providing an issuer claim without requiring it in the configuration."""
            channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@kermit:test")
    
        @override_config(
            {
                "jwt_config": {
                    "jwt_enabled": True,
                    "secret": jwt_secret,
                    "algorithm": jwt_algorithm,
                    "audiences": ["test-audience"],
                }
            }
        )
        def test_login_aud(self):
            """Test validating the audience claim."""
            # A valid audience.
            channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@kermit:test")
    
            # An invalid audience.
            channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"], "JWT validation failed: Invalid audience"
            )
    
            # Not providing an audience.
            channel = self.jwt_login({"sub": "kermit"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"],
                'JWT validation failed: Token is missing the "aud" claim',
            )
    
        def test_login_aud_no_config(self):
            """Test providing an audience without requiring it in the configuration."""
            channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"], "JWT validation failed: Invalid audience"
            )
    
    
        def test_login_no_token(self):
    
            params = {"type": "org.matrix.login.jwt"}
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
    
    
    # The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use
    # RSS256, with a public key configured in synapse as "jwt_secret", and tokens
    # signed by the private key.
    
    @skip_unless(HAS_JWT, "requires jwt")
    
    class JWTPubKeyTestCase(unittest.HomeserverTestCase):
        servlets = [
            login.register_servlets,
        ]