Skip to content
Snippets Groups Projects
Commit 651faee6 authored by Erik Johnston's avatar Erik Johnston
Browse files

Add an admin option to shared secret registration

parent caf33b2d
No related branches found
No related tags found
No related merge requests found
......@@ -42,6 +42,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
"admin": admin,
}
server_location = server_location.rstrip("/")
......@@ -73,7 +74,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
sys.exit(1)
def register_new_user(user, password, server_location, shared_secret):
def register_new_user(user, password, server_location, shared_secret, admin):
if not user:
try:
default_user = getpass.getuser()
......@@ -104,7 +105,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match"
sys.exit(1)
request_registration(user, password, server_location, shared_secret)
if not admin:
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__":
......@@ -124,6 +132,11 @@ if __name__ == "__main__":
default=None,
help="New password for user. Will prompt if omitted.",
)
parser.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
......@@ -156,4 +169,4 @@ if __name__ == "__main__":
else:
secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret)
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
......@@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler):
password=None,
generate_token=True,
guest_access_token=None,
make_guest=False
make_guest=False,
admin=False,
):
"""Registers a new client on the server.
......@@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler):
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin,
)
else:
# autogen a sequential user ID
......
......@@ -345,6 +345,7 @@ class RegisterRestServlet(ClientV1RestServlet):
user_id, token = yield handler.register(
localpart=user,
password=password,
admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
......
......@@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks
def register(self, user_id, token, password_hash,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None):
create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
......@@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
admin
)
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,))
......@@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore):
make_guest,
appservice_id,
create_profile_with_localpart,
admin,
):
now = int(self.clock.time())
......@@ -125,29 +127,42 @@ class RegistrationStore(SQLBaseStore):
try:
if was_guest:
txn.execute("UPDATE users SET"
" password_hash = ?,"
" upgrade_ts = ?,"
" is_guest = ?"
" WHERE name = ?",
[password_hash, now, 1 if make_guest else 0, user_id])
txn.execute(
"UPDATE users SET"
" password_hash = ?,"
" upgrade_ts = ?,"
" is_guest = ?,"
" admin = ?"
" WHERE name = ?",
(password_hash, now, 1 if make_guest else 0, admin, user_id,)
)
self._simple_update_one_txn(
txn,
"users",
keyvalues={
"name": user_id,
},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": admin,
}
)
else:
txn.execute("INSERT INTO users "
"("
" name,"
" password_hash,"
" creation_ts,"
" is_guest,"
" appservice_id"
") "
"VALUES (?,?,?,?,?)",
[
user_id,
password_hash,
now,
1 if make_guest else 0,
appservice_id,
])
self._simple_insert_txn(
txn,
"users",
values={
"name": user_id,
"password_hash": password_hash,
"creation_ts": now,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": admin,
}
)
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment