Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • maunium/synapse
  • leytilera/synapse
2 results
Show changes
Showing
with 1312 additions and 700 deletions
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019-2021 Matrix.org Federation C.I.C
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import (
......@@ -29,10 +35,8 @@ from typing import (
Union,
)
from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import (
......@@ -63,6 +67,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
......@@ -85,7 +90,7 @@ from synapse.replication.http.federation import (
from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
......@@ -129,6 +134,7 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.hostname
self.handler = hs.get_federation_handler()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._federation_event_handler = hs.get_federation_event_handler()
......@@ -136,6 +142,7 @@ class FederationServer(FederationBase):
self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler()
self._e2e_keys_handler = hs.get_e2e_keys_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
self._state_storage_controller = hs.get_storage_controllers().state
......@@ -162,9 +169,9 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache: ResponseCache[
Tuple[str, Optional[str]]
] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = (
ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
)
self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000
)
......@@ -514,7 +521,7 @@ class FederationServer(FederationBase):
logger.error(
"Failed to handle PDU %s",
event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
exc_info=(f.type, f.value, f.getTracebackObject()),
)
return {"error": str(e)}
......@@ -539,7 +546,25 @@ class FederationServer(FederationBase):
edu_type=edu_dict["edu_type"],
content=edu_dict["content"],
)
await self.registry.on_edu(edu.edu_type, origin, edu.content)
try:
await self.registry.on_edu(edu.edu_type, origin, edu.content)
except Exception:
# If there was an error handling the EDU, we must reject the
# transaction.
#
# Some EDU types (notably, to-device messages) are, despite their name,
# expected to be reliable; if we weren't able to do something with it,
# we have to tell the sender that, and the only way the protocol gives
# us to do so is by sending an HTTP error back on the transaction.
#
# We log the exception now, and then raise a new SynapseError to cause
# the transaction to be failed.
logger.exception("Error handling EDU of type %s", edu.edu_type)
raise SynapseError(500, f"Error handing EDU of type {edu.edu_type}")
# TODO: if the first EDU fails, we should probably abort the whole
# thing rather than carrying on with the rest of them. That would
# probably be best done inside `concurrently_execute`.
await concurrently_execute(
_process_edu,
......@@ -649,7 +674,7 @@ class FederationServer(FederationBase):
# This is in addition to the HS-level rate limiting applied by
# BaseFederationServlet.
# type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
await self._room_member_handler._join_rate_per_room_limiter.ratelimit(
requester=None,
key=room_id,
update=False,
......@@ -692,7 +717,7 @@ class FederationServer(FederationBase):
SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE,
caller_supports_partial_state,
)
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
await self._room_member_handler._join_rate_per_room_limiter.ratelimit(
requester=None,
key=room_id,
update=False,
......@@ -738,12 +763,10 @@ class FederationServer(FederationBase):
"event": event_json,
"state": [p.get_pdu_json(time_now) for p in state_events],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
"org.matrix.msc3706.partial_state": caller_supports_partial_state,
"members_omitted": caller_supports_partial_state,
}
if servers_in_room is not None:
resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
resp["servers_in_room"] = list(servers_in_room)
return resp
......@@ -807,7 +830,7 @@ class FederationServer(FederationBase):
raise IncompatibleRoomVersionError(room_version=room_version.identifier)
# Check that this room supports knocking as defined by its room version
if not room_version.msc2403_knocking:
if not room_version.knock_join_rule:
raise SynapseError(
403,
"This room version does not support knocking",
......@@ -851,14 +874,7 @@ class FederationServer(FederationBase):
context, self._room_prejoin_state_types
)
)
return {
"knock_room_state": stripped_room_state,
# Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
# Thus, we also populate a 'knock_state_events' with the same content to
# support old instances.
# See https://github.com/matrix-org/synapse/issues/14088.
"knock_state_events": stripped_room_state,
}
return {"knock_room_state": stripped_room_state}
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
......@@ -910,7 +926,7 @@ class FederationServer(FederationBase):
errcode=Codes.NOT_FOUND,
)
if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
if membership_type == Membership.KNOCK and not room_version.knock_join_rule:
raise SynapseError(
403,
"This room version does not support knocking",
......@@ -934,7 +950,7 @@ class FederationServer(FederationBase):
# the event is valid to be sent into the room. Currently this is only done
# if the user is being joined via restricted join rules.
if (
room_version.msc3083_join_rules
room_version.restricted_join_rule
and event.membership == Membership.JOIN
and EventContentFields.AUTHORISING_USER in event.content
):
......@@ -942,10 +958,10 @@ class FederationServer(FederationBase):
authorising_server = get_domain_from_id(
event.content[EventContentFields.AUTHORISING_USER]
)
if authorising_server != self.server_name:
if not self._is_mine_server_name(authorising_server):
raise SynapseError(
400,
f"Cannot authorise request from resident server: {authorising_server}",
f"Cannot authorise membership event for {authorising_server}. We can only authorise requests from our own homeserver",
)
event.signatures.update(
......@@ -1005,12 +1021,13 @@ class FederationServer(FederationBase):
@trace
async def on_claim_client_keys(
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
if any(
not self.hs.is_mine(UserID.from_string(user_id))
for user_id, _, _, _ in query
):
raise SynapseError(400, "User is not hosted on this homeserver")
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self._e2e_keys_handler.claim_local_one_time_keys(
......@@ -1022,7 +1039,9 @@ class FederationServer(FederationBase):
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(device_id, {})[
key_id
] = key
logger.info(
"Claimed one-time-keys: %s",
......@@ -1240,9 +1259,18 @@ class FederationServer(FederationBase):
logger.info("handling received PDU in room %s: %s", room_id, event)
try:
with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu(
origin, event
)
# We're taking out a lock within a lock, which could
# lead to deadlocks if we're not careful. However, it is
# safe on this occasion as we only ever take a write
# lock when deleting a room, which we would never do
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock.
async with self._worker_lock_handler.acquire_read_write_lock(
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
await self._federation_event_handler.on_receive_pdu(
origin, event
)
except FederationError as e:
# XXX: Ideally we'd inform the remote we failed to process
# the event, but we can't return an error in the transaction
......@@ -1253,7 +1281,7 @@ class FederationServer(FederationBase):
logger.error(
"Failed to handle PDU %s",
event.event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
exc_info=(f.type, f.value, f.getTracebackObject()),
)
received_ts = await self.store.remove_received_event_from_staging(
......@@ -1297,9 +1325,6 @@ class FederationServer(FederationBase):
return
lock = new_lock
def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
) -> None:
......@@ -1320,75 +1345,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
server_name: name of server, without any port part
acl_event: m.room.server_acl event
Returns:
True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True
# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False
def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
server_acl_evaluator = (
await self._storage_controllers.state.get_server_acl_for_room(room_id)
)
return False
regex = glob_to_regex(acl_entry)
return bool(regex.match(server_name))
if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
server_name
):
raise AuthError(code=403, msg="Server is banned from room")
class FederationHandlerRegistry:
......@@ -1462,19 +1425,14 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE:
if not self.config.server.track_presence and edu_type == EduTypes.PRESENCE:
return
# Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
if handler:
with start_active_span_from_edu(content, "handle_edu"):
try:
await handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
await handler(origin, content)
return
# Check if we can route it somewhere else that isn't us
......@@ -1483,17 +1441,12 @@ class FederationHandlerRegistry:
# Pick an instance randomly so that we don't overload one.
route_to = random.choice(instances)
try:
await self._send_edu(
instance_name=route_to,
edu_type=edu_type,
origin=origin,
content=content,
)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
await self._send_edu(
instance_name=route_to,
edu_type=edu_type,
origin=origin,
content=content,
)
return
# Oh well, let's just log and move on.
......
# Copyright 2014-2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains all the persistence actions done by the federation
"""This module contains all the persistence actions done by the federation
package.
These actions are mostly only used by the :py:mod:`.replication` module.
......
# Copyright 2014-2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A federation sender that forwards things to be sent across replication to
a worker process.
......@@ -49,7 +56,7 @@ from synapse.api.presence import UserPresenceState
from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.metrics import LaterGauge
from synapse.replication.tcp.streams.federation import FederationStream
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection
from synapse.util.metrics import Measure
from .units import Edu
......@@ -68,6 +75,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
# We may have multiple federation sender instances, so we need to track
# their positions separately.
......@@ -80,9 +88,9 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
self.presence_destinations: SortedDict[
int, Tuple[str, Iterable[str]]
] = SortedDict()
self.presence_destinations: SortedDict[int, Tuple[str, Iterable[str]]] = (
SortedDict()
)
# (destination, key) -> EDU
self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}
......@@ -198,7 +206,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
key: Optional[Hashable] = None,
) -> None:
"""As per FederationSender"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return
......@@ -228,7 +236,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
"""
# nothing to do here: the replication listener will handle it.
def send_presence_to_destinations(
async def send_presence_to_destinations(
self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""As per FederationSender
......@@ -244,7 +252,9 @@ class FederationRemoteSendQueue(AbstractFederationSender):
self.notifier.on_new_replication_data()
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
async def send_device_messages(
self, destinations: StrCollection, immediate: bool = True
) -> None:
"""As per FederationSender"""
# We don't need to replicate this as it gets sent down a different
# stream.
......@@ -392,7 +402,7 @@ class PresenceDestinationsRow(BaseFederationRow):
@staticmethod
def from_data(data: JsonDict) -> "PresenceDestinationsRow":
return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
state=UserPresenceState(**data["state"]), destinations=data["dests"]
)
def to_data(self) -> JsonDict:
......@@ -462,7 +472,7 @@ class ParsedFederationStreamData:
edus: Dict[str, List[Edu]]
def process_rows_for_federation(
async def process_rows_for_federation(
transaction_queue: FederationSender,
rows: List[FederationStream.FederationStreamRow],
) -> None:
......@@ -495,7 +505,7 @@ def process_rows_for_federation(
parsed_row.add_to_buffer(buff)
for state, destinations in buff.presence_destinations:
transaction_queue.send_presence_to_destinations(
await transaction_queue.send_presence_to_destinations(
states=[state], destinations=destinations
)
......
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Federation Sender is responsible for sending Persistent Data Units (PDUs)
and Ephemeral Data Units (EDUs) to other homeservers using
......@@ -67,7 +73,7 @@ The loop continues so long as there is anything to send. At each iteration of th
When the `PerDestinationQueue` has the catch-up flag set, the *Catch-Up Transmission Loop*
(`_catch_up_transmission_loop`) is used in lieu of the regular `_transaction_transmission_loop`.
(Only once the catch-up mode has been exited can the regular tranaction transmission behaviour
(Only once the catch-up mode has been exited can the regular transaction transmission behaviour
be resumed.)
*Catch-Up Mode*, entered upon Synapse startup or once a homeserver has fallen behind due to
......@@ -109,10 +115,8 @@ was enabled*, Catch-Up Mode is exited and we return to `_transaction_transmissio
If a remote server is unreachable over federation, we back off from that server,
with an exponentially-increasing retry interval.
Whilst we don't automatically retry after the interval, we prevent making new attempts
until such time as the back-off has cleared.
Once the back-off is cleared and a new PDU or EDU arrives for transmission, the transmission
loop resumes and empties the queue by making federation requests.
We automatically retry after the retry interval expires (roughly, the logic to do so
being triggered every minute).
If the backoff grows too large (> 1 hour), the in-memory queue is emptied (to prevent
unbounded growth) and Catch-Up Mode is entered.
......@@ -135,22 +139,23 @@ from typing import (
Hashable,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
)
import attr
from prometheus_client import Counter
from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.per_destination_queue import (
CATCHUP_RETRY_INTERVAL,
PerDestinationQueue,
)
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
from synapse.logging.context import make_deferred_yieldable, run_in_background
......@@ -164,9 +169,16 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
from synapse.types import (
JsonDict,
ReadReceipt,
RoomStreamToken,
StrCollection,
get_domain_from_id,
)
from synapse.util import Clock
from synapse.util.metrics import Measure
from synapse.util.retryutils import filter_destinations_by_retry_limiter
if TYPE_CHECKING:
from synapse.events.presence_router import PresenceRouter
......@@ -184,14 +196,17 @@ sent_pdus_destination_dist_total = Counter(
"Total number of PDUs queued for sending across all destinations",
)
# Time (in s) after Synapse's startup that we will begin to wake up destinations
# that have catch-up outstanding.
CATCH_UP_STARTUP_DELAY_SEC = 15
# Time (in s) to wait before trying to wake up destinations that have
# catch-up outstanding.
# Please note that rate limiting still applies, so while the loop is
# executed every X seconds the destinations may not be woken up because
# they are being rate limited following previous attempt failures.
WAKEUP_RETRY_PERIOD_SEC = 60
# Time (in s) to wait in between waking up each destination, i.e. one destination
# will be woken up every <x> seconds after Synapse's startup until we have woken
# every destination has outstanding catch-up.
CATCH_UP_STARTUP_INTERVAL_SEC = 5
# will be woken up every <x> seconds until we have woken every destination
# has outstanding catch-up.
WAKEUP_INTERVAL_BETWEEN_DESTINATIONS_SEC = 5
class AbstractFederationSender(metaclass=abc.ABCMeta):
......@@ -212,7 +227,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def send_presence_to_destinations(
async def send_presence_to_destinations(
self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""Send the given presence states to the given destinations.
......@@ -241,9 +256,11 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
async def send_device_messages(
self, destinations: StrCollection, immediate: bool = True
) -> None:
"""Tells the sender that a new device message is ready to be sent to the
destination. The `immediate` flag specifies whether the messages should
destinations. The `immediate` flag specifies whether the messages should
be tried to be sent immediately, or whether it can be delayed for a
short while (to aid performance).
"""
......@@ -285,12 +302,10 @@ class _DestinationWakeupQueue:
# being woken up.
_MAX_TIME_IN_QUEUE = 30.0
# The maximum duration in seconds between waking up consecutive destination
# queues.
_MAX_DELAY = 0.1
sender: "FederationSender" = attr.ib()
clock: Clock = attr.ib()
max_delay_s: int = attr.ib()
queue: "OrderedDict[str, Literal[None]]" = attr.ib(factory=OrderedDict)
processing: bool = attr.ib(default=False)
......@@ -320,7 +335,7 @@ class _DestinationWakeupQueue:
# We also add an upper bound to the delay, to gracefully handle the
# case where the queue only has a few entries in it.
current_sleep_seconds = min(
self._MAX_DELAY, self._MAX_TIME_IN_QUEUE / len(self.queue)
self.max_delay_s, self._MAX_TIME_IN_QUEUE / len(self.queue)
)
while self.queue:
......@@ -362,6 +377,7 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs)
......@@ -403,31 +419,23 @@ class FederationSender(AbstractFederationSender):
self._is_processing = False
self._last_poked_id = -1
# map from room_id to a set of PerDestinationQueues which we believe are
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
self._external_cache = hs.get_external_cache()
self._rr_txn_interval_per_room_ms = (
1000.0
/ hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
rr_txn_interval_per_room_s = (
1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second
)
self._destination_wakeup_queue = _DestinationWakeupQueue(
self, self.clock, max_delay_s=rr_txn_interval_per_room_s
)
# wake up destinations that have outstanding PDUs to be caught up
self._catchup_after_startup_timer: Optional[
IDelayedCall
] = self.clock.call_later(
CATCH_UP_STARTUP_DELAY_SEC,
# Regularly wake up destinations that have outstanding PDUs to be caught up
self.clock.looping_call_now(
run_as_background_process,
WAKEUP_RETRY_PERIOD_SEC * 1000.0,
"wake_destinations_needing_catchup",
self._wake_destinations_needing_catchup,
)
self._external_cache = hs.get_external_cache()
self._destination_wakeup_queue = _DestinationWakeupQueue(self, self.clock)
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
......@@ -575,14 +583,14 @@ class FederationSender(AbstractFederationSender):
"get_joined_hosts", str(sg)
)
if destinations is None:
# Add logging to help track down #13444
# Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
logger.info(
"Unexpectedly did not have cached destinations for %s / %s",
sg,
event.event_id,
)
else:
# Add logging to help track down #13444
# Add logging to help track down https://github.com/matrix-org/synapse/issues/13444
logger.info(
"Unexpectedly did not have cached prev group for %s",
event.event_id,
......@@ -716,6 +724,13 @@ class FederationSender(AbstractFederationSender):
pdu.internal_metadata.stream_ordering,
)
destinations = await filter_destinations_by_retry_limiter(
destinations,
clock=self.clock,
store=self.store,
retry_due_within_ms=CATCHUP_RETRY_INTERVAL,
)
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu)
......@@ -728,102 +743,123 @@ class FederationSender(AbstractFederationSender):
# Some background on the rate-limiting going on here.
#
# It turns out that if we attempt to send out RRs as soon as we get them from
# a client, then we end up trying to do several hundred Hz of federation
# transactions. (The number of transactions scales as O(N^2) on the size of a
# room, since in a large room we have both more RRs coming in, and more servers
# to send them to.)
# It turns out that if we attempt to send out RRs as soon as we get them
# from a client, then we end up trying to do several hundred Hz of
# federation transactions. (The number of transactions scales as O(N^2)
# on the size of a room, since in a large room we have both more RRs
# coming in, and more servers to send them to.)
#
# This leads to a lot of CPU load, and we end up getting behind. The solution
# currently adopted is as follows:
# This leads to a lot of CPU load, and we end up getting behind. The
# solution currently adopted is to differentiate between receipts and
# destinations we should immediately send to, and those we can trickle
# the receipts to.
#
# The first receipt in a given room is sent out immediately, at time T0. Any
# further receipts are, in theory, batched up for N seconds, where N is calculated
# based on the number of servers in the room to achieve a transaction frequency
# of around 50Hz. So, for example, if there were 100 servers in the room, then
# N would be 100 / 50Hz = 2 seconds.
# The current logic is to send receipts out immediately if:
# - the room is "small", i.e. there's only N servers to send receipts
# to, and so sending out the receipts immediately doesn't cause too
# much load; or
# - the receipt is for an event that happened recently, as users
# notice if receipts are delayed when they know other users are
# currently reading the room; or
# - the receipt is being sent to the server that sent the event, so
# that users see receipts for their own receipts quickly.
#
# Then, after T+N, we flush out any receipts that have accumulated, and restart
# the timer to flush out more receipts at T+2N, etc. If no receipts accumulate,
# we stop the cycle and go back to the start.
# For destinations that we should delay sending the receipt to, we queue
# the receipts up to be sent in the next transaction, but don't trigger
# a new transaction to be sent. We then add the destination to the
# `DestinationWakeupQueue`, which will slowly iterate over each
# destination and trigger a new transaction to be sent.
#
# However, in practice, it is often possible to flush out receipts earlier: in
# particular, if we are sending a transaction to a given server anyway (for
# example, because we have a PDU or a RR in another room to send), then we may
# as well send out all of the pending RRs for that server. So it may be that
# by the time we get to T+N, we don't actually have any RRs left to send out.
# Nevertheless we continue to buffer up RRs for the room in question until we
# reach the point that no RRs arrive between timer ticks.
# However, in practice, it is often possible to send out delayed
# receipts earlier: in particular, if we are sending a transaction to a
# given server anyway (for example, because we have a PDU or a RR in
# another room to send), then we may as well send out all of the pending
# RRs for that server. So it may be that by the time we get to waking up
# the destination, we don't actually have any RRs left to send out.
#
# For even more background, see https://github.com/matrix-org/synapse/issues/4730.
# For even more background, see
# https://github.com/matrix-org/synapse/issues/4730.
room_id = receipt.room_id
# Local read receipts always have 1 event ID.
event_id = receipt.event_ids[0]
# Work out which remote servers should be poked and poke them.
domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
domains = [
domains: StrCollection = [
d
for d in domains_set
if d != self.server_name
if not self.is_mine_server_name(d)
and self._federation_shard_config.should_handle(self._instance_name, d)
]
domains = await filter_destinations_by_retry_limiter(
domains,
clock=self.clock,
store=self.store,
retry_due_within_ms=CATCHUP_RETRY_INTERVAL,
)
if not domains:
return
queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
# We now split which domains we want to wake up immediately vs which we
# want to delay waking up.
immediate_domains: StrCollection
delay_domains: StrCollection
# if there is no flush yet scheduled, we will send out these receipts with
# immediate flushes, and schedule the next flush for this room.
if queues_pending_flush is not None:
logger.debug("Queuing receipt for: %r", domains)
if len(domains) < 10:
# For "small" rooms send to all domains immediately
immediate_domains = domains
delay_domains = ()
else:
logger.debug("Sending receipt to: %r", domains)
self._schedule_rr_flush_for_room(room_id, len(domains))
metadata = await self.store.get_metadata_for_event(
receipt.room_id, event_id
)
assert metadata is not None
for domain in domains:
queue = self._get_per_destination_queue(domain)
queue.queue_read_receipt(receipt)
sender_domain = get_domain_from_id(metadata.sender)
# if there is already a RR flush pending for this room, then make sure this
# destination is registered for the flush
if queues_pending_flush is not None:
queues_pending_flush.add(queue)
if self.clock.time_msec() - metadata.received_ts < 60_000:
# We always send receipts for recent messages immediately
immediate_domains = domains
delay_domains = ()
else:
queue.flush_read_receipts_for_room(room_id)
def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None:
# that is going to cause approximately len(domains) transactions, so now back
# off for that multiplied by RR_TXN_INTERVAL_PER_ROOM
backoff_ms = self._rr_txn_interval_per_room_ms * n_domains
logger.debug("Scheduling RR flush in %s in %d ms", room_id, backoff_ms)
self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id)
self._queues_awaiting_rr_flush_by_room[room_id] = set()
def _flush_rrs_for_room(self, room_id: str) -> None:
queues = self._queues_awaiting_rr_flush_by_room.pop(room_id)
logger.debug("Flushing RRs in %s to %s", room_id, queues)
if not queues:
# no more RRs arrived for this room; we are done.
return
# Otherwise, we delay waking up all destinations except for the
# sender's domain.
immediate_domains = []
delay_domains = []
for domain in domains:
if domain == sender_domain:
immediate_domains.append(domain)
else:
delay_domains.append(domain)
for domain in immediate_domains:
# Add to destination queue and wake the destination up
queue = self._get_per_destination_queue(domain)
queue.queue_read_receipt(receipt)
queue.attempt_new_transaction()
# schedule the next flush
self._schedule_rr_flush_for_room(room_id, len(queues))
for domain in delay_domains:
# Add to destination queue...
queue = self._get_per_destination_queue(domain)
queue.queue_read_receipt(receipt)
for queue in queues:
queue.flush_read_receipts_for_room(room_id)
# ... and schedule the destination to be woken up.
self._destination_wakeup_queue.add_to_queue(domain)
def send_presence_to_destinations(
async def send_presence_to_destinations(
self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""Send the given presence states to the given destinations.
destinations (list[str])
"""
if not states or not self.hs.config.server.use_presence:
if not states or not self.hs.config.server.track_presence:
# No-op if presence is disabled.
return
......@@ -831,12 +867,19 @@ class FederationSender(AbstractFederationSender):
for state in states:
assert self.is_mine_id(state.user_id)
destinations = await filter_destinations_by_retry_limiter(
[
d
for d in destinations
if self._federation_shard_config.should_handle(self._instance_name, d)
],
clock=self.clock,
store=self.store,
retry_due_within_ms=CATCHUP_RETRY_INTERVAL,
)
for destination in destinations:
if destination == self.server_name:
continue
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
if self.is_mine_server_name(destination):
continue
self._get_per_destination_queue(destination).send_presence(
......@@ -860,7 +903,7 @@ class FederationSender(AbstractFederationSender):
content: content of EDU
key: clobbering key for this edu
"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return
......@@ -896,21 +939,29 @@ class FederationSender(AbstractFederationSender):
else:
queue.send_edu(edu)
def send_device_messages(self, destination: str, immediate: bool = True) -> None:
if destination == self.server_name:
logger.warning("Not sending device update to ourselves")
return
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
async def send_device_messages(
self, destinations: StrCollection, immediate: bool = True
) -> None:
destinations = await filter_destinations_by_retry_limiter(
[
destination
for destination in destinations
if self._federation_shard_config.should_handle(
self._instance_name, destination
)
and not self.is_mine_server_name(destination)
],
clock=self.clock,
store=self.store,
retry_due_within_ms=CATCHUP_RETRY_INTERVAL,
)
if immediate:
self._get_per_destination_queue(destination).attempt_new_transaction()
else:
self._get_per_destination_queue(destination).mark_new_data()
self._destination_wakeup_queue.add_to_queue(destination)
for destination in destinations:
if immediate:
self._get_per_destination_queue(destination).attempt_new_transaction()
else:
self._get_per_destination_queue(destination).mark_new_data()
self._destination_wakeup_queue.add_to_queue(destination)
def wake_destination(self, destination: str) -> None:
"""Called when we want to retry sending transactions to a remote.
......@@ -919,7 +970,7 @@ class FederationSender(AbstractFederationSender):
might have come back.
"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.warning("Not waking up ourselves")
return
......@@ -965,7 +1016,6 @@ class FederationSender(AbstractFederationSender):
if not destinations_to_wake:
# finished waking all destinations!
self._catchup_after_startup_timer = None
break
last_processed = destinations_to_wake[-1]
......@@ -982,4 +1032,4 @@ class FederationSender(AbstractFederationSender):
last_processed,
)
self.wake_destination(destination)
await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
await self.clock.sleep(WAKEUP_INTERVAL_BETWEEN_DESTINATIONS_SEC)
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
from collections import OrderedDict
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
......@@ -59,6 +66,14 @@ sent_edus_by_type = Counter(
)
# If the retry interval is larger than this then we enter "catchup" mode
CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000
# Limit how many presence states we add to each presence EDU, to ensure that
# they are bounded in size.
MAX_PRESENCE_STATES_PER_EDU = 50
class PerDestinationQueue:
"""
Manages the per-destination transmission queues.
......@@ -134,14 +149,13 @@ class PerDestinationQueue:
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
self._pending_presence: Dict[str, UserPresenceState] = {}
self._pending_presence: OrderedDict[str, UserPresenceState] = OrderedDict()
# List of room_id -> receipt_type -> user_id -> receipt_dict,
#
# Each receipt can only have a single receipt per
# (room ID, receipt type, user ID, thread ID) tuple.
self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = []
self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
......@@ -243,15 +257,7 @@ class PerDestinationQueue:
}
)
def flush_read_receipts_for_room(self, room_id: str) -> None:
# If there are any pending receipts for this room then force-flush them
# in a new transaction.
for edu in self._pending_receipt_edus:
if room_id in edu:
self._rrs_pending_flush = True
self.attempt_new_transaction()
# No use in checking remaining EDUs if the room was found.
break
self.mark_new_data()
def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
self._pending_edus_keyed[(edu.edu_type, key)] = edu
......@@ -323,12 +329,11 @@ class PerDestinationQueue:
# not caught up yet
return
pending_pdus = []
while True:
self._new_data_to_send = False
async with _TransactionQueueManager(self) as (
pending_pdus,
pending_pdus, # noqa: F811
pending_edus,
):
if not pending_pdus and not pending_edus:
......@@ -370,7 +375,7 @@ class PerDestinationQueue:
),
)
if e.retry_interval > 60 * 60 * 1000:
if e.retry_interval > CATCHUP_RETRY_INTERVAL:
# we won't retry for another hour!
# (this suggests a significant outage)
# We drop pending EDUs because otherwise they will
......@@ -389,7 +394,7 @@ class PerDestinationQueue:
# through another mechanism, because this is all volatile!
self._pending_edus = []
self._pending_edus_keyed = {}
self._pending_presence = {}
self._pending_presence.clear()
self._pending_receipt_edus = []
self._start_catching_up()
......@@ -589,12 +594,9 @@ class PerDestinationQueue:
self._destination, last_successful_stream_ordering
)
def _get_receipt_edus(self, force_flush: bool, limit: int) -> Iterable[Edu]:
def _get_receipt_edus(self, limit: int) -> Iterable[Edu]:
if not self._pending_receipt_edus:
return
if not force_flush and not self._rrs_pending_flush:
# not yet time for this lot
return
# Send at most limit EDUs for receipts.
for content in self._pending_receipt_edus[:limit]:
......@@ -711,25 +713,29 @@ class _TransactionQueueManager:
# Add presence EDU.
if self.queue._pending_presence:
# Only send max 50 presence entries in the EDU, to bound the amount
# of data we're sending.
presence_to_add: List[JsonDict] = []
while (
self.queue._pending_presence
and len(presence_to_add) < MAX_PRESENCE_STATES_PER_EDU
):
_, presence = self.queue._pending_presence.popitem(last=False)
presence_to_add.append(
format_user_presence_state(presence, self.queue._clock.time_msec())
)
pending_edus.append(
Edu(
origin=self.queue._server_name,
destination=self.queue._destination,
edu_type=EduTypes.PRESENCE,
content={
"push": [
format_user_presence_state(
presence, self.queue._clock.time_msec()
)
for presence in self.queue._pending_presence.values()
]
},
content={"push": presence_to_add},
)
)
self.queue._pending_presence = {}
# Add read receipt EDUs.
pending_edus.extend(self.queue._get_receipt_edus(force_flush=False, limit=5))
pending_edus.extend(self.queue._get_receipt_edus(limit=5))
edu_limit = MAX_EDUS_PER_TRANSACTION - len(pending_edus)
# Next, prioritize to-device messages so that existing encryption channels
......@@ -777,13 +783,6 @@ class _TransactionQueueManager:
if not self._pdus and not pending_edus:
return [], []
# if we've decided to send a transaction anyway, and we have room, we
# may as well send any pending RRs
if edu_limit:
pending_edus.extend(
self.queue._get_receipt_edus(force_flush=True, limit=edu_limit)
)
if self._pdus:
self._last_stream_ordering = self._pdus[
-1
......
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The transport layer is responsible for both sending transactions to remote
homeservers and receiving a variety of requests from other homeservers.
......
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 Sorunome
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Collection,
Dict,
......@@ -35,6 +43,7 @@ import ijson
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
......@@ -45,7 +54,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import ExceptionBundle
if TYPE_CHECKING:
......@@ -58,9 +67,8 @@ class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
self._is_mine_server_name = hs.is_mine_server_name
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
......@@ -235,7 +243,7 @@ class TransportLayerClient:
transaction.transaction_id,
)
if transaction.destination == self.server_name:
if self._is_mine_server_name(transaction.destination):
raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is
......@@ -250,8 +258,10 @@ class TransportLayerClient:
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
try_trailing_slash_on_400=True,
# Sending a transaction should always succeed, if it doesn't
# then something is wrong and we should backoff.
backoff_on_all_error_codes=True,
)
async def make_query(
......@@ -363,12 +373,8 @@ class TransportLayerClient:
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
if self._faster_joins_enabled:
# lazy-load state on join
query_params["org.matrix.msc3706.partial_state"] = (
"true" if omit_members else "false"
)
query_params["omit_members"] = "true" if omit_members else "false"
# lazy-load state on join
query_params["omit_members"] = "true" if omit_members else "false"
return await self.client.put_json(
destination=destination,
......@@ -434,7 +440,7 @@ class TransportLayerClient:
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
{"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
"""
......@@ -480,13 +486,11 @@ class TransportLayerClient:
See synapse.federation.federation_client.FederationClient.get_public_rooms for
more information.
"""
path = _create_v1_path("/publicRooms")
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
data: Dict[str, Any] = {
"include_all_networks": "true" if include_all_networks else "false"
}
data: Dict[str, Any] = {"include_all_networks": include_all_networks}
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
......@@ -510,17 +514,15 @@ class TransportLayerClient:
)
raise
else:
path = _create_v1_path("/publicRooms")
args: Dict[str, Union[str, Iterable[str]]] = {
"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:
args["third_party_instance_id"] = (third_party_instance_id,)
args["third_party_instance_id"] = third_party_instance_id
if limit:
args["limit"] = [str(limit)]
args["limit"] = str(limit)
if since_token:
args["since"] = [since_token]
args["since"] = since_token
try:
response = await self.client.get_json(
......@@ -635,7 +637,11 @@ class TransportLayerClient:
)
async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
......@@ -650,16 +656,17 @@ class TransportLayerClient:
Response:
{
"device_keys": {
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
"<algorithm>:<key_id>": <OTK JSON>
}
}
}
}
Args:
user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
......@@ -669,7 +676,55 @@ class TransportLayerClient:
path = _create_v1_path("/user/keys/claim")
return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)
async def claim_client_keys_unstable(
self,
user: UserID,
destination: str,
query_content: JsonDict,
timeout: Optional[int],
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {"<algorithm>": <count>}
}
}
}
Response:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": <OTK JSON>
}
}
}
}
Args:
user: the user_id of the requesting user
destination: The server to query.
query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
return await self.client.post_json(
destination=destination,
path=path,
data={"one_time_keys": query_content},
timeout=timeout,
)
async def get_missing_events(
......@@ -758,6 +813,87 @@ class TransportLayerClient:
destination=destination, path=path, data={"user_ids": user_ids}
)
async def download_media_r0(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def download_media_v3(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
# Matrix 1.7 allows for this to redirect to another URL, this should
# just be ignored for an old homeserver, so always provide it.
"allow_redirect": "true",
},
follow_redirects=True,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
......@@ -859,9 +995,7 @@ def _members_omitted_parser(response: SendJoinResponse) -> Generator[None, Any,
while True:
val = yield
if not isinstance(val, bool):
raise TypeError(
"members_omitted (formerly org.matrix.msc370c.partial_state) must be a boolean"
)
raise TypeError("members_omitted must be a boolean")
response.members_omitted = val
......@@ -921,14 +1055,6 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
]
if not v1_api:
self._coros.append(
ijson.items_coro(
_members_omitted_parser(self._response),
"org.matrix.msc3706.partial_state",
use_float="True",
)
)
# The stable field name comes last, so it "wins" if the fields disagree
self._coros.append(
ijson.items_coro(
_members_omitted_parser(self._response),
......@@ -937,14 +1063,6 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
)
)
self._coros.append(
ijson.items_coro(
_servers_in_room_parser(self._response),
"org.matrix.msc3706.servers_in_room",
use_float="True",
)
)
# Again, stable field name comes last
self._coros.append(
ijson.items_coro(
......
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 Sorunome
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type
from typing_extensions import Literal
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Type
from synapse.api.errors import FederationDeniedError, SynapseError
from synapse.federation.transport.server._base import (
......@@ -25,6 +30,8 @@ from synapse.federation.transport.server._base import (
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationMediaDownloadServlet,
FederationMediaThumbnailServlet,
FederationUnstableClientKeysClaimServlet,
)
from synapse.http.server import HttpServer, JsonResource
......@@ -149,7 +156,10 @@ class PublicRoomList(BaseFederationServlet):
limit = None
data = await self.handler.get_local_public_room_list(
limit, since_token, network_tuple=network_tuple, from_federation=True
limit,
since_token,
network_tuple=network_tuple,
from_federation_origin=origin,
)
return 200, data
......@@ -190,7 +200,7 @@ class PublicRoomList(BaseFederationServlet):
since_token=since_token,
search_filter=search_filter,
network_tuple=network_tuple,
from_federation=True,
from_federation_origin=origin,
)
return 200, data
......@@ -259,6 +269,10 @@ SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
"federation": FEDERATION_SERVLET_CLASSES,
"room_list": (PublicRoomList,),
"openid": (OpenIdUserInfo,),
"media": (
FederationMediaDownloadServlet,
FederationMediaThumbnailServlet,
),
}
......@@ -305,6 +319,13 @@ def register_servlets(
):
continue
if (
servletclass == FederationMediaDownloadServlet
or servletclass == FederationMediaThumbnailServlet
):
if not hs.config.media.can_load_media_repo:
continue
servletclass(
hs=hs,
authenticator=authenticator,
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
......@@ -57,6 +64,7 @@ class Authenticator:
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.store = hs.get_datastores().main
self.federation_domain_whitelist = (
hs.config.federation.federation_domain_whitelist
......@@ -100,10 +108,12 @@ class Authenticator:
json_request["signatures"].setdefault(origin, {})[key] = sig
# if the origin_server sent a destination along it needs to match our own server_name
if destination is not None and destination != self.server_name:
if destination is not None and not self._is_mine_server_name(
destination
):
raise AuthenticationError(
HTTPStatus.UNAUTHORIZED,
"Destination mismatch in auth header",
f"Destination mismatch in auth header, received: {destination!r}",
Codes.UNAUTHORIZED,
)
if (
......@@ -170,7 +180,11 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
"""
try:
header_str = header_bytes.decode("utf-8")
params = re.split(" +", header_str)[1].split(",")
space_or_tab = "[ \t]"
params = re.split(
rf"{space_or_tab}*,{space_or_tab}*",
re.split(r"^X-Matrix +", header_str, maxsplit=1)[1],
)
param_dict: Dict[str, str] = {
k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
}
......@@ -346,13 +360,33 @@ class BaseFederationServlet:
"request"
)
return None
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationMediaDownloadServlet"
or func.__self__.__class__.__name__ # type: ignore
== "FederationMediaThumbnailServlet"
):
response = await func(
origin, content, request, *args, **kwargs
)
else:
response = await func(
origin, content, request.args, *args, **kwargs
)
else:
if (
func.__self__.__class__.__name__ # type: ignore
== "FederationMediaDownloadServlet"
or func.__self__.__class__.__name__ # type: ignore
== "FederationMediaThumbnailServlet"
):
response = await func(
origin, content, request, *args, **kwargs
)
else:
response = await func(
origin, content, request.args, *args, **kwargs
)
else:
response = await func(
origin, content, request.args, *args, **kwargs
)
finally:
# if we used the origin's context as the parent, add a new span using
# the servlet span as a parent, so that we have a link
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
......@@ -24,8 +33,6 @@ from typing import (
Union,
)
from typing_extensions import Literal
from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
......@@ -36,10 +43,15 @@ from synapse.federation.transport.server._base import (
)
from synapse.http.servlet import (
parse_boolean_from_args,
parse_integer,
parse_integer_from_args,
parse_string,
parse_string_from_args,
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.types import JsonDict
from synapse.util import SYNAPSE_VERSION
from synapse.util.ratelimitutils import FederationRateLimiter
......@@ -431,16 +443,6 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
PREFIX = FEDERATION_V2_PREFIX
def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._read_msc3706_query_param = hs.config.experimental.msc3706_enabled
async def on_PUT(
self,
origin: str,
......@@ -452,16 +454,7 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
partial_state = False
# The stable query parameter wins, if it disagrees with the unstable
# parameter for some reason.
stable_param = parse_boolean_from_args(query, "omit_members", default=None)
if stable_param is not None:
partial_state = stable_param
elif self._read_msc3706_query_param:
partial_state = parse_boolean_from_args(
query, "org.matrix.msc3706.partial_state", default=False
)
partial_state = parse_boolean_from_args(query, "omit_members", default=False)
result = await self.handler.on_send_join_request(
origin, content, room_id, caller_supports_partial_state=partial_state
......@@ -515,6 +508,9 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
event = content["event"]
invite_room_state = content.get("invite_room_state", [])
if not isinstance(invite_room_state, list):
invite_room_state = []
# Synapse expects invite_room_state to be in unsigned, as it is in v1
# API
......@@ -577,16 +573,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
key_query.append((user_id, device_id, algorithm, 1))
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=False
key_query, always_include_fallback_keys=False
)
return 200, response
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
"""
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
always includes fallback keys in the response.
Identical to the stable endpoint (FederationClientKeysClaimServlet) except
it allows for querying for multiple OTKs at once and always includes fallback
keys in the response.
"""
PREFIX = FEDERATION_UNSTABLE_PREFIX
......@@ -596,8 +599,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
counts = Counter(algorithms)
for algorithm, count in counts.items():
key_query.append((user_id, device_id, algorithm, count))
response = await self.handler.on_claim_client_keys(
origin, content, always_include_fallback_keys=True
key_query, always_include_fallback_keys=True
)
return 200, response
......@@ -783,6 +794,94 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
return 200, {"account_statuses": statuses, "failures": failures}
class FederationMediaDownloadServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""
PATH = "/media/download/(?P<media_id>[^/]*)"
RATELIMIT = True
def __init__(
self,
hs: "HomeServer",
ratelimiter: FederationRateLimiter,
authenticator: Authenticator,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.media_repo = self.hs.get_media_repository()
async def on_GET(
self,
origin: Optional[str],
content: Literal[None],
request: SynapseRequest,
media_id: str,
) -> None:
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
await self.media_repo.get_local_media(
request, media_id, None, max_timeout_ms, federation=True
)
class FederationMediaThumbnailServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/thumbnail` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""
PATH = "/media/thumbnail/(?P<media_id>[^/]*)"
RATELIMIT = True
def __init__(
self,
hs: "HomeServer",
ratelimiter: FederationRateLimiter,
authenticator: Authenticator,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.media_repo = self.hs.get_media_repository()
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.thumbnail_provider = ThumbnailProvider(
hs, self.media_repo, self.media_repo.media_storage
)
async def on_GET(
self,
origin: Optional[str],
content: Literal[None],
request: SynapseRequest,
media_id: str,
) -> None:
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self.dynamic_thumbnails:
await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms, True
)
else:
await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms, True
)
self.media_repo.mark_recently_accessed(None, media_id)
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
......@@ -805,6 +904,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationUnstableClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Defines the JSON structure of the protocol units used by the server to
"""Defines the JSON structure of the protocol units used by the server to
server protocol.
"""
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
......@@ -102,7 +109,7 @@ class AccountHandler:
"""
status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None:
status = {
......
# Copyright 2015, 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
......@@ -26,7 +33,7 @@ from synapse.replication.http.account_data import (
ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
......@@ -246,7 +253,7 @@ class AccountDataHandler:
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
self, user_id: str, room_id: str, tag: str, content: JsonMapping
) -> int:
"""Add a tag to a room for a user.
......
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import email.mime.multipart
import email.utils
......@@ -98,6 +104,22 @@ class AccountValidityHandler:
for callback in self._module_api_callbacks.on_user_registration_callbacks:
await callback(user_id)
async def on_user_login(
self,
user_id: str,
auth_provider_type: Optional[str],
auth_provider_id: Optional[str],
) -> None:
"""Tell third-party modules about a user logins.
Args:
user_id: The mxID of the user.
auth_provider_type: The type of login.
auth_provider_id: The ID of the auth provider.
"""
for callback in self._module_api_callbacks.on_user_login_callbacks:
await callback(user_id, auth_provider_type, auth_provider_id)
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
......@@ -164,7 +186,7 @@ class AccountValidityHandler:
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
UserID.from_string(user_id)
)
if user_display_name is None:
user_display_name = user_id
......@@ -212,8 +234,8 @@ class AccountValidityHandler:
addresses = []
for threepid in threepids:
if threepid["medium"] == "email":
addresses.append(threepid["address"])
if threepid.medium == "email":
addresses.append(threepid.address)
return addresses
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import (
JsonMapping,
Requester,
RoomStreamToken,
ScheduledTask,
StateMap,
TaskStatus,
UserID,
UserInfo,
create_requester,
)
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
......@@ -26,6 +56,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
REDACT_ALL_EVENTS_ACTION_NAME = "redact_all_events"
class AdminHandler:
def __init__(self, hs: "HomeServer"):
......@@ -34,8 +66,24 @@ class AdminHandler:
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
self.event_creation_handler = hs.get_event_creation_handler()
self._task_scheduler = hs.get_task_scheduler()
async def get_whois(self, user: UserID) -> JsonDict:
self._task_scheduler.register_action(
self._redact_all_events, REDACT_ALL_EVENTS_ACTION_NAME
)
self.hs = hs
async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]:
"""Get the current status of an active redaction process
Args:
redact_id: redact_id returned by start_redact_events.
"""
return await self._task_scheduler.get_task(redact_id)
async def get_whois(self, user: UserID) -> JsonMapping:
connections = []
sessions = await self._store.get_user_ip_and_agents(user)
......@@ -55,41 +103,36 @@ class AdminHandler:
return ret
async def get_user(self, user: UserID) -> Optional[JsonDict]:
async def get_user(self, user: UserID) -> Optional[JsonMapping]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None
# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
"suspended": user_info.suspended,
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved
# Add additional user metadata
profile = await self._store.get_profileinfo(user.localpart)
profile = await self._store.get_profileinfo(user)
threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
......@@ -99,10 +142,13 @@ class AdminHandler:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
......@@ -119,13 +165,7 @@ class AdminHandler:
# Get all rooms the user is in or has been in
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
Membership.INVITE,
Membership.KNOCK,
),
membership_list=Membership.LIST,
)
# We only try and fetch events for rooms the user has been in. If
......@@ -172,10 +212,10 @@ class AdminHandler:
if room.membership == Membership.JOIN:
stream_ordering = self._store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
stream_ordering = room.event_pos.stream
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
from_key = RoomStreamToken(topological=0, stream=0)
to_key = RoomStreamToken(stream=stream_ordering)
# Events that we've processed in this room
written_events: Set[str] = set()
......@@ -197,16 +237,31 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self._store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
(
events,
_,
_,
) = await self._store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_key,
to_key=to_key,
limit=100,
direction=Direction.FORWARDS,
)
if not events:
break
from_key = events[-1].internal_metadata.after
last_event = events[-1]
assert last_event.internal_metadata.stream_ordering
from_key = RoomStreamToken(
stream=last_event.internal_metadata.stream_ordering,
topological=last_event.depth,
)
events = await filter_events_for_client(
self._storage_controllers, user_id, events
self._storage_controllers,
user_id,
events,
)
writer.write_events(room_id, events)
......@@ -284,7 +339,7 @@ class AdminHandler:
start, limit, user_id
)
for media in media_ids:
writer.write_media_id(media["media_id"], media)
writer.write_media_id(media.media_id, attr.asdict(media))
logger.info(
"[%s] Written %d media_ids of %s",
......@@ -298,6 +353,155 @@ class AdminHandler:
return writer.finished()
async def start_redact_events(
self,
user_id: str,
rooms: list,
requester: JsonMapping,
reason: Optional[str],
limit: Optional[int],
) -> str:
"""
Start a task redacting the events of the given user in the given rooms
Args:
user_id: the user ID of the user whose events should be redacted
rooms: the rooms in which to redact the user's events
requester: the user requesting the events
reason: reason for requesting the redaction, ie spam, etc
limit: limit on the number of events in each room to redact
Returns:
a unique ID which can be used to query the status of the task
"""
active_tasks = await self._task_scheduler.get_tasks(
actions=[REDACT_ALL_EVENTS_ACTION_NAME],
resource_id=user_id,
statuses=[TaskStatus.ACTIVE],
)
if len(active_tasks) > 0:
raise SynapseError(
400, "Redact already in progress for user %s" % (user_id,)
)
if not limit:
limit = 1000
redact_id = await self._task_scheduler.schedule_task(
REDACT_ALL_EVENTS_ACTION_NAME,
resource_id=user_id,
params={
"rooms": rooms,
"requester": requester,
"user_id": user_id,
"reason": reason,
"limit": limit,
},
)
logger.info(
"starting redact events with redact_id %s",
redact_id,
)
return redact_id
async def _redact_all_events(
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
"""
Task to redact all of a users events in the given rooms, tracking which, if any, events
whose redaction failed
"""
assert task.params is not None
rooms = task.params.get("rooms")
assert rooms is not None
r = task.params.get("requester")
assert r is not None
admin = Requester.deserialize(self._store, r)
user_id = task.params.get("user_id")
assert user_id is not None
# puppet the user if they're ours, otherwise use admin to redact
requester = create_requester(
user_id if self.hs.is_mine_id(user_id) else admin.user.to_string(),
authenticated_entity=admin.user.to_string(),
)
reason = task.params.get("reason")
limit = task.params.get("limit")
assert limit is not None
result: Mapping[str, Any] = (
task.result if task.result else {"failed_redactions": {}}
)
for room in rooms:
room_version = await self._store.get_room_version(room)
event_ids = await self._store.get_events_sent_by_user_in_room(
user_id,
room,
limit,
["m.room.member", "m.room.message"],
)
if not event_ids:
# nothing to redact in this room
continue
events = await self._store.get_events_as_list(event_ids)
for event in events:
# we care about join events but not other membership events
if event.type == "m.room.member":
content = event.content
if content:
if content.get("membership") == Membership.JOIN:
pass
else:
continue
relations = await self._store.get_relations_for_event(
room, event.event_id, event, event_type=EventTypes.Redaction
)
# if we've already successfully redacted this event then skip processing it
if relations[0]:
continue
event_dict = {
"type": EventTypes.Redaction,
"content": {"reason": reason} if reason else {},
"room_id": room,
"sender": requester.user.to_string(),
}
if room_version.updated_redaction_rules:
event_dict["content"]["redacts"] = event.event_id
else:
event_dict["redacts"] = event.event_id
try:
# set the prev event to the offending message to allow for redactions
# to be processed in the case where the user has been kicked/banned before
# redactions are requested
(
redaction,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
prev_event_ids=[event.event_id],
ratelimit=False,
)
except Exception as ex:
logger.info(
f"Redaction of event {event.event_id} failed due to: {ex}"
)
result["failed_redactions"][event.event_id] = str(ex)
await self._task_scheduler.update_task(task.id, result=result)
return TaskStatus.COMPLETE, result, None
class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data."""
......@@ -347,7 +551,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_profile(self, profile: JsonDict) -> None:
def write_profile(self, profile: JsonMapping) -> None:
"""Write the profile of a user.
Args:
......@@ -356,7 +560,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_devices(self, devices: List[JsonDict]) -> None:
def write_devices(self, devices: Sequence[JsonMapping]) -> None:
"""Write the devices of a user.
Args:
......@@ -365,7 +569,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_connections(self, connections: List[JsonDict]) -> None:
def write_connections(self, connections: Sequence[JsonMapping]) -> None:
"""Write the connections of a user.
Args:
......@@ -375,7 +579,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None:
"""Write the account data of a user.
......@@ -386,7 +590,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
"""Write the media's metadata of a user.
Exports only the metadata, as this can be fetched from the database via
read only. In order to access the files, a connection to the correct
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
......@@ -46,6 +53,8 @@ from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
StreamKeyType,
......@@ -215,8 +224,8 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
"""
......@@ -258,19 +267,6 @@ class ApplicationServicesHandler:
):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
# All of the supported streams that this function handles use an
# integer to track progress (rather than a RoomStreamToken - a
# vector clock implementation) as they don't support multiple
# stream writers.
#
# As a result, we simply assert that new_token is an integer.
# If we do end up needing to pass a RoomStreamToken down here
# in the future, using RoomStreamToken.stream (the minimum stream
# position) to convert to an ascending integer value should work.
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == StreamKeyType.TO_DEVICE
......@@ -285,6 +281,9 @@ class ApplicationServicesHandler:
):
return
# We know we're not a `RoomStreamToken` at this point.
assert not isinstance(new_token, RoomStreamToken)
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
......@@ -325,8 +324,8 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
stream_key: str,
new_token: int,
stream_key: StreamKeyType,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
......@@ -339,6 +338,7 @@ class ApplicationServicesHandler:
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
......@@ -349,15 +349,23 @@ class ApplicationServicesHandler:
(service.id, stream_key)
):
if stream_key == StreamKeyType.RECEIPT:
assert isinstance(new_token, MultiWriterStreamToken)
# We store appservice tokens as integers, so we ignore
# the `instance_map` components and instead simply
# follow the base stream position.
new_token = MultiWriterStreamToken(stream=new_token.stream)
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
service, "read_receipt", new_token.stream
)
elif stream_key == StreamKeyType.PRESENCE:
assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
......@@ -367,6 +375,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.TO_DEVICE:
assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
......@@ -382,6 +391,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.DEVICE_LIST:
assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary(
service, new_token
)
......@@ -397,7 +407,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the typing events since the given stream token that the given application
service should receive.
......@@ -431,8 +441,8 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
......@@ -454,15 +464,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
if new_token is not None and new_token <= from_key:
if new_token is not None and new_token.stream <= from_key:
logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,)
)
return []
from_token = MultiWriterStreamToken(stream=from_key)
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key, to_key=new_token
service=service, from_key=from_token, to_key=new_token
)
return receipts
......@@ -471,7 +483,7 @@ class ApplicationServicesHandler:
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
......@@ -491,7 +503,7 @@ class ApplicationServicesHandler:
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
events: List[JsonDict] = []
events: List[JsonMapping] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
......@@ -841,8 +853,10 @@ class ApplicationServicesHandler:
return True
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
......@@ -863,18 +877,18 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = []
for user_id, device, algorithm in query:
for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
missing.append((user_id, device, algorithm, count))
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
(user_id, device, algorithm, count)
)
continue
......@@ -882,10 +896,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
run_in_background( # type: ignore[call-overload]
self.appservice_api.claim_client_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
......@@ -938,10 +952,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
run_in_background( # type: ignore[call-overload]
self.appservice_api.query_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
......
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
# Copyright 2017 Vector Creations Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import unicodedata
......@@ -52,7 +59,6 @@ from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import (
......@@ -160,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
"phone" not in identifier and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
......@@ -212,26 +217,25 @@ class AuthHandler:
self._password_enabled_for_login = hs.config.auth.password_enabled_for_login
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
self._account_validity_handler = hs.get_account_validity_handler()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
# The number of seconds to keep a UI auth session active.
self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts
# Ratelimiter for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
self._clock = self.hs.get_clock()
......@@ -275,6 +279,8 @@ class AuthHandler:
# response.
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
async def validate_user_via_ui_auth(
self,
requester: Requester,
......@@ -323,8 +329,12 @@ class AuthHandler:
LimitExceededError if the ratelimiter's failed request count for this
user is too high to proceed
"""
if self.msc3861_oauth_delegation_enabled:
raise SynapseError(
HTTPStatus.INTERNAL_SERVER_ERROR, "UIA shouldn't be used with MSC3861"
)
if not requester.access_token_id:
raise ValueError("Cannot validate a user without an access token")
if can_skip_ui_auth and self._ui_auth_session_timeout:
......@@ -1419,12 +1429,6 @@ class AuthHandler:
return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
result = await self.validate_hash(password, password_hash)
if not result:
logger.warning("Failed password login for user %s", user_id)
......@@ -1575,7 +1579,10 @@ class AuthHandler:
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
address = canonicalise_email(address)
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
......@@ -1606,7 +1613,10 @@ class AuthHandler:
"""
# 'Canonicalise' email addresses as per above
if medium == "email":
address = canonicalise_email(address)
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
await self.store.user_delete_threepid(user_id, medium, address)
......@@ -1749,15 +1759,18 @@ class AuthHandler:
registered.
auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
# If the account has been deactivated, do not proceed with the login.
#
# This gets checked again when the token is submitted but this lets us
# provide an HTML error page to the user (instead of issuing a token and
# having it error later).
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
if deactivated:
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
user_profile_data = await self.store.get_profileinfo(
UserID.from_string(registered_user_id).localpart
UserID.from_string(registered_user_id)
)
# Store any extra attributes which will be passed in the login response.
......@@ -1783,6 +1796,13 @@ class AuthHandler:
client_redirect_url, "loginToken", login_token
)
# Run post-login module callback handlers
await self._account_validity_handler.on_user_login(
user_id=registered_user_id,
auth_provider_type=LoginType.SSO,
auth_provider_id=auth_provider_id,
)
# if the client is whitelisted, we can redirect straight to it
if client_redirect_url.startswith(self._whitelisted_sso_clients):
request.redirect(redirect_url)
......@@ -2170,7 +2190,7 @@ class PasswordAuthProvider:
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
......@@ -2233,7 +2253,7 @@ class PasswordAuthProvider:
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, List, Optional
......@@ -67,8 +74,12 @@ class CasHandler:
self._cas_server_url = hs.config.cas.cas_server_url
self._cas_service_url = hs.config.cas.cas_service_url
self._cas_protocol_version = hs.config.cas.cas_protocol_version
self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas.cas_required_attributes
self._cas_enable_registration = hs.config.cas.cas_enable_registration
self._cas_allow_numeric_ids = hs.config.cas.cas_allow_numeric_ids
self._cas_numeric_ids_prefix = hs.config.cas.cas_numeric_ids_prefix
self._http_client = hs.get_proxied_http_client()
......@@ -76,12 +87,13 @@ class CasHandler:
self.idp_id = "cas"
# user-facing name of this auth provider
self.idp_name = "CAS"
self.idp_name = hs.config.cas.idp_name
# MXC URI for icon for this auth provider
self.idp_icon = hs.config.cas.idp_icon
# we do not currently support brands/icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
# optional brand identifier for this auth provider
self.idp_brand = hs.config.cas.idp_brand
self._sso_handler = hs.get_sso_handler()
......@@ -120,7 +132,10 @@ class CasHandler:
Returns:
The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
if self._cas_protocol_version == 3:
uri = self._cas_server_url + "/p3/proxyValidate"
else:
uri = self._cas_server_url + "/proxyValidate"
args = {
"ticket": ticket,
"service": self._build_service_param(service_args),
......@@ -175,6 +190,9 @@ class CasHandler:
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
# if numeric user IDs are allowed and username is numeric then we add the prefix so Synapse can handle it
if self._cas_allow_numeric_ids and user is not None and user.isdigit():
user = f"{self._cas_numeric_ids_prefix}{user}"
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
......@@ -390,4 +408,5 @@ class CasHandler:
client_redirect_url,
cas_response_to_user_attributes,
grandfather_existing_users,
registration_enabled=self._cas_enable_registration,
)