Skip to content
Snippets Groups Projects
Unverified Commit 8c41c04e authored by Erik Johnston's avatar Erik Johnston Committed by GitHub
Browse files

Merge pull request #5244 from matrix-org/rav/server_keys/00-factor-out-fetchers

Factor out KeyFetchers from KeyRing
parents 6368150a ec24108c
No related branches found
No related tags found
No related merge requests found
Refactor synapse.crypto.keyring to use a KeyFetcher interface.
...@@ -80,12 +80,13 @@ class KeyLookupError(ValueError): ...@@ -80,12 +80,13 @@ class KeyLookupError(ValueError):
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.config = hs.get_config() self._key_fetchers = (
self.perspective_servers = self.config.perspectives StoreKeyFetcher(hs),
self.hs = hs PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
# map from server name to Deferred. Has an entry for each server with # map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download # an ongoing key download; the Deferred completes once the download
...@@ -271,13 +272,6 @@ class Keyring(object): ...@@ -271,13 +272,6 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
...@@ -288,8 +282,8 @@ class Keyring(object): ...@@ -288,8 +282,8 @@ class Keyring(object):
verify_request.key_ids verify_request.key_ids
) )
for fn in key_fetch_fns: for f in self._key_fetchers:
results = yield fn(missing_keys.items()) results = yield f.get_keys(missing_keys.items())
# We now need to figure out which verify requests we have keys # We now need to figure out which verify requests we have keys
# for and which we don't # for and which we don't
...@@ -348,8 +342,9 @@ class Keyring(object): ...@@ -348,8 +342,9 @@ class Keyring(object):
run_in_background(do_iterations).addErrback(on_err) run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
...@@ -359,6 +354,18 @@ class Keyring(object): ...@@ -359,6 +354,18 @@ class Keyring(object):
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult map from server_name -> key_id -> FetchKeyResult
""" """
raise NotImplementedError
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
...@@ -370,8 +377,127 @@ class Keyring(object): ...@@ -370,8 +377,127 @@ class Keyring(object):
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys) defer.returnValue(keys)
class BaseV2KeyFetcher(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.config = hs.get_config()
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys:
raise KeyLookupError(
"Key response must include verification keys for all signatures"
)
verify_signed_json(
response_json, server_name, verify_keys[key_id].verify_key
)
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(verify_keys)
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs):
super(PerspectivesKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids): def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
...@@ -408,28 +534,6 @@ class Keyring(object): ...@@ -408,28 +534,6 @@ class Keyring(object):
defer.returnValue(union_of_keys) defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, server_names_and_key_ids, perspective_name, perspective_keys
...@@ -520,6 +624,38 @@ class Keyring(object): ...@@ -520,6 +624,38 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs):
super(ServerKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
...@@ -568,107 +704,6 @@ class Keyring(object): ...@@ -568,107 +704,6 @@ class Keyring(object):
defer.returnValue({server_name: keys}) defer.returnValue({server_name: keys})
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
# TODO: improve this signature checking
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in verify_keys:
raise KeyLookupError(
"Key response must include verification keys for all signatures"
)
verify_signed_json(
response_json, server_name, verify_keys[key_id].verify_key
)
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(verify_keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_key_deferred(verify_request): def _handle_key_deferred(verify_request):
......
...@@ -20,7 +20,7 @@ from twisted.web.resource import Resource ...@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import KeyLookupError from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
...@@ -89,7 +89,7 @@ class RemoteKey(Resource): ...@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
...@@ -217,7 +217,7 @@ class RemoteKey(Resource): ...@@ -217,7 +217,7 @@ class RemoteKey(Resource):
if cache_misses and query_remote_on_cache_miss: if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items(): for server_name, key_ids in cache_misses.items():
try: try:
yield self.keyring.get_server_verify_key_v2_direct( yield self.fetcher.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except KeyLookupError as e: except KeyLookupError as e:
......
...@@ -24,7 +24,11 @@ from twisted.internet import defer ...@@ -24,7 +24,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError from synapse.crypto.keyring import (
KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
...@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called) self.assertFalse(d.called)
self.get_success(d) self.get_success(d)
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
return hs
def test_get_keys_from_server(self): def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit # arbitrarily advance the clock a bit
self.reactor.advance(100) self.reactor.advance(100)
SERVER_NAME = "server2" SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs) fetcher = ServerKeyFetcher(self.hs)
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey) testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1" testverifykey_id = "ed25519:ver1"
...@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key, testverifykey)
...@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection # change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
self.get_failure( self.get_failure(
kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError fetcher.get_keys(server_name_and_key_ids), KeyLookupError
) )
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
keys = self.mock_perspective_server.get_verify_keys()
hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
return hs
def test_get_keys_from_perspectives(self): def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit # arbitrarily advance the clock a bit
self.reactor.advance(100) self.reactor.advance(100)
fetcher = PerspectivesKeyFetcher(self.hs)
SERVER_NAME = "server2" SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1") testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey) testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1" testverifykey_id = "ed25519:ver1"
...@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ...@@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys) self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment