Skip to content
Snippets Groups Projects
Unverified Commit 3c0213a2 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Merge pull request #3439 from vojeroen/send_sni_for_federation_requests

send SNI for federation requests
parents 0ad98e38 2e9c73e8
No related branches found
No related tags found
No related merge requests found
Add support for the SNI extension to federation TLS connections
\ No newline at end of file
...@@ -168,11 +168,13 @@ def start(config_options): ...@@ -168,11 +168,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer( ss = ClientReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -174,11 +174,13 @@ def start(config_options): ...@@ -174,11 +174,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer( ss = EventCreatorServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -143,11 +143,13 @@ def start(config_options): ...@@ -143,11 +143,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer( ss = FederationReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -186,11 +186,13 @@ def start(config_options): ...@@ -186,11 +186,13 @@ def start(config_options):
config.send_federation = True config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer( ps = FederationSenderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -208,11 +208,13 @@ def start(config_options): ...@@ -208,11 +208,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer( ss = FrontendProxyServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -338,6 +338,7 @@ def setup(config_options): ...@@ -338,6 +338,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
...@@ -346,6 +347,7 @@ def setup(config_options): ...@@ -346,6 +347,7 @@ def setup(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -155,11 +155,13 @@ def start(config_options): ...@@ -155,11 +155,13 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer( ss = MediaRepositoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -214,11 +214,13 @@ def start(config_options): ...@@ -214,11 +214,13 @@ def start(config_options):
config.update_user_directory = True config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer( ps = UserDirectoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
......
...@@ -11,19 +11,22 @@ ...@@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from zope.interface import implementer
from OpenSSL import SSL, crypto from OpenSSL import SSL, crypto
from twisted.internet import ssl
from twisted.internet._sslverify import _defaultCurveName from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ServerContextFactory(ssl.ContextFactory): class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming """Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections and to make connections to remote servers.""" connections."""
def __init__(self, config): def __init__(self, config):
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
...@@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory): ...@@ -48,3 +51,78 @@ class ServerContextFactory(ssl.ContextFactory):
def getContext(self): def getContext(self):
return self._context return self._context
def _idnaBytes(text):
"""
Convert some text typed by a human into some ASCII bytes. This is a
copy of twisted.internet._idna._idnaBytes. For documentation, see the
twisted documentation.
"""
try:
import idna
except ImportError:
return text.encode("idna")
else:
return idna.encode(text)
def _tolerateErrors(wrapped):
"""
Wrap up an info_callback for pyOpenSSL so that if something goes wrong
the error is immediately logged and the connection is dropped if possible.
This is a copy of twisted.internet._sslverify._tolerateErrors. For
documentation, see the twisted documentation.
"""
def infoCallback(connection, where, ret):
try:
return wrapped(connection, where, ret)
except: # noqa: E722, taken from the twisted implementation
f = Failure()
logger.exception("Error during info_callback")
connection.get_app_data().failVerification(f)
return infoCallback
@implementer(IOpenSSLClientConnectionCreator)
class ClientTLSOptions(object):
"""
Client creator for TLS without certificate identity verification. This is a
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
verification left out. For documentation, see the twisted documentation.
"""
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
def clientConnectionForTLS(self, tlsProtocol):
context = self._ctx
connection = SSL.Connection(context, None)
connection.set_app_data(tlsProtocol)
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
connection.set_tlsext_host_name(self._hostnameBytes)
class ClientTLSOptionsFactory(object):
"""Factory for Twisted ClientTLSOptions that are used to make connections
to remote servers for federation."""
def __init__(self, config):
# We don't use config options yet
pass
def get_options(self, host):
return ClientTLSOptions(
host.decode('utf-8'),
CertificateOptions(verify=False).getContext()
)
...@@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/" ...@@ -30,14 +30,14 @@ KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path factory.path = path
factory.host = server_name factory.host = server_name
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, tls_client_options_factory, timeout=30
) )
for i in range(5): for i in range(5):
......
...@@ -512,7 +512,7 @@ class Keyring(object): ...@@ -512,7 +512,7 @@ class Keyring(object):
continue continue
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory, server_name, self.hs.tls_client_options_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
...@@ -655,7 +655,7 @@ class Keyring(object): ...@@ -655,7 +655,7 @@ class Keyring(object):
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory server_name, self.hs.tls_client_options_factory
) )
# Check the response. # Check the response.
......
...@@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError ...@@ -26,7 +26,6 @@ from twisted.names.error import DNSNameError, DomainError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination. # our record of an individual server which can be tried to reach a destination.
...@@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name): ...@@ -103,15 +102,16 @@ def parse_and_validate_server_name(server_name):
return host, port return host, port
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, tls_client_options_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
Args: Args:
reactor: Twisted reactor. reactor: Twisted reactor.
destination (bytes): The name of the server to connect to. destination (bytes): The name of the server to connect to.
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory tls_client_options_factory
which generates SSL contexts to use for TLS. (synapse.crypto.context_factory.ClientTLSOptionsFactory):
Factory which generates TLS options for client connections.
timeout (int): connection timeout in seconds timeout (int): connection timeout in seconds
""" """
...@@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, ...@@ -122,13 +122,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
if timeout is not None: if timeout is not None:
endpoint_kw_args.update(timeout=timeout) endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None: if tls_client_options_factory is None:
transport_endpoint = HostnameEndpoint transport_endpoint = HostnameEndpoint
default_port = 8008 default_port = 8008
else: else:
def transport_endpoint(reactor, host, port, timeout): def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS( return wrapClientTLS(
ssl_context_factory, tls_client_options_factory.get_options(host),
HostnameEndpoint(reactor, host, port, timeout=timeout)) HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448 default_port = 8448
......
...@@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3 ...@@ -61,14 +61,14 @@ MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
ssl_context_factory=self.tls_server_context_factory tls_client_options_factory=self.tls_client_options_factory
) )
......
...@@ -127,6 +127,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None ...@@ -127,6 +127,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
database_engine=db_engine, database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor, reactor=reactor,
**kargs **kargs
) )
...@@ -147,6 +148,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None ...@@ -147,6 +148,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
database_engine=db_engine, database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor, reactor=reactor,
**kargs **kargs
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment