Skip to content
Snippets Groups Projects
Unverified Commit 3f58fc84 authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Type hints and validation improvements. (#9321)

* Adds type hints to the groups servlet and stringutils code.
* Assert the maximum length of some input values for spec compliance.
parent 0963d39e
No related branches found
No related tags found
No related merge requests found
Assert a maximum length for the `client_secret` parameter for spec compliance.
......@@ -18,6 +18,7 @@
import logging
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
......@@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
# TODO: Flairs
# Note that the maximum lengths are somewhat arbitrary.
MAX_SHORT_DESC_LEN = 1000
MAX_LONG_DESC_LEN = 10000
class GroupsServerWorkerHandler:
def __init__(self, hs):
self.hs = hs
......@@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
profile = {}
for keyname in ("name", "avatar_url", "short_description", "long_description"):
for keyname, max_length in (
("name", MAX_DISPLAYNAME_LEN),
("avatar_url", MAX_AVATAR_URL_LEN),
("short_description", MAX_SHORT_DESC_LEN),
("long_description", MAX_LONG_DESC_LEN),
):
if keyname in content:
value = content[keyname]
if not isinstance(value, str):
raise SynapseError(400, "%r value is not a string" % (keyname,))
raise SynapseError(
400,
"%r value is not a string" % (keyname,),
errcode=Codes.INVALID_PARAM,
)
if len(value) > max_length:
raise SynapseError(
400,
"Invalid %s parameter" % (keyname,),
errcode=Codes.INVALID_PARAM,
)
profile[keyname] = value
await self.store.update_group_profile(group_id, profile)
......
This diff is collapsed.
......@@ -193,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
......@@ -293,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
......
......@@ -25,7 +25,17 @@ import abc
import functools
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
TypeVar,
Union,
cast,
)
import twisted.internet.base
import twisted.internet.tcp
......@@ -588,7 +598,9 @@ class HomeServer(metaclass=abc.ABCMeta):
return UserDirectoryHandler(self)
@cache_in_self
def get_groups_local_handler(self):
def get_groups_local_handler(
self,
) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app:
return GroupsLocalWorkerHandler(self)
else:
......
......@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
......@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
rand = random.SystemRandom()
def random_string(length):
def random_string(length: int) -> str:
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
def random_string_with_symbols(length: int) -> str:
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s):
if isinstance(s, bytes):
try:
s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
def is_ascii(s: bytes) -> bool:
try:
s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
def assert_valid_client_secret(client_secret):
"""Validate that a given string matches the client_secret regex defined by the spec"""
if client_secret_regex.match(client_secret) is None:
def assert_valid_client_secret(client_secret: str) -> None:
"""Validate that a given string matches the client_secret defined by the spec"""
if (
len(client_secret) <= 0
or len(client_secret) > 255
or CLIENT_SECRET_REGEX.match(client_secret) is None
):
raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment