Skip to content
Snippets Groups Projects
devicemessage.py 15.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • Patrick Cloke's avatar
    Patrick Cloke committed
    # This file is licensed under the Affero General Public License (AGPL) version 3.
    #
    
    # Copyright 2016 OpenMarket Ltd
    
    Patrick Cloke's avatar
    Patrick Cloke committed
    # Copyright (C) 2023 New Vector, Ltd
    #
    # This program is free software: you can redistribute it and/or modify
    # it under the terms of the GNU Affero General Public License as
    # published by the Free Software Foundation, either version 3 of the
    # License, or (at your option) any later version.
    #
    # See the GNU Affero General Public License for more details:
    # <https://www.gnu.org/licenses/agpl-3.0.html>.
    #
    # Originally licensed under the Apache License, Version 2.0:
    # <http://www.apache.org/licenses/LICENSE-2.0>.
    #
    # [This file includes modifications made by New Vector Limited]
    
    from http import HTTPStatus
    from typing import TYPE_CHECKING, Any, Dict, Optional
    
    from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
    
    from synapse.api.errors import Codes, SynapseError
    
    from synapse.api.ratelimiting import Ratelimiter
    
    from synapse.logging.context import run_in_background
    
    from synapse.logging.opentracing import (
    
        get_active_span_text_map,
    
    from synapse.replication.http.devices import (
        ReplicationMultiUserDevicesResyncRestServlet,
    )
    
    from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
    
    from synapse.util import json_encoder
    
    from synapse.util.stringutils import random_string
    
    
    if TYPE_CHECKING:
    
        from synapse.server import HomeServer
    
    logger = logging.getLogger(__name__)
    
    
    
    class DeviceMessageHandler:
    
        def __init__(self, hs: "HomeServer"):
    
            self.store = hs.get_datastores().main
    
            self.notifier = hs.get_notifier()
    
            self.is_mine = hs.is_mine
    
            if hs.config.experimental.msc3814_enabled:
                self.event_sources = hs.get_event_sources()
                self.device_handler = hs.get_device_handler()
    
            # We only need to poke the federation sender explicitly if its on the
            # same instance. Other federation sender instances will get notified by
            # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
            # in the to-device replication stream.
            self.federation_sender = None
            if hs.should_send_federation():
                self.federation_sender = hs.get_federation_sender()
    
            # If we can handle the to device EDUs we do so, otherwise we route them
            # to the appropriate worker.
            if hs.get_instance_name() in hs.config.worker.writers.to_device:
                hs.get_federation_registry().register_edu_handler(
    
                    EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
    
                )
            else:
                hs.get_federation_registry().register_instances_for_edu(
    
                    EduTypes.DIRECT_TO_DEVICE,
    
                    hs.config.worker.writers.to_device,
    
            # The handler to call when we think a user's device list might be out of
            # sync. We do all device list resyncing on the master instance, so if
            # we're on a worker we hit the device resync replication API.
            if hs.config.worker.worker_app is None:
    
                self._multi_user_device_resync = (
                    hs.get_device_handler().device_list_updater.multi_user_device_resync
    
                self._multi_user_device_resync = (
                    ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
    
            # a rate limiter for room key requests.  The keys are
            # (sending_user_id, sending_device_id).
    
            self._ratelimiter = Ratelimiter(
    
                clock=hs.get_clock(),
    
                cfg=hs.config.ratelimiting.rc_key_requests,
    
        async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
    
            """
            Handle receiving to-device messages from remote homeservers.
    
    
            Note that any errors thrown from this method will cause the federation /send
            request to receive an error response.
    
    
            Args:
                origin: The remote homeserver.
                content: The JSON dictionary containing the to-device messages.
            """
    
            local_messages = {}
            sender_user_id = content["sender"]
            if origin != get_domain_from_id(sender_user_id):
    
                    "Dropping device message from %r with spoofed sender %r",
    
    Amber Brown's avatar
    Amber Brown committed
                    origin,
                    sender_user_id,
    
                )
            message_type = content["type"]
            message_id = content["message_id"]
            for user_id, by_device in content["messages"].items():
    
                # we use UserID.from_string to catch invalid user ids
                if not self.is_mine(UserID.from_string(user_id)):
    
                    logger.warning("To-device message to non-local user %s", user_id)
    
                    raise SynapseError(400, "Not a user here")
    
    
                # Ratelimit key requests by the sending user.
                if message_type == ToDeviceEventTypes.RoomKeyRequest:
                    allowed, _ = await self._ratelimiter.can_do_action(
                        None, (sender_user_id, None)
                    )
                    if not allowed:
                        logger.info(
                            "Dropping room_key_request from %s to %s due to rate limit",
                            sender_user_id,
                            user_id,
                        )
                        continue
    
    
                messages_by_device = {
                    device_id: {
                        "content": message_content,
                        "type": message_type,
                        "sender": sender_user_id,
                    }
                    for device_id, message_content in by_device.items()
                }
    
                local_messages[user_id] = messages_by_device
    
    
                await self._check_for_unknown_devices(
    
                    message_type, sender_user_id, by_device
                )
    
            # Add messages to the database.
            # Retrieve the stream id of the last-processed to-device message.
            last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
    
                origin, message_id, local_messages
            )
    
    
            # Notify listeners that there are new to-device messages to process,
            # handing them the latest stream id.
    
            self.notifier.on_new_event(
    
                StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
    
        async def _check_for_unknown_devices(
    
            self,
            message_type: str,
            sender_user_id: str,
            by_device: Dict[str, Dict[str, Any]],
    
            """Checks inbound device messages for unknown remote devices, and if
    
            found marks the remote cache for the user as stale.
            """
    
            if message_type != "m.room_key_request":
                return
    
            # Get the sending device IDs
            requesting_device_ids = set()
            for message_content in by_device.values():
                device_id = message_content.get("requesting_device_id")
                requesting_device_ids.add(device_id)
    
            # Check if we are tracking the devices of the remote user.
    
            room_ids = await self.store.get_rooms_for_user(sender_user_id)
    
            if not room_ids:
                logger.info(
                    "Received device message from remote device we don't"
                    " share a room with: %s %s",
                    sender_user_id,
                    requesting_device_ids,
                )
                return
    
            # If we are tracking check that we know about the sending
            # devices.
    
            cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
    
    
            unknown_devices = requesting_device_ids - set(cached_devices)
            if unknown_devices:
                logger.info(
                    "Received device message from remote device not in our cache: %s %s",
                    sender_user_id,
                    unknown_devices,
                )
    
                await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
    
    
                # Immediately attempt a resync in the background
    
                run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
    
        async def send_device_message(
            self,
    
            requester: Requester,
    
            message_type: str,
            messages: Dict[str, Dict[str, JsonDict]],
        ) -> None:
    
            """
            Handle a request from a user to send to-device message(s).
    
            Args:
                requester: The user that is sending the to-device messages.
                message_type: The type of to-device messages that are being sent.
                messages: A dictionary containing recipients mapped to messages intended for them.
            """
    
            sender_user_id = requester.user.to_string()
    
    
            set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
            set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
    
            local_messages = {}
    
            remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
    
            for user_id, by_device in messages.items():
    
                if not UserID.is_valid(user_id):
                    logger.warning(
                        "Ignoring attempt to send device message to invalid user: %r",
                        user_id,
                    )
                    continue
    
    
                # add an opentracing log entry for each message
                for device_id, message_content in by_device.items():
                    log_kv(
                        {
                            "event": "send_to_device_message",
                            "user_id": user_id,
                            "device_id": device_id,
                            EventContentFields.TO_DEVICE_MSGID: message_content.get(
                                EventContentFields.TO_DEVICE_MSGID
                            ),
                        }
                    )
    
    
                # Ratelimit local cross-user key requests by the sending device.
                if (
    
                    message_type == ToDeviceEventTypes.RoomKeyRequest
    
                    and user_id != sender_user_id
    
                ):
                    allowed, _ = await self._ratelimiter.can_do_action(
    
                        requester, (sender_user_id, requester.device_id)
    
                        log_kv({"message": f"dropping key requests to {user_id}"})
    
                        logger.info(
                            "Dropping room_key_request from %s to %s due to rate limit",
                            sender_user_id,
                            user_id,
                        )
                        continue
    
                # we use UserID.from_string to catch invalid user ids
                if self.is_mine(UserID.from_string(user_id)):
    
                    messages_by_device = {
                        device_id: {
                            "content": message_content,
                            "type": message_type,
                            "sender": sender_user_id,
                        }
                        for device_id, message_content in by_device.items()
                    }
                    if messages_by_device:
                        local_messages[user_id] = messages_by_device
                else:
                    destination = get_domain_from_id(user_id)
                    remote_messages.setdefault(destination, {})[user_id] = by_device
    
    
            context = get_active_span_text_map()
    
    
            remote_edu_contents = {}
            for destination, messages in remote_messages.items():
    
                # The EDU contains a "message_id" property which is used for
                # idempotence. Make up a random one.
                message_id = random_string(16)
                log_kv({"destination": destination, "message_id": message_id})
    
    
                remote_edu_contents[destination] = {
                    "messages": messages,
                    "sender": sender_user_id,
                    "type": message_type,
                    "message_id": message_id,
                    "org.matrix.opentracing_context": json_encoder.encode(context),
                }
    
            # Add messages to the database.
            # Retrieve the stream id of the last-processed to-device message.
            last_stream_id = await self.store.add_messages_to_device_inbox(
    
                local_messages, remote_edu_contents
            )
    
    
            # Notify listeners that there are new to-device messages to process,
            # handing them the latest stream id.
    
            self.notifier.on_new_event(
    
                StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
    
            if self.federation_sender:
    
                # Enqueue a new federation transaction to send the new
                # device messages to each remote destination.
                await self.federation_sender.send_device_messages(remote_messages.keys())
    
    
        async def get_events_for_dehydrated_device(
            self,
            requester: Requester,
            device_id: str,
            since_token: Optional[str],
            limit: int,
        ) -> JsonDict:
            """Fetches up to `limit` events sent to `device_id` starting from `since_token`
            and returns the new since token. If there are no more messages, returns an empty
            array.
    
            Args:
                requester: the user requesting the messages
                device_id: ID of the dehydrated device
                since_token: stream id to start from when fetching messages
                limit: the number of messages to fetch
            Returns:
                A dict containing the to-device messages, as well as a token that the client
                can provide in the next call to fetch the next batch of messages
            """
    
            user_id = requester.user.to_string()
    
            # only allow fetching messages for the dehydrated device id currently associated
            # with the user
            dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
            if dehydrated_device is None:
                raise SynapseError(
                    HTTPStatus.FORBIDDEN,
                    "No dehydrated device exists",
                    Codes.FORBIDDEN,
                )
    
            dehydrated_device_id, _ = dehydrated_device
            if device_id != dehydrated_device_id:
                raise SynapseError(
                    HTTPStatus.FORBIDDEN,
                    "You may only fetch messages for your dehydrated device",
                    Codes.FORBIDDEN,
                )
    
            since_stream_id = 0
            if since_token:
                if not since_token.startswith("d"):
                    raise SynapseError(
                        HTTPStatus.BAD_REQUEST,
                        "from parameter %r has an invalid format" % (since_token,),
                        errcode=Codes.INVALID_PARAM,
                    )
    
                try:
                    since_stream_id = int(since_token[1:])
                except Exception:
                    raise SynapseError(
                        HTTPStatus.BAD_REQUEST,
                        "from parameter %r has an invalid format" % (since_token,),
                        errcode=Codes.INVALID_PARAM,
                    )
    
            to_token = self.event_sources.get_current_token().to_device_key
    
            messages, stream_id = await self.store.get_messages_for_device(
                user_id, device_id, since_stream_id, to_token, limit
            )
    
            for message in messages:
                # Remove the message id before sending to client
                message_id = message.pop("message_id", None)
                if message_id:
                    set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
    
            logger.debug(
                "Returning %d to-device messages between %d and %d (current token: %d) for "
                "dehydrated device %s, user_id %s",
                len(messages),
                since_stream_id,
                stream_id,
                to_token,
                device_id,
                user_id,
            )
    
            return {
                "events": messages,
                "next_batch": f"d{stream_id}",
            }