Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • maunium/synapse
  • leytilera/synapse
2 results
Show changes
# Copyright 2020 Quentin Gliech
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2020 Quentin Gliech
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import binascii
import inspect
import json
......@@ -24,6 +31,7 @@ from typing import (
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)
......@@ -45,7 +53,6 @@ from pymacaroons.exceptions import (
MacaroonInitException,
MacaroonInvalidSignatureException,
)
from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
......@@ -58,6 +65,7 @@ from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
......@@ -374,7 +382,12 @@ class OidcProvider:
self._macaroon_generaton = macaroon_generator
self._config = provider
self._callback_url: str = hs.config.oidc.oidc_callback_url
self._callback_url: str
if provider.redirect_uri is not None:
self._callback_url = provider.redirect_uri
else:
self._callback_url = hs.config.oidc.oidc_callback_url
# Calculate the prefix for OIDC callback paths based on the public_baseurl.
# We'll insert this into the Path= parameter of any session cookies we set.
......@@ -414,9 +427,19 @@ class OidcProvider:
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config
user_mapping_provider_init_method = (
provider.user_mapping_provider_class.__init__
)
if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
ModuleApi(hs, hs.get_auth_handler()),
)
else:
self._user_mapping_provider = provider.user_mapping_provider_class(
provider.user_mapping_provider_config,
)
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
......@@ -435,6 +458,10 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
self.additional_authorization_parameters = (
provider.additional_authorization_parameters
)
self._sso_handler = hs.get_sso_handler()
self._device_handler = hs.get_device_handler()
......@@ -618,6 +645,11 @@ class OidcProvider:
elif self._config.pkce_method == "never":
metadata.pop("code_challenge_methods_supported", None)
if self._config.id_token_signing_alg_values_supported:
metadata["id_token_signing_alg_values_supported"] = (
self._config.id_token_signing_alg_values_supported
)
self._validate_metadata(metadata)
return metadata
......@@ -811,14 +843,38 @@ class OidcProvider:
logger.debug("Using the OAuth2 access_token to request userinfo")
metadata = await self.load_metadata()
resp = await self._http_client.get_json(
resp = await self._http_client.request(
"GET",
metadata["userinfo_endpoint"],
headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
headers=Headers(
{"Authorization": ["Bearer {}".format(token["access_token"])]}
),
)
logger.debug("Retrieved user info from userinfo endpoint: %r", resp)
body = await readBody(resp)
content_type_headers = resp.headers.getRawHeaders("Content-Type")
assert content_type_headers
# We use `startswith` because the header value can contain the `charset` parameter
# even if it is useless, and Twisted doesn't take care of that for us.
if content_type_headers[0].startswith("application/jwt"):
alg_values = metadata.get(
"id_token_signing_alg_values_supported", ["RS256"]
)
jwt = JsonWebToken(alg_values)
jwk_set = await self.load_jwks()
try:
decoded_resp = jwt.decode(body, key=jwk_set)
except ValueError:
logger.info("Reloading JWKS after decode error")
jwk_set = await self.load_jwks(force=True) # try reloading the jwks
decoded_resp = jwt.decode(body, key=jwk_set)
else:
decoded_resp = json_decoder.decode(body.decode("utf-8"))
logger.debug("Retrieved user info from userinfo endpoint: %r", decoded_resp)
return UserInfo(resp)
return UserInfo(decoded_resp)
async def _verify_jwt(
self,
......@@ -956,7 +1012,21 @@ class OidcProvider:
"""
state = generate_token()
nonce = generate_token()
# Generate a nonce 32 characters long. When encoded with base64url later on,
# the nonce will be 43 characters when sent to the identity provider.
#
# While RFC7636 does not specify a minimum length for the `nonce`
# parameter, the TI-Messenger IDP_FD spec v1.7.3 does require it to be
# between 43 and 128 characters. This spec concerns using Matrix for
# communication in German healthcare.
#
# As increasing the length only strengthens security, we use this length
# to allow TI-Messenger deployments using Synapse to satisfy this
# external spec.
#
# See https://github.com/element-hq/synapse/pull/18109 for more context.
nonce = generate_token(length=32)
code_verifier = ""
if not client_redirect_url:
......@@ -964,17 +1034,21 @@ class OidcProvider:
metadata = await self.load_metadata()
additional_authorization_parameters = dict(
self.additional_authorization_parameters
)
# Automatically enable PKCE if it is supported.
extra_grant_values = {}
if metadata.get("code_challenge_methods_supported"):
code_verifier = generate_token(48)
# Note that we verified the server supports S256 earlier (in
# OidcProvider._validate_metadata).
extra_grant_values = {
"code_challenge_method": "S256",
"code_challenge": create_s256_code_challenge(code_verifier),
}
additional_authorization_parameters.update(
{
"code_challenge_method": "S256",
"code_challenge": create_s256_code_challenge(code_verifier),
}
)
cookie = self._macaroon_generaton.generate_oidc_session_token(
state=state,
......@@ -1013,7 +1087,7 @@ class OidcProvider:
scope=self._scopes,
state=state,
nonce=nonce,
**extra_grant_values,
**additional_authorization_parameters,
)
async def handle_oidc_callback(
......@@ -1354,7 +1428,7 @@ class OidcProvider:
finish_request(request)
class LogoutToken(JWTClaims):
class LogoutToken(JWTClaims): # type: ignore[misc]
"""
Holds and verify claims of a logout token, as per
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
......@@ -1576,7 +1650,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""
def __init__(self, config: JinjaOidcMappingConfig):
def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config
@staticmethod
......