Skip to content
Snippets Groups Projects
Unverified Commit 287918e2 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Additional type hints for the client REST servlets (part 3). (#10707)

parent 78e590d4
No related branches found
No related tags found
No related merge requests found
Add missing type hints to REST servlets.
...@@ -13,12 +13,19 @@ ...@@ -13,12 +13,19 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet): ...@@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type): async def on_PUT(
self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
...@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet): ...@@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
return 200, {} return 200, {}
async def on_GET(self, request, user_id, account_data_type): async def on_GET(
self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
...@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet): ...@@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type): async def on_PUT(
self,
request: SynapseRequest,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
...@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet): ...@@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
return 200, {} return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type): async def on_GET(
self,
request: SynapseRequest,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
...@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet): ...@@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
return 200, event return 200, event
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server) AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server)
...@@ -156,7 +156,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ...@@ -156,7 +156,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id: str, group_id: str,
category_id: Optional[str], category_id: Optional[str],
room_id: str, room_id: str,
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
...@@ -188,7 +188,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ...@@ -188,7 +188,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
...@@ -451,7 +451,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): ...@@ -451,7 +451,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
...@@ -674,7 +674,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): ...@@ -674,7 +674,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
...@@ -706,7 +706,7 @@ class GroupAdminUsersInviteServlet(RestServlet): ...@@ -706,7 +706,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id, user_id self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
...@@ -738,7 +738,7 @@ class GroupAdminUsersKickServlet(RestServlet): ...@@ -738,7 +738,7 @@ class GroupAdminUsersKickServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id, user_id self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
......
...@@ -13,13 +13,20 @@ ...@@ -13,13 +13,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet): ...@@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
async def on_POST(self, request, room_id, receipt_type, event_id): async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
...@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet): ...@@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server) ReceiptRestServlet(hs).register(http_server)
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import List, Union from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
import synapse import synapse
import synapse.api.auth import synapse.api.auth
...@@ -29,15 +31,13 @@ from synapse.api.errors import ( ...@@ -29,15 +31,13 @@ from synapse.api.errors import (
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
...@@ -45,6 +45,7 @@ from synapse.http.servlet import ( ...@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
from synapse.types import JsonDict from synapse.types import JsonDict
...@@ -59,17 +60,16 @@ from synapse.util.threepids import ( ...@@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$") PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
...@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ...@@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text, template_text=self.config.email_registration_template_text,
) )
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config: if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
...@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ...@@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$") PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
...@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ...@@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
...@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ...@@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html self.config.email_registration_template_failure_html
) )
async def on_GET(self, request, medium): async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
400, "This medium is currently not supported for registration" 400, "This medium is currently not supported for registration"
...@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet): ...@@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available") PATTERNS = client_patterns("/register/available")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
...@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ...@@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
), ),
) )
async def on_GET(self, request): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration: if not self.hs.config.enable_registration:
raise SynapseError( raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
...@@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): ...@@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$") PATTERNS = client_patterns("/register$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
...@@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet): ...@@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet):
) )
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP() client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False) await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user" kind = parse_string(request, "kind", default="user")
if b"kind" in request.args:
kind = request.args[b"kind"][0]
if kind == b"guest": if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr) ret = await self._do_guest_registration(body, address=client_addr)
return ret return ret
elif kind != b"user": elif kind != "user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind.decode("utf8"),) f"Do not understand membership kind: {kind}",
) )
if self._msc2918_enabled: if self._msc2918_enabled:
...@@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet): ...@@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet):
async def _do_appservice_registration( async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False self, username, as_token, body, should_issue_refresh_token: bool = False
): ) -> JsonDict:
user_id = await self.registration_handler.appservice_register( user_id = await self.registration_handler.appservice_register(
username, as_token username, as_token
) )
...@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet): ...@@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict, params: JsonDict,
is_appservice_ghost: bool = False, is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
): ) -> JsonDict:
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token. Allocates device_id if one was not given; also creates access_token.
...@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet): ...@@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result return result
async def _do_guest_registration(self, params, address=None): async def _do_guest_registration(
self, params: JsonDict, address: Optional[str] = None
) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled") raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
...@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet): ...@@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows( def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one config: HomeServerConfig, auth_handler: AuthHandler
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
auth_handler: AuthHandler,
) -> List[List[str]]: ) -> List[List[str]]:
"""Get a suitable flows list for registration """Get a suitable flows list for registration
...@@ -929,7 +911,7 @@ def _calculate_registration_flows( ...@@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows return flows
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server) EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server)
......
...@@ -19,25 +19,32 @@ any time to reflect changes in the MSC. ...@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
""" """
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import ( from synapse.storage.relations import (
AggregationPaginationToken, AggregationPaginationToken,
PaginationChunk, PaginationChunk,
RelationPaginationToken, RelationPaginationToken,
) )
from synapse.types import JsonDict
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet): ...@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
...@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet): ...@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__, self.__class__.__name__,
) )
def on_PUT(self, request, *args, **kwargs): def on_PUT(
self,
request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_PUT_or_POST, request, *args, **kwargs request,
self.on_PUT_or_POST,
request,
room_id,
parent_id,
relation_type,
event_type,
txn_id,
) )
async def on_PUT_or_POST( async def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
...@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet): ...@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
...@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet): ...@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, request, room_id, parent_id, relation_type=None, event_type=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
...@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet): ...@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it. # b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from") from_token_str = parse_string(request, "from")
...@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet): ...@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, request, room_id, parent_id, relation_type=None, event_type=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
...@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet): ...@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
...@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
...@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): async def on_GET(
self,
request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
key: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
...@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ...@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value return 200, return_value
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server) RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment