Skip to content
Snippets Groups Projects
server.py 5.14 KiB
Newer Older
  • Learn to ignore specific revisions
  • Amber Brown's avatar
    Amber Brown committed
    import json
    
    from io import BytesIO
    
    from six import text_type
    
    
    Amber Brown's avatar
    Amber Brown committed
    import attr
    
    from twisted.internet import threads
    
    from twisted.internet.defer import Deferred
    
    Amber Brown's avatar
    Amber Brown committed
    from twisted.python.failure import Failure
    
    from twisted.test.proto_helpers import MemoryReactorClock
    
    from synapse.http.site import SynapseRequest
    
    from synapse.util import Clock
    
    Amber Brown's avatar
    Amber Brown committed
    
    
    from tests.utils import setup_test_homeserver as _sth
    
    
    @attr.s
    class FakeChannel(object):
        """
        A fake Twisted Web Channel (the part that interfaces with the
        wire).
        """
    
    
        result = attr.ib(default=attr.Factory(dict))
    
    
        @property
        def json_body(self):
            if not self.result:
                raise Exception("No result yet.")
    
            return json.loads(self.result["body"].decode('utf8'))
    
        @property
        def code(self):
            if not self.result:
                raise Exception("No result yet.")
            return int(self.result["code"])
    
    
        def writeHeaders(self, version, code, reason, headers):
            self.result["version"] = version
            self.result["code"] = code
            self.result["reason"] = reason
            self.result["headers"] = headers
    
        def write(self, content):
            if "body" not in self.result:
                self.result["body"] = b""
    
            self.result["body"] += content
    
        def requestDone(self, _self):
            self.result["done"] = True
    
        def getPeer(self):
            return None
    
        def getHost(self):
            return None
    
        @property
        def transport(self):
            return self
    
    
    class FakeSite:
        """
        A fake Twisted Web Site, with mocks of the extra things that
        Synapse adds.
        """
    
        server_version_string = b"1"
        site_tag = "test"
    
        @property
        def access_logger(self):
            class FakeLogger:
                def info(self, *args, **kwargs):
                    pass
    
            return FakeLogger()
    
    
    def make_request(method, path, content=b""):
        """
        Make a web request using the given method and path, feed it the
        content, and return the Request and the Channel underneath.
        """
    
        if not isinstance(method, bytes):
            method = method.encode('ascii')
    
        if not isinstance(path, bytes):
            path = path.encode('ascii')
    
        # Decorate it to be the full path
        if not path.startswith(b"/_matrix"):
            path = b"/_matrix/client/r0/" + path
    
            path = path.replace(b"//", b"/")
    
        if isinstance(content, text_type):
            content = content.encode('utf8')
    
        site = FakeSite()
        channel = FakeChannel()
    
        req = SynapseRequest(site, channel)
        req.process = lambda: b""
        req.content = BytesIO(content)
        req.requestReceived(method, path, b"1.1")
    
        return req, channel
    
    
    def wait_until_result(clock, channel, timeout=100):
        """
        Wait until the channel has a result.
        """
        clock.run()
        x = 0
    
        while not channel.result:
            x += 1
    
            if x > timeout:
                raise Exception("Timed out waiting for request to finish.")
    
            clock.advance(0.1)
    
    
    
    def render(request, resource, clock):
        request.render(resource)
        wait_until_result(clock, request._channel)
    
    
    
    class ThreadedMemoryReactorClock(MemoryReactorClock):
        """
        A MemoryReactorClock that supports callFromThread.
        """
    
    black's avatar
    black committed
    
    
        def callFromThread(self, callback, *args, **kwargs):
            """
            Make the callback fire in the next reactor iteration.
            """
            d = Deferred()
            d.addCallback(lambda x: callback(*args, **kwargs))
            self.callLater(0, d.callback, True)
            return d
    
    
    
    def setup_test_homeserver(cleanup_func, *args, **kwargs):
    
        """
        Set up a synchronous test server, driven by the reactor used by
        the homeserver.
        """
    
        d = _sth(cleanup_func, *args, **kwargs).result
    
        if isinstance(d, Failure):
            d.raiseException()
    
    
        # Make the thread pool synchronous.
        clock = d.get_clock()
        pool = d.get_db_pool()
    
        def runWithConnection(func, *args, **kwargs):
            return threads.deferToThreadPool(
                pool._reactor,
                pool.threadpool,
                pool._runWithConnection,
                func,
                *args,
                **kwargs
            )
    
        def runInteraction(interaction, *args, **kwargs):
            return threads.deferToThreadPool(
                pool._reactor,
                pool.threadpool,
                pool._runInteraction,
                interaction,
                *args,
                **kwargs
            )
    
        pool.runWithConnection = runWithConnection
        pool.runInteraction = runInteraction
    
        class ThreadPool:
            """
            Threadless thread pool.
            """
    
    black's avatar
    black committed
    
    
            def stop(self):
                pass
    
    
            def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
                def _(res):
                    if isinstance(res, Failure):
                        onResult(False, res)
                    else:
                        onResult(True, res)
    
                d = Deferred()
                d.addCallback(lambda x: function(*args, **kwargs))
                d.addBoth(_)
                clock._reactor.callLater(0, d.callback, True)
                return d
    
        clock.threadpool = ThreadPool()
        pool.threadpool = ThreadPool()
        return d
    
    
    
    def get_clock():
        clock = ThreadedMemoryReactorClock()
        hs_clock = Clock(clock)
        return (clock, hs_clock)