Skip to content
Snippets Groups Projects
Commit a7a91391 authored by Erik Johnston's avatar Erik Johnston
Browse files

Merge remote-tracking branch 'origin/erikj/as_mau_block' into develop

parents 70586aa6 14eab1b4
No related branches found
No related tags found
No related merge requests found
Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit.
......@@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking(
self,
......@@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of
# users.
return
elif requester.app_service and not self._track_appservice_user_ips:
# If we're authenticated as an appservice then we only block
# auth if `track_appservice_user_ips` is set, as that option
# implicitly means that application services are part of MAU
# limits.
return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
......
......@@ -738,6 +738,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
) -> str:
"""
Creates a new access token for the user with the given user ID.
......@@ -754,6 +755,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID)
valid_until_ms: when the token is valid until. None for
no expiry.
is_appservice_ghost: Whether the user is an application ghost user
Returns:
The access token for the user's session.
Raises:
......@@ -774,7 +776,11 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
await self.auth.check_auth_blocking(user_id)
if (
not is_appservice_ghost
or self.hs.config.appservice.track_appservice_user_ips
):
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
......
......@@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
......@@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
......@@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
)
return (registered_device_id, access_token)
......
......@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
async def _serialize_payload(
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
......@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
}
async def _handle_request(self, request, user_id):
......@@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest
user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost=is_appservice_ghost,
)
return 200, {"device_id": device_id, "access_token": access_token}
......
......@@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
return await self._create_registration_details(user_id, body)
return await self._create_registration_details(
user_id, body, is_appservice_ghost=True,
)
async def _create_registration_details(self, user_id, params):
async def _create_registration_details(
self, user_id, params, is_appservice_ghost=False
):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
......@@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False
user_id,
device_id,
initial_display_name,
is_guest=False,
is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
......
......@@ -19,6 +19,7 @@ import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
......@@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_as_ignores_mau(self):
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
"""
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
# check we're testing what we think we are: there should be two active users
self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
self.create_user("kermit3")
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
# Cheekily add an application service that we use to register a new user
# with.
as_token = "foobartoken"
self.store.services_cache.append(
ApplicationService(
token=as_token,
hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender:test",
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
)
)
self.create_user("as_kermit4", token=as_token)
def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
......@@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count))
def create_user(self, localpart):
def create_user(self, localpart, token=None):
request_data = json.dumps(
{
"username": localpart,
......@@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
}
)
channel = self.make_request("POST", "/register", request_data)
channel = self.make_request(
"POST", "/register", request_data, access_token=token,
)
if channel.code != 200:
raise HttpResponseException(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment