Skip to content
Snippets Groups Projects
Commit def5ea40 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Don't bomb out on direct key fetches as soon as one fails

parent 06a1f3e2
No related branches found
No related tags found
No related merge requests found
...@@ -46,6 +46,7 @@ from synapse.api.errors import ( ...@@ -46,6 +46,7 @@ from synapse.api.errors import (
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
...@@ -169,7 +170,12 @@ class Keyring(object): ...@@ -169,7 +170,12 @@ class Keyring(object):
) )
) )
logger.debug("Verifying for %s with key_ids %s", server_name, key_ids) logger.debug(
"Verifying for %s with key_ids %s, min_validity %i",
server_name,
key_ids,
validity_time,
)
# add the key request to the queue, but don't start it off yet. # add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
...@@ -744,34 +750,42 @@ class ServerKeyFetcher(BaseV2KeyFetcher): ...@@ -744,34 +750,42 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
@defer.inlineCallbacks
def get_keys(self, keys_to_fetch): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
# TODO make this more resilient
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct,
server_name,
server_keys.keys(),
)
for server_name, server_keys in keys_to_fetch.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {} results = {}
for result in results:
merged.update(result) @defer.inlineCallbacks
def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item
try:
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning(
"Error looking up keys %s from %s: %s", key_ids, server_name, e
)
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
defer.returnValue( return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
{server_name: keys for server_name, keys in merged.items() if keys} lambda _: results
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
server_name (str):
key_ids (iterable[str]):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
...@@ -823,7 +837,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): ...@@ -823,7 +837,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
) )
keys.update(response_keys) keys.update(response_keys)
defer.returnValue({server_name: keys}) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
......
...@@ -25,11 +25,7 @@ from twisted.internet import defer ...@@ -25,11 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import ( from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher
KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
...@@ -364,9 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): ...@@ -364,9 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
) )
# change the server name: it should cause a rejection # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertEqual(keys, {})
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment