Skip to content
Snippets Groups Projects
Unverified Commit abedf7d7 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Get rid of `_auth_and_persist_event` (#10781)

This is only called in two places, and the code seems much clearer without it.
parent 03caba65
No related branches found
No related tags found
No related merge requests found
Clean up some of the federation event authentication code for clarity.
......@@ -909,12 +909,18 @@ class FederationEventHandler:
context = await self._state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(
origin, event, context, state=state, backfilled=backfilled
context = await self._check_event_auth(
origin,
event,
context,
state=state,
backfilled=backfilled,
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
await self._run_push_actions_and_persist_event(event, context, backfilled)
if backfilled:
return
......@@ -1239,51 +1245,6 @@ class FederationEventHandler:
],
)
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False,
) -> None:
"""
Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
claimed_auth_event_map:
A map of (type, state_key) => event for the event's claimed auth_events.
Possibly incomplete, and possibly including events that are not yet
persisted, or authed, or in the right room.
Only populated when populating outliers.
backfilled: True if the event was backfilled.
"""
# claimed_auth_event_map should be given iff the event is an outlier
assert bool(claimed_auth_event_map) == event.internal_metadata.outlier
context = await self._check_event_auth(
origin,
event,
context,
state=state,
claimed_auth_event_map=claimed_auth_event_map,
backfilled=backfilled,
)
await self._run_push_actions_and_persist_event(event, context, backfilled)
async def _check_event_auth(
self,
origin: str,
......@@ -1558,39 +1519,45 @@ class FederationEventHandler:
event.room_id, [e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
if e.event_id in seen_remotes:
for auth_event in remote_auth_chain:
if auth_event.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
if auth_event.event_id == event.event_id:
continue
try:
auth_ids = e.auth_event_ids()
auth_ids = auth_event.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
auth_event.internal_metadata.outlier = True
logger.debug(
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
auth_event.event_id,
)
missing_auth_event_context = (
await self._state_handler.compute_event_context(e)
await self._state_handler.compute_event_context(auth_event)
)
await self._auth_and_persist_event(
missing_auth_event_context = await self._check_event_auth(
origin,
e,
auth_event,
missing_auth_event_context,
claimed_auth_event_map=auth,
)
await self.persist_events_and_notify(
event.room_id, [(auth_event, missing_auth_event_context)]
)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
if auth_event.event_id in event_auth_events:
auth_events[
(auth_event.type, auth_event.state_key)
] = auth_event
except AuthError:
pass
......@@ -1733,10 +1700,13 @@ class FederationEventHandler:
context: The event context.
backfilled: True if the event was backfilled.
"""
# this method should not be called on outliers (those code paths call
# persist_events_and_notify directly.)
assert not event.internal_metadata.outlier
try:
if (
not event.internal_metadata.is_outlier()
and not backfilled
not backfilled
and not context.rejected
and (await self._store.get_min_depth(event.room_id)) <= event.depth
):
......
......@@ -76,9 +76,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.handler = self.homeserver.get_federation_handler()
federation_event_handler = self.homeserver.get_federation_event_handler()
federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
context
)
async def _check_event_auth(
origin,
event,
context,
state=None,
claimed_auth_event_map=None,
backfilled=False,
):
return context
federation_event_handler._check_event_auth = _check_event_auth
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment