Skip to content
Snippets Groups Projects
Unverified Commit c9b91436 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Fix-up type hints in tests/server.py. (#15084)

This file was being ignored by mypy, we remove that
and add the missing type hints & deal with any fallout.
parent 61bfcd66
No related branches found
No related tags found
No related merge requests found
Improve type hints.
...@@ -31,8 +31,6 @@ exclude = (?x) ...@@ -31,8 +31,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/server.py
)$ )$
[mypy-synapse.federation.transport.client] [mypy-synapse.federation.transport.client]
......
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock from unittest.mock import Mock
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
...@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock ...@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock from ..utils import MockClock
if TYPE_CHECKING:
from twisted.internet.testing import MemoryReactor
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
......
...@@ -30,7 +30,7 @@ from twisted.internet.interfaces import ( ...@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator, IOpenSSLClientConnectionCreator,
IProtocolFactory, IProtocolFactory,
) )
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent from twisted.web.client import Agent
...@@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase): ...@@ -466,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else: else:
assert isinstance(proxy_server_transport, FakeTransport) assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol c2s_transport.other = server_ssl_protocol
self.reactor.advance(0) self.reactor.advance(0)
......
...@@ -28,7 +28,7 @@ from twisted.internet.endpoints import ( ...@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol, _WrappingProtocol,
) )
from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
...@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase): ...@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else: else:
assert isinstance(proxy_server_transport, FakeTransport) assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol c2s_transport.other = server_ssl_protocol
self.reactor.advance(0) self.reactor.advance(0)
......
...@@ -34,7 +34,7 @@ from synapse.util import Clock ...@@ -34,7 +34,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
from tests.server import FakeChannel, make_request from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless from tests.unittest import override_config, skip_unless
...@@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase): ...@@ -1322,16 +1322,8 @@ class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
channel = self.submit_logout_token(logout_token) channel = self.submit_logout_token(logout_token)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
# Now try to exchange the login token # Now try to exchange the login token, it should fail.
channel = make_request( self.helper.login_via_token(login_token, 403)
self.hs.get_reactor(),
self.site,
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
# It should have failed
self.assertEqual(channel.code, 403)
@override_config( @override_config(
{ {
......
...@@ -36,6 +36,7 @@ from urllib.parse import urlencode ...@@ -36,6 +36,7 @@ from urllib.parse import urlencode
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
...@@ -67,6 +68,7 @@ class RestHelper: ...@@ -67,6 +68,7 @@ class RestHelper:
""" """
hs: HomeServer hs: HomeServer
reactor: MemoryReactorClock
site: Site site: Site
auth_user_id: Optional[str] auth_user_id: Optional[str]
...@@ -142,7 +144,7 @@ class RestHelper: ...@@ -142,7 +144,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
path, path,
...@@ -216,7 +218,7 @@ class RestHelper: ...@@ -216,7 +218,7 @@ class RestHelper:
data["reason"] = reason data["reason"] = reason
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
path, path,
...@@ -313,7 +315,7 @@ class RestHelper: ...@@ -313,7 +315,7 @@ class RestHelper:
data.update(extra_data or {}) data.update(extra_data or {})
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"PUT", "PUT",
path, path,
...@@ -394,7 +396,7 @@ class RestHelper: ...@@ -394,7 +396,7 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"PUT", "PUT",
path, path,
...@@ -433,7 +435,7 @@ class RestHelper: ...@@ -433,7 +435,7 @@ class RestHelper:
path = path + f"?access_token={tok}" path = path + f"?access_token={tok}"
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
path, path,
...@@ -488,7 +490,7 @@ class RestHelper: ...@@ -488,7 +490,7 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
channel = make_request(self.hs.get_reactor(), self.site, method, path, content) channel = make_request(self.reactor, self.site, method, path, content)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code, expect_code,
...@@ -573,8 +575,8 @@ class RestHelper: ...@@ -573,8 +575,8 @@ class RestHelper:
image_length = len(image_data) image_length = len(image_data)
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
FakeSite(resource, self.hs.get_reactor()), FakeSite(resource, self.reactor),
"POST", "POST",
path, path,
content=image_data, content=image_data,
...@@ -603,7 +605,7 @@ class RestHelper: ...@@ -603,7 +605,7 @@ class RestHelper:
expect_code: The return code to expect from attempting the whoami request expect_code: The return code to expect from attempting the whoami request
""" """
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
"account/whoami", "account/whoami",
...@@ -642,7 +644,7 @@ class RestHelper: ...@@ -642,7 +644,7 @@ class RestHelper:
) -> Tuple[JsonDict, FakeAuthorizationGrant]: ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC """Log in (as a new user) via OIDC
Returns the result of the final token login. Returns the result of the final token login and the fake authorization grant.
Requires that "oidc_config" in the homeserver config be set appropriately Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
...@@ -672,10 +674,28 @@ class RestHelper: ...@@ -672,10 +674,28 @@ class RestHelper:
assert m, channel.text_body assert m, channel.text_body
login_token = m.group(1) login_token = m.group(1)
# finally, submit the matrix login token to the login API, which gives us our return self.login_via_token(login_token, expected_status), grant
# matrix access token and device id.
def login_via_token(
self,
login_token: str,
expected_status: int = 200,
) -> JsonDict:
"""Submit the matrix login token to the login API, which gives us our
matrix access token and device id.Log in (as a new user) via OIDC
Returns the result of the token login.
Requires that "oidc_config" in the homeserver config be set appropriately
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
"public_base_url".
Also requires the login servlet and the OIDC callback resource to be mounted at
the normal places.
"""
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"POST", "POST",
"/login", "/login",
...@@ -684,7 +704,7 @@ class RestHelper: ...@@ -684,7 +704,7 @@ class RestHelper:
assert ( assert (
channel.code == expected_status channel.code == expected_status
), f"unexpected status in response: {channel.code}" ), f"unexpected status in response: {channel.code}"
return channel.json_body, grant return channel.json_body
def auth_via_oidc( def auth_via_oidc(
self, self,
...@@ -805,7 +825,7 @@ class RestHelper: ...@@ -805,7 +825,7 @@ class RestHelper:
with fake_serer.patch_homeserver(hs=self.hs): with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code # now hit the callback URI with the right params and a made-up code
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
callback_uri, callback_uri,
...@@ -849,7 +869,7 @@ class RestHelper: ...@@ -849,7 +869,7 @@ class RestHelper:
# is the easiest way of figuring out what the Host header ought to be set to # is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy. # to keep Synapse happy.
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
uri, uri,
...@@ -867,7 +887,7 @@ class RestHelper: ...@@ -867,7 +887,7 @@ class RestHelper:
location = get_location(channel) location = get_location(channel)
parts = urllib.parse.urlsplit(location) parts = urllib.parse.urlsplit(location)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.reactor,
self.site, self.site,
"GET", "GET",
urllib.parse.urlunsplit(("", "") + parts[2:]), urllib.parse.urlunsplit(("", "") + parts[2:]),
...@@ -900,9 +920,7 @@ class RestHelper: ...@@ -900,9 +920,7 @@ class RestHelper:
+ urllib.parse.urlencode({"session": ui_auth_session_id}) + urllib.parse.urlencode({"session": ui_auth_session_id})
) )
# hit the redirect url (which will issue a cookie and state) # hit the redirect url (which will issue a cookie and state)
channel = make_request( channel = make_request(self.reactor, self.site, "GET", sso_redirect_endpoint)
self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
)
# that should serve a confirmation page # that should serve a confirmation page
assert channel.code == HTTPStatus.OK, channel.text_body assert channel.code == HTTPStatus.OK, channel.text_body
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
......
This diff is collapsed.
...@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol ...@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
...@@ -82,7 +82,7 @@ from tests.server import ( ...@@ -82,7 +82,7 @@ from tests.server import (
) )
from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb from tests.utils import checked_cast, default_config, setupdb
setupdb() setupdb()
setup_logging() setup_logging()
...@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase): ...@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.utils import RestHelper from tests.rest.client.utils import RestHelper
self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) self.helper = RestHelper(
self.hs,
checked_cast(MemoryReactorClock, self.hs.get_reactor()),
self.site,
getattr(self, "user_id", None),
)
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment