Skip to content
Snippets Groups Projects
Commit 747aa9f8 authored by Brendan Abolivier's avatar Brendan Abolivier
Browse files

Add account expiration feature

parent 35442efb
No related branches found
No related tags found
No related merge requests found
Add time-based account expiration.
......@@ -643,6 +643,12 @@ uploads_path: "DATADIR/uploads"
#
#enable_registration: false
# Optional account validity parameter. This allows for, e.g., accounts to
# be denied any request after a given period.
#
#account_validity:
# period: 6w
# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
......
......@@ -64,6 +64,8 @@ class Auth(object):
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
self._account_validity = hs.config.account_validity
@defer.inlineCallbacks
def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
......@@ -226,6 +228,16 @@ class Auth(object):
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
# Deny the request if the user account has expired.
if self._account_validity.enabled:
expiration_ts = yield self.store.get_expiration_ts_for_user(user)
if self.clock.time_msec() >= expiration_ts:
raise AuthError(
403,
"User account has expired",
errcode=Codes.EXPIRED_ACCOUNT,
)
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
......
......@@ -60,6 +60,7 @@ class Codes(object):
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
class CodeMessageException(RuntimeError):
......
......@@ -20,6 +20,15 @@ from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
def __init__(self, config):
self.enabled = (len(config) > 0)
period = config.get("period", None)
if period:
self.period = self.parse_duration(period)
class RegistrationConfig(Config):
def read_config(self, config):
......@@ -31,6 +40,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
self.account_validity = AccountValidityConfig(config.get("account_validity", {}))
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
......@@ -75,6 +86,12 @@ class RegistrationConfig(Config):
#
#enable_registration: false
# Optional account validity parameter. This allows for, e.g., accounts to
# be denied any request after a given period.
#
#account_validity:
# period: 6w
# The user must provide all of the below types of 3PID when registering.
#
#registrations_require_3pid:
......
......@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 53
SCHEMA_VERSION = 54
dir_path = os.path.abspath(os.path.dirname(__file__))
......
......@@ -86,6 +86,26 @@ class RegistrationWorkerStore(SQLBaseStore):
token
)
@cachedInlineCallbacks()
def get_expiration_ts_for_user(self, user):
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
user (str): The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
res = yield self._simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user.to_string()},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_date_for_user",
)
defer.returnValue(res)
@defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
......@@ -351,6 +371,8 @@ class RegistrationStore(RegistrationWorkerStore,
columns=["creation_ts"],
)
self._account_validity = hs.config.account_validity
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
......@@ -485,6 +507,18 @@ class RegistrationStore(RegistrationWorkerStore,
"user_type": user_type,
}
)
if self._account_validity.enabled:
now_ms = self.clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
self._simple_insert_txn(
txn,
"account_validity",
values={
"user_id": user_id,
"expiration_ts_ms": expiration_ts,
}
)
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
......
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Track what users are in public rooms.
CREATE TABLE IF NOT EXISTS account_validity (
user_id TEXT PRIMARY KEY,
expiration_ts_ms BIGINT NOT NULL
);
import datetime
import json
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.rest.client.v1 import admin, login
from synapse.rest.client.v2_alpha import register, sync
from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
servlets = [register.register_servlets]
def make_homeserver(self, reactor, clock):
......@@ -181,3 +184,47 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
class AccountValidityTestCase(unittest.HomeserverTestCase):
servlets = [
register.register_servlets,
admin.register_servlets,
login.register_servlets,
sync.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
config.enable_registration = True
config.account_validity.enabled = True
config.account_validity.period = 604800000 # Time in ms for 1 week
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def test_validity_period(self):
self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
# The specific endpoint doesn't matter, all we need is an authenticated
# endpoint.
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
request, channel = self.make_request(
b"GET", "/sync", access_token=tok,
)
self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
)
......@@ -24,7 +24,7 @@ from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
from .utils import MockClock
from .utils import MockClock, default_config
_next_event_id = 1000
......@@ -159,6 +159,7 @@ class StateTestCase(unittest.TestCase):
self.store = StateGroupStore()
hs = Mock(
spec_set=[
"config",
"get_datastore",
"get_auth",
"get_state_handler",
......@@ -166,6 +167,7 @@ class StateTestCase(unittest.TestCase):
"get_state_resolution_handler",
]
)
hs.config = default_config("tesths")
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment