Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • maunium/synapse
  • leytilera/synapse
2 results
Show changes
Showing
with 1502 additions and 659 deletions
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, List, Optional, Pattern
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2019 The Matrix.org Foundation C.I.C.d
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Set
......
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
......@@ -35,3 +41,4 @@ class UserDirectoryConfig(Config):
self.user_directory_search_prefer_local_users = user_directory_config.get(
"prefer_local_users", False
)
self.show_locked_users = user_directory_config.get("show_locked_users", False)
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from synapse.types import JsonDict
from ._base import Config
from ._base import Config, ConfigError, read_file
CONFLICTING_SHARED_SECRET_OPTS_ERROR = """\
You have configured both `turn_shared_secret` and `turn_shared_secret_path`.
These are mutually incompatible.
"""
class VoipConfig(Config):
section = "voip"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
def read_config(
self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any
) -> None:
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
if self.turn_shared_secret and not allow_secrets_in_config:
raise ConfigError(
"Config options that expect an in-line secret as value are disabled",
("turn_shared_secret",),
)
turn_shared_secret_path = config.get("turn_shared_secret_path")
if turn_shared_secret_path:
if self.turn_shared_secret:
raise ConfigError(CONFLICTING_SHARED_SECRET_OPTS_ERROR)
self.turn_shared_secret = read_file(
turn_shared_secret_path, ("turn_shared_secret_path",)
).strip()
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(
......
# Copyright 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
import attr
from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse._pydantic_compat import (
BaseModel,
Extra,
StrictBool,
StrictInt,
StrictStr,
)
from synapse.config._base import (
Config,
ConfigError,
RoutableShardedWorkerHandlingConfig,
ShardedWorkerHandlingConfig,
read_file,
)
from synapse.config._util import parse_and_validate_mapping
from synapse.config.server import (
......@@ -39,6 +53,30 @@ The '%s' configuration option is deprecated and will be removed in a future
Synapse version. Please use ``%s: name_of_worker`` instead.
"""
_MISSING_MAIN_PROCESS_INSTANCE_MAP_DATA = """
Missing data for a worker to connect to main process. Please include '%s' in the
`instance_map` declared in your shared yaml configuration as defined in configuration
documentation here:
`https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#instance_map`
"""
WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE = """
'%s' is no longer a supported worker setting, please place '%s' onto your shared
configuration under `main` inside the `instance_map`. See workers documentation here:
`https://element-hq.github.io/synapse/latest/workers.html#worker-configuration`
"""
CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR = """\
Conflicting options 'worker_replication_secret' and
'worker_replication_secret_path' are both defined in config file.
"""
# This allows for a handy knob when it's time to change from 'master' to
# something with less 'history'
MAIN_PROCESS_INSTANCE_NAME = "master"
# Use this to adjust what the main process is known as in the yaml instance_map
MAIN_PROCESS_INSTANCE_MAP_NAME = "main"
logger = logging.getLogger(__name__)
......@@ -75,7 +113,7 @@ class ConfigModel(BaseModel):
allow_mutation = False
class InstanceLocationConfig(ConfigModel):
class InstanceTcpLocationConfig(ConfigModel):
"""The host and port to talk to an instance via HTTP replication."""
host: StrictStr
......@@ -91,6 +129,23 @@ class InstanceLocationConfig(ConfigModel):
return f"{self.host}:{self.port}"
class InstanceUnixLocationConfig(ConfigModel):
"""The socket file to talk to an instance via HTTP replication."""
path: StrictStr
def scheme(self) -> str:
"""Hardcode a retrievable scheme"""
return "unix"
def netloc(self) -> str:
"""Nicely format the address location data"""
return f"{self.path}"
InstanceLocationConfig = Union[InstanceTcpLocationConfig, InstanceUnixLocationConfig]
@attr.s
class WriterLocations:
"""Specifies the instances that write various streams.
......@@ -107,6 +162,8 @@ class WriterLocations:
can only be a single instance.
presence: The instances that write to the presence stream. Currently
can only be a single instance.
push_rules: The instances that write to the push stream. Currently
can only be a single instance.
"""
events: List[str] = attr.ib(
......@@ -133,6 +190,31 @@ class WriterLocations:
default=["master"],
converter=_instance_to_list_converter,
)
push_rules: List[str] = attr.ib(
default=["master"],
converter=_instance_to_list_converter,
)
@attr.s(auto_attribs=True)
class OutboundFederationRestrictedTo:
"""Whether we limit outbound federation to a certain set of instances.
Attributes:
instances: optional list of instances that can make outbound federation
requests. If None then all instances can make federation requests.
locations: list of instance locations to connect to proxy via.
"""
instances: Optional[List[str]]
locations: List[InstanceLocationConfig] = attr.Factory(list)
def __contains__(self, instance: str) -> bool:
# It feels a bit dirty to return `True` if `instances` is `None`, but it makes
# sense in downstream usage in the sense that if
# `outbound_federation_restricted_to` is not configured, then any instance can
# talk to federation (no restrictions so always return `True`).
return self.instances is None or instance in self.instances
class WorkerConfig(Config):
......@@ -142,7 +224,9 @@ class WorkerConfig(Config):
section = "worker"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
def read_config(
self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any
) -> None:
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
......@@ -161,27 +245,31 @@ class WorkerConfig(Config):
raise ConfigError("worker_log_config must be a string")
self.worker_log_config = worker_log_config
# The host used to connect to the main synapse
self.worker_replication_host = config.get("worker_replication_host", None)
# The port on the main synapse for TCP replication
if "worker_replication_port" in config:
raise ConfigError(DIRECT_TCP_ERROR, ("worker_replication_port",))
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
# The tls mode on the main synapse for HTTP replication endpoint.
# For backward compatibility this defaults to False.
self.worker_replication_http_tls = config.get(
"worker_replication_http_tls", False
)
# The shared secret used for authentication when connecting to the main synapse.
self.worker_replication_secret = config.get("worker_replication_secret", None)
worker_replication_secret = config.get("worker_replication_secret", None)
if worker_replication_secret and not allow_secrets_in_config:
raise ConfigError(
"Config options that expect an in-line secret as value are disabled",
("worker_replication_secret",),
)
worker_replication_secret_path = config.get(
"worker_replication_secret_path", None
)
if worker_replication_secret_path:
if worker_replication_secret:
raise ConfigError(CONFLICTING_WORKER_REPLICATION_SECRET_OPTS_ERROR)
self.worker_replication_secret = read_file(
worker_replication_secret_path, "worker_replication_secret_path"
).strip()
else:
self.worker_replication_secret = worker_replication_secret
self.worker_name = config.get("worker_name", self.worker_app)
self.instance_name = self.worker_name or "master"
self.instance_name = self.worker_name or MAIN_PROCESS_INSTANCE_NAME
# FIXME: Remove this check after a suitable amount of time.
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
......@@ -215,11 +303,60 @@ class WorkerConfig(Config):
)
# A map from instance name to host/port of their HTTP replication endpoint.
self.instance_map: Dict[
str, InstanceLocationConfig
] = parse_and_validate_mapping(
config.get("instance_map", {}),
InstanceLocationConfig,
# Check if the main process is declared. The main process itself doesn't need
# this data as it would never have to talk to itself.
instance_map: Dict[str, Any] = config.get("instance_map", {})
if self.instance_name is not MAIN_PROCESS_INSTANCE_NAME:
# TODO: The next 3 condition blocks can be deleted after some time has
# passed and we're ready to stop checking for these settings.
# The host used to connect to the main synapse
main_host = config.get("worker_replication_host", None)
if main_host:
raise ConfigError(
WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE
% ("worker_replication_host", main_host)
)
# The port on the main synapse for HTTP replication endpoint
main_port = config.get("worker_replication_http_port")
if main_port:
raise ConfigError(
WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE
% ("worker_replication_http_port", main_port)
)
# The tls mode on the main synapse for HTTP replication endpoint.
# For backward compatibility this defaults to False.
main_tls = config.get("worker_replication_http_tls", False)
if main_tls:
raise ConfigError(
WORKER_REPLICATION_SETTING_DEPRECATED_MESSAGE
% ("worker_replication_http_tls", main_tls)
)
# For now, accept 'main' in the instance_map, but the replication system
# expects 'master', force that into being until it's changed later.
if MAIN_PROCESS_INSTANCE_MAP_NAME in instance_map:
instance_map[MAIN_PROCESS_INSTANCE_NAME] = instance_map[
MAIN_PROCESS_INSTANCE_MAP_NAME
]
del instance_map[MAIN_PROCESS_INSTANCE_MAP_NAME]
else:
# If we've gotten here, it means that the main process is not on the
# instance_map.
raise ConfigError(
_MISSING_MAIN_PROCESS_INSTANCE_MAP_DATA
% MAIN_PROCESS_INSTANCE_MAP_NAME
)
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
self.instance_map: Dict[str, InstanceLocationConfig] = (
parse_and_validate_mapping(
instance_map,
InstanceLocationConfig, # type: ignore[arg-type]
)
)
# Map from type of streams to source, c.f. WriterLocations.
......@@ -235,6 +372,7 @@ class WorkerConfig(Config):
"account_data",
"receipts",
"presence",
"push_rules",
):
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
......@@ -259,9 +397,9 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
if len(self.writers.receipts) == 0:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
"Must specify at least one instance to handle `receipts` messages."
)
if len(self.writers.events) == 0:
......@@ -272,6 +410,11 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `presence` messages."
)
if len(self.writers.push_rules) != 1:
raise ConfigError(
"Must only specify one instance to handle `push` messages."
)
self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events
)
......@@ -313,6 +456,28 @@ class WorkerConfig(Config):
new_option_name="update_user_directory_from_worker",
)
outbound_federation_restricted_to = config.get(
"outbound_federation_restricted_to", None
)
self.outbound_federation_restricted_to = OutboundFederationRestrictedTo(
outbound_federation_restricted_to
)
if outbound_federation_restricted_to:
if not self.worker_replication_secret:
raise ConfigError(
"`worker_replication_secret` must be configured when using `outbound_federation_restricted_to`."
)
for instance in outbound_federation_restricted_to:
if instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured in 'outbound_federation_restricted_to' but does not appear in `instance_map` config."
% (instance,)
)
self.outbound_federation_restricted_to.locations.append(
self.instance_map[instance]
)
def _should_this_worker_perform_duty(
self,
config: Dict[str, Any],
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
......
#
# Copyright 2014-2016 OpenMarket Ltd
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import hashlib
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
......@@ -23,12 +30,7 @@ from signedjson.key import (
get_verify_key,
is_signing_algorithm_supported,
)
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
signature_ids,
verify_signed_json,
)
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
from signedjson.types import VerifyKey
from unpaddedbase64 import decode_base64
......@@ -173,7 +175,7 @@ class Keyring:
process_batch_callback=self._inner_fetch_key_requests,
)
self._hostname = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
# build a FetchKeyResult for each of our own keys, to shortcircuit the
# fetcher.
......@@ -277,7 +279,7 @@ class Keyring:
# If we are the originating server, short-circuit the key-fetch for any keys
# we already have
if verify_request.server_name == self._hostname:
if self._is_mine_server_name(verify_request.server_name):
for key_id in verify_request.key_ids:
if key_id in self._local_verify_keys:
found_keys[key_id] = self._local_verify_keys[key_id]
......@@ -587,7 +589,7 @@ class BaseV2KeyFetcher(KeyFetcher):
% (server_name,)
)
for key_id, key_data in response_json["old_verify_keys"].items():
for key_id, key_data in response_json.get("old_verify_keys", {}).items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
......@@ -596,24 +598,12 @@ class BaseV2KeyFetcher(KeyFetcher):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
key_json_bytes = encode_canonical_json(response_json)
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
consumeErrors=True,
).addErrback(unwrapFirstError)
await self.store.store_server_keys_response(
server_name=server_name,
from_server=from_server,
ts_added_ms=time_added_ms,
verify_keys=verify_keys,
response_json=response_json,
)
return verify_keys
......@@ -775,10 +765,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
keys.setdefault(server_name, {}).update(processed_response)
await self.store.store_server_signature_keys(
perspective_name, time_now_ms, added_keys
)
return keys
def _validate_perspectives_response(
......@@ -853,11 +839,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Map from server_name -> key_id -> FetchKeyResult
"""
results = {}
# We only need to do one request per server.
servers_to_fetch = {k.server_name for k in keys_to_fetch}
async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
server_name = key_to_fetch_item.server_name
results = {}
async def get_keys(server_name: str) -> None:
try:
keys = await self.get_server_verify_keys_v2_direct(server_name)
results[server_name] = keys
......@@ -866,7 +853,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys from %s", server_name)
await yieldable_gather_results(get_keys, keys_to_fetch)
await yieldable_gather_results(get_keys, servers_to_fetch)
return results
async def get_server_verify_keys_v2_direct(
......
# Copyright 2014 - 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import logging
import typing
from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from typing import (
Any,
ChainMap,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Union,
cast,
)
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from typing_extensions import Protocol
from unpaddedbase64 import decode_base64
from synapse.api.constants import (
......@@ -68,8 +88,7 @@ class _EventSourceStore(Protocol):
redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False,
allow_rejected: bool = False,
) -> Dict[str, "EventBase"]:
...
) -> Dict[str, "EventBase"]: ...
def validate_event_for_room_version(event: "EventBase") -> None:
......@@ -126,7 +145,7 @@ def validate_event_for_room_version(event: "EventBase") -> None:
raise AuthError(403, "Event not signed by sending server")
is_invite_via_allow_rule = (
event.room_version.msc3083_join_rules
event.room_version.restricted_join_rule
and event.type == EventTypes.Member
and event.membership == Membership.JOIN
and EventContentFields.AUTHORISING_USER in event.content
......@@ -168,12 +187,22 @@ async def check_state_independent_auth_rules(
return
# 2. Reject if event has auth_events that: ...
auth_events: ChainMap[str, EventBase] = ChainMap()
if batched_auth_events:
# Copy the batched auth events to avoid mutating them.
auth_events = dict(batched_auth_events)
needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys()
# batched_auth_events can become very large. To avoid repeatedly copying it, which
# would significantly impact performance, we use a ChainMap.
# batched_auth_events must be cast to MutableMapping because .new_child() requires
# this type. This casting is safe as the mapping is never mutated.
auth_events = auth_events.new_child(
cast(MutableMapping[str, "EventBase"], batched_auth_events)
)
needed_auth_event_ids = [
event_id
for event_id in event.auth_event_ids()
if event_id not in batched_auth_events
]
if needed_auth_event_ids:
auth_events.update(
auth_events = auth_events.new_child(
await store.get_events(
needed_auth_event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
......@@ -181,10 +210,12 @@ async def check_state_independent_auth_rules(
)
)
else:
auth_events = await store.get_events(
event.auth_event_ids(),
redact_behaviour=EventRedactBehaviour.as_is,
allow_rejected=True,
auth_events = auth_events.new_child(
await store.get_events(
event.auth_event_ids(),
redact_behaviour=EventRedactBehaviour.as_is,
allow_rejected=True,
)
)
room_id = event.room_id
......@@ -339,13 +370,6 @@ def check_state_dependent_auth_rules(
if event.type == EventTypes.Redaction:
check_redaction(event.room_version, event, auth_dict)
if (
event.type == EventTypes.MSC2716_INSERTION
or event.type == EventTypes.MSC2716_BATCH
or event.type == EventTypes.MSC2716_MARKER
):
check_historical(event.room_version, event, auth_dict)
logger.debug("Allowing! %s", event)
......@@ -359,14 +383,12 @@ LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS = {
RoomVersions.V4,
RoomVersions.V5,
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.V7,
RoomVersions.V8,
RoomVersions.V9,
RoomVersions.MSC3787,
RoomVersions.V10,
RoomVersions.MSC2716v4,
RoomVersions.MSC1767v10,
RoomVersions.MSC3757v10,
}
......@@ -457,7 +479,7 @@ def _check_create(event: "EventBase") -> None:
# 1.4 If content has no creator field, reject if the room version requires it.
if (
not event.room_version.msc2175_implicit_room_creator
not event.room_version.implicit_room_creator
and EventContentFields.ROOM_CREATOR not in event.content
):
raise AuthError(403, "Create event lacks a 'creator' property")
......@@ -494,7 +516,7 @@ def _is_membership_change_allowed(
key = (EventTypes.Create, "")
create = auth_events.get(key)
if create and event.prev_event_ids()[0] == create.event_id:
if room_version.msc2175_implicit_room_creator:
if room_version.implicit_room_creator:
creator = create.sender
else:
creator = create.content[EventContentFields.ROOM_CREATOR]
......@@ -517,7 +539,7 @@ def _is_membership_change_allowed(
caller_invited = caller and caller.membership == Membership.INVITE
caller_knocked = (
caller
and room_version.msc2403_knocking
and room_version.knock_join_rule
and caller.membership == Membership.KNOCK
)
......@@ -544,6 +566,7 @@ def _is_membership_change_allowed(
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_membership": caller.membership if caller else None,
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"caller_knocked": caller_knocked,
......@@ -617,9 +640,9 @@ def _is_membership_change_allowed(
elif join_rule == JoinRules.PUBLIC:
pass
elif (
room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED
room_version.restricted_join_rule and join_rule == JoinRules.RESTRICTED
) or (
room_version.msc3787_knock_restricted_join_rule
room_version.knock_restricted_join_rule
and join_rule == JoinRules.KNOCK_RESTRICTED
):
# This is the same as public, but the event must contain a reference
......@@ -649,13 +672,14 @@ def _is_membership_change_allowed(
elif (
join_rule == JoinRules.INVITE
or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
or (room_version.knock_join_rule and join_rule == JoinRules.KNOCK)
or (
room_version.msc3787_knock_restricted_join_rule
room_version.knock_restricted_join_rule
and join_rule == JoinRules.KNOCK_RESTRICTED
)
):
if not caller_in_room and not caller_invited:
# You can only join the room if you are invited or are already in the room.
if not (caller_in_room or caller_invited):
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
......@@ -679,15 +703,21 @@ def _is_membership_change_allowed(
errcode=Codes.INSUFFICIENT_POWER,
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
if user_level < ban_level:
raise UnstableSpecAuthError(
403,
"You don't have permission to ban",
errcode=Codes.INSUFFICIENT_POWER,
)
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
elif user_level <= target_level:
raise UnstableSpecAuthError(
403,
"You don't have permission to ban this user",
errcode=Codes.INSUFFICIENT_POWER,
)
elif room_version.knock_join_rule and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK and (
not room_version.msc3787_knock_restricted_join_rule
not room_version.knock_restricted_join_rule
or join_rule != JoinRules.KNOCK_RESTRICTED
):
raise AuthError(403, "You don't have permission to knock")
......@@ -763,9 +793,10 @@ def get_send_level(
def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
state_key = event.get_state_key()
power_levels_event = get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
send_level = get_send_level(event.type, state_key, power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
......@@ -776,11 +807,34 @@ def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> b
errcode=Codes.INSUFFICIENT_POWER,
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(403, "You are not allowed to set others state")
if (
state_key is not None
and state_key.startswith("@")
and state_key != event.user_id
):
if event.room_version.msc3757_enabled:
try:
colon_idx = state_key.index(":", 1)
suffix_idx = state_key.find("_", colon_idx + 1)
state_key_user_id = (
state_key[:suffix_idx] if suffix_idx != -1 else state_key
)
if not UserID.is_valid(state_key_user_id):
raise ValueError
except ValueError:
raise SynapseError(
400,
"State key neither equals a valid user ID, nor starts with one plus an underscore",
errcode=Codes.BAD_JSON,
)
if (
# sender is owner of the state key
state_key_user_id == event.user_id
# sender has higher PL than the owner of the state key
or user_level > get_user_power_level(state_key_user_id, auth_events)
):
return True
raise AuthError(403, "You are not allowed to set others state")
return True
......@@ -823,38 +877,6 @@ def check_redaction(
raise AuthError(403, "You don't have permission to redact events")
def check_historical(
room_version_obj: RoomVersion,
event: "EventBase",
auth_events: StateMap["EventBase"],
) -> None:
"""Check whether the event sender is allowed to send historical related
events like "insertion", "batch", and "marker".
Returns:
None
Raises:
AuthError if the event sender is not allowed to send historical related events
("insertion", "batch", and "marker").
"""
# Ignore the auth checks in room versions that do not support historical
# events
if not room_version_obj.msc2716_historical:
return
user_level = get_user_power_level(event.user_id, auth_events)
historical_level = get_named_level(auth_events, "historical", 100)
if user_level < historical_level:
raise UnstableSpecAuthError(
403,
'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
errcode=Codes.INSUFFICIENT_POWER,
)
def _check_power_levels(
room_version_obj: RoomVersion,
event: "EventBase",
......@@ -876,7 +898,7 @@ def _check_power_levels(
# Reject events with stringy power levels if required by room version
if (
event.type == EventTypes.PowerLevels
and room_version_obj.msc3667_int_only_power_levels
and room_version_obj.enforce_int_power_levels
):
for k, v in event.content.items():
if k in {
......@@ -888,11 +910,12 @@ def _check_power_levels(
"kick",
"invite",
}:
if type(v) is not int:
if type(v) is not int: # noqa: E721
raise SynapseError(400, f"{v!r} must be an integer.")
if k in {"events", "notifications", "users"}:
if not isinstance(v, collections.abc.Mapping) or not all(
type(v) is int for v in v.values()
type(v) is int
for v in v.values() # noqa: E721
):
raise SynapseError(
400,
......@@ -1012,7 +1035,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in
key = (EventTypes.Create, "")
create_event = auth_events.get(key)
if create_event is not None:
if create_event.room_version.msc2175_implicit_room_creator:
if create_event.room_version.implicit_room_creator:
creator = create_event.sender
else:
creator = create_event.content[EventContentFields.ROOM_CREATOR]
......@@ -1054,10 +1077,15 @@ def _verify_third_party_invite(
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
third_party_invite = event.content["third_party_invite"]
if not isinstance(third_party_invite, collections.abc.Mapping):
return False
if "signed" not in third_party_invite:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
signed = third_party_invite["signed"]
if not isinstance(signed, collections.abc.Mapping):
return False
for key in {"mxid", "token", "signatures"}:
if key not in signed:
return False
......@@ -1075,8 +1103,6 @@ def _verify_third_party_invite(
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
......@@ -1088,7 +1114,9 @@ def _verify_third_party_invite(
verify_key = decode_verify_key_bytes(
key_name, decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# verify_signed_json incorrectly states it wants a dict, it
# just needs a mapping.
verify_signed_json(signed, server, verify_key) # type: ignore[arg-type]
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
......@@ -1145,7 +1173,7 @@ def auth_types_for_event(
)
auth_types.add(key)
if room_version.msc3083_join_rules and membership == Membership.JOIN:
if room_version.restricted_join_rule and membership == Membership.JOIN:
if EventContentFields.AUTHORISING_USER in event.content:
key = (
EventTypes.Member,
......
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import collections.abc
import os
from typing import (
TYPE_CHECKING,
Any,
......@@ -24,8 +29,8 @@ from typing import (
Generic,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
......@@ -34,29 +39,29 @@ from typing import (
)
import attr
from typing_extensions import Literal
from unpaddedbase64 import encode_base64
from synapse.api.constants import RelationTypes
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.types import JsonDict, RoomStreamToken, StrCollection
from synapse.synapse_rust.events import EventInternalMetadata
from synapse.types import JsonDict, StrCollection
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
if TYPE_CHECKING:
from synapse.events.builder import EventBuilder
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
# dict to frozen_dicts is expensive.
#
# NOTE: This is overridden by the configuration by the Synapse worker apps, but
# for the sake of tests, it is set here while it cannot be configured on the
# homeserver object itself.
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
USE_FROZEN_DICTS = False
"""
Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
bugs where we accidentally share e.g. signature dicts. However, converting a
dict to frozen_dicts is expensive.
NOTE: This is overridden by the configuration by the Synapse worker apps, but
for the sake of tests, it is set here because it cannot be configured on the
homeserver object itself.
"""
T = TypeVar("T")
......@@ -71,7 +76,7 @@ T = TypeVar("T")
#
# Note that DictProperty/DefaultDictProperty cannot actually be used with
# EventBuilder as it lacks a _dict property.
_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
_DictPropertyInstance = Union["EventBase", "EventBuilder"]
class DictProperty(Generic[T]):
......@@ -87,16 +92,14 @@ class DictProperty(Generic[T]):
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DictProperty":
...
) -> "DictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
) -> T: ...
def __get__(
self,
......@@ -108,7 +111,7 @@ class DictProperty(Generic[T]):
if instance is None:
return self
try:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
assert isinstance(instance, EventBase)
return instance._dict[self.key]
except KeyError as e1:
# We want this to look like a regular attribute error (mostly so that
......@@ -124,11 +127,11 @@ class DictProperty(Generic[T]):
) from e1.__context__
def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
assert isinstance(instance, EventBase)
instance._dict[self.key] = v
def __delete__(self, instance: _DictPropertyInstance) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
assert isinstance(instance, EventBase)
try:
del instance._dict[self.key]
except KeyError as e1:
......@@ -155,16 +158,14 @@ class DefaultDictProperty(DictProperty, Generic[T]):
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DefaultDictProperty":
...
) -> "DefaultDictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
) -> T: ...
def __get__(
self,
......@@ -173,134 +174,10 @@ class DefaultDictProperty(DictProperty, Generic[T]):
) -> Union[T, "DefaultDictProperty"]:
if instance is None:
return self
assert isinstance(instance, (EventBase, _EventInternalMetadata))
assert isinstance(instance, EventBase)
return instance._dict.get(self.key, self.default)
class _EventInternalMetadata:
__slots__ = ["_dict", "stream_ordering", "outlier"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
# reused. TODO: fix that
self._dict = dict(internal_metadata_dict)
# the stream ordering of this event. None, until it has been persisted.
self.stream_ordering: Optional[int] = None
# whether this event is an outlier (ie, whether we have the state at that point
# in the DAG)
self.outlier = False
out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
soft_failed: DictProperty[bool] = DictProperty("soft_failed")
proactively_send: DictProperty[bool] = DictProperty("proactively_send")
redacted: DictProperty[bool] = DictProperty("redacted")
historical: DictProperty[bool] = DictProperty("historical")
txn_id: DictProperty[str] = DictProperty("txn_id")
"""The transaction ID, if it was set when the event was created."""
token_id: DictProperty[int] = DictProperty("token_id")
"""The access token ID of the user who sent this event, if any."""
device_id: DictProperty[str] = DictProperty("device_id")
"""The device ID of the user who sent this event, if any."""
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
before: DictProperty[RoomStreamToken] = DictProperty("before")
after: DictProperty[RoomStreamToken] = DictProperty("after")
order: DictProperty[Tuple[int, int]] = DictProperty("order")
def get_dict(self) -> JsonDict:
return dict(self._dict)
def is_outlier(self) -> bool:
return self.outlier
def is_out_of_band_membership(self) -> bool:
"""Whether this event is an out-of-band membership.
OOB memberships are a special case of outlier events: they are membership events
for federated rooms that we aren't full members of. Examples include invites
received over federation, and rejections for such invites.
The concept of an OOB membership is needed because these events need to be
processed as if they're new regular events (e.g. updating membership state in
the database, relaying to clients via /sync, etc) despite being outliers.
See also https://matrix-org.github.io/synapse/develop/development/room-dag-concepts.html#out-of-band-membership-events.
(Added in synapse 0.99.0, so may be unreliable for events received before that)
"""
return self._dict.get("out_of_band_membership", False)
def get_send_on_behalf_of(self) -> Optional[str]:
"""Whether this server should send the event on behalf of another server.
This is used by the federation "send_join" API to forward the initial join
event for a server in the room.
returns a str with the name of the server this event is sent on behalf of.
"""
return self._dict.get("send_on_behalf_of")
def need_to_check_redaction(self) -> bool:
"""Whether the redaction event needs to be rechecked when fetching
from the database.
Starting in room v3 redaction events are accepted up front, and later
checked to see if the redacter and redactee's domains match.
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
"""
return self._dict.get("recheck_redaction", False)
def is_soft_failed(self) -> bool:
"""Whether the event has been soft failed.
Soft failed events should be handled as usual, except:
1. They should not go down sync or event streams, or generally
sent to clients.
2. They should not be added to the forward extremities (and
therefore not to current state).
"""
return self._dict.get("soft_failed", False)
def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and
servers.
This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.
"""
return self._dict.get("proactively_send", True)
def is_redacted(self) -> bool:
"""Whether the event has been redacted.
This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call.
"""
return self._dict.get("redacted", False)
def is_historical(self) -> bool:
"""Whether this is a historical message.
This is used by the batchsend historical message endpoint and
is needed to and mark the event as backfilled and skip some checks
like push notifications.
"""
return self._dict.get("historical", False)
def is_notifiable(self) -> bool:
"""Whether this event can trigger a push notification"""
return not self.is_outlier() or self.is_out_of_band_membership()
class EventBase(metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
......@@ -326,7 +203,7 @@ class EventBase(metaclass=abc.ABCMeta):
self._dict = event_dict
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
self.internal_metadata = EventInternalMetadata(internal_metadata_dict)
depth: DictProperty[int] = DictProperty("depth")
content: DictProperty[JsonDict] = DictProperty("content")
......@@ -355,7 +232,7 @@ class EventBase(metaclass=abc.ABCMeta):
@property
def redacts(self) -> Optional[str]:
"""MSC2176 moved the redacts field into the content."""
if self.room_version.msc2176_redaction_rules:
if self.room_version.updated_redaction_rules:
return self.content.get("redacts")
return self.get("redacts")
......@@ -417,7 +294,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]:
return self._dict.keys()
def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
......@@ -447,12 +324,17 @@ class EventBase(metaclass=abc.ABCMeta):
def __repr__(self) -> str:
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
conditional_membership_string = ""
if self.get("type") == EventTypes.Member:
conditional_membership_string = f"membership={self.membership}, "
return (
f"<{self.__class__.__name__} "
f"{rejection}"
f"event_id={self.event_id}, "
f"type={self.get('type')}, "
f"state_key={self.get('state_key')}, "
f"{conditional_membership_string}"
f"outlier={self.internal_metadata.is_outlier()}"
">"
)
......@@ -562,7 +444,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
......@@ -676,3 +558,22 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
aggregation_key = None
return _EventRelation(parent_id, rel_type, aggregation_key)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StrippedStateEvent:
"""
A stripped down state event. Usually used for remote invite/knocks so the user can
make an informed decision on whether they want to join.
Attributes:
type: Event `type`
state_key: Event `state_key`
sender: Event `sender`
content: Event `content`
"""
type: str
state_key: str
sender: str
content: Dict[str, Any]
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from http import HTTPStatus
from typing import Any, Dict, Tuple
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
from synapse.module_api import EventBase, ModuleApi, run_as_background_process
logger = logging.getLogger(__name__)
class InviteAutoAccepter:
def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
# Keep a reference to the Module API.
self._api = api
self._config = config
if not self._config.enabled:
return
should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
if not should_run_on_this_worker:
logger.info(
"Not accepting invites on this worker (configured: %r, here: %r)",
config.worker_to_run_on,
self._api.worker_name,
)
return
logger.info(
"Accepting invites on this worker (here: %r)", self._api.worker_name
)
# Register the callback.
self._api.register_third_party_rules_callbacks(
on_new_event=self.on_new_event,
)
async def on_new_event(self, event: EventBase, *args: Any) -> None:
"""Listens for new events, and if the event is an invite for a local user then
automatically accepts it.
Args:
event: The incoming event.
"""
# Check if the event is an invite for a local user.
if (
event.type != EventTypes.Member
or event.is_state() is False
or event.membership != Membership.INVITE
or self._api.is_mine(event.state_key) is False
):
return
# Only accept invites for direct messages if the configuration mandates it.
is_direct_message = event.content.get("is_direct", False)
if (
self._config.accept_invites_only_for_direct_messages
and is_direct_message is False
):
return
# Only accept invites from remote users if the configuration mandates it.
is_from_local_user = self._api.is_mine(event.sender)
if (
self._config.accept_invites_only_from_local_users
and is_from_local_user is False
):
return
# Check the user is activated.
recipient = await self._api.get_userinfo_by_id(event.state_key)
# Ignore if the user doesn't exist.
if recipient is None:
return
# Never accept invites for deactivated users.
if recipient.is_deactivated:
return
# Never accept invites for suspended users.
if recipient.suspended:
return
# Never accept invites for locked users.
if recipient.locked:
return
# Make the user join the room. We run this as a background process to circumvent a race condition
# that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
run_as_background_process(
"retry_make_join",
self._retry_make_join,
event.state_key,
event.state_key,
event.room_id,
"join",
bg_start_span=False,
)
if is_direct_message:
# Mark this room as a direct message!
await self._mark_room_as_direct_message(
event.state_key, event.sender, event.room_id
)
async def _mark_room_as_direct_message(
self, user_id: str, dm_user_id: str, room_id: str
) -> None:
"""
Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
from the perspective of the user `user_id`.
Args:
user_id: the user for whom the membership is changing
dm_user_id: the user performing the membership change
room_id: room id of the room the user is invited to
"""
# This is a dict of User IDs to tuples of Room IDs
# (get_global will return a frozendict of tuples as it freezes the data,
# but we should accept either frozen or unfrozen variants.)
# Be careful: we convert the outer frozendict into a dict here,
# but the contents of the dict are still frozen (tuples in lieu of lists,
# etc.)
dm_map: Dict[str, Tuple[str, ...]] = dict(
await self._api.account_data_manager.get_global(
user_id, AccountDataTypes.DIRECT
)
or {}
)
if dm_user_id not in dm_map:
dm_map[dm_user_id] = (room_id,)
else:
dm_rooms_for_user = dm_map[dm_user_id]
assert isinstance(dm_rooms_for_user, (tuple, list))
dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
await self._api.account_data_manager.put_global(
user_id, AccountDataTypes.DIRECT, dm_map
)
async def _retry_make_join(
self, sender: str, target: str, room_id: str, new_membership: str
) -> None:
"""
A function to retry sending the `make_join` request with an increasing backoff. This is
implemented to work around a race condition when receiving invites over federation.
Args:
sender: the user performing the membership change
target: the user for whom the membership is changing
room_id: room id of the room to join to
new_membership: the type of membership event (in this case will be "join")
"""
sleep = 0
retries = 0
join_event = None
while retries < 5:
try:
await self._api.sleep(sleep)
join_event = await self._api.update_room_membership(
sender=sender,
target=target,
room_id=room_id,
new_membership=new_membership,
)
except SynapseError as e:
if e.code == HTTPStatus.FORBIDDEN:
logger.debug(
f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
)
else:
logger.warn(
f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
)
except Exception as e:
logger.warn(
f"Update_room_membership raised the following unexpected exception: {e}"
)
sleep = 2**retries
retries += 1
if join_event is not None:
break
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
from synapse.api.constants import MAX_DEPTH
from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
EventFormatVersions,
......@@ -25,10 +32,11 @@ from synapse.api.room_versions import (
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict
from synapse.synapse_rust.events import EventInternalMetadata
from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
......@@ -87,8 +95,8 @@ class EventBuilder:
_redacts: Optional[str] = None
_origin_server_ts: Optional[int] = None
internal_metadata: _EventInternalMetadata = attr.Factory(
lambda: _EventInternalMetadata({})
internal_metadata: EventInternalMetadata = attr.Factory(
lambda: EventInternalMetadata({})
)
@property
......@@ -101,9 +109,22 @@ class EventBuilder:
def is_state(self) -> bool:
return self._state_key is not None
def is_mine_id(self, user_id: str) -> bool:
"""Determines whether a user ID or room alias originates from this homeserver.
Returns:
`True` if the hostname part of the user ID or room alias matches this
homeserver.
`False` otherwise, or if the user ID or room alias is malformed.
"""
localpart_hostname = user_id.split(":", 1)
if len(localpart_hostname) < 2:
return False
return localpart_hostname[1] == self._hostname
async def build(
self,
prev_event_ids: Collection[str],
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
......@@ -134,9 +155,49 @@ class EventBuilder:
self, state_ids
)
# Check for out-of-band membership that may have been exposed on `/sync` but
# the events have not been de-outliered yet so they won't be part of the
# room state yet.
#
# This helps in situations where a remote homeserver invites a local user to
# a room that we're already participating in; and we've persisted the invite
# as an out-of-band membership (outlier), but it hasn't been pushed to us as
# part of a `/send` transaction yet and de-outliered. This also helps for
# any of the other out-of-band membership transitions.
#
# As an optimization, we could check if the room state already includes a
# non-`leave` membership event, then we can assume the membership event has
# been de-outliered and we don't need to check for an out-of-band
# membership. But we don't have the necessary information from a
# `StateMap[str]` and we'll just have to take the hit of this extra lookup
# for any membership event for now.
if self.type == EventTypes.Member and self.is_mine_id(self.state_key):
(
_membership,
member_event_id,
) = await self._store.get_local_current_membership_for_user_in_room(
user_id=self.state_key,
room_id=self.room_id,
)
# There is no need to check if the membership is actually an
# out-of-band membership (`outlier`) as we would end up with the
# same result either way (adding the member event to the
# `auth_event_ids`).
if (
member_event_id is not None
# We only need to be careful about duplicating the event in the
# `auth_event_ids` list (duplicate `type`/`state_key` is part of the
# authorization rules)
and member_event_id not in auth_event_ids
):
auth_event_ids.append(member_event_id)
# Also make sure to point to the previous membership event that will
# allow this one to happen so the computed state works out.
prev_event_ids.append(member_event_id)
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
......@@ -175,7 +236,7 @@ class EventBuilder:
# MSC2174 moves the redacts property to the content, it is invalid to
# provide it as a top-level property.
if self._redacts is not None and not self.room_version.msc2176_redaction_rules:
if self._redacts is not None and not self.room_version.updated_redaction_rules:
event_dict["redacts"] = self._redacts
if self._origin_server_ts is not None:
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
......@@ -73,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(
f: Optional[Callable[P, R]]
f: Optional[Callable[P, R]],
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import attr
from immutabledict import immutabledict
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
......@@ -54,7 +62,6 @@ class UnpersistedEventContextBase(ABC):
A method to convert an UnpersistedEventContext to an EventContext, suitable for
sending to the database with the associated event.
"""
pass
@abstractmethod
async def get_prev_state_ids(
......@@ -68,7 +75,6 @@ class UnpersistedEventContextBase(ABC):
state_filter: specifies the type of state event to fetch from DB, example:
EventTypes.JoinRules
"""
pass
@attr.s(slots=True, auto_attribs=True)
......@@ -106,33 +112,32 @@ class EventContext(UnpersistedEventContextBase):
state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
then this is the delta of the state between the two groups.
prev_group: If it is known, ``state_group``'s prev_group. Note that this being
None does not necessarily mean that ``state_group`` does not have
a prev_group!
state_group_deltas: If not empty, this is a dict collecting a mapping of the state
difference between state groups.
If the event is a state event, this is normally the same as
``state_group_before_event``.
The keys are a tuple of two integers: the initial group and final state group.
The corresponding value is a state map representing the state delta between
these state groups.
If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
will always also be ``None``.
The dictionary is expected to have at most two entries with state groups of:
Note that this *not* (necessarily) the state group associated with
``_prev_state_ids``.
1. The state group before the event and after the event.
2. The state group preceding the state group before the event and the
state group before the event.
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
and ``state_group``.
This information is collected and stored as part of an optimization for persisting
events.
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""
_storage: "StorageControllers"
state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
rejected: Optional[str] = None
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None
prev_group: Optional[int] = None
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None
partial_state: bool = False
......@@ -144,16 +149,14 @@ class EventContext(UnpersistedEventContextBase):
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
) -> "EventContext":
return EventContext(
storage=storage,
state_group=state_group,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=state_delta_due_to_event,
prev_group=prev_group,
delta_ids=delta_ids,
state_group_deltas=state_group_deltas,
partial_state=partial_state,
)
......@@ -162,7 +165,7 @@ class EventContext(UnpersistedEventContextBase):
storage: "StorageControllers",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
return EventContext(storage=storage, state_group_deltas={})
async def persist(self, event: EventBase) -> "EventContext":
return self
......@@ -182,11 +185,10 @@ class EventContext(UnpersistedEventContextBase):
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
"state_group_deltas": _encode_state_group_delta(self.state_group_deltas),
"state_delta_due_to_event": _encode_state_dict(
self._state_delta_due_to_event
),
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
}
......@@ -203,17 +205,17 @@ class EventContext(UnpersistedEventContextBase):
Returns:
The event context.
"""
context = EventContext(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
storage=storage,
state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"],
state_group_deltas=_decode_state_group_delta(input["state_group_deltas"]),
state_delta_due_to_event=_decode_state_dict(
input["state_delta_due_to_event"]
),
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
)
......@@ -242,9 +244,11 @@ class EventContext(UnpersistedEventContextBase):
return self._state_group
@trace
@tag_args
async def get_current_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> Optional[StateMap[str]]:
) -> StateMap[str]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
......@@ -252,13 +256,12 @@ class EventContext(UnpersistedEventContextBase):
not make it into the room state. This method will raise an exception if
``rejected`` is set.
It is also an error to access this for an outlier event.
Arg:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
......@@ -275,6 +278,8 @@ class EventContext(UnpersistedEventContextBase):
return prev_state_ids
@trace
@tag_args
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
......@@ -294,7 +299,8 @@ class EventContext(UnpersistedEventContextBase):
this tuple.
"""
assert self.state_group_before_event is not None
if self.state_group_before_event is None:
return {}
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event, state_filter
)
......@@ -344,7 +350,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
_storage: "StorageControllers"
state_group_before_event: Optional[int]
state_group_after_event: Optional[int]
state_delta_due_to_event: Optional[dict]
state_delta_due_to_event: Optional[StateMap[str]]
prev_group_for_state_group_before_event: Optional[int]
delta_ids_to_state_group_before_event: Optional[StateMap[str]]
partial_state: bool
......@@ -375,26 +381,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
events_and_persisted_context = []
for event, unpersisted_context in amended_events_and_context:
if event.is_state():
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.state_group_before_event,
delta_ids=unpersisted_context.state_delta_due_to_event,
)
else:
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
prev_group=unpersisted_context.prev_group_for_state_group_before_event,
delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
)
state_group_deltas = unpersisted_context._build_state_group_deltas()
context = EventContext(
storage=unpersisted_context._storage,
state_group=unpersisted_context.state_group_after_event,
state_group_before_event=unpersisted_context.state_group_before_event,
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
state_group_deltas=state_group_deltas,
)
events_and_persisted_context.append((event, context))
return events_and_persisted_context
......@@ -447,11 +443,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
# if the event isn't a state event the state group doesn't change
if not self.state_delta_due_to_event:
state_group_after_event = self.state_group_before_event
self.state_group_after_event = self.state_group_before_event
# otherwise if it is a state event we need to get a state group for it
else:
state_group_after_event = await self._storage.state.store_state_group(
self.state_group_after_event = await self._storage.state.store_state_group(
event.event_id,
event.room_id,
prev_group=self.state_group_before_event,
......@@ -459,16 +455,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
current_state_ids=None,
)
state_group_deltas = self._build_state_group_deltas()
return EventContext.with_state(
storage=self._storage,
state_group=state_group_after_event,
state_group=self.state_group_after_event,
state_group_before_event=self.state_group_before_event,
state_delta_due_to_event=self.state_delta_due_to_event,
state_group_deltas=state_group_deltas,
partial_state=self.partial_state,
prev_group=self.state_group_before_event,
delta_ids=self.state_delta_due_to_event,
)
def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:
"""
Collect deltas between the state groups associated with this context
"""
state_group_deltas = {}
# if we know the state group before the event and after the event, add them and the
# state delta between them to state_group_deltas
if self.state_group_before_event and self.state_group_after_event:
# if we have the state groups we should have the delta
assert self.state_delta_due_to_event is not None
state_group_deltas[
(
self.state_group_before_event,
self.state_group_after_event,
)
] = self.state_delta_due_to_event
# the state group before the event may also have a state group which precedes it, if
# we have that and the state group before the event, add them and the state
# delta between them to state_group_deltas
if (
self.prev_group_for_state_group_before_event
and self.state_group_before_event
):
# if we have both state groups we should have the delta between them
assert self.delta_ids_to_state_group_before_event is not None
state_group_deltas[
(
self.prev_group_for_state_group_before_event,
self.state_group_before_event,
)
] = self.delta_ids_to_state_group_before_event
return state_group_deltas
def _encode_state_group_delta(
state_group_delta: Dict[Tuple[int, int], StateMap[str]],
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
if not state_group_delta:
return []
state_group_delta_encoded = []
for key, value in state_group_delta.items():
state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value)))
return state_group_delta_encoded
def _decode_state_group_delta(
input: List[Tuple[int, int, List[Tuple[str, str, str]]]],
) -> Dict[Tuple[int, int], StateMap[str]]:
if not input:
return {}
state_group_deltas = {}
for state_group_1, state_group_2, state_dict in input:
state_map = _decode_state_dict(state_dict)
assert state_map is not None
state_group_deltas[(state_group_1, state_group_2)] = state_map
return state_group_deltas
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
......@@ -483,7 +544,7 @@ def _encode_state_dict(
def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]]
input: Optional[List[Tuple[str, str, str]]],
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
......
# Copyright 2014-2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2021 The Matrix.org Foundation C.I.C.
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
import re
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Mapping,
Match,
MutableMapping,
Optional,
Union,
......@@ -40,23 +49,29 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict, Requester
from . import EventBase
from . import EventBase, StrippedStateEvent, make_event_from_dict
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
# Split strings on "." but not "\." (or "\\\.").
SPLIT_FIELD_REGEX = re.compile(r"\\*\.")
# Find escaped characters, e.g. those with a \ in front of them.
ESCAPE_SEQUENCE_PATTERN = re.compile(r"\\(.)")
CANONICALJSON_MAX_INT = (2**53) - 1
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
# Module API callback that allows adding fields to the unsigned section of
# events that are sent to clients.
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
[EventBase], Awaitable[JsonDict]
]
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
......@@ -67,17 +82,15 @@ def prune_event(event: EventBase) -> EventBase:
"""
pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
from . import make_event_from_dict
pruned_event = make_event_from_dict(
pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
# copy the internal fields
# Copy the bits of `internal_metadata` that aren't returned by `get_dict`
pruned_event.internal_metadata.stream_ordering = (
event.internal_metadata.stream_ordering
)
pruned_event.internal_metadata.instance_name = event.internal_metadata.instance_name
pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
# Mark the event as redacted
......@@ -86,6 +99,30 @@ def prune_event(event: EventBase) -> EventBase:
return pruned_event
def clone_event(event: EventBase) -> EventBase:
"""Take a copy of the event.
This is mostly useful because it does a *shallow* copy of the `unsigned` data,
which means it can then be updated without corrupting the in-memory cache. Note that
other properties of the event, such as `content`, are *not* (currently) copied here.
"""
# XXX: We rely on at least one of `event.get_dict()` and `make_event_from_dict()`
# making a copy of `unsigned`. Currently, both do, though I don't really know why.
# Still, as long as they do, there's not much point doing yet another copy here.
new_event = make_event_from_dict(
event.get_dict(), event.room_version, event.internal_metadata.get_dict()
)
# Copy the bits of `internal_metadata` that aren't returned by `get_dict`.
new_event.internal_metadata.stream_ordering = (
event.internal_metadata.stream_ordering
)
new_event.internal_metadata.instance_name = event.internal_metadata.instance_name
new_event.internal_metadata.outlier = event.internal_metadata.outlier
return new_event
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
......@@ -109,13 +146,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
"origin_server_ts",
]
# Room versions from before MSC2176 had additional allowed keys.
if not room_version.msc2176_redaction_rules:
allowed_keys.extend(["prev_state", "membership"])
# Room versions before MSC3989 kept the origin field.
if not room_version.msc3989_redaction_rules:
allowed_keys.append("origin")
# Earlier room versions from had additional allowed keys.
if not room_version.updated_redaction_rules:
allowed_keys.extend(["prev_state", "membership", "origin"])
event_type = event_dict["type"]
......@@ -128,17 +161,29 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
if event_type == EventTypes.Member:
add_fields("membership")
if room_version.msc3375_redaction_rules:
if room_version.restricted_join_rule_fix:
add_fields(EventContentFields.AUTHORISING_USER)
if room_version.updated_redaction_rules:
# Preserve the signed field under third_party_invite.
third_party_invite = event_dict["content"].get("third_party_invite")
if isinstance(third_party_invite, collections.abc.Mapping):
new_content["third_party_invite"] = {}
if "signed" in third_party_invite:
new_content["third_party_invite"]["signed"] = third_party_invite[
"signed"
]
elif event_type == EventTypes.Create:
# MSC2176 rules state that create events cannot be redacted.
if room_version.msc2176_redaction_rules:
return event_dict
if room_version.updated_redaction_rules:
# MSC2176 rules state that create events cannot have their `content` redacted.
new_content = event_dict["content"]
elif not room_version.implicit_room_creator:
# Some room versions give meaning to `creator`
add_fields("creator")
add_fields("creator")
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
if room_version.msc3083_join_rules:
if room_version.restricted_join_rule:
add_fields("allow")
elif event_type == EventTypes.PowerLevels:
add_fields(
......@@ -152,24 +197,27 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
"redact",
)
if room_version.msc2176_redaction_rules:
if room_version.updated_redaction_rules:
add_fields("invite")
if room_version.msc2716_historical:
add_fields("historical")
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
elif event_type == EventTypes.Redaction and room_version.updated_redaction_rules:
add_fields("redacts")
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE)
# Protect the rel_type and event_id fields under the m.relates_to field.
if room_version.msc3389_relation_redactions:
relates_to = event_dict["content"].get("m.relates_to")
if isinstance(relates_to, collections.abc.Mapping):
new_relates_to = {}
for field in ("rel_type", "event_id"):
if field in relates_to:
new_relates_to[field] = relates_to[field]
# Only include a non-empty relates_to field.
if new_relates_to:
new_content["m.relates_to"] = new_relates_to
allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
......@@ -231,6 +279,57 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
sub_out_dict[key_to_move] = sub_dict[key_to_move]
def _escape_slash(m: Match[str]) -> str:
"""
Replacement function; replace a backslash-backslash or backslash-dot with the
second character. Leaves any other string alone.
"""
if m.group(1) in ("\\", "."):
return m.group(1)
return m.group(0)
def _split_field(field: str) -> List[str]:
"""
Splits strings on unescaped dots and removes escaping.
Args:
field: A string representing a path to a field.
Returns:
A list of nested fields to traverse.
"""
# Convert the field and remove escaping:
#
# 1. "content.body.thing\.with\.dots"
# 2. ["content", "body", "thing\.with\.dots"]
# 3. ["content", "body", "thing.with.dots"]
# Find all dots (and their preceding backslashes). If the dot is unescaped
# then emit a new field part.
result = []
prev_start = 0
for match in SPLIT_FIELD_REGEX.finditer(field):
# If the match is an *even* number of characters than the dot was escaped.
if len(match.group()) % 2 == 0:
continue
# Add a new part (up to the dot, exclusive) after escaping.
result.append(
ESCAPE_SEQUENCE_PATTERN.sub(
_escape_slash, field[prev_start : match.end() - 1]
)
)
prev_start = match.end()
# Add any part of the field after the last unescaped dot. (Note that if the
# character is a dot this correctly adds a blank string.)
result.append(re.sub(r"\\(.)", _escape_slash, field[prev_start:]))
return result
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
......@@ -238,7 +337,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
If there are no event fields specified then all fields are included.
The entries may include '.' characters to indicate sub-fields.
So ['content.body'] will include the 'body' field of the 'content' object.
A literal '.' character in a field name may be escaped using a '\'.
A literal '.' or '\' character in a field name may be escaped using a '\'.
Args:
dictionary: The dictionary to read from.
......@@ -253,13 +352,7 @@ def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
# for each field, convert it:
# ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
split_fields = [_split_field(f) for f in fields]
output: JsonDict = {}
for field_array in split_fields:
......@@ -339,7 +432,6 @@ def serialize_event(
time_now_ms: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
msc3970_enabled: bool = False,
) -> JsonDict:
"""Serialize event for clients
......@@ -347,8 +439,6 @@ def serialize_event(
e
time_now_ms
config: Event serialization config
msc3970_enabled: Whether MSC3970 is enabled. It changes whether we should
include the `transaction_id` in the event's `unsigned` section.
Returns:
The serialized event dictionary.
......@@ -374,38 +464,46 @@ def serialize_event(
e.unsigned["redacted_because"],
time_now_ms,
config=config,
msc3970_enabled=msc3970_enabled,
)
# If we have a txn_id saved in the internal_metadata, we should include it in the
# unsigned section of the event if it was sent by the same session as the one
# requesting the event.
txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None)
if txn_id is not None and config.requester is not None:
# For the MSC3970 rules to be applied, we *need* to have the device ID in the
# event internal metadata. Since we were not recording them before, if it hasn't
# been recorded, we fallback to the old behaviour.
if (
txn_id is not None
and config.requester is not None
and config.requester.user.to_string() == e.sender
):
# Some events do not have the device ID stored in the internal metadata,
# this includes old events as well as those created by appservice, guests,
# or with tokens minted with the admin API. For those events, fallback
# to using the access token instead.
event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None)
if msc3970_enabled and event_device_id is not None:
if event_device_id is not None:
if event_device_id == config.requester.device_id:
d["unsigned"]["transaction_id"] = txn_id
else:
# The pre-MSC3970 behaviour is to only include the transaction ID if the
# event was sent from the same access token. For regular users, we can use
# the access token ID to determine this. For guests, we can't, but since
# each guest only has one access token, we can just check that the event was
# sent by the same user as the one requesting the event.
# Fallback behaviour: only include the transaction ID if the event
# was sent from the same access token.
#
# For regular users, the access token ID can be used to determine this.
# This includes access tokens minted with the admin API.
#
# For guests and appservice users, we can't check the access token ID
# so assume it is the same session.
event_token_id: Optional[int] = getattr(
e.internal_metadata, "token_id", None
)
if config.requester.user.to_string() == e.sender and (
if (
(
event_token_id is not None
and config.requester.access_token_id is not None
and event_token_id == config.requester.access_token_id
)
or config.requester.is_guest
or config.requester.app_service
):
d["unsigned"]["transaction_id"] = txn_id
......@@ -420,6 +518,17 @@ def serialize_event(
if config.as_client_event:
d = config.event_format(d)
# If the event is a redaction, the field with the redacted event ID appears
# in a different location depending on the room version. e.redacts handles
# fetching from the proper location; copy it to the other location for forwards-
# and backwards-compatibility with clients.
if e.type == EventTypes.Redaction and e.redacts is not None:
if e.room_version.updated_redaction_rules:
d["redacts"] = e.redacts
else:
d["content"] = dict(d["content"])
d["content"]["redacts"] = e.redacts
only_event_fields = config.only_event_fields
if only_event_fields:
if not isinstance(only_event_fields, list) or not all(
......@@ -438,10 +547,13 @@ class EventClientSerializer:
clients.
"""
def __init__(self, *, msc3970_enabled: bool = False):
self._msc3970_enabled = msc3970_enabled
def __init__(self, hs: "HomeServer") -> None:
self._store = hs.get_datastores().main
self._add_extra_fields_to_unsigned_client_event_callbacks: List[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = []
def serialize_event(
async def serialize_event(
self,
event: Union[JsonDict, EventBase],
time_now: int,
......@@ -465,14 +577,23 @@ class EventClientSerializer:
if not isinstance(event, EventBase):
return event
serialized_event = serialize_event(
event, time_now, config=config, msc3970_enabled=self._msc3970_enabled
)
serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {}
for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
u = await callback(event)
new_unsigned.update(u)
if new_unsigned:
# We do the `update` this way round so that modules can't clobber
# existing fields.
new_unsigned.update(serialized_event["unsigned"])
serialized_event["unsigned"] = new_unsigned
# Check if there are any bundled aggregations to include with the event.
if bundle_aggregations:
if event.event_id in bundle_aggregations:
self._inject_bundled_aggregations(
await self._inject_bundled_aggregations(
event,
time_now,
config,
......@@ -482,7 +603,7 @@ class EventClientSerializer:
return serialized_event
def _inject_bundled_aggregations(
async def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
......@@ -513,9 +634,9 @@ class EventClientSerializer:
serialized_aggregations = {}
if event_aggregations.references:
serialized_aggregations[
RelationTypes.REFERENCE
] = event_aggregations.references
serialized_aggregations[RelationTypes.REFERENCE] = (
event_aggregations.references
)
if event_aggregations.replace:
# Include information about it in the relations dict.
......@@ -524,7 +645,7 @@ class EventClientSerializer:
# said that we should only include the `event_id`, `origin_server_ts` and
# `sender` of the edit; however MSC3925 proposes extending it to the whole
# of the edit, which is what we do here.
serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
event_aggregations.replace,
time_now,
config=config,
......@@ -534,7 +655,7 @@ class EventClientSerializer:
if event_aggregations.thread:
thread = event_aggregations.thread
serialized_latest_event = self.serialize_event(
serialized_latest_event = await self.serialize_event(
thread.latest_event,
time_now,
config=config,
......@@ -557,7 +678,7 @@ class EventClientSerializer:
"m.relations", {}
).update(serialized_aggregations)
def serialize_events(
async def serialize_events(
self,
events: Iterable[Union[JsonDict, EventBase]],
time_now: int,
......@@ -579,7 +700,7 @@ class EventClientSerializer:
The list of serialized events
"""
return [
self.serialize_event(
await self.serialize_event(
event,
time_now,
config=config,
......@@ -588,6 +709,14 @@ class EventClientSerializer:
for event in events
]
def register_add_extra_fields_to_unsigned_client_event_callback(
self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
) -> None:
"""Register a callback that returns additions to the unsigned section of
serialized events.
"""
self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
......@@ -636,7 +765,7 @@ def _copy_power_level_value_as_integer(
:raises TypeError: if `old_value` is neither an integer nor a base-10 string
representation of an integer.
"""
if type(old_value) is int:
if type(old_value) is int: # noqa: E721
power_levels[key] = old_value
return
......@@ -664,7 +793,7 @@ def validate_canonicaljson(value: Any) -> None:
* Floats
* NaN, Infinity, -Infinity
"""
if type(value) is int:
if type(value) is int: # noqa: E721
if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value:
raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON)
......@@ -707,3 +836,48 @@ def maybe_upsert_event_field(
del container[key]
return upsert_okay
def strip_event(event: EventBase) -> JsonDict:
"""
Used for "stripped state" events which provide a simplified view of the state of a
room intended to help a potential joiner identify the room (relevant when the user
is invited or knocked).
Stripped state events can only have the `sender`, `type`, `state_key` and `content`
properties present.
"""
return {
"type": event.type,
"state_key": event.state_key,
"content": event.content,
"sender": event.sender,
}
def parse_stripped_state_event(raw_stripped_event: Any) -> Optional[StrippedStateEvent]:
"""
Given a raw value from an event's `unsigned` field, attempt to parse it into a
`StrippedStateEvent`.
"""
if isinstance(raw_stripped_event, dict):
# All of these fields are required
type = raw_stripped_event.get("type")
state_key = raw_stripped_event.get("state_key")
sender = raw_stripped_event.get("sender")
content = raw_stripped_event.get("content")
if (
isinstance(type, str)
and isinstance(state_key, str)
and isinstance(sender, str)
and isinstance(content, dict)
):
return StrippedStateEvent(
type=type,
state_key=state_key,
sender=sender,
content=content,
)
return None
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Iterable, List, Type, Union, cast
from typing import List, Type, Union, cast
import jsonschema
from pydantic import Field, StrictBool, StrictStr
from synapse._pydantic_compat import Field, StrictBool, StrictStr
from synapse.api.constants import (
MAX_ALIAS_LENGTH,
EventContentFields,
......@@ -33,10 +40,10 @@ from synapse.events.utils import (
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
from synapse.types.rest import RequestBodyModel
class EventValidator:
......@@ -51,7 +58,7 @@ class EventValidator:
event: The event to validate.
config: The homeserver's configuration.
"""
self.validate_builder(event)
self.validate_builder(event, config)
if event.format_version == EventFormatVersions.ROOM_V1_V2:
EventID.from_string(event.event_id)
......@@ -82,6 +89,12 @@ class EventValidator:
# Note that only the client controlled portion of the event is
# checked, since we trust the portions of the event we created.
validate_canonicaljson(event.content)
if not 0 < event.origin_server_ts < 2**53:
raise SynapseError(400, "Event timestamp is out of range")
# meow: allow specific users to send potentially dangerous events.
if event.sender in config.meow.validation_override:
return
if event.type == EventTypes.Aliases:
if "aliases" in event.content:
......@@ -100,7 +113,10 @@ class EventValidator:
self._validate_retention(event)
elif event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server.server_name, event):
server_acl_evaluator = server_acl_evaluator_from_event(event)
if not server_acl_evaluator.server_matches_acl_event(
config.server.server_name
):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)
......@@ -134,13 +150,8 @@ class EventValidator:
)
# If the event contains a mentions key, validate it.
if (
EventContentFields.MSC3952_MENTIONS in event.content
and config.experimental.msc3952_intentional_mentions
):
validate_json_object(
event.content[EventContentFields.MSC3952_MENTIONS], Mentions
)
if EventContentFields.MENTIONS in event.content:
validate_json_object(event.content[EventContentFields.MENTIONS], Mentions)
def _validate_retention(self, event: EventBase) -> None:
"""Checks that an event that defines the retention policy for a room respects the
......@@ -156,7 +167,7 @@ class EventValidator:
max_lifetime = event.content.get("max_lifetime")
if min_lifetime is not None:
if type(min_lifetime) is not int:
if type(min_lifetime) is not int: # noqa: E721
raise SynapseError(
code=400,
msg="'min_lifetime' must be an integer",
......@@ -164,7 +175,7 @@ class EventValidator:
)
if max_lifetime is not None:
if type(max_lifetime) is not int:
if type(max_lifetime) is not int: # noqa: E721
raise SynapseError(
code=400,
msg="'max_lifetime' must be an integer",
......@@ -182,7 +193,9 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
def validate_builder(
self, event: Union[EventBase, EventBuilder], config: HomeServerConfig
) -> None:
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
......@@ -200,6 +213,10 @@ class EventValidator:
RoomID.from_string(event.room_id)
UserID.from_string(event.sender)
# meow: allow specific users to send so-called invalid events
if event.sender in config.meow.validation_override:
return
if event.type == EventTypes.Message:
strings = ["body", "msgtype"]
......@@ -230,7 +247,7 @@ class EventValidator:
self._ensure_state_event(event)
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
......
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This package includes all the federation specific logic.
"""
"""This package includes all the federation specific logic."""
# Copyright 2015, 2016 OpenMarket Ltd
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Optional
......@@ -49,7 +56,7 @@ class FederationBase:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.keyring = hs.get_keyring()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.store = hs.get_datastores().main
......@@ -231,7 +238,7 @@ async def _check_sigs_on_pdu(
# If this is a join event for a restricted room it may have been authorised
# via a different server from the sending server. Check those signatures.
if (
room_version.msc3083_join_rules
room_version.restricted_join_rule
and pdu.type == EventTypes.Member
and pdu.membership == Membership.JOIN
and EventContentFields.AUTHORISING_USER in pdu.content
......@@ -280,7 +287,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
_strip_unsigned_values(pdu_json)
depth = pdu_json["depth"]
if type(depth) is not int:
if type(depth) is not int: # noqa: E721
raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
if depth < 0:
......
# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 Sorunome
# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
......@@ -21,6 +28,7 @@ from typing import (
TYPE_CHECKING,
AbstractSet,
Awaitable,
BinaryIO,
Callable,
Collection,
Container,
......@@ -48,6 +56,7 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
......@@ -64,7 +73,7 @@ from synapse.federation.transport.client import SendJoinResponse
from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
......@@ -235,11 +244,16 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: Optional[int]
self,
user: UserID,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Args:
user: The user id of the requesting user
destination: Domain name of the remote homeserver
content: The query content.
......@@ -247,8 +261,55 @@ class FederationClient(FederationBase):
The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
use_unstable = False
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
# If more than one algorithm is requested, attempt to use the unstable
# endpoint.
if sum(algorithms.values()) > 1:
use_unstable = True
if algorithms:
# For the stable query, choose only the first algorithm.
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
# For the unstable query, repeat each algorithm by count, then
# splat those into chain to get a flattened list of all algorithms.
#
# Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
unstable_content.setdefault(user_id, {})[device_id] = list(
itertools.chain(
*(
itertools.repeat(algorithm, count)
for algorithm, count in algorithms.items()
)
)
)
if use_unstable:
try:
return await self.transport_layer.claim_client_keys_unstable(
user, destination, unstable_content, timeout
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
)
else:
logger.debug("Skipping unstable claim client keys API")
# TODO Potentially attempt multiple queries and combine the results?
return await self.transport_layer.claim_client_keys(
destination, content, timeout
user, destination, content, timeout
)
@trace
......@@ -807,7 +868,7 @@ class FederationClient(FederationBase):
for destination in destinations:
# We don't want to ask our own server for information we don't have
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue
try:
......@@ -931,7 +992,7 @@ class FederationClient(FederationBase):
if not room_version:
raise UnsupportedRoomVersionError()
if not room_version.msc2403_knocking and membership == Membership.KNOCK:
if not room_version.knock_join_rule and membership == Membership.KNOCK:
raise SynapseError(
400,
"This room version does not support knocking",
......@@ -1017,7 +1078,7 @@ class FederationClient(FederationBase):
# * Ensure the signatures are good.
#
# Otherwise, fallback to the provided event.
if room_version.msc3083_join_rules and response.event:
if room_version.restricted_join_rule and response.event:
event = response.event
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
......@@ -1097,7 +1158,7 @@ class FederationClient(FederationBase):
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
s.internal_metadata = s.internal_metadata.copy()
# double-check that the auth chain doesn't include a different create event
auth_chain_create_events = [
......@@ -1143,7 +1204,7 @@ class FederationClient(FederationBase):
# MSC3083 defines additional error codes for room joins.
failover_errcodes = None
if room_version.msc3083_join_rules:
if room_version.restricted_join_rule:
failover_errcodes = (
Codes.UNABLE_AUTHORISE_JOIN,
Codes.UNABLE_TO_GRANT_JOIN,
......@@ -1350,7 +1411,7 @@ class FederationClient(FederationBase):
The remote homeserver return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
{"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
......@@ -1377,7 +1438,7 @@ class FederationClient(FederationBase):
The remote homeserver can optionally return some state from the room. The response
dictionary is in the form:
{"knock_state_events": [<state event dict>, ...]}
{"knock_room_state": [<state event dict>, ...]}
The list of state events may be empty.
"""
......@@ -1489,7 +1550,7 @@ class FederationClient(FederationBase):
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations:
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue
try:
......@@ -1652,7 +1713,7 @@ class FederationClient(FederationBase):
async def timestamp_to_event(
self,
*,
destinations: List[str],
destinations: StrCollection,
room_id: str,
timestamp: int,
direction: Direction,
......@@ -1810,6 +1871,95 @@ class FederationClient(FederationBase):
return filtered_statuses, filtered_failures
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the r0 endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't download media %s/%s with the v3 API, falling back to the r0 API",
destination,
media_id,
)
return await self.transport_layer.download_media_r0(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:
......@@ -1839,7 +1989,7 @@ class TimestampToEventResponse:
)
origin_server_ts = d.get("origin_server_ts")
if type(origin_server_ts) is not int:
if type(origin_server_ts) is not int: # noqa: E721
raise ValueError(
"Invalid response: 'origin_server_ts' must be a int but received %r"
% origin_server_ts
......