Skip to content
Snippets Groups Projects
Unverified Commit 7036e24e authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Add background update for add chain cover index (#9029)

parent 21a296cd
No related branches found
No related tags found
No related merge requests found
Improve efficiency of large state resolutions for new rooms. Improve efficiency of large state resolutions.
Improve efficiency of large state resolutions.
...@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db") ...@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"], "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"], "rooms": ["is_public", "has_auth_chain_index"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
......
...@@ -466,9 +466,6 @@ class PersistEventsStore: ...@@ -466,9 +466,6 @@ class PersistEventsStore:
if not state_events: if not state_events:
return return
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# We need to know the type/state_key and auth events of the events we're # We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event # calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and # instances as we'll potentially be pulling more events from the DB and
...@@ -479,9 +476,33 @@ class PersistEventsStore: ...@@ -479,9 +476,33 @@ class PersistEventsStore:
event_to_auth_chain = { event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values() e.event_id: e.auth_event_ids() for e in state_events.values()
} }
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain
)
def _add_chain_cover_index(
self,
txn,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
) -> None:
"""Calculate the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
# Map from event ID to chain ID/sequence number.
chain_map = {} # type: Dict[str, Tuple[int, int]]
# Set of event IDs to calculate chain ID/seq numbers for. # Set of event IDs to calculate chain ID/seq numbers for.
events_to_calc_chain_id_for = set(state_events) events_to_calc_chain_id_for = set(event_to_room_id)
# We check if there are any events that need to be handled in the rooms # We check if there are any events that need to be handled in the rooms
# we're looking at. These should just be out of band memberships, where # we're looking at. These should just be out of band memberships, where
...@@ -491,7 +512,7 @@ class PersistEventsStore: ...@@ -491,7 +512,7 @@ class PersistEventsStore:
table="event_auth_chain_to_calculate", table="event_auth_chain_to_calculate",
keyvalues={}, keyvalues={},
column="room_id", column="room_id",
iterable={e.room_id for e in state_events.values()}, iterable=set(event_to_room_id.values()),
retcols=("event_id", "type", "state_key"), retcols=("event_id", "type", "state_key"),
) )
for row in rows: for row in rows:
...@@ -582,16 +603,17 @@ class PersistEventsStore: ...@@ -582,16 +603,17 @@ class PersistEventsStore:
# the list of events to calculate chain IDs for next time # the list of events to calculate chain IDs for next time
# around. (Otherwise we will have already added it to the # around. (Otherwise we will have already added it to the
# table). # table).
event = state_events.get(event_id) room_id = event_to_room_id.get(event_id)
if event: if room_id:
e_type, state_key = event_to_types[event_id]
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="event_auth_chain_to_calculate", table="event_auth_chain_to_calculate",
values={ values={
"event_id": event.event_id, "event_id": event_id,
"room_id": event.room_id, "room_id": room_id,
"type": event.type, "type": e_type,
"state_key": event.state_key, "state_key": state_key,
}, },
) )
...@@ -617,7 +639,7 @@ class PersistEventsStore: ...@@ -617,7 +639,7 @@ class PersistEventsStore:
events_to_calc_chain_id_for, event_to_auth_chain events_to_calc_chain_id_for, event_to_auth_chain
): ):
existing_chain_id = None existing_chain_id = None
for auth_id in event_to_auth_chain[event_id]: for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id): if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id] existing_chain_id = chain_map[auth_id]
break break
...@@ -730,11 +752,11 @@ class PersistEventsStore: ...@@ -730,11 +752,11 @@ class PersistEventsStore:
# auth events (A, B) to check if B is reachable from A. # auth events (A, B) to check if B is reachable from A.
reduction = { reduction = {
a_id a_id
for a_id in event_to_auth_chain[event_id] for a_id in event_to_auth_chain.get(event_id, [])
if chain_map[a_id][0] != chain_id if chain_map[a_id][0] != chain_id
} }
for start_auth_id, end_auth_id in itertools.permutations( for start_auth_id, end_auth_id in itertools.permutations(
event_to_auth_chain[event_id], r=2, event_to_auth_chain.get(event_id, []), r=2,
): ):
if chain_links.exists_path_from( if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id] chain_map[start_auth_id], chain_map[end_auth_id]
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
...@@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ...@@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rejected_events_metadata", self._rejected_events_metadata, "rejected_events_metadata", self._rejected_events_metadata,
) )
self.db_pool.updates.register_background_update_handler(
"chain_cover", self._chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size): async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
...@@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ...@@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
) )
return len(results) return len(results)
async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
"""A background updates that iterates over all rooms and generates the
chain cover index for them.
"""
current_room_id = progress.get("current_room_id", "")
# Have we finished processing the current room.
finished = progress.get("finished", True)
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
# Have we set the `has_auth_chain_index` for the room yet.
has_set_room_has_chain_index = progress.get(
"has_set_room_has_chain_index", False
)
if finished:
# If we've finished with the previous room (or its our first
# iteration) we move on to the next room.
def _get_next_room(txn: Cursor) -> Optional[str]:
sql = """
SELECT room_id FROM rooms
WHERE room_id > ?
AND (
NOT has_auth_chain_index
OR has_auth_chain_index IS NULL
)
ORDER BY room_id
LIMIT 1
"""
txn.execute(sql, (current_room_id,))
row = txn.fetchone()
if row:
return row[0]
return None
current_room_id = await self.db_pool.runInteraction(
"_chain_cover_index", _get_next_room
)
if not current_room_id:
await self.db_pool.updates._end_background_update("chain_cover")
return 0
logger.debug("Adding chain cover to %s", current_room_id)
def _calculate_auth_chain(
txn: Cursor, last_depth: int, last_stream: int
) -> Tuple[int, int, int]:
# Get the next set of events in the room (that we haven't already
# computed chain cover for). We do this in topological order.
# We want to do a `(topological_ordering, stream_ordering) > (?,?)`
# comparison, but that is not supported on older SQLite versions
tuple_clause, tuple_args = make_tuple_comparison_clause(
self.database_engine,
[
("topological_ordering", last_depth),
("stream_ordering", last_stream),
],
)
sql = """
SELECT
event_id, state_events.type, state_events.state_key,
topological_ordering, stream_ordering
FROM events
INNER JOIN state_events USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
LEFT JOIN event_auth_chain_to_calculate USING (event_id)
WHERE events.room_id = ?
AND event_auth_chains.event_id IS NULL
AND event_auth_chain_to_calculate.event_id IS NULL
AND %(tuple_cmp)s
ORDER BY topological_ordering, stream_ordering
LIMIT ?
""" % {
"tuple_cmp": tuple_clause,
}
args = [current_room_id]
args.extend(tuple_args)
args.append(batch_size)
txn.execute(sql, args)
rows = txn.fetchall()
# Put the results in the necessary format for
# `_add_chain_cover_index`
event_to_room_id = {row[0]: current_room_id for row in rows}
event_to_types = {row[0]: (row[1], row[2]) for row in rows}
new_last_depth = rows[-1][3] if rows else last_depth # type: int
new_last_stream = rows[-1][4] if rows else last_stream # type: int
count = len(rows)
# We also need to fetch the auth events for them.
auth_events = self.db_pool.simple_select_many_txn(
txn,
table="event_auth",
column="event_id",
iterable=event_to_room_id,
keyvalues={},
retcols=("event_id", "auth_id"),
)
event_to_auth_chain = {} # type: Dict[str, List[str]]
for row in auth_events:
event_to_auth_chain.setdefault(row["event_id"], []).append(
row["auth_id"]
)
# Calculate and persist the chain cover index for this set of events.
#
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
self.hs.get_datastores().persist_events._add_chain_cover_index(
txn, event_to_room_id, event_to_types, event_to_auth_chain,
)
return new_last_depth, new_last_stream, count
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed = count
if count < batch_size and not has_set_room_has_chain_index:
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
# flag without having the chain cover calculated for them. This is
# fine as a) the code gracefully handles these cases and b) we'll
# calculate them below.
await self.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": current_room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
last_depth, last_stream, count = await self.db_pool.runInteraction(
"_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
)
total_rows_processed += count
# Note that at this point its technically possible that more events
# than our `batch_size` have been persisted without their chain
# cover, so we need to continue processing this room if the last
# count returned was equal to the `batch_size`.
if count < batch_size:
# We've finished calculating the index for this room, move on to the
# next room.
await self.db_pool.updates._background_update_progress(
"chain_cover", {"current_room_id": current_room_id, "finished": True},
)
else:
# We still have outstanding events to calculate the index for.
await self.db_pool.updates._background_update_progress(
"chain_cover",
{
"current_room_id": current_room_id,
"last_depth": last_depth,
"last_stream": last_stream,
"has_auth_chain_index": has_set_room_has_chain_index,
"finished": False,
},
)
return total_rows_processed
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
(5906, 'chain_cover', '{}', 'rejected_events_metadata');
...@@ -20,7 +20,10 @@ from twisted.trial import unittest ...@@ -20,7 +20,10 @@ from twisted.trial import unittest
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events import _LinkMap from synapse.storage.databases.main.events import _LinkMap
from synapse.types import create_requester
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
...@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase): ...@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)]) self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)]) self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def test_background_update(self):
"""Test that the background update to calculate auth chains for historic
rooms works correctly.
"""
# Create a room
user_id = self.register_user("foo", "pass")
token = self.login("foo", "pass")
room_id = self.helper.create_room_as(user_id, tok=token)
requester = create_requester(user_id)
store = self.hs.get_datastore()
# Mark the room as not having a chain cover index
self.get_success(
store.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
desc="test",
)
)
# Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler()
latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state1 = list(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success(
event_handler.create_event(
requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
)
state2 = list(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info.
def _delete_tables(txn):
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
self.get_success(store.db_pool.runInteraction("test", _delete_tables))
# Insert and run the background update.
self.get_success(
store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
store.db_pool.updates._all_done = False
while not self.get_success(
store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
store.db_pool.runInteraction(
"test",
store._get_auth_chain_difference_using_cover_index_txn,
room_id,
[state1, state2],
)
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment