Newer
Older
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 Quentin Gliech
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, AsyncMock, Mock, patch
from urllib.parse import parse_qs, urlparse
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
try:
import authlib # noqa: F401
from authlib.oidc.core import UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata
from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True
except ImportError:
HAS_OIDC = False
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
# config for common cases
DEFAULT_CONFIG = {
"enabled": True,
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"issuer": ISSUER,
"scopes": SCOPES,
"user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# extends the default config with explicit OAuth2 endpoints instead of using discovery
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"authorization_endpoint": ISSUER + "authorize",
"token_endpoint": ISSUER + "token",
"jwks_uri": ISSUER + "jwks",
@staticmethod
def parse_config(config: JsonDict) -> None:
return None
def __init__(self, config: None):
def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]
async def map_user_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> "UserAttributeDict":
# This is testing not providing the full map.
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
async def get_extra_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> JsonDict:
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
# Superclass is testing the legacy interface for map_user_attributes.
async def map_user_attributes( # type: ignore[override]
self, userinfo: "UserInfo", token: "Token", failures: int
) -> "UserAttributeDict":
return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
# this key was generated with:
# openssl ecparam -name prime256v1 -genkey -noout |
# openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
#
# we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
# that's what Apple use, and we want to be sure that we work with Apple's keys.
#
# (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
# keys using ASN.1. Both are then typically formatted using PEM, which says: use the
# base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
# really need to care about any of that.)
return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
def _public_key_file_path() -> str:
"""path to a file containing the public half of a test key"""
# this was generated with:
# openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
#
# See above about where oidc_test_key.p8 came from
return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
hs = self.setup_test_homeserver()
self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
self.hs_patcher.start() # type: ignore[attr-defined]
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
sso_handler.render_error = self.render_error # type: ignore[method-assign]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
auth_handler = hs.get_auth_handler()
# Mock the complete SSO login method.
self.complete_sso_login = AsyncMock()
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign]
def tearDown(self) -> None:
self.hs_patcher.stop() # type: ignore[attr-defined]
return super().tearDown()
def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()
def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata()
metadata.update(values)
return patch.object(self.fake_server, "get_metadata", return_value=metadata)
def start_authorization(
self,
userinfo: dict,
client_redirect_url: str = "http://client/redirect",
scope: str = "openid",
with_sid: bool = False,
) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
"""Start an authorization request, and get the callback request back."""
nonce = random_string(10)
state = random_string(10)
code, grant = self.fake_server.start_authorization(
userinfo=userinfo,
scope=scope,
client_id=self.provider._client_auth.client_id,
redirect_uri=self.provider._callback_url,
nonce=nonce,
with_sid=with_sid,
)
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant
def assertRenderedError(
self, error: str, error_description: Optional[str] = None
) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_called_once()
self.assertEqual(metadata.issuer, self.fake_server.issuer)
self.assertEqual(
metadata.authorization_endpoint,
self.fake_server.authorization_endpoint,
)
self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
# It seems like authlib does not have that defined in its metadata models
self.assertEqual(
metadata.get("userinfo_endpoint"),
self.fake_server.userinfo_endpoint,
)
# subsequent calls should be cached
self.reset_mocks()
self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.fake_server.get_jwks_handler.assert_called_once()
self.assertEqual(jwks, self.fake_server.get_jwks())
# subsequent calls should be cached…
self.reset_mocks()
self.get_success(self.provider.load_jwks())
self.fake_server.get_jwks_handler.assert_not_called()
self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
self.fake_server.get_jwks_handler.assert_called_once()
with self.metadata_edit({"jwks_uri": None}):
# If we don't do this, the load_metadata call will throw because of the
# missing jwks_uri
self.provider._user_profile_method = "userinfo_endpoint"
self.get_success(self.provider.load_metadata(force=True))
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
def force_load_metadata() -> Awaitable[None]:
async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
# Default test config does not throw
force_load_metadata()
with self.metadata_edit({"issuer": None}):
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "http://insecure/"}):
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"issuer": "https://invalid/?because=query"}):
self.assertRaisesRegex(ValueError, "issuer", force_load_metadata)
with self.metadata_edit({"authorization_endpoint": None}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"authorization_endpoint": "http://insecure/auth"}):
self.assertRaisesRegex(
ValueError, "authorization_endpoint", force_load_metadata
)
with self.metadata_edit({"token_endpoint": None}):
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"token_endpoint": "http://insecure/token"}):
self.assertRaisesRegex(ValueError, "token_endpoint", force_load_metadata)
with self.metadata_edit({"jwks_uri": None}):
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"jwks_uri": "http://insecure/jwks.json"}):
self.assertRaisesRegex(ValueError, "jwks_uri", force_load_metadata)
with self.metadata_edit({"response_types_supported": ["id_token"]}):
self.assertRaisesRegex(
ValueError, "response_types_supported", force_load_metadata
)
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_basic"]}
):
# should not throw, as client_secret_basic is the default auth method
force_load_metadata()
with self.metadata_edit(
{"token_endpoint_auth_methods_supported": ["client_secret_post"]}
):
self.assertRaisesRegex(
ValueError,
"token_endpoint_auth_methods_supported",
force_load_metadata,
# Tests for configs that require the userinfo endpoint
self.assertFalse(h._uses_userinfo)
self.assertEqual(h._user_profile_method, "auto")
h._user_profile_method = "userinfo_endpoint"
self.assertTrue(h._uses_userinfo)
# Revert the profile method and do not request the "openid" scope: this should
# mean that we check for a userinfo endpoint
h._user_profile_method = "auto"
h._scopes = []
self.assertTrue(h._uses_userinfo)
with self.metadata_edit({"userinfo_endpoint": None}):
self.assertRaisesRegex(ValueError, "userinfo_endpoint", force_load_metadata)
with self.metadata_edit({"jwks_uri": None}):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
self.assertEqual(url.path, auth_endpoint.path)
params = parse_qs(url.query)
self.assertEqual(params["redirect_uri"], [CALLBACK_URL])
self.assertEqual(params["response_type"], ["code"])
self.assertEqual(params["scope"], [" ".join(SCOPES)])
self.assertEqual(params["client_id"], [CLIENT_ID])
self.assertEqual(len(params["state"]), 1)
self.assertEqual(len(params["nonce"]), 1)
self.assertNotIn("code_challenge", params)
# Check what is in the cookies
self.assertEqual(len(req.cookies), 2) # two cookies
cookie_header = req.cookies[0]
# The cookie name and path don't really matter, just that it has to be coherent
# between the callback & redirect handlers.
parts = [p.strip() for p in cookie_header.split(b";")]
self.assertIn(b"Path=/_synapse/client/oidc", parts)
name, cookie = parts[0].split(b"=")
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = get_value_from_macaroon(macaroon, "state")
nonce = get_value_from_macaroon(macaroon, "nonce")
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(code_verifier, "")
self.assertEqual(redirect, "http://client/redirect")
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request_with_code_challenge(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(
req, b"http://client/redirect"
)
)
)
# Ensure the code_challenge param is added to the redirect.
params = parse_qs(url.query)
self.assertEqual(len(params["code_challenge"]), 1)
# Check what is in the cookies
self.assertEqual(len(req.cookies), 2) # two cookies
cookie_header = req.cookies[0]
# The cookie name and path don't really matter, just that it has to be coherent
# between the callback & redirect handlers.
parts = [p.strip() for p in cookie_header.split(b";")]
self.assertIn(b"Path=/_synapse/client/oidc", parts)
name, cookie = parts[0].split(b"=")
self.assertEqual(name, b"oidc_session")
# Ensure the code_verifier is set in the cookie.
macaroon = pymacaroons.Macaroon.deserialize(cookie)
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
self.assertNotEqual(code_verifier, "")
@override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "always"}})
def test_redirect_request_with_forced_code_challenge(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
# Ensure the code_challenge param is added to the redirect.
params = parse_qs(url.query)
self.assertEqual(len(params["code_challenge"]), 1)
# Check what is in the cookies
self.assertEqual(len(req.cookies), 2) # two cookies
cookie_header = req.cookies[0]
# The cookie name and path don't really matter, just that it has to be coherent
# between the callback & redirect handlers.
parts = [p.strip() for p in cookie_header.split(b";")]
self.assertIn(b"Path=/_synapse/client/oidc", parts)
name, cookie = parts[0].split(b"=")
self.assertEqual(name, b"oidc_session")
# Ensure the code_verifier is set in the cookie.
macaroon = pymacaroons.Macaroon.deserialize(cookie)
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
self.assertNotEqual(code_verifier, "")
@override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "never"}})
def test_redirect_request_with_disabled_code_challenge(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
# The metadata should state that PKCE is enabled.
with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
url = urlparse(
self.get_success(
self.provider.handle_redirect_request(
req, b"http://client/redirect"
)
)
)
# Ensure the code_challenge param is added to the redirect.
params = parse_qs(url.query)
self.assertNotIn("code_challenge", params)
# Check what is in the cookies
self.assertEqual(len(req.cookies), 2) # two cookies
cookie_header = req.cookies[0]
# The cookie name and path don't really matter, just that it has to be coherent
# between the callback & redirect handlers.
parts = [p.strip() for p in cookie_header.split(b";")]
self.assertIn(b"Path=/_synapse/client/oidc", parts)
name, cookie = parts[0].split(b"=")
self.assertEqual(name, b"oidc_session")
# Ensure the code_verifier is blank in the cookie.
macaroon = pymacaroons.Macaroon.deserialize(cookie)
code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
self.assertEqual(code_verifier, "")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
- when the callback works, with userinfo from ID token
- when the user mapping fails
- when ID token verification fails
- when the callback works, with userinfo fetched from the userinfo endpoint
- when the userinfo fetching fails
- when the code exchange fails
"""
# ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented.
mapping_provider = self.provider._user_mapping_provider
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
username = "bar"
userinfo = {
"sub": "foo",
"username": username,
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
client_redirect_url = "http://client/redirect"
request, _ = self.start_authorization(
userinfo, client_redirect_url=client_redirect_url
)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
expected_user_id,
self.provider.idp_id,
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_not_called()
# Handle mapping errors
request, _ = self.start_authorization(userinfo)
with patch.object(
"_remote_id_from_userinfo",
new=Mock(side_effect=MappingException()),
):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
# Handle ID token errors
request, _ = self.start_authorization(userinfo)
with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.reset_mocks()
# With userinfo fetching
self.provider._user_profile_method = "userinfo_endpoint"
# Without the "openid" scope, the FakeProvider does not generate an id_token
request, _ = self.start_authorization(userinfo, scope="")
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
expected_user_id,
self.provider.idp_id,
request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_called_once()
self.reset_mocks()
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
request, grant = self.start_authorization(userinfo, with_sid=True)
self.assertIsNotNone(grant.sid)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
expected_user_id,
self.provider.idp_id,
request,
ANY,
None,
new_user=False,
auth_provider_session_id=grant.sid,
self.fake_server.post_token_handler.assert_called_once()
self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
# Handle userinfo fetching error
request, _ = self.start_authorization(userinfo)
with self.fake_server.buggy_endpoint(userinfo=True):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
request, _ = self.start_authorization(userinfo)
with self.fake_server.buggy_endpoint(token=True):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("server_error")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
# Missing cookie
request.args = {}
request.getCookie.return_value = None
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
session = self._generate_oidc_session_token(
state="state",
nonce="nonce",
client_redirect_url="http://client/redirect",
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {
"type": "Bearer",
"access_token": "aabbcc",
}
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["client_secret"], [CLIENT_SECRET])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
# Test providing a code verifier.
code_verifier = "code_verifier"
ret = self.get_success(
self.provider._exchange_code(code, code_verifier=code_verifier)
)
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["client_secret"], [CLIENT_SECRET])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
self.assertEqual(args["code_verifier"], [code_verifier])
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={"error": "foo", "error_description": "bar"}
from synapse.handlers.oidc import OidcError
exc = self.get_failure(
self.provider._exchange_code(code, code_verifier=""), OidcError
)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.fake_server.post_token_handler.return_value = FakeResponse(
code=500, body=b"Not JSON"
exc = self.get_failure(
self.provider._exchange_code(code, code_verifier=""), OidcError
)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=500, payload={"error": "internal_server_error"}
exc = self.get_failure(
self.provider._exchange_code(code, code_verifier=""), OidcError
)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=400, payload={}
exc = self.get_failure(
self.provider._exchange_code(code, code_verifier=""), OidcError
)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.fake_server.post_token_handler.return_value = FakeResponse.json(
code=200, payload={"error": "some_error"}
exc = self.get_failure(
self.provider._exchange_code(code, code_verifier=""), OidcError
)
self.assertEqual(exc.value.error, "some_error")
@override_config(
{
"oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "client_secret_post",
"client_secret_jwt_key": {
"key_file": _key_file_path(),
"jwt_header": {"alg": "ES256", "kid": "ABC789"},
"jwt_payload": {"iss": "DEFGHI"},
},
}
}
)
def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
token = {
"type": "Bearer",
"access_token": "aabbcc",
}
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
)
code = "code"
# advance the clock a bit before we start, so we aren't working with zero
# timestamps.
self.reactor.advance(1000)
start_time = self.reactor.seconds()
ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
# the client secret provided to the should be a jwt which can be checked with
# the public key
args = parse_qs(kwargs["data"].decode("utf-8"))
secret = args["client_secret"][0]
with open(_public_key_file_path()) as f:
key = f.read()
claims = jwt.decode(secret, key)
self.assertEqual(claims.header["kid"], "ABC789")
self.assertEqual(claims["aud"], ISSUER)
self.assertEqual(claims["iss"], "DEFGHI")
self.assertEqual(claims["sub"], CLIENT_ID)
self.assertEqual(claims["iat"], start_time)
self.assertGreater(claims["exp"], start_time)
# check the rest of the POSTed data
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
"enabled": True,
"client_id": CLIENT_ID,
"issuer": ISSUER,
"client_auth_method": "none",
}
}
)
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {
"type": "Bearer",
"access_token": "aabbcc",
}
self.fake_server.post_token_handler.side_effect = None
self.fake_server.post_token_handler.return_value = FakeResponse.json(
payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
self.assertEqual(ret, token)
# the request should have hit the token endpoint
kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
self.assertEqual(args["code"], [code])
self.assertEqual(args["client_id"], [CLIENT_ID])
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
@override_config(
{
"oidc_config": {
**DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
}
}
)
def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
"@foo:test",
self.provider.idp_id,
request,
ANY,
{"phone": "1234567"},
new_user=True,
auth_provider_session_id=None,
Warren Bailey
committed
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": True}})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
"sub": "test_user",
"username": "test_user",
}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
"@test_user:test",
self.provider.idp_id,
request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
self.reset_mocks()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_called_once_with(
"@test_user_2:test",
self.provider.idp_id,
request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
self.reset_mocks()
# Test if the mxid is already taken
store = self.hs.get_datastores().main
user3 = UserID.from_string("@test_user_3:test")
self.get_success(
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
Warren Bailey
committed
@override_config({"oidc_config": {**DEFAULT_CONFIG, "enable_registration": False}})
def test_map_userinfo_to_user_does_not_register_new_user(self) -> None:
"""Ensures new users are not registered if the enabled registration flag is disabled."""
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"User does not exist and registrations are disabled",
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})