Skip to content
Snippets Groups Projects
Commit 66848557 authored by Mark Haines's avatar Mark Haines
Browse files

Verify signatures for server2server requests

parent 10ef8e6e
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ from .transport import TransportLayer
def initialize_http_replication(homeserver):
transport = TransportLayer(
homeserver,
homeserver.hostname,
server=homeserver.get_resource_for_federation(),
client=homeserver.get_http_client()
......
......@@ -54,7 +54,7 @@ class TransportLayer(object):
we receive data.
"""
def __init__(self, server_name, server, client):
def __init__(self, homeserver, server_name, server, client):
"""
Args:
server_name (str): Local home server host
......@@ -63,6 +63,7 @@ class TransportLayer(object):
client (synapse.protocol.http.HttpClient): the http client used to
send requests
"""
self.keyring = homeserver.get_keyring()
self.server_name = server_name
self.server = server
self.client = client
......@@ -195,6 +196,66 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
"destination": self.server_name,
"signatures": {},
}
content = None
origin = None
if request.method == "PUT":
#TODO: Handle other method types? other content types?
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
def parse_auth_header(header_str):
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
#TODO(markjh): Send a 401 response?
raise Exception("Missing auth headers")
for auth in auth_headers:
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin,{})[key] = sig
from syutil.jsonutil import encode_canonical_json
logger.debug("Checking %s %s",
origin, encode_canonical_json(json_request))
yield self.keyring.verify_json_for_server(origin, json_request)
defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
(origin, content) = yield self._authenticate_request(request)
response = yield handler(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response)
return new_handler
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
......@@ -208,7 +269,7 @@ class TransportLayer(object):
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
self._on_send_request
self._with_authentication(self._on_send_request)
)
@log_function
......@@ -226,9 +287,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pull/$"),
lambda request: handler.on_pull_request(
request.args["origin"][0],
request.args["v"]
self._with_authentication(
lambda origin, content, query:
handler.on_pull_request(query["origin"][0], query["v"])
)
)
......@@ -237,8 +298,9 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
pdu_origin, pdu_id
self._with_authentication(
lambda origin, content, query, pdu_origin, pdu_id:
handler.on_pdu_request(pdu_origin, pdu_id)
)
)
......@@ -246,38 +308,47 @@ class TransportLayer(object):
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
lambda request, context: handler.on_context_state_request(
context
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(context)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
lambda request, context: self._on_backfill_request(
context, request.args["v"],
request.args["limit"]
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
context, query["v"], query["limit"]
)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
lambda request, context: handler.on_context_pdus_request(context)
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_pdus_request(context)
)
)
# This is when we receive a server-server Query
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/query/([^/]*)$"),
lambda request, query_type: handler.on_query_request(
query_type, {k: v[0] for k, v in request.args.items()}
self._with_authentication(
lambda origin, content, query, query_type:
handler.on_query_request(
query_type, {k: v[0] for k, v in query.items()}
)
)
)
@defer.inlineCallbacks
@log_function
def _on_send_request(self, request, transaction_id):
def _on_send_request(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
......@@ -292,12 +363,7 @@ class TransportLayer(object):
"""
# Parse the request
try:
data = request.content.read()
l = data[:20].encode("string_escape")
logger.debug("Got data: \"%s\"", l)
transaction_data = json.loads(data)
transaction_data = content
logger.debug(
"Decoded %s: %s",
......
......@@ -177,16 +177,20 @@ class MatrixHttpClient(BaseHttpClient):
request = sign_json(request, self.server_name, self.signing_key)
from syutil.jsonutil import encode_canonical_json
logger.debug("Signing " + " " * 11 + "%s %s",
self.server_name, encode_canonical_json(request))
auth_headers = []
for key,sig in request["signatures"][self.server_name].items():
auth_headers.append(
auth_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
)
))
headers_dict["Authorization"] = auth_headers
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None):
......
......@@ -221,6 +221,7 @@ class FederationTestCase(unittest.TestCase):
json_data_callback=ANY,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()
......
......@@ -76,6 +76,9 @@ class MockHttpResource(HttpServer):
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method
mock_request.uri = path
# return the right path if the event requires it
mock_request.path = path
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment