Skip to content
Snippets Groups Projects
Commit 22112f8d authored by Steven Hammerton's avatar Steven Hammerton
Browse files

Formatting changes

parent c33f5c1a
Branches
Tags
No related merge requests found
...@@ -298,7 +298,8 @@ class AuthHandler(BaseHandler): ...@@ -298,7 +298,8 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def login_with_cas_user_id(self, user_id): def login_with_cas_user_id(self, user_id):
""" """
Authenticates the user with the given user ID, intended to have been captured from a CAS response Authenticates the user with the given user ID,
intended to have been captured from a CAS response
Args: Args:
user_id (str): User ID user_id (str): User ID
......
...@@ -77,9 +77,13 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -77,9 +77,13 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
url = "%s/proxyValidate" % (self.cas_server_url) url = "%s/proxyValidate" % (self.cas_server_url)
parameters = {"ticket": login_submission["ticket"], "service": login_submission["service"]} parameters = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
response = requests.get(url, verify=False, params=parameters) response = requests.get(url, verify=False, params=parameters)
result = yield self.do_cas_login(response.text) result = yield self.do_cas_login(response.text)
defer.returnValue(result) defer.returnValue(result)
...@@ -130,7 +134,8 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -130,7 +134,8 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = yield auth_handler.login_with_cas_user_id(user_id) user_id, access_token, refresh_token = yield
auth_handler.login_with_cas_user_id(user_id)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
...@@ -139,7 +144,8 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -139,7 +144,8 @@ class LoginRestServlet(ClientV1RestServlet):
} }
else: else:
user_id, access_token = yield self.handlers.registration_handler.register(localpart=user) user_id, access_token = yield
self.handlers.registration_handler.register(localpart=user)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
...@@ -148,7 +154,6 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -148,7 +154,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
...@@ -224,6 +229,7 @@ class SAML2RestServlet(ClientV1RestServlet): ...@@ -224,6 +229,7 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERN = client_path_pattern("/login/cas")
...@@ -234,6 +240,7 @@ class CasRestServlet(ClientV1RestServlet): ...@@ -234,6 +240,7 @@ class CasRestServlet(ClientV1RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url}) return (200, {"serverUrl": self.cas_server_url})
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment