Skip to content
Snippets Groups Projects
test_user.py 127 KiB
Newer Older
  • Learn to ignore specific revisions
  •             expected_media_list: The list of media_ids in the order we expect to get
                    back from the server
                order_by: The type of ordering to give the server
                dir: The direction of ordering to give the server
            """
    
            url = self.url + "?"
            if order_by is not None:
    
                url += f"order_by={order_by}&"
    
            if dir is not None and dir in ("b", "f"):
    
            channel = self.make_request(
                "GET",
    
                access_token=self.admin_user_tok,
            )
            self.assertEqual(200, channel.code, msg=channel.json_body)
            self.assertEqual(channel.json_body["total"], len(expected_media_list))
    
            returned_order = [row["media_id"] for row in channel.json_body["media"]]
            self.assertEqual(expected_media_list, returned_order)
            self._check_fields(channel.json_body["media"])
    
    
    
    class UserTokenRestTestCase(unittest.HomeserverTestCase):
    
        """Test for /_synapse/admin/v1/users/<user>/login"""
    
    
        servlets = [
            synapse.rest.admin.register_servlets,
            login.register_servlets,
            sync.register_servlets,
            room.register_servlets,
            devices.register_servlets,
            logout.register_servlets,
        ]
    
        def prepare(self, reactor, clock, hs):
            self.store = hs.get_datastore()
    
            self.admin_user = self.register_user("admin", "pass", admin=True)
            self.admin_user_tok = self.login("admin", "pass")
    
            self.other_user = self.register_user("user", "pass")
            self.other_user_tok = self.login("user", "pass")
            self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
                self.other_user
            )
    
        def _get_token(self) -> str:
    
                "POST", self.url, b"{}", access_token=self.admin_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            return channel.json_body["access_token"]
    
        def test_no_auth(self):
    
            """Try to login as a user without authentication."""
    
            channel = self.make_request("POST", self.url, b"{}")
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
    
        def test_not_admin(self):
    
            """Try to login as a user as a non-admin user."""
    
                "POST", self.url, b"{}", access_token=self.other_user_tok
            )
    
    
            self.assertEqual(403, channel.code, msg=channel.json_body)
    
    
        def test_send_event(self):
    
            """Test that sending event as a user works."""
    
            # Create a room.
            room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
    
            # Login in as the user
            puppet_token = self._get_token()
    
            # Test that sending works, and generates the event as the right user.
            resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
            event_id = resp["event_id"]
            event = self.get_success(self.store.get_event(event_id))
            self.assertEqual(event.sender, self.other_user)
    
        def test_devices(self):
    
            """Tests that logging in as a user doesn't create a new device for them."""
    
            # Login in as the user
            self._get_token()
    
            # Check that we don't see a new device in our devices list
    
                "GET", "devices", b"{}", access_token=self.other_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # We should only see the one device (from the login in `prepare`)
            self.assertEqual(len(channel.json_body["devices"]), 1)
    
        def test_logout(self):
    
            """Test that calling `/logout` with the token works."""
    
            # Login in as the user
            puppet_token = self._get_token()
    
            # Test that we can successfully make a request
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # Logout with the puppet token
    
            channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token)
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # The puppet token should no longer work
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
    
            # .. but the real user's tokens should still work
    
                "GET", "devices", b"{}", access_token=self.other_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
        def test_user_logout_all(self):
            """Tests that the target user calling `/logout/all` does *not* expire
            the token.
            """
            # Login in as the user
            puppet_token = self._get_token()
    
            # Test that we can successfully make a request
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # Logout all with the real user token
    
                "POST", "logout/all", b"{}", access_token=self.other_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # The puppet token should still work
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # .. but the real user's tokens shouldn't
    
                "GET", "devices", b"{}", access_token=self.other_user_tok
            )
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
    
        def test_admin_logout_all(self):
            """Tests that the admin user calling `/logout/all` does expire the
            token.
            """
            # Login in as the user
            puppet_token = self._get_token()
    
            # Test that we can successfully make a request
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # Logout all with the admin user token
    
                "POST", "logout/all", b"{}", access_token=self.admin_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
            # The puppet token should no longer work
    
            channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
    
            # .. but the real user's tokens should still work
    
                "GET", "devices", b"{}", access_token=self.other_user_tok
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
    
        @unittest.override_config(
            {
                "public_baseurl": "https://example.org/",
                "user_consent": {
                    "version": "1.0",
                    "policy_name": "My Cool Privacy Policy",
                    "template_dir": "/",
                    "require_at_registration": True,
                    "block_events_error": "You should accept the policy",
                },
                "form_secret": "123secret",
            }
        )
        def test_consent(self):
    
            """Test that sending a message is not subject to the privacy policies."""
    
            # Have the admin user accept the terms.
            self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
    
            # First, cheekily accept the terms and create a room
            self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
            room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
            self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
    
            # Now unaccept it and check that we can't send an event
            self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
            self.helper.send_event(
                room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
            )
    
            # Login in as the user
            puppet_token = self._get_token()
    
            # Sending an event on their behalf should work fine
            self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
    
        @override_config(
            {"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
        )
        def test_mau_limit(self):
            # Create a room as the admin user. This will bump the monthly active users to 1.
            room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
    
            # Trying to join as the other user should fail due to reaching MAU limit.
            self.helper.join(
                room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
            )
    
            # Logging in as the other user and joining a room should work, even
            # though the MAU limit would stop the user doing so.
            puppet_token = self._get_token()
            self.helper.join(room_id, user=self.other_user, tok=puppet_token)
    
    @parameterized_class(
        ("url_prefix",),
        [
            ("/_synapse/admin/v1/whois/%s",),
            ("/_matrix/client/r0/admin/whois/%s",),
        ],
    )
    
    class WhoisRestTestCase(unittest.HomeserverTestCase):
    
        servlets = [
            synapse.rest.admin.register_servlets,
            login.register_servlets,
        ]
    
        def prepare(self, reactor, clock, hs):
            self.admin_user = self.register_user("admin", "pass", admin=True)
            self.admin_user_tok = self.login("admin", "pass")
    
            self.other_user = self.register_user("user", "pass")
    
            self.url = self.url_prefix % self.other_user
    
    
        def test_no_auth(self):
            """
            Try to get information of an user without authentication.
            """
    
            channel = self.make_request("GET", self.url, b"{}")
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
    
        def test_requester_is_not_admin(self):
            """
            If the user is not a server admin, an error is returned.
            """
            self.register_user("user2", "pass")
            other_user2_token = self.login("user2", "pass")
    
    
            channel = self.make_request(
                "GET",
    
            self.assertEqual(403, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
    
        def test_user_is_not_local(self):
            """
            Tests that a lookup for a user that is not a local returns a 400
            """
    
            url = self.url_prefix % "@unknown_person:unknown_domain"
    
            channel = self.make_request(
                "GET",
    
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
            self.assertEqual("Can only whois a local user", channel.json_body["error"])
    
        def test_get_whois_admin(self):
            """
            The lookup should succeed for an admin.
            """
    
            channel = self.make_request(
                "GET",
    
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
            self.assertEqual(self.other_user, channel.json_body["user_id"])
            self.assertIn("devices", channel.json_body)
    
        def test_get_whois_user(self):
            """
            The lookup should succeed for a normal user looking up their own information.
            """
            other_user_token = self.login("user", "pass")
    
    
            channel = self.make_request(
                "GET",
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
            self.assertEqual(self.other_user, channel.json_body["user_id"])
            self.assertIn("devices", channel.json_body)
    
    
    
    class ShadowBanRestTestCase(unittest.HomeserverTestCase):
    
        servlets = [
            synapse.rest.admin.register_servlets,
            login.register_servlets,
        ]
    
        def prepare(self, reactor, clock, hs):
            self.store = hs.get_datastore()
    
            self.admin_user = self.register_user("admin", "pass", admin=True)
            self.admin_user_tok = self.login("admin", "pass")
    
            self.other_user = self.register_user("user", "pass")
    
            self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
                self.other_user
            )
    
        def test_no_auth(self):
            """
            Try to get information of an user without authentication.
            """
            channel = self.make_request("POST", self.url)
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
    
        def test_requester_is_not_admin(self):
            """
            If the user is not a server admin, an error is returned.
            """
            other_user_token = self.login("user", "pass")
    
            channel = self.make_request("POST", self.url, access_token=other_user_token)
    
            self.assertEqual(403, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
    
        def test_user_is_not_local(self):
            """
            Tests that shadow-banning for a user that is not a local returns a 400
            """
            url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
    
            channel = self.make_request("POST", url, access_token=self.admin_user_tok)
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
        def test_success(self):
            """
            Shadow-banning should succeed for an admin.
            """
            # The user starts off as not shadow-banned.
            other_user_token = self.login("user", "pass")
            result = self.get_success(self.store.get_user_by_access_token(other_user_token))
            self.assertFalse(result.shadow_banned)
    
            channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
            self.assertEqual(200, channel.code, msg=channel.json_body)
            self.assertEqual({}, channel.json_body)
    
            # Ensure the user is shadow-banned (and the cache was cleared).
            result = self.get_success(self.store.get_user_by_access_token(other_user_token))
            self.assertTrue(result.shadow_banned)
    
    
    
    class RateLimitTestCase(unittest.HomeserverTestCase):
    
        servlets = [
            synapse.rest.admin.register_servlets,
            login.register_servlets,
        ]
    
        def prepare(self, reactor, clock, hs):
            self.store = hs.get_datastore()
    
            self.admin_user = self.register_user("admin", "pass", admin=True)
            self.admin_user_tok = self.login("admin", "pass")
    
            self.other_user = self.register_user("user", "pass")
            self.url = (
                "/_synapse/admin/v1/users/%s/override_ratelimit"
                % urllib.parse.quote(self.other_user)
            )
    
    
        @parameterized.expand(["GET", "POST", "DELETE"])
        def test_no_auth(self, method: str):
    
            """
            Try to get information of a user without authentication.
            """
    
            channel = self.make_request(method, self.url, b"{}")
    
            self.assertEqual(401, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
    
    
        @parameterized.expand(["GET", "POST", "DELETE"])
        def test_requester_is_no_admin(self, method: str):
    
            """
            If the user is not a server admin, an error is returned.
            """
            other_user_token = self.login("user", "pass")
    
            channel = self.make_request(
    
            self.assertEqual(403, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
    
    
        @parameterized.expand(["GET", "POST", "DELETE"])
        def test_user_does_not_exist(self, method: str):
    
            """
            Tests that a lookup for a user that does not exist returns a 404
            """
            url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
    
            channel = self.make_request(
    
                url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(404, channel.code, msg=channel.json_body)
            self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
    
    
        @parameterized.expand(
            [
                ("GET", "Can only look up local users"),
                ("POST", "Only local users can be ratelimited"),
                ("DELETE", "Only local users can be ratelimited"),
            ]
        )
        def test_user_is_not_local(self, method: str, error_msg: str):
    
            """
            Tests that a lookup for a user that is not a local returns a 400
            """
            url = (
                "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit"
            )
    
            channel = self.make_request(
    
                url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
            self.assertEqual(error_msg, channel.json_body["error"])
    
    
        def test_invalid_parameter(self):
            """
            If parameters are invalid, an error is returned.
            """
            # messages_per_second is a string
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"messages_per_second": "string"},
            )
    
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
    
            # messages_per_second is negative
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"messages_per_second": -1},
            )
    
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
    
            # burst_count is a string
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"burst_count": "string"},
            )
    
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
    
            # burst_count is negative
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"burst_count": -1},
            )
    
    
            self.assertEqual(400, channel.code, msg=channel.json_body)
    
            self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
    
        def test_return_zero_when_null(self):
            """
            If values in database are `null` API should return an int `0`
            """
    
            self.get_success(
                self.store.db_pool.simple_upsert(
                    table="ratelimit_override",
                    keyvalues={"user_id": self.other_user},
                    values={
                        "messages_per_second": None,
                        "burst_count": None,
                    },
                )
            )
    
            # request status
            channel = self.make_request(
                "GET",
                self.url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertEqual(0, channel.json_body["messages_per_second"])
            self.assertEqual(0, channel.json_body["burst_count"])
    
        def test_success(self):
            """
            Rate-limiting (set/update/delete) should succeed for an admin.
            """
            # request status
            channel = self.make_request(
                "GET",
                self.url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertNotIn("messages_per_second", channel.json_body)
            self.assertNotIn("burst_count", channel.json_body)
    
            # set ratelimit
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"messages_per_second": 10, "burst_count": 11},
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertEqual(10, channel.json_body["messages_per_second"])
            self.assertEqual(11, channel.json_body["burst_count"])
    
            # update ratelimit
            channel = self.make_request(
                "POST",
                self.url,
                access_token=self.admin_user_tok,
                content={"messages_per_second": 20, "burst_count": 21},
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertEqual(20, channel.json_body["messages_per_second"])
            self.assertEqual(21, channel.json_body["burst_count"])
    
            # request status
            channel = self.make_request(
                "GET",
                self.url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertEqual(20, channel.json_body["messages_per_second"])
            self.assertEqual(21, channel.json_body["burst_count"])
    
            # delete ratelimit
            channel = self.make_request(
                "DELETE",
                self.url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertNotIn("messages_per_second", channel.json_body)
            self.assertNotIn("burst_count", channel.json_body)
    
            # request status
            channel = self.make_request(
                "GET",
                self.url,
                access_token=self.admin_user_tok,
            )
    
            self.assertEqual(200, channel.code, msg=channel.json_body)
    
            self.assertNotIn("messages_per_second", channel.json_body)
            self.assertNotIn("burst_count", channel.json_body)