Skip to content
Snippets Groups Projects
test_login.py 50.7 KiB
Newer Older
  • Learn to ignore specific revisions
  •             "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=",
                "-----END RSA PRIVATE KEY-----",
            ]
        )
    
        # Generated with `openssl rsa -in foo.key -pubout`, with the the above
        # private key placed in foo.key (jwt_privatekey).
        jwt_pubkey = "\n".join(
            [
                "-----BEGIN PUBLIC KEY-----",
                "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7",
                "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==",
                "-----END PUBLIC KEY-----",
            ]
        )
    
        # This key is used to sign tokens that shouldn't be accepted by synapse.
        # Generated just like jwt_privatekey.
        bad_privatekey = "\n".join(
            [
                "-----BEGIN RSA PRIVATE KEY-----",
                "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv",
                "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L",
                "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY",
                "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I",
                "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb",
                "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0",
                "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m",
                "-----END RSA PRIVATE KEY-----",
            ]
        )
    
        def make_homeserver(self, reactor, clock):
            self.hs = self.setup_test_homeserver()
            self.hs.config.jwt_enabled = True
            self.hs.config.jwt_secret = self.jwt_pubkey
            self.hs.config.jwt_algorithm = "RS256"
            return self.hs
    
    
        def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str:
    
            # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
    
            result = jwt.encode(payload, secret, "RS256")  # type: Union[bytes,str]
    
            if isinstance(result, bytes):
                return result.decode("ascii")
            return result
    
            params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
            return channel
    
        def test_login_jwt_valid(self):
            channel = self.jwt_login({"sub": "kermit"})
            self.assertEqual(channel.result["code"], b"200", channel.result)
            self.assertEqual(channel.json_body["user_id"], "@kermit:test")
    
        def test_login_jwt_invalid_signature(self):
            channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
    
            self.assertEqual(channel.result["code"], b"403", channel.result)
            self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
    
            self.assertEqual(
                channel.json_body["error"],
                "JWT validation failed: Signature verification failed",
            )
    
    
    
    AS_USER = "as_user_alice"
    
    
    class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
        servlets = [
            login.register_servlets,
            register.register_servlets,
        ]
    
        def register_as_user(self, username):
    
    Richard van der Hoff's avatar
    Richard van der Hoff committed
            self.make_request(
    
                b"POST",
                "/_matrix/client/r0/register?access_token=%s" % (self.service.token,),
                {"username": username},
            )
    
        def make_homeserver(self, reactor, clock):
            self.hs = self.setup_test_homeserver()
    
            self.service = ApplicationService(
                id="unique_identifier",
                token="some_token",
                hostname="example.com",
                sender="@asbot:example.com",
                namespaces={
                    ApplicationService.NS_USERS: [
                        {"regex": r"@as_user.*", "exclusive": False}
                    ],
                    ApplicationService.NS_ROOMS: [],
                    ApplicationService.NS_ALIASES: [],
                },
            )
            self.another_service = ApplicationService(
                id="another__identifier",
                token="another_token",
                hostname="example.com",
                sender="@as2bot:example.com",
                namespaces={
                    ApplicationService.NS_USERS: [
                        {"regex": r"@as2_user.*", "exclusive": False}
                    ],
                    ApplicationService.NS_ROOMS: [],
                    ApplicationService.NS_ALIASES: [],
                },
            )
    
            self.hs.get_datastore().services_cache.append(self.service)
            self.hs.get_datastore().services_cache.append(self.another_service)
            return self.hs
    
        def test_login_appservice_user(self):
    
            """Test that an appservice user can use /login"""
    
            self.register_as_user(AS_USER)
    
            params = {
                "type": login.LoginRestServlet.APPSERVICE_TYPE,
                "identifier": {"type": "m.id.user", "user": AS_USER},
            }
    
                b"POST", LOGIN_URL, params, access_token=self.service.token
            )
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
        def test_login_appservice_user_bot(self):
    
            """Test that the appservice bot can use /login"""
    
            self.register_as_user(AS_USER)
    
            params = {
                "type": login.LoginRestServlet.APPSERVICE_TYPE,
                "identifier": {"type": "m.id.user", "user": self.service.sender},
            }
    
                b"POST", LOGIN_URL, params, access_token=self.service.token
            )
    
            self.assertEquals(channel.result["code"], b"200", channel.result)
    
        def test_login_appservice_wrong_user(self):
    
            """Test that non-as users cannot login with the as token"""
    
            self.register_as_user(AS_USER)
    
            params = {
                "type": login.LoginRestServlet.APPSERVICE_TYPE,
                "identifier": {"type": "m.id.user", "user": "fibble_wibble"},
            }
    
                b"POST", LOGIN_URL, params, access_token=self.service.token
            )
    
            self.assertEquals(channel.result["code"], b"403", channel.result)
    
        def test_login_appservice_wrong_as(self):
    
            """Test that as users cannot login with wrong as token"""
    
            self.register_as_user(AS_USER)
    
            params = {
                "type": login.LoginRestServlet.APPSERVICE_TYPE,
                "identifier": {"type": "m.id.user", "user": AS_USER},
            }
    
                b"POST", LOGIN_URL, params, access_token=self.another_service.token
            )
    
            self.assertEquals(channel.result["code"], b"403", channel.result)
    
        def test_login_appservice_no_token(self):
            """Test that users must provide a token when using the appservice
    
            """
            self.register_as_user(AS_USER)
    
            params = {
                "type": login.LoginRestServlet.APPSERVICE_TYPE,
                "identifier": {"type": "m.id.user", "user": AS_USER},
            }
    
            channel = self.make_request(b"POST", LOGIN_URL, params)
    
    
            self.assertEquals(channel.result["code"], b"401", channel.result)
    
    
    
    @skip_unless(HAS_OIDC, "requires OIDC")
    class UsernamePickerTestCase(HomeserverTestCase):
        """Tests for the username picker flow of SSO login"""
    
        servlets = [login.register_servlets]
    
        def default_config(self):
            config = super().default_config()
            config["public_baseurl"] = BASE_URL
    
            config["oidc_config"] = {}
            config["oidc_config"].update(TEST_OIDC_CONFIG)
            config["oidc_config"]["user_mapping_provider"] = {
                "config": {"display_name_template": "{{ user.displayname }}"}
            }
    
            # whitelist this client URI so we redirect straight to it rather than
            # serving a confirmation page
    
            config["sso"] = {"client_whitelist": ["https://x"]}
    
            return config
    
        def create_resource_dict(self) -> Dict[str, Resource]:
            d = super().create_resource_dict()
    
            d.update(build_synapse_client_resource_tree(self.hs))
    
            return d
    
        def test_username_picker(self):
            """Test the happy path of a username picker flow."""
    
            # do the start of the login flow
            channel = self.helper.auth_via_oidc(
    
                {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
    
            )
    
            # that should redirect to the username picker
            self.assertEqual(channel.code, 302, channel.result)
            picker_url = channel.headers.getRawHeaders("Location")[0]
    
            self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
    
    
            # ... with a username_mapping_session cookie
            cookies = {}  # type: Dict[str,str]
            channel.extract_cookies(cookies)
            self.assertIn("username_mapping_session", cookies)
            session_id = cookies["username_mapping_session"]
    
            # introspect the sso handler a bit to check that the username mapping session
            # looks ok.
            username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
            self.assertIn(
                session_id, username_mapping_sessions, "session id not found in map",
            )
            session = username_mapping_sessions[session_id]
            self.assertEqual(session.remote_user_id, "tester")
            self.assertEqual(session.display_name, "Jonny")
    
            self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
    
    
            # the expiry time should be about 15 minutes away
            expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
            self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
    
            # Now, submit a username to the username picker, which should serve a redirect
    
            content = urlencode({b"username": b"bobby"}).encode("utf8")
            chan = self.make_request(
                "POST",
    
                content=content,
                content_is_form=True,
                custom_headers=[
                    ("Cookie", "username_mapping_session=" + session_id),
                    # old versions of twisted don't do form-parsing without a valid
                    # content-length header.
                    ("Content-Length", str(len(content))),
                ],
            )
            self.assertEqual(chan.code, 302, chan.result)
            location_headers = chan.headers.getRawHeaders("Location")
    
    
            # send a request to the completion page, which should 302 to the client redirectUrl
            chan = self.make_request(
                "GET",
                path=location_headers[0],
                custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
            )
            self.assertEqual(chan.code, 302, chan.result)
            location_headers = chan.headers.getRawHeaders("Location")
    
    
            # ensure that the returned location matches the requested redirect URL
            path, query = location_headers[0].split("?", 1)
            self.assertEqual(path, "https://x")
    
            # it will have url-encoded the params properly, so we'll have to parse them
            params = urllib.parse.parse_qsl(
                query, keep_blank_values=True, strict_parsing=True, errors="strict"
    
            self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
            self.assertEqual(params[2][0], "loginToken")
    
    
            # fish the login token out of the returned redirect uri
    
    
            # finally, submit the matrix login token to the login API, which gives us our
            # matrix access token, mxid, and device id.
            chan = self.make_request(
                "POST", "/login", content={"type": "m.login.token", "token": login_token},
            )
            self.assertEqual(chan.code, 200, chan.result)
            self.assertEqual(chan.json_body["user_id"], "@bobby:test")