Skip to content
Snippets Groups Projects
Unverified Commit 7afb5e04 authored by Hanadi's avatar Hanadi Committed by GitHub
Browse files

Fix using dehydrated devices (MSC2697) & refresh tokens (#16288)

Refresh tokens were not correctly moved to the rehydrated
device (similar to how the access token is currently handled).
This resulted in invalid refresh tokens after rehydration.
parent d38d0dff
No related branches found
No related tags found
No related merge requests found
Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.
...@@ -758,12 +758,13 @@ class DeviceHandler(DeviceWorkerHandler): ...@@ -758,12 +758,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID # If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access # matched the stored dehydrated device), then modify the access
# token to use the dehydrated device's ID and copy the old device # token and refresh token to use the dehydrated device's ID and
# display name to the dehydrated device, and destroy the old device # copy the old device display name to the dehydrated device,
# ID # and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token( old_device_id = await self.store.set_device_for_access_token(
access_token, device_id access_token, device_id
) )
await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id) old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None: if old_device is None:
raise errors.NotFoundError() raise errors.NotFoundError()
......
...@@ -2312,6 +2312,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): ...@@ -2312,6 +2312,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id return next_id
async def set_device_for_refresh_token(
self, user_id: str, old_device_id: str, device_id: str
) -> None:
"""Moves refresh tokens from old device to current device
Args:
user_id: The user of the devices.
old_device_id: The old device.
device_id: The new device ID.
Returns:
None
"""
await self.db_pool.simple_update(
"refresh_tokens",
keyvalues={"user_id": user_id, "device_id": old_device_id},
updatevalues={"device_id": device_id},
desc="set_device_for_refresh_token",
)
def _set_device_for_access_token_txn( def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str self, txn: LoggingTransaction, token: str, device_id: str
) -> str: ) -> str:
......
...@@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ...@@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.message_handler = hs.get_device_message_handler() self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler() self.registration = hs.get_registration_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
...@@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ...@@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
# Create a new login for the user and dehydrated the device # Create a new login for the user and dehydrated the device
device_id, access_token, _expiration_time, _refresh_token = self.get_success( device_id, access_token, _expiration_time, refresh_token = self.get_success(
self.registration.register_device( self.registration.register_device(
user_id=user_id, user_id=user_id,
device_id=None, device_id=None,
initial_display_name="new device", initial_display_name="new device",
should_issue_refresh_token=True,
) )
) )
...@@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ...@@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(user_info.device_id, retrieved_device_id) self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the user device has the refresh token
assert refresh_token is not None
self.get_success(
self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
)
# make sure the device has the display name that was set from the login # make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment