Skip to content
Snippets Groups Projects
Unverified Commit 2295095c authored by David Robertson's avatar David Robertson Committed by GitHub
Browse files

Use Pydantic to validate /devices endpoints (#14054)

parent 1fa2e587
No related branches found
No related tags found
No related merge requests found
Improve validation of request bodies for the [Device Management](https://spec.matrix.org/v1.4/client-server-api/#device-management) and [MSC2697 Device Dehyrdation](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) client-server API endpoints.
...@@ -14,18 +14,21 @@ ...@@ -14,18 +14,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from pydantic import Extra, StrictStr
from synapse.api import errors from synapse.api import errors
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, parse_and_validate_json_object_from_request,
parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet): ...@@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData]
devices: List[StrictStr]
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PostBody)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# DELETE # TODO: Can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.PostBody.parse_obj({})
else: else:
raise e raise e
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
"remove device(s) from your account", "remove device(s) from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
...@@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet): ...@@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet):
) )
await self.device_handler.delete_devices( await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"] requester.user.to_string(), body.devices
) )
return 200, {} return 200, {}
...@@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet): ...@@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet):
return 200, device return 200, device
class DeleteBody(RequestBodyModel):
auth: Optional[AuthenticationData]
@interactive_auth_handler @interactive_auth_handler
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
...@@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet): ...@@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.DeleteBody)
except errors.SynapseError as e: except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON: if e.errcode == errors.Codes.NOT_JSON:
# TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict # the same as those that pass an empty dict
body = {} body = self.DeleteBody.parse_obj({})
else: else:
raise raise
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
body, body.dict(exclude_unset=True),
"remove a device from your account", "remove a device from your account",
# Users might call this multiple times in a row while cleaning up # Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used. # devices, allow a single UI auth session to be re-used.
...@@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet): ...@@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet):
) )
return 200, {} return 200, {}
class PutBody(RequestBodyModel):
display_name: Optional[StrictStr]
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, device_id: str self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request) body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.device_handler.update_device( await self.device_handler.update_device(
requester.user.to_string(), device_id, body requester.user.to_string(), device_id, body.dict()
) )
return 200, {} return 200, {}
class DehydratedDeviceDataModel(RequestBodyModel):
"""JSON blob describing a dehydrated device to be stored.
Expects other freeform fields. Use .dict() to access them.
"""
class Config:
extra = Extra.allow
algorithm: StrictStr
class DehydratedDeviceServlet(RestServlet): class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device. """Retrieve or store a dehydrated device.
...@@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet): ...@@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet):
else: else:
raise errors.NotFoundError("No dehydrated device available") raise errors.NotFoundError("No dehydrated device available")
class PutBody(RequestBodyModel):
device_id: StrictStr
device_data: DehydratedDeviceDataModel
initial_device_display_name: Optional[StrictStr]
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
submission = parse_json_object_from_request(request) submission = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if "device_data" not in submission:
raise errors.SynapseError(
400,
"device_data missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
400,
"device_data must be an object",
errcode=errors.Codes.INVALID_PARAM,
)
device_id = await self.device_handler.store_dehydrated_device( device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(), requester.user.to_string(),
submission["device_data"], submission.device_data,
submission.get("initial_device_display_name", None), submission.initial_device_display_name,
) )
return 200, {"device_id": device_id} return 200, {"device_id": device_id}
...@@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet): ...@@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
class PostBody(RequestBodyModel):
device_id: StrictStr
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
submission = parse_json_object_from_request(request) submission = parse_and_validate_json_object_from_request(request, self.PostBody)
if "device_id" not in submission:
raise errors.SynapseError(
400,
"device_id missing",
errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
400,
"device_id must be a string",
errcode=errors.Codes.INVALID_PARAM,
)
result = await self.device_handler.rehydrate_device( result = await self.device_handler.rehydrate_device(
requester.user.to_string(), requester.user.to_string(),
self.auth.get_access_token_from_request(request), self.auth.get_access_token_from_request(request),
submission["device_id"], submission.device_id,
) )
return 200, result return 200, result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment