Skip to content
Snippets Groups Projects
Unverified Commit 30da50a5 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Initialise user displayname from SAML2 data (#4272)

When we register a new user from SAML2 data, initialise their displayname
correctly.
parent 35e13477
No related branches found
No related tags found
No related merge requests found
SAML2 authentication: Initialise user display name from SAML2 data
...@@ -126,6 +126,7 @@ class RegistrationHandler(BaseHandler): ...@@ -126,6 +126,7 @@ class RegistrationHandler(BaseHandler):
make_guest=False, make_guest=False,
admin=False, admin=False,
threepid=None, threepid=None,
default_display_name=None,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
...@@ -140,6 +141,8 @@ class RegistrationHandler(BaseHandler): ...@@ -140,6 +141,8 @@ class RegistrationHandler(BaseHandler):
since it offers no means of associating a device_id with the since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token access_token. Instead you should call auth_handler.issue_access_token
after registration. after registration.
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
...@@ -169,6 +172,13 @@ class RegistrationHandler(BaseHandler): ...@@ -169,6 +172,13 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
if was_guest:
# If the user was a guest then they already have a profile
default_display_name = None
elif default_display_name is None:
default_display_name = localpart
token = None token = None
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
...@@ -178,10 +188,7 @@ class RegistrationHandler(BaseHandler): ...@@ -178,10 +188,7 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=( create_profile_with_displayname=default_display_name,
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin, admin=admin,
) )
...@@ -203,13 +210,15 @@ class RegistrationHandler(BaseHandler): ...@@ -203,13 +210,15 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_not_appservice_exclusive(user_id) yield self.check_user_id_not_appservice_exclusive(user_id)
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if default_display_name is None:
default_display_name = localpart
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=default_display_name,
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
...@@ -300,7 +309,7 @@ class RegistrationHandler(BaseHandler): ...@@ -300,7 +309,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=user.localpart,
) )
defer.returnValue(user_id) defer.returnValue(user_id)
...@@ -478,7 +487,7 @@ class RegistrationHandler(BaseHandler): ...@@ -478,7 +487,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
create_profile_with_localpart=user.localpart, create_profile_with_displayname=user.localpart,
) )
else: else:
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
......
...@@ -451,6 +451,7 @@ class SSOAuthHandler(object): ...@@ -451,6 +451,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_successful_auth( def on_successful_auth(
self, username, request, client_redirect_url, self, username, request, client_redirect_url,
user_display_name=None,
): ):
"""Called once the user has successfully authenticated with the SSO. """Called once the user has successfully authenticated with the SSO.
...@@ -467,6 +468,9 @@ class SSOAuthHandler(object): ...@@ -467,6 +468,9 @@ class SSOAuthHandler(object):
client_redirect_url (unicode): the redirect_url the client gave us when client_redirect_url (unicode): the redirect_url the client gave us when
it first started the process. it first started the process.
user_display_name (unicode|None): if set, and we have to register a new user,
we will set their displayname to this.
Returns: Returns:
Deferred[none]: Completes once we have handled the request. Deferred[none]: Completes once we have handled the request.
""" """
...@@ -478,6 +482,7 @@ class SSOAuthHandler(object): ...@@ -478,6 +482,7 @@ class SSOAuthHandler(object):
yield self._registration_handler.register( yield self._registration_handler.register(
localpart=localpart, localpart=localpart,
generate_token=False, generate_token=False,
default_display_name=user_display_name,
) )
) )
......
...@@ -66,6 +66,9 @@ class SAML2ResponseResource(Resource): ...@@ -66,6 +66,9 @@ class SAML2ResponseResource(Resource):
raise CodeMessageException(400, "uid not in SAML2 response") raise CodeMessageException(400, "uid not in SAML2 response")
username = saml2_auth.ava["uid"][0] username = saml2_auth.ava["uid"][0]
displayName = saml2_auth.ava.get("displayName", [None])[0]
return self._sso_auth_handler.on_successful_auth( return self._sso_auth_handler.on_successful_auth(
username, request, relay_state, username, request, relay_state,
user_display_name=displayName,
) )
...@@ -22,6 +22,7 @@ from twisted.internet import defer ...@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.storage import background_updates from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
...@@ -167,7 +168,7 @@ class RegistrationStore(RegistrationWorkerStore, ...@@ -167,7 +168,7 @@ class RegistrationStore(RegistrationWorkerStore,
def register(self, user_id, token=None, password_hash=None, def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False): create_profile_with_displayname=None, admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
...@@ -181,8 +182,8 @@ class RegistrationStore(RegistrationWorkerStore, ...@@ -181,8 +182,8 @@ class RegistrationStore(RegistrationWorkerStore,
make_guest (boolean): True if the the new user should be guest, make_guest (boolean): True if the the new user should be guest,
false to add a regular user account. false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user. appservice_id (str): The ID of the appservice registering the user.
create_profile_with_localpart (str): Optionally create a profile for create_profile_with_displayname (unicode): Optionally create a profile for
the given localpart. the user, setting their displayname to the given value
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
""" """
...@@ -195,7 +196,7 @@ class RegistrationStore(RegistrationWorkerStore, ...@@ -195,7 +196,7 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_displayname,
admin admin
) )
...@@ -208,9 +209,11 @@ class RegistrationStore(RegistrationWorkerStore, ...@@ -208,9 +209,11 @@ class RegistrationStore(RegistrationWorkerStore,
was_guest, was_guest,
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_displayname,
admin, admin,
): ):
user_id_obj = UserID.from_string(user_id)
now = int(self.clock.time()) now = int(self.clock.time())
next_id = self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
...@@ -273,12 +276,15 @@ class RegistrationStore(RegistrationWorkerStore, ...@@ -273,12 +276,15 @@ class RegistrationStore(RegistrationWorkerStore,
(next_id, user_id, token,) (next_id, user_id, token,)
) )
if create_profile_with_localpart: if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race # set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames # between auto-joins and clients trying to set displaynames
#
# *obviously* the 'profiles' table uses localpart for user_id
# while everything else uses the full mxid.
txn.execute( txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)", "INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(create_profile_with_localpart, create_profile_with_localpart) (user_id_obj.localpart, create_profile_with_displayname)
) )
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
......
...@@ -149,7 +149,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase): ...@@ -149,7 +149,7 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
def test_populate_monthly_users_is_guest(self): def test_populate_monthly_users_is_guest(self):
# Test that guest users are not added to mau list # Test that guest users are not added to mau list
user_id = "user_id" user_id = "@user_id:host"
self.store.register( self.store.register(
user_id=user_id, token="123", password_hash=None, make_guest=True user_id=user_id, token="123", password_hash=None, make_guest=True
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment