Skip to content
Snippets Groups Projects
Commit 6f0c344c authored by Daniel Wagner-Hall's avatar Daniel Wagner-Hall
Browse files

Merge pull request #255 from matrix-org/mergeeriksmadness

Merge erikj/user_dedup to develop
parents 3cab86a1 d3c0e488
No related branches found
No related tags found
No related merge requests found
...@@ -163,7 +163,8 @@ class AuthHandler(BaseHandler): ...@@ -163,7 +163,8 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
...@@ -280,27 +281,49 @@ class AuthHandler(BaseHandler): ...@@ -280,27 +281,49 @@ class AuthHandler(BaseHandler):
password (str): Password password (str): Password
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
yield self._check_password(user_id, password) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id) access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks that user_id has passed password, raises LoginError if not.""" """Checks to see if a user with the given id exists. Will check case
user_info = yield self.store.get_user_by_id(user_id=user_id) insensitively, but will throw if there are multiple inexact matches.
if not user_info:
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"] if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not bcrypt.checkpw(password, stored_hash): if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
......
...@@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler): ...@@ -56,8 +56,8 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if u: if users:
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",
......
...@@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -83,10 +83,11 @@ class LoginRestServlet(ClientV1RestServlet):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create( user_id = UserID.create(
user_id, self.hs.hostname).to_string() user_id, self.hs.hostname
).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
access_token, refresh_token = yield auth_handler.login_with_password( user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])
......
...@@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore): ...@@ -120,6 +120,20 @@ class RegistrationStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn.fetchall())
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment