Skip to content
Snippets Groups Projects
Unverified Commit 97fd29c0 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Don't send IP addresses as SNI (#4452)

The problem here is that we have cut-and-pasted an impl from Twisted, and then
failed to maintain it. It was fixed in Twisted in
https://github.com/twisted/twisted/pull/1047/files; let's do the same here.
parent 6b90ae6e
No related branches found
No related tags found
No related merge requests found
Don't send IP addresses as SNI
......@@ -17,6 +17,7 @@ from zope.interface import implementer
from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.ssl import CertificateOptions, ContextFactory
from twisted.python.failure import Failure
......@@ -98,8 +99,14 @@ class ClientTLSOptions(object):
def __init__(self, hostname, ctx):
self._ctx = ctx
self._hostname = hostname
self._hostnameBytes = _idnaBytes(hostname)
if isIPAddress(hostname) or isIPv6Address(hostname):
self._hostnameBytes = hostname.encode('ascii')
self._sendSNI = False
else:
self._hostnameBytes = _idnaBytes(hostname)
self._sendSNI = True
ctx.set_info_callback(
_tolerateErrors(self._identityVerifyingInfoCallback)
)
......@@ -111,7 +118,9 @@ class ClientTLSOptions(object):
return connection
def _identityVerifyingInfoCallback(self, connection, where, ret):
if where & SSL.SSL_CB_HANDSHAKE_START:
# Literal IPv4 and IPv6 addresses are not permitted
# as host names according to the RFCs
if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
connection.set_tlsext_host_name(self._hostnameBytes)
......
......@@ -46,7 +46,7 @@ class MatrixFederationAgentTests(TestCase):
_srv_resolver=self.mock_resolver,
)
def _make_connection(self, client_factory):
def _make_connection(self, client_factory, expected_sni):
"""Builds a test server, and completes the outgoing client connection
Returns:
......@@ -69,9 +69,17 @@ class MatrixFederationAgentTests(TestCase):
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(FakeTransport(client_protocol, self.reactor))
# finally, give the reactor a pump to get the TLS juices flowing.
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
# check the SNI
server_name = server_tls_protocol._tlsConnection.get_servername()
self.assertEqual(
server_name,
expected_sni,
"Expected SNI %s but got %s" % (expected_sni, server_name),
)
# fish the test server back out of the server-side TLS protocol.
return server_tls_protocol.wrappedProtocol
......@@ -113,7 +121,10 @@ class MatrixFederationAgentTests(TestCase):
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(client_factory)
http_server = self._make_connection(
client_factory,
expected_sni=b"testserv",
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
......@@ -150,6 +161,52 @@ class MatrixFederationAgentTests(TestCase):
json = self.successResultOf(treq.json_content(response))
self.assertEqual(json, {"a": 1})
def test_get_ip_address(self):
"""
Test the behaviour when the server name contains an explicit IP (with no port)
"""
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
self.mock_resolver.resolve_service.side_effect = lambda _: []
# then there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once()
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=None,
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
# XXX currently broken
# self.assertEqual(
# request.requestHeaders.getRawHeaders(b'host'),
# [b'1.2.3.4:8448']
# )
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def _check_logcontext(context):
current = LoggingContext.current_context()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment