Skip to content
Snippets Groups Projects
Commit 895b79ac authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Factor out KeyFetchers from KeyRing

Rather than have three methods which have to have the same interface,
factor out a separate interface which is provided by three implementations.

I find it easier to grok the code this way.
parent b75537be
No related branches found
No related tags found
No related merge requests found
Refactor synapse.crypto.keyring to use a KeyFetcher interface.
...@@ -80,12 +80,13 @@ class KeyLookupError(ValueError): ...@@ -80,12 +80,13 @@ class KeyLookupError(ValueError):
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.config = hs.get_config() self._key_fetchers = (
self.perspective_servers = self.config.perspectives StoreKeyFetcher(hs),
self.hs = hs PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
# map from server name to Deferred. Has an entry for each server with # map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download # an ongoing key download; the Deferred completes once the download
...@@ -271,13 +272,6 @@ class Keyring(object): ...@@ -271,13 +272,6 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
...@@ -288,8 +282,8 @@ class Keyring(object): ...@@ -288,8 +282,8 @@ class Keyring(object):
verify_request.key_ids verify_request.key_ids
) )
for fn in key_fetch_fns: for f in self._key_fetchers:
results = yield fn(missing_keys.items()) results = yield f.get_keys(missing_keys.items())
# We now need to figure out which verify requests we have keys # We now need to figure out which verify requests we have keys
# for and which we don't # for and which we don't
...@@ -348,8 +342,9 @@ class Keyring(object): ...@@ -348,8 +342,9 @@ class Keyring(object):
run_in_background(do_iterations).addErrback(on_err) run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
...@@ -359,6 +354,18 @@ class Keyring(object): ...@@ -359,6 +354,18 @@ class Keyring(object):
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult map from server_name -> key_id -> FetchKeyResult
""" """
raise NotImplementedError
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
...@@ -370,8 +377,127 @@ class Keyring(object): ...@@ -370,8 +377,127 @@ class Keyring(object):
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys) defer.returnValue(keys)
class BaseV2KeyFetcher(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.config = hs.get_config()
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys:
raise KeyLookupError(
"Key response must include verification keys for all signatures"
)
verify_signed_json(
response_json, server_name, verify_keys[key_id].verify_key
)
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(verify_keys)
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs):
super(PerspectivesKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids): def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
...@@ -408,28 +534,6 @@ class Keyring(object): ...@@ -408,28 +534,6 @@ class Keyring(object):
defer.returnValue(union_of_keys) defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, server_names_and_key_ids, perspective_name, perspective_keys
...@@ -520,6 +624,38 @@ class Keyring(object): ...@@ -520,6 +624,38 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs):
super(ServerKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
...@@ -568,107 +704,6 @@ class Keyring(object): ...@@ -568,107 +704,6 @@ class Keyring(object):
defer.returnValue({server_name: keys}) defer.returnValue({server_name: keys})
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys:
raise KeyLookupError(
"Key response must include verification keys for all signatures"
)
verify_signed_json(
response_json, server_name, verify_keys[key_id].verify_key
)
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(verify_keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_key_deferred(verify_request): def _handle_key_deferred(verify_request):
......
...@@ -24,7 +24,11 @@ from twisted.internet import defer ...@@ -24,7 +24,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError from synapse.crypto.keyring import (
KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
...@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called) self.assertFalse(d.called)
self.get_success(d) self.get_success(d)
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
return hs
def test_get_keys_from_server(self): def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit # arbitrarily advance the clock a bit
self.reactor.advance(100) self.reactor.advance(100)
SERVER_NAME = "server2" SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs) fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey) testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1" testverifykey_id = "ed25519:ver1"
...@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key, testverifykey)
...@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection # change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
self.get_failure( self.get_failure(
kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError fetcher.get_keys(server_name_and_key_ids), KeyLookupError
) )
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
keys = self.mock_perspective_server.get_verify_keys()
hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
return hs
def test_get_keys_from_perspectives(self): def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit # arbitrarily advance the clock a bit
self.reactor.advance(100) self.reactor.advance(100)
fetcher = PerspectivesKeyFetcher(self.hs)
SERVER_NAME = "server2" SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey) testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1" testverifykey_id = "ed25519:ver1"
...@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys) self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment