Skip to content
Snippets Groups Projects
Unverified Commit b9325908 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Add missing type hints to non-client REST servlets. (#10817)

Including admin, consent, key, synapse, and media. All REST servlets
(the synapse.rest module) now require typed method definitions.
parent 8c7a531e
No related branches found
No related tags found
No related merge requests found
Showing
with 129 additions and 74 deletions
Convert the internal `FileInfo` class to attrs and add type hints. Add missing type hints to REST servlets.
Add missing type hints to REST servlets.
...@@ -90,7 +90,7 @@ files = ...@@ -90,7 +90,7 @@ files =
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*] [mypy-synapse.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue] [mypy-synapse.util.batching_queue]
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from typing import TYPE_CHECKING
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import ( from synapse.rest.client import (
account, account,
...@@ -57,6 +59,9 @@ from synapse.rest.client import ( ...@@ -57,6 +59,9 @@ from synapse.rest.client import (
voip, voip,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""Matrix Client API REST resource. """Matrix Client API REST resource.
...@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource): ...@@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc * etc
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False) JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs) self.register_servlets(self, hs)
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource) versions.register_servlets(hs, client_resource)
# Deprecated in r0 # Deprecated in r0
......
...@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet): ...@@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id, device_id: str self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
......
...@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet): ...@@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer): def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice" PATTERN = "/send_server_notice"
json_resource.register_paths( json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
......
...@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet): ...@@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {} self.nonces: Dict[str, int] = {}
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self) -> None:
""" """
Clear out old nonces that are older than NONCE_TIMEOUT. Clear out old nonces that are older than NONCE_TIMEOUT.
""" """
......
...@@ -17,17 +17,22 @@ import logging ...@@ -17,17 +17,22 @@ import logging
from hashlib import sha256 from hashlib import sha256
from http import HTTPStatus from http import HTTPStatus
from os import path from os import path
from typing import Dict, List from typing import TYPE_CHECKING, Any, Dict, List
import jinja2 import jinja2
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from twisted.web.server import Request
from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# language to use for the templates. TODO: figure this out from Accept-Language # language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en" TEMPLATE_LANGUAGE = "en"
...@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource): ...@@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user. against the user.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
...@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource): ...@@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8") self._hmac_secret = hs.config.form_secret.encode("utf-8")
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", default=self._default_consent_version) version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="") username = parse_string(request, "u", default="")
userhmac = None userhmac = None
has_consented = False has_consented = False
public_version = username == "" public_version = username == ""
if not public_version: if not public_version:
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True) userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes) self._check_hash(username, userhmac_bytes)
...@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource): ...@@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("Unknown policy version") raise NotFoundError("Unknown policy version")
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", required=True) version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True) userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)
...@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource): ...@@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("success.html not found") raise NotFoundError("success.html not found")
def _render_template(self, request, template_name, **template_args): def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much # get_template checks for ".." so we don't need to worry too much
# about path traversal here. # about path traversal here.
template_html = self._jinja_env.get_template( template_html = self._jinja_env.get_template(
...@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource): ...@@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args) html = template_html.render(**template_args)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac): def _check_hash(self, userid: str, userhmac: bytes) -> None:
""" """
Args: Args:
userid (unicode): userid:
userhmac (bytes): userhmac:
Raises: Raises:
SynapseError if the hash doesn't match SynapseError if the hash doesn't match
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
class HealthResource(Resource): class HealthResource(Resource):
...@@ -25,6 +26,6 @@ class HealthResource(Resource): ...@@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1 isLeaf = 1
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain") request.setHeader(b"Content-Type", b"text/plain")
return b"OK" return b"OK"
...@@ -12,14 +12,19 @@ ...@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from .local_key_resource import LocalKey from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey from .remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
class KeyApiV2Resource(Resource): class KeyApiV2Resource(Resource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"server", LocalKey(hs)) self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs)) self.putChild(b"query", RemoteKey(hs))
...@@ -12,16 +12,21 @@ ...@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes from synapse.http.server import respond_with_json_bytes
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,18 +63,18 @@ class LocalKey(Resource): ...@@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec()) self.update_response_body(self.clock.time_msec())
Resource.__init__(self) Resource.__init__(self)
def update_response_body(self, time_now_msec): def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval) self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object()) self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self): def response_json_object(self) -> JsonDict:
verify_keys = {} verify_keys = {}
for key in self.config.signing_key: for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode() verify_key_bytes = key.verify_key.encode()
...@@ -94,7 +99,7 @@ class LocalKey(Resource): ...@@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key) json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object return json_object
def render_GET(self, request): def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains. # Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
......
...@@ -13,17 +13,23 @@ ...@@ -13,17 +13,23 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource): ...@@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
...@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource): ...@@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config self.config = hs.config
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
assert request.postpath is not None
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
query: dict = {server.decode("ascii"): {}} query: dict = {server.decode("ascii"): {}}
...@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource): ...@@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def query_keys(self, request, query, query_remote_on_cache_miss=False): async def query_keys(
self,
request: Request,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
...@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource): ...@@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused. # Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {} cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items(): for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in results] results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None: if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0 cache_misses.setdefault(server_name, {})[key_id] = 0
...@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource): ...@@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json) signed_keys.append(key_json)
results = {"server_keys": signed_keys} response = {"server_keys": signed_keys}
respond_with_json(request, 200, results, canonical_json=True) respond_with_json(request, 200, response, canonical_json=True)
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable, Dict, Generator, List, Optional, Tuple from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
import attr import attr
...@@ -122,7 +123,7 @@ def add_file_headers( ...@@ -122,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
def _quote(x): def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8")) return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types. # Default to a UTF-8 charset for text content types.
...@@ -282,10 +283,15 @@ class Responder: ...@@ -282,10 +283,15 @@ class Responder:
""" """
pass pass
def __enter__(self): def __enter__(self) -> None:
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass pass
...@@ -317,31 +323,31 @@ class FileInfo: ...@@ -317,31 +323,31 @@ class FileInfo:
# The below properties exist to maintain compatibility with third-party modules. # The below properties exist to maintain compatibility with third-party modules.
@property @property
def thumbnail_width(self): def thumbnail_width(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.width return self.thumbnail.width
@property @property
def thumbnail_height(self): def thumbnail_height(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.height return self.thumbnail.height
@property @property
def thumbnail_method(self): def thumbnail_method(self) -> Optional[str]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.method return self.thumbnail.method
@property @property
def thumbnail_type(self): def thumbnail_type(self) -> Optional[str]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.type return self.thumbnail.type
@property @property
def thumbnail_length(self): def thumbnail_length(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.length return self.thumbnail.length
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import functools import functools
import os import os
import re import re
from typing import Callable, List from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
...@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]: ...@@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
""" """
@functools.wraps(func) @functools.wraps(func)
def _wrapped(self, *args, **kwargs): def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs) path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path) return os.path.join(self.base_path, path)
...@@ -129,7 +129,7 @@ class MediaFilePaths: ...@@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path. # using the new path.
def remote_media_thumbnail_rel_legacy( def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str self, server_name: str, file_id: str, width: int, height: int, content_type: str
): ) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join( return os.path.join(
......
...@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple ...@@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
...@@ -32,6 +33,7 @@ from synapse.api.errors import ( ...@@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID from synapse.types import UserID
...@@ -114,7 +116,7 @@ class MediaRepository: ...@@ -114,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
) )
def _start_update_recently_accessed(self): def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed
) )
...@@ -469,7 +471,9 @@ class MediaRepository: ...@@ -469,7 +471,9 @@ class MediaRepository:
return media_info return media_info
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(
self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";") scpos = media_type.find(";")
if scpos > 0: if scpos > 0:
media_type = media_type[:scpos] media_type = media_type[:scpos]
......
...@@ -15,7 +15,20 @@ import contextlib ...@@ -15,7 +15,20 @@ import contextlib
import logging import logging
import os import os
import shutil import shutil
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Type,
)
import attr import attr
...@@ -83,12 +96,14 @@ class MediaStorage: ...@@ -83,12 +96,14 @@ class MediaStorage:
return fname return fname
async def write_to_file(self, source: IO, output: IO): async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`.""" """Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output) await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager @contextlib.contextmanager
def store_into_file(self, file_info: FileInfo): def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as """Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
...@@ -125,7 +140,7 @@ class MediaStorage: ...@@ -125,7 +140,7 @@ class MediaStorage:
try: try:
with open(fname, "wb") as f: with open(fname, "wb") as f:
async def finish(): async def finish() -> None:
# Ensure that all writes have been flushed and close the # Ensure that all writes have been flushed and close the
# file. # file.
f.flush() f.flush()
...@@ -315,7 +330,12 @@ class FileResponder(Responder): ...@@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer) FileSender().beginFileTransfer(self.open_file, consumer)
) )
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close() self.open_file.close()
...@@ -339,7 +359,7 @@ class ReadableFileWrapper: ...@@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock) clock = attr.ib(type=Clock)
path = attr.ib(type=str) path = attr.ib(type=str)
async def write_chunks_to(self, callback: Callable[[bytes], None]): async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk.""" """Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file: with open(self.path, "rb") as file:
......
...@@ -27,6 +27,7 @@ from urllib import parse as urlparse ...@@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr import attr
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.server import Request from twisted.web.server import Request
...@@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource): ...@@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag, etag=etag,
) )
def _start_expire_url_cache_data(self): def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data "expire_url_cache_data", self._expire_url_cache_data
) )
...@@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: ...@@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
def _iterate_over_text( def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion, """Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags. skipping text nodes inside certain tags.
......
...@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider): ...@@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else: else:
# TODO: Handle errors. # TODO: Handle errors.
async def store(): async def store() -> None:
try: try:
return await maybe_awaitable( return await maybe_awaitable(
self.backend.store_file(path, file_info) self.backend.store_file(path, file_info)
...@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider): ...@@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path self.cache_directory = hs.config.media_store_path
self.base_directory = config self.base_directory = config
def __str__(self): def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None: async def store_file(self, path: str, file_info: FileInfo) -> None:
......
...@@ -41,7 +41,7 @@ class Thumbnailer: ...@@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod @staticmethod
def set_limits(max_image_pixels: int): def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str): def __init__(self, input_path: str):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request from twisted.web.server import Request
...@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource): ...@@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version self._consent_version = hs.config.consent.user_consent_version
def template_search_dirs(): def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory: if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir: if hs.config.sso.sso_template_dir:
...@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource): ...@@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params) html = template.render(template_params)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
async def _async_render_POST(self, request: Request): async def _async_render_POST(self, request: Request) -> None:
try: try:
session_id = get_username_mapping_session_cookie_from_request(request) session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e: except SynapseError as e:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment