Skip to content
Snippets Groups Projects
Unverified Commit e9b2d047 authored by Matthew Hodgson's avatar Matthew Hodgson Committed by GitHub
Browse files

make /context lazyload & filter aware (#3567)

make /context lazyload & filter aware.
parent b0b5566f
Branches
Tags
No related merge requests found
make the /context API filter & lazy-load aware as per MSC1227
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
"""Contains functions for performing events on rooms.""" """Contains functions for performing events on rooms."""
import itertools
import logging import logging
import math import math
import string import string
...@@ -401,7 +402,7 @@ class RoomContextHandler(object): ...@@ -401,7 +402,7 @@ class RoomContextHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit): def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event """Retrieves events, pagination tokens and state around a given event
in a room. in a room.
...@@ -411,6 +412,8 @@ class RoomContextHandler(object): ...@@ -411,6 +412,8 @@ class RoomContextHandler(object):
event_id (str) event_id (str)
limit (int): The maximum number of events to return in total limit (int): The maximum number of events to return in total
(excluding state). (excluding state).
event_filter (Filter|None): the filter to apply to the events returned
(excluding the target event_id)
Returns: Returns:
dict, or None if the event isn't found dict, or None if the event isn't found
...@@ -443,7 +446,7 @@ class RoomContextHandler(object): ...@@ -443,7 +446,7 @@ class RoomContextHandler(object):
) )
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit room_id, event_id, before_limit, after_limit, event_filter
) )
results["events_before"] = yield filter_evts(results["events_before"]) results["events_before"] = yield filter_evts(results["events_before"])
...@@ -455,8 +458,23 @@ class RoomContextHandler(object): ...@@ -455,8 +458,23 @@ class RoomContextHandler(object):
else: else:
last_event_id = event_id last_event_id = event_id
types = None
filtered_types = None
if event_filter and event_filter.lazy_load_members():
members = set(ev.sender for ev in itertools.chain(
results["events_before"],
(results["event"],),
results["events_after"],
))
filtered_types = [EventTypes.Member]
types = [(EventTypes.Member, member) for member in members]
# XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events( state = yield self.store.get_state_for_events(
[last_event_id], None [last_event_id], types, filtered_types=filtered_types,
) )
results["state"] = list(state[last_event_id].values()) results["state"] = list(state[last_event_id].values())
......
...@@ -287,7 +287,7 @@ class SearchHandler(BaseHandler): ...@@ -287,7 +287,7 @@ class SearchHandler(BaseHandler):
contexts = {} contexts = {}
for event in allowed_events: for event in allowed_events:
res = yield self.store.get_events_around( res = yield self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit event.room_id, event.event_id, before_limit, after_limit,
) )
logger.info( logger.info(
......
...@@ -531,11 +531,20 @@ class RoomEventContextServlet(ClientV1RestServlet): ...@@ -531,11 +531,20 @@ class RoomEventContextServlet(ClientV1RestServlet):
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
results = yield self.room_context_handler.get_event_context( results = yield self.room_context_handler.get_event_context(
requester.user, requester.user,
room_id, room_id,
event_id, event_id,
limit, limit,
event_filter,
) )
if not results: if not results:
......
...@@ -527,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -527,7 +527,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit): def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None,
):
"""Retrieve events and pagination tokens around a given event in a """Retrieve events and pagination tokens around a given event in a
room. room.
...@@ -536,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -536,6 +538,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str) event_id (str)
before_limit (int) before_limit (int)
after_limit (int) after_limit (int)
event_filter (Filter|None)
Returns: Returns:
dict dict
...@@ -543,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -543,7 +546,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = yield self.runInteraction( results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn, "get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit room_id, event_id, before_limit, after_limit, event_filter,
) )
events_before = yield self._get_events( events_before = yield self._get_events(
...@@ -563,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -563,7 +566,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"end": results["after"]["token"], "end": results["after"]["token"],
}) })
def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit): def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
):
"""Retrieves event_ids and pagination tokens around a given event in a """Retrieves event_ids and pagination tokens around a given event in a
room. room.
...@@ -572,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -572,6 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id (str) event_id (str)
before_limit (int) before_limit (int)
after_limit (int) after_limit (int)
event_filter (Filter|None)
Returns: Returns:
dict dict
...@@ -601,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ...@@ -601,11 +607,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows, start_token = self._paginate_room_events_txn( rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit, txn, room_id, before_token, direction='b', limit=before_limit,
event_filter=event_filter,
) )
events_before = [r.event_id for r in rows] events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn( rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit, txn, room_id, after_token, direction='f', limit=after_limit,
event_filter=event_filter,
) )
events_after = [r.event_id for r in rows] events_after = [r.event_id for r in rows]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment