Skip to content
Snippets Groups Projects
Unverified Commit c64002e1 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Refactor `SsoHandler.get_mxid_from_sso` (#8900)

* Factor out _call_attribute_mapper and _register_mapped_user

This is mostly an attempt to simplify `get_mxid_from_sso`.

* Move mapping_lock down into SsoHandler.
parent 1821f7cc
No related branches found
No related tags found
No related merge requests found
Add support for allowing users to pick their own user ID during a single-sign-on login.
...@@ -34,7 +34,6 @@ from synapse.types import ( ...@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart, map_username_to_mxid_localpart,
mxid_localpart_allowed_characters, mxid_localpart_allowed_characters,
) )
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -81,9 +80,6 @@ class SamlHandler(BaseHandler): ...@@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object # a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
def handle_redirect_request( def handle_redirect_request(
...@@ -299,15 +295,14 @@ class SamlHandler(BaseHandler): ...@@ -299,15 +295,14 @@ class SamlHandler(BaseHandler):
return None return None
with (await self._mapping_lock.queue(self._auth_provider_id)): return await self._sso_handler.get_mxid_from_sso(
return await self._sso_handler.get_mxid_from_sso( self._auth_provider_id,
self._auth_provider_id, remote_user_id,
remote_user_id, user_agent,
user_agent, ip_address,
ip_address, saml_response_to_remapped_user_attributes,
saml_response_to_remapped_user_attributes, grandfather_existing_users,
grandfather_existing_users, )
)
def _remote_id_from_saml_response( def _remote_id_from_saml_response(
self, self,
......
...@@ -22,6 +22,7 @@ from twisted.web.http import Request ...@@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import RedirectException from synapse.api.errors import RedirectException
from synapse.http.server import respond_with_html from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters from synapse.types import UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
...@@ -54,6 +55,9 @@ class SsoHandler: ...@@ -54,6 +55,9 @@ class SsoHandler:
self._error_template = hs.config.sso_error_template self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
def render_error( def render_error(
self, request, error: str, error_description: Optional[str] = None self, request, error: str, error_description: Optional[str] = None
) -> None: ) -> None:
...@@ -172,24 +176,38 @@ class SsoHandler: ...@@ -172,24 +176,38 @@ class SsoHandler:
to an additional page. (e.g. to prompt for more information) to an additional page. (e.g. to prompt for more information)
""" """
# first of all, check if we already have a mapping for this user # grab a lock while we try to find a mapping for this user. This seems...
previously_registered_user_id = await self.get_sso_user_by_remote_user_id( # optimistic, especially for implementations that end up redirecting to
auth_provider_id, remote_user_id, # interstitial pages.
) with await self._mapping_lock.queue(auth_provider_id):
if previously_registered_user_id: # first of all, check if we already have a mapping for this user
return previously_registered_user_id previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
# Check for grandfathering of users. )
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id: if previously_registered_user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id return previously_registered_user_id
# Otherwise, generate a new user. # Check for grandfathering of users.
if grandfather_existing_users:
previously_registered_user_id = await grandfather_existing_users()
if previously_registered_user_id:
# Future logins should also match this user ID.
await self._store.record_user_external_id(
auth_provider_id, remote_user_id, previously_registered_user_id
)
return previously_registered_user_id
# Otherwise, generate a new user.
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
user_id = await self._register_mapped_user(
attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
)
return user_id
async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES): for i in range(self._MAP_USERNAME_RETRIES):
try: try:
attributes = await sso_to_matrix_id_mapper(i) attributes = await sso_to_matrix_id_mapper(i)
...@@ -227,7 +245,16 @@ class SsoHandler: ...@@ -227,7 +245,16 @@ class SsoHandler:
raise MappingException( raise MappingException(
"Unable to generate a Matrix ID from the SSO response" "Unable to generate a Matrix ID from the SSO response"
) )
return attributes
async def _register_mapped_user(
self,
attributes: UserAttributes,
auth_provider_id: str,
remote_user_id: str,
user_agent: str,
ip_address: str,
) -> str:
# Since the localpart is provided via a potentially untrusted module, # Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering. # ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart): if contains_invalid_mxid_characters(attributes.localpart):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment