Skip to content
Snippets Groups Projects
Unverified Commit 8520bc31 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Fix Host header sent by MatrixFederationAgent (#4468)

Move the Host header logic down here so that (a) it is used if we reuse the
agent elsewhere, and (b) we can mess about with it with .well-known.
parent 0b3fd140
No related branches found
No related tags found
No related merge requests found
Move SRV logic into the Agent layer
...@@ -19,6 +19,7 @@ from zope.interface import implementer ...@@ -19,6 +19,7 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_server_name
...@@ -109,6 +110,15 @@ class MatrixFederationAgent(object): ...@@ -109,6 +110,15 @@ class MatrixFederationAgent(object):
else: else:
target = pick_server_from_list(server_list) target = pick_server_from_list(server_list)
# make sure that the Host header is set correctly
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', server_name_bytes)
class EndpointFactory(object): class EndpointFactory(object):
@staticmethod @staticmethod
def endpointForURI(_uri): def endpointForURI(_uri):
......
...@@ -255,7 +255,6 @@ class MatrixFederationHttpClient(object): ...@@ -255,7 +255,6 @@ class MatrixFederationHttpClient(object):
headers_dict = { headers_dict = {
b"User-Agent": [self.version_string_bytes], b"User-Agent": [self.version_string_bytes],
b"Host": [destination_bytes],
} }
with limiter: with limiter:
......
...@@ -131,6 +131,10 @@ class MatrixFederationAgentTests(TestCase): ...@@ -131,6 +131,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv:8448']
)
content = request.content.read() content = request.content.read()
self.assertEqual(content, b'') self.assertEqual(content, b'')
...@@ -195,6 +199,10 @@ class MatrixFederationAgentTests(TestCase): ...@@ -195,6 +199,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'1.2.3.4'],
)
# finish the request # finish the request
request.finish() request.finish()
...@@ -235,6 +243,10 @@ class MatrixFederationAgentTests(TestCase): ...@@ -235,6 +243,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# finish the request # finish the request
request.finish() request.finish()
...@@ -276,6 +288,10 @@ class MatrixFederationAgentTests(TestCase): ...@@ -276,6 +288,10 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b'GET') self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar') self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'testserv'],
)
# finish the request # finish the request
request.finish() request.finish()
......
...@@ -49,7 +49,6 @@ class FederationClientTests(HomeserverTestCase): ...@@ -49,7 +49,6 @@ class FederationClientTests(HomeserverTestCase):
return hs return hs
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
self.cl = MatrixFederationHttpClient(self.hs) self.cl = MatrixFederationHttpClient(self.hs)
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
...@@ -95,6 +94,7 @@ class FederationClientTests(HomeserverTestCase): ...@@ -95,6 +94,7 @@ class FederationClientTests(HomeserverTestCase):
# that should have made it send the request to the transport # that should have made it send the request to the transport
self.assertRegex(transport.value(), b"^GET /foo/bar") self.assertRegex(transport.value(), b"^GET /foo/bar")
self.assertRegex(transport.value(), b"Host: testserv:8008")
# Deferred is still without a result # Deferred is still without a result
self.assertNoResult(test_d) self.assertNoResult(test_d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment