Skip to content
Snippets Groups Projects
Commit 76421c49 authored by Steven Hammerton's avatar Steven Hammerton
Browse files

Allow optional config params for a required attribute and it's value, if...

Allow optional config params for a required attribute and it's value, if specified any CAS user must have the given attribute and the value must equal
parent 7845f62c
No related branches found
No related tags found
No related merge requests found
...@@ -27,13 +27,28 @@ class CasConfig(Config): ...@@ -27,13 +27,28 @@ class CasConfig(Config):
if cas_config: if cas_config:
self.cas_enabled = True self.cas_enabled = True
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
if "required_attribute" in cas_config:
self.cas_required_attribute = cas_config["required_attribute"]
else:
self.cas_required_attribute = None
if "required_attribute_value" in cas_config:
self.cas_required_attribute_value = cas_config["required_attribute_value"]
else:
self.cas_required_attribute_value = None
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_required_attribute = None
self.cas_required_attribute_value = None
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable CAS for registration and login. # Enable CAS for registration and login.
#cas_config: #cas_config:
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# #required_attribute: something
# #required_attribute_value: true
""" """
...@@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -45,8 +45,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.idp_redirect_url = hs.config.saml2_idp_redirect_url self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_required_attribute = hs.config.cas_required_attribute
self.cas_required_attribute_value = hs.config.cas_required_attribute_value
self.servername = hs.config.server_name self.servername = hs.config.server_name
def on_GET(self, request): def on_GET(self, request):
...@@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -126,6 +127,19 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
(user, attributes) = self.parse_cas_response(cas_response_body) (user, attributes) = self.parse_cas_response(cas_response_body)
if self.cas_required_attribute is not None:
# If required attribute was not in CAS Response - Forbidden
if self.cas_required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if self.cas_required_attribute_value is not None:
actualValue = attributes[self.cas_required_attribute]
# If required attribute value does not match expected - Forbidden
if self.cas_required_attribute_value != actualValue:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment