Skip to content
Snippets Groups Projects
test_relations.py 20.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • # -*- coding: utf-8 -*-
    # Copyright 2019 New Vector Ltd
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    
    Erik Johnston's avatar
    Erik Johnston committed
    import itertools
    
    import json
    
    import six
    
    from synapse.api.constants import EventTypes, RelationTypes
    
    from synapse.rest import admin
    
    from synapse.rest.client.v1 import login, room
    
    from synapse.rest.client.v2_alpha import register, relations
    
    
    from tests import unittest
    
    
    class RelationsTestCase(unittest.HomeserverTestCase):
        servlets = [
            relations.register_servlets,
            room.register_servlets,
            login.register_servlets,
    
            register.register_servlets,
            admin.register_servlets_for_client_rest_resource,
    
    Erik Johnston's avatar
    Erik Johnston committed
        def make_homeserver(self, reactor, clock):
            # We need to enable msc1849 support for aggregations
            config = self.default_config()
            config["experimental_msc1849_support_enabled"] = True
            return self.setup_test_homeserver(config=config)
    
    
        def prepare(self, reactor, clock, hs):
    
            self.user_id, self.user_token = self._create_user("alice")
            self.user2_id, self.user2_token = self._create_user("bob")
    
            self.room = self.helper.create_room_as(self.user_id, tok=self.user_token)
            self.helper.join(self.room, user=self.user2_id, tok=self.user2_token)
            res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
    
            self.parent_id = res["event_id"]
    
        def test_send_relation(self):
            """Tests that sending a relation using the new /send_relation works
            creates the right shape of event.
            """
    
    
    Amber Brown's avatar
    Amber Brown committed
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
    
            self.assertEquals(200, channel.code, channel.json_body)
    
            event_id = channel.json_body["event_id"]
    
            request, channel = self.make_request(
    
                "GET",
                "/rooms/%s/event/%s" % (self.room, event_id),
                access_token=self.user_token,
    
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assert_dict(
                {
                    "type": "m.reaction",
                    "sender": self.user_id,
                    "content": {
                        "m.relates_to": {
                            "event_id": self.parent_id,
    
    Amber Brown's avatar
    Amber Brown committed
                            "key": "👍",
    
                            "rel_type": RelationTypes.ANNOTATION,
                        }
                    },
                },
                channel.json_body,
            )
    
        def test_deny_membership(self):
            """Test that we deny relations on membership events
            """
            channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member)
            self.assertEquals(400, channel.code, channel.json_body)
    
    
        def test_deny_double_react(self):
            """Test that we deny relations on membership events
            """
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
            self.assertEquals(200, channel.code, channel.json_body)
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
            self.assertEquals(400, channel.code, channel.json_body)
    
    
    Erik Johnston's avatar
    Erik Johnston committed
        def test_basic_paginate_relations(self):
    
            """Tests that calling pagination API corectly the latest relations.
            """
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
            self.assertEquals(200, channel.code, channel.json_body)
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
            self.assertEquals(200, channel.code, channel.json_body)
            annotation_id = channel.json_body["event_id"]
    
            request, channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
                % (self.room, self.parent_id),
    
                access_token=self.user_token,
    
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            # We expect to get back a single pagination result, which is the full
            # relation event we sent above.
            self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
            self.assert_dict(
    
                {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
    
                channel.json_body["chunk"][0],
            )
    
    
    Erik Johnston's avatar
    Erik Johnston committed
            # Make sure next_batch has something in it that looks like it could be a
            # valid token.
            self.assertIsInstance(
                channel.json_body.get("next_batch"), six.string_types, channel.json_body
            )
    
        def test_repeated_paginate_relations(self):
            """Test that if we paginate using a limit and tokens then we get the
            expected events.
            """
    
            expected_event_ids = []
            for _ in range(10):
                channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction")
                self.assertEquals(200, channel.code, channel.json_body)
                expected_event_ids.append(channel.json_body["event_id"])
    
            prev_token = None
            found_event_ids = []
            for _ in range(20):
                from_token = ""
                if prev_token:
                    from_token = "&from=" + prev_token
    
                request, channel = self.make_request(
                    "GET",
                    "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
                    % (self.room, self.parent_id, from_token),
    
                    access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
                )
                self.render(request)
                self.assertEquals(200, channel.code, channel.json_body)
    
                found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
                next_batch = channel.json_body.get("next_batch")
    
                self.assertNotEquals(prev_token, next_batch)
                prev_token = next_batch
    
                if not prev_token:
                    break
    
            # We paginated backwards, so reverse
            found_event_ids.reverse()
            self.assertEquals(found_event_ids, expected_event_ids)
    
        def test_aggregation_pagination_groups(self):
            """Test that we can paginate annotation groups correctly.
            """
    
    
            # We need to create ten separate users to send each reaction.
            access_tokens = [self.user_token, self.user2_token]
            idx = 0
            while len(access_tokens) < 10:
                user_id, token = self._create_user("test" + str(idx))
                idx += 1
    
                self.helper.join(self.room, user=user_id, tok=token)
                access_tokens.append(token)
    
            idx = 0
    
    Amber Brown's avatar
    Amber Brown committed
            sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1}
    
    Erik Johnston's avatar
    Erik Johnston committed
            for key in itertools.chain.from_iterable(
                itertools.repeat(key, num) for key, num in sent_groups.items()
            ):
                channel = self._send_relation(
    
                    RelationTypes.ANNOTATION,
                    "m.reaction",
                    key=key,
                    access_token=access_tokens[idx],
    
    Erik Johnston's avatar
    Erik Johnston committed
                )
                self.assertEquals(200, channel.code, channel.json_body)
    
    
                idx += 1
                idx %= len(access_tokens)
    
    
    Erik Johnston's avatar
    Erik Johnston committed
            prev_token = None
            found_groups = {}
            for _ in range(20):
                from_token = ""
                if prev_token:
                    from_token = "&from=" + prev_token
    
                request, channel = self.make_request(
                    "GET",
                    "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
                    % (self.room, self.parent_id, from_token),
    
                    access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
                )
                self.render(request)
                self.assertEquals(200, channel.code, channel.json_body)
    
                self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
    
                for groups in channel.json_body["chunk"]:
                    # We only expect reactions
                    self.assertEqual(groups["type"], "m.reaction", channel.json_body)
    
                    # We should only see each key once
                    self.assertNotIn(groups["key"], found_groups, channel.json_body)
    
                    found_groups[groups["key"]] = groups["count"]
    
                next_batch = channel.json_body.get("next_batch")
    
                self.assertNotEquals(prev_token, next_batch)
                prev_token = next_batch
    
                if not prev_token:
                    break
    
            self.assertEquals(sent_groups, found_groups)
    
        def test_aggregation_pagination_within_group(self):
            """Test that we can paginate within an annotation group.
            """
    
    
            # We need to create ten separate users to send each reaction.
            access_tokens = [self.user_token, self.user2_token]
            idx = 0
            while len(access_tokens) < 10:
                user_id, token = self._create_user("test" + str(idx))
                idx += 1
    
                self.helper.join(self.room, user=user_id, tok=token)
                access_tokens.append(token)
    
            idx = 0
    
    Erik Johnston's avatar
    Erik Johnston committed
            expected_event_ids = []
            for _ in range(10):
                channel = self._send_relation(
    
                    RelationTypes.ANNOTATION,
                    "m.reaction",
    
    Amber Brown's avatar
    Amber Brown committed
                    key="👍",
    
                    access_token=access_tokens[idx],
    
    Erik Johnston's avatar
    Erik Johnston committed
                )
                self.assertEquals(200, channel.code, channel.json_body)
                expected_event_ids.append(channel.json_body["event_id"])
    
    
    Erik Johnston's avatar
    Erik Johnston committed
            # Also send a different type of reaction so that we test we don't see it
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
            self.assertEquals(200, channel.code, channel.json_body)
    
            prev_token = None
            found_event_ids = []
    
    Amber Brown's avatar
    Amber Brown committed
            encoded_key = six.moves.urllib.parse.quote_plus("👍".encode("utf-8"))
    
    Erik Johnston's avatar
    Erik Johnston committed
            for _ in range(20):
                from_token = ""
                if prev_token:
                    from_token = "&from=" + prev_token
    
                request, channel = self.make_request(
                    "GET",
                    "/_matrix/client/unstable/rooms/%s"
                    "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
                    % (
                        self.room,
                        self.parent_id,
                        RelationTypes.ANNOTATION,
                        encoded_key,
                        from_token,
                    ),
    
                    access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
                )
                self.render(request)
                self.assertEquals(200, channel.code, channel.json_body)
    
                self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
    
                found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
    
                next_batch = channel.json_body.get("next_batch")
    
                self.assertNotEquals(prev_token, next_batch)
                prev_token = next_batch
    
                if not prev_token:
                    break
    
            # We paginated backwards, so reverse
            found_event_ids.reverse()
            self.assertEquals(found_event_ids, expected_event_ids)
    
        def test_aggregation(self):
            """Test that annotations get correctly aggregated.
            """
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
            self.assertEquals(200, channel.code, channel.json_body)
    
    
            channel = self._send_relation(
                RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
            )
    
    Erik Johnston's avatar
    Erik Johnston committed
            self.assertEquals(200, channel.code, channel.json_body)
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
            self.assertEquals(200, channel.code, channel.json_body)
    
            request, channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/rooms/%s/aggregations/%s"
                % (self.room, self.parent_id),
    
                access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assertEquals(
                channel.json_body,
                {
                    "chunk": [
                        {"type": "m.reaction", "key": "a", "count": 2},
                        {"type": "m.reaction", "key": "b", "count": 1},
                    ]
                },
            )
    
    
        def test_aggregation_redactions(self):
    
    Erik Johnston's avatar
    Erik Johnston committed
            """Test that annotations get correctly aggregated after a redaction.
    
            """
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
            self.assertEquals(200, channel.code, channel.json_body)
            to_redact_event_id = channel.json_body["event_id"]
    
            channel = self._send_relation(
                RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
            )
            self.assertEquals(200, channel.code, channel.json_body)
    
    
    Erik Johnston's avatar
    Erik Johnston committed
            # Now lets redact one of the 'a' reactions
    
            request, channel = self.make_request(
                "POST",
                "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
                access_token=self.user_token,
                content={},
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            request, channel = self.make_request(
                "GET",
                "/_matrix/client/unstable/rooms/%s/aggregations/%s"
                % (self.room, self.parent_id),
                access_token=self.user_token,
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assertEquals(
                channel.json_body,
                {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
            )
    
    
    Erik Johnston's avatar
    Erik Johnston committed
        def test_aggregation_must_be_annotation(self):
            """Test that aggregations must be annotations.
            """
    
            request, channel = self.make_request(
                "GET",
    
                "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
                % (self.room, self.parent_id, RelationTypes.REPLACE),
    
                access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
            )
            self.render(request)
            self.assertEquals(400, channel.code, channel.json_body)
    
        def test_aggregation_get_event(self):
            """Test that annotations and references get correctly bundled when
            getting the parent event.
            """
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
            self.assertEquals(200, channel.code, channel.json_body)
    
    
            channel = self._send_relation(
                RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
            )
    
    Erik Johnston's avatar
    Erik Johnston committed
            self.assertEquals(200, channel.code, channel.json_body)
    
            channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
            self.assertEquals(200, channel.code, channel.json_body)
    
    
            channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
    
    Erik Johnston's avatar
    Erik Johnston committed
            self.assertEquals(200, channel.code, channel.json_body)
            reply_1 = channel.json_body["event_id"]
    
    
            channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
    
    Erik Johnston's avatar
    Erik Johnston committed
            self.assertEquals(200, channel.code, channel.json_body)
            reply_2 = channel.json_body["event_id"]
    
            request, channel = self.make_request(
    
                "GET",
                "/rooms/%s/event/%s" % (self.room, self.parent_id),
                access_token=self.user_token,
    
    Erik Johnston's avatar
    Erik Johnston committed
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assertEquals(
                channel.json_body["unsigned"].get("m.relations"),
                {
                    RelationTypes.ANNOTATION: {
                        "chunk": [
                            {"type": "m.reaction", "key": "a", "count": 2},
                            {"type": "m.reaction", "key": "b", "count": 1},
                        ]
                    },
    
                    RelationTypes.REFERENCE: {
    
    Erik Johnston's avatar
    Erik Johnston committed
                        "chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
                    },
                },
            )
    
    
        def test_edit(self):
            """Test that a simple edit works.
            """
    
            new_body = {"msgtype": "m.text", "body": "I've been edited!"}
            channel = self._send_relation(
    
                RelationTypes.REPLACE,
    
                "m.room.message",
                content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
            )
            self.assertEquals(200, channel.code, channel.json_body)
    
            edit_event_id = channel.json_body["event_id"]
    
            request, channel = self.make_request(
    
                "GET",
                "/rooms/%s/event/%s" % (self.room, self.parent_id),
                access_token=self.user_token,
    
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assertEquals(channel.json_body["content"], new_body)
    
    
            relations_dict = channel.json_body["unsigned"].get("m.relations")
            self.assertIn(RelationTypes.REPLACE, relations_dict)
    
            m_replace_dict = relations_dict[RelationTypes.REPLACE]
            for key in ["event_id", "sender", "origin_server_ts"]:
                self.assertIn(key, m_replace_dict)
    
            self.assert_dict(
                {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
    
            )
    
        def test_multi_edit(self):
            """Test that multiple edits, including attempts by people who
            shouldn't be allowed, are correctly handled.
            """
    
            channel = self._send_relation(
    
                RelationTypes.REPLACE,
    
                "m.room.message",
                content={
                    "msgtype": "m.text",
                    "body": "Wibble",
                    "m.new_content": {"msgtype": "m.text", "body": "First edit"},
                },
            )
            self.assertEquals(200, channel.code, channel.json_body)
    
            new_body = {"msgtype": "m.text", "body": "I've been edited!"}
            channel = self._send_relation(
    
                RelationTypes.REPLACE,
    
                "m.room.message",
                content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
            )
            self.assertEquals(200, channel.code, channel.json_body)
    
            edit_event_id = channel.json_body["event_id"]
    
            channel = self._send_relation(
    
                RelationTypes.REPLACE,
    
                "m.room.message.WRONG_TYPE",
                content={
                    "msgtype": "m.text",
                    "body": "Wibble",
                    "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"},
                },
            )
            self.assertEquals(200, channel.code, channel.json_body)
    
            request, channel = self.make_request(
    
                "GET",
                "/rooms/%s/event/%s" % (self.room, self.parent_id),
                access_token=self.user_token,
    
            )
            self.render(request)
            self.assertEquals(200, channel.code, channel.json_body)
    
            self.assertEquals(channel.json_body["content"], new_body)
    
    
            relations_dict = channel.json_body["unsigned"].get("m.relations")
            self.assertIn(RelationTypes.REPLACE, relations_dict)
    
            m_replace_dict = relations_dict[RelationTypes.REPLACE]
            for key in ["event_id", "sender", "origin_server_ts"]:
                self.assertIn(key, m_replace_dict)
    
            self.assert_dict(
                {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
    
        def _send_relation(
            self, relation_type, event_type, key=None, content={}, access_token=None
        ):
    
            """Helper function to send a relation pointing at `self.parent_id`
    
            Args:
                relation_type (str): One of `RelationTypes`
                event_type (str): The type of the event to create
                key (str|None): The aggregation key used for m.annotation relation
                    type.
    
                content(dict|None): The content of the created event.
    
                access_token (str|None): The access token used to send the relation,
                    defaults to `self.user_token`
    
            if not access_token:
                access_token = self.user_token
    
    
    Erik Johnston's avatar
    Erik Johnston committed
                query = "?key=" + six.moves.urllib.parse.quote_plus(key.encode("utf-8"))
    
    
            request, channel = self.make_request(
                "POST",
                "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
                % (self.room, self.parent_id, relation_type, event_type, query),
    
                json.dumps(content).encode("utf-8"),
    
                access_token=access_token,
    
            )
            self.render(request)
            return channel
    
    
        def _create_user(self, localpart):
            user_id = self.register_user(localpart, "abc123")
            access_token = self.login(localpart, "abc123")
    
            return user_id, access_token