Skip to content
Snippets Groups Projects
Commit 0d43f991 authored by Neil Johnson's avatar Neil Johnson
Browse files

support admin_email config and pass through into blocking errors, return AuthError in all cases

parent 99dd975d
No related branches found
No related tags found
No related merge requests found
...@@ -781,11 +781,15 @@ class Auth(object): ...@@ -781,11 +781,15 @@ class Auth(object):
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED 403, self.hs.config.hs_disabled_message,
errcode=Codes.HS_DISABLED,
admin_email=self.hs.config.admin_email,
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED 403, "MAU Limit Exceeded",
admin_email=self.hs.config.admin_email,
errcode=Codes.MAU_LIMIT_EXCEEDED
) )
...@@ -225,11 +225,20 @@ class NotFoundError(SynapseError): ...@@ -225,11 +225,20 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if "errcode" not in kwargs: if "errcode" not in kwargs:
kwargs["errcode"] = Codes.FORBIDDEN kwargs["errcode"] = Codes.FORBIDDEN
super(AuthError, self).__init__(*args, **kwargs) self.admin_email = kwargs.get('admin_email')
self.msg = kwargs.get('msg')
self.errcode = kwargs.get('errcode')
super(AuthError, self).__init__(*args, errcode=kwargs["errcode"])
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
admin_email=self.admin_email,
)
class EventSizeError(SynapseError): class EventSizeError(SynapseError):
......
...@@ -82,6 +82,10 @@ class ServerConfig(Config): ...@@ -82,6 +82,10 @@ class ServerConfig(Config):
self.hs_disabled = config.get("hs_disabled", False) self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "") self.hs_disabled_message = config.get("hs_disabled_message", "")
# Admin email to direct users at should their instance become blocked
# due to resource constraints
self.admin_email = config.get("admin_email", None)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(
......
...@@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler): ...@@ -144,7 +144,8 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield self._check_mau_limits()
yield self.auth.check_auth_blocking()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
...@@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler): ...@@ -289,7 +290,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
...@@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler): ...@@ -439,7 +440,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self._check_mau_limits() yield self.auth.check_auth_blocking()
need_register = True need_register = True
try: try:
...@@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler): ...@@ -534,13 +535,13 @@ class RegistrationHandler(BaseHandler):
action="join", action="join",
) )
@defer.inlineCallbacks # @defer.inlineCallbacks
def _check_mau_limits(self): # def _s(self):
""" # """
Do not accept registrations if monthly active user limits exceeded # Do not accept registrations if monthly active user limits exceeded
and limiting is enabled # and limiting is enabled
""" # """
try: # try:
yield self.auth.check_auth_blocking() # yield self.auth.check_auth_blocking()
except AuthError as e: # except AuthError as e:
raise RegistrationError(e.code, str(e), e.errcode) # raise RegistrationError(e.code, str(e), e.errcode)
...@@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase): ...@@ -455,8 +455,11 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
with self.assertRaises(AuthError): with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_email, self.hs.config.admin_email)
self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
# Ensure does not throw an error # Ensure does not throw an error
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
...@@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase): ...@@ -470,5 +473,6 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled_message = "Reason for being disabled" self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(AuthError) as e: with self.assertRaises(AuthError) as e:
yield self.auth.check_auth_blocking() yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_email, self.hs.config.admin_email)
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED) self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
self.assertEquals(e.exception.code, 403) self.assertEquals(e.exception.code, 403)
...@@ -17,7 +17,7 @@ from mock import Mock ...@@ -17,7 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError from synapse.api.errors import AuthError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
...@@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase): ...@@ -109,7 +109,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.get_or_create_user("requester", 'b', "display_name") yield self.handler.get_or_create_user("requester", 'b', "display_name")
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase): ...@@ -118,7 +118,7 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register(localpart="local_part") yield self.handler.register(localpart="local_part")
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase): ...@@ -127,5 +127,5 @@ class RegistrationTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
with self.assertRaises(RegistrationError): with self.assertRaises(AuthError):
yield self.handler.register_saml2(localpart="local_part") yield self.handler.register_saml2(localpart="local_part")
...@@ -139,6 +139,7 @@ def setup_test_homeserver( ...@@ -139,6 +139,7 @@ def setup_test_homeserver(
config.hs_disabled_message = "" config.hs_disabled_message = ""
config.max_mau_value = 50 config.max_mau_value = 50
config.mau_limits_reserved_threepids = [] config.mau_limits_reserved_threepids = []
config.admin_email = None
# we need a sane default_room_version, otherwise attempts to create rooms will # we need a sane default_room_version, otherwise attempts to create rooms will
# fail. # fail.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment