Skip to content
Snippets Groups Projects
Unverified Commit ee382025 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Abstract shared SSO code. (#8765)

De-duplicates code between the SAML and OIDC implementations.
parent e487d9fa
No related branches found
No related tags found
No related merge requests found
Consolidate logic between the OpenID Connect and SAML code.
...@@ -34,7 +34,8 @@ from typing_extensions import TypedDict ...@@ -34,7 +34,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody from twisted.web.client import readBody
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import respond_with_html from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
...@@ -83,17 +84,12 @@ class OidcError(Exception): ...@@ -83,17 +84,12 @@ class OidcError(Exception):
return self.error return self.error
class MappingException(Exception): class OidcHandler(BaseHandler):
"""Used to catch errors when mapping the UserInfo object
"""
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow. """Handles requests related to the OpenID Connect login flow.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str] self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str self._user_profile_method = hs.config.oidc_user_profile_method # type: str
...@@ -120,36 +116,13 @@ class OidcHandler: ...@@ -120,36 +116,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._datastore = hs.get_datastore()
self._clock = hs.get_clock()
self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.macaroon_secret_key
self._error_template = hs.config.sso_error_template
# identifier for the external_ids table # identifier for the external_ids table
self._auth_provider_id = "oidc" self._auth_provider_id = "oidc"
def _render_error( self._sso_handler = hs.get_sso_handler()
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error. Those include
well-known OAuth2/OIDC error types like invalid_request or
access_denied.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def _validate_metadata(self): def _validate_metadata(self):
"""Verifies the provider metadata. """Verifies the provider metadata.
...@@ -571,7 +544,7 @@ class OidcHandler: ...@@ -571,7 +544,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call way, we don't raise SynapseError from here. Instead, we call
``self._render_error`` which displays an HTML page for the error. ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here: Most of the OpenID Connect logic happens here:
...@@ -609,7 +582,7 @@ class OidcHandler: ...@@ -609,7 +582,7 @@ class OidcHandler:
if error != "access_denied": if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description) logger.error("Error from the OIDC provider: %s %s", error, description)
self._render_error(request, error, description) self._sso_handler.render_error(request, error, description)
return return
# otherwise, it is presumably a successful response. see: # otherwise, it is presumably a successful response. see:
...@@ -619,7 +592,9 @@ class OidcHandler: ...@@ -619,7 +592,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None: if session is None:
logger.info("No session cookie found") logger.info("No session cookie found")
self._render_error(request, "missing_session", "No session cookie found") self._sso_handler.render_error(
request, "missing_session", "No session cookie found"
)
return return
# Remove the cookie. There is a good chance that if the callback failed # Remove the cookie. There is a good chance that if the callback failed
...@@ -637,7 +612,9 @@ class OidcHandler: ...@@ -637,7 +612,9 @@ class OidcHandler:
# Check for the state query parameter # Check for the state query parameter
if b"state" not in request.args: if b"state" not in request.args:
logger.info("State parameter is missing") logger.info("State parameter is missing")
self._render_error(request, "invalid_request", "State parameter is missing") self._sso_handler.render_error(
request, "invalid_request", "State parameter is missing"
)
return return
state = request.args[b"state"][0].decode() state = request.args[b"state"][0].decode()
...@@ -651,17 +628,19 @@ class OidcHandler: ...@@ -651,17 +628,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state) ) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e: except MacaroonDeserializationException as e:
logger.exception("Invalid session") logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e)) self._sso_handler.render_error(request, "invalid_session", str(e))
return return
except MacaroonInvalidSignatureException as e: except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session") logger.exception("Could not verify session")
self._render_error(request, "mismatching_session", str(e)) self._sso_handler.render_error(request, "mismatching_session", str(e))
return return
# Exchange the code with the provider # Exchange the code with the provider
if b"code" not in request.args: if b"code" not in request.args:
logger.info("Code parameter is missing") logger.info("Code parameter is missing")
self._render_error(request, "invalid_request", "Code parameter is missing") self._sso_handler.render_error(
request, "invalid_request", "Code parameter is missing"
)
return return
logger.debug("Exchanging code") logger.debug("Exchanging code")
...@@ -670,7 +649,7 @@ class OidcHandler: ...@@ -670,7 +649,7 @@ class OidcHandler:
token = await self._exchange_code(code) token = await self._exchange_code(code)
except OidcError as e: except OidcError as e:
logger.exception("Could not exchange code") logger.exception("Could not exchange code")
self._render_error(request, e.error, e.error_description) self._sso_handler.render_error(request, e.error, e.error_description)
return return
logger.debug("Successfully obtained OAuth2 access token") logger.debug("Successfully obtained OAuth2 access token")
...@@ -683,7 +662,7 @@ class OidcHandler: ...@@ -683,7 +662,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token) userinfo = await self._fetch_userinfo(token)
except Exception as e: except Exception as e:
logger.exception("Could not fetch userinfo") logger.exception("Could not fetch userinfo")
self._render_error(request, "fetch_error", str(e)) self._sso_handler.render_error(request, "fetch_error", str(e))
return return
else: else:
logger.debug("Extracting userinfo from id_token") logger.debug("Extracting userinfo from id_token")
...@@ -691,7 +670,7 @@ class OidcHandler: ...@@ -691,7 +670,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce) userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e: except Exception as e:
logger.exception("Invalid id_token") logger.exception("Invalid id_token")
self._render_error(request, "invalid_token", str(e)) self._sso_handler.render_error(request, "invalid_token", str(e))
return return
# Pull out the user-agent and IP from the request. # Pull out the user-agent and IP from the request.
...@@ -705,7 +684,7 @@ class OidcHandler: ...@@ -705,7 +684,7 @@ class OidcHandler:
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return return
# Mapping providers might not have get_extra_attributes: only call this # Mapping providers might not have get_extra_attributes: only call this
...@@ -770,7 +749,7 @@ class OidcHandler: ...@@ -770,7 +749,7 @@ class OidcHandler:
macaroon.add_first_party_caveat( macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,) "ui_auth_session_id = %s" % (ui_auth_session_id,)
) )
now = self._clock.time_msec() now = self.clock.time_msec()
expiry = now + duration_in_ms expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
...@@ -845,7 +824,7 @@ class OidcHandler: ...@@ -845,7 +824,7 @@ class OidcHandler:
if not caveat.startswith(prefix): if not caveat.startswith(prefix):
return False return False
expiry = int(caveat[len(prefix) :]) expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec() now = self.clock.time_msec()
return now < expiry return now < expiry
async def _map_userinfo_to_user( async def _map_userinfo_to_user(
...@@ -885,20 +864,14 @@ class OidcHandler: ...@@ -885,20 +864,14 @@ class OidcHandler:
# to be strings. # to be strings.
remote_user_id = str(remote_user_id) remote_user_id = str(remote_user_id)
logger.info( # first of all, check if we already have a mapping for this user
"Looking for existing mapping for user %s:%s", previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
self._auth_provider_id,
remote_user_id,
)
registered_user_id = await self._datastore.get_user_by_external_id(
self._auth_provider_id, remote_user_id, self._auth_provider_id, remote_user_id,
) )
if previously_registered_user_id:
return previously_registered_user_id
if registered_user_id is not None: # Otherwise, generate a new user.
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
try: try:
attributes = await self._user_mapping_provider.map_user_attributes( attributes = await self._user_mapping_provider.map_user_attributes(
userinfo, token userinfo, token
...@@ -917,8 +890,8 @@ class OidcHandler: ...@@ -917,8 +890,8 @@ class OidcHandler:
localpart = map_username_to_mxid_localpart(attributes["localpart"]) localpart = map_username_to_mxid_localpart(attributes["localpart"])
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self.server_name).to_string()
users = await self._datastore.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
if self._allow_existing_users: if self._allow_existing_users:
if len(users) == 1: if len(users) == 1:
...@@ -942,7 +915,8 @@ class OidcHandler: ...@@ -942,7 +915,8 @@ class OidcHandler:
default_display_name=attributes["display_name"], default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address), user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id(
await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id, self._auth_provider_id, remote_user_id, registered_user_id,
) )
return registered_user_id return registered_user_id
......
...@@ -24,7 +24,8 @@ from saml2.client import Saml2Client ...@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
...@@ -42,10 +43,6 @@ if TYPE_CHECKING: ...@@ -42,10 +43,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s(slots=True) @attr.s(slots=True)
class Saml2SessionData: class Saml2SessionData:
"""Data we track about SAML2 sessions""" """Data we track about SAML2 sessions"""
...@@ -57,17 +54,13 @@ class Saml2SessionData: ...@@ -57,17 +54,13 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None) ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler: class SamlHandler(BaseHandler):
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = ( self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute hs.config.saml2_grandfathered_mxid_source_attribute
...@@ -88,26 +81,9 @@ class SamlHandler: ...@@ -88,26 +81,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings # a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can self._sso_handler = hs.get_sso_handler()
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def handle_redirect_request( def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
...@@ -130,7 +106,7 @@ class SamlHandler: ...@@ -130,7 +106,7 @@ class SamlHandler:
# Since SAML sessions timeout it is useful to log when they were created. # Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,)) logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec() now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData( self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id, creation_time=now, ui_auth_session_id=ui_auth_session_id,
) )
...@@ -171,12 +147,12 @@ class SamlHandler: ...@@ -171,12 +147,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here # in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later. # so we can track down the session IDs later.
logger.warning(str(e)) logger.warning(str(e))
self._render_error( self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login." request, "unsolicited_response", "Unexpected SAML2 login."
) )
return return
except Exception as e: except Exception as e:
self._render_error( self._sso_handler.render_error(
request, request,
"invalid_response", "invalid_response",
"Unable to parse SAML2 response: %s." % (e,), "Unable to parse SAML2 response: %s." % (e,),
...@@ -184,7 +160,7 @@ class SamlHandler: ...@@ -184,7 +160,7 @@ class SamlHandler:
return return
if saml2_auth.not_signed: if saml2_auth.not_signed:
self._render_error( self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed." request, "unsigned_respond", "SAML2 response was not signed."
) )
return return
...@@ -210,7 +186,7 @@ class SamlHandler: ...@@ -210,7 +186,7 @@ class SamlHandler:
# attributes. # attributes.
for requirement in self._saml2_attribute_requirements: for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement): if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error( self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here." request, "unauthorised", "You are not authorised to log in here."
) )
return return
...@@ -226,7 +202,7 @@ class SamlHandler: ...@@ -226,7 +202,7 @@ class SamlHandler:
) )
except MappingException as e: except MappingException as e:
logger.exception("Could not map user") logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e)) self._sso_handler.render_error(request, "mapping_error", str(e))
return return
# Complete the interactive auth session or the login. # Complete the interactive auth session or the login.
...@@ -274,17 +250,11 @@ class SamlHandler: ...@@ -274,17 +250,11 @@ class SamlHandler:
with (await self._mapping_lock.queue(self._auth_provider_id)): with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user # first of all, check if we already have a mapping for this user
logger.info( previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
"Looking for existing mapping for user %s:%s", self._auth_provider_id, remote_user_id,
self._auth_provider_id,
remote_user_id,
) )
registered_user_id = await self._datastore.get_user_by_external_id( if previously_registered_user_id:
self._auth_provider_id, remote_user_id return previously_registered_user_id
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id
# backwards-compatibility hack: see if there is an existing user with a # backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid # suitable mapping from the uid
...@@ -294,7 +264,7 @@ class SamlHandler: ...@@ -294,7 +264,7 @@ class SamlHandler:
): ):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0] attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID( user_id = UserID(
map_username_to_mxid_localpart(attrval), self._hostname map_username_to_mxid_localpart(attrval), self.server_name
).to_string() ).to_string()
logger.info( logger.info(
"Looking for existing account based on mapped %s %s", "Looking for existing account based on mapped %s %s",
...@@ -302,11 +272,11 @@ class SamlHandler: ...@@ -302,11 +272,11 @@ class SamlHandler:
user_id, user_id,
) )
users = await self._datastore.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
registered_user_id = list(users.keys())[0] registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id) logger.info("Grandfathering mapping to %s", registered_user_id)
await self._datastore.record_user_external_id( await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id self._auth_provider_id, remote_user_id, registered_user_id
) )
return registered_user_id return registered_user_id
...@@ -335,8 +305,8 @@ class SamlHandler: ...@@ -335,8 +305,8 @@ class SamlHandler:
emails = attribute_dict.get("emails", []) emails = attribute_dict.get("emails", [])
# Check if this mxid already exists # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive( if not await self.store.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string() UserID(localpart, self.server_name).to_string()
): ):
# This mxid is free # This mxid is free
break break
...@@ -348,7 +318,6 @@ class SamlHandler: ...@@ -348,7 +318,6 @@ class SamlHandler:
) )
logger.info("Mapped SAML user to local part %s", localpart) logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, localpart=localpart,
default_display_name=displayname, default_display_name=displayname,
...@@ -356,13 +325,13 @@ class SamlHandler: ...@@ -356,13 +325,13 @@ class SamlHandler:
user_agent_ips=(user_agent, ip_address), user_agent_ips=(user_agent, ip_address),
) )
await self._datastore.record_user_external_id( await self.store.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id self._auth_provider_id, remote_user_id, registered_user_id
) )
return registered_user_id return registered_user_id
def expire_sessions(self): def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set() to_expire = set()
for reqid, data in self._outstanding_requests_dict.items(): for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before: if data.creation_time < expire_before:
......
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MappingException(Exception):
"""Used to catch errors when mapping the UserInfo object
"""
class SsoHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._error_template = hs.config.sso_error_template
def render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and responds with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
async def get_sso_user_by_remote_user_id(
self, auth_provider_id: str, remote_user_id: str
) -> Optional[str]:
"""
Maps the user ID of a remote IdP to a mxid for a previously seen user.
If the user has not been seen yet, this will return None.
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
remote_user_id: The user ID according to the remote IdP. This might
be an e-mail address, a GUID, or some other form. It must be
unique and immutable.
Returns:
The mxid of a previously seen user.
"""
# Check if we already have a mapping for this user.
logger.info(
"Looking for existing mapping for user %s:%s",
auth_provider_id,
remote_user_id,
)
previously_registered_user_id = await self.store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
# A match was found, return the user ID.
if previously_registered_user_id is not None:
logger.info("Found existing mapping %s", previously_registered_user_id)
return previously_registered_user_id
# No match.
return None
...@@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler ...@@ -89,6 +89,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
...@@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta): ...@@ -390,6 +391,10 @@ class HomeServer(metaclass=abc.ABCMeta):
else: else:
return FollowerTypingHandler(self) return FollowerTypingHandler(self)
@cache_in_self
def get_sso_handler(self) -> SsoHandler:
return SsoHandler(self)
@cache_in_self @cache_in_self
def get_sync_handler(self) -> SyncHandler: def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self) return SyncHandler(self)
......
...@@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
) )
self.handler = OidcHandler(hs) self.handler = OidcHandler(hs)
# Mock the render error method.
self.render_error = Mock(return_value=None)
self.handler._sso_handler.render_error = self.render_error
return hs return hs
...@@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values) return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None): def assertRenderedError(self, error, error_description=None):
args = self.handler._render_error.call_args[0] args = self.render_error.call_args[0]
self.assertEqual(args[1], error) self.assertEqual(args[1], error)
if error_description is not None: if error_description is not None:
self.assertEqual(args[2], error_description) self.assertEqual(args[2], error_description)
# Reset the render_error mock # Reset the render_error mock
self.handler._render_error.reset_mock() self.render_error.reset_mock()
def test_config(self): def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly.""" """Basic config correctly sets up the callback URL and client auth correctly."""
...@@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self): def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed.""" """Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={}) request = Mock(args={})
request.args[b"error"] = [b"invalid_client"] request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request)) self.get_success(self.handler.handle_oidc_callback(request))
...@@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar", "preferred_username": "bar",
} }
user_id = "@foo:domain.org" user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
...@@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address userinfo, token, user_agent, ip_address
) )
self.handler._fetch_userinfo.assert_not_called() self.handler._fetch_userinfo.assert_not_called()
self.handler._render_error.assert_not_called() self.render_error.assert_not_called()
# Handle mapping errors # Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock( self.handler._map_userinfo_to_user = simple_async_mock(
...@@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address userinfo, token, user_agent, ip_address
) )
self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._fetch_userinfo.assert_called_once_with(token)
self.handler._render_error.assert_not_called() self.render_error.assert_not_called()
# Handle userinfo fetching error # Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
...@@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase): ...@@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self): def test_callback_session(self):
"""The callback verifies the session presence and validity""" """The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"]) request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie # Missing cookie
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment