Skip to content
Snippets Groups Projects
Unverified Commit 98c4e35e authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Convert the device message and pagination handlers to async/await. (#7678)

parent 03619324
No related branches found
No related tags found
No related merge requests found
Convert the device message and pagination handlers to async/await.
...@@ -18,8 +18,6 @@ from typing import Any, Dict ...@@ -18,8 +18,6 @@ from typing import Any, Dict
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
...@@ -51,8 +49,7 @@ class DeviceMessageHandler(object): ...@@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
self._device_list_updater = hs.get_device_handler().device_list_updater self._device_list_updater = hs.get_device_handler().device_list_updater
@defer.inlineCallbacks async def on_direct_to_device_edu(self, origin, content):
def on_direct_to_device_edu(self, origin, content):
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id): if origin != get_domain_from_id(sender_user_id):
...@@ -82,11 +79,11 @@ class DeviceMessageHandler(object): ...@@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
} }
local_messages[user_id] = messages_by_device local_messages[user_id] = messages_by_device
yield self._check_for_unknown_devices( await self._check_for_unknown_devices(
message_type, sender_user_id, by_device message_type, sender_user_id, by_device
) )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox( stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages origin, message_id, local_messages
) )
...@@ -94,14 +91,13 @@ class DeviceMessageHandler(object): ...@@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys() "to_device_key", stream_id, users=local_messages.keys()
) )
@defer.inlineCallbacks async def _check_for_unknown_devices(
def _check_for_unknown_devices(
self, self,
message_type: str, message_type: str,
sender_user_id: str, sender_user_id: str,
by_device: Dict[str, Dict[str, Any]], by_device: Dict[str, Dict[str, Any]],
): ):
"""Checks inbound device messages for unkown remote devices, and if """Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale. found marks the remote cache for the user as stale.
""" """
...@@ -115,7 +111,7 @@ class DeviceMessageHandler(object): ...@@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
requesting_device_ids.add(device_id) requesting_device_ids.add(device_id)
# Check if we are tracking the devices of the remote user. # Check if we are tracking the devices of the remote user.
room_ids = yield self.store.get_rooms_for_user(sender_user_id) room_ids = await self.store.get_rooms_for_user(sender_user_id)
if not room_ids: if not room_ids:
logger.info( logger.info(
"Received device message from remote device we don't" "Received device message from remote device we don't"
...@@ -127,7 +123,7 @@ class DeviceMessageHandler(object): ...@@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
# If we are tracking check that we know about the sending # If we are tracking check that we know about the sending
# devices. # devices.
cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
unknown_devices = requesting_device_ids - set(cached_devices) unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices: if unknown_devices:
...@@ -136,15 +132,14 @@ class DeviceMessageHandler(object): ...@@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
sender_user_id, sender_user_id,
unknown_devices, unknown_devices,
) )
yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background # Immediately attempt a resync in the background
run_in_background( run_in_background(
self._device_list_updater.user_device_resync, sender_user_id self._device_list_updater.user_device_resync, sender_user_id
) )
@defer.inlineCallbacks async def send_device_message(self, sender_user_id, message_type, messages):
def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages)) set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
...@@ -183,7 +178,7 @@ class DeviceMessageHandler(object): ...@@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
} }
log_kv({"local_messages": local_messages}) log_kv({"local_messages": local_messages})
stream_id = yield self.store.add_messages_to_device_inbox( stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents local_messages, remote_edu_contents
) )
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
...@@ -97,8 +96,7 @@ class PaginationHandler(object): ...@@ -97,8 +96,7 @@ class PaginationHandler(object):
job["longest_max_lifetime"], job["longest_max_lifetime"],
) )
@defer.inlineCallbacks async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
def purge_history_for_rooms_in_range(self, min_ms, max_ms):
"""Purge outdated events from rooms within the given retention range. """Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its If a default retention policy is defined in the server's configuration and its
...@@ -137,7 +135,7 @@ class PaginationHandler(object): ...@@ -137,7 +135,7 @@ class PaginationHandler(object):
include_null, include_null,
) )
rooms = yield self.store.get_rooms_for_retention_period_in_range( rooms = await self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null min_ms, max_ms, include_null
) )
...@@ -165,9 +163,9 @@ class PaginationHandler(object): ...@@ -165,9 +163,9 @@ class PaginationHandler(object):
# Figure out what token we should start purging at. # Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime ts = self.clock.time_msec() - max_lifetime
stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = yield self.store.get_room_event_before_stream_ordering( r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering, room_id, stream_ordering,
) )
if not r: if not r:
...@@ -227,8 +225,7 @@ class PaginationHandler(object): ...@@ -227,8 +225,7 @@ class PaginationHandler(object):
) )
return purge_id return purge_id
@defer.inlineCallbacks async def _purge_history(self, purge_id, room_id, token, delete_local_events):
def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room. """Carry out a history purge on a room.
Args: Args:
...@@ -237,14 +234,11 @@ class PaginationHandler(object): ...@@ -237,14 +234,11 @@ class PaginationHandler(object):
token (str): topological token to delete events before token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as delete_local_events (bool): True to delete local events as well as
remote ones remote ones
Returns:
Deferred
""" """
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
with (yield self.pagination_lock.write(room_id)): with await self.pagination_lock.write(room_id):
yield self.storage.purge_events.purge_history( await self.storage.purge_events.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
logger.info("[purge] complete") logger.info("[purge] complete")
...@@ -282,9 +276,7 @@ class PaginationHandler(object): ...@@ -282,9 +276,7 @@ class PaginationHandler(object):
await self.store.get_room_version_id(room_id) await self.store.get_room_version_id(room_id)
# first check that we have no users in this room # first check that we have no users in this room
joined = await defer.maybeDeferred( joined = await self.store.is_host_joined(room_id, self._server_name)
self.store.is_host_joined, room_id, self._server_name
)
if joined: if joined:
raise SynapseError(400, "Users are still joined to this room") raise SynapseError(400, "Users are still joined to this room")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment