Skip to content
Snippets Groups Projects
Unverified Commit aee9130a authored by Andrew Morgan's avatar Andrew Morgan Committed by GitHub
Browse files

Stop Auth methods from polling the config on every req. (#7420)

parent b26f3e58
No related branches found
No related tags found
No related merge requests found
Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request.
\ No newline at end of file
...@@ -26,16 +26,15 @@ from twisted.internet import defer ...@@ -26,16 +26,15 @@ from twisted.internet import defer
import synapse.logging.opentracing as opentracing import synapse.logging.opentracing as opentracing
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
InvalidClientTokenError, InvalidClientTokenError,
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import is_threepid_reserved
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
...@@ -77,7 +76,11 @@ class Auth(object): ...@@ -77,7 +76,11 @@ class Auth(object):
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache) register_cache("cache", "token_cache", self.token_cache)
self._auth_blocking = AuthBlocking(self.hs)
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True): def check_from_context(self, room_version: str, event, context, do_sig_check=True):
...@@ -191,7 +194,7 @@ class Auth(object): ...@@ -191,7 +194,7 @@ class Auth(object):
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id) opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
...@@ -454,7 +457,7 @@ class Auth(object): ...@@ -454,7 +457,7 @@ class Auth(object):
# access_tokens include a nonce for uniqueness: any value is acceptable # access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat): def _verify_expiry(self, caveat):
prefix = "time < " prefix = "time < "
...@@ -663,71 +666,5 @@ class Auth(object): ...@@ -663,71 +666,5 @@ class Auth(object):
% (user_id, room_id), % (user_id, room_id),
) )
@defer.inlineCallbacks def check_auth_blocking(self, *args, **kwargs):
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): return self._auth_blocking.check_auth_blocking(*args, **kwargs)
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self.hs.config.server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self.hs.config.hs_disabled:
raise ResourceLimitError(
403,
self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
logger = logging.getLogger(__name__)
class AuthBlocking(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
self._hs_disabled = hs.config.hs_disabled
self._hs_disabled_message = hs.config.hs_disabled_message
self._admin_contact = hs.config.admin_contact
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self._hs_disabled:
raise ResourceLimitError(
403,
self._hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self._admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self._limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self._admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)
...@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase): ...@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs) self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.auth._auth_blocking
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = b"_test_token_" self.test_token = b"_test_token_"
...@@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase): ...@@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_blocking_mau(self): def test_blocking_mau(self):
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.hs.config.max_mau_value = 50 self.auth_blocking._max_mau_value = 50
lots_of_users = 100 lots_of_users = 100
small_number_of_users = 1 small_number_of_users = 1
# Ensure no error thrown # Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
...@@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase): ...@@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self): def test_blocking_mau__depending_on_user_type(self):
self.hs.config.max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed # Support users allowed
...@@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase): ...@@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2) self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"} threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
...@@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase): ...@@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hs_disabled(self): def test_hs_disabled(self):
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
...@@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase): ...@@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase):
""" """
# this should be the default, but we had a bug where the test was doing the wrong # this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit # thing, so let's make it explicit
self.hs.config.server_notices_mxid = None self.auth_blocking._server_notices_mxid = None
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
...@@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase): ...@@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self): def test_server_notices_mxid_special_cased(self):
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
user = "@user:server" user = "@user:server"
self.hs.config.server_notices_mxid = user self.auth_blocking._server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
...@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase): ...@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests # MAU tests
self.hs.config.max_mau_value = 50 # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1 self.small_number_of_users = 1
self.large_number_of_users = 100 self.large_number_of_users = 100
...@@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_disabled(self): def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
...@@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
...@@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase): ...@@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_parity(self): def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort # If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
...@@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
...@@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=defer.succeed(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
...@@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=defer.succeed(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
...@@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase): ...@@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)
......
...@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ...@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler() self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
def test_wait_for_sync_for_user_auth_blocking(self): # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1) sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works # Test that global lock works
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
......
...@@ -19,6 +19,7 @@ import json ...@@ -19,6 +19,7 @@ import json
from mock import Mock from mock import Mock
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync from synapse.rest.client.v2_alpha import register, sync
...@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase): ...@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2 self.hs.config.max_mau_value = 2
self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room" self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.hs.config.mau_trial_days = 0
# AuthBlocking reads config options during hs creation. Recreate the
# hs' copy of AuthBlocking after we've updated config values above
self.auth_blocking = AuthBlocking(self.hs)
self.hs.get_auth()._auth_blocking = self.auth_blocking
return self.hs return self.hs
def test_simple_deny_mau(self): def test_simple_deny_mau(self):
...@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase): ...@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self): def test_trial_users_cant_come_back(self):
self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1 self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially # We should be able to register more than the limit initially
...@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase): ...@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_tracked_but_not_limited(self): def test_tracked_but_not_limited(self):
self.hs.config.max_mau_value = 1 # should not matter self.auth_blocking._max_mau_value = 1 # should not matter
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the # Simply being able to create 2 users indicates that the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment