Skip to content
Snippets Groups Projects
Unverified Commit 6b3ac3b8 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Convert device handler to async/await (#7871)

parent 00e57b75
No related branches found
No related tags found
No related merge requests found
Convert device handler to async/await.
This diff is collapsed.
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred, fail, succeed
from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
...@@ -79,6 +81,28 @@ class Distributor(object): ...@@ -79,6 +81,28 @@ class Distributor(object):
run_as_background_process(name, self.signals[name].fire, *args, **kwargs) run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
def maybeAwaitableDeferred(f, *args, **kw):
"""
Invoke a function that may or may not return a Deferred or an Awaitable.
This is a modified version of twisted.internet.defer.maybeDeferred.
"""
try:
result = f(*args, **kw)
except Exception:
return fail(failure.Failure(captureVars=Deferred.debug))
if isinstance(result, Deferred):
return result
# Handle the additional case of an awaitable being returned.
elif inspect.isawaitable(result):
return defer.ensureDeferred(result)
elif isinstance(result, failure.Failure):
return fail(result)
else:
return succeed(result)
class Signal(object): class Signal(object):
"""A Signal is a dispatch point that stores a list of callables as """A Signal is a dispatch point that stores a list of callables as
observers of it. observers of it.
...@@ -122,7 +146,7 @@ class Signal(object): ...@@ -122,7 +146,7 @@ class Signal(object):
), ),
) )
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [run_in_background(do, o) for o in self.observers] deferreds = [run_in_background(do, o) for o in self.observers]
......
...@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): ...@@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc")) self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted # check the device was deleted
res = self.handler.get_device(user1, "abc") self.get_failure(
self.pump() self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
) )
# we'd like to check the access token was invalidated, but that's a # we'd like to check the access token was invalidated, but that's a
...@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase): ...@@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self): def test_update_unknown_device(self):
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
res = self.handler.update_device("user_id", "unknown_device_id", update) self.get_failure(
self.pump() self.handler.update_device("user_id", "unknown_device_id", update),
self.assertIsInstance( synapse.api.errors.NotFoundError,
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
) )
def _record_users(self): def _record_users(self):
......
...@@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ...@@ -334,10 +334,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = None res = None
try: try:
yield self.hs.get_device_handler().check_device_registered( yield defer.ensureDeferred(
user_id=local_user, self.hs.get_device_handler().check_device_registered(
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", user_id=local_user,
initial_device_display_name="new display name", device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
)
) )
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
......
...@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ...@@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
store = self.homeserver.get_datastore() store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.
...@@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ...@@ -218,23 +218,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client() federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock( federation_client.query_user_devices = Mock(
return_value={ return_value=succeed(
"user_id": remote_user_id, {
"stream_id": 1,
"devices": [],
"master_key": {
"user_id": remote_user_id, "user_id": remote_user_id,
"usage": ["master"], "stream_id": 1,
"keys": {"ed25519:" + remote_master_key: remote_master_key}, "devices": [],
}, "master_key": {
"self_signing_key": { "user_id": remote_user_id,
"user_id": remote_user_id, "usage": ["master"],
"usage": ["self_signing"], "keys": {"ed25519:" + remote_master_key: remote_master_key},
"keys": {
"ed25519:" + remote_self_signing_key: remote_self_signing_key
}, },
}, "self_signing_key": {
} "user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
},
},
}
)
) )
# Resync the device list. # Resync the device list.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment