Skip to content
Snippets Groups Projects
Unverified Commit 22246919 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Add more type hints to SAML handler. (#7445)

parent d7983b63
No related branches found
No related tags found
No related merge requests found
Add type hints to the SAML handler.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import Optional, Tuple from typing import Callable, Dict, Optional, Set, Tuple
import attr import attr
import saml2 import saml2
...@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError ...@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.module_api.errors import RedirectException from synapse.module_api.errors import RedirectException
from synapse.types import ( from synapse.types import (
...@@ -81,17 +82,19 @@ class SamlHandler: ...@@ -81,17 +82,19 @@ class SamlHandler:
self._error_html_content = hs.config.saml2_error_html_content self._error_html_content = hs.config.saml2_error_html_content
def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None): def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
"""Handle an incoming request to /login/sso/redirect """Handle an incoming request to /login/sso/redirect
Args: Args:
client_redirect_url (bytes): the URL that we should redirect the client_redirect_url: the URL that we should redirect the
client to when everything is done client to when everything is done
ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login). None if this is a login).
Returns: Returns:
bytes: URL to redirect to URL to redirect to
""" """
reqid, info = self._saml_client.prepare_for_authenticate( reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url relay_state=client_redirect_url
...@@ -109,15 +112,15 @@ class SamlHandler: ...@@ -109,15 +112,15 @@ class SamlHandler:
# this shouldn't happen! # this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header") raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request): async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_matrix/saml2/authn_response """Handle an incoming request to /_matrix/saml2/authn_response
Args: Args:
request (SynapseRequest): the incoming request from the browser. We'll request: the incoming request from the browser. We'll
respond to it with a redirect. respond to it with a redirect.
Returns: Returns:
Deferred[none]: Completes once we have handled the request. Completes once we have handled the request.
""" """
resp_bytes = parse_string(request, "SAMLResponse", required=True) resp_bytes = parse_string(request, "SAMLResponse", required=True)
relay_state = parse_string(request, "RelayState", required=True) relay_state = parse_string(request, "RelayState", required=True)
...@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile( ...@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
def dot_replace_for_mxid(username: str) -> str: def dot_replace_for_mxid(username: str) -> str:
"""Replace any characters which are not allowed in Matrix IDs with a dot."""
username = username.lower() username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username) username = DOT_REPLACE_PATTERN.sub(".", username)
...@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str: ...@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
MXID_MAPPER_MAP = { MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart, "hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid, "dotreplace": dot_replace_for_mxid,
} } # type: Dict[str, Callable[[str], str]]
@attr.s @attr.s
...@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object): ...@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
def get_remote_user_id( def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
): ) -> str:
"""Extracts the remote user id from the SAML response""" """Extracts the remote user id from the SAML response"""
try: try:
return saml_response.ava["uid"][0] return saml_response.ava["uid"][0]
...@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object): ...@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
return SamlConfig(mxid_source_attribute, mxid_mapper) return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod @staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML """Returns the required attributes of a SAML
Args: Args:
config: A SamlConfig object containing configuration params for this provider config: A SamlConfig object containing configuration params for this provider
Returns: Returns:
tuple[set,set]: The first set equates to the saml auth response The first set equates to the saml auth response
attributes that are required for the module to function, whereas the attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if second set consists of those attributes which can be used if
available, but are not necessary available, but are not necessary
......
...@@ -186,6 +186,7 @@ commands = mypy \ ...@@ -186,6 +186,7 @@ commands = mypy \
synapse/handlers/cas_handler.py \ synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \ synapse/handlers/directory.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \
synapse/logging/ \ synapse/logging/ \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment