Skip to content
Snippets Groups Projects
Unverified Commit ecbfa4fe authored by Patrick Cloke's avatar Patrick Cloke Committed by GitHub
Browse files

Additional type hints for client REST servlets (part 5) (#10736)

Additionally this enforce type hints on all function signatures inside
of the synapse.rest.client package.
parent f58d202e
No related branches found
No related tags found
No related merge requests found
Add missing type hints to REST servlets.
......@@ -98,6 +98,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*]
disallow_untyped_defs = True
[mypy-pymacaroons.*]
ignore_missing_imports = True
......
......@@ -572,6 +572,25 @@ def parse_string_from_args(
return strings[0]
@overload
def parse_json_value_from_request(request: Request) -> JsonDict:
...
@overload
def parse_json_value_from_request(
request: Request, allow_empty_body: Literal[False]
) -> JsonDict:
...
@overload
def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]:
...
def parse_json_value_from_request(
request: Request, allow_empty_body: bool = False
) -> Optional[JsonDict]:
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
......@@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
......@@ -16,7 +16,7 @@
"""
import logging
import re
from typing import Iterable, Pattern
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
......@@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
)
def interactive_auth_handler(orig):
C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
def interactive_auth_handler(orig: C) -> C:
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns an Awaitable (errcode, body) response
......@@ -91,10 +94,10 @@ def interactive_auth_handler(orig):
await self.auth_handler.check_auth
"""
async def wrapped(*args, **kwargs):
async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
return wrapped
return cast(C, wrapped)
......@@ -15,7 +15,7 @@
import logging
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request
......@@ -43,14 +43,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _validate_group_id(f):
def _validate_group_id(
f: Callable[..., Awaitable[Tuple[int, JsonDict]]]
) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]:
"""Wrapper to validate the form of the group ID.
Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
"""
@wraps(f)
def wrapper(self, request: Request, group_id: str, *args, **kwargs):
def wrapper(
self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any
) -> Awaitable[Tuple[int, JsonDict]]:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
......
......@@ -12,22 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
import attr
from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_json_value_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RuleSpec:
scope: str
template: str
rule_id: str
attr: Optional[str]
class PushRuleRestServlet(RestServlet):
......@@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet):
"Unrecognised request: You probably wanted a trailing slash"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
......@@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet):
self._users_new_default_push_rules = hs.config.users_new_default_push_rules
async def on_PUT(self, request, path):
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
......@@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
if "attr" in spec:
if spec.attr:
await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
return 200, {}
if spec["rule_id"].startswith("."):
if spec.rule_id.startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try:
(conditions, actions) = _rule_tuple_from_request_object(
spec["template"], spec["rule_id"], content
spec.template, spec.rule_id, content
)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
......@@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
async def on_DELETE(self, request, path):
async def on_DELETE(
self, request: SynapseRequest, path: str
) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
......@@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet):
else:
raise
async def on_GET(self, request, path):
async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
......@@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet):
rules = format_push_rules_for_user(requester.user, rules)
path = path.split("/")[1:]
path_parts = path.split("/")[1:]
if path == []:
if path_parts == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == "":
if path_parts[0] == "":
return 200, rules
elif path[0] == "global":
result = _filter_ruleset_with_path(rules["global"], path[1:])
elif path_parts[0] == "global":
result = _filter_ruleset_with_path(rules["global"], path_parts[1:])
return 200, result
else:
raise UnrecognizedRequestError()
def notify_user(self, user_id):
def notify_user(self, user_id: str) -> None:
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
if spec["attr"] not in ("enabled", "actions"):
async def set_rule_attr(
self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
) -> None:
if spec.attr not in ("enabled", "actions"):
# for the sake of potential future expansion, shouldn't report
# 404 in the case of an unknown request so check it corresponds to
# a known attribute first.
raise UnrecognizedRequestError()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec["rule_id"]
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if spec.attr == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
......@@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
return await self.store.set_push_rule_enabled(
await self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val, is_default_rule
)
elif spec["attr"] == "actions":
elif spec.attr == "actions":
if not isinstance(val, dict):
raise SynapseError(400, "Value must be a dict")
actions = val.get("actions")
if not isinstance(actions, list):
raise SynapseError(400, "Value for 'actions' must be dict")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
rule_id = spec["rule_id"]
rule_id = spec.rule_id
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if user_id in self._users_new_default_push_rules:
......@@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet):
if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return await self.store.set_push_rule_actions(
await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path):
def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
"""Turn a sequence of path components into a rule spec
Args:
path (sequence[unicode]): the URL path components.
path: the URL path components.
Returns:
dict: rule spec dict, containing scope/template/rule_id entries,
and possibly attr.
rule spec, containing scope/template/rule_id entries, and possibly attr.
Raises:
UnrecognizedRequestError if the path components cannot be parsed.
......@@ -237,17 +262,18 @@ def _rule_spec_from_path(path):
rule_id = path[0]
spec = {"scope": scope, "template": template, "rule_id": rule_id}
path = path[1:]
attr = None
if len(path) > 0 and len(path[0]) > 0:
spec["attr"] = path[0]
attr = path[0]
return spec
return RuleSpec(scope, template, rule_id, attr)
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
def _rule_tuple_from_request_object(
rule_template: str, rule_id: str, req_obj: JsonDict
) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
if rule_template in ["override", "underride"]:
if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
......@@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
return conditions, actions
def _check_actions(actions):
def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
if not isinstance(actions, list):
raise InvalidRuleException("No actions found")
......@@ -290,7 +316,7 @@ def _check_actions(actions):
raise InvalidRuleException("Unrecognised action")
def _filter_ruleset_with_path(ruleset, path):
def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
......@@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path):
if r["rule_id"] == rule_id:
the_rule = r
if the_rule is None:
raise NotFoundError
raise NotFoundError()
path = path[1:]
if len(path) == 0:
......@@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path):
raise UnrecognizedRequestError()
def _priority_class_from_spec(spec):
if spec["template"] not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
pc = PRIORITY_CLASS_MAP[spec["template"]]
def _priority_class_from_spec(spec: RuleSpec) -> int:
if spec.template not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec.template))
pc = PRIORITY_CLASS_MAP[spec.template]
return pc
def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec["rule_id"])
def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
return _namespaced_rule_id(spec, spec.rule_id)
def _namespaced_rule_id(spec, rule_id):
return "global/%s/%s" % (spec["template"], rule_id)
def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
return "global/%s/%s" % (spec.template, rule_id)
class InvalidRuleException(Exception):
pass
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
PushRuleRestServlet(hs).register(http_server)
......@@ -15,28 +15,37 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
from twisted.python.failure import Failure
from twisted.web.server import Request
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict
from synapse.util.async_helpers import ObservableDeferred
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
self.transactions: Dict[
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
] = {}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
def _get_transaction_key(self, request: Request) -> str:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
......@@ -45,15 +54,21 @@ class HttpTransactionCache:
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
request: The incoming request. Must contain an access_token.
Returns:
str: A transaction key
A transaction key
"""
assert request.path is not None
token = self.auth.get_access_token_from_request(request)
return request.path.decode("utf8") + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
def fetch_or_execute_request(
self,
request: Request,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
......@@ -64,15 +79,20 @@ class HttpTransactionCache:
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
def fetch_or_execute(
self,
txn_key: str,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict).
txn_key: A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
......@@ -90,7 +110,7 @@ class HttpTransactionCache:
# if the request fails with an exception, remove it
# from the transaction map. This is done to ensure that we don't
# cache transient errors like rate-limiting errors, etc.
def remove_from_map(err):
def remove_from_map(err: Failure) -> None:
self.transactions.pop(txn_key, None)
# we deliberately do not propagate the error any further, as we
# expect the observers to have reported it.
......@@ -99,7 +119,7 @@ class HttpTransactionCache:
return make_deferred_yieldable(observable.observe())
def _cleanup(self):
def _cleanup(self) -> None:
now = self.clock.time_msec()
for key in list(self.transactions):
ts = self.transactions[key][1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment