Skip to content
Snippets Groups Projects
server.py 26.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    Amber Brown's avatar
    Amber Brown committed
    import json
    
    from collections import deque
    
    from io import SEEK_END, BytesIO
    
    from typing import (
        Callable,
        Dict,
        Iterable,
        MutableMapping,
        Optional,
        Tuple,
    
    Amber Brown's avatar
    Amber Brown committed
    import attr
    
    from typing_extensions import Deque
    
    from zope.interface import implementer
    
    Amber Brown's avatar
    Amber Brown committed
    
    
    from twisted.internet import address, threads, udp
    
    from twisted.internet._resolver import SimpleResolverComplexifier
    
    from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
    
    from twisted.internet.error import DNSLookupError
    
    from twisted.internet.interfaces import (
    
        IHostnameResolver,
        IProtocol,
        IPullProducer,
        IPushProducer,
    
        IReactorPluggableNameResolver,
    
    Amber Brown's avatar
    Amber Brown committed
    from twisted.python.failure import Failure
    
    from twisted.test.proto_helpers import (
        AccumulatingProtocol,
        MemoryReactor,
        MemoryReactorClock,
    )
    
    from twisted.web.http_headers import Headers
    
    from twisted.web.resource import IResource
    
    from twisted.web.server import Request, Site
    
    from synapse.config.database import DatabaseConnectionConfig
    
    from synapse.http.site import SynapseRequest
    
    from synapse.logging.context import ContextResourceUsage
    
    from synapse.server import HomeServer
    from synapse.storage import DataStore
    from synapse.storage.engines import PostgresEngine, create_engine
    
    from synapse.types import JsonDict
    
    from synapse.util import Clock
    
    Amber Brown's avatar
    Amber Brown committed
    
    
    from tests.utils import (
        LEAVE_DB,
        POSTGRES_BASE_DB,
        POSTGRES_HOST,
        POSTGRES_PASSWORD,
        POSTGRES_USER,
    
    logger = logging.getLogger(__name__)
    
    
    # the type of thing that can be passed into `make_request` in the headers list
    CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
    
    
    class TimedOutException(Exception):
        """
        A web query timed out.
        """
    
    
    
    class FakeChannel:
    
        """
        A fake Twisted Web Channel (the part that interfaces with the
        wire).
        """
    
    
        site: Union[Site, "FakeSite"]
        _reactor: MemoryReactor
        result: dict = attr.Factory(dict)
        _ip: str = "127.0.0.1"
    
        _producer: Optional[Union[IPullProducer, IPushProducer]] = None
    
        resource_usage: Optional[ContextResourceUsage] = None
    
            return json.loads(self.text_body)
    
        @property
        def text_body(self) -> str:
            """The body of the result, utf-8-decoded.
    
            Raises an exception if the request has not yet completed.
            """
            if not self.is_finished:
                raise Exception("Request not yet completed")
            return self.result["body"].decode("utf8")
    
        def is_finished(self) -> bool:
            """check if the response has been completely received"""
            return self.result.get("done", False)
    
    
        @property
        def code(self):
            if not self.result:
                raise Exception("No result yet.")
            return int(self.result["code"])
    
            if not self.result:
                raise Exception("No result yet.")
            h = Headers()
            for i in self.result["headers"]:
                h.addRawHeader(*i)
            return h
    
    
        def writeHeaders(self, version, code, reason, headers):
            self.result["version"] = version
            self.result["code"] = code
            self.result["reason"] = reason
            self.result["headers"] = headers
    
        def write(self, content):
    
            assert isinstance(content, bytes), "Should be bytes! " + repr(content)
    
    
            if "body" not in self.result:
                self.result["body"] = b""
    
            self.result["body"] += content
    
    
    Amber Brown's avatar
    Amber Brown committed
        def registerProducer(self, producer, streaming):
            self._producer = producer
    
            self.producerStreaming = streaming
    
            def _produce():
                if self._producer:
                    self._producer.resumeProducing()
                    self._reactor.callLater(0.1, _produce)
    
            if not streaming:
                self._reactor.callLater(0.0, _produce)
    
    Amber Brown's avatar
    Amber Brown committed
    
        def unregisterProducer(self):
            if self._producer is None:
                return
    
            self._producer = None
    
    
        def requestDone(self, _self):
            self.result["done"] = True
    
            if isinstance(_self, SynapseRequest):
                self.resource_usage = _self.logcontext.get_resource_usage()
    
    Erik Johnston's avatar
    Erik Johnston committed
            # We give an address so that getClientIP returns a non null entry,
            # causing us to record the MAU
    
            return address.IPv4Address("TCP", self._ip, 3423)
    
            # this is called by Request.__init__ to configure Request.host.
            return address.IPv4Address("TCP", "127.0.0.1", 8888)
    
        def isSecure(self):
            return False
    
        def await_result(self, timeout_ms: int = 1000) -> None:
    
            """
            Wait until the request is finished.
            """
    
            end_time = self._reactor.seconds() + timeout_ms / 1000.0
    
            while not self.is_finished():
    
                # If there's a producer, tell it to resume producing so we get content
                if self._producer:
                    self._producer.resumeProducing()
    
    
                    raise TimedOutException("Timed out waiting for request to finish.")
    
                self._reactor.advance(0.1)
    
    
        def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
            """Process the contents of any Set-Cookie headers in the response
    
            Any cookines found are added to the given dict
            """
    
            headers = self.headers.getRawHeaders("Set-Cookie")
            if not headers:
                return
    
            for h in headers:
    
                parts = h.split(";")
                k, v = parts[0].split("=", maxsplit=1)
                cookies[k] = v
    
    
    
    class FakeSite:
        """
        A fake Twisted Web Site, with mocks of the extra things that
        Synapse adds.
        """
    
        server_version_string = b"1"
        site_tag = "test"
    
        access_logger = logging.getLogger("synapse.access.http.fake")
    
        def __init__(self, resource: IResource, reactor: IReactorTime):
    
            """
    
            Args:
                resource: the resource to be used for rendering all requests
            """
            self._resource = resource
    
    
        def getResourceFor(self, request):
            return self._resource
    
    
    def make_request(
    
        site: Union[Site, FakeSite],
    
        method: Union[bytes, str],
        path: Union[bytes, str],
        content: Union[bytes, str, JsonDict] = b"",
        access_token: Optional[str] = None,
    
        request: Type[Request] = SynapseRequest,
    
        shorthand: bool = True,
        federation_auth_origin: Optional[bytes] = None,
        content_is_form: bool = False,
    
        await_result: bool = True,
    
        custom_headers: Optional[Iterable[CustomHeaderType]] = None,
    
        client_ip: str = "127.0.0.1",
    
        Make a web request using the given method, path and content, and render it
    
    
        Returns the fake Channel object which records the response to the request.
    
            site: The twisted Site to use to render the request
    
            method: The HTTP request method ("verb").
            path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
            content: The body of the request. JSON-encoded, if a str of bytes.
            access_token: The access token to add as authorization for the request.
            request: The request class to create.
    
            shorthand: Whether to try and be helpful and prefix the given URL
    
                with the usual REST API path, if it doesn't contain it.
            federation_auth_origin: if set to not-None, we will add a fake
    
                Authorization header pretenting to be the given server name.
    
            content_is_form: Whether the content is URL encoded form data. Adds the
                'Content-Type': 'application/x-www-form-urlencoded' header.
    
            await_result: whether to wait for the request to complete rendering. If true,
                 will pump the reactor until the the renderer tells the channel the request
                 is finished.
    
            custom_headers: (name, value) pairs to add as request headers
    
            client_ip: The IP to use as the requesting IP. Useful for testing
                ratelimiting.
    
    
        if not isinstance(method, bytes):
    
    Amber Brown's avatar
    Amber Brown committed
            method = method.encode("ascii")
    
    
        if not isinstance(path, bytes):
    
    Amber Brown's avatar
    Amber Brown committed
            path = path.encode("ascii")
    
        # Decorate it to be the full path, if we're using shorthand
    
        if (
            shorthand
            and not path.startswith(b"/_matrix")
            and not path.startswith(b"/_synapse")
        ):
    
            if path.startswith(b"/"):
                path = path[1:]
    
            path = b"/_matrix/client/r0/" + path
    
    
        if not path.startswith(b"/"):
            path = b"/" + path
    
    
        if isinstance(content, dict):
            content = json.dumps(content).encode("utf8")
    
    Amber Brown's avatar
    Amber Brown committed
            content = content.encode("utf8")
    
        channel = FakeChannel(site, reactor, ip=client_ip)
    
        req = request(channel, site)
    
        req.content = BytesIO(content)
    
        # Twisted expects to be at the end of the content when parsing the request.
    
        req.content.seek(0, SEEK_END)
    
    Erik Johnston's avatar
    Erik Johnston committed
    
        if access_token:
    
            req.requestHeaders.addRawHeader(
    
    Amber Brown's avatar
    Amber Brown committed
                b"Authorization", b"Bearer " + access_token.encode("ascii")
    
    Erik Johnston's avatar
    Erik Johnston committed
    
    
        if federation_auth_origin is not None:
            req.requestHeaders.addRawHeader(
    
                b"Authorization",
                b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
    
            if content_is_form:
                req.requestHeaders.addRawHeader(
                    b"Content-Type", b"application/x-www-form-urlencoded"
                )
            else:
                # Assume the body is JSON
                req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
    
        if custom_headers:
            for k, v in custom_headers:
                req.requestHeaders.addRawHeader(k, v)
    
    
        req.parseCookies()
    
        req.requestReceived(method, path, b"1.1")
    
    
        if await_result:
            channel.await_result()
    
    
    @implementer(IReactorPluggableNameResolver)
    
    class ThreadedMemoryReactorClock(MemoryReactorClock):
        """
        A MemoryReactorClock that supports callFromThread.
        """
    
    black's avatar
    black committed
    
    
            self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
    
            self.lookups: Dict[str, str] = {}
            self._thread_callbacks: Deque[Callable[[], None]] = deque()
    
            lookups = self.lookups
    
    
            @implementer(IResolverSimple)
    
            class FakeResolver:
    
                def getHostByName(self, name, timeout=None):
                    if name not in lookups:
    
                        return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
    
                    return succeed(lookups[name])
    
            self.nameResolver = SimpleResolverComplexifier(FakeResolver())
    
        def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
            raise NotImplementedError()
    
    
    Amber Brown's avatar
    Amber Brown committed
        def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
    
            p = udp.Port(port, protocol, interface, maxPacketSize, self)
            p.startListening()
            self._udp.append(p)
            return p
    
    
        def callFromThread(self, callback, *args, **kwargs):
            """
            Make the callback fire in the next reactor iteration.
            """
    
            cb = lambda: callback(*args, **kwargs)
            # it's not safe to call callLater() here, so we append the callback to a
            # separate queue.
            self._thread_callbacks.append(cb)
    
        def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
    
            """Add a callback that will be invoked when we receive a connection
            attempt to the given IP/port using `connectTCP`.
    
            Note that the callback gets run before we return the connection to the
            client, which means callbacks cannot block while waiting for writes.
            """
            self._tcp_callbacks[(host, port)] = callback
    
    
        def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
    
            """Fake L{IReactorTCP.connectTCP}."""
    
    
            conn = super().connectTCP(
                host, port, factory, timeout=timeout, bindAddress=None
            )
    
            callback = self._tcp_callbacks.get((host, port))
            if callback:
                callback()
    
            return conn
    
    
        def advance(self, amount):
            # first advance our reactor's time, and run any "callLater" callbacks that
            # makes ready
            super().advance(amount)
    
            # now run any "callFromThread" callbacks
            while True:
                try:
                    callback = self._thread_callbacks.popleft()
                except IndexError:
                    break
                callback()
    
                # check for more "callLater" callbacks added by the thread callback
                # This isn't required in a regular reactor, but it ends up meaning that
                # our database queries can complete in a single call to `advance` [1] which
                # simplifies tests.
                #
                # [1]: we replace the threadpool backing the db connection pool with a
                # mock ThreadPool which doesn't really use threads; but we still use
                # reactor.callFromThread to feed results back from the db functions to the
                # main thread.
                super().advance(0)
    
    
    
    class ThreadPool:
        """
        Threadless thread pool.
        """
    
        def __init__(self, reactor):
            self._reactor = reactor
    
        def start(self):
            pass
    
        def stop(self):
            pass
    
        def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
            def _(res):
                if isinstance(res, Failure):
                    onResult(False, res)
                else:
                    onResult(True, res)
    
            d = Deferred()
            d.addCallback(lambda x: function(*args, **kwargs))
            d.addBoth(_)
            self._reactor.callLater(0, d.callback, True)
            return d
    
    
    def _make_test_homeserver_synchronous(server: HomeServer) -> None:
    
        Make the given test homeserver's database interactions synchronous.
    
        clock = server.get_clock()
    
        for database in server.get_datastores().databases:
            pool = database._db_pool
    
            def runWithConnection(func, *args, **kwargs):
                return threads.deferToThreadPool(
                    pool._reactor,
                    pool.threadpool,
                    pool._runWithConnection,
                    func,
                    *args,
    
                )
    
            def runInteraction(interaction, *args, **kwargs):
                return threads.deferToThreadPool(
                    pool._reactor,
                    pool.threadpool,
                    pool._runInteraction,
                    interaction,
                    *args,
    
            pool.runWithConnection = runWithConnection
            pool.runInteraction = runInteraction
    
            # Replace the thread pool with a threadless 'thread' pool
    
            pool.threadpool = ThreadPool(clock._reactor)
    
        # We've just changed the Databases to run DB transactions on the same
        # thread, so we need to disable the dedicated thread behaviour.
        server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
    
    
    def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
    
        clock = ThreadedMemoryReactorClock()
        hs_clock = Clock(clock)
    
    @attr.s(cmp=False)
    
    class FakeTransport:
    
        """
        A twisted.internet.interfaces.ITransport implementation which sends all its data
        straight into an IProtocol object: it exists to connect two IProtocols together.
    
        To use it, instantiate it with the receiving IProtocol, and then pass it to the
        sending IProtocol's makeConnection method:
    
            server = HTTPChannel()
            client.makeConnection(FakeTransport(server, self.reactor))
    
        If you want bidirectional communication, you'll need two instances.
        """
    
        other = attr.ib()
        """The Protocol object which will receive any data written to this transport.
    
        :type: twisted.internet.interfaces.IProtocol
        """
    
        _reactor = attr.ib()
        """Test reactor
    
        :type: twisted.internet.interfaces.IReactorTime
        """
    
    
        _protocol = attr.ib(default=None)
        """The Protocol which is producing data for this transport. Optional, but if set
        will get called back for connectionLost() notifications etc.
        """
    
    
        _peer_address: Optional[IAddress] = attr.ib(default=None)
        """The value to be returend by getPeer"""
    
    
    Amber Brown's avatar
    Amber Brown committed
        buffer = attr.ib(default=b"")
    
        producer = attr.ib(default=None)
    
        autoflush = attr.ib(default=True)
    
            return self._peer_address
    
        def loseConnection(self, reason=None):
            if not self.disconnecting:
    
                logger.info("FakeTransport: loseConnection(%s)", reason)
    
                self.disconnecting = True
                if self._protocol:
                    self._protocol.connectionLost(reason)
    
    
                # if we still have data to write, delay until that is done
                if self.buffer:
                    logger.info(
                        "FakeTransport: Delaying disconnect until buffer is flushed"
                    )
                else:
    
                    self.connected = False
    
            logger.info("FakeTransport: abortConnection()")
    
    
            if not self.disconnecting:
                self.disconnecting = True
                if self._protocol:
                    self._protocol.connectionLost(None)
    
            self.disconnected = True
    
            if not self.producer:
                return
    
    
        def resumeProducing(self):
            if not self.producer:
                return
            self.producer.resumeProducing()
    
    
        def unregisterProducer(self):
            if not self.producer:
                return
    
            self.producer = None
    
        def registerProducer(self, producer, streaming):
            self.producer = producer
            self.producerStreaming = streaming
    
            def _produce():
    
                if not self.producer:
                    # we've been unregistered
                    return
                # some implementations of IProducer (for example, FileSender)
                # don't return a deferred.
                d = maybeDeferred(self.producer.resumeProducing)
    
                d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
    
            if not streaming:
                self._reactor.callLater(0.0, _produce)
    
        def write(self, byt):
    
            if self.disconnecting:
                raise Exception("Writing to disconnecting FakeTransport")
    
    
            # always actually do the write asynchronously. Some protocols (notably the
            # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
            # still doing a write. Doing a callLater here breaks the cycle.
    
            if self.autoflush:
                self._reactor.callLater(0.0, self.flush)
    
    
        def writeSequence(self, seq):
            for x in seq:
                self.write(x)
    
    
        def flush(self, maxbytes=None):
            if not self.buffer:
                # nothing to do. Don't write empty buffers: it upsets the
                # TLSMemoryBIOProtocol
                return
    
            if self.disconnected:
                return
    
            if maxbytes is not None:
                to_write = self.buffer[:maxbytes]
            else:
                to_write = self.buffer
    
            logger.info("%s->%s: %s", self._protocol, self.other, to_write)
    
            try:
                self.other.dataReceived(to_write)
            except Exception as e:
    
                logger.exception("Exception writing to protocol: %s", e)
    
            self.buffer = self.buffer[len(to_write) :]
    
            if self.buffer and self.autoflush:
                self._reactor.callLater(0.0, self.flush)
    
            if not self.buffer and self.disconnecting:
                logger.info("FakeTransport: Buffer now empty, completing disconnect")
                self.disconnected = True
    
    
    def connect_client(
        reactor: ThreadedMemoryReactorClock, client_id: int
    ) -> Tuple[IProtocol, AccumulatingProtocol]:
    
        """
        Connect a client to a fake TCP transport.
    
        Args:
            reactor
            factory: The connecting factory to build.
        """
    
        factory = reactor.tcpClients.pop(client_id)[2]
    
        client = factory.buildProtocol(None)
        server = AccumulatingProtocol()
        server.makeConnection(FakeTransport(client, reactor))
        client.makeConnection(FakeTransport(server, reactor))
    
        return client, server
    
    
    
    class TestHomeServer(HomeServer):
        DATASTORE_CLASS = DataStore
    
    
    def setup_test_homeserver(
        cleanup_func,
        name="test",
        config=None,
        reactor=None,
        homeserver_to_use: Type[HomeServer] = TestHomeServer,
        **kwargs,
    ):
        """
        Setup a homeserver suitable for running tests against.  Keyword arguments
        are passed to the Homeserver constructor.
    
        If no datastore is supplied, one is created and given to the homeserver.
    
        Args:
            cleanup_func : The function used to register a cleanup routine for
                           after the test.
    
        Calling this method directly is deprecated: you should instead derive from
        HomeserverTestCase.
        """
        if reactor is None:
            from twisted.internet import reactor
    
        if config is None:
            config = default_config(name, parse=True)
    
        config.ldap_enabled = False
    
        if "clock" not in kwargs:
            kwargs["clock"] = MockClock()
    
        if USE_POSTGRES_FOR_TESTS:
            test_db = "synapse_test_%s" % uuid.uuid4().hex
    
            database_config = {
                "name": "psycopg2",
                "args": {
                    "database": test_db,
                    "host": POSTGRES_HOST,
                    "password": POSTGRES_PASSWORD,
                    "user": POSTGRES_USER,
                    "cp_min": 1,
                    "cp_max": 5,
                },
            }
        else:
    
            if SQLITE_PERSIST_DB:
                # The current working directory is in _trial_temp, so this gets created within that directory.
                test_db_location = os.path.abspath("test.db")
                logger.debug("Will persist db to %s", test_db_location)
                # Ensure each test gets a clean database.
                try:
                    os.remove(test_db_location)
                except FileNotFoundError:
                    pass
                else:
                    logger.debug("Removed existing DB at %s", test_db_location)
            else:
                test_db_location = ":memory:"
    
    
                "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
    
            }
    
        if "db_txn_limit" in kwargs:
            database_config["txn_limit"] = kwargs["db_txn_limit"]
    
        database = DatabaseConnectionConfig("master", database_config)
        config.database.databases = [database]
    
        db_engine = create_engine(database.config)
    
        # Create the database before we actually try and connect to it, based off
        # the template database we generate in setupdb()
        if isinstance(db_engine, PostgresEngine):
            db_conn = db_engine.module.connect(
                database=POSTGRES_BASE_DB,
                user=POSTGRES_USER,
                host=POSTGRES_HOST,
                password=POSTGRES_PASSWORD,
            )
            db_conn.autocommit = True
            cur = db_conn.cursor()
            cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
            cur.execute(
                "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
            )
            cur.close()
            db_conn.close()
    
        hs = homeserver_to_use(
            name,
            config=config,
            version_string="Synapse/tests",
            reactor=reactor,
        )
    
        # Install @cache_in_self attributes
        for key, val in kwargs.items():
            setattr(hs, "_" + key, val)
    
        # Mock TLS
        hs.tls_server_context_factory = Mock()
        hs.tls_client_options_factory = Mock()
    
        hs.setup()
        if homeserver_to_use == TestHomeServer:
            hs.setup_background_tasks()
    
        if isinstance(db_engine, PostgresEngine):
            database = hs.get_datastores().databases[0]
    
            # We need to do cleanup on PostgreSQL
            def cleanup():
                import psycopg2
    
                # Close all the db pools
                database._db_pool.close()
    
                dropped = False
    
                # Drop the test database
                db_conn = db_engine.module.connect(
                    database=POSTGRES_BASE_DB,
                    user=POSTGRES_USER,
                    host=POSTGRES_HOST,
                    password=POSTGRES_PASSWORD,
                )
                db_conn.autocommit = True
                cur = db_conn.cursor()
    
                # Try a few times to drop the DB. Some things may hold on to the
                # database for a few more seconds due to flakiness, preventing
                # us from dropping it when the test is over. If we can't drop
                # it, warn and move on.
                for _ in range(5):
                    try:
                        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
                        db_conn.commit()
                        dropped = True
                    except psycopg2.OperationalError as e:
                        warnings.warn(
                            "Couldn't drop old db: " + str(e), category=UserWarning
                        )
                        time.sleep(0.5)
    
                cur.close()
                db_conn.close()
    
                if not dropped:
                    warnings.warn("Failed to drop old DB.", category=UserWarning)
    
            if not LEAVE_DB:
                # Register the cleanup hook
                cleanup_func(cleanup)
    
        # bcrypt is far too slow to be doing in unit tests
        # Need to let the HS build an auth handler and then mess with it
        # because AuthHandler's constructor requires the HS, so we can't make one
        # beforehand and pass it in to the HS's constructor (chicken / egg)
        async def hash(p):
            return hashlib.md5(p.encode("utf8")).hexdigest()
    
        hs.get_auth_handler().hash = hash
    
        async def validate_hash(p, h):
            return hashlib.md5(p.encode("utf8")).hexdigest() == h
    
        hs.get_auth_handler().validate_hash = validate_hash
    
        # Make the threadpool and database transactions synchronous for testing.
        _make_test_homeserver_synchronous(hs)
    
        return hs