Skip to content
Snippets Groups Projects
unittest.py 37 KiB
Newer Older
  • Learn to ignore specific revisions
  • Patrick Cloke's avatar
    Patrick Cloke committed
    # This file is licensed under the Affero General Public License (AGPL) version 3.
    #
    
    # Copyright 2019 Matrix.org Federation C.I.C
    # Copyright 2014-2016 OpenMarket Ltd
    
    Patrick Cloke's avatar
    Patrick Cloke committed
    # Copyright (C) 2023 New Vector, Ltd
    #
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU Affero General Public License as
    # published by the Free Software Foundation, either version 3 of the
    # License, or (at your option) any later version.
    #
    # See the GNU Affero General Public License for more details:
    # <https://www.gnu.org/licenses/agpl-3.0.html>.
    #
    # Originally licensed under the Apache License, Version 2.0:
    # <http://www.apache.org/licenses/LICENSE-2.0>.
    #
    # [This file includes modifications made by New Vector Limited]
    
    import hashlib
    import hmac
    
    from unittest.mock import Mock, patch
    
    import canonicaljson
    import signedjson.key
    import unpaddedbase64
    
    from typing_extensions import Concatenate, ParamSpec, Protocol
    
    from twisted.internet.defer import Deferred, ensureDeferred
    
    from twisted.python.failure import Failure
    
    from twisted.python.threadpool import ThreadPool
    
    from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
    
    from twisted.web.resource import Resource
    
    from twisted.web.server import Request
    
    from synapse.api.constants import EventTypes
    
    from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
    
    from synapse.config._base import Config, RootConfig
    
    from synapse.config.homeserver import HomeServerConfig
    
    from synapse.config.server import DEFAULT_ROOM_VERSION
    from synapse.crypto.event_signing import add_hashes_and_signatures
    
    from synapse.federation.transport.server import TransportLayerServer
    
    from synapse.http.server import JsonResource, OptionsResource
    
    from synapse.http.site import SynapseRequest, SynapseSite
    
    from synapse.logging.context import (
        SENTINEL_CONTEXT,
    
        current_context,
        set_current_context,
    )
    
    from synapse.rest import RegisterServletsFunc
    
    from synapse.server import HomeServer
    
    from synapse.storage.keys import FetchKeyResult
    
    from synapse.types import JsonDict, Requester, UserID, create_requester
    
    from synapse.util import Clock
    
    from synapse.util.httpresourcetree import create_resource_tree
    
    from tests.server import (
        CustomHeaderType,
        FakeChannel,
    
        ThreadedMemoryReactorClock,
    
        get_clock,
        make_request,
        setup_test_homeserver,
    )
    
    from tests.test_utils import event_injection, setup_awaitable_errors
    
    from tests.test_utils.logging_setup import setup_logging
    
    from tests.utils import checked_cast, default_config, setupdb
    
    TV = TypeVar("TV")
    _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
    
    
    P = ParamSpec("P")
    R = TypeVar("R")
    S = TypeVar("S")
    
    
    
    class _TypedFailure(Generic[_ExcType], Protocol):
        """Extension to twisted.Failure, where the 'value' has a certain type."""
    
        @property
    
        def value(self) -> _ExcType: ...
    
    def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
    
        """A CLOS-style 'around' modifier, which wraps the original method of the
        given instance with another piece of code.
    
        @around(self)
        def method_name(orig, *args, **kwargs):
            return orig(*args, **kwargs)
        """
    
    black's avatar
    black committed
    
    
        def _around(code: Callable[Concatenate[S, P], R]) -> None:
    
            def new(*args: P.args, **kwargs: P.kwargs) -> R:
    
    _TConfig = TypeVar("_TConfig", Config, RootConfig)
    
    
    def deepcopy_config(config: _TConfig) -> _TConfig:
        new_config: _TConfig
    
        if isinstance(config, RootConfig):
            new_config = config.__class__(config.config_files)  # type: ignore[arg-type]
        else:
            new_config = config.__class__(config.root)
    
        for attr_name in config.__dict__:
            if attr_name.startswith("__") or attr_name == "root":
                continue
            attr = getattr(config, attr_name)
            if isinstance(attr, Config):
                new_attr = deepcopy_config(attr)
            else:
                new_attr = attr
    
            setattr(new_config, attr_name, new_attr)
    
        return new_config
    
    
    
    @functools.lru_cache(maxsize=8)
    def _parse_config_dict(config: str) -> RootConfig:
        config_obj = HomeServerConfig()
        config_obj.parse_config_dict(json.loads(config), "", "")
        return config_obj
    
    
    
    def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
        """Creates a :class:`HomeServerConfig` instance with the given configuration dict.
    
        This is equivalent to::
    
            config_obj = HomeServerConfig()
            config_obj.parse_config_dict(config, "", "")
    
        but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
        to avoid validating the whole configuration every time.
        """
    
        config_obj = _parse_config_dict(json.dumps(config, sort_keys=True))
    
        """A subclass of twisted.trial's TestCase which looks for 'loglevel'
        attributes on both itself and its individual test methods, to override the
        root logger's logging level while that test (case|method) runs."""
    
    
        def __init__(self, methodName: str):
            super().__init__(methodName)
    
            level = getattr(method, "loglevel", getattr(self, "loglevel", None))
    
            def setUp(orig: Callable[[], R]) -> R:
    
                # if we're not starting in the sentinel logcontext, then to be honest
                # all future bets are off.
    
                if current_context():
    
                        "Test starting with non-sentinel logging context %s"
    
                        % (current_context(),)
    
                # Disable GC for duration of test. See below for why.
                gc.disable()
    
    
                old_level = logging.getLogger().level
    
                if level is not None and old_level != level:
    
    black's avatar
    black committed
    
    
                    def tearDown(orig: Callable[[], R]) -> R:
    
                    logging.getLogger().setLevel(level)
    
    
                # Trial messes with the warnings configuration, thus this has to be
                # done in the context of an individual TestCase.
                self.addCleanup(setup_awaitable_errors())
    
    
            # We want to force a GC to workaround problems with deferreds leaking
            # logcontexts when they are GCed (see the logcontext docs).
            #
            # The easiest way to do this would be to do a full GC after each test
            # run, but that is very expensive. Instead, we disable GC (above) for
    
            # the duration of the test and only run a gen-0 GC, which is a lot
            # quicker. This doesn't clean up everything, since the TestCase
            # instance still holds references to objects created during the test,
            # such as HomeServers, so we do a full GC every so often.
    
            def tearDown(orig: Callable[[], R]) -> R:
    
                gc.collect(0)
    
                # Run a full GC every 50 gen-0 GCs.
                gen0_stats = gc.get_stats()[0]
                gen0_collections = gen0_stats["collections"]
                if gen0_collections % 50 == 0:
                    gc.collect()
    
                set_current_context(SENTINEL_CONTEXT)
    
        def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
    
            """Asserts that the given object has each of the attributes given, and
    
            that the value of each matches according to assertEqual."""
    
            for key in attrs.keys():
    
                if not hasattr(obj, key):
                    raise AssertionError("Expected obj to have a '.%s'" % key)
                try:
    
                    self.assertEqual(attrs[key], getattr(obj, key))
    
                    raise (type(e))(f"Assert error for '.{key}':") from e
    
        def assert_dict(self, required: Mapping, actual: Mapping) -> None:
    
            """Does a partial assert of a dict.
    
            Args:
    
                required: The keys and value which MUST be in 'actual'.
                actual: The test result. Extra keys will not be checked.
    
    black's avatar
    black committed
                    required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
                )
    
            actual_items: AbstractSet[TV],
            expected_items: AbstractSet[TV],
    
            exact: bool = False,
            message: Optional[str] = None,
        ) -> None:
            """
            Assert that all of the `expected_items` are included in the `actual_items`.
    
            This assert could also be called `assertContains`, `assertItemsInSet`
    
            Args:
                actual_items: The container
                expected_items: The items to check for in the container
                exact: Whether the actual state should be exactly equal to the expected
                    state (no extras).
                message: Optional message to include in the failure message.
            """
            # Check that each set has the same items
            if exact and actual_items == expected_items:
                return
            # Check for a superset
            elif not exact and actual_items >= expected_items:
                return
    
            expected_lines: List[str] = []
            for expected_item in expected_items:
                is_expected_in_actual = expected_item in actual_items
                expected_lines.append(
                    "{}  {}".format(" " if is_expected_in_actual else "?", expected_item)
                )
    
            actual_lines: List[str] = []
            for actual_item in actual_items:
                is_actual_in_expected = actual_item in expected_items
                actual_lines.append(
                    "{}  {}".format("+" if is_actual_in_expected else " ", actual_item)
                )
    
            newline = "\n"
            expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}"
            actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}"
            first_message = (
                "Items must match exactly" if exact else "Some expected items are missing."
            )
            diff_message = f"{first_message}\n{expected_string}\n{actual_string}"
    
            self.fail(f"{diff_message}\n{message}")
    
    
    def DEBUG(target: TV) -> TV:
    
        """A decorator to set the .loglevel attribute to logging.DEBUG.
        Can apply to either a TestCase or an individual test method."""
    
        target.loglevel = logging.DEBUG  # type: ignore[attr-defined]
    
    def INFO(target: TV) -> TV:
    
        """A decorator to set the .loglevel attribute to logging.INFO.
        Can apply to either a TestCase or an individual test method."""
    
        target.loglevel = logging.INFO  # type: ignore[attr-defined]
    
    def logcontext_clean(target: TV) -> TV:
    
        """A decorator which marks the TestCase or method as 'logcontext_clean'
    
        ... ie, any logcontext errors should cause a test failure
        """
    
    
        def logcontext_error(msg: str) -> NoReturn:
    
            raise AssertionError("logcontext error: %s" % (msg))
    
        patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
    
        return patcher(target)  # type: ignore[call-overload]
    
    class HomeserverTestCase(TestCase):
        """
        A base TestCase that reduces boilerplate for HomeServer-using test cases.
    
    
        Defines a setUp method which creates a mock reactor, and instantiates a homeserver
        running on that reactor.
    
        There are various hooks for modifying the way that the homeserver is instantiated:
    
        * override make_homeserver, for example by making it pass different parameters into
          setup_test_homeserver.
    
        * override default_config, to return a modified configuration dictionary for use
          by setup_test_homeserver.
    
        * On a per-test basis, you can use the @override_config decorator to give a
          dictionary containing additional configuration settings to be added to the basic
          config dict.
    
    
            servlets: List of servlet registration function.
    
            user_id (str): The user ID to assume if auth is hijacked.
    
            hijack_auth: Whether to hijack auth to return the user specified
    
               in user_id.
    
        hijack_auth: ClassVar[bool] = True
        needs_threadpool: ClassVar[bool] = False
    
        servlets: ClassVar[List[RegisterServletsFunc]] = []
    
        def __init__(self, methodName: str):
            super().__init__(methodName)
    
    
            # see if we have any additional config for this test
            method = getattr(self, methodName)
            self._extra_config = getattr(method, "_extra_config", None)
    
    
        def setUp(self) -> None:
    
            """
            Set up the TestCase by calling the homeserver constructor, optionally
            hijacking the authentication system to return a fixed user, and then
            calling the prepare function.
            """
            self.reactor, self.clock = get_clock()
            self._hs_args = {"clock": self.clock, "reactor": self.reactor}
            self.hs = self.make_homeserver(self.reactor, self.clock)
    
    
            self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
    
    
            # Honour the `use_frozen_dicts` config option. We have to do this
            # manually because this is taken care of in the app `start` code, which
            # we don't run. Plus we want to reset it on tearDown.
    
            events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts
    
            if self.hs is None:
                raise Exception("No homeserver returned from make_homeserver.")
    
            if not isinstance(self.hs, HomeServer):
                raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
    
    
            # create the root resource, and a site to wrap it.
            self.resource = self.create_test_resource()
    
            self.site = SynapseSite(
                logger_name="synapse.access.http.fake",
    
                site_tag=self.hs.config.server.server_name,
    
                config=self.hs.config.server.listeners[0],
    
                resource=self.resource,
                server_version_string="1",
    
                max_request_body_size=4096,
    
                reactor=self.reactor,
    
            from tests.rest.client.utils import RestHelper
    
            self.helper = RestHelper(
                self.hs,
                checked_cast(MemoryReactorClock, self.hs.get_reactor()),
                self.site,
                getattr(self, "user_id", None),
            )
    
            if hasattr(self, "user_id"):
    
                    assert self.helper.auth_user_id is not None
    
                    # We need a valid token ID to satisfy foreign key constraints.
                    token_id = self.get_success(
    
                        self.hs.get_datastores().main.add_access_token_to_user(
    
                    # This has to be a function and not just a Mock, because
                    # `self.helper.auth_user_id` is temporarily reassigned in some tests
    
                    async def get_requester(*args: Any, **kwargs: Any) -> Requester:
    
                        assert self.helper.auth_user_id is not None
    
                        return create_requester(
    
                            user_id=UserID.from_string(self.helper.auth_user_id),
                            access_token_id=token_id,
    
                    # Type ignore: mypy doesn't like us assigning to methods.
    
                    self.hs.get_auth().get_user_by_req = get_requester  # type: ignore[method-assign]
                    self.hs.get_auth().get_user_by_access_token = get_requester  # type: ignore[method-assign]
    
                    self.hs.get_auth().get_access_token_from_request = Mock(  # type: ignore[method-assign]
                        return_value=token
                    )
    
                self.reactor.threadpool = ThreadPool()  # type: ignore[assignment]
    
                self.addCleanup(self.reactor.threadpool.stop)
                self.reactor.threadpool.start()
    
    
            if hasattr(self, "prepare"):
                self.prepare(self.reactor, self.clock, self.hs)
    
    
        def tearDown(self) -> None:
    
            # Reset to not use frozen dicts.
            events.USE_FROZEN_DICTS = False
    
    
        def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
    
            """
            Wait until a Deferred is done, where it's waiting on a real thread.
            """
            start_time = time.time()
    
            while not deferred.called:
                if start_time + timeout < time.time():
                    raise ValueError("Timed out waiting for threadpool")
                self.reactor.advance(0.01)
                time.sleep(0.01)
    
    
        def wait_for_background_updates(self) -> None:
    
            """Block until all background database updates have completed."""
    
            store = self.hs.get_datastores().main
    
                store.db_pool.updates.has_completed_background_updates()
    
                    store.db_pool.updates.do_next_background_update(False), by=0.1
    
        def make_homeserver(
            self, reactor: ThreadedMemoryReactorClock, clock: Clock
        ) -> HomeServer:
    
            """
            Make and return a homeserver.
    
            Args:
                reactor: A Twisted Reactor, or something that pretends to be one.
    
                clock: The Clock, associated with the reactor.
    
                A homeserver suitable for testing.
    
    
            Function to be overridden in subclasses.
            """
    
            hs = self.setup_test_homeserver()
            return hs
    
        def create_test_resource(self) -> Resource:
    
            Create a the root resource for the test server.
    
            The default calls `self.create_resource_dict` and builds the resultant dict
            into a tree.
    
            root_resource = OptionsResource()
    
            create_resource_tree(self.create_resource_dict(), root_resource)
            return root_resource
    
        def create_resource_dict(self) -> Dict[str, Resource]:
            """Create a resource tree for the test server
    
            A resource tree is a mapping from path to twisted.web.resource.
    
            The default implementation creates a JsonResource and calls each function in
            `servlets` to register servlets against it.
            """
            servlet_resource = JsonResource(self.hs)
            for servlet in self.servlets:
                servlet(self.hs, servlet_resource)
            return {
                "/_matrix/client": servlet_resource,
                "/_synapse/admin": servlet_resource,
            }
    
        def default_config(self) -> JsonDict:
    
            config = default_config("test")
    
    
            # apply any additional config which was specified via the override_config
            # decorator.
            if self._extra_config is not None:
                config.update(self._extra_config)
    
            return config
    
        def prepare(
            self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
        ) -> None:
    
            """
            Prepare for the test.  This involves things like mocking out parts of
            the homeserver, or building test data common across the whole test
            suite.
    
            Args:
                reactor: A Twisted Reactor, or something that pretends to be one.
    
                clock: The Clock, associated with the reactor.
                homeserver: The HomeServer to test against.
    
    
            Function to optionally be overridden in subclasses.
            """
    
    
            method: Union[bytes, str],
            path: Union[bytes, str],
    
            content: Union[bytes, str, JsonDict] = b"",
    
            access_token: Optional[str] = None,
    
            request: Type[Request] = SynapseRequest,
    
            federation_auth_origin: Optional[bytes] = None,
    
            content_type: Optional[bytes] = None,
    
            content_is_form: bool = False,
    
            await_result: bool = True,
    
            custom_headers: Optional[Iterable[CustomHeaderType]] = None,
    
            client_ip: str = "127.0.0.1",
    
            """
            Create a SynapseRequest at the path using the method and containing the
            given content.
    
            Args:
    
                method: The HTTP request method ("verb").
                path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces
                    and such). content (bytes or dict): The body of the request.
                    JSON-encoded, if a dict.
    
                shorthand: Whether to try and be helpful and prefix the given URL
                with the usual REST API path, if it doesn't contain it.
    
                federation_auth_origin: if set to not-None, we will add a fake
    
                    Authorization header pretenting to be the given server name.
    
    
                content_type: The content-type to use for the request. If not set then will default to
                    application/json unless content_is_form is true.
    
                content_is_form: Whether the content is URL encoded form data. Adds the
                    'Content-Type': 'application/x-www-form-urlencoded' header.
    
                await_result: whether to wait for the request to complete rendering. If
                     true (the default), will pump the test reactor until the the renderer
                     tells the channel the request is finished.
    
    
                custom_headers: (name, value) pairs to add as request headers
    
    
                client_ip: The IP to use as the requesting IP. Useful for testing
                    ratelimiting.
    
    
                The FakeChannel object which stores the result of the request.
    
            return make_request(
    
                self.reactor,
    
                method,
                path,
                content,
                access_token,
                request,
                shorthand,
    
                content_type,
    
        def setup_test_homeserver(
            self, name: Optional[str] = None, **kwargs: Any
        ) -> HomeServer:
    
            """
            Set up the test homeserver, meant to be called by the overridable
            make_homeserver. It automatically passes through the test class's
            clock & reactor.
    
            Args:
                See tests.utils.setup_test_homeserver.
    
            Returns:
                synapse.server.HomeServer
            """
            kwargs = dict(kwargs)
            kwargs.update(self._hs_args)
    
            if "config" not in kwargs:
                config = self.default_config()
    
            # The server name can be specified using either the `name` argument or a config
            # override. The `name` argument takes precedence over any config overrides.
            if name is not None:
                config["server_name"] = name
    
    
            # Parse the config from a config dict into a HomeServerConfig
    
            config_obj = make_homeserver_config_obj(config)
    
            # The server name in the config is now `name`, if provided, or the `server_name`
            # from a config override, or the default of "test". Whichever it is, we
            # construct a homeserver with a matching name.
            kwargs["name"] = config_obj.server.server_name
    
    
            async def run_bg_updates() -> None:
    
                with LoggingContext("run_bg_updates"):
    
                    self.get_success(stor.db_pool.updates.run_background_updates(False))
    
            hs = setup_test_homeserver(self.addCleanup, **kwargs)
    
            stor = hs.get_datastores().main
    
            # Run the database background updates, when running against "master".
            if hs.__class__.__name__ == "TestHomeServer":
    
                self.get_success(run_bg_updates())
    
        def pump(self, by: float = 0.0) -> None:
    
            """
            Pump the reactor enough that Deferreds will fire.
            """
    
            self.reactor.pump([by] * 100)
    
        def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
    
            deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
    
            return self.successResultOf(deferred)
    
            self, d: Awaitable[Any], exc: Type[_ExcType], by: float = 0.0
    
        ) -> _TypedFailure[_ExcType]:
    
            """
            Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
            """
    
            deferred: Deferred[Any] = ensureDeferred(d)  # type: ignore[arg-type]
    
            return self.failureResultOf(deferred, exc)
    
        def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
    
            """Drive deferred to completion and return result or raise exception
            on failure.
            """
    
            deferred: Deferred[TV] = ensureDeferred(d)  # type: ignore[arg-type]
    
            results: list = []
    
            deferred.addBoth(results.append)
    
            self.pump(by=by)
    
            if not results:
                self.fail(
                    "Success result expected on {!r}, found no result instead".format(
                        deferred
                    )
                )
    
            result = results[0]
    
            if isinstance(result, Failure):
                result.raiseException()
    
            return result
    
    
        def register_user(
            self,
            username: str,
            password: str,
            admin: Optional[bool] = False,
            displayname: Optional[str] = None,
        ) -> str:
    
            """
            Register a user. Requires the Admin API be registered.
    
            Args:
    
                username: The user part of the new user.
                password: The password of the new user.
                admin: Whether the user should be created as an admin or not.
                displayname: The displayname of the new user.
    
            self.hs.config.registration.registration_shared_secret = "shared"
    
    
            # Create the user
    
            channel = self.make_request("GET", "/_synapse/admin/v1/register")
    
            self.assertEqual(channel.code, 200, msg=channel.result)
    
            nonce = channel.json_body["nonce"]
    
            want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
    
    Amber Brown's avatar
    Amber Brown committed
            nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
    
            if admin:
                nonce_str += b"\x00admin"
            else:
                nonce_str += b"\x00notadmin"
    
    Amber Brown's avatar
    Amber Brown committed
            want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
    
            want_mac_digest = want_mac.hexdigest()
    
            body = {
                "nonce": nonce,
                "username": username,
                "displayname": displayname,
                "password": password,
                "admin": admin,
                "mac": want_mac_digest,
                "inhibit_login": True,
            }
            channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
    
            self.assertEqual(channel.code, 200, channel.json_body)
    
    
            user_id = channel.json_body["user_id"]
            return user_id
    
    
        def register_appservice_user(
            self,
            username: str,
            appservice_token: str,
    
            """Register an appservice user as an application service.
            Requires the client-facing registration API be registered.
    
            Args:
                username: the user to be registered by an application service.
    
                    Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
    
                appservice_token: the acccess token for that application service.
    
            Raises: if the request to '/register' does not return 200 OK.
    
    
            Returns:
                The MXID of the new user, the device ID of the new user's first device.
    
            """
            channel = self.make_request(
                "POST",
                "/_matrix/client/r0/register",
                {
                    "username": username,
                    "type": "m.login.application_service",
                },
                access_token=appservice_token,
            )
            self.assertEqual(channel.code, 200, channel.json_body)
    
            return channel.json_body["user_id"], channel.json_body.get("device_id")
    
            username: str,
            password: str,
            device_id: Optional[str] = None,
    
            additional_request_fields: Optional[Dict[str, str]] = None,
    
            custom_headers: Optional[Iterable[CustomHeaderType]] = None,
        ) -> str:
    
            Log in a user, and get an access token. Requires the Login API be registered.
    
    
            Args:
                username: The localpart to assign to the new user.
                password: The password to assign to the new user.
                device_id: An optional device ID to assign to the new device created during
                    login.
                additional_request_fields: A dictionary containing any additional /login
                    request fields and their values.
                custom_headers: Custom HTTP headers and values to add to the /login request.
    
            Returns:
                The newly registered user's Matrix ID.
    
            """
            body = {"type": "m.login.password", "user": username, "password": password}
            if device_id:
                body["device_id"] = device_id
    
            if additional_request_fields:
                body.update(additional_request_fields)
    
            self.assertEqual(channel.code, 200, channel.result)
    
            access_token = channel.json_body["access_token"]
    
            return access_token
    
        def create_and_send_event(
    
            self,
            room_id: str,
            user: UserID,
            soft_failed: bool = False,
            prev_event_ids: Optional[List[str]] = None,
        ) -> str:
    
            """
            Create and send an event.
    
            Args:
    
                soft_failed: Whether to create a soft failed event or not
                prev_event_ids: Explicitly set the prev events,
    
                    or if None just use the default
    
            Returns:
    
            """
            event_creator = self.hs.get_event_creation_handler()
    
            requester = create_requester(user)
    
            event, unpersisted_context = self.get_success(
    
                event_creator.create_event(
                    requester,
                    {
                        "type": EventTypes.Message,
                        "room_id": room_id,
                        "sender": user.to_string(),
                        "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
                    },
    
                    prev_event_ids=prev_event_ids,
    
            context = self.get_success(unpersisted_context.persist(event))
    
            if soft_failed:
                event.internal_metadata.soft_failed = True
    
    
            self.get_success(
    
                event_creator.handle_new_client_event(
                    requester, events_and_context=[(event, context)]
                )
    
        def inject_room_member(self, room: str, user: str, membership: str) -> None:
    
            """
            Inject a membership event into a room.
    
    
            Deprecated: use event_injection.inject_room_member directly
    
    
            Args:
                room: Room ID to inject the event into.
                user: MXID of the user to inject the membership for.
                membership: The membership type.
            """
    
            self.get_success(
                event_injection.inject_member_event(self.hs, room, user, membership)
            )
    
    
    
    class FederatingHomeserverTestCase(HomeserverTestCase):
        """
    
        A federating homeserver, set up to validate incoming federation requests
    
        OTHER_SERVER_NAME = "other.example.com"
        OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
    
        def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
    
            super().prepare(reactor, clock, hs)
    
            # poke the other server's signing key into the key store, so that we don't
            # make requests for it
            verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
            verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
    
                hs.get_datastores().main.store_server_keys_response(
    
                    from_server=self.OTHER_SERVER_NAME,
    
                    ts_added_ms=clock.time_msec(),
                    verify_keys={
                        verify_key_id: FetchKeyResult(
                            verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
                        ),
                    },
                    response_json={
                        "verify_keys": {
                            verify_key_id: {
                                "key": signedjson.key.encode_verify_key_base64(verify_key)
    
                )
            )
    
        def create_resource_dict(self) -> Dict[str, Resource]:
            d = super().create_resource_dict()
            d["/_matrix/federation"] = TransportLayerServer(self.hs)
            return d
    
        def make_signed_federation_request(
            self,
            method: str,
            path: str,
            content: Optional[JsonDict] = None,
            await_result: bool = True,
    
            custom_headers: Optional[Iterable[CustomHeaderType]] = None,
    
            client_ip: str = "127.0.0.1",
        ) -> FakeChannel:
            """Make an inbound signed federation request to this server
    
            The request is signed as if it came from "other.example.com", which our HS
            already has the keys for.
            """
    
            if custom_headers is None:
                custom_headers = []
            else:
                custom_headers = list(custom_headers)
    
            custom_headers.append(
                (
                    "Authorization",
                    _auth_header_for_request(
                        origin=self.OTHER_SERVER_NAME,
                        destination=self.hs.hostname,
                        signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
                        method=method,
                        path=path,
                        content=content,
                    ),
                )
            )
    
            return make_request(
                self.reactor,
                self.site,
                method=method,
                path=path,
    
                content=content if content is not None else "",
    
                shorthand=False,
                await_result=await_result,
                custom_headers=custom_headers,
                client_ip=client_ip,
    
        def add_hashes_and_signatures_from_other_server(