Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • maunium/synapse
  • leytilera/synapse
2 results
Show changes
Showing
with 2492 additions and 985 deletions
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
......@@ -102,7 +109,7 @@ class AccountHandler:
"""
status = {"exists": False}
userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None:
status = {
......
# Copyright 2015, 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
......@@ -26,7 +33,7 @@ from synapse.replication.http.account_data import (
ReplicationRemoveUserAccountDataRestServlet,
)
from synapse.streams import EventSource
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
from synapse.types import JsonDict, JsonMapping, StrCollection, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
......@@ -246,7 +253,7 @@ class AccountDataHandler:
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
self, user_id: str, room_id: str, tag: str, content: JsonMapping
) -> int:
"""Add a tag to a room for a user.
......
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import email.mime.multipart
import email.utils
......@@ -98,6 +104,22 @@ class AccountValidityHandler:
for callback in self._module_api_callbacks.on_user_registration_callbacks:
await callback(user_id)
async def on_user_login(
self,
user_id: str,
auth_provider_type: Optional[str],
auth_provider_id: Optional[str],
) -> None:
"""Tell third-party modules about a user logins.
Args:
user_id: The mxID of the user.
auth_provider_type: The type of login.
auth_provider_id: The ID of the auth provider.
"""
for callback in self._module_api_callbacks.on_user_login_callbacks:
await callback(user_id, auth_provider_type, auth_provider_id)
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time
......@@ -164,7 +186,7 @@ class AccountValidityHandler:
try:
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
UserID.from_string(user_id)
)
if user_display_name is None:
user_display_name = user_id
......@@ -212,8 +234,8 @@ class AccountValidityHandler:
addresses = []
for threepid in threepids:
if threepid["medium"] == "email":
addresses.append(threepid["address"])
if threepid.medium == "email":
addresses.append(threepid.address)
return addresses
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.types import (
JsonMapping,
Requester,
RoomStreamToken,
ScheduledTask,
StateMap,
TaskStatus,
UserID,
UserInfo,
create_requester,
)
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
......@@ -26,6 +56,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
REDACT_ALL_EVENTS_ACTION_NAME = "redact_all_events"
class AdminHandler:
def __init__(self, hs: "HomeServer"):
......@@ -34,8 +66,24 @@ class AdminHandler:
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
self.event_creation_handler = hs.get_event_creation_handler()
self._task_scheduler = hs.get_task_scheduler()
async def get_whois(self, user: UserID) -> JsonDict:
self._task_scheduler.register_action(
self._redact_all_events, REDACT_ALL_EVENTS_ACTION_NAME
)
self.hs = hs
async def get_redact_task(self, redact_id: str) -> Optional[ScheduledTask]:
"""Get the current status of an active redaction process
Args:
redact_id: redact_id returned by start_redact_events.
"""
return await self._task_scheduler.get_task(redact_id)
async def get_whois(self, user: UserID) -> JsonMapping:
connections = []
sessions = await self._store.get_user_ip_and_agents(user)
......@@ -55,41 +103,36 @@ class AdminHandler:
return ret
async def get_user(self, user: UserID) -> Optional[JsonDict]:
async def get_user(self, user: UserID) -> Optional[JsonMapping]:
"""Function to get user details"""
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
user_info: Optional[UserInfo] = await self._store.get_user_by_id(
user.to_string()
)
if user_info is None:
return None
# Restrict returned information to a known set of fields. This prevents additional
# fields added to get_user_by_id from modifying Synapse's external API surface.
user_info_to_return = {
"name",
"admin",
"deactivated",
"shadow_banned",
"creation_ts",
"appservice_id",
"consent_server_notice_sent",
"consent_version",
"consent_ts",
"user_type",
"is_guest",
user_info_dict = {
"name": user.to_string(),
"admin": user_info.is_admin,
"deactivated": user_info.is_deactivated,
"locked": user_info.locked,
"shadow_banned": user_info.is_shadow_banned,
"creation_ts": user_info.creation_ts,
"appservice_id": user_info.appservice_id,
"consent_server_notice_sent": user_info.consent_server_notice_sent,
"consent_version": user_info.consent_version,
"consent_ts": user_info.consent_ts,
"user_type": user_info.user_type,
"is_guest": user_info.is_guest,
"suspended": user_info.suspended,
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set.
user_info_dict = {
key: value
for key, value in user_info_dict.items()
if key in user_info_to_return
}
user_info_dict["approved"] = user_info.approved
# Add additional user metadata
profile = await self._store.get_profileinfo(user.localpart)
profile = await self._store.get_profileinfo(user)
threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
......@@ -99,10 +142,13 @@ class AdminHandler:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
user_info_dict["last_seen_ts"] = last_seen_ts
return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
......@@ -119,13 +165,7 @@ class AdminHandler:
# Get all rooms the user is in or has been in
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
Membership.INVITE,
Membership.KNOCK,
),
membership_list=Membership.LIST,
)
# We only try and fetch events for rooms the user has been in. If
......@@ -172,10 +212,10 @@ class AdminHandler:
if room.membership == Membership.JOIN:
stream_ordering = self._store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
stream_ordering = room.event_pos.stream
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
from_key = RoomStreamToken(topological=0, stream=0)
to_key = RoomStreamToken(stream=stream_ordering)
# Events that we've processed in this room
written_events: Set[str] = set()
......@@ -197,16 +237,31 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self._store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
(
events,
_,
_,
) = await self._store.paginate_room_events_by_topological_ordering(
room_id=room_id,
from_key=from_key,
to_key=to_key,
limit=100,
direction=Direction.FORWARDS,
)
if not events:
break
from_key = events[-1].internal_metadata.after
last_event = events[-1]
assert last_event.internal_metadata.stream_ordering
from_key = RoomStreamToken(
stream=last_event.internal_metadata.stream_ordering,
topological=last_event.depth,
)
events = await filter_events_for_client(
self._storage_controllers, user_id, events
self._storage_controllers,
user_id,
events,
)
writer.write_events(room_id, events)
......@@ -284,7 +339,7 @@ class AdminHandler:
start, limit, user_id
)
for media in media_ids:
writer.write_media_id(media["media_id"], media)
writer.write_media_id(media.media_id, attr.asdict(media))
logger.info(
"[%s] Written %d media_ids of %s",
......@@ -298,6 +353,155 @@ class AdminHandler:
return writer.finished()
async def start_redact_events(
self,
user_id: str,
rooms: list,
requester: JsonMapping,
reason: Optional[str],
limit: Optional[int],
) -> str:
"""
Start a task redacting the events of the given user in the given rooms
Args:
user_id: the user ID of the user whose events should be redacted
rooms: the rooms in which to redact the user's events
requester: the user requesting the events
reason: reason for requesting the redaction, ie spam, etc
limit: limit on the number of events in each room to redact
Returns:
a unique ID which can be used to query the status of the task
"""
active_tasks = await self._task_scheduler.get_tasks(
actions=[REDACT_ALL_EVENTS_ACTION_NAME],
resource_id=user_id,
statuses=[TaskStatus.ACTIVE],
)
if len(active_tasks) > 0:
raise SynapseError(
400, "Redact already in progress for user %s" % (user_id,)
)
if not limit:
limit = 1000
redact_id = await self._task_scheduler.schedule_task(
REDACT_ALL_EVENTS_ACTION_NAME,
resource_id=user_id,
params={
"rooms": rooms,
"requester": requester,
"user_id": user_id,
"reason": reason,
"limit": limit,
},
)
logger.info(
"starting redact events with redact_id %s",
redact_id,
)
return redact_id
async def _redact_all_events(
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
"""
Task to redact all of a users events in the given rooms, tracking which, if any, events
whose redaction failed
"""
assert task.params is not None
rooms = task.params.get("rooms")
assert rooms is not None
r = task.params.get("requester")
assert r is not None
admin = Requester.deserialize(self._store, r)
user_id = task.params.get("user_id")
assert user_id is not None
# puppet the user if they're ours, otherwise use admin to redact
requester = create_requester(
user_id if self.hs.is_mine_id(user_id) else admin.user.to_string(),
authenticated_entity=admin.user.to_string(),
)
reason = task.params.get("reason")
limit = task.params.get("limit")
assert limit is not None
result: Mapping[str, Any] = (
task.result if task.result else {"failed_redactions": {}}
)
for room in rooms:
room_version = await self._store.get_room_version(room)
event_ids = await self._store.get_events_sent_by_user_in_room(
user_id,
room,
limit,
["m.room.member", "m.room.message"],
)
if not event_ids:
# nothing to redact in this room
continue
events = await self._store.get_events_as_list(event_ids)
for event in events:
# we care about join events but not other membership events
if event.type == "m.room.member":
content = event.content
if content:
if content.get("membership") == Membership.JOIN:
pass
else:
continue
relations = await self._store.get_relations_for_event(
room, event.event_id, event, event_type=EventTypes.Redaction
)
# if we've already successfully redacted this event then skip processing it
if relations[0]:
continue
event_dict = {
"type": EventTypes.Redaction,
"content": {"reason": reason} if reason else {},
"room_id": room,
"sender": requester.user.to_string(),
}
if room_version.updated_redaction_rules:
event_dict["content"]["redacts"] = event.event_id
else:
event_dict["redacts"] = event.event_id
try:
# set the prev event to the offending message to allow for redactions
# to be processed in the case where the user has been kicked/banned before
# redactions are requested
(
redaction,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
prev_event_ids=[event.event_id],
ratelimit=False,
)
except Exception as ex:
logger.info(
f"Redaction of event {event.event_id} failed due to: {ex}"
)
result["failed_redactions"][event.event_id] = str(ex)
await self._task_scheduler.update_task(task.id, result=result)
return TaskStatus.COMPLETE, result, None
class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data."""
......@@ -347,7 +551,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_profile(self, profile: JsonDict) -> None:
def write_profile(self, profile: JsonMapping) -> None:
"""Write the profile of a user.
Args:
......@@ -356,7 +560,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_devices(self, devices: List[JsonDict]) -> None:
def write_devices(self, devices: Sequence[JsonMapping]) -> None:
"""Write the devices of a user.
Args:
......@@ -365,7 +569,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_connections(self, connections: List[JsonDict]) -> None:
def write_connections(self, connections: Sequence[JsonMapping]) -> None:
"""Write the connections of a user.
Args:
......@@ -375,7 +579,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
@abc.abstractmethod
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
self, file_name: str, account_data: Mapping[str, JsonMapping]
) -> None:
"""Write the account data of a user.
......@@ -386,7 +590,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None:
"""Write the media's metadata of a user.
Exports only the metadata, as this can be fetched from the database via
read only. In order to access the files, a connection to the correct
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
......@@ -46,6 +53,8 @@ from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
MultiWriterStreamToken,
RoomAlias,
RoomStreamToken,
StreamKeyType,
......@@ -215,8 +224,8 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
stream_key: StreamKeyType,
new_token: Union[int, RoomStreamToken, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
"""
......@@ -258,19 +267,6 @@ class ApplicationServicesHandler:
):
return
# Assert that new_token is an integer (and not a RoomStreamToken).
# All of the supported streams that this function handles use an
# integer to track progress (rather than a RoomStreamToken - a
# vector clock implementation) as they don't support multiple
# stream writers.
#
# As a result, we simply assert that new_token is an integer.
# If we do end up needing to pass a RoomStreamToken down here
# in the future, using RoomStreamToken.stream (the minimum stream
# position) to convert to an ascending integer value should work.
# Additional context: https://github.com/matrix-org/synapse/pull/11137
assert isinstance(new_token, int)
# Ignore to-device messages if the feature flag is not enabled
if (
stream_key == StreamKeyType.TO_DEVICE
......@@ -285,6 +281,9 @@ class ApplicationServicesHandler:
):
return
# We know we're not a `RoomStreamToken` at this point.
assert not isinstance(new_token, RoomStreamToken)
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
......@@ -325,8 +324,8 @@ class ApplicationServicesHandler:
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
stream_key: str,
new_token: int,
stream_key: StreamKeyType,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s", stream_key)
......@@ -339,6 +338,7 @@ class ApplicationServicesHandler:
#
# Instead we simply grab the latest typing updates in _handle_typing
# and, if they apply to this application service, send it off.
assert isinstance(new_token, int)
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
......@@ -349,15 +349,23 @@ class ApplicationServicesHandler:
(service.id, stream_key)
):
if stream_key == StreamKeyType.RECEIPT:
assert isinstance(new_token, MultiWriterStreamToken)
# We store appservice tokens as integers, so we ignore
# the `instance_map` components and instead simply
# follow the base stream position.
new_token = MultiWriterStreamToken(stream=new_token.stream)
events = await self._handle_receipts(service, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "read_receipt", new_token
service, "read_receipt", new_token.stream
)
elif stream_key == StreamKeyType.PRESENCE:
assert isinstance(new_token, int)
events = await self._handle_presence(service, users, new_token)
self.scheduler.enqueue_for_appservice(service, ephemeral=events)
......@@ -367,6 +375,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.TO_DEVICE:
assert isinstance(new_token, int)
# Retrieve a list of to-device message events, as well as the
# maximum stream token of the messages we were able to retrieve.
to_device_messages = await self._get_to_device_messages(
......@@ -382,6 +391,7 @@ class ApplicationServicesHandler:
)
elif stream_key == StreamKeyType.DEVICE_LIST:
assert isinstance(new_token, int)
device_list_summary = await self._get_device_list_summary(
service, new_token
)
......@@ -397,7 +407,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the typing events since the given stream token that the given application
service should receive.
......@@ -431,8 +441,8 @@ class ApplicationServicesHandler:
return typing
async def _handle_receipts(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
......@@ -454,15 +464,17 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
if new_token is not None and new_token <= from_key:
if new_token is not None and new_token.stream <= from_key:
logger.debug(
"Rejecting token lower than or equal to stored: %s" % (new_token,)
)
return []
from_token = MultiWriterStreamToken(stream=from_key)
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key, to_key=new_token
service=service, from_key=from_token, to_key=new_token
)
return receipts
......@@ -471,7 +483,7 @@ class ApplicationServicesHandler:
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
) -> List[JsonDict]:
) -> List[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
......@@ -491,7 +503,7 @@ class ApplicationServicesHandler:
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
events: List[JsonDict] = []
events: List[JsonMapping] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
......@@ -841,9 +853,9 @@ class ApplicationServicesHandler:
return True
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
]:
"""Claim one time keys from application services.
......@@ -856,7 +868,7 @@ class ApplicationServicesHandler:
Returns:
A tuple of:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
......@@ -865,18 +877,18 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
missing = []
for user_id, device, algorithm in query:
for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
missing.append((user_id, device, algorithm))
missing.append((user_id, device, algorithm, count))
continue
# Find the associated appservice.
for service in services:
if service.is_exclusive_user(user_id):
query_by_appservice.setdefault(service.id, []).append(
(user_id, device, algorithm)
(user_id, device, algorithm, count)
)
continue
......@@ -884,10 +896,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
run_in_background( # type: ignore[call-overload]
self.appservice_api.claim_client_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
......@@ -897,12 +909,11 @@ class ApplicationServicesHandler:
)
# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a list
# and the caller combines them.
claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.append(result[0])
claimed_keys.update(result[0])
missing.extend(result[1])
return claimed_keys, missing
......@@ -941,10 +952,10 @@ class ApplicationServicesHandler:
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
run_in_background( # type: ignore[call-overload]
self.appservice_api.query_keys,
# We know this must be an app service.
self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
self.store.get_app_service_by_id(service_id),
service_query,
)
for service_id, service_query in query_by_appservice.items()
......
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
# Copyright 2017 Vector Creations Ltd
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import unicodedata
......@@ -52,7 +59,6 @@ from synapse.api.errors import (
NotFoundError,
StoreError,
SynapseError,
UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import (
......@@ -160,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
"phone" not in identifier and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
......@@ -212,26 +217,25 @@ class AuthHandler:
self._password_enabled_for_login = hs.config.auth.password_enabled_for_login
self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
self._account_validity_handler = hs.get_account_validity_handler()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
# The number of seconds to keep a UI auth session active.
self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts
# Ratelimiter for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
)
self._clock = self.hs.get_clock()
......@@ -275,6 +279,8 @@ class AuthHandler:
# response.
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
async def validate_user_via_ui_auth(
self,
requester: Requester,
......@@ -323,8 +329,12 @@ class AuthHandler:
LimitExceededError if the ratelimiter's failed request count for this
user is too high to proceed
"""
if self.msc3861_oauth_delegation_enabled:
raise SynapseError(
HTTPStatus.INTERNAL_SERVER_ERROR, "UIA shouldn't be used with MSC3861"
)
if not requester.access_token_id:
raise ValueError("Cannot validate a user without an access token")
if can_skip_ui_auth and self._ui_auth_session_timeout:
......@@ -1419,12 +1429,6 @@ class AuthHandler:
return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
result = await self.validate_hash(password, password_hash)
if not result:
logger.warning("Failed password login for user %s", user_id)
......@@ -1575,7 +1579,10 @@ class AuthHandler:
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
address = canonicalise_email(address)
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
......@@ -1606,7 +1613,10 @@ class AuthHandler:
"""
# 'Canonicalise' email addresses as per above
if medium == "email":
address = canonicalise_email(address)
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
await self.store.user_delete_threepid(user_id, medium, address)
......@@ -1749,15 +1759,18 @@ class AuthHandler:
registered.
auth_provider_session_id: The session ID from the SSO IdP received during login.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
# If the account has been deactivated, do not proceed with the login.
#
# This gets checked again when the token is submitted but this lets us
# provide an HTML error page to the user (instead of issuing a token and
# having it error later).
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
if deactivated:
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
user_profile_data = await self.store.get_profileinfo(
UserID.from_string(registered_user_id).localpart
UserID.from_string(registered_user_id)
)
# Store any extra attributes which will be passed in the login response.
......@@ -1783,6 +1796,13 @@ class AuthHandler:
client_redirect_url, "loginToken", login_token
)
# Run post-login module callback handlers
await self._account_validity_handler.on_user_login(
user_id=registered_user_id,
auth_provider_type=LoginType.SSO,
auth_provider_id=auth_provider_id,
)
# if the client is whitelisted, we can redirect straight to it
if client_redirect_url.startswith(self._whitelisted_sso_clients):
request.redirect(redirect_url)
......@@ -2170,7 +2190,7 @@ class PasswordAuthProvider:
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
......@@ -2233,7 +2253,7 @@ class PasswordAuthProvider:
# result is always the right type, but as it is 3rd party code it might not be
if not isinstance(result, tuple) or len(result) != 2:
logger.warning(
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
callback,
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, List, Optional
......@@ -67,8 +74,12 @@ class CasHandler:
self._cas_server_url = hs.config.cas.cas_server_url
self._cas_service_url = hs.config.cas.cas_service_url
self._cas_protocol_version = hs.config.cas.cas_protocol_version
self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas.cas_required_attributes
self._cas_enable_registration = hs.config.cas.cas_enable_registration
self._cas_allow_numeric_ids = hs.config.cas.cas_allow_numeric_ids
self._cas_numeric_ids_prefix = hs.config.cas.cas_numeric_ids_prefix
self._http_client = hs.get_proxied_http_client()
......@@ -76,12 +87,13 @@ class CasHandler:
self.idp_id = "cas"
# user-facing name of this auth provider
self.idp_name = "CAS"
self.idp_name = hs.config.cas.idp_name
# MXC URI for icon for this auth provider
self.idp_icon = hs.config.cas.idp_icon
# we do not currently support brands/icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
# optional brand identifier for this auth provider
self.idp_brand = hs.config.cas.idp_brand
self._sso_handler = hs.get_sso_handler()
......@@ -120,7 +132,10 @@ class CasHandler:
Returns:
The parsed CAS response.
"""
uri = self._cas_server_url + "/proxyValidate"
if self._cas_protocol_version == 3:
uri = self._cas_server_url + "/p3/proxyValidate"
else:
uri = self._cas_server_url + "/proxyValidate"
args = {
"ticket": ticket,
"service": self._build_service_param(service_args),
......@@ -175,6 +190,9 @@ class CasHandler:
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
# if numeric user IDs are allowed and username is numeric then we add the prefix so Synapse can handle it
if self._cas_allow_numeric_ids and user is not None and user.isdigit():
user = f"{self._cas_numeric_ids_prefix}{user}"
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
......@@ -390,4 +408,5 @@ class CasHandler:
client_redirect_url,
cas_response_to_user_attributes,
grandfather_existing_users,
registration_enabled=self._cas_enable_registration,
)
# Copyright 2017, 2018 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.constants import Membership
from synapse.api.errors import SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.metrics.background_process_metrics import run_as_background_process
......@@ -39,11 +47,11 @@ class DeactivateAccountHandler:
self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
self._third_party_rules = hs.get_third_party_event_rules()
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
self._third_party_rules = hs.get_third_party_event_rules()
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
......@@ -103,10 +111,10 @@ class DeactivateAccountHandler:
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids:
for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
......@@ -117,9 +125,9 @@ class DeactivateAccountHandler:
# Remove any local threepid associations for this account.
local_threepids = await self.store.user_get_threepids(user_id)
for threepid in local_threepids:
for local_threepid in local_threepids:
await self._auth_handler.delete_local_threepid(
user_id, threepid["medium"], threepid["address"]
user_id, local_threepid.medium, local_threepid.address
)
# delete any devices belonging to the user, which will also
......@@ -162,9 +170,9 @@ class DeactivateAccountHandler:
# parts users from rooms (if it isn't already running)
self._start_user_parting()
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
await self._reject_pending_invites_for_user(user_id)
# Reject all pending invites and knocks for the user, so that the
# user doesn't show up in the "invited" section of rooms' members list.
await self._reject_pending_invites_and_knocks_for_user(user_id)
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
......@@ -188,34 +196,37 @@ class DeactivateAccountHandler:
return identity_server_supports_unbinding
async def _reject_pending_invites_for_user(self, user_id: str) -> None:
"""Reject pending invites addressed to a given user ID.
async def _reject_pending_invites_and_knocks_for_user(self, user_id: str) -> None:
"""Reject pending invites and knocks addressed to a given user ID.
Args:
user_id: The user ID to reject pending invites for.
user_id: The user ID to reject pending invites and knocks for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
pending_knocks = await self.store.get_knocked_at_rooms_for_local_user(user_id)
for room in pending_invites:
for room in itertools.chain(pending_invites, pending_knocks):
try:
await self._room_member_handler.update_membership(
create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
Membership.LEAVE,
ratelimit=False,
require_consent=False,
)
logger.info(
"Rejected invite for deactivated user %r in room %r",
"Rejected %r for deactivated user %r in room %r",
room.membership,
user_id,
room.room_id,
)
except Exception:
logger.exception(
"Failed to reject invite for user %r in room %r:"
"Failed to reject %r for user %r in room %r:"
" ignoring and continuing",
room.membership,
user_id,
room.room_id,
)
......@@ -250,17 +261,32 @@ class DeactivateAccountHandler:
user = UserID.from_string(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)
requester = create_requester(user, authenticated_entity=self._server_name)
should_erase = await self.store.is_user_erased(user_id)
for room_id in rooms_for_user:
logger.info("User parter parting %r from %r", user_id, room_id)
try:
# Before parting the user, redact all membership events if requested
if should_erase:
event_ids = await self.store.get_membership_event_ids_for_user(
user_id, room_id
)
for event_id in event_ids:
await self.store.expire_event(event_id)
await self._room_member_handler.update_membership(
create_requester(user, authenticated_entity=self._server_name),
requester,
user,
room_id,
"leave",
ratelimit=False,
require_consent=False,
)
# Mark the room forgotten too, because they won't be able to do this
# for us. This may lead to the room being purged eventually.
await self._room_member_handler.forget(user, room_id)
except Exception:
logger.exception(
"Failed to part user %r from room %r: ignoring and continuing",
......@@ -297,5 +323,5 @@ class DeactivateAccountHandler:
# Add the user to the directory, if necessary. Note that
# this must be done after the user is re-activated, because
# deactivated users are excluded from the user directory.
profile = await self.store.get_profileinfo(user.localpart)
profile = await self.store.get_profileinfo(user)
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.api.errors import ShadowBanError
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.logging.opentracing import set_tag
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.delayed_events import (
ReplicationAddedDelayedEventRestServlet,
)
from synapse.storage.databases.main.delayed_events import (
DelayedEventDetails,
DelayID,
EventType,
StateKey,
Timestamp,
UserLocalpart,
)
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
JsonDict,
Requester,
RoomID,
UserID,
create_requester,
)
from synapse.util.events import generate_fake_event_id
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class DelayedEventsHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._config = hs.config
self._clock = hs.get_clock()
self._event_creation_handler = hs.get_event_creation_handler()
self._room_member_handler = hs.get_room_member_handler()
self._request_ratelimiter = hs.get_request_ratelimiter()
# Ratelimiter for management of existing delayed events,
# keyed by the sending user ID & device ID.
self._delayed_event_mgmt_ratelimiter = Ratelimiter(
store=self._store,
clock=self._clock,
cfg=self._config.ratelimiting.rc_delayed_event_mgmt,
)
self._next_delayed_event_call: Optional[IDelayedCall] = None
# The current position in the current_state_delta stream
self._event_pos: Optional[int] = None
# Guard to ensure we only process event deltas one at a time
self._event_processing = False
if hs.config.worker.worker_app is None:
self._repl_client = None
async def _schedule_db_events() -> None:
# We kick this off to pick up outstanding work from before the last restart.
# Block until we're up to date.
await self._unsafe_process_new_event()
hs.get_notifier().add_replication_callback(self.notify_new_event)
# Kick off again (without blocking) to catch any missed notifications
# that may have fired before the callback was added.
self._clock.call_later(0, self.notify_new_event)
# Delayed events that are already marked as processed on startup might not have been
# sent properly on the last run of the server, so unmark them to send them again.
# Caveat: this will double-send delayed events that successfully persisted, but failed
# to be removed from the DB table of delayed events.
# TODO: To avoid double-sending, scan the timeline to find which of these events were
# already sent. To do so, must store delay_ids in sent events to retrieve them later.
await self._store.unprocess_delayed_events()
events, next_send_ts = await self._store.process_timeout_delayed_events(
self._get_current_ts()
)
if next_send_ts:
self._schedule_next_at(next_send_ts)
# Can send the events in background after having awaited on marking them as processed
run_as_background_process(
"_send_events",
self._send_events,
events,
)
self._initialized_from_db = run_as_background_process(
"_schedule_db_events", _schedule_db_events
)
else:
self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs)
@property
def _is_master(self) -> bool:
return self._repl_client is None
def notify_new_event(self) -> None:
"""
Called when there may be more state event deltas to process,
which should cancel pending delayed events for the same state.
"""
if self._event_processing:
return
self._event_processing = True
async def process() -> None:
try:
await self._unsafe_process_new_event()
finally:
self._event_processing = False
run_as_background_process("delayed_events.notify_new_event", process)
async def _unsafe_process_new_event(self) -> None:
# If self._event_pos is None then means we haven't fetched it from the DB yet
if self._event_pos is None:
self._event_pos = await self._store.get_delayed_events_stream_pos()
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
if self._event_pos > room_max_stream_ordering:
# apparently, we've processed more events than exist in the database!
# this can happen if events are removed with history purge or similar.
logger.warning(
"Event stream ordering appears to have gone backwards (%i -> %i): "
"rewinding delayed events processor",
self._event_pos,
room_max_stream_ordering,
)
self._event_pos = room_max_stream_ordering
# Loop round handling deltas until we're up to date
while True:
with Measure(self._clock, "delayed_events_delta"):
room_max_stream_ordering = self._store.get_room_max_stream_ordering()
if self._event_pos == room_max_stream_ordering:
return
logger.debug(
"Processing delayed events %s->%s",
self._event_pos,
room_max_stream_ordering,
)
(
max_pos,
deltas,
) = await self._storage_controllers.state.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
logger.debug(
"Handling %d state deltas for delayed events processing",
len(deltas),
)
await self._handle_state_deltas(deltas)
self._event_pos = max_pos
# Expose current event processing position to prometheus
event_processing_positions.labels("delayed_events").set(max_pos)
await self._store.update_delayed_events_stream_pos(max_pos)
async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
"""
Process current state deltas to cancel other users' pending delayed events
that target the same state.
"""
for delta in deltas:
if delta.event_id is None:
logger.debug(
"Not handling delta for deleted state: %r %r",
delta.event_type,
delta.state_key,
)
continue
logger.debug(
"Handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id
)
event = await self._store.get_event(
delta.event_id, check_room_id=delta.room_id
)
sender = UserID.from_string(event.sender)
next_send_ts = await self._store.cancel_delayed_state_events(
room_id=delta.room_id,
event_type=delta.event_type,
state_key=delta.state_key,
not_from_localpart=(
sender.localpart
if sender.domain == self._config.server.server_name
else ""
),
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def add(
self,
requester: Requester,
*,
room_id: str,
event_type: str,
state_key: Optional[str],
origin_server_ts: Optional[int],
content: JsonDict,
delay: int,
) -> str:
"""
Creates a new delayed event and schedules its delivery.
Args:
requester: The requester of the delayed event, who will be its owner.
room_id: The ID of the room where the event should be sent to.
event_type: The type of event to be sent.
state_key: The state key of the event to be sent, or None if it is not a state event.
origin_server_ts: The custom timestamp to send the event with.
If None, the timestamp will be the actual time when the event is sent.
content: The content of the event to be sent.
delay: How long (in milliseconds) to wait before automatically sending the event.
Returns: The ID of the added delayed event.
Raises:
SynapseError: if the delayed event fails validation checks.
"""
# Use standard request limiter for scheduling new delayed events.
# TODO: Instead apply ratelimiting based on the scheduled send time.
# See https://github.com/element-hq/synapse/issues/18021
await self._request_ratelimiter.ratelimit(requester)
self._event_creation_handler.validator.validate_builder(
self._event_creation_handler.event_builder_factory.for_room_version(
await self._store.get_room_version(room_id),
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": str(requester.user),
**({"state_key": state_key} if state_key is not None else {}),
},
),
self._config,
)
creation_ts = self._get_current_ts()
delay_id, next_send_ts = await self._store.add_delayed_event(
user_localpart=requester.user.localpart,
device_id=requester.device_id,
creation_ts=creation_ts,
room_id=room_id,
event_type=event_type,
state_key=state_key,
origin_server_ts=origin_server_ts,
content=content,
delay=delay,
)
if self._repl_client is not None:
# NOTE: If this throws, the delayed event will remain in the DB and
# will be picked up once the main worker gets another delayed event.
await self._repl_client(
instance_name=MAIN_PROCESS_INSTANCE_NAME,
next_send_ts=next_send_ts,
)
elif self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
return delay_id
def on_added(self, next_send_ts: int) -> None:
next_send_ts = Timestamp(next_send_ts)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def cancel(self, requester: Requester, delay_id: str) -> None:
"""
Cancels the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,
(requester.user.to_string(), requester.device_id),
)
await self._initialized_from_db
next_send_ts = await self._store.cancel_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def restart(self, requester: Requester, delay_id: str) -> None:
"""
Restarts the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,
(requester.user.to_string(), requester.device_id),
)
await self._initialized_from_db
next_send_ts = await self._store.restart_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
current_ts=self._get_current_ts(),
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def send(self, requester: Requester, delay_id: str) -> None:
"""
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
# Use standard request limiter for sending delayed events on-demand,
# as an on-demand send is similar to sending a regular event.
await self._request_ratelimiter.ratelimit(requester)
await self._initialized_from_db
event, next_send_ts = await self._store.process_target_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
await self._send_event(
DelayedEventDetails(
delay_id=DelayID(delay_id),
user_localpart=UserLocalpart(requester.user.localpart),
room_id=event.room_id,
type=event.type,
state_key=event.state_key,
origin_server_ts=event.origin_server_ts,
content=event.content,
device_id=event.device_id,
)
)
async def _send_on_timeout(self) -> None:
self._next_delayed_event_call = None
events, next_send_ts = await self._store.process_timeout_delayed_events(
self._get_current_ts()
)
if next_send_ts:
self._schedule_next_at(next_send_ts)
await self._send_events(events)
async def _send_events(self, events: List[DelayedEventDetails]) -> None:
sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set()
for event in events:
if event.state_key is not None:
state_info = (event.room_id, event.type, event.state_key)
if state_info in sent_state:
continue
else:
state_info = None
try:
# TODO: send in background if message event or non-conflicting state event
await self._send_event(event)
if state_info is not None:
sent_state.add(state_info)
except Exception:
logger.exception("Failed to send delayed event")
for room_id, event_type, state_key in sent_state:
await self._store.delete_processed_delayed_state_events(
room_id=str(room_id),
event_type=event_type,
state_key=state_key,
)
def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None:
if next_send_ts is not None:
self._schedule_next_at(next_send_ts)
elif self._next_delayed_event_call is not None:
self._next_delayed_event_call.cancel()
self._next_delayed_event_call = None
def _schedule_next_at(self, next_send_ts: Timestamp) -> None:
delay = next_send_ts - self._get_current_ts()
delay_sec = delay / 1000 if delay > 0 else 0
if self._next_delayed_event_call is None:
self._next_delayed_event_call = self._clock.call_later(
delay_sec,
run_as_background_process,
"_send_on_timeout",
self._send_on_timeout,
)
else:
self._next_delayed_event_call.reset(delay_sec)
async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
"""Return all pending delayed events requested by the given user."""
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,
(requester.user.to_string(), requester.device_id),
)
return await self._store.get_all_delayed_events_for_user(
requester.user.localpart
)
async def _send_event(
self,
event: DelayedEventDetails,
txn_id: Optional[str] = None,
) -> None:
user_id = UserID(event.user_localpart, self._config.server.server_name)
user_id_str = user_id.to_string()
# Create a new requester from what data is currently available
requester = create_requester(
user_id,
is_guest=await self._store.is_guest(user_id_str),
device_id=event.device_id,
)
try:
if event.state_key is not None and event.type == EventTypes.Member:
membership = event.content.get("membership")
assert membership is not None
event_id, _ = await self._room_member_handler.update_membership(
requester,
target=UserID.from_string(event.state_key),
room_id=event.room_id.to_string(),
action=membership,
content=event.content,
origin_server_ts=event.origin_server_ts,
)
else:
event_dict: JsonDict = {
"type": event.type,
"content": event.content,
"room_id": event.room_id.to_string(),
"sender": user_id_str,
}
if event.origin_server_ts is not None:
event_dict["origin_server_ts"] = event.origin_server_ts
if event.state_key is not None:
event_dict["state_key"] = event.state_key
(
sent_event,
_,
) = await self._event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
)
event_id = sent_event.event_id
except ShadowBanError:
event_id = generate_fake_event_id()
finally:
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
try:
await self._store.delete_processed_delayed_event(
event.delay_id, event.user_localpart
)
except Exception:
logger.exception("Failed to delete processed delayed event")
set_tag("event_id", event_id)
def _get_current_ts(self) -> Timestamp:
return Timestamp(self._clock.time_msec())
def _next_send_ts_changed(self, next_send_ts: Optional[Timestamp]) -> bool:
# The DB alone knows if the next send time changed after adding/modifying
# a delayed event, but if we were to ever miss updating our delayed call's
# firing time, we may miss other updates. So, keep track of changes to the
# the next send time here instead of in the DB.
cached_next_send_ts = (
int(self._next_delayed_event_call.getTime() * 1000)
if self._next_delayed_event_call is not None
else None
)
return next_send_ts != cached_next_send_ts
# Copyright 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
AbstractSet,
Dict,
Iterable,
List,
......@@ -28,7 +33,7 @@ from typing import (
)
from synapse.api import errors
from synapse.api.constants import EduTypes, EventTypes
from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import (
Codes,
FederationDeniedError,
......@@ -42,11 +47,18 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
ScheduledTask,
StrCollection,
StreamKeyType,
StreamToken,
TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
......@@ -56,13 +68,17 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.cancellation import cancellable
from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.retryutils import (
NotRetryingDestination,
filter_destinations_by_retry_limiter,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
MAX_DEVICE_DISPLAY_NAME_LEN = 100
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
......@@ -76,13 +92,23 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self._appservice_handler = hs.get_application_service_handler()
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self._event_sources = hs.get_event_sources()
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self._query_appservices_for_keys = (
hs.config.experimental.msc3984_appservice_key_query
)
self._task_scheduler = hs.get_task_scheduler()
self.device_list_updater = DeviceListWorkerUpdater(hs)
self._task_scheduler.register_action(
self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
)
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
......@@ -146,20 +172,32 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
self,
user_id: str,
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
now_device_lists_key = self.store.get_device_stream_token()
if now_token:
now_device_lists_key = now_token.device_list_key
changed_users = await self.store.get_device_list_changes_in_rooms(
room_ids, from_token.device_list_key
room_ids,
from_token.device_list_key,
now_device_lists_key,
)
if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, [user_id]
from_token.device_list_key,
[user_id],
to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
......@@ -177,150 +215,227 @@ class DeviceWorkerHandler:
tracked_users.add(user_id)
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key,
tracked_users,
to_key=now_device_lists_key,
)
return changed
@trace
@measure_func("device.get_user_ids_changed")
@cancellable
async def get_user_ids_changed(
self, user_id: str, from_token: StreamToken
) -> JsonDict:
) -> DeviceListUpdates:
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
"""
set_tag("user_id", user_id)
set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
now_token = self._event_sources.get_current_token()
changed = await self.get_device_changes_in_shared_rooms(
user_id, room_ids, from_token
)
# We need to work out all the different membership changes for the user
# and user they share a room with, to pass to
# `generate_sync_entry_for_device_list`. See its docstring for details
# on the data required.
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
joined_room_ids = await self.store.get_rooms_for_user(user_id)
member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
# Get the set of rooms that the user has joined/left
membership_changes = (
await self.store.get_current_state_delta_membership_changes_for_user(
user_id, from_key=from_token.room_key, to_key=now_token.room_key
)
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = from_token.room_key.stream
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
# Check if the forward extremities have changed. If not then we know
# the current state won't have changed, and so we can skip this room.
try:
if not await self.store.have_room_forward_extremities_changed_since(
room_id, stream_ordering
):
continue
except errors.StoreError:
pass
current_state_ids = await self._state_storage.get_current_state_ids(
room_id, await_full_state=False
# Check for newly joined or left rooms. We need to make sure that we add
# to newly joined in the case membership goes from join -> leave -> join
# again.
newly_joined_rooms: Set[str] = set()
newly_left_rooms: Set[str] = set()
for change in membership_changes:
# We check for changes in "joinedness", i.e. if the membership has
# changed to or from JOIN.
if change.membership == Membership.JOIN:
if change.prev_membership != Membership.JOIN:
newly_joined_rooms.add(change.room_id)
newly_left_rooms.discard(change.room_id)
elif change.prev_membership == Membership.JOIN:
newly_joined_rooms.discard(change.room_id)
newly_left_rooms.add(change.room_id)
# We now work out if any other users have since joined or left the rooms
# the user is currently in.
# List of membership changes per room
room_to_deltas: Dict[str, List[StateDelta]] = {}
# The set of event IDs of membership events (so we can fetch their
# associated membership).
memberships_to_fetch: Set[str] = set()
# TODO: Only pull out membership events?
state_changes = await self.store.get_current_state_deltas_for_rooms(
joined_room_ids, from_token=from_token.room_key, to_token=now_token.room_key
)
for delta in state_changes:
if delta.event_type != EventTypes.Member:
continue
room_to_deltas.setdefault(delta.room_id, []).append(delta)
if delta.event_id:
memberships_to_fetch.add(delta.event_id)
if delta.prev_event_id:
memberships_to_fetch.add(delta.prev_event_id)
# Fetch all the memberships for the membership events
event_id_to_memberships: Mapping[str, Optional[EventIdMembership]] = {}
if memberships_to_fetch:
event_id_to_memberships = await self.store.get_membership_from_event_ids(
memberships_to_fetch
)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
continue
joined_invited_knocked = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
)
# Fetch the current state at the time.
try:
event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
# we have purged the stream_ordering index since the stream
# ordering: treat it the same as a new room
event_ids = []
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
log_kv(
{"event": "encountered empty previous state", "room_id": room_id}
)
for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
# We now want to find any user that have newly joined/invited/knocked,
# or newly left, similarly to above.
newly_joined_or_invited_or_knocked_users: Set[str] = set()
newly_left_users: Set[str] = set()
for _, deltas in room_to_deltas.items():
for delta in deltas:
# Get the prev/new memberships for the delta
new_membership = None
prev_membership = None
if delta.event_id:
m = event_id_to_memberships.get(delta.event_id)
if m is not None:
new_membership = m.membership
if delta.prev_event_id:
m = event_id_to_memberships.get(delta.prev_event_id)
if m is not None:
prev_membership = m.membership
# Check if a user has newly joined/invited/knocked, or left.
if new_membership in joined_invited_knocked:
if prev_membership not in joined_invited_knocked:
newly_joined_or_invited_or_knocked_users.add(delta.state_key)
newly_left_users.discard(delta.state_key)
elif prev_membership in joined_invited_knocked:
newly_joined_or_invited_or_knocked_users.discard(delta.state_key)
newly_left_users.add(delta.state_key)
# Now we actually calculate the device list entry with the information
# calculated above.
device_list_updates = await self.generate_sync_entry_for_device_list(
user_id=user_id,
since_token=from_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
if not current_member_id:
continue
log_kv(
{
"changed": device_list_updates.changed,
"left": device_list_updates.left,
}
)
# mapping from event_id -> state_dict
prev_state_ids = await self._state_storage.get_state_ids_for_events(
event_ids,
await_full_state=False,
)
return device_list_updates
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in prev_state_ids.values():
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for etype, state_key in current_state_ids.keys():
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
break
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.items():
etype, state_key = key
if etype != EventTypes.Member:
continue
async def generate_sync_entry_for_device_list(
self,
user_id: str,
since_token: StreamToken,
now_token: StreamToken,
joined_room_ids: AbstractSet[str],
newly_joined_rooms: AbstractSet[str],
newly_joined_or_invited_or_knocked_users: AbstractSet[str],
newly_left_rooms: AbstractSet[str],
newly_left_users: AbstractSet[str],
) -> DeviceListUpdates:
"""Generate the DeviceListUpdates section of sync
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in prev_state_ids.values():
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
possibly_changed.add(state_key)
break
if possibly_changed or possibly_left:
possibly_joined = possibly_changed
possibly_left = possibly_changed | possibly_left
# Double check if we still share rooms with the given user.
users_rooms = await self.store.get_rooms_for_users(possibly_left)
for changed_user_id, entries in users_rooms.items():
if any(rid in room_ids for rid in entries):
possibly_left.discard(changed_user_id)
else:
possibly_joined.discard(changed_user_id)
Args:
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
newly_joined_or_invited_or_knocked_users: Set of users that have joined,
been invited to a room or are knocking on a room since
previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
"""
# Take a copy since these fields will be mutated later.
newly_joined_or_invited_or_knocked_users = set(
newly_joined_or_invited_or_knocked_users
)
newly_left_users = set(newly_left_users)
# We want to figure out what user IDs the client should refetch
# device keys for, and which users we aren't going to track changes
# for anymore.
#
# For the first step we check:
# a. if any users we share a room with have updated their devices,
# and
# b. we also check if we've joined any new rooms, or if a user has
# joined a room we're in.
#
# For the second step we just find any users we no longer share a
# room with by looking at all users that have left a room plus users
# that were in a room we've left.
users_that_have_changed = set()
# Step 1a, check for changes in devices of users we share a room
# with
users_that_have_changed = await self.get_device_changes_in_shared_rooms(
user_id,
joined_room_ids,
from_token=since_token,
now_token=now_token,
)
else:
possibly_joined = set()
possibly_left = set()
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
joined_users = await self.store.get_users_in_room(room_id)
newly_joined_or_invited_or_knocked_users.update(joined_users)
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_or_knocked_users)
result = {"changed": list(possibly_joined), "left": list(possibly_left)}
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key
)
users_that_have_changed.update(user_signatures_changed)
log_kv(result)
# Now find users that we no longer track
for room_id in newly_left_rooms:
left_users = await self.store.get_users_in_room(room_id)
newly_left_users.update(left_users)
return result
# Remove any users that we still share a room with.
left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
for user_id, entries in left_users_rooms.items():
if any(rid in joined_room_ids for rid in entries):
newly_left_users.discard(user_id)
return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users)
async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
if not self.hs.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "User is not hosted on this homeserver")
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
......@@ -329,6 +444,30 @@ class DeviceWorkerHandler:
user_id, "self_signing"
)
# Check if the application services have any results.
if self._query_appservices_for_keys:
# Query the appservice for all devices for this user.
query: Dict[str, Optional[List[str]]] = {user_id: None}
# Query the appservices for any keys.
appservice_results = await self._appservice_handler.query_keys(query)
# Merge results, overriding anything from the database.
appservice_devices = appservice_results.get("device_keys", {}).get(
user_id, {}
)
# Filter the database results to only those devices that the appservice has
# *not* responded with.
devices = [d for d in devices if d["device_id"] not in appservice_devices]
# Append the appservice response by wrapping each result in another dictionary.
devices.extend(
{"device_id": device_id, "keys": device}
for device_id, device in appservice_devices.items()
)
# TODO Handle cross-signing keys.
return {
"user_id": user_id,
"stream_id": stream_id,
......@@ -348,6 +487,35 @@ class DeviceWorkerHandler:
"Trying handling device list state for partial join: not supported on workers."
)
DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
DEVICE_MSGS_DELETE_SLEEP_MS = 100
async def _delete_device_messages(
self,
task: ScheduledTask,
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
assert task.params is not None
user_id = task.params["user_id"]
device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"]
# Delete the messages in batches to avoid too much DB load.
from_stream_id = None
while True:
from_stream_id, _ = await self.store.delete_messages_for_device_between(
user_id=user_id,
device_id=device_id,
from_stream_id=from_stream_id,
to_stream_id=up_to_stream_id,
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
)
if from_stream_id is None:
return TaskStatus.COMPLETE, None, None
await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
class DeviceHandler(DeviceWorkerHandler):
device_list_updater: "DeviceListUpdater"
......@@ -358,6 +526,11 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool
self._dont_notify_new_devices_for = (
hs.config.registration.dont_notify_new_devices_for
)
self.device_list_updater = DeviceListUpdater(hs, self)
......@@ -435,6 +608,9 @@ class DeviceHandler(DeviceWorkerHandler):
self._check_device_name_length(initial_device_display_name)
# Check if we should send out device lists updates for this new device.
notify = user_id not in self._dont_notify_new_devices_for
if device_id is not None:
new_device = await self.store.store_device(
user_id=user_id,
......@@ -444,7 +620,8 @@ class DeviceHandler(DeviceWorkerHandler):
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [device_id])
if notify:
await self.notify_device_update(user_id, [device_id])
return device_id
# if the device id is not specified, we'll autogen one, but loop a few
......@@ -460,7 +637,8 @@ class DeviceHandler(DeviceWorkerHandler):
auth_provider_session_id=auth_provider_session_id,
)
if new_device:
await self.notify_device_update(user_id, [new_device_id])
if notify:
await self.notify_device_update(user_id, [new_device_id])
return new_device_id
attempts += 1
......@@ -502,6 +680,7 @@ class DeviceHandler(DeviceWorkerHandler):
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
to_device_stream_id = self._event_sources.get_current_token().to_device_key
try:
await self.store.delete_devices(user_id, device_ids)
......@@ -513,8 +692,6 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
# Delete data specific to each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
......@@ -533,8 +710,59 @@ class DeviceHandler(DeviceWorkerHandler):
f"org.matrix.msc3890.local_notification_settings.{device_id}",
)
# Delete device messages asynchronously and in batches using the task scheduler
# We specify an upper stream id to avoid deleting non delivered messages
# if an user re-uses a device ID.
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=device_id,
params={
"user_id": user_id,
"device_id": device_id,
"up_to_stream_id": to_device_stream_id,
},
)
# Pushers are deleted after `delete_access_tokens_for_user` is called so that
# modules using `on_logged_out` hook can use them if needed.
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
await self.notify_device_update(user_id, device_ids)
async def upsert_device(
self, user_id: str, device_id: str, display_name: Optional[str] = None
) -> bool:
"""Create or update a device
Args:
user_id: The user to update devices of.
device_id: The device to update.
display_name: The new display name for this device.
Returns:
True if the device was created, False if it was updated.
"""
# Reject a new displayname which is too long.
self._check_device_name_length(display_name)
created = await self.store.store_device(
user_id,
device_id,
initial_device_display_name=display_name,
)
if not created:
await self.store.update_device(
user_id,
device_id,
new_display_name=display_name,
)
await self.notify_device_update(user_id, [device_id])
return created
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
"""Update the given device
......@@ -624,29 +852,38 @@ class DeviceHandler(DeviceWorkerHandler):
async def store_dehydrated_device(
self,
user_id: str,
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
keys_for_device: Optional[JsonDict] = None,
) -> str:
"""Store a dehydrated device for a user. If the user had a previous
dehydrated device, it is removed.
"""Store a dehydrated device for a user, optionally storing the keys associated with
it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
device_id = await self.check_device_registered(
user_id,
None,
device_id,
initial_device_display_name,
)
time_now = self.clock.time_msec()
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
user_id, device_id, device_data, time_now, keys_for_device
)
if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])
return device_id
async def rehydrate_device(
......@@ -668,12 +905,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access
# token to use the dehydrated device's ID and copy the old device
# display name to the dehydrated device, and destroy the old device
# ID
# token and refresh token to use the dehydrated device's ID and
# copy the old device display name to the dehydrated device,
# and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token(
access_token, device_id
)
await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None:
raise errors.NotFoundError()
......@@ -691,6 +929,22 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None:
"""
Delete a stored dehydrated device.
Args:
user_id: the user_id to delete the device from
device_id: id of the dehydrated device to delete
"""
success = await self.store.remove_dehydrated_device(user_id, device_id)
if not success:
raise errors.NotFoundError()
await self.delete_devices(user_id, [device_id])
await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:
"""Called when we have a new local device list update that we need to
......@@ -737,7 +991,6 @@ class DeviceHandler(DeviceWorkerHandler):
else:
assert max_stream_id == stream_id
# Avoid moving `room_id` backwards.
pass
if self._handle_new_device_update_new_data:
continue
......@@ -772,6 +1025,13 @@ class DeviceHandler(DeviceWorkerHandler):
context=opentracing_context,
)
await self.store.mark_redundant_device_lists_pokes(
user_id=user_id,
device_id=device_id,
room_id=room_id,
converted_upto_stream_id=stream_id,
)
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
......@@ -781,17 +1041,16 @@ class DeviceHandler(DeviceWorkerHandler):
user_id,
hosts,
)
for host in hosts:
self.federation_sender.send_device_messages(
host, immediate=False
)
# TODO: when called, this isn't in a logging context.
# This leads to log spam, sentry event spam, and massive
# memory usage.
# See https://github.com/matrix-org/synapse/issues/12552.
# log_kv(
# {"message": "sent device update to host", "host": host}
# )
await self.federation_sender.send_device_messages(
hosts, immediate=False
)
# TODO: when called, this isn't in a logging context.
# This leads to log spam, sentry event spam, and massive
# memory usage.
# See https://github.com/matrix-org/synapse/issues/12552.
# log_kv(
# {"message": "sent device update to host", "host": host}
# )
if current_stream_id != stream_id:
# Clear the set of hosts we've already sent to as we're
......@@ -896,19 +1155,20 @@ class DeviceHandler(DeviceWorkerHandler):
# Notify things that device lists need to be sent out.
self.notifier.notify_replication()
for host in potentially_changed_hosts:
self.federation_sender.send_device_messages(host, immediate=False)
await self.federation_sender.send_device_messages(
potentially_changed_hosts, immediate=False
)
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
ip = client_ips.get((device["user_id"], device["device_id"]))
device.update(
{
"last_seen_user_agent": ip.get("user_agent"),
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
"last_seen_user_agent": ip.user_agent if ip else None,
"last_seen_ts": ip.last_seen if ip else None,
"last_seen_ip": ip.ip if ip else None,
}
)
......@@ -919,19 +1179,15 @@ class DeviceListWorkerUpdater:
def __init__(self, hs: "HomeServer"):
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
ReplicationUserDevicesResyncRestServlet,
)
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
self._multi_user_device_resync_client = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
......@@ -946,37 +1202,7 @@ class DeviceListWorkerUpdater:
# Shortcut empty requests
return {}
try:
return await self._multi_user_device_resync_client(user_ids=user_ids)
except SynapseError as err:
if not (
err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
):
raise
# Fall back to single requests
result: Dict[str, Optional[JsonDict]] = {}
for user_id in user_ids:
result[user_id] = await self._user_device_resync_client(user_id=user_id)
return result
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
None when we weren't able to fetch the device info for some reason,
e.g. due to a connection problem.
"""
return (await self.multi_user_device_resync([user_id]))[user_id]
return await self._multi_user_device_resync_client(user_ids=user_ids)
class DeviceListUpdater(DeviceListWorkerUpdater):
......@@ -990,6 +1216,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self._notifier = hs.get_notifier()
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
self._resync_linearizer = Linearizer(name="remote_device_resync")
# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[
......@@ -1129,7 +1356,14 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
if resync:
await self.user_device_resync(user_id)
# We mark as stale up front in case we get restarted.
await self.store.mark_remote_users_device_caches_as_stale([user_id])
run_as_background_process(
"_maybe_retry_device_resync",
self.multi_user_device_resync,
[user_id],
False,
)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
......@@ -1192,14 +1426,23 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Filter out users whose host is marked as "down" up front.
hosts = await filter_destinations_by_retry_limiter(
{get_domain_from_id(u) for u in need_resync}, self.clock, self.store
)
hosts = set(hosts)
# Iterate over the set of user IDs.
for user_id in need_resync:
if get_domain_from_id(user_id) not in hosts:
continue
try:
# Try to resync the current user's devices list.
result = await self.user_device_resync(
user_id=user_id,
mark_failed_as_stale=False,
)
result = (await self.multi_user_device_resync([user_id], False))[
user_id
]
# user_device_resync only returns a result if it managed to
# successfully resync and update the database. Updating the table
......@@ -1226,7 +1469,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonDict]]:
) -> Dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
......@@ -1246,9 +1489,11 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
failed = set()
# TODO(Perf): Actually batch these up
for user_id in user_ids:
user_result, user_failed = await self._user_device_resync_returning_failed(
user_id
)
async with self._resync_linearizer.queue(user_id):
(
user_result,
user_failed,
) = await self._user_device_resync_returning_failed(user_id)
result[user_id] = user_result
if user_failed:
failed.add(user_id)
......@@ -1258,21 +1503,9 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
return result
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[JsonDict]:
result, failed = await self._user_device_resync_returning_failed(user_id)
if failed and mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
await self.store.mark_remote_users_device_caches_as_stale((user_id,))
return result
async def _user_device_resync_returning_failed(
self, user_id: str
) -> Tuple[Optional[JsonDict], bool]:
) -> Tuple[Optional[JsonMapping], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
......@@ -1285,6 +1518,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
e.g. due to a connection problem.
- True iff the resync failed and the device list should be marked as stale.
"""
# Check that we haven't gone and fetched the devices since we last
# checked if we needed to resync these device lists.
if await self.store.get_users_whose_devices_are_cached([user_id]):
cached = await self.store.get_cached_devices_for_user(user_id)
return cached, False
logger.debug("Attempting to resync the device list for %s", user_id)
log_kv({"message": "Doing resync to update device list."})
# Fetch all devices for the user.
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
......@@ -25,7 +33,9 @@ from synapse.logging.opentracing import (
log_kv,
set_tag,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
)
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
......@@ -46,6 +56,9 @@ class DeviceMessageHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.is_mine = hs.is_mine
if hs.config.experimental.msc3814_enabled:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
......@@ -71,12 +84,12 @@ class DeviceMessageHandler:
# sync. We do all device list resyncing on the master instance, so if
# we're on a worker we hit the device resync replication API.
if hs.config.worker.worker_app is None:
self._user_device_resync = (
hs.get_device_handler().device_list_updater.user_device_resync
self._multi_user_device_resync = (
hs.get_device_handler().device_list_updater.multi_user_device_resync
)
else:
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
self._multi_user_device_resync = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
# a rate limiter for room key requests. The keys are
......@@ -84,14 +97,16 @@ class DeviceMessageHandler:
self._ratelimiter = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
cfg=hs.config.ratelimiting.rc_key_requests,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
"""
Handle receiving to-device messages from remote homeservers.
Note that any errors thrown from this method will cause the federation /send
request to receive an error response.
Args:
origin: The remote homeserver.
content: The JSON dictionary containing the to-device messages.
......@@ -198,7 +213,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
# Immediately attempt a resync in the background
run_in_background(self._user_device_resync, user_id=sender_user_id)
run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
async def send_device_message(
self,
......@@ -221,6 +236,13 @@ class DeviceMessageHandler:
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
if not UserID.is_valid(user_id):
logger.warning(
"Ignoring attempt to send device message to invalid user: %r",
user_id,
)
continue
# add an opentracing log entry for each message
for device_id, message_content in by_device.items():
log_kv(
......@@ -297,7 +319,93 @@ class DeviceMessageHandler:
)
if self.federation_sender:
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation_sender.send_device_messages(destination)
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
await self.federation_sender.send_device_messages(remote_messages.keys())
async def get_events_for_dehydrated_device(
self,
requester: Requester,
device_id: str,
since_token: Optional[str],
limit: int,
) -> JsonDict:
"""Fetches up to `limit` events sent to `device_id` starting from `since_token`
and returns the new since token. If there are no more messages, returns an empty
array.
Args:
requester: the user requesting the messages
device_id: ID of the dehydrated device
since_token: stream id to start from when fetching messages
limit: the number of messages to fetch
Returns:
A dict containing the to-device messages, as well as a token that the client
can provide in the next call to fetch the next batch of messages
"""
user_id = requester.user.to_string()
# only allow fetching messages for the dehydrated device id currently associated
# with the user
dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
if dehydrated_device is None:
raise SynapseError(
HTTPStatus.FORBIDDEN,
"No dehydrated device exists",
Codes.FORBIDDEN,
)
dehydrated_device_id, _ = dehydrated_device
if device_id != dehydrated_device_id:
raise SynapseError(
HTTPStatus.FORBIDDEN,
"You may only fetch messages for your dehydrated device",
Codes.FORBIDDEN,
)
since_stream_id = 0
if since_token:
if not since_token.startswith("d"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"from parameter %r has an invalid format" % (since_token,),
errcode=Codes.INVALID_PARAM,
)
try:
since_stream_id = int(since_token[1:])
except Exception:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"from parameter %r has an invalid format" % (since_token,),
errcode=Codes.INVALID_PARAM,
)
to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device(
user_id, device_id, since_stream_id, to_token, limit
)
for message in messages:
# Remove the message id before sending to client
message_id = message.pop("message_id", None)
if message_id:
set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d) for "
"dehydrated device %s, user_id %s",
len(messages),
since_stream_id,
stream_id,
to_token,
device_id,
user_id,
)
return {
"events": messages,
"next_batch": f"d{stream_id}",
}
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal
from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Sequence
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
......@@ -52,7 +57,9 @@ class DirectoryHandler:
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.server.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules()
self._third_party_event_rules = (
hs.get_module_api_callbacks().third_party_event_rules
)
self.server_name = hs.hostname
self.federation = hs.get_federation_client()
......@@ -60,7 +67,7 @@ class DirectoryHandler:
"directory", self.on_directory_query
)
self.spam_checker = hs.get_spam_checker()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
async def _create_association(
self,
......@@ -71,9 +78,11 @@ class DirectoryHandler:
) -> None:
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
# meow: allow specific users to include anything in room aliases
if creator not in self.config.meow.validation_override:
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if ":" in room_alias.localpart:
raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
......@@ -118,7 +127,10 @@ class DirectoryHandler:
user_id = requester.user.to_string()
room_alias_str = room_alias.to_string()
if len(room_alias_str) > MAX_ALIAS_LENGTH:
if (
user_id not in self.hs.config.meow.validation_override
and len(room_alias_str) > MAX_ALIAS_LENGTH
):
raise SynapseError(
400,
"Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
......@@ -145,10 +157,12 @@ class DirectoryHandler:
403, "You must be in the room to create an alias for it"
)
spam_check = await self.spam_checker.user_may_create_room_alias(
user_id, room_alias
spam_check = (
await self._spam_checker_module_callbacks.user_may_create_room_alias(
user_id, room_alias
)
)
if spam_check != self.spam_checker.NOT_SPAM:
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise AuthError(
403,
"This user is not permitted to create this alias",
......@@ -158,7 +172,7 @@ class DirectoryHandler:
if not self.config.roomdirectory.is_alias_creation_allowed(
user_id, room_id, room_alias_str
):
) and not is_admin:
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
......@@ -273,7 +287,9 @@ class DirectoryHandler:
except RequestSendFailed:
raise SynapseError(502, "Failed to fetch alias")
except CodeMessageException as e:
logging.warning("Error retrieving alias")
logging.warning(
"Error retrieving alias %s -> %s %s", room_alias, e.code, e.msg
)
if e.code == 404:
fed_result = None
else:
......@@ -444,7 +460,9 @@ class DirectoryHandler:
"""
user_id = requester.user.to_string()
spam_check = await self.spam_checker.user_may_publish_room(user_id, room_id)
spam_check = await self._spam_checker_module_callbacks.user_may_publish_room(
user_id, room_id
)
if spam_check != NOT_SPAM:
raise AuthError(
403,
......@@ -490,18 +508,16 @@ class DirectoryHandler:
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
):
) and not await self.auth.is_server_admin(requester):
# Let's just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module
allowed_by_third_party_rules = (
await (
self.third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
allowed_by_third_party_rules = await (
self._third_party_event_rules.check_visibility_can_be_modified(
room_id, visibility
)
)
if not allowed_by_third_party_rules:
......
# Copyright 2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
......@@ -29,8 +35,12 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
ScheduledTask,
TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
......@@ -38,7 +48,10 @@ from synapse.types import (
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.retryutils import (
NotRetryingDestination,
filter_destinations_by_retry_limiter,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
......@@ -46,6 +59,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
......@@ -55,6 +71,8 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
self._worker_lock_handler = hs.get_worker_locks_handler()
self._task_scheduler = hs.get_task_scheduler()
federation_registry = hs.get_federation_registry()
......@@ -75,6 +93,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update,
)
self.device_key_uploader = self.upload_device_keys_for_user
else:
self.device_key_uploader = (
ReplicationUploadKeysForUserRestServlet.make_client(hs)
)
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
......@@ -95,6 +119,10 @@ class E2eKeysHandler:
hs.config.experimental.msc3984_appservice_key_query
)
self._task_scheduler.register_action(
self._delete_old_one_time_keys_task, "delete_old_otks"
)
@trace
@cancellable
async def query_devices(
......@@ -138,6 +166,11 @@ class E2eKeysHandler:
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
if not UserID.is_valid(user_id):
# Ignore invalid user IDs, which is the same behaviour as if
# the user existed but had no keys.
continue
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
......@@ -252,10 +285,8 @@ class E2eKeysHandler:
"%d destinations to query devices for", len(remote_queries_not_in_cache)
)
async def _query(
destination_queries: Tuple[str, Dict[str, Iterable[str]]]
) -> None:
destination, queries = destination_queries
async def _query(destination: str) -> None:
queries = remote_queries_not_in_cache[destination]
return await self._query_devices_for_destination(
results,
cross_signing_keys,
......@@ -265,18 +296,32 @@ class E2eKeysHandler:
timeout,
)
# Only try and fetch keys for destinations that are not marked as
# down.
unfiltered_destinations = remote_queries_not_in_cache.keys()
filtered_destinations = set(
await filter_destinations_by_retry_limiter(
unfiltered_destinations,
self.clock,
self.store,
# Let's give an arbitrary grace period for those hosts that are
# only recently down
retry_due_within_ms=60 * 1000,
)
)
failures.update(
(dest, _NOT_READY_FOR_RETRY_FAILURE)
for dest in (unfiltered_destinations - filtered_destinations)
)
await concurrently_execute(
_query,
remote_queries_not_in_cache.items(),
filtered_destinations,
10,
delay_cancellation=True,
)
ret = {"device_keys": results, "failures": failures}
ret.update(cross_signing_keys)
return ret
return {"device_keys": results, "failures": failures, **cross_signing_keys}
@trace
async def _query_devices_for_destination(
......@@ -408,7 +453,7 @@ class E2eKeysHandler:
@cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
) -> Dict[str, Dict[str, JsonMapping]]:
"""Get cross-signing keys for users from the database
Args:
......@@ -545,25 +590,30 @@ class E2eKeysHandler:
device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
"device_keys", {}
)
if any(
not self.is_mine(UserID.from_string(user_id))
for user_id in device_keys_query
):
raise SynapseError(400, "User is not hosted on this homeserver")
res = await self.query_local_devices(
device_keys_query,
include_displaynames=(
self.config.federation.allow_device_name_lookup_over_federation
),
)
ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
ret.update(cross_signing_keys)
return ret
return {"device_keys": res, **cross_signing_keys}
async def claim_local_one_time_keys(
self, local_query: List[Tuple[str, str, str]]
self,
local_query: List[Tuple[str, str, str, int]],
always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
......@@ -572,44 +622,120 @@ class E2eKeysHandler:
3. Attempt to fetch fallback keys from the database.
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
local_query: An iterable of tuples of (user ID, device ID, algorithm, number of keys).
always_include_fallback_keys: True to always include fallback keys.
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
# Cap the number of OTKs that can be claimed at once to avoid abuse.
local_query = [
(user_id, device_id, algorithm, min(count, 5))
for user_id, device_id, algorithm, count in local_query
]
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
# TODO Should this query for fallback keys of uploaded OTKs if
# always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
appservice_results = []
appservice_results = {}
# Calculate which user ID / device ID / algorithm tuples to get fallback
# keys for. This can be either only missing results *or* all results
# (which don't already have a fallback key).
if always_include_fallback_keys:
# Build the fallback query as any part of the original query where
# the appservice didn't respond with a fallback key.
fallback_query = []
# Iterate each item in the original query and search the results
# from the appservice for that user ID / device ID. If it is found,
# check if any of the keys match the requested algorithm & are a
# fallback key.
for user_id, device_id, algorithm, _count in local_query:
# Check if the appservice responded for this query.
as_result = appservice_results.get(user_id, {}).get(device_id, {})
found_otk = False
for key_id, key_json in as_result.items():
if key_id.startswith(f"{algorithm}:"):
# A OTK or fallback key was found for this query.
found_otk = True
# A fallback key was found for this query, no need to
# query further.
if key_json.get("fallback", False):
break
else:
# No fallback key was found from appservices, query for it.
# Only mark the fallback key as used if no OTK was found
# (from either the database or appservices).
mark_as_used = not found_otk and not any(
key_id.startswith(f"{algorithm}:")
for key_id in otk_results.get(user_id, {})
.get(device_id, {})
.keys()
)
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
else:
# All fallback keys get marked as used.
fallback_query = [
# Note that it doesn't make sense to request more than 1 fallback key
# per (user_id, device_id, algorithm).
(user_id, device_id, algorithm, True)
for user_id, device_id, algorithm, count in not_found
]
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
return (otk_results, *appservice_results, fallback_results)
return (otk_results, appservice_results, fallback_results)
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self,
query: Dict[str, Dict[str, Dict[str, int]]],
user: UserID,
timeout: Optional[int],
always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
"""
Args:
query: A chain of maps from (user_id, device_id, algorithm) to the requested
number of keys to claim.
user: The user who is claiming these keys.
timeout: How long to wait for any federation key claim requests before
giving up.
always_include_fallback_keys: always include a fallback key for local users'
devices, even if we managed to claim a one-time-key.
Returns: a heterogeneous dict with two keys:
one_time_keys: chain of maps user ID -> device ID -> key ID -> key.
failures: map from remote destination to a JsonDict describing the error.
"""
local_query: List[Tuple[str, str, str, int]] = []
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
for user_id, one_time_keys in query.items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
for device_id, algorithms in one_time_keys.items():
for algorithm, count in algorithms.items():
local_query.append((user_id, device_id, algorithm, count))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
......@@ -617,7 +743,9 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
results = await self.claim_local_one_time_keys(local_query)
results = await self.claim_local_one_time_keys(
local_query, always_include_fallback_keys
)
# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
......@@ -625,7 +753,9 @@ class E2eKeysHandler:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
json_result.setdefault(user_id, {})[device_id] = {key_id: key}
json_result.setdefault(user_id, {}).setdefault(
device_id, {}
).update({key_id: key})
# Remote failures.
failures: Dict[str, JsonDict] = {}
......@@ -636,7 +766,7 @@ class E2eKeysHandler:
device_keys = remote_queries[destination]
try:
remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
user, destination, device_keys, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
......@@ -677,36 +807,27 @@ class E2eKeysHandler:
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
keys: the body of a /keys/upload request.
Returns a dictionary with one field:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
await self.device_key_uploader(
user_id=user_id,
device_id=device_id,
keys={"device_keys": device_keys},
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
......@@ -742,60 +863,106 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id)
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
@tag_args
async def upload_device_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> None:
"""
Args:
user_id: user whose keys are being uploaded.
device_id: device whose keys are being uploaded.
device_keys: the `device_keys` of an /keys/upload request.
"""
# This can only be called from the main process.
assert isinstance(self.device_handler, DeviceHandler)
time_now = self.clock.time_msec()
device_keys = keys["device_keys"]
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
"Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
await self.device_handler.notify_device_update(user_id, [device_id])
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj))
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
await self.device_handler.check_device_registered(user_id, device_id)
# First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
async def _upload_one_time_keys_for_user(
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
# We take out a lock so that we don't have to worry about a client
# sending duplicate requests.
lock_key = f"{user_id}_{device_id}"
async with self._worker_lock_handler.acquire_lock(
ONE_TIME_KEY_UPLOAD, lock_key
):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
device_id,
user_id,
time_now,
)
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
for algorithm, key_id, key in key_list:
ex_json = existing_key_map.get((algorithm, key_id), None)
if ex_json:
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
(
"One time key %s:%s already exists. "
"Old key: %s; new key: %r"
)
% (algorithm, key_id, ex_json, key),
)
% (algorithm, key_id, ex_json, key),
else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
else:
new_keys.append(
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
)
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict
......@@ -1059,7 +1226,7 @@ class E2eKeysHandler:
user_id: str,
master_key_id: str,
signed_master_key: JsonDict,
stored_master_key: JsonDict,
stored_master_key: JsonMapping,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
......@@ -1210,7 +1377,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Tuple[JsonDict, str, VerifyKey]:
) -> Tuple[JsonMapping, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
......@@ -1265,7 +1432,7 @@ class E2eKeysHandler:
self,
user: UserID,
desired_key_type: str,
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
) -> Optional[Tuple[JsonMapping, str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
......@@ -1358,19 +1525,100 @@ class E2eKeysHandler:
return desired_key_data
async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
async def check_cross_signing_setup(self, user_id: str) -> Tuple[bool, bool]:
"""Checks if the user has cross-signing set up
Args:
user_id: The user to check
Returns: a 2-tuple of booleans
- whether the user has cross-signing set up, and
- whether the user's master cross-signing key may be replaced without UIA.
"""
(
exists,
ts_replacable_without_uia_before,
) = await self.store.get_master_cross_signing_key_updatable_before(user_id)
if ts_replacable_without_uia_before is None:
return exists, False
else:
return exists, self.clock.time_msec() < ts_replacable_without_uia_before
async def has_different_keys(self, user_id: str, body: JsonDict) -> bool:
"""
Check if a key provided in `body` differs from the same key stored in the DB. Returns
true on the first difference. If a key exists in `body` but does not exist in the DB,
returns True. If `body` has no keys, this always returns False.
Note by 'key' we mean Matrix key rather than JSON key.
The purpose of this function is to detect whether or not we need to apply UIA checks.
We must apply UIA checks if any key in the database is being overwritten. If a key is
being inserted for the first time, or if the key exactly matches what is in the database,
then no UIA check needs to be performed.
Args:
user_id: The user who sent the `body`.
body: The JSON request body from POST /keys/device_signing/upload
Returns:
True if the user has cross-signing set up, False otherwise
True if any key in `body` has a different value in the database.
"""
existing_master_key = await self.store.get_e2e_cross_signing_key(
user_id, "master"
)
return existing_master_key is not None
# Ensure that each key provided in the request body exactly matches the one we have stored.
# The first time we see the DB having a different key to the matching request key, bail.
# Note: we do not care if the DB has a key which the request does not specify, as we only
# care about *replacements* or *insertions* (i.e UPSERT)
req_body_key_to_db_key = {
"master_key": "master",
"self_signing_key": "self_signing",
"user_signing_key": "user_signing",
}
for req_body_key, db_key in req_body_key_to_db_key.items():
if req_body_key in body:
existing_key = await self.store.get_e2e_cross_signing_key(
user_id, db_key
)
if existing_key != body[req_body_key]:
return True
return False
async def _delete_old_one_time_keys_task(
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Scheduler task to delete old one time keys.
Until Synapse 1.119, Synapse used to issue one-time-keys in a random order, leading to the possibility
that it could still have old OTKs that the client has dropped. This task is scheduled exactly once
by a database schema delta file, and it clears out old one-time-keys that look like they came from libolm.
"""
last_user = task.result.get("from_user", "") if task.result else ""
while True:
# We process users in batches of 100
users, rowcount = await self.store.delete_old_otks_for_next_user_batch(
last_user, 100
)
if len(users) == 0:
# We're done!
return TaskStatus.COMPLETE, None, None
logger.debug(
"Deleted %i old one-time-keys for users '%s'..'%s'",
rowcount,
users[0],
users[-1],
)
last_user = users[-1]
# Store our progress
await self._task_scheduler.update_task(
task.id, result={"from_user": last_user}
)
# Sleep a little before doing the next user.
#
# matrix.org has about 15M users in the e2e_one_time_keys_json table
# (comprising 20M devices). We want this to take about a week, so we need
# to do about one batch of 100 users every 4 seconds.
await self.clock.sleep(4)
def _check_cross_signing_key(
......@@ -1406,7 +1654,7 @@ def _check_device_signature(
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
stored_device: JsonDict,
stored_device: JsonMapping,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
......@@ -1446,6 +1694,9 @@ def _check_device_signature(
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
_NOT_READY_FOR_RETRY_FAILURE = {"status": 503, "message": "Not ready for retry"}
def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
......@@ -1454,7 +1705,7 @@ def _exception_to_failure(e: Exception) -> JsonDict:
return {"status": e.code, "message": str(e)}
if isinstance(e, NotRetryingDestination):
return {"status": 503, "message": "Not ready for retry"}
return _NOT_READY_FOR_RETRY_FAILURE
# include ConnectionRefused and other errors
#
......
# Copyright 2017, 2018 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Optional, cast
from typing_extensions import Literal
from typing import TYPE_CHECKING, Dict, Literal, Optional, cast
from synapse.api.errors import (
Codes,
......@@ -28,7 +32,7 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, trace
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
from synapse.util.async_helpers import ReadWriteLock
if TYPE_CHECKING:
from synapse.server import HomeServer
......@@ -52,7 +56,7 @@ class E2eRoomKeysHandler:
# clients belonging to a user will receive and try to upload a new session at
# roughly the same time. Also used to lock out uploads when the key is being
# changed.
self._upload_linearizer = Linearizer("upload_room_keys_lock")
self._upload_lock = ReadWriteLock()
@trace
async def get_room_keys(
......@@ -83,7 +87,7 @@ class E2eRoomKeysHandler:
# we deliberately take the lock to get keys so that changing the version
# works atomically
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.read(user_id):
# make sure the backup version exists
try:
await self.store.get_e2e_room_keys_version_info(user_id, version)
......@@ -126,7 +130,7 @@ class E2eRoomKeysHandler:
"""
# lock for consistency with uploading
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.write(user_id):
# make sure the backup version exists
try:
version_info = await self.store.get_e2e_room_keys_version_info(
......@@ -187,7 +191,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.write(user_id):
# Check that the version we're trying to upload is the current version
try:
version_info = await self.store.get_e2e_room_keys_version_info(user_id)
......@@ -241,6 +245,12 @@ class E2eRoomKeysHandler:
if current_room_key:
if self._should_replace_room_key(current_room_key, room_key):
log_kv({"message": "Replacing room key."})
logger.debug(
"Replacing room key. room=%s session=%s user=%s",
room_id,
session_id,
user_id,
)
# updates are done one at a time in the DB, so send
# updates right away rather than batching them up,
# like we do with the inserts
......@@ -250,6 +260,12 @@ class E2eRoomKeysHandler:
changed = True
else:
log_kv({"message": "Not replacing room_key."})
logger.debug(
"Not replacing room key. room=%s session=%s user=%s",
room_id,
session_id,
user_id,
)
else:
log_kv(
{
......@@ -259,6 +275,12 @@ class E2eRoomKeysHandler:
}
)
log_kv({"message": "Replacing room key."})
logger.debug(
"Inserting new room key. room=%s session=%s user=%s",
room_id,
session_id,
user_id,
)
to_insert.append((room_id, session_id, room_key))
changed = True
......@@ -331,7 +353,7 @@ class E2eRoomKeysHandler:
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.write(user_id):
new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
......@@ -358,7 +380,7 @@ class E2eRoomKeysHandler:
}
"""
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.read(user_id):
try:
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
......@@ -383,7 +405,7 @@ class E2eRoomKeysHandler:
NotFoundError: if this backup version doesn't exist
"""
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.write(user_id):
try:
await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
......@@ -413,7 +435,7 @@ class E2eRoomKeysHandler:
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
async with self._upload_linearizer.queue(user_id):
async with self._upload_lock.write(user_id):
try:
old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Mapping, Optional, Union
......@@ -29,7 +36,7 @@ from synapse.event_auth import (
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, StrCollection, get_domain_from_id
from synapse.types import StateMap, StrCollection
if TYPE_CHECKING:
from synapse.server import HomeServer
......@@ -47,6 +54,7 @@ class EventAuthHandler:
self._store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
self._is_mine_id = hs.is_mine_id
async def check_auth_rules_from_context(
self,
......@@ -247,7 +255,7 @@ class EventAuthHandler:
if not await self.is_user_in_rooms(allowed_rooms, user_id):
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
if not self._is_mine_id(user_id):
for room_id in allowed_rooms:
if not await self._store.is_host_joined(room_id, self._server_name):
raise SynapseError(
......@@ -276,7 +284,7 @@ class EventAuthHandler:
True if the proper room version and join rules are set for restricted access.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
if not room_version.restricted_join_rule:
return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
......@@ -291,7 +299,7 @@ class EventAuthHandler:
return True
# also check for MSC3787 behaviour
if room_version.msc3787_knock_restricted_join_rule:
if room_version.knock_restricted_join_rule:
return content_join_rule == JoinRules.KNOCK_RESTRICTED
return False
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
......@@ -67,6 +74,7 @@ class EventStreamHandler:
context = await presence_handler.user_syncing(
requester.user.to_string(),
requester.device_id,
affect_presence=affect_presence,
presence_state=PresenceState.ONLINE,
)
......@@ -119,7 +127,7 @@ class EventStreamHandler:
events.extend(to_add)
chunks = self._event_serializer.serialize_events(
chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
......@@ -181,7 +189,10 @@ class EventHandler:
is_peeking = not is_user_in_room
filtered = await filter_events_for_client(
self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
self._storage_controllers,
user.to_string(),
[event],
is_peeking=is_peeking,
)
if not filtered:
......
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 Sorunome
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains handlers for federation events."""
......@@ -60,6 +67,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
......@@ -105,14 +113,12 @@ backfill_processing_before_timer = Histogram(
)
# TODO: We can refactor this away now that there is only one backfill point again
class _BackfillPointType(Enum):
# a regular backwards extremity (ie, an event which we don't yet have, but which
# is referred to by other events in the DAG)
BACKWARDS_EXTREMITY = enum.auto()
# an MSC2716 "insertion event"
INSERTION_PONT = enum.auto()
@attr.s(slots=True, auto_attribs=True, frozen=True)
class _BackfillPoint:
......@@ -141,18 +147,20 @@ class FederationHandler:
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker()
self.is_mine_server_name = hs.is_mine_server_name
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self._event_auth_handler = hs.get_event_auth_handler()
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_proxied_blacklisted_http_client()
self.http_client = hs.get_proxied_blocklisted_http_client()
self._replication = hs.get_replication_data_handler()
self._federation_event_handler = hs.get_federation_event_handler()
self._device_handler = hs.get_device_handler()
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
self._notifier = hs.get_notifier()
self._worker_locks = hs.get_worker_locks_handler()
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
......@@ -169,7 +177,9 @@ class FederationHandler:
self._room_backfill = Linearizer("room_backfill")
self.third_party_event_rules = hs.get_third_party_event_rules()
self._third_party_event_rules = (
hs.get_module_api_callbacks().third_party_event_rules
)
# Tracks running partial state syncs by room ID.
# Partial state syncs currently only run on the main process, so it's okay to
......@@ -197,8 +207,9 @@ class FederationHandler:
)
@trace
@tag_args
async def maybe_backfill(
self, room_id: str, current_depth: int, limit: int
self, room_id: str, current_depth: int, limit: int, record_time: bool = True
) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
......@@ -211,19 +222,28 @@ class FederationHandler:
limit: The number of events that the pagination request will
return. This is used as part of the heuristic to decide if we
should back paginate.
record_time: Whether to record the time it takes to backfill.
Returns:
True if we actually tried to backfill something, otherwise False.
"""
# Starting the processing time here so we can include the room backfill
# linearizer lock queue in the timing
processing_start_time = self.clock.time_msec()
processing_start_time = self.clock.time_msec() if record_time else 0
async with self._room_backfill.queue(room_id):
return await self._maybe_backfill_inner(
room_id,
current_depth,
limit,
processing_start_time=processing_start_time,
)
async with self._worker_locks.acquire_read_write_lock(
PURGE_PAGINATION_LOCK_NAME, room_id, write=False
):
return await self._maybe_backfill_inner(
room_id,
current_depth,
limit,
processing_start_time=processing_start_time,
)
@trace
@tag_args
async def _maybe_backfill_inner(
self,
room_id: str,
......@@ -244,6 +264,9 @@ class FederationHandler:
limit: The max number of events to request from the remote federated server.
processing_start_time: The time when `maybe_backfill` started processing.
Only used for timing. If `None`, no timing observation will be made.
Returns:
True if we actually tried to backfill something, otherwise False.
"""
backwards_extremities = [
_BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
......@@ -261,32 +284,10 @@ class FederationHandler:
)
]
insertion_events_to_be_backfilled: List[_BackfillPoint] = []
if self.hs.config.experimental.msc2716_enabled:
insertion_events_to_be_backfilled = [
_BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
room_id=room_id,
current_depth=current_depth,
# We only need to end up with 5 extremities combined with
# the backfill points to make the `/backfill` request ...
# (see the other comment above for more context).
limit=50,
)
]
logger.debug(
"_maybe_backfill_inner: backwards_extremities=%s insertion_events_to_be_backfilled=%s",
backwards_extremities,
insertion_events_to_be_backfilled,
)
# we now have a list of potential places to backpaginate from. We prefer to
# start with the most recent (ie, max depth), so let's sort the list.
sorted_backfill_points: List[_BackfillPoint] = sorted(
itertools.chain(
backwards_extremities,
insertion_events_to_be_backfilled,
),
backwards_extremities,
key=lambda e: -int(e.depth),
)
......@@ -299,15 +300,39 @@ class FederationHandler:
len(sorted_backfill_points),
sorted_backfill_points,
)
set_tag(
SynapseTags.RESULT_PREFIX + "sorted_backfill_points",
str(sorted_backfill_points),
)
set_tag(
SynapseTags.RESULT_PREFIX + "sorted_backfill_points.length",
str(len(sorted_backfill_points)),
)
# If we have no backfill points lower than the `current_depth` then
# either we can a) bail or b) still attempt to backfill. We opt to try
# backfilling anyway just in case we do get relevant events.
# If we have no backfill points lower than the `current_depth` then either we
# can a) bail or b) still attempt to backfill. We opt to try backfilling anyway
# just in case we do get relevant events. This is good for eventual consistency
# sake but we don't need to block the client for something that is just as
# likely not to return anything relevant so we backfill in the background. The
# only way, this could return something relevant is if we discover a new branch
# of history that extends all the way back to where we are currently paginating
# and it's within the 100 events that are returned from `/backfill`.
if not sorted_backfill_points and current_depth != MAX_DEPTH:
# Check that we actually have later backfill points, if not just return.
have_later_backfill_points = await self.store.get_backfill_points_in_room(
room_id=room_id,
current_depth=MAX_DEPTH,
limit=1,
)
if not have_later_backfill_points:
return False
logger.debug(
"_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
)
return await self._maybe_backfill_inner(
run_as_background_process(
"_maybe_backfill_inner_anyway_with_max_depth",
self.maybe_backfill,
room_id=room_id,
# We use `MAX_DEPTH` so that we find all backfill points next
# time (all events are below the `MAX_DEPTH`)
......@@ -316,8 +341,11 @@ class FederationHandler:
# We don't want to start another timing observation from this
# nested recursive call. The top-most call can record the time
# overall otherwise the smaller one will throw off the results.
processing_start_time=None,
record_time=False,
)
# We return `False` because we're backfilling in the background and there is
# no new events immediately for the caller to know about yet.
return False
# Even after recursing with `MAX_DEPTH`, we didn't find any
# backward extremities to backfill from.
......@@ -381,10 +409,7 @@ class FederationHandler:
# event but not anything before it. This would require looking at the
# state *before* the event, ignoring the special casing certain event
# types have.
if bp.type == _BackfillPointType.INSERTION_PONT:
event_ids_to_check = [bp.event_id]
else:
event_ids_to_check = await self.store.get_successor_events(bp.event_id)
event_ids_to_check = await self.store.get_successor_events(bp.event_id)
events_to_check = await self.store.get_events_as_list(
event_ids_to_check,
......@@ -451,7 +476,7 @@ class FederationHandler:
for dom in domains:
# We don't want to ask our own server for information we don't have
if dom == self.server_name:
if self.is_mine_server_name(dom):
continue
try:
......@@ -850,19 +875,13 @@ class FederationHandler:
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
stripped_room_state = (
knock_response.get("knock_room_state")
# Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
# Thus, we also check for a 'knock_state_events' to support old instances.
# See https://github.com/matrix-org/synapse/issues/14088.
or knock_response.get("knock_state_events")
)
stripped_room_state = knock_response.get("knock_room_state")
if stripped_room_state is None:
raise KeyError(
"Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
"send_knock response"
)
raise KeyError("Missing 'knock_room_state' field in send_knock response")
if not isinstance(stripped_room_state, list):
raise TypeError("'knock_room_state' has wrong type")
event.unsigned["knock_room_state"] = stripped_room_state
......@@ -954,7 +973,7 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
prev_event_ids = None
if room_version.msc3083_join_rules:
if room_version.restricted_join_rule:
# Note that the room's state can change out from under us and render our
# nice join rules-conformant event non-conformant by the time we build the
# event. When this happens, our validation at the end fails and we respond
......@@ -1042,7 +1061,7 @@ class FederationHandler:
if self.hs.config.server.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
spam_check = await self.spam_checker.user_may_invite(
spam_check = await self._spam_checker_module_callbacks.user_may_invite(
event.sender, event.state_key, event.room_id
)
if spam_check != NOT_SPAM:
......@@ -1253,7 +1272,7 @@ class FederationHandler:
unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event_allowed, _ = await self._third_party_event_rules.check_event_allowed(
event, unpersisted_context
)
if not event_allowed:
......@@ -1436,7 +1455,7 @@ class FederationHandler:
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
EventValidator().validate_builder(builder, self.hs.config)
# Try several times, it could fail with PartialStateConflictError
# in send_membership_event, cf comment in except block.
......@@ -1488,7 +1507,6 @@ class FederationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
pass
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
......@@ -1564,7 +1582,6 @@ class FederationHandler:
# in the meantime and context needs to be recomputed, so let's do so.
if i == max_retries - 1:
raise e
pass
async def add_display_name_to_third_party_invite(
self,
......@@ -1578,9 +1595,7 @@ class FederationHandler:
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
)
prev_state_ids = await context.get_prev_state_ids(StateFilter.from_types([key]))
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = await self.store.get_event(
......@@ -1605,7 +1620,7 @@ class FederationHandler:
builder = self.event_builder_factory.for_room_version(
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
EventValidator().validate_builder(builder, self.hs.config)
(
event,
......@@ -1633,7 +1648,7 @@ class FederationHandler:
token = signed["token"]
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
StateFilter.from_types([(EventTypes.ThirdPartyInvite, token)])
)
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import itertools
......@@ -70,7 +77,9 @@ from synapse.logging.opentracing import (
trace,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.devices import (
ReplicationMultiUserDevicesResyncRestServlet,
)
from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
......@@ -86,7 +95,7 @@ from synapse.types import (
)
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.iterutils import batch_iter
from synapse.util.iterutils import batch_iter, partition, sorted_topologically
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
......@@ -142,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._state_deletion_store = hs.get_datastores().state_deletion
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
......@@ -155,10 +166,13 @@ class FederationEventHandler:
self._get_room_member_handler = hs.get_room_member_handler
self._federation_client = hs.get_federation_client()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._third_party_event_rules = (
hs.get_module_api_callbacks().third_party_event_rules
)
self._notifier = hs.get_notifier()
self._is_mine_id = hs.is_mine_id
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()
......@@ -167,8 +181,8 @@ class FederationEventHandler:
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
if hs.config.worker.worker_app:
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
self._multi_user_device_resync = (
ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
......@@ -568,7 +582,9 @@ class FederationEventHandler:
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
......@@ -596,18 +612,6 @@ class FederationEventHandler:
room_id, [(event, context)]
)
# If we're joining the room again, check if there is new marker
# state indicating that there is new history imported somewhere in
# the DAG. Multiple markers can exist in the current state with
# unique state_keys.
#
# Do this after the state from the remote join was persisted (via
# `persist_events_and_notify`). Otherwise we can run into a
# situation where the create event doesn't exist yet in the
# `current_state_events`
for e in state:
await self._handle_marker_event(origin, e)
return stream_id_after_persist
async def update_state_for_partial_state_event(
......@@ -684,7 +688,7 @@ class FederationEventHandler:
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
"""
if dest == self._server_name:
if self._is_mine_server_name(dest):
raise SynapseError(400, "Can't backfill from self.")
events = await self._federation_client.backfill(
......@@ -730,12 +734,11 @@ class FederationEventHandler:
if not prevs - seen:
return
latest_list = await self._store.get_latest_event_ids_in_room(room_id)
latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest_list)
latest |= seen
latest = seen | latest_frozen
logger.info(
"Requesting missing events between %s and %s",
......@@ -756,7 +759,7 @@ class FederationEventHandler:
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# causing additional state resolution? See https://github.com/matrix-org/synapse/issues/1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
......@@ -860,7 +863,7 @@ class FederationEventHandler:
[event.event_id for event in events]
)
new_events = []
new_events: List[EventBase] = []
for event in events:
event_id = event.event_id
......@@ -885,12 +888,64 @@ class FederationEventHandler:
# Continue on with the events that are new to us.
new_events.append(event)
# We want to sort these by depth so we process them and
# tell clients about them in order.
sorted_events = sorted(new_events, key=lambda x: x.depth)
for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev, backfilled=backfilled)
set_tag(
SynapseTags.RESULT_PREFIX + "new_events.length",
str(len(new_events)),
)
@trace
async def _process_new_pulled_events(new_events: Collection[EventBase]) -> None:
# We want to sort these by depth so we process them and tell clients about
# them in order. It's also more efficient to backfill this way (`depth`
# ascending) because one backfill event is likely to be the `prev_event` of
# the next event we're going to process.
sorted_events = sorted(new_events, key=lambda x: x.depth)
for ev in sorted_events:
with nested_logging_context(ev.event_id):
await self._process_pulled_event(origin, ev, backfilled=backfilled)
# Check if we've already tried to process these events at some point in the
# past. We aren't concerned with the expontntial backoff here, just whether it
# has failed to be processed before.
event_ids_with_failed_pull_attempts = (
await self._store.get_event_ids_with_failed_pull_attempts(
[event.event_id for event in new_events]
)
)
events_with_failed_pull_attempts, fresh_events = partition(
new_events, lambda e: e.event_id in event_ids_with_failed_pull_attempts
)
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "events_with_failed_pull_attempts",
str(event_ids_with_failed_pull_attempts),
)
set_tag(
SynapseTags.RESULT_PREFIX + "events_with_failed_pull_attempts.length",
str(len(events_with_failed_pull_attempts)),
)
set_tag(
SynapseTags.FUNC_ARG_PREFIX + "fresh_events",
str([event.event_id for event in fresh_events]),
)
set_tag(
SynapseTags.RESULT_PREFIX + "fresh_events.length",
str(len(fresh_events)),
)
# Process previously failed backfill events in the background to not waste
# time on something that is likely to fail again.
if len(events_with_failed_pull_attempts) > 0:
run_as_background_process(
"_process_new_pulled_events_with_failed_pull_attempts",
_process_new_pulled_events,
events_with_failed_pull_attempts,
)
# We can optimistically try to process and wait for the event to be fully
# persisted if we've never tried before.
if len(fresh_events) > 0:
await _process_new_pulled_events(fresh_events)
@trace
@tag_args
......@@ -1091,16 +1146,8 @@ class FederationEventHandler:
partial_state_flags = await self._store.get_partial_state_events(seen)
partial_state = any(partial_state_flags.values())
# Get the state of the events we know about
ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen, await_full_state=False
)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
state_maps: List[StateMap[str]] = []
# Ask the remote server for the states we don't
# know about
......@@ -1119,13 +1166,26 @@ class FederationEventHandler:
state_maps.append(remote_state_map)
# Get the state of the events we know about. We do this *after*
# trying to fetch missing state over federation as that might fail
# and then we can skip loading the local state.
ours = await self._state_storage_controller.get_state_groups_ids(
room_id, seen, await_full_state=False
)
state_maps.extend(ours.values())
# we don't need this any more, let's delete it.
del ours
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
except Exception as e:
......@@ -1313,9 +1373,9 @@ class FederationEventHandler:
)
if remote_event.is_state() and remote_event.rejected_reason is None:
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
state_map[(remote_event.type, remote_event.state_key)] = (
remote_event.event_id
)
return state_map
......@@ -1396,8 +1456,6 @@ class FederationEventHandler:
await self._run_push_actions_and_persist_event(event, context, backfilled)
await self._handle_marker_event(origin, event)
if backfilled or context.rejected:
return
......@@ -1487,99 +1545,16 @@ class FederationEventHandler:
# Immediately attempt a resync in the background
if self._config.worker.worker_app:
await self._user_device_resync(user_id=sender)
await self._multi_user_device_resync(user_ids=[sender])
else:
await self._device_list_updater.user_device_resync(sender)
await self._device_list_updater.multi_user_device_resync(
user_ids=[sender]
)
except Exception:
logger.exception("Failed to resync device for %s", sender)
@trace
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
Args:
origin: Origin of the event. Will be called to get the insertion event
marker_event: The event to process
"""
if marker_event.type != EventTypes.MSC2716_MARKER:
# Not a marker event
return
if marker_event.rejected_reason is not None:
# Rejected event
return
# Skip processing a marker event if the room version doesn't
# support it or the event is not from the room creator.
room_version = await self._store.get_room_version(marker_event.room_id)
create_event = await self._store.get_create_event_for_room(marker_event.room_id)
room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
if not room_version.msc2716_historical and (
not self._config.experimental.msc2716_enabled
or marker_event.sender != room_creator
):
return
logger.debug("_handle_marker_event: received %s", marker_event)
insertion_event_id = marker_event.content.get(
EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
)
if insertion_event_id is None:
# Nothing to retrieve then (invalid marker)
return
already_seen_insertion_event = await self._store.have_seen_event(
marker_event.room_id, insertion_event_id
)
if already_seen_insertion_event:
# No need to process a marker again if we have already seen the
# insertion event that it was pointing to
return
logger.debug(
"_handle_marker_event: backfilling insertion event %s", insertion_event_id
)
await self._get_events_and_persist(
origin,
marker_event.room_id,
[insertion_event_id],
)
insertion_event = await self._store.get_event(
insertion_event_id, allow_none=True
)
if insertion_event is None:
logger.warning(
"_handle_marker_event: server %s didn't return insertion event %s for marker %s",
origin,
insertion_event_id,
marker_event.event_id,
)
return
logger.debug(
"_handle_marker_event: succesfully backfilled insertion event %s from marker event %s",
insertion_event,
marker_event,
)
await self._store.insert_insertion_extremity(
insertion_event_id, marker_event.room_id
)
logger.debug(
"_handle_marker_event: insertion extremity added for %s from marker event %s",
insertion_event,
marker_event,
)
async def backfill_event_id(
self, destinations: List[str], room_id: str, event_id: str
self, destinations: StrCollection, room_id: str, event_id: str
) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
......@@ -1710,64 +1685,39 @@ class FederationEventHandler:
# XXX: it might be possible to kick this process off in parallel with fetching
# the events.
while event_map:
# build a list of events whose auth events are not in the queue.
roots = tuple(
ev
for ev in event_map.values()
if not any(aid in event_map for aid in ev.auth_event_ids())
)
if not roots:
# if *none* of the remaining events are ready, that means
# we have a loop. This either means a bug in our logic, or that
# somebody has managed to create a loop (which requires finding a
# hash collision in room v2 and later).
logger.warning(
"Loop found in auth events while fetching missing state/auth "
"events: %s",
shortstr(event_map.keys()),
)
return
logger.info(
"Persisting %i of %i remaining outliers: %s",
len(roots),
len(event_map),
shortstr(e.event_id for e in roots),
)
await self._auth_and_persist_outliers_inner(room_id, roots)
for ev in roots:
del event_map[ev.event_id]
async def _auth_and_persist_outliers_inner(
self, room_id: str, fetched_events: Collection[EventBase]
) -> None:
"""Helper for _auth_and_persist_outliers
Persists a batch of events where we have (theoretically) already persisted all
of their auth events.
Marks the events as outliers, auths them, persists them to the database, and,
where appropriate (eg, an invite), awakes the notifier.
# We need to persist an event's auth events before the event.
auth_graph = {
ev.event_id: [e_id for e_id in ev.auth_event_ids() if e_id in event_map]
for ev in event_map.values()
}
sorted_auth_event_ids = sorted_topologically(event_map.keys(), auth_graph)
sorted_auth_events = [event_map[e_id] for e_id in sorted_auth_event_ids]
logger.info(
"Persisting %i remaining outliers: %s",
len(sorted_auth_events),
shortstr(e.event_id for e in sorted_auth_events),
)
Params:
origin: where the events came from
room_id: the room that the events are meant to be in (though this has
not yet been checked)
fetched_events: the events to persist
"""
# get all the auth events for all the events in this batch. By now, they should
# have been persisted.
auth_events = {
aid for event in fetched_events for aid in event.auth_event_ids()
auth_event_ids = {
aid for event in sorted_auth_events for aid in event.auth_event_ids()
}
persisted_events = await self._store.get_events(
auth_events,
allow_rejected=True,
)
auth_map = {
ev.event_id: ev
for ev in sorted_auth_events
if ev.event_id in auth_event_ids
}
missing_events = auth_event_ids.difference(auth_map)
if missing_events:
persisted_events = await self._store.get_events(
missing_events,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
auth_map.update(persisted_events)
events_and_contexts_to_persist: List[Tuple[EventBase, EventContext]] = []
......@@ -1775,7 +1725,7 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
auth = []
for auth_event_id in event.auth_event_ids():
ae = persisted_events.get(auth_event_id)
ae = auth_map.get(auth_event_id)
if not ae:
# the fact we can't find the auth event doesn't mean it doesn't
# exist, which means it is premature to reject `event`. Instead we
......@@ -1794,7 +1744,9 @@ class FederationEventHandler:
context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(event)
await check_state_independent_auth_rules(self._store, event)
await check_state_independent_auth_rules(
self._store, event, batched_auth_events=auth_map
)
check_state_dependent_auth_rules(event, auth)
except AuthError as e:
logger.warning("Rejecting %r because %s", event, e)
......@@ -1811,17 +1763,25 @@ class FederationEventHandler:
events_and_contexts_to_persist.append((event, context))
for event in fetched_events:
for i, event in enumerate(sorted_auth_events):
await prep(event)
await self.persist_events_and_notify(
room_id,
events_and_contexts_to_persist,
# Mark these events backfilled as they're historic events that will
# eventually be backfilled. For example, missing events we fetch
# during backfill should be marked as backfilled as well.
backfilled=True,
)
# The above function is typically not async, and so won't yield to
# the reactor. For large rooms let's yield to the reactor
# occasionally to ensure we don't block other work.
if (i + 1) % 1000 == 0:
await self._clock.sleep(0)
# Also persist the new event in batches for similar reasons as above.
for batch in batch_iter(events_and_contexts_to_persist, 1000):
await self.persist_events_and_notify(
room_id,
batch,
# Mark these events as backfilled as they're historic events that will
# eventually be backfilled. For example, missing events we fetch
# during backfill should be marked as backfilled as well.
backfilled=True,
)
@trace
async def _check_event_auth(
......@@ -1920,7 +1880,9 @@ class FederationEventHandler:
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
......@@ -2016,8 +1978,7 @@ class FederationEventHandler:
# partial and full state and may not be accurate.
return
extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
......@@ -2061,7 +2022,9 @@ class FederationEventHandler:
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_deletion_store
),
)
)
else:
......@@ -2319,8 +2282,9 @@ class FederationEventHandler:
event_and_contexts, backfilled=backfilled
)
# After persistence we always need to notify replication there may
# be new data.
# After persistence, we never notify clients (wake up `/sync` streams) about
# backfilled events but it's important to let all the workers know about any
# new event (backfilled or not) because TODO
self._notifier.notify_replication()
if self._ephemeral_messages_enabled:
......@@ -2384,6 +2348,12 @@ class FederationEventHandler:
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
# If this is a server ACL event, clear the cache in the storage controller.
if event.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(event.room_id,)
)
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
......
# Copyright 2015, 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
import logging
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
import attr
from synapse.api.errors import (
CodeMessageException,
Codes,
......@@ -52,10 +61,10 @@ class IdentityHandler:
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
self._http_client = SimpleHttpClient(
hs,
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
ip_whitelist=hs.config.server.federation_ip_range_whitelist,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
ip_allowlist=hs.config.server.federation_ip_range_allowlist,
)
self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
......@@ -66,14 +75,12 @@ class IdentityHandler:
self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
cfg=hs.config.ratelimiting.rc_3pid_validation,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
cfg=hs.config.ratelimiting.rc_3pid_validation,
)
async def ratelimit_request_token_requests(
......@@ -197,7 +204,7 @@ class IdentityHandler:
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
data = await self.blacklisting_http_client.post_json_get_json(
data = await self._http_client.post_json_get_json(
bind_url, bind_data, headers=headers
)
......@@ -308,9 +315,7 @@ class IdentityHandler:
try:
# Use the blacklisting http client as this call is only to identity servers
# provided by a client
await self.blacklisting_http_client.post_json_get_json(
url, content, headers
)
await self._http_client.post_json_get_json(url, content, headers)
changed = True
except HttpResponseException as e:
changed = False
......@@ -361,9 +366,9 @@ class IdentityHandler:
# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
......@@ -484,7 +489,6 @@ class IdentityHandler:
# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None
# Try to validate as email
if self.hs.config.email.can_verify_email:
......@@ -492,19 +496,18 @@ class IdentityHandler:
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)
if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)
# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
return validation_session
return None
async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
......@@ -579,7 +582,7 @@ class IdentityHandler:
"""
# Check what hashing details are supported by this identity server
try:
hash_details = await self.blacklisting_http_client.get_json(
hash_details = await self._http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
......@@ -646,7 +649,7 @@ class IdentityHandler:
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = await self.blacklisting_http_client.post_json_get_json(
lookup_results = await self._http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
......@@ -752,7 +755,7 @@ class IdentityHandler:
url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
try:
data = await self.blacklisting_http_client.post_json_get_json(
data = await self._http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
......