Skip to content
Snippets Groups Projects
Unverified Commit 1799a54a authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Batch fetch bundled annotations (#14491)

Avoid an n+1 query problem and fetch the bundled aggregations for
m.annotation relations in a single query instead of a query per event.

This applies similar logic for as was previously done for edits in
8b309adb (#11660) and threads
in b65acead (#11752).
parent da933bfc
No related branches found
No related tags found
No related merge requests found
Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations.
...@@ -13,7 +13,16 @@ ...@@ -13,7 +13,16 @@
# limitations under the License. # limitations under the License.
import enum import enum
import logging import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
import attr import attr
...@@ -259,48 +268,64 @@ class RelationsHandler: ...@@ -259,48 +268,64 @@ class RelationsHandler:
e.msg, e.msg,
) )
async def get_annotations_for_event( async def get_annotations_for_events(
self, self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
event_id: str, ) -> Dict[str, List[JsonDict]]:
room_id: str, """Get a list of annotations to the given events, grouped by event type and
limit: int = 5,
ignored_users: FrozenSet[str] = frozenset(),
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count. aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend This is used e.g. to get the what and how many reactions have happened
on an event. on an event.
Args: Args:
event_id: Fetch events that relate to this event ID. event_ids: Fetch events that relate to these event IDs.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
ignored_users: The users ignored by the requesting user. ignored_users: The users ignored by the requesting user.
Returns: Returns:
List of groups of annotations that match. Each row is a dict with A map of event IDs to a list of groups of annotations that match.
`type`, `key` and `count` fields. Each entry is a dict with `type`, `key` and `count` fields.
""" """
# Get the base results for all users. # Get the base results for all users.
full_results = await self._main_store.get_aggregation_groups_for_event( full_results = await self._main_store.get_aggregation_groups_for_events(
event_id, room_id, limit event_ids
) )
# Avoid additional logic if there are no ignored users.
if not ignored_users:
return {
event_id: results
for event_id, results in full_results.items()
if results
}
# Then subtract off the results for any ignored users. # Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users( ignored_results = await self._main_store.get_aggregation_groups_for_users(
event_id, room_id, limit, ignored_users [event_id for event_id, results in full_results.items() if results],
ignored_users,
) )
filtered_results = [] filtered_results = {}
for result in full_results: for event_id, results in full_results.items():
key = (result["type"], result["key"]) # If no annotations, skip.
if key in ignored_results: if not results:
result = result.copy() continue
result["count"] -= ignored_results[key]
if result["count"] <= 0: # If there are not ignored results for this event, copy verbatim.
continue if event_id not in ignored_results:
filtered_results.append(result) filtered_results[event_id] = results
continue
# Otherwise, subtract out the ignored results.
event_ignored_results = ignored_results[event_id]
for result in results:
key = (result["type"], result["key"])
if key in event_ignored_results:
# Ensure to not modify the cache.
result = result.copy()
result["count"] -= event_ignored_results[key]
if result["count"] <= 0:
continue
filtered_results.setdefault(event_id, []).append(result)
return filtered_results return filtered_results
...@@ -366,59 +391,62 @@ class RelationsHandler: ...@@ -366,59 +391,62 @@ class RelationsHandler:
results = {} results = {}
for event_id, summary in summaries.items(): for event_id, summary in summaries.items():
if summary: # If no thread, skip.
thread_count, latest_thread_event = summary if not summary:
continue
# Subtract off the count of any ignored users.
for ignored_user in ignored_users:
thread_count -= ignored_results.get((event_id, ignored_user), 0)
# This is gnarly, but if the latest event is from an ignored user,
# attempt to find one that isn't from an ignored user.
if latest_thread_event.sender in ignored_users:
room_id = latest_thread_event.room_id
# If the root event is not found, something went wrong, do
# not include a summary of the thread.
event = await self._event_handler.get_event(user, room_id, event_id)
if event is None:
continue
potential_events, _ = await self.get_relations_for_event( thread_count, latest_thread_event = summary
event_id,
event,
room_id,
RelationTypes.THREAD,
ignored_users,
)
# If all found events are from ignored users, do not include # Subtract off the count of any ignored users.
# a summary of the thread. for ignored_user in ignored_users:
if not potential_events: thread_count -= ignored_results.get((event_id, ignored_user), 0)
continue
# The *last* event returned is the one that is cared about. # This is gnarly, but if the latest event is from an ignored user,
event = await self._event_handler.get_event( # attempt to find one that isn't from an ignored user.
user, room_id, potential_events[-1].event_id if latest_thread_event.sender in ignored_users:
) room_id = latest_thread_event.room_id
# It is unexpected that the event will not exist.
if event is None: # If the root event is not found, something went wrong, do
logger.warning( # not include a summary of the thread.
"Unable to fetch latest event in a thread with event ID: %s", event = await self._event_handler.get_event(user, room_id, event_id)
potential_events[-1].event_id, if event is None:
) continue
continue
latest_thread_event = event potential_events, _ = await self.get_relations_for_event(
event_id,
results[event_id] = _ThreadAggregation( event,
latest_event=latest_thread_event, room_id,
count=thread_count, RelationTypes.THREAD,
# If there's a thread summary it must also exist in the ignored_users,
# participated dictionary.
current_user_participated=events_by_id[event_id].sender == user_id
or participated[event_id],
) )
# If all found events are from ignored users, do not include
# a summary of the thread.
if not potential_events:
continue
# The *last* event returned is the one that is cared about.
event = await self._event_handler.get_event(
user, room_id, potential_events[-1].event_id
)
# It is unexpected that the event will not exist.
if event is None:
logger.warning(
"Unable to fetch latest event in a thread with event ID: %s",
potential_events[-1].event_id,
)
continue
latest_thread_event = event
results[event_id] = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=events_by_id[event_id].sender == user_id
or participated[event_id],
)
return results return results
@trace @trace
...@@ -496,17 +524,18 @@ class RelationsHandler: ...@@ -496,17 +524,18 @@ class RelationsHandler:
# (as that is what makes it part of the thread). # (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
# Fetch other relations per event. # Fetch any annotations (ie, reactions) to bundle with this event.
for event in events_by_id.values(): annotations_by_event_id = await self.get_annotations_for_events(
# Fetch any annotations (ie, reactions) to bundle with this event. events_by_id.keys(), ignored_users=ignored_users
annotations = await self.get_annotations_for_event( )
event.event_id, event.room_id, ignored_users=ignored_users for event_id, annotations in annotations_by_event_id.items():
)
if annotations: if annotations:
results.setdefault( results.setdefault(event_id, BundledAggregations()).annotations = {
event.event_id, BundledAggregations() "chunk": annotations
).annotations = {"chunk": annotations} }
# Fetch other relations per event.
for event in events_by_id.values():
# Fetch any references to bundle with this event. # Fetch any references to bundle with this event.
references, next_token = await self.get_relations_for_event( references, next_token = await self.get_relations_for_event(
event.event_id, event.event_id,
......
...@@ -20,6 +20,7 @@ from typing import ( ...@@ -20,6 +20,7 @@ from typing import (
FrozenSet, FrozenSet,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Set, Set,
Tuple, Tuple,
...@@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore): ...@@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore):
) )
return result is not None return result is not None
@cached(tree=True) @cached()
async def get_aggregation_groups_for_event( async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
self, event_id: str, room_id: str, limit: int = 5 raise NotImplementedError()
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and @cachedList(
cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
)
async def get_aggregation_groups_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[JsonDict]]]:
"""Get a list of annotations on the given events, grouped by event type and
aggregation key, sorted by count. aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend This is used e.g. to get the what and how many reactions have happend
on an event. on an event.
Args: Args:
event_id: Fetch events that relate to this event ID. event_ids: Fetch events that relate to these event IDs.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
Returns: Returns:
List of groups of annotations that match. Each row is a dict with A map of event IDs to a list of groups of annotations that match.
`type`, `key` and `count` fields. Each entry is a dict with `type`, `key` and `count` fields.
""" """
# The number of entries to return per event ID.
limit = 5
args = [ clause, args = make_in_list_sql_clause(
event_id, self.database_engine, "relates_to_id", event_ids
room_id, )
RelationTypes.ANNOTATION, args.append(RelationTypes.ANNOTATION)
limit,
]
sql = """ sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender) SELECT
FROM event_relations relates_to_id,
INNER JOIN events USING (event_id) annotation.type,
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? aggregation_key,
GROUP BY relation_type, type, aggregation_key COUNT(DISTINCT annotation.sender)
ORDER BY COUNT(*) DESC FROM events AS annotation
LIMIT ? INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE
{clause}
AND relation_type = ?
GROUP BY relates_to_id, annotation.type, aggregation_key
ORDER BY relates_to_id, COUNT(*) DESC
""" """
def _get_aggregation_groups_for_event_txn( def _get_aggregation_groups_for_events_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[JsonDict]: ) -> Mapping[str, List[JsonDict]]:
txn.execute(sql, args) txn.execute(sql, args)
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn] result: Dict[str, List[JsonDict]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
event_results = result.setdefault(event_id, [])
# Limit the number of results per event ID.
if len(event_results) == limit:
continue
event_results.append({"type": type, "key": key, "count": count})
return result
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
) )
async def get_aggregation_groups_for_users( async def get_aggregation_groups_for_users(
self, self, event_ids: Collection[str], users: FrozenSet[str]
event_id: str, ) -> Dict[str, Dict[Tuple[str, str], int]]:
room_id: str,
limit: int,
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
"""Fetch the partial aggregations for an event for specific users. """Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users. remove information from the results for ignored users.
Args: Args:
event_id: Fetch events that relate to this event ID. event_ids: Fetch events that relate to these event IDs.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
users: The users to fetch information for. users: The users to fetch information for.
Returns: Returns:
A map of (event type, aggregation key) to a count of users. A map of event ID to a map of (event type, aggregation key) to a
count of users.
""" """
if not users: if not users:
return {} return {}
args: List[Union[str, int]] = [ events_sql, args = make_in_list_sql_clause(
event_id, self.database_engine, "relates_to_id", event_ids
room_id, )
RelationTypes.ANNOTATION,
]
users_sql, users_args = make_in_list_sql_clause( users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "sender", users self.database_engine, "annotation.sender", users
) )
args.extend(users_args) args.extend(users_args)
args.append(RelationTypes.ANNOTATION)
sql = f""" sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender) SELECT
FROM event_relations relates_to_id,
INNER JOIN events USING (event_id) annotation.type,
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql} aggregation_key,
GROUP BY relation_type, type, aggregation_key COUNT(DISTINCT annotation.sender)
ORDER BY COUNT(*) DESC FROM events AS annotation
LIMIT ? INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = annotation.room_id
WHERE {events_sql} AND {users_sql} AND relation_type = ?
GROUP BY relates_to_id, annotation.type, aggregation_key
ORDER BY relates_to_id, COUNT(*) DESC
""" """
def _get_aggregation_groups_for_users_txn( def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]: ) -> Dict[str, Dict[Tuple[str, str], int]]:
txn.execute(sql, args + [limit]) txn.execute(sql, args)
return {(row[0], row[1]): row[2] for row in txn} result: Dict[str, Dict[Tuple[str, str], int]] = {}
for event_id, type, key, count in cast(
List[Tuple[str, str, str, int]], txn
):
result.setdefault(event_id, {})[(type, key)] = count
return result
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
......
...@@ -503,7 +503,7 @@ def cachedList( ...@@ -503,7 +503,7 @@ def cachedList(
is specified as a list that is iterated through to lookup keys in the is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results the cache gets passed to the original function, which is expected to results
in a map of key to value for each passed value. THe new results are stored in the in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None. original cache. Note that any missing values are cached as None.
Args: Args:
......
...@@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): ...@@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# The "user" sent the root event and is making queries for the bundled # The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated. # aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9) self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
# The "user2" sent replies in the thread and is making queries for the # The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated. # bundled aggregations: they have participated.
# #
...@@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): ...@@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"), bundled_aggregations["latest_event"].get("unsigned"),
) )
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
def test_nested_thread(self) -> None: def test_nested_thread(self) -> None:
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment