Skip to content
Snippets Groups Projects
Commit 4e97eb89 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Handle loops in redaction events

parent 448bcfd0
No related branches found
No related tags found
No related merge requests found
...@@ -483,7 +483,8 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -483,7 +483,8 @@ class EventsWorkerStore(SQLBaseStore):
if events_to_fetch: if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch) logger.debug("Also fetching redaction events %s", events_to_fetch)
result_map = {} # build a map from event_id to EventBase
event_map = {}
for event_id, row in fetched_events.items(): for event_id, row in fetched_events.items():
if not row: if not row:
continue continue
...@@ -494,14 +495,37 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -494,14 +495,37 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason: if not allow_rejected and rejected_reason:
continue continue
cache_entry = yield self._get_event_from_row( d = json.loads(row["json"])
row["internal_metadata"], internal_metadata = json.loads(row["internal_metadata"])
row["json"],
row["redactions"], format_version = row["format_version"]
rejected_reason=row["rejected_reason"], if format_version is None:
format_version=row["format_version"], # This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
) )
event_map[event_id] = original_ev
# finally, we can decide whether each one nededs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id]["redactions"]
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
self._get_event_cache.prefill((event_id,), cache_entry)
result_map[event_id] = cache_entry result_map[event_id] = cache_entry
return result_map return result_map
...@@ -615,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -615,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict return event_dict
@defer.inlineCallbacks def _maybe_redact_event_row(self, original_ev, redactions, event_map):
def _get_event_from_row(
self, internal_metadata, js, redactions, format_version, rejected_reason=None
):
"""Parse an event row which has been read from the database
Args:
internal_metadata (str): json-encoded internal_metadata column
js (str): json-encoded event body from event_json
redactions (list[str]): a list of the events which claim to have redacted
this event, from the redactions table
format_version: (str): the 'format_version' column
rejected_reason (str|None): the reason this event was rejected, if any
Returns:
_EventCacheEntry
"""
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if format_version is None:
# This means that we stored the event before we had the concept
# of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1
original_ev = event_type_from_format_version(format_version)(
event_dict=d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
return cache_entry
@defer.inlineCallbacks
def _maybe_redact_event_row(self, original_ev, redactions):
"""Given an event object and a list of possible redacting event ids, """Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted determine whether to honour any of those redactions and if so return a redacted
event. event.
...@@ -666,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -666,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore):
Args: Args:
original_ev (EventBase): original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events redactions (iterable[str]): list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in
which we can look up the redaaction events. Map from event id to event.
Returns: Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned Deferred[EventBase|None]: if the event should be redacted, a pruned
...@@ -675,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -675,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore):
# we choose to ignore redactions of m.room.create events. # we choose to ignore redactions of m.room.create events.
return None return None
if original_ev.type == "m.room.redaction":
# ... and redaction events
return None
redaction_map = yield self._get_events_from_cache_or_db(redactions)
for redaction_id in redactions: for redaction_id in redactions:
redaction_entry = redaction_map.get(redaction_id) redaction_event = event_map.get(redaction_id)
if not redaction_entry: if not redaction_event or redaction_event.rejected_reason:
# we don't have the redaction event, or the redaction event was not # we don't have the redaction event, or the redaction event was not
# authorized. # authorized.
logger.debug( logger.debug(
...@@ -693,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore): ...@@ -693,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore):
) )
continue continue
redaction_event = redaction_entry.event
if redaction_event.room_id != original_ev.room_id: if redaction_event.room_id != original_ev.room_id:
logger.debug( logger.debug(
"%s was redacted by %s but redaction was in a different room!", "%s was redacted by %s but redaction was in a different room!",
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
...@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase): ...@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
event.unsigned["redacted_because"], event.unsigned["redacted_because"],
) )
def test_circular_redaction(self):
redaction_event_id1 = "$redaction1_id:test"
redaction_event_id2 = "$redaction2_id:test"
class EventIdManglingBuilder:
def __init__(self, base_builder, event_id):
self._base_builder = base_builder
self._event_id = event_id
@defer.inlineCallbacks
def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids)
built_event.event_id = self._event_id
built_event._event_dict["event_id"] = self._event_id
return built_event
@property
def room_id(self):
return self._base_builder.room_id
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id2,
},
),
redaction_event_id1,
)
)
)
self.get_success(self.store.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": EventTypes.Redaction,
"sender": self.u_alice.to_string(),
"room_id": self.room1.to_string(),
"content": {"reason": "test"},
"redacts": redaction_event_id1,
},
),
redaction_event_id2,
)
)
)
self.get_success(self.store.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
# it should have been redacted
self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
self.assertEqual(
fetched.unsigned["redacted_because"].event_id, redaction_event_id2
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment