Skip to content
Snippets Groups Projects
Unverified Commit 3cfac959 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Merge pull request #6689 from matrix-org/rav/saml_mapping_provider_updates

Updates to the SAML mapping provider API
parents 80396850 d56e95ea
No related branches found
No related tags found
No related merge requests found
Updates to the SAML mapping provider API.
...@@ -24,6 +24,7 @@ from saml2.client import Saml2Client ...@@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import ( from synapse.types import (
UserID, UserID,
...@@ -59,7 +60,8 @@ class SamlHandler: ...@@ -59,7 +60,8 @@ class SamlHandler:
# plugin to do custom mapping from saml response to mxid # plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config hs.config.saml2_user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
) )
# identifier for the external_ids table # identifier for the external_ids table
...@@ -112,10 +114,10 @@ class SamlHandler: ...@@ -112,10 +114,10 @@ class SamlHandler:
# the dict. # the dict.
self.expire_sessions() self.expire_sessions()
user_id = await self._map_saml_response_to_user(resp_bytes) user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state) self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes): async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try: try:
saml2_auth = self._saml_client.parse_authn_request_response( saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes, resp_bytes,
...@@ -183,7 +185,7 @@ class SamlHandler: ...@@ -183,7 +185,7 @@ class SamlHandler:
# Map saml response to user attributes using the configured mapping provider # Map saml response to user attributes using the configured mapping provider
for i in range(1000): for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i saml2_auth, i, client_redirect_url=client_redirect_url,
) )
logger.debug( logger.debug(
...@@ -216,6 +218,8 @@ class SamlHandler: ...@@ -216,6 +218,8 @@ class SamlHandler:
500, "Unable to generate a Matrix ID from the SAML response" 500, "Unable to generate a Matrix ID from the SAML response"
) )
logger.info("Mapped SAML user to local part %s", localpart)
registered_user_id = await self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname localpart=localpart, default_display_name=displayname
) )
...@@ -265,17 +269,21 @@ class SamlConfig(object): ...@@ -265,17 +269,21 @@ class SamlConfig(object):
class DefaultSamlMappingProvider(object): class DefaultSamlMappingProvider(object):
__version__ = "0.0.1" __version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig): def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
"""The default SAML user mapping provider """The default SAML user mapping provider
Args: Args:
parsed_config: Module configuration parsed_config: Module configuration
module_api: module api proxy
""" """
self._mxid_source_attribute = parsed_config.mxid_source_attribute self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes( def saml_response_to_user_attributes(
self, saml_response: saml2.response.AuthnResponse, failures: int = 0, self,
saml_response: saml2.response.AuthnResponse,
failures: int,
client_redirect_url: str,
) -> dict: ) -> dict:
"""Maps some text from a SAML response to attributes of a new user """Maps some text from a SAML response to attributes of a new user
...@@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object): ...@@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object):
failures: How many times a call to this function with this failures: How many times a call to this function with this
saml_response has resulted in a failure saml_response has resulted in a failure
client_redirect_url: where the client wants to redirect to
Returns: Returns:
dict: A dict containing new user attributes. Possible keys: dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid * mxid_localpart (str): Required. The localpart of the user's mxid
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment