Skip to content
Snippets Groups Projects
test_appservice.py 44.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • 
            # Define an application service for the tests
            self._service_token = "VERYSECRET"
            self._service = ApplicationService(
                self._service_token,
                "as1",
                "@as.sender:test",
                namespaces={
                    "users": [
                        {"regex": "@_as_.*:test", "exclusive": True},
                        {"regex": "@as.sender:test", "exclusive": True},
                    ]
                },
                msc3202_transaction_extensions=True,
            )
            self.hs.get_datastores().main.services_cache = [self._service]
    
            # Register some appservice users
            self._sender_user, self._sender_device = self.register_appservice_user(
                "as.sender", self._service_token
            )
            self._namespaced_user, self._namespaced_device = self.register_appservice_user(
                "_as_user1", self._service_token
            )
    
            # Register a real user as well.
            self._real_user = self.register_user("real.user", "meow")
            self._real_user_token = self.login("real.user", "meow")
    
        async def _add_otks_for_device(
            self, user_id: str, device_id: str, otk_count: int
        ) -> None:
            """
            Add some dummy keys. It doesn't matter if they're not a real algorithm;
            that should be opaque to the server anyway.
            """
            await self.hs.get_datastores().main.add_e2e_one_time_keys(
                user_id,
                device_id,
                self.clock.time_msec(),
                [("algo", f"k{i}", "{}") for i in range(otk_count)],
            )
    
        async def _add_fallback_key_for_device(
            self, user_id: str, device_id: str, used: bool
        ) -> None:
            """
            Adds a fake fallback key to a device, optionally marking it as used
            right away.
            """
            store = self.hs.get_datastores().main
            await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"})
            if used is True:
                # Mark the key as used
                await store.db_pool.simple_update_one(
                    table="e2e_fallback_keys_json",
                    keyvalues={
                        "user_id": user_id,
                        "device_id": device_id,
                        "algorithm": "algo",
                        "key_id": "fk",
                    },
                    updatevalues={"used": True},
                    desc="_get_fallback_key_set_used",
                )
    
        def _set_up_devices_and_a_room(self) -> str:
            """
            Helper to set up devices for all the users
            and a room for the users to talk in.
            """
    
            async def preparation():
                await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
                await self._add_fallback_key_for_device(
                    self._sender_user, self._sender_device, used=True
                )
                await self._add_otks_for_device(
                    self._namespaced_user, self._namespaced_device, 36
                )
                await self._add_fallback_key_for_device(
                    self._namespaced_user, self._namespaced_device, used=False
                )
    
                # Register a device for the real user, too, so that we can later ensure
                # that we don't leak information to the AS about the non-AS user.
                await self.hs.get_datastores().main.store_device(
                    self._real_user, "REALDEV", "UltraMatrix 3000"
                )
                await self._add_otks_for_device(self._real_user, "REALDEV", 50)
    
            self.get_success(preparation())
    
            room_id = self.helper.create_room_as(
                self._real_user, is_public=True, tok=self._real_user_token
            )
            self.helper.join(
                room_id,
                self._namespaced_user,
                tok=self._service_token,
                appservice_user_id=self._namespaced_user,
            )
    
            # Check it was called for sanity. (This was to send the join event to the AS.)
            self.send_mock.assert_called()
            self.send_mock.reset_mock()
    
            return room_id
    
        @override_config(
            {"experimental_features": {"msc3202_transaction_extensions": True}}
        )
        def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus(
            self,
        ) -> None:
            """
            Tests that:
            - the AS receives one-time key counts and unused fallback keys for:
                - the specified sender; and
                - any user who is in receipt of the PDUs
            """
    
            room_id = self._set_up_devices_and_a_room()
    
            # Send a message into the AS's room
            self.helper.send(room_id, "woof woof", tok=self._real_user_token)
    
            # Capture what was sent as an AS transaction.
            self.send_mock.assert_called()
            last_args, _last_kwargs = self.send_mock.call_args
    
            otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS]
    
            unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
                self.ARG_FALLBACK_KEYS
            ]
    
            self.assertEqual(
                otks,
                {
                    "@as.sender:test": {self._sender_device: {"algo": 42}},
                    "@_as_user1:test": {self._namespaced_device: {"algo": 36}},
                },
            )
            self.assertEqual(
                unused_fallbacks,
                {
                    "@as.sender:test": {self._sender_device: []},
                    "@_as_user1:test": {self._namespaced_device: ["algo"]},
                },
            )