Skip to content
Snippets Groups Projects
Commit 22112f8d authored by Steven Hammerton's avatar Steven Hammerton
Browse files

Formatting changes

parent c33f5c1a
No related branches found
No related tags found
No related merge requests found
...@@ -298,7 +298,8 @@ class AuthHandler(BaseHandler): ...@@ -298,7 +298,8 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def login_with_cas_user_id(self, user_id): def login_with_cas_user_id(self, user_id):
""" """
Authenticates the user with the given user ID, intended to have been captured from a CAS response Authenticates the user with the given user ID,
intended to have been captured from a CAS response
Args: Args:
user_id (str): User ID user_id (str): User ID
......
...@@ -77,9 +77,13 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -77,9 +77,13 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] == LoginRestServlet.CAS_TYPE): elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
url = "%s/proxyValidate" % (self.cas_server_url) url = "%s/proxyValidate" % (self.cas_server_url)
parameters = {"ticket": login_submission["ticket"], "service": login_submission["service"]} parameters = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
response = requests.get(url, verify=False, params=parameters) response = requests.get(url, verify=False, params=parameters)
result = yield self.do_cas_login(response.text) result = yield self.do_cas_login(response.text)
defer.returnValue(result) defer.returnValue(result)
...@@ -130,7 +134,8 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -130,7 +134,8 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.handlers.auth_handler auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = yield auth_handler.login_with_cas_user_id(user_id) user_id, access_token, refresh_token = yield
auth_handler.login_with_cas_user_id(user_id)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
...@@ -139,7 +144,8 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -139,7 +144,8 @@ class LoginRestServlet(ClientV1RestServlet):
} }
else: else:
user_id, access_token = yield self.handlers.registration_handler.register(localpart=user) user_id, access_token = yield
self.handlers.registration_handler.register(localpart=user)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
...@@ -148,7 +154,6 @@ class LoginRestServlet(ClientV1RestServlet): ...@@ -148,7 +154,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
...@@ -224,6 +229,7 @@ class SAML2RestServlet(ClientV1RestServlet): ...@@ -224,6 +229,7 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue(None) defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERN = client_path_pattern("/login/cas")
...@@ -234,6 +240,7 @@ class CasRestServlet(ClientV1RestServlet): ...@@ -234,6 +240,7 @@ class CasRestServlet(ClientV1RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url}) return (200, {"serverUrl": self.cas_server_url})
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment