Skip to content
Snippets Groups Projects
Unverified Commit c588b9b9 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Factor SSO success handling out of CAS login (#4264)

This is mostly factoring out the post-CAS-login code to somewhere we can reuse
it for other SSO flows, but it also fixes the userid mapping while we're at it.
parent b0c24a66
No related branches found
No related tags found
No related merge requests found
Fix CAS login when username is not valid in an MXID
...@@ -563,10 +563,10 @@ class AuthHandler(BaseHandler): ...@@ -563,10 +563,10 @@ class AuthHandler(BaseHandler):
insensitively, but return None if there are multiple inexact matches. insensitively, but return None if there are multiple inexact matches.
Args: Args:
(str) user_id: complete @user:id (unicode|bytes) user_id: complete @user:id
Returns: Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches multiple matches
""" """
res = yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
...@@ -954,6 +954,15 @@ class MacaroonGenerator(object): ...@@ -954,6 +954,15 @@ class MacaroonGenerator(object):
return macaroon.serialize() return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
"""
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
......
...@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError ...@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import (
from synapse.types import UserID RestServlet,
parse_json_object_from_request,
parse_string,
)
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
...@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet): ...@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler() self._sso_auth_handler = SSOAuthHandler(hs)
self.handlers = hs.get_handlers()
self.macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = request.args[b"redirectUrl"][0] client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client() http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": request.args[b"ticket"][0].decode('ascii'), "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
...@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet): ...@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
result = yield self.handle_cas_response(request, body, client_redirect_url) result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url): def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
...@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet): ...@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value: if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string() return self._sso_auth_handler.on_successful_auth(
auth_handler = self.auth_handler user, request, client_redirect_url,
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
) )
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
finish_request(request)
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None user = None
...@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet): ...@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
return user, attributes return user, attributes
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
service
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_handlers().registration_handler
self._macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks
def on_successful_auth(
self, username, request, client_redirect_url,
):
"""Called once the user has successfully authenticated with the SSO.
Registers the user if necessary, and then returns a redirect (with
a login token) to the client.
Args:
username (unicode|bytes): the remote user id. We'll map this onto
something sane for a MXID localpath.
request (SynapseRequest): the incoming request from the browser. We'll
respond to it with a redirect.
client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process.
Returns:
Deferred[none]: Completes once we have handled the request.
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self._registration_handler.register(
localpart=localpart,
generate_token=False,
)
)
login_token = self._macaroon_gen.generate_short_term_login_token(
registered_user_id
)
redirect_url = self._add_login_token_to_redirect_url(
client_redirect_url, login_token
)
request.redirect(redirect_url)
finish_request(request)
@staticmethod
def _add_login_token_to_redirect_url(url, token):
url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.cas_enabled: if hs.config.cas_enabled:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import string import string
from collections import namedtuple from collections import namedtuple
...@@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart): ...@@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart) return any(c not in mxid_localpart_allowed_characters for c in localpart)
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
# localpart.
#
# It works by:
# * building a string containing the allowed characters (excluding '=')
# * escaping every special character with a backslash (to stop '-' being interpreted as a
# range operator)
# * wrapping it in a '[^...]' regex
# * converting the whole lot to a 'bytes' sequence, so that we can use it to match
# bytes rather than strings
#
NON_MXID_CHARACTER_PATTERN = re.compile(
("[^%s]" % (
re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
)).encode("ascii"),
)
def map_username_to_mxid_localpart(username, case_sensitive=False):
"""Map a username onto a string suitable for a MXID
This follows the algorithm laid out at
https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
Args:
username (unicode|bytes): username to be mapped
case_sensitive (bool): true if TEST and test should be mapped
onto different mxids
Returns:
unicode: string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
username = username.encode('utf-8')
# first we sort out upper-case characters
if case_sensitive:
def f1(m):
return b"_" + m.group().lower()
username = UPPER_CASE_PATTERN.sub(f1, username)
else:
username = username.lower()
# then we sort out non-ascii characters
def f2(m):
g = m.group()[0]
if isinstance(g, str):
# on python 2, we need to do a ord(). On python 3, the
# byte itself will do.
g = ord(g)
return b"=%02x" % (g,)
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
# we also do the =-escaping to mxids starting with an underscore.
username = re.sub(b'^_', b'=5f', username)
# we should now only have ascii bytes left, so can decode back to a
# unicode.
return username.decode('ascii')
class StreamToken( class StreamToken(
namedtuple("Token", ( namedtuple("Token", (
"room_key", "room_key",
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest from tests import unittest
from tests.utils import TestHomeServer from tests.utils import TestHomeServer
...@@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase): ...@@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase):
except SynapseError as exc: except SynapseError as exc:
self.assertEqual(400, exc.code) self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode) self.assertEqual("M_UNKNOWN", exc.errcode)
class MapUsernameTestCase(unittest.TestCase):
def testPassThrough(self):
self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
def testUpperCase(self):
self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
self.assertEqual(
map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
"t_e_s_t__1234",
)
def testSymbols(self):
self.assertEqual(
map_username_to_mxid_localpart("test=$?_1234"),
"test=3d=24=3f_1234",
)
def testLeadingUnderscore(self):
self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
def testNonAscii(self):
# this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
self.assertEqual(
map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
"t=c3=aast",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment