Skip to content
Snippets Groups Projects
Unverified Commit eedab12e authored by Wilson's avatar Wilson Committed by GitHub
Browse files

forward requester id to check username for spam callbacks (#17916)

parent 483602ef
No related branches found
No related tags found
No related merge requests found
Module developers will have access to user id of requester when adding `check_username_for_spam` callbacks to `spam_checker_module_callbacks`. Contributed by Wilson@Pangea.chat.
\ No newline at end of file
......@@ -245,7 +245,7 @@ this callback.
_First introduced in Synapse v1.37.0_
```python
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile) -> bool
async def check_username_for_spam(user_profile: synapse.module_api.UserProfile, requester_id: str) -> bool
```
Called when computing search results in the user directory. The module must return a
......@@ -264,6 +264,8 @@ The profile is represented as a dictionary with the following keys:
The module is given a copy of the original dictionary, so modifying it from within the
module cannot modify a user's profile when included in user directory search results.
The requester_id parameter is the ID of the user that called the user directory API.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `False`, Synapse falls through to the next one. The value of the first
callback that does not return `False` will be used. If this happens, Synapse will not call
......
......@@ -72,8 +72,8 @@ class ExampleSpamChecker:
async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
async def check_username_for_spam(self, user_profile, requester_id):
return False # allow all usernames regardless of requester
async def check_registration_for_spam(
self,
......
......@@ -161,7 +161,7 @@ class UserDirectoryHandler(StateDeltasHandler):
non_spammy_users = []
for user in results["results"]:
if not await self._spam_checker_module_callbacks.check_username_for_spam(
user
user, user_id
):
non_spammy_users.append(user)
results["results"] = non_spammy_users
......
......@@ -31,6 +31,7 @@ from typing import (
Optional,
Tuple,
Union,
cast,
)
# `Literal` appears with Python 3.8.
......@@ -168,7 +169,10 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
]
],
]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[
Callable[[UserProfile], Awaitable[bool]],
Callable[[UserProfile, str], Awaitable[bool]],
]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
......@@ -716,7 +720,9 @@ class SpamCheckerModuleApiCallbacks:
return self.NOT_SPAM
async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
async def check_username_for_spam(
self, user_profile: UserProfile, requester_id: str
) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
......@@ -727,15 +733,33 @@ class SpamCheckerModuleApiCallbacks:
* user_id
* display_name
* avatar_url
requester_id: The user ID of the user making the user directory search request.
Returns:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
checker_args = inspect.signature(callback)
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
res = await delay_cancellation(callback(user_profile.copy()))
# Also ensure backwards compatibility with spam checker callbacks
# that don't expect the requester_id argument.
if len(checker_args.parameters) == 2:
callback_with_requester_id = cast(
Callable[[UserProfile, str], Awaitable[bool]], callback
)
res = await delay_cancellation(
callback_with_requester_id(user_profile.copy(), requester_id)
)
else:
callback_without_requester_id = cast(
Callable[[UserProfile], Awaitable[bool]], callback
)
res = await delay_cancellation(
callback_without_requester_id(user_profile.copy())
)
if res:
return True
......
......@@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
# Kept old spam checker without `requester_id` tests for backwards compatibility.
async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
......@@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
# Kept old spam checker without `requester_id` tests for backwards compatibility.
# Configure a spam checker that filters all users.
async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
......@@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
async def allow_all_expects_requester_id(
user_profile: UserProfile, requester_id: str
) -> bool:
self.assertEqual(requester_id, u1)
# Allow all users.
return False
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_module_api_callbacks().spam_checker
spam_checker._check_username_for_spam_callbacks = [
allow_all_expects_requester_id
]
# The results do not change:
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
async def block_all_expects_requester_id(
user_profile: UserProfile, requester_id: str
) -> bool:
self.assertEqual(requester_id, u1)
# All users are spammy.
return True
spam_checker._check_username_for_spam_callbacks = [
block_all_expects_requester_id
]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
@override_config(
{
"spam_checker": {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment