Skip to content
Snippets Groups Projects
Commit 98483890 authored by Brendan Abolivier's avatar Brendan Abolivier
Browse files

Merge branch 'develop' of github.com:matrix-org/synapse into develop

parents b3b2038b ef884f6d
No related branches found
No related tags found
No related merge requests found
Convert the identity handler to async/await.
...@@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes ...@@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from twisted.internet import defer
from twisted.internet.error import TimeoutError from twisted.internet.error import TimeoutError
from synapse.api.errors import ( from synapse.api.errors import (
...@@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler): ...@@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client() self.federation_http_client = hs.get_http_client()
self.hs = hs self.hs = hs
@defer.inlineCallbacks async def threepid_from_creds(self, id_server, creds):
def threepid_from_creds(self, id_server, creds):
""" """
Retrieve and validate a threepid identifier from a "credentials" dictionary against a Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server given identity server
...@@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler): ...@@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try: try:
data = yield self.http_client.get_json(url, query_params) data = await self.http_client.get_json(url, query_params)
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e: except HttpResponseException as e:
...@@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler): ...@@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server, creds) logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None return None
@defer.inlineCallbacks async def bind_threepid(
def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
): ):
"""Bind a 3PID to an identity server """Bind a 3PID to an identity server
...@@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler): ...@@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
try: try:
# Use the blacklisting http client as this call is only to identity servers # Use the blacklisting http client as this call is only to identity servers
# provided by a client # provided by a client
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers bind_url, bind_data, headers=headers
) )
# Remember where we bound the threepid # Remember where we bound the threepid
yield self.store.add_user_bound_threepid( await self.store.add_user_bound_threepid(
user_id=mxid, user_id=mxid,
medium=data["medium"], medium=data["medium"],
address=data["address"], address=data["address"],
...@@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler): ...@@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
return data return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
res = yield self.bind_threepid( res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False client_secret, sid, mxid, id_server, id_access_token, use_v2=False
) )
return res return res
@defer.inlineCallbacks async def try_unbind_threepid(self, mxid, threepid):
def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all """Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on identity servers we're aware the binding is present on
...@@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler): ...@@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"): if threepid.get("id_server"):
id_servers = [threepid["id_server"]] id_servers = [threepid["id_server"]]
else: else:
id_servers = yield self.store.get_id_servers_user_bound( id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"] user_id=mxid, medium=threepid["medium"], address=threepid["address"]
) )
...@@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler): ...@@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
changed = True changed = True
for id_server in id_servers: for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server( changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server mxid, threepid, id_server
) )
return changed return changed
@defer.inlineCallbacks async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server """Removes a binding from an identity server
Args: Args:
...@@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler): ...@@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
try: try:
# Use the blacklisting http client as this call is only to identity servers # Use the blacklisting http client as this call is only to identity servers
# provided by a client # provided by a client
yield self.blacklisting_http_client.post_json_get_json( await self.blacklisting_http_client.post_json_get_json(
url, content, headers url, content, headers
) )
changed = True changed = True
...@@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler): ...@@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
yield self.store.remove_user_bound_threepid( await self.store.remove_user_bound_threepid(
user_id=mxid, user_id=mxid,
medium=threepid["medium"], medium=threepid["medium"],
address=threepid["address"], address=threepid["address"],
...@@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler): ...@@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler):
return session_id return session_id
@defer.inlineCallbacks async def requestEmailToken(
def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None self, id_server, email, client_secret, send_attempt, next_link=None
): ):
""" """
...@@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler): ...@@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler):
) )
try: try:
data = yield self.http_client.post_json_get_json( data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken", id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params, params,
) )
...@@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler): ...@@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
@defer.inlineCallbacks async def requestMsisdnToken(
def requestMsisdnToken(
self, self,
id_server, id_server,
country, country,
...@@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler): ...@@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler):
) )
try: try:
data = yield self.http_client.post_json_get_json( data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params, params,
) )
...@@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler): ...@@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler):
) )
return data return data
@defer.inlineCallbacks async def validate_threepid_session(self, client_secret, sid):
def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID """Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally. Tries validating against any configured account_threepid_delegates as well as locally.
...@@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler): ...@@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email # Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server # Ask our delegated email identity server
validation_session = yield self.threepid_from_creds( validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds self.hs.config.account_threepid_delegate_email, threepid_creds
) )
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details # Get a validated session matching these details
validation_session = yield self.store.get_threepid_validation_session( validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True "email", client_secret, sid=sid, validated=True
) )
...@@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler): ...@@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn # Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn: if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server # Ask our delegated msisdn identity server
validation_session = yield self.threepid_from_creds( validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds self.hs.config.account_threepid_delegate_msisdn, threepid_creds
) )
return validation_session return validation_session
@defer.inlineCallbacks async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes """Proxy a POST submitToken request to an identity server for verification purposes
Args: Args:
...@@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler): ...@@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token} body = {"client_secret": client_secret, "sid": sid, "token": token}
try: try:
return ( return await self.http_client.post_json_get_json(
yield self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", body,
body,
)
) )
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
...@@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler): ...@@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server") raise SynapseError(400, "Error contacting the identity server")
@defer.inlineCallbacks async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server. """Looks up a 3pid in the passed identity server.
Args: Args:
...@@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler): ...@@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler):
""" """
if id_access_token is not None: if id_access_token is not None:
try: try:
results = yield self._lookup_3pid_v2( results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address id_server, id_access_token, medium, address
) )
return results return results
...@@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler): ...@@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e) logger.warning("Error when looking up hashing details: %s", e)
return None return None
return (yield self._lookup_3pid_v1(id_server, medium, address)) return await self._lookup_3pid_v1(id_server, medium, address)
@defer.inlineCallbacks async def _lookup_3pid_v1(self, id_server, medium, address):
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup. """Looks up a 3pid in the passed identity server using v1 lookup.
Args: Args:
...@@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler): ...@@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized. str: the matrix ID of the 3pid, or None if it is not recognized.
""" """
try: try:
data = yield self.blacklisting_http_client.get_json( data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address}, {"medium": medium, "address": address},
) )
...@@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler): ...@@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server) await self._verify_any_signature(data, id_server)
return data["mxid"] return data["mxid"]
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
...@@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler): ...@@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler):
return None return None
@defer.inlineCallbacks async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup. """Looks up a 3pid in the passed identity server using v2 lookup.
Args: Args:
...@@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler): ...@@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler):
""" """
# Check what hashing details are supported by this identity server # Check what hashing details are supported by this identity server
try: try:
hash_details = yield self.blacklisting_http_client.get_json( hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token}, {"access_token": id_access_token},
) )
...@@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler): ...@@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)} headers = {"Authorization": create_id_access_token_header(id_access_token)}
try: try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json( lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{ {
"addresses": [lookup_value], "addresses": [lookup_value],
...@@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler): ...@@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value) mxid = lookup_results["mappings"].get(lookup_value)
return mxid return mxid
@defer.inlineCallbacks async def _verify_any_signature(self, data, server_hostname):
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
try: try:
key_data = yield self.blacklisting_http_client.get_json( key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" "%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name) % (id_server_scheme, server_hostname, key_name)
) )
...@@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler): ...@@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler):
) )
return return
@defer.inlineCallbacks async def ask_id_server_for_third_party_invite(
def ask_id_server_for_third_party_invite(
self, self,
requester, requester,
id_server, id_server,
...@@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler): ...@@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup # Attempt a v2 lookup
url = base_url + "/v2/store-invite" url = base_url + "/v2/store-invite"
try: try:
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
url, url,
invite_config, invite_config,
{"Authorization": create_id_access_token_header(id_access_token)}, {"Authorization": create_id_access_token_header(id_access_token)},
...@@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler): ...@@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite" url = base_url + "/api/v1/store-invite"
try: try:
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config url, invite_config
) )
except TimeoutError: except TimeoutError:
...@@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler): ...@@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see # types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170 # https://github.com/matrix-org/sydent/pull/170
try: try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json( data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config url, invite_config
) )
except HttpResponseException as e: except HttpResponseException as e:
......
...@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker: ...@@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def _check_threepid(self, medium, authdict):
def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict: if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
...@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker: ...@@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError( raise SynapseError(
400, "Phone number verification is not enabled on this homeserver" 400, "Phone number verification is not enabled on this homeserver"
) )
threepid = yield identity_handler.threepid_from_creds( threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds self.hs.config.account_threepid_delegate_msisdn, threepid_creds
) )
elif medium == "email": elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
threepid = yield identity_handler.threepid_from_creds( threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds self.hs.config.account_threepid_delegate_email, threepid_creds
) )
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None threepid = None
row = yield self.store.get_threepid_validation_session( row = await self.store.get_threepid_validation_session(
medium, medium,
threepid_creds["client_secret"], threepid_creds["client_secret"],
sid=threepid_creds["sid"], sid=threepid_creds["sid"],
...@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker: ...@@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
} }
# Valid threepid returned, delete from the db # Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"]) await self.store.delete_threepid_session(threepid_creds["sid"])
else: else:
raise SynapseError( raise SynapseError(
400, "Email address verification is not enabled on this homeserver" 400, "Email address verification is not enabled on this homeserver"
...@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec ...@@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
) )
def check_auth(self, authdict, clientip): def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict) return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
...@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): ...@@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn) return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip): def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict) return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [ INTERACTIVE_AUTH_CHECKERS = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment