Skip to content
Snippets Groups Projects
Commit 09cf1308 authored by Neil Johnson's avatar Neil Johnson
Browse files

only block on sync where user is not part of the mau cohort

parent c6b28fb4
No related branches found
No related tags found
No related merge requests found
......@@ -775,17 +775,26 @@ class Auth(object):
)
@defer.inlineCallbacks
def check_auth_blocking(self):
def check_auth_blocking(self, user_id=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str): If present, checks for presence against existing MAU cohort
"""
if self.hs.config.hs_disabled:
raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
)
if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort
if user_id:
timestamp = yield self.store._user_last_seen_monthly_active(user_id)
if timestamp:
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
)
......@@ -208,7 +208,12 @@ class SyncHandler(object):
Returns:
Deferred[SyncResult]
"""
yield self.auth.check_auth_blocking()
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
......
......@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from synapse.api.errors import AuthError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID
......@@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver()
self.sync_handler = SyncHandler(self.hs)
self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
sync_config = SyncConfig(
user=UserID("@user", "server"),
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config)
# Test that global lock works
self.hs.config.hs_disabled = True
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
request_key="request_key",
device_id="device_id",
)
# Ensure that an exception is not thrown
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.hs.config.hs_disabled = True
with self.assertRaises(AuthError):
yield self.sync_handler.wait_for_sync_for_user(sync_config)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment