Skip to content
Snippets Groups Projects
Commit ceb599e7 authored by Mark Haines's avatar Mark Haines
Browse files

Add tests for redactions

parent 8c82b069
Branches
Tags
No related merge requests found
...@@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore): ...@@ -69,6 +69,7 @@ class SlavedEventStore(BaseSlavedStore):
"_get_current_state_for_key" "_get_current_state_for_key"
] ]
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__ get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = ( get_rooms_for_user_where_membership_is = (
...@@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore): ...@@ -103,7 +104,7 @@ class SlavedEventStore(BaseSlavedStore):
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()
result["backfilled"] = self._backfill_id_gen.get_current_token() result["backfill"] = self._backfill_id_gen.get_current_token()
return result return result
def process_replication(self, result): def process_replication(self, result):
...@@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore): ...@@ -145,7 +146,6 @@ class SlavedEventStore(BaseSlavedStore):
position = row[0] position = row[0]
internal = json.loads(row[1]) internal = json.loads(row[1])
event_json = json.loads(row[2]) event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal) event = FrozenEvent(event_json, internal_metadata_dict=internal)
self._invalidate_caches_for_event( self._invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets event, backfilled, reset_state=position in state_resets
......
...@@ -112,7 +112,7 @@ class StreamIdGenerator(object): ...@@ -112,7 +112,7 @@ class StreamIdGenerator(object):
self._current + self._step * (n + 1), self._current + self._step * (n + 1),
self._step self._step
) )
self._current += n self._current += n * self._step
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
......
...@@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): ...@@ -51,7 +51,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args) master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args) slaved_result = yield getattr(self.slaved_store, method)(*args)
self.assertEqual(master_result, slaved_result)
if expected_result is not None: if expected_result is not None:
self.assertEqual(master_result, expected_result) self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result) self.assertEqual(slaved_result, expected_result)
self.assertEqual(master_result, slaved_result)
...@@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ...@@ -205,13 +205,59 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
[join3] [join3]
) )
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
event_id = 0 event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks
def persist( def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={},
state=None, reset_state=False, backfill=False, state=None, reset_state=False, backfill=False,
depth=None, prev_events=[], auth_events=[], prev_state=[], depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None,
**content **content
): ):
""" """
...@@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ...@@ -236,6 +282,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_dict["state_key"] = key event_dict["state_key"] = key
event_dict["prev_state"] = prev_state event_dict["prev_state"] = prev_state
if redacts is not None:
event_dict["redacts"] = redacts
event = FrozenEvent(event_dict, internal_metadata_dict=internal) event = FrozenEvent(event_dict, internal_metadata_dict=internal)
self.event_id += 1 self.event_id += 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment