Skip to content
Snippets Groups Projects
Commit 74c56f79 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Break dependency of auth_handler on device_handler

I'm going to need to make the device_handler depend on the auth_handler, so I
need to break this dependency to avoid a cycle.

It turns out that the auth_handler was only using the device_handler in one
place which was an edge case which we can more elegantly handle by throwing an
error rather than fixing it up.
parent a72e4e3e
No related branches found
No related tags found
No related merge requests found
...@@ -75,7 +75,6 @@ class AuthHandler(BaseHandler): ...@@ -75,7 +75,6 @@ class AuthHandler(BaseHandler):
logger.info("Extra password_providers: %r", self.password_providers) logger.info("Extra password_providers: %r", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
...@@ -406,8 +405,7 @@ class AuthHandler(BaseHandler): ...@@ -406,8 +405,7 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None):
initial_display_name=None):
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
...@@ -421,13 +419,10 @@ class AuthHandler(BaseHandler): ...@@ -421,13 +419,10 @@ class AuthHandler(BaseHandler):
device_id (str|None): the device ID to associate with the tokens. device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated: None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID) we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
The access token for the user's session. The access token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
...@@ -437,9 +432,11 @@ class AuthHandler(BaseHandler): ...@@ -437,9 +432,11 @@ class AuthHandler(BaseHandler):
# really don't want is active access_tokens without a record of the # really don't want is active access_tokens without a record of the
# device, so we double-check it here. # device, so we double-check it here.
if device_id is not None: if device_id is not None:
yield self.device_handler.check_device_registered( try:
user_id, device_id, initial_display_name yield self.store.get_device(user_id, device_id)
) except StoreError:
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
defer.returnValue(access_token) defer.returnValue(access_token)
......
...@@ -219,7 +219,6 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -219,7 +219,6 @@ class LoginRestServlet(ClientV1RestServlet):
) )
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
canonical_user_id, device_id, canonical_user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {
...@@ -241,7 +240,6 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -241,7 +240,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id, user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
...@@ -284,7 +282,6 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -284,7 +282,6 @@ class LoginRestServlet(ClientV1RestServlet):
) )
access_token = yield auth_handler.get_access_token_for_user_id( access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id, registered_user_id, device_id,
login_submission.get("initial_device_display_name"),
) )
result = { result = {
......
...@@ -566,7 +566,6 @@ class RegisterRestServlet(RestServlet): ...@@ -566,7 +566,6 @@ class RegisterRestServlet(RestServlet):
access_token = ( access_token = (
yield self.auth_handler.get_access_token_for_user_id( yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
) )
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment