Skip to content
Snippets Groups Projects
Unverified Commit 319a805c authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Raise an exception when getting state at an outlier (#12191)

It seems like calling `_get_state_group_for_events` for an event where the
state is unknown is an error. Accordingly, let's raise an exception rather than
silently returning an empty result.
parent 9b43df1f
No related branches found
No related tags found
No related merge requests found
Avoid trying to calculate the state at outlier events.
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
from frozendict import frozendict
......@@ -309,9 +309,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
num_args=1,
)
async def _get_state_group_for_events(
self, event_ids: Iterable[str]
self, event_ids: Collection[str]
) -> Dict[str, int]:
"""Returns mapping event_id -> state_group"""
"""Returns mapping event_id -> state_group.
Raises:
RuntimeError if the state is unknown at any of the given events
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
......@@ -321,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
return {row["event_id"]: row["state_group"] for row in rows}
res = {row["event_id"]: row["state_group"] for row in rows}
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
......
......@@ -571,6 +571,10 @@ class StateGroupStorage:
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
......@@ -659,6 +663,10 @@ class StateGroupStorage:
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
......@@ -696,6 +704,10 @@ class StateGroupStorage:
Returns:
A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
......@@ -723,6 +735,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
......@@ -741,6 +757,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
......
......@@ -20,17 +20,17 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
......@@ -39,7 +39,7 @@ def generate_fake_event_id() -> str:
return "$fake_" + random_string(43)
class FederationTestCase(unittest.HomeserverTestCase):
class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
......@@ -219,41 +219,77 @@ class FederationTestCase(unittest.HomeserverTestCase):
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
requester = create_requester(user_id)
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
# we need a user on the remote server to be a member, so that we can send
# extremity-causing events.
self.get_success(
event_injection.inject_member_event(
self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
)
)
ev1 = self.helper.send(room_id, "first message", tok=tok)
send_result = self.helper.send(room_id, "first message", tok=tok)
ev1 = self.get_success(
self.store.get_event(send_result["event_id"], allow_none=False)
)
current_state = self.get_success(
self.store.get_events_as_list(
(self.get_success(self.store.get_current_state_ids(room_id))).values()
)
)
# Create "many" backward extremities. The magic number we're trying to
# create more than is 5 which corresponds to the number of backward
# extremities we slice off in `_maybe_backfill_inner`
federation_event_handler = self.hs.get_federation_event_handler()
for _ in range(0, 8):
event_handler = self.hs.get_event_creation_handler()
event, context = self.get_success(
event_handler.create_event(
requester,
event = make_event_from_dict(
self.add_hashes_and_signatures(
{
"origin_server_ts": 1,
"type": "m.room.message",
"content": {
"msgtype": "m.text",
"body": "message connected to fake event",
},
"room_id": room_id,
"sender": user_id,
"sender": f"@user:{self.OTHER_SERVER_NAME}",
"prev_events": [
ev1.event_id,
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
# lazy: *everything* is an auth event
"auth_events": [ev.event_id for ev in current_state],
"depth": ev1.depth + 1,
},
prev_event_ids=[
ev1["event_id"],
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
)
room_version,
),
room_version,
)
# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
)
)
# we should now have 8 backwards extremities.
backwards_extremities = self.get_success(
self.store.db_pool.simple_select_list(
"event_backward_extremities",
keyvalues={"room_id": room_id},
retcols=["event_id"],
)
)
self.assertEqual(len(backwards_extremities), 8)
current_depth = 1
limit = 100
with LoggingContext("receive_pdu"):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment