Skip to content
Snippets Groups Projects
Unverified Commit 508196e0 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Reject invalid server names (#3480)

Make sure that server_names used in auth headers are sane, and reject them with
a sensible error code, before they disappear off into the depths of the system.
parent f7416308
No related branches found
No related tags found
No related merge requests found
Reject invalid server names in federation requests
......@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.endpoint import parse_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
......@@ -99,26 +100,6 @@ class Authenticator(object):
origin = None
def parse_auth_header(header_str):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except Exception:
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
......@@ -127,8 +108,8 @@ class Authenticator(object):
)
for auth in auth_headers:
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
if auth.startswith(b"X-Matrix"):
(origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
......@@ -165,6 +146,47 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
"""
try:
header_str = header_bytes.decode('utf-8')
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith(b"\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
# ensure that the origin is a valid server name
parse_server_name(origin)
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
logger.warn(
"Error parsing auth header '%s': %s",
header_bytes.decode('ascii', 'replace'),
e,
)
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED,
)
class BaseFederationServlet(object):
REQUIRE_AUTH = True
......
......@@ -38,6 +38,36 @@ _Server = collections.namedtuple(
)
def parse_server_name(server_name):
"""Split a server name into host/port parts.
Does some basic sanity checking of the
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == ']':
# ipv6 literal, hopefully
if server_name[0] != '[':
raise Exception()
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None):
"""Construct an endpoint for the given matrix destination.
......@@ -50,9 +80,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds
"""
domain_port = destination.split(":")
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
domain, port = parse_server_name(destination)
endpoint_kw_args = {}
......
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.endpoint import parse_server_name
from tests import unittest
class ServerNameTestCase(unittest.TestCase):
def test_parse_server_name(self):
test_data = {
'localhost': ('localhost', None),
'my-example.com:1234': ('my-example.com', 1234),
'1.2.3.4': ('1.2.3.4', None),
'[0abc:1def::1234]': ('[0abc:1def::1234]', None),
'1.2.3.4:1': ('1.2.3.4', 1),
'[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
}
for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o)
def test_parse_bad_server_names(self):
test_data = [
"", # empty
"localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't
]
for i in test_data:
try:
parse_server_name(i)
self.fail(
"Expected parse_server_name(\"%s\") to throw" % i,
)
except ValueError:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment