Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • maunium/synapse
  • leytilera/synapse
2 results
Show changes
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # [This file includes modifications made by New Vector Limited]
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
...@@ -27,13 +35,16 @@ logger = logging.getLogger(__name__) ...@@ -27,13 +35,16 @@ logger = logging.getLogger(__name__)
class ReadMarkerHandler: class ReadMarkerHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.server_name = hs.config.server.server_name
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.account_data_handler = hs.get_account_data_handler() self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")
async def received_client_read_marker( async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str self,
room_id: str,
user_id: str,
event_id: str,
extra_content: Optional[JsonDict] = None,
) -> None: ) -> None:
"""Updates the read marker for a given user in a given room if the event ID given """Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker. is ahead in the stream relative to the current read marker.
...@@ -49,12 +60,12 @@ class ReadMarkerHandler: ...@@ -49,12 +60,12 @@ class ReadMarkerHandler:
should_update = True should_update = True
# Get event ordering, this also ensures we know about the event # Get event ordering, this also ensures we know about the event
event_ordering = await self.store.get_event_ordering(event_id) event_ordering = await self.store.get_event_ordering(event_id, room_id)
if existing_read_marker: if existing_read_marker:
try: try:
old_event_ordering = await self.store.get_event_ordering( old_event_ordering = await self.store.get_event_ordering(
existing_read_marker["event_id"] existing_read_marker["event_id"], room_id
) )
except SynapseError: except SynapseError:
# Old event no longer exists, assume new is ahead. This may # Old event no longer exists, assume new is ahead. This may
...@@ -65,7 +76,7 @@ class ReadMarkerHandler: ...@@ -65,7 +76,7 @@ class ReadMarkerHandler:
should_update = event_ordering > old_event_ordering should_update = event_ordering > old_event_ordering
if should_update: if should_update:
content = {"event_id": event_id} content = {"event_id": event_id, **(extra_content or {})}
await self.account_data_handler.add_account_data_to_room( await self.account_data_handler.add_account_data_to_room(
user_id, room_id, ReceiptTypes.FULLY_READ, content user_id, room_id, ReceiptTypes.FULLY_READ, content
) )
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # See the GNU Affero General Public License for more details:
# you may not use this file except in compliance with the License. # <https://www.gnu.org/licenses/agpl-3.0.html>.
# You may obtain a copy of the License at #
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
...@@ -19,6 +26,8 @@ from synapse.appservice import ApplicationService ...@@ -19,6 +26,8 @@ from synapse.appservice import ApplicationService
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
JsonMapping,
MultiWriterStreamToken,
ReadReceipt, ReadReceipt,
StreamKeyType, StreamKeyType,
UserID, UserID,
...@@ -37,6 +46,8 @@ class ReceiptsHandler: ...@@ -37,6 +46,8 @@ class ReceiptsHandler:
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()
self.event_handler = hs.get_event_handler()
self._storage_controllers = hs.get_storage_controllers()
self.hs = hs self.hs = hs
...@@ -81,6 +92,20 @@ class ReceiptsHandler: ...@@ -81,6 +92,20 @@ class ReceiptsHandler:
) )
continue continue
# Let's check that the origin server is in the room before accepting the receipt.
# We don't want to block waiting on a partial state so take an
# approximation if needed.
domains = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
if origin not in domains:
logger.info(
"Ignoring receipt for room %r from server %s as they're not in the room",
room_id,
origin,
)
continue
for receipt_type, users in room_values.items(): for receipt_type, users in room_values.items():
for user_id, user_values in users.items(): for user_id, user_values in users.items():
if get_domain_from_id(user_id) != origin: if get_domain_from_id(user_id) != origin:
...@@ -113,11 +138,10 @@ class ReceiptsHandler: ...@@ -113,11 +138,10 @@ class ReceiptsHandler:
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.""" """Takes a list of receipts, stores them and informs the notifier."""
min_batch_id: Optional[int] = None
max_batch_id: Optional[int] = None
receipts_persisted: List[ReadReceipt] = []
for receipt in receipts: for receipt in receipts:
res = await self.store.insert_receipt( stream_id = await self.store.insert_receipt(
receipt.room_id, receipt.room_id,
receipt.receipt_type, receipt.receipt_type,
receipt.user_id, receipt.user_id,
...@@ -126,30 +150,26 @@ class ReceiptsHandler: ...@@ -126,30 +150,26 @@ class ReceiptsHandler:
receipt.data, receipt.data,
) )
if not res: if stream_id is None:
# res will be None if this receipt is 'old' # stream_id will be None if this receipt is 'old'
continue continue
stream_id, max_persisted_id = res receipts_persisted.append(receipt)
if min_batch_id is None or stream_id < min_batch_id:
min_batch_id = stream_id
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
# Either both of these should be None or neither. if not receipts_persisted:
if min_batch_id is None or max_batch_id is None:
# no new receipts # no new receipts
return False return False
affected_room_ids = list({r.room_id for r in receipts}) max_batch_id = self.store.get_max_receipt_stream_id()
affected_room_ids = list({r.room_id for r in receipts_persisted})
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
) )
# Note that the min here shouldn't be relied upon to be accurate. # Note that the min here shouldn't be relied upon to be accurate.
await self.hs.get_pusherpool().on_new_receipts( await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids {r.user_id for r in receipts_persisted}
) )
return True return True
...@@ -158,20 +178,27 @@ class ReceiptsHandler: ...@@ -158,20 +178,27 @@ class ReceiptsHandler:
self, self,
room_id: str, room_id: str,
receipt_type: str, receipt_type: str,
user_id: str, user_id: UserID,
event_id: str, event_id: str,
thread_id: Optional[str], thread_id: Optional[str],
extra_content: Optional[JsonDict] = None,
) -> None: ) -> None:
"""Called when a client tells us a local user has read up to the given """Called when a client tells us a local user has read up to the given
event_id in the room. event_id in the room.
""" """
# Ensure the room/event exists, this will raise an error if the user
# cannot view the event.
if not await self.event_handler.get_event(user_id, room_id, event_id):
return
receipt = ReadReceipt( receipt = ReadReceipt(
room_id=room_id, room_id=room_id,
receipt_type=receipt_type, receipt_type=receipt_type,
user_id=user_id, user_id=user_id.to_string(),
event_ids=[event_id], event_ids=[event_id],
thread_id=thread_id, thread_id=thread_id,
data={"ts": int(self.clock.time_msec())}, data={"ts": int(self.clock.time_msec()), **(extra_content or {})},
) )
is_new = await self._handle_new_receipts([receipt]) is_new = await self._handle_new_receipts([receipt])
...@@ -182,15 +209,15 @@ class ReceiptsHandler: ...@@ -182,15 +209,15 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource(EventSource[int, JsonDict]): class ReceiptEventSource(EventSource[MultiWriterStreamToken, JsonMapping]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.config = hs.config self.config = hs.config
@staticmethod @staticmethod
def filter_out_private_receipts( def filter_out_private_receipts(
rooms: Sequence[JsonDict], user_id: str rooms: Sequence[JsonMapping], user_id: str
) -> List[JsonDict]: ) -> List[JsonMapping]:
""" """
Filters a list of serialized receipts (as returned by /sync and /initialSync) Filters a list of serialized receipts (as returned by /sync and /initialSync)
and removes private read receipts of other users. and removes private read receipts of other users.
...@@ -207,7 +234,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ...@@ -207,7 +234,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
The same as rooms, but filtered. The same as rooms, but filtered.
""" """
result = [] result: List[JsonMapping] = []
# Iterate through each room's receipt content. # Iterate through each room's receipt content.
for room in rooms: for room in rooms:
...@@ -255,14 +282,19 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ...@@ -255,14 +282,19 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
async def get_new_events( async def get_new_events(
self, self,
user: UserID, user: UserID,
from_key: int, from_key: MultiWriterStreamToken,
limit: int, limit: int,
room_ids: Iterable[str], room_ids: Iterable[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: to_key: Optional[MultiWriterStreamToken] = None,
from_key = int(from_key) ) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
to_key = self.get_current_key() """
Find read receipts for given rooms (> `from_token` and <= `to_token`)
"""
if to_key is None:
to_key = self.get_current_key()
if from_key == to_key: if from_key == to_key:
return [], to_key return [], to_key
...@@ -278,8 +310,11 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ...@@ -278,8 +310,11 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key return events, to_key
async def get_new_events_as( async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService self,
) -> Tuple[List[JsonDict], int]: from_key: MultiWriterStreamToken,
to_key: MultiWriterStreamToken,
service: ApplicationService,
) -> Tuple[List[JsonMapping], MultiWriterStreamToken]:
"""Returns a set of new read receipt events that an appservice """Returns a set of new read receipt events that an appservice
may be interested in. may be interested in.
...@@ -294,8 +329,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ...@@ -294,8 +329,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
appservice may be interested in. appservice may be interested in.
* The current read receipt stream token. * The current read receipt stream token.
""" """
from_key = int(from_key)
if from_key == to_key: if from_key == to_key:
return [], to_key return [], to_key
...@@ -315,5 +348,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ...@@ -315,5 +348,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key return events, to_key
def get_current_key(self) -> int: def get_current_key(self) -> MultiWriterStreamToken:
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()
# Copyright 2014 - 2016 OpenMarket Ltd #
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # This program is free software: you can redistribute it and/or modify
# you may not use this file except in compliance with the License. # it under the terms of the GNU Affero General Public License as
# You may obtain a copy of the License at # published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, TypedDict
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import TypedDict
from synapse import types from synapse import types
from synapse.api.constants import ( from synapse.api.constants import (
...@@ -141,27 +147,25 @@ class RegistrationHandler: ...@@ -141,27 +147,25 @@ class RegistrationHandler:
localpart: str, localpart: str,
guest_access_token: Optional[str] = None, guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None, assigned_user_id: Optional[str] = None,
allow_invalid: bool = False,
inhibit_user_in_use_error: bool = False, inhibit_user_in_use_error: bool = False,
) -> None: ) -> None:
if types.contains_invalid_mxid_characters( # meow: allow admins to register invalid user ids
localpart, self.hs.config.experimental.msc4009_e164_mxids if not allow_invalid:
): if types.contains_invalid_mxid_characters(localpart):
extra_chars = ( raise SynapseError(
"=_-./+" if self.hs.config.experimental.msc4009_e164_mxids else "=_-./" 400,
) "User ID can only contain characters a-z, 0-9, or '=_-./+'",
raise SynapseError( Codes.INVALID_USERNAME,
400, )
f"User ID can only contain characters a-z, 0-9, or '{extra_chars}'",
Codes.INVALID_USERNAME,
)
if not localpart: if not localpart:
raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME) raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME)
if localpart[0] == "_": if localpart[0] == "_":
raise SynapseError( raise SynapseError(
400, "User ID may not begin with _", Codes.INVALID_USERNAME 400, "User ID may not begin with _", Codes.INVALID_USERNAME
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
...@@ -175,14 +179,16 @@ class RegistrationHandler: ...@@ -175,14 +179,16 @@ class RegistrationHandler:
"A different user ID has already been registered for this session", "A different user ID has already been registered for this session",
) )
self.check_user_id_not_appservice_exclusive(user_id) # meow: allow admins to register reserved user ids and long user ids
if not allow_invalid:
self.check_user_id_not_appservice_exclusive(user_id)
if len(user_id) > MAX_USERID_LENGTH: if len(user_id) > MAX_USERID_LENGTH:
raise SynapseError( raise SynapseError(
400, 400,
"User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,), "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,),
Codes.INVALID_USERNAME, Codes.INVALID_USERNAME,
) )
users = await self.store.get_users_by_id_case_insensitive(user_id) users = await self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
...@@ -288,7 +294,12 @@ class RegistrationHandler: ...@@ -288,7 +294,12 @@ class RegistrationHandler:
await self.auth_blocking.check_auth_blocking(threepid=threepid) await self.auth_blocking.check_auth_blocking(threepid=threepid)
if localpart is not None: if localpart is not None:
await self.check_username(localpart, guest_access_token=guest_access_token) allow_invalid = by_admin and self.hs.config.meow.admin_api_register_invalid
await self.check_username(
localpart,
guest_access_token=guest_access_token,
allow_invalid=allow_invalid,
)
was_guest = guest_access_token is not None was_guest = guest_access_token is not None
...@@ -315,7 +326,7 @@ class RegistrationHandler: ...@@ -315,7 +326,7 @@ class RegistrationHandler:
approved=approved, approved=approved,
) )
profile = await self.store.get_profileinfo(localpart) profile = await self.store.get_profileinfo(user)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
user_id, profile user_id, profile
) )
...@@ -588,7 +599,7 @@ class RegistrationHandler: ...@@ -588,7 +599,7 @@ class RegistrationHandler:
# moving away from bare excepts is a good thing to do. # moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
async def _auto_join_rooms(self, user_id: str) -> None: async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place """Automatically joins users to auto join rooms - creating the room in the first place
...@@ -628,7 +639,9 @@ class RegistrationHandler: ...@@ -628,7 +639,9 @@ class RegistrationHandler:
""" """
await self._auto_join_rooms(user_id) await self._auto_join_rooms(user_id)
async def appservice_register(self, user_localpart: str, as_token: str) -> str: async def appservice_register(
self, user_localpart: str, as_token: str
) -> Tuple[str, ApplicationService]:
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
...@@ -651,7 +664,7 @@ class RegistrationHandler: ...@@ -651,7 +664,7 @@ class RegistrationHandler:
appservice_id=service_id, appservice_id=service_id,
create_profile_with_displayname=user.localpart, create_profile_with_displayname=user.localpart,
) )
return user_id return (user_id, service)
def check_user_id_not_appservice_exclusive( def check_user_id_not_appservice_exclusive(
self, user_id: str, allowed_appservice: Optional[ApplicationService] = None self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C. # Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # See the GNU Affero General Public License for more details:
# you may not use this file except in compliance with the License. # <https://www.gnu.org/licenses/agpl-3.0.html>.
# You may obtain a copy of the License at #
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum import enum
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Sequence,
)
import attr import attr
...@@ -157,17 +174,23 @@ class RelationsHandler: ...@@ -157,17 +174,23 @@ class RelationsHandler:
now = self._clock.time_msec() now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester) serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = { return_value: JsonDict = {
"chunk": self._event_serializer.serialize_events( "chunk": await self._event_serializer.serialize_events(
events, events,
now, now,
bundle_aggregations=aggregations, bundle_aggregations=aggregations,
config=serialize_options, config=serialize_options,
), ),
} }
if recurse:
return_value["recursion_depth"] = 3
if include_original_event: if include_original_event:
# Do not bundle aggregations when retrieving the original event because # Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it. # we want the content before relations are applied to it.
return_value["original_event"] = self._event_serializer.serialize_event( return_value[
"original_event"
] = await self._event_serializer.serialize_event(
event, event,
now, now,
bundle_aggregations=None, bundle_aggregations=None,
...@@ -205,16 +228,22 @@ class RelationsHandler: ...@@ -205,16 +228,22 @@ class RelationsHandler:
event_id: The event IDs to look and redact relations of. event_id: The event IDs to look and redact relations of.
initial_redaction_event: The redaction for the event referred to by initial_redaction_event: The redaction for the event referred to by
event_id. event_id.
relation_types: The types of relations to look for. relation_types: The types of relations to look for. If "*" is in the list,
all related events will be redacted regardless of the type.
Raises: Raises:
ShadowBanError if the requester is shadow-banned ShadowBanError if the requester is shadow-banned
""" """
related_event_ids = ( if "*" in relation_types:
await self._main_store.get_all_relations_for_event_with_types( related_event_ids = await self._main_store.get_all_relations_for_event(
event_id, relation_types event_id
)
else:
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
event_id, relation_types
)
) )
)
for related_event_id in related_event_ids: for related_event_id in related_event_ids:
try: try:
...@@ -239,7 +268,7 @@ class RelationsHandler: ...@@ -239,7 +268,7 @@ class RelationsHandler:
async def get_references_for_events( async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[_RelatedEvent]]: ) -> Mapping[str, Sequence[_RelatedEvent]]:
"""Get a list of references to the given events. """Get a list of references to the given events.
Args: Args:
...@@ -362,9 +391,9 @@ class RelationsHandler: ...@@ -362,9 +391,9 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event. # Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event( potential_events, _ = await self._main_store.get_relations_for_event(
room_id,
event_id, event_id,
event, event,
room_id,
RelationTypes.THREAD, RelationTypes.THREAD,
direction=Direction.FORWARDS, direction=Direction.FORWARDS,
) )
...@@ -586,7 +615,7 @@ class RelationsHandler: ...@@ -586,7 +615,7 @@ class RelationsHandler:
) )
now = self._clock.time_msec() now = self._clock.time_msec()
serialized_events = self._event_serializer.serialize_events( serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations events, now, bundle_aggregations=aggregations
) )
......