From 23b626f2e68e985a3218abd0fc7d03b53bbcaf89 Mon Sep 17 00:00:00 2001
From: Quentin Gliech <quenting@element.io>
Date: Wed, 4 Dec 2024 12:04:49 +0100
Subject: [PATCH] Support for MSC4190: device management for application
 services (#17705)

This is an implementation of MSC4190, which allows appservices to manage
their user's devices without /login & /logout.

---------

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
---
 changelog.d/17705.feature               |   1 +
 synapse/appservice/__init__.py          |   2 +
 synapse/config/appservice.py            |  13 ++
 synapse/handlers/device.py              |  34 +++++
 synapse/handlers/register.py            |   6 +-
 synapse/rest/client/devices.py          |  62 +++++---
 synapse/rest/client/register.py         |   7 +-
 tests/handlers/test_appservice.py       |  15 +-
 tests/handlers/test_oauth_delegation.py |  31 +++-
 tests/rest/client/test_devices.py       | 181 ++++++++++++++++++++++++
 tests/rest/client/test_register.py      |  28 ++++
 tests/unittest.py                       |   4 +-
 12 files changed, 351 insertions(+), 33 deletions(-)
 create mode 100644 changelog.d/17705.feature

diff --git a/changelog.d/17705.feature b/changelog.d/17705.feature
new file mode 100644
index 0000000000..e2cd7bca4f
--- /dev/null
+++ b/changelog.d/17705.feature
@@ -0,0 +1 @@
+Support for [MSC4190](https://github.com/matrix-org/matrix-spec-proposals/pull/4190): device management for Application Services.
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index a96cdbf1e7..6ee5240c4e 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -87,6 +87,7 @@ class ApplicationService:
         ip_range_whitelist: Optional[IPSet] = None,
         supports_ephemeral: bool = False,
         msc3202_transaction_extensions: bool = False,
+        msc4190_device_management: bool = False,
     ):
         self.token = token
         self.url = (
@@ -100,6 +101,7 @@ class ApplicationService:
         self.ip_range_whitelist = ip_range_whitelist
         self.supports_ephemeral = supports_ephemeral
         self.msc3202_transaction_extensions = msc3202_transaction_extensions
+        self.msc4190_device_management = msc4190_device_management
 
         if "|" in self.id:
             raise Exception("application service ID cannot contain '|' character")
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 6ff00e1ff8..dda6bcd1b7 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -183,6 +183,18 @@ def _load_appservice(
             "The `org.matrix.msc3202` option should be true or false if specified."
         )
 
+    # Opt-in flag for the MSC4190 behaviours.
+    # When enabled, the following C-S API endpoints change for appservices:
+    # - POST /register does not return an access token
+    # - PUT /devices/{device_id} creates a new device if one does not exist
+    # - DELETE /devices/{device_id} no longer requires UIA
+    # - POST /delete_devices/{device_id} no longer requires UIA
+    msc4190_enabled = as_info.get("io.element.msc4190", False)
+    if not isinstance(msc4190_enabled, bool):
+        raise ValueError(
+            "The `io.element.msc4190` option should be true or false if specified."
+        )
+
     return ApplicationService(
         token=as_info["as_token"],
         url=as_info["url"],
@@ -195,4 +207,5 @@ def _load_appservice(
         ip_range_whitelist=ip_range_whitelist,
         supports_ephemeral=supports_ephemeral,
         msc3202_transaction_extensions=msc3202_transaction_extensions,
+        msc4190_device_management=msc4190_enabled,
     )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d88660e273..d9622080b4 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -729,6 +729,40 @@ class DeviceHandler(DeviceWorkerHandler):
 
         await self.notify_device_update(user_id, device_ids)
 
+    async def upsert_device(
+        self, user_id: str, device_id: str, display_name: Optional[str] = None
+    ) -> bool:
+        """Create or update a device
+
+        Args:
+            user_id: The user to update devices of.
+            device_id: The device to update.
+            display_name: The new display name for this device.
+
+        Returns:
+            True if the device was created, False if it was updated.
+
+        """
+
+        # Reject a new displayname which is too long.
+        self._check_device_name_length(display_name)
+
+        created = await self.store.store_device(
+            user_id,
+            device_id,
+            initial_device_display_name=display_name,
+        )
+
+        if not created:
+            await self.store.update_device(
+                user_id,
+                device_id,
+                new_display_name=display_name,
+            )
+
+        await self.notify_device_update(user_id, [device_id])
+        return created
+
     async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
         """Update the given device
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c200e29569..c49db83ce7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -630,7 +630,9 @@ class RegistrationHandler:
         """
         await self._auto_join_rooms(user_id)
 
-    async def appservice_register(self, user_localpart: str, as_token: str) -> str:
+    async def appservice_register(
+        self, user_localpart: str, as_token: str
+    ) -> Tuple[str, ApplicationService]:
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -653,7 +655,7 @@ class RegistrationHandler:
             appservice_id=service_id,
             create_profile_with_displayname=user.localpart,
         )
-        return user_id
+        return (user_id, service)
 
     def check_user_id_not_appservice_exclusive(
         self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 6a45a5d130..4607b23494 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -114,15 +114,19 @@ class DeleteDevicesRestServlet(RestServlet):
             else:
                 raise e
 
-        await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body.dict(exclude_unset=True),
-            "remove device(s) from your account",
-            # Users might call this multiple times in a row while cleaning up
-            # devices, allow a single UI auth session to be re-used.
-            can_skip_ui_auth=True,
-        )
+        if requester.app_service and requester.app_service.msc4190_device_management:
+            # MSC4190 can skip UIA for this endpoint
+            pass
+        else:
+            await self.auth_handler.validate_user_via_ui_auth(
+                requester,
+                request,
+                body.dict(exclude_unset=True),
+                "remove device(s) from your account",
+                # Users might call this multiple times in a row while cleaning up
+                # devices, allow a single UI auth session to be re-used.
+                can_skip_ui_auth=True,
+            )
 
         await self.device_handler.delete_devices(
             requester.user.to_string(), body.devices
@@ -175,9 +179,6 @@ class DeviceRestServlet(RestServlet):
     async def on_DELETE(
         self, request: SynapseRequest, device_id: str
     ) -> Tuple[int, JsonDict]:
-        if self._msc3861_oauth_delegation_enabled:
-            raise UnrecognizedRequestError(code=404)
-
         requester = await self.auth.get_user_by_req(request)
 
         try:
@@ -192,15 +193,24 @@ class DeviceRestServlet(RestServlet):
             else:
                 raise
 
-        await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body.dict(exclude_unset=True),
-            "remove a device from your account",
-            # Users might call this multiple times in a row while cleaning up
-            # devices, allow a single UI auth session to be re-used.
-            can_skip_ui_auth=True,
-        )
+        if requester.app_service and requester.app_service.msc4190_device_management:
+            # MSC4190 allows appservices to delete devices through this endpoint without UIA
+            # It's also allowed with MSC3861 enabled
+            pass
+
+        else:
+            if self._msc3861_oauth_delegation_enabled:
+                raise UnrecognizedRequestError(code=404)
+
+            await self.auth_handler.validate_user_via_ui_auth(
+                requester,
+                request,
+                body.dict(exclude_unset=True),
+                "remove a device from your account",
+                # Users might call this multiple times in a row while cleaning up
+                # devices, allow a single UI auth session to be re-used.
+                can_skip_ui_auth=True,
+            )
 
         await self.device_handler.delete_devices(
             requester.user.to_string(), [device_id]
@@ -216,6 +226,16 @@ class DeviceRestServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         body = parse_and_validate_json_object_from_request(request, self.PutBody)
+
+        # MSC4190 allows appservices to create devices through this endpoint
+        if requester.app_service and requester.app_service.msc4190_device_management:
+            created = await self.device_handler.upsert_device(
+                user_id=requester.user.to_string(),
+                device_id=device_id,
+                display_name=body.display_name,
+            )
+            return 201 if created else 200, {}
+
         await self.device_handler.update_device(
             requester.user.to_string(), device_id, body.dict()
         )
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 61e1436841..ad76f188ab 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -771,9 +771,12 @@ class RegisterRestServlet(RestServlet):
         body: JsonDict,
         should_issue_refresh_token: bool = False,
     ) -> JsonDict:
-        user_id = await self.registration_handler.appservice_register(
+        user_id, appservice = await self.registration_handler.appservice_register(
             username, as_token
         )
+        if appservice.msc4190_device_management:
+            body["inhibit_login"] = True
+
         return await self._create_registration_details(
             user_id,
             body,
@@ -937,7 +940,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet):
 
         as_token = self.auth.get_access_token_from_request(request)
 
-        user_id = await self.registration_handler.appservice_register(
+        user_id, _ = await self.registration_handler.appservice_register(
             desired_username, as_token
         )
         return 200, {"user_id": user_id}
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 1eec0d43b7..1db630e9e4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -1165,12 +1165,23 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastores().main.services_cache = [self._service]
 
         # Register some appservice users
-        self._sender_user, self._sender_device = self.register_appservice_user(
+        user_id, device_id = self.register_appservice_user(
             "as.sender", self._service_token
         )
-        self._namespaced_user, self._namespaced_device = self.register_appservice_user(
+        # With MSC4190 enabled, there will not be a device created
+        # during AS registration. However MSC4190 is not enabled
+        # in this test. It may become the default behaviour in the
+        # future, in which case this test will need to be updated.
+        assert device_id is not None
+        self._sender_user = user_id
+        self._sender_device = device_id
+
+        user_id, device_id = self.register_appservice_user(
             "_as_user1", self._service_token
         )
+        assert device_id is not None
+        self._namespaced_user = user_id
+        self._namespaced_device = device_id
 
         # Register a real user as well.
         self._real_user = self.register_user("real.user", "meow")
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index 5b5dc713d1..5f73469daa 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -560,9 +560,15 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         self.assertEqual(channel.code, 401, channel.json_body)
 
     def expect_unrecognized(
-        self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
+        self,
+        method: str,
+        path: str,
+        content: Union[bytes, str, JsonDict] = "",
+        auth: bool = False,
     ) -> None:
-        channel = self.make_request(method, path, content)
+        channel = self.make_request(
+            method, path, content, access_token="token" if auth else None
+        )
 
         self.assertEqual(channel.code, 404, channel.json_body)
         self.assertEqual(
@@ -648,8 +654,25 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
 
     def test_device_management_endpoints_removed(self) -> None:
         """Test that device management endpoints that were removed in MSC2964 are no longer available."""
-        self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
-        self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
+
+        # Because we still support those endpoints with ASes, it checks the
+        # access token before returning 404
+        self.http_client.request = AsyncMock(
+            return_value=FakeResponse.json(
+                code=200,
+                payload={
+                    "active": True,
+                    "sub": SUBJECT,
+                    "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
+                    "username": USERNAME,
+                },
+            )
+        )
+
+        self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
+        self.expect_unrecognized(
+            "DELETE", "/_matrix/client/v3/devices/{DEVICE}", auth=True
+        )
 
     def test_openid_endpoints_removed(self) -> None:
         """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index a3ed12a38f..dd3abdebac 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -24,6 +24,7 @@ from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError
+from synapse.appservice import ApplicationService
 from synapse.rest import admin, devices, sync
 from synapse.rest.client import keys, login, register
 from synapse.server import HomeServer
@@ -455,3 +456,183 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
             token,
         )
         self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}})
+
+
+class MSC4190AppserviceDevicesTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        register.register_servlets,
+        devices.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        self.hs = self.setup_test_homeserver()
+
+        # This application service uses the new MSC4190 behaviours
+        self.msc4190_service = ApplicationService(
+            id="msc4190",
+            token="some_token",
+            hs_token="some_token",
+            sender="@as:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+            },
+            msc4190_device_management=True,
+        )
+        # This application service doesn't use the new MSC4190 behaviours
+        self.pre_msc_service = ApplicationService(
+            id="regular",
+            token="other_token",
+            hs_token="other_token",
+            sender="@as2:example.com",
+            namespaces={
+                ApplicationService.NS_USERS: [{"regex": "@.*", "exclusive": False}]
+            },
+            msc4190_device_management=False,
+        )
+        self.hs.get_datastores().main.services_cache.append(self.msc4190_service)
+        self.hs.get_datastores().main.services_cache.append(self.pre_msc_service)
+        return self.hs
+
+    def test_PUT_device(self) -> None:
+        self.register_appservice_user("alice", self.msc4190_service.token)
+        self.register_appservice_user("bob", self.pre_msc_service.token)
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.json_body, {"devices": []})
+
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+            content={"display_name": "Alice's device"},
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 201, channel.json_body)
+
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(len(channel.json_body["devices"]), 1)
+        self.assertEqual(channel.json_body["devices"][0]["device_id"], "AABBCCDD")
+
+        # Doing a second time should return a 200 instead of a 201
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+            content={"display_name": "Alice's device"},
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # On the regular service, that API should not allow for the
+        # creation of new devices.
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@bob:test",
+            content={"display_name": "Bob's device"},
+            access_token=self.pre_msc_service.token,
+        )
+        self.assertEqual(channel.code, 404, channel.json_body)
+
+    def test_DELETE_device(self) -> None:
+        self.register_appservice_user("alice", self.msc4190_service.token)
+
+        # There should be no device
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.json_body, {"devices": []})
+
+        # Create a device
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+            content={},
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 201, channel.json_body)
+
+        # There should be one device
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(len(channel.json_body["devices"]), 1)
+
+        # Delete the device. UIA should not be required.
+        channel = self.make_request(
+            "DELETE",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # There should be no device again
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.json_body, {"devices": []})
+
+    def test_POST_delete_devices(self) -> None:
+        self.register_appservice_user("alice", self.msc4190_service.token)
+
+        # There should be no device
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.json_body, {"devices": []})
+
+        # Create a device
+        channel = self.make_request(
+            "PUT",
+            "/_matrix/client/v3/devices/AABBCCDD?user_id=@alice:test",
+            content={},
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 201, channel.json_body)
+
+        # There should be one device
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(len(channel.json_body["devices"]), 1)
+
+        # Delete the device with delete_devices
+        # UIA should not be required.
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/delete_devices?user_id=@alice:test",
+            content={"devices": ["AABBCCDD"]},
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+
+        # There should be no device again
+        channel = self.make_request(
+            "GET",
+            "/_matrix/client/v3/devices?user_id=@alice:test",
+            access_token=self.msc4190_service.token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self.assertEqual(channel.json_body, {"devices": []})
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index c091f403cc..b697bf6f67 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -120,6 +120,34 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 401, msg=channel.result)
 
+    def test_POST_appservice_msc4190_enabled(self) -> None:
+        # With MSC4190 enabled, the registration should *not* return an access token
+        user_id = "@as_user_kermit:test"
+        as_token = "i_am_an_app_service"
+
+        appservice = ApplicationService(
+            as_token,
+            id="1234",
+            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            sender="@as:test",
+            msc4190_device_management=True,
+        )
+
+        self.hs.get_datastores().main.services_cache.append(appservice)
+        request_data = {
+            "username": "as_user_kermit",
+            "type": APP_SERVICE_REGISTRATION_TYPE,
+        }
+
+        channel = self.make_request(
+            b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
+        )
+
+        self.assertEqual(channel.code, 200, msg=channel.result)
+        det_data = {"user_id": user_id, "home_server": self.hs.hostname}
+        self.assertLessEqual(det_data.items(), channel.json_body.items())
+        self.assertNotIn("access_token", channel.json_body)
+
     def test_POST_bad_password(self) -> None:
         request_data = {"username": "kermit", "password": 666}
         channel = self.make_request(b"POST", self.url, request_data)
diff --git a/tests/unittest.py b/tests/unittest.py
index 614e805abd..6a32861a3e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -781,7 +781,7 @@ class HomeserverTestCase(TestCase):
         self,
         username: str,
         appservice_token: str,
-    ) -> Tuple[str, str]:
+    ) -> Tuple[str, Optional[str]]:
         """Register an appservice user as an application service.
         Requires the client-facing registration API be registered.
 
@@ -805,7 +805,7 @@ class HomeserverTestCase(TestCase):
             access_token=appservice_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
-        return channel.json_body["user_id"], channel.json_body["device_id"]
+        return channel.json_body["user_id"], channel.json_body.get("device_id")
 
     def login(
         self,
-- 
GitLab