Skip to content
Snippets Groups Projects
Commit 95f3fcda authored by Erik Johnston's avatar Erik Johnston
Browse files

Check that event is visible in new APIs

parent b5c62c6b
No related branches found
No related tags found
No related merge requests found
...@@ -131,6 +131,7 @@ class RelationPaginationServlet(RestServlet): ...@@ -131,6 +131,7 @@ class RelationPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
...@@ -140,6 +141,10 @@ class RelationPaginationServlet(RestServlet): ...@@ -140,6 +141,10 @@ class RelationPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token = parse_string(request, "from")
to_token = parse_string(request, "to") to_token = parse_string(request, "to")
...@@ -195,6 +200,7 @@ class RelationAggregationPaginationServlet(RestServlet): ...@@ -195,6 +200,7 @@ class RelationAggregationPaginationServlet(RestServlet):
super(RelationAggregationPaginationServlet, self).__init__() super(RelationAggregationPaginationServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
...@@ -204,6 +210,10 @@ class RelationAggregationPaginationServlet(RestServlet): ...@@ -204,6 +210,10 @@ class RelationAggregationPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
...@@ -258,6 +268,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -258,6 +268,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
...@@ -267,6 +278,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -267,6 +278,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
...@@ -296,8 +311,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -296,8 +311,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
defer.returnValue((200, return_value)) defer.returnValue((200, return_value))
defer.returnValue((200, return_value))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RelationSendServlet(hs).register(http_server) RelationSendServlet(hs).register(http_server)
......
...@@ -296,7 +296,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): ...@@ -296,7 +296,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request( request, channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/aggregations/m.replaces/%s?limit=1" "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1"
% (self.room, self.parent_id), % (self.room, self.parent_id),
) )
self.render(request) self.render(request)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment