Skip to content
Snippets Groups Projects
Commit f65e31d2 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Do an AAAA lookup on SRV record targets (#2462)

Support SRV records which point at AAAA records, as well as A records.

Fixes https://github.com/matrix-org/synapse/issues/2405
parent f496399a
Branches
Tags
No related merge requests found
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
...@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__) ...@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {} SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple( _Server = collections.namedtuple(
"_Server", "priority weight host port expires" "_Server", "priority weight host port expires"
) )
...@@ -219,9 +223,10 @@ class SRVClientEndpoint(object): ...@@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
return self.default_server return self.default_server
else: else:
raise ConnectError( raise ConnectError(
"Not server available for %s" % self.service_name "No server available for %s" % self.service_name
) )
# look for all servers with the same priority
min_priority = self.servers[0].priority min_priority = self.servers[0].priority
weight_indexes = list( weight_indexes = list(
(index, server.weight + 1) (index, server.weight + 1)
...@@ -231,11 +236,22 @@ class SRVClientEndpoint(object): ...@@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes) total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight) target_weight = random.randint(0, total_weight)
for index, weight in weight_indexes: for index, weight in weight_indexes:
target_weight -= weight target_weight -= weight
if target_weight <= 0: if target_weight <= 0:
server = self.servers[index] server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index] del self.servers[index]
self.used_servers.append(server) self.used_servers.append(server)
return server return server
...@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t ...@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue continue
payload = answer.payload payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try: hosts = yield _get_hosts_for_srv_record(
answers, _, _ = yield dns_client.lookupAddress(host) dns_client, str(payload.target)
except DNSNameError: )
continue
for answer in answers: for (ip, ttl) in hosts:
if answer.type == dns.A and answer.payload: host_ttl = min(answer.ttl, ttl)
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
servers.append(_Server( servers.append(_Server(
host=ip, host=ip,
port=int(payload.port), port=int(payload.port),
priority=int(payload.priority), priority=int(payload.priority),
weight=int(payload.weight), weight=int(payload.weight),
expires=int(clock.time()) + host_ttl, expires=int(clock.time()) + host_ttl,
)) ))
servers.sort() servers.sort()
cache[service_name] = list(servers) cache[service_name] = list(servers)
...@@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t ...@@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e raise e
defer.returnValue(servers) defer.returnValue(servers)
@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []
def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.
return res[0]
def eb(res):
res.trap(DNSNameError)
return []
# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
for result in results:
for answer in result:
if not answer.payload:
continue
try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue
# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)
...@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service ...@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
from tests.utils import MockClock from tests.utils import MockClock
@unittest.DEBUG
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve(self): def test_resolve(self):
dns_client_mock = Mock() dns_client_mock = Mock()
service_name = "test_service.examle.com" service_name = "test_service.example.com"
host_name = "example.com" host_name = "example.com"
ip_address = "127.0.0.1" ip_address = "127.0.0.1"
ip6_address = "::1"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, type=dns.SRV,
...@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase): ...@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
) )
) )
dns_client_mock.lookupService.return_value = ([answer_srv], None, None) answer_aaaa = dns.RRHeader(
dns_client_mock.lookupAddress.return_value = ([answer_a], None, None) type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)
dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None),
)
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)
cache = {} cache = {}
...@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase): ...@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name) dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name) dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
self.assertEquals(len(servers), 1) self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name]) self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address) self.assertEquals(servers[0].host, ip_address)
self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self): def test_from_cache_expired_and_dns_fail(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment