Skip to content
Snippets Groups Projects
Unverified Commit f00c4e7a authored by Dirk Klimpel's avatar Dirk Klimpel Committed by GitHub
Browse files

Add type hints to device and event report admin API (#9519)

parent ad8589d3
Branches
Tags
No related merge requests found
Add type hints to device and event report admin API.
\ No newline at end of file
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
...@@ -20,8 +21,12 @@ from synapse.http.servlet import ( ...@@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet): ...@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id, device_id): async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
...@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet): ...@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
) )
return 200, device return 200, device
async def on_DELETE(self, request, user_id, device_id): async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
...@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet): ...@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id) await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {} return 200, {}
async def on_PUT(self, request, user_id, device_id): async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
...@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet): ...@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
...@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet): ...@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
...@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet): ...@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
......
...@@ -14,10 +14,16 @@ ...@@ -14,10 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet): ...@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$") PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
...@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet): ...@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, report_id): async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
message = ( message = (
"The report_id parameter must be a string representing a positive integer." "The report_id parameter must be a string representing a positive integer."
) )
try: try:
report_id = int(report_id) resolved_report_id = int(report_id)
except ValueError: except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
if report_id < 0: if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
ret = await self.store.get_event_report(report_id) ret = await self.store.get_event_report(resolved_report_id)
if not ret: if not ret:
raise NotFoundError("Event report not found") raise NotFoundError("Event report not found")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment