Skip to content
Snippets Groups Projects
Unverified Commit 51df675c authored by Andrew Ferrazzutti's avatar Andrew Ferrazzutti Committed by GitHub
Browse files

MSC4140: don't cancel delayed state on own state (#17810)

When a user sends a state event, do not cancel their own delayed events
for the same piece of state.

For context, see [the relevant section in the
MSC](https://github.com/matrix-org/matrix-spec-proposals/blob/a09a883d9a013ac4b6ffddebd7ea87a827d211b9/proposals/4140-delayed-events-futures.md#delayed-state-events-are-cancelled-by-a-more-recent-state-event).
parent 59a15da4
No related branches found
No related tags found
No related merge requests found
Update MSC4140 implementation to no longer cancel a user's own delayed state events with an event type & state key that match a more recent state event sent by that user.
...@@ -191,18 +191,36 @@ class DelayedEventsHandler: ...@@ -191,18 +191,36 @@ class DelayedEventsHandler:
async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None: async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
""" """
Process current state deltas to cancel pending delayed events Process current state deltas to cancel other users' pending delayed events
that target the same state. that target the same state.
""" """
for delta in deltas: for delta in deltas:
if delta.event_id is None:
logger.debug(
"Not handling delta for deleted state: %r %r",
delta.event_type,
delta.state_key,
)
continue
logger.debug( logger.debug(
"Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id "Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
) )
event = await self._store.get_event(
delta.event_id, check_room_id=delta.room_id
)
sender = UserID.from_string(event.sender)
next_send_ts = await self._store.cancel_delayed_state_events( next_send_ts = await self._store.cancel_delayed_state_events(
room_id=delta.room_id, room_id=delta.room_id,
event_type=delta.event_type, event_type=delta.event_type,
state_key=delta.state_key, state_key=delta.state_key,
not_from_localpart=(
sender.localpart
if sender.domain == self._config.server.server_name
else ""
),
) )
if self._next_send_ts_changed(next_send_ts): if self._next_send_ts_changed(next_send_ts):
......
...@@ -424,25 +424,37 @@ class DelayedEventsStore(SQLBaseStore): ...@@ -424,25 +424,37 @@ class DelayedEventsStore(SQLBaseStore):
room_id: str, room_id: str,
event_type: str, event_type: str,
state_key: str, state_key: str,
not_from_localpart: str,
) -> Optional[Timestamp]: ) -> Optional[Timestamp]:
""" """
Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed.
Args:
room_id: The room ID to match against.
event_type: The event type to match against.
state_key: The state key to match against.
not_from_localpart: The localpart of a user whose delayed events to not cancel.
If set to the empty string, any users' delayed events may be cancelled.
Returns: The send time of the next delayed event to be sent, if any. Returns: The send time of the next delayed event to be sent, if any.
""" """
def cancel_delayed_state_events_txn( def cancel_delayed_state_events_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Optional[Timestamp]: ) -> Optional[Timestamp]:
self.db_pool.simple_delete_txn( txn.execute(
txn, """
table="delayed_events", DELETE FROM delayed_events
keyvalues={ WHERE room_id = ? AND event_type = ? AND state_key = ?
"room_id": room_id, AND user_localpart <> ?
"event_type": event_type, AND NOT is_processed
"state_key": state_key, """,
"is_processed": False, (
}, room_id,
event_type,
state_key,
not_from_localpart,
),
) )
return self._get_next_delayed_event_send_ts_txn(txn) return self._get_next_delayed_event_send_ts_txn(txn)
......
...@@ -22,7 +22,8 @@ from parameterized import parameterized ...@@ -22,7 +22,8 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import delayed_events, room, versions from synapse.rest import admin
from synapse.rest.client import delayed_events, login, room, versions
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
...@@ -32,7 +33,6 @@ from tests.unittest import HomeserverTestCase ...@@ -32,7 +33,6 @@ from tests.unittest import HomeserverTestCase
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events" PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
_HS_NAME = "red"
_EVENT_TYPE = "com.example.test" _EVENT_TYPE = "com.example.test"
...@@ -54,23 +54,41 @@ class DelayedEventsUnstableSupportTestCase(HomeserverTestCase): ...@@ -54,23 +54,41 @@ class DelayedEventsUnstableSupportTestCase(HomeserverTestCase):
class DelayedEventsTestCase(HomeserverTestCase): class DelayedEventsTestCase(HomeserverTestCase):
"""Tests getting and managing delayed events.""" """Tests getting and managing delayed events."""
servlets = [delayed_events.register_servlets, room.register_servlets] servlets = [
user_id = f"@sid1:{_HS_NAME}" admin.register_servlets,
delayed_events.register_servlets,
login.register_servlets,
room.register_servlets,
]
def default_config(self) -> JsonDict: def default_config(self) -> JsonDict:
config = super().default_config() config = super().default_config()
config["server_name"] = _HS_NAME
config["max_event_delay_duration"] = "24h" config["max_event_delay_duration"] = "24h"
return config return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user1_user_id = self.register_user("user1", "pass")
self.user1_access_token = self.login("user1", "pass")
self.user2_user_id = self.register_user("user2", "pass")
self.user2_access_token = self.login("user2", "pass")
self.room_id = self.helper.create_room_as( self.room_id = self.helper.create_room_as(
self.user_id, self.user1_user_id,
tok=self.user1_access_token,
extra_content={ extra_content={
"preset": "trusted_private_chat", "preset": "public_chat",
"power_level_content_override": {
"events": {
_EVENT_TYPE: 0,
}
},
}, },
) )
self.helper.join(
room=self.room_id, user=self.user2_user_id, tok=self.user2_access_token
)
def test_delayed_events_empty_on_startup(self) -> None: def test_delayed_events_empty_on_startup(self) -> None:
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
...@@ -85,6 +103,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -85,6 +103,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
{ {
setter_key: setter_expected, setter_key: setter_expected,
}, },
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
events = self._get_delayed_events() events = self._get_delayed_events()
...@@ -94,7 +113,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -94,7 +113,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
self.helper.get_state( self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -104,7 +123,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -104,7 +123,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
content = self.helper.get_state( content = self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
...@@ -113,7 +132,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -113,7 +132,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}} {"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
) )
def test_get_delayed_events_ratelimit(self) -> None: def test_get_delayed_events_ratelimit(self) -> None:
args = ("GET", PATH_PREFIX) args = ("GET", PATH_PREFIX, b"", self.user1_access_token)
channel = self.make_request(*args) channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
...@@ -123,7 +142,9 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -123,7 +142,9 @@ class DelayedEventsTestCase(HomeserverTestCase):
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success( self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
) )
# Test that the request isn't ratelimited anymore. # Test that the request isn't ratelimited anymore.
...@@ -134,6 +155,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -134,6 +155,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"{PATH_PREFIX}/", f"{PATH_PREFIX}/",
access_token=self.user1_access_token,
) )
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
...@@ -141,6 +163,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -141,6 +163,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
access_token=self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
...@@ -153,6 +176,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -153,6 +176,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
{}, {},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
...@@ -165,6 +189,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -165,6 +189,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
{"action": "oops"}, {"action": "oops"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual( self.assertEqual(
...@@ -178,6 +203,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -178,6 +203,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/abc", f"{PATH_PREFIX}/abc",
{"action": action}, {"action": action},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
...@@ -192,6 +218,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -192,6 +218,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
{ {
setter_key: setter_expected, setter_key: setter_expected,
}, },
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -205,7 +232,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -205,7 +232,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
self.helper.get_state( self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -214,6 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -214,6 +241,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_id}", f"{PATH_PREFIX}/{delay_id}",
{"action": "cancel"}, {"action": "cancel"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
...@@ -222,7 +250,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -222,7 +250,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
content = self.helper.get_state( content = self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -237,6 +265,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -237,6 +265,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
_get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
{}, {},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -247,6 +276,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -247,6 +276,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"}, {"action": "cancel"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
...@@ -254,13 +284,16 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -254,13 +284,16 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"}, {"action": "cancel"},
self.user1_access_token,
) )
channel = self.make_request(*args) channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success( self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
) )
# Test that the request isn't ratelimited anymore. # Test that the request isn't ratelimited anymore.
...@@ -278,6 +311,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -278,6 +311,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
{ {
setter_key: setter_expected, setter_key: setter_expected,
}, },
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -291,7 +325,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -291,7 +325,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
self.helper.get_state( self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -300,13 +334,14 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -300,13 +334,14 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_id}", f"{PATH_PREFIX}/{delay_id}",
{"action": "send"}, {"action": "send"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state( content = self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
...@@ -319,6 +354,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -319,6 +354,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
_get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
{}, {},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -329,6 +365,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -329,6 +365,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"}, {"action": "send"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
...@@ -336,13 +373,16 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -336,13 +373,16 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"}, {"action": "send"},
self.user1_access_token,
) )
channel = self.make_request(*args) channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success( self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
) )
# Test that the request isn't ratelimited anymore. # Test that the request isn't ratelimited anymore.
...@@ -360,6 +400,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -360,6 +400,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
{ {
setter_key: setter_expected, setter_key: setter_expected,
}, },
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -373,7 +414,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -373,7 +414,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
self.helper.get_state( self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -382,6 +423,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -382,6 +423,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_id}", f"{PATH_PREFIX}/{delay_id}",
{"action": "restart"}, {"action": "restart"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
...@@ -393,7 +435,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -393,7 +435,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
self.helper.get_state( self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
expect_code=HTTPStatus.NOT_FOUND, expect_code=HTTPStatus.NOT_FOUND,
) )
...@@ -403,7 +445,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -403,7 +445,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
content = self.helper.get_state( content = self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
...@@ -418,6 +460,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -418,6 +460,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
_get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000), _get_path_for_delayed_send(self.room_id, _EVENT_TYPE, 100000),
{}, {},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id") delay_id = channel.json_body.get("delay_id")
...@@ -428,6 +471,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -428,6 +471,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"}, {"action": "restart"},
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
...@@ -435,21 +479,66 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -435,21 +479,66 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST", "POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}", f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"}, {"action": "restart"},
self.user1_access_token,
) )
channel = self.make_request(*args) channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result) self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting. # Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success( self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(self.user_id, 0, 0) self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
) )
# Test that the request isn't ratelimited anymore. # Test that the request isn't ratelimited anymore.
channel = self.make_request(*args) channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_delayed_state_events_are_cancelled_by_more_recent_state(self) -> None: def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
state_key = "to_be_cancelled" self,
) -> None:
state_key = "to_not_be_cancelled_by_same_user"
setter_key = "setter"
setter_expected = "on_timeout"
channel = self.make_request(
"PUT",
_get_path_for_delayed_state(self.room_id, _EVENT_TYPE, state_key, 900),
{
setter_key: setter_expected,
},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
self.helper.send_state(
self.room_id,
_EVENT_TYPE,
{
setter_key: "manual",
},
self.user1_access_token,
state_key=state_key,
)
events = self._get_delayed_events()
self.assertEqual(1, len(events), events)
self.reactor.advance(1)
content = self.helper.get_state(
self.room_id,
_EVENT_TYPE,
self.user1_access_token,
state_key=state_key,
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_delayed_state_is_cancelled_by_new_state_from_other_user(
self,
) -> None:
state_key = "to_be_cancelled_by_other_user"
setter_key = "setter" setter_key = "setter"
channel = self.make_request( channel = self.make_request(
...@@ -458,19 +547,20 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -458,19 +547,20 @@ class DelayedEventsTestCase(HomeserverTestCase):
{ {
setter_key: "on_timeout", setter_key: "on_timeout",
}, },
self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
events = self._get_delayed_events() events = self._get_delayed_events()
self.assertEqual(1, len(events), events) self.assertEqual(1, len(events), events)
setter_expected = "manual" setter_expected = "other_user"
self.helper.send_state( self.helper.send_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
{ {
setter_key: setter_expected, setter_key: setter_expected,
}, },
None, self.user2_access_token,
state_key=state_key, state_key=state_key,
) )
self.assertListEqual([], self._get_delayed_events()) self.assertListEqual([], self._get_delayed_events())
...@@ -479,7 +569,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -479,7 +569,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
content = self.helper.get_state( content = self.helper.get_state(
self.room_id, self.room_id,
_EVENT_TYPE, _EVENT_TYPE,
"", self.user1_access_token,
state_key=state_key, state_key=state_key,
) )
self.assertEqual(setter_expected, content.get(setter_key), content) self.assertEqual(setter_expected, content.get(setter_key), content)
...@@ -488,6 +578,7 @@ class DelayedEventsTestCase(HomeserverTestCase): ...@@ -488,6 +578,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
PATH_PREFIX, PATH_PREFIX,
access_token=self.user1_access_token,
) )
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment