Skip to content
Snippets Groups Projects
unittest.py 21.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • Matthew Hodgson's avatar
    Matthew Hodgson committed
    # Copyright 2014-2016 OpenMarket Ltd
    
    # Copyright 2019 Matrix.org Federation C.I.C
    
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    import hashlib
    import hmac
    
    from canonicaljson import json
    
    
    from twisted.internet.defer import Deferred, succeed
    from twisted.python.threadpool import ThreadPool
    
    from synapse.api.constants import EventTypes, Membership
    from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
    
    from synapse.config.homeserver import HomeServerConfig
    
    from synapse.config.ratelimiting import FederationRateLimitConfig
    from synapse.federation.transport import server as federation_server
    
    from synapse.http.server import JsonResource
    
    from synapse.http.site import SynapseRequest
    
    from synapse.logging.context import LoggingContext
    
    from synapse.server import HomeServer
    
    from synapse.types import Requester, UserID, create_requester
    
    from synapse.util.ratelimitutils import FederationRateLimiter
    
    from tests.server import get_clock, make_request, render, setup_test_homeserver
    
    from tests.test_utils.logging_setup import setup_logging
    
    from tests.utils import default_config, setupdb
    
    setupdb()
    
    def around(target):
        """A CLOS-style 'around' modifier, which wraps the original method of the
        given instance with another piece of code.
    
        @around(self)
        def method_name(orig, *args, **kwargs):
            return orig(*args, **kwargs)
        """
    
    black's avatar
    black committed
    
    
        def _around(code):
            name = code.__name__
            orig = getattr(target, name)
    
            def new(*args, **kwargs):
                return code(orig, *args, **kwargs)
    
        """A subclass of twisted.trial's TestCase which looks for 'loglevel'
        attributes on both itself and its individual test methods, to override the
        root logger's logging level while that test (case|method) runs."""
    
    
        def __init__(self, methodName, *args, **kwargs):
            super(TestCase, self).__init__(methodName, *args, **kwargs)
    
            level = getattr(method, "loglevel", getattr(self, "loglevel", None))
    
                # if we're not starting in the sentinel logcontext, then to be honest
                # all future bets are off.
                if LoggingContext.current_context() is not LoggingContext.sentinel:
                    self.fail(
    
                        "Test starting with non-sentinel logging context %s"
                        % (LoggingContext.current_context(),)
    
                old_level = logging.getLogger().level
    
                if level is not None and old_level != level:
    
    black's avatar
    black committed
    
    
                    logging.getLogger().setLevel(level)
    
    
            @around(self)
            def tearDown(orig):
                ret = orig()
                # force a GC to workaround problems with deferreds leaking logcontexts when
                # they are GCed (see the logcontext docs)
                gc.collect()
                LoggingContext.set_current_context(LoggingContext.sentinel)
    
                return ret
    
    
        def assertObjectHasAttributes(self, attrs, obj):
            """Asserts that the given object has each of the attributes given, and
            that the value of each matches according to assertEquals."""
            for (key, value) in attrs.items():
                if not hasattr(obj, key):
                    raise AssertionError("Expected obj to have a '.%s'" % key)
                try:
                    self.assertEquals(attrs[key], getattr(obj, key))
                except AssertionError as e:
    
                    raise (type(e))(e.message + " for '.%s'" % key)
    
        def assert_dict(self, required, actual):
            """Does a partial assert of a dict.
    
            Args:
                required (dict): The keys and value which MUST be in 'actual'.
                actual (dict): The test result. Extra keys will not be checked.
            """
            for key in required:
    
    black's avatar
    black committed
                self.assertEquals(
                    required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
                )
    
        """A decorator to set the .loglevel attribute to logging.DEBUG.
        Can apply to either a TestCase or an individual test method."""
    
    def INFO(target):
        """A decorator to set the .loglevel attribute to logging.INFO.
        Can apply to either a TestCase or an individual test method."""
        target.loglevel = logging.INFO
        return target
    
    
    
    class HomeserverTestCase(TestCase):
        """
        A base TestCase that reduces boilerplate for HomeServer-using test cases.
    
    
        Defines a setUp method which creates a mock reactor, and instantiates a homeserver
        running on that reactor.
    
        There are various hooks for modifying the way that the homeserver is instantiated:
    
        * override make_homeserver, for example by making it pass different parameters into
          setup_test_homeserver.
    
        * override default_config, to return a modified configuration dictionary for use
          by setup_test_homeserver.
    
        * On a per-test basis, you can use the @override_config decorator to give a
          dictionary containing additional configuration settings to be added to the basic
          config dict.
    
    
        Attributes:
            servlets (list[function]): List of servlet registration function.
            user_id (str): The user ID to assume if auth is hijacked.
            hijack_auth (bool): Whether to hijack auth to return the user specified
            in user_id.
        """
    
        def __init__(self, methodName, *args, **kwargs):
            super().__init__(methodName, *args, **kwargs)
    
            # see if we have any additional config for this test
            method = getattr(self, methodName)
            self._extra_config = getattr(method, "_extra_config", None)
    
    
        def setUp(self):
            """
            Set up the TestCase by calling the homeserver constructor, optionally
            hijacking the authentication system to return a fixed user, and then
            calling the prepare function.
            """
            self.reactor, self.clock = get_clock()
            self._hs_args = {"clock": self.clock, "reactor": self.reactor}
            self.hs = self.make_homeserver(self.reactor, self.clock)
    
            if self.hs is None:
                raise Exception("No homeserver returned from make_homeserver.")
    
            if not isinstance(self.hs, HomeServer):
                raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
    
            # Register the resources
    
            self.resource = self.create_test_json_resource()
    
            from tests.rest.client.v1.utils import RestHelper
    
            self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
    
            if hasattr(self, "user_id"):
    
                if self.hijack_auth:
    
                    def get_user_by_access_token(token=None, allow_guest=False):
    
                        return succeed(
                            {
                                "user": UserID.from_string(self.helper.auth_user_id),
                                "token_id": 1,
                                "is_guest": False,
                            }
                        )
    
    
                    def get_user_by_req(request, allow_guest=False, rights="access"):
    
                        return succeed(
                            create_requester(
                                UserID.from_string(self.helper.auth_user_id), 1, False, None
                            )
    
                        )
    
                    self.hs.get_auth().get_user_by_req = get_user_by_req
                    self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
                    self.hs.get_auth().get_access_token_from_request = Mock(
                        return_value="1234"
                    )
    
    
            if self.needs_threadpool:
                self.reactor.threadpool = ThreadPool()
                self.addCleanup(self.reactor.threadpool.stop)
                self.reactor.threadpool.start()
    
    
            if hasattr(self, "prepare"):
                self.prepare(self.reactor, self.clock, self.hs)
    
    
        def wait_on_thread(self, deferred, timeout=10):
            """
            Wait until a Deferred is done, where it's waiting on a real thread.
            """
            start_time = time.time()
    
            while not deferred.called:
                if start_time + timeout < time.time():
                    raise ValueError("Timed out waiting for threadpool")
                self.reactor.advance(0.01)
                time.sleep(0.01)
    
    
        def make_homeserver(self, reactor, clock):
            """
            Make and return a homeserver.
    
            Args:
                reactor: A Twisted Reactor, or something that pretends to be one.
                clock (synapse.util.Clock): The Clock, associated with the reactor.
    
            Returns:
                A homeserver (synapse.server.HomeServer) suitable for testing.
    
            Function to be overridden in subclasses.
            """
    
            hs = self.setup_test_homeserver()
            return hs
    
        def create_test_json_resource(self):
            """
            Create a test JsonResource, with the relevant servlets registerd to it
    
            The default implementation calls each function in `servlets` to do the
            registration.
    
            Returns:
                JsonResource:
            """
            resource = JsonResource(self.hs)
    
            for servlet in self.servlets:
                servlet(self.hs, resource)
    
            return resource
    
    
        def default_config(self, name="test"):
            """
    
    
            Args:
                name (str): The homeserver name/domain.
            """
    
            config = default_config(name)
    
            # apply any additional config which was specified via the override_config
            # decorator.
            if self._extra_config is not None:
                config.update(self._extra_config)
    
            return config
    
        def prepare(self, reactor, clock, homeserver):
            """
            Prepare for the test.  This involves things like mocking out parts of
            the homeserver, or building test data common across the whole test
            suite.
    
            Args:
                reactor: A Twisted Reactor, or something that pretends to be one.
                clock (synapse.util.Clock): The Clock, associated with the reactor.
                homeserver (synapse.server.HomeServer): The HomeServer to test
                against.
    
            Function to optionally be overridden in subclasses.
            """
    
    
            self,
            method,
            path,
            content=b"",
            access_token=None,
            request=SynapseRequest,
            shorthand=True,
    
            federation_auth_origin=None,
    
            """
            Create a SynapseRequest at the path using the method and containing the
            given content.
    
            Args:
                method (bytes/unicode): The HTTP request method ("verb").
                path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
                escaped UTF-8 & spaces and such).
    
                content (bytes or dict): The body of the request. JSON-encoded, if
                a dict.
    
                shorthand: Whether to try and be helpful and prefix the given URL
                with the usual REST API path, if it doesn't contain it.
    
                federation_auth_origin (bytes|None): if set to not-None, we will add a fake
                    Authorization header pretenting to be the given server name.
    
                Tuple[synapse.http.site.SynapseRequest, channel]
    
            if isinstance(content, dict):
    
    Amber Brown's avatar
    Amber Brown committed
                content = json.dumps(content).encode("utf8")
    
            return make_request(
    
                self.reactor,
                method,
                path,
                content,
                access_token,
                request,
                shorthand,
    
    
        def render(self, request):
            """
            Render a request against the resources registered by the test class's
            servlets.
    
            Args:
                request (synapse.http.site.SynapseRequest): The request to render.
            """
            render(request, self.resource, self.reactor)
    
        def setup_test_homeserver(self, *args, **kwargs):
            """
            Set up the test homeserver, meant to be called by the overridable
            make_homeserver. It automatically passes through the test class's
            clock & reactor.
    
            Args:
                See tests.utils.setup_test_homeserver.
    
            Returns:
                synapse.server.HomeServer
            """
            kwargs = dict(kwargs)
            kwargs.update(self._hs_args)
    
            if "config" not in kwargs:
                config = self.default_config()
    
            else:
                config = kwargs["config"]
    
            # Parse the config from a config dict into a HomeServerConfig
            config_obj = HomeServerConfig()
    
            config_obj.parse_config_dict(config, "", "")
    
            hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
            stor = hs.get_datastore()
    
            # Run the database background updates.
            if hasattr(stor, "do_next_background_update"):
                while not self.get_success(stor.has_completed_background_updates()):
                    self.get_success(stor.do_next_background_update(1))
    
            return hs
    
        def pump(self, by=0.0):
    
            """
            Pump the reactor enough that Deferreds will fire.
            """
    
            self.reactor.pump([by] * 100)
    
            if not isinstance(d, Deferred):
                return d
    
            return self.successResultOf(d)
    
        def get_failure(self, d, exc):
            """
            Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
            """
            if not isinstance(d, Deferred):
                return d
            self.pump()
            return self.failureResultOf(d, exc)
    
    
        def register_user(self, username, password, admin=False):
            """
            Register a user. Requires the Admin API be registered.
    
            Args:
                username (bytes/unicode): The user part of the new user.
                password (bytes/unicode): The password of the new user.
                admin (bool): Whether the user should be created as an admin
                or not.
    
            Returns:
                The MXID of the new user (unicode).
            """
    
    Amber Brown's avatar
    Amber Brown committed
            self.hs.config.registration_shared_secret = "shared"
    
    
            # Create the user
            request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
            self.render(request)
    
            self.assertEqual(channel.code, 200)
    
            nonce = channel.json_body["nonce"]
    
            want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
    
    Amber Brown's avatar
    Amber Brown committed
            nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
    
            if admin:
                nonce_str += b"\x00admin"
            else:
                nonce_str += b"\x00notadmin"
    
    Amber Brown's avatar
    Amber Brown committed
            want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
    
            want_mac = want_mac.hexdigest()
    
            body = json.dumps(
                {
                    "nonce": nonce,
                    "username": username,
                    "password": password,
                    "admin": admin,
                    "mac": want_mac,
                }
            )
            request, channel = self.make_request(
    
    Amber Brown's avatar
    Amber Brown committed
                "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
    
            )
            self.render(request)
    
            self.assertEqual(channel.code, 200, channel.json_body)
    
    
            user_id = channel.json_body["user_id"]
            return user_id
    
        def login(self, username, password, device_id=None):
            """
            Log in a user, and get an access token. Requires the Login API be
            registered.
    
            """
            body = {"type": "m.login.password", "user": username, "password": password}
            if device_id:
                body["device_id"] = device_id
    
            request, channel = self.make_request(
    
    Amber Brown's avatar
    Amber Brown committed
                "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
    
            )
            self.render(request)
    
            self.assertEqual(channel.code, 200, channel.result)
    
            access_token = channel.json_body["access_token"]
    
            return access_token
    
        def create_and_send_event(
            self, room_id, user, soft_failed=False, prev_event_ids=None
        ):
            """
            Create and send an event.
    
            Args:
                soft_failed (bool): Whether to create a soft failed event or not
                prev_event_ids (list[str]|None): Explicitly set the prev events,
                    or if None just use the default
    
            Returns:
                str: The new event's ID.
            """
            event_creator = self.hs.get_event_creation_handler()
            secrets = self.hs.get_secrets()
            requester = Requester(user, None, False, None, None)
    
            prev_events_and_hashes = None
            if prev_event_ids:
                prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
    
            event, context = self.get_success(
                event_creator.create_event(
                    requester,
                    {
                        "type": EventTypes.Message,
                        "room_id": room_id,
                        "sender": user.to_string(),
                        "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
                    },
                    prev_events_and_hashes=prev_events_and_hashes,
                )
            )
    
            if soft_failed:
                event.internal_metadata.soft_failed = True
    
    
    Amber Brown's avatar
    Amber Brown committed
            self.get_success(event_creator.send_nonmember_event(requester, event, context))
    
    
            return event.event_id
    
        def add_extremity(self, room_id, event_id):
            """
            Add the given event as an extremity to the room.
            """
            self.get_success(
    
                self.hs.get_datastore().db.simple_insert(
    
                    table="event_forward_extremities",
                    values={"room_id": room_id, "event_id": event_id},
                    desc="test_add_extremity",
                )
            )
    
            self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
    
    
        def attempt_wrong_password_login(self, username, password):
            """Attempts to login as the user with the given password, asserting
            that the attempt *fails*.
            """
            body = {"type": "m.login.password", "user": username, "password": password}
    
            request, channel = self.make_request(
    
    Amber Brown's avatar
    Amber Brown committed
                "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
    
            )
            self.render(request)
            self.assertEqual(channel.code, 403, channel.result)
    
        def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
            """
            Inject a membership event into a room.
    
            Args:
                room: Room ID to inject the event into.
                user: MXID of the user to inject the membership for.
                membership: The membership type.
            """
            event_builder_factory = self.hs.get_event_builder_factory()
            event_creation_handler = self.hs.get_event_creation_handler()
    
            room_version = self.get_success(self.hs.get_datastore().get_room_version(room))
    
            builder = event_builder_factory.for_room_version(
                KNOWN_ROOM_VERSIONS[room_version],
                {
                    "type": EventTypes.Member,
                    "sender": user,
                    "state_key": user,
                    "room_id": room,
                    "content": {"membership": membership},
                },
            )
    
            event, context = self.get_success(
                event_creation_handler.create_new_client_event(builder)
            )
    
            self.get_success(
                self.hs.get_storage().persistence.persist_event(event, context)
            )
    
    
    class FederatingHomeserverTestCase(HomeserverTestCase):
        """
        A federating homeserver that authenticates incoming requests as `other.example.com`.
        """
    
        def prepare(self, reactor, clock, homeserver):
            class Authenticator(object):
                def authenticate_request(self, request, content):
                    return succeed("other.example.com")
    
            ratelimiter = FederationRateLimiter(
                clock,
                FederationRateLimitConfig(
                    window_size=1,
                    sleep_limit=1,
                    sleep_msec=1,
                    reject_limit=1000,
                    concurrent_requests=1000,
                ),
            )
            federation_server.register_servlets(
                homeserver, self.resource, Authenticator(), ratelimiter
            )
    
            return super().prepare(reactor, clock, homeserver)
    
    
    
    def override_config(extra_config):
        """A decorator which can be applied to test functions to give additional HS config
    
        For use
    
        For example:
    
            class MyTestCase(HomeserverTestCase):
                @override_config({"enable_registration": False, ...})
                def test_foo(self):
                    ...
    
        Args:
            extra_config(dict): Additional config settings to be merged into the default
                config dict before instantiating the test homeserver.
        """
    
        def decorator(func):
            func._extra_config = extra_config
            return func
    
        return decorator