From e80dad5fa9ccc9fb7645c043a1e1995065c4bb2a Mon Sep 17 00:00:00 2001
From: Devon Hudson <devon.dmytro@gmail.com>
Date: Thu, 14 Nov 2024 16:18:24 +0000
Subject: [PATCH] Move server event filtering logic to rust (#17928)

### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [X] Pull request is based on the develop branch
* [X] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [X] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct
(run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))
---
 changelog.d/17928.misc          |   1 +
 rust/src/events/filter.rs       | 107 ++++++++++++++++++++++++++++++++
 rust/src/events/mod.rs          |   4 +-
 rust/src/identifier.rs          |  86 +++++++++++++++++++++++++
 rust/src/lib.rs                 |   2 +
 rust/src/matrix_const.rs        |  28 +++++++++
 rust/src/push/utils.rs          |   1 -
 synapse/synapse_rust/events.pyi |  28 ++++++++-
 synapse/visibility.py           |  66 ++++----------------
 9 files changed, 265 insertions(+), 58 deletions(-)
 create mode 100644 changelog.d/17928.misc
 create mode 100644 rust/src/events/filter.rs
 create mode 100644 rust/src/identifier.rs
 create mode 100644 rust/src/matrix_const.rs

diff --git a/changelog.d/17928.misc b/changelog.d/17928.misc
new file mode 100644
index 0000000000..b5aef4457a
--- /dev/null
+++ b/changelog.d/17928.misc
@@ -0,0 +1 @@
+Move server event filtering logic to rust.
diff --git a/rust/src/events/filter.rs b/rust/src/events/filter.rs
new file mode 100644
index 0000000000..7e39972c62
--- /dev/null
+++ b/rust/src/events/filter.rs
@@ -0,0 +1,107 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+use std::collections::HashMap;
+
+use pyo3::{exceptions::PyValueError, pyfunction, PyResult};
+
+use crate::{
+    identifier::UserID,
+    matrix_const::{
+        HISTORY_VISIBILITY_INVITED, HISTORY_VISIBILITY_JOINED, MEMBERSHIP_INVITE, MEMBERSHIP_JOIN,
+    },
+};
+
+#[pyfunction(name = "event_visible_to_server")]
+pub fn event_visible_to_server_py(
+    sender: String,
+    target_server_name: String,
+    history_visibility: String,
+    erased_senders: HashMap<String, bool>,
+    partial_state_invisible: bool,
+    memberships: Vec<(String, String)>, // (state_key, membership)
+) -> PyResult<bool> {
+    event_visible_to_server(
+        sender,
+        target_server_name,
+        history_visibility,
+        erased_senders,
+        partial_state_invisible,
+        memberships,
+    )
+    .map_err(|e| PyValueError::new_err(format!("{e}")))
+}
+
+/// Return whether the target server is allowed to see the event.
+///
+/// For a fully stated room, the target server is allowed to see an event E if:
+///   - the state at E has world readable or shared history vis, OR
+///   - the state at E says that the target server is in the room.
+///
+/// For a partially stated room, the target server is allowed to see E if:
+///   - E was created by this homeserver, AND:
+///       - the partial state at E has world readable or shared history vis, OR
+///       - the partial state at E says that the target server is in the room.
+pub fn event_visible_to_server(
+    sender: String,
+    target_server_name: String,
+    history_visibility: String,
+    erased_senders: HashMap<String, bool>,
+    partial_state_invisible: bool,
+    memberships: Vec<(String, String)>, // (state_key, membership)
+) -> anyhow::Result<bool> {
+    if let Some(&erased) = erased_senders.get(&sender) {
+        if erased {
+            return Ok(false);
+        }
+    }
+
+    if partial_state_invisible {
+        return Ok(false);
+    }
+
+    if history_visibility != HISTORY_VISIBILITY_INVITED
+        && history_visibility != HISTORY_VISIBILITY_JOINED
+    {
+        return Ok(true);
+    }
+
+    let mut visible = false;
+    for (state_key, membership) in memberships {
+        let state_key = UserID::try_from(state_key.as_ref())
+            .map_err(|e| anyhow::anyhow!(format!("invalid user_id ({state_key}): {e}")))?;
+        if state_key.server_name() != target_server_name {
+            return Err(anyhow::anyhow!(
+                "state_key.server_name ({}) does not match target_server_name ({target_server_name})",
+                state_key.server_name()
+            ));
+        }
+
+        match membership.as_str() {
+            MEMBERSHIP_INVITE => {
+                if history_visibility == HISTORY_VISIBILITY_INVITED {
+                    visible = true;
+                    break;
+                }
+            }
+            MEMBERSHIP_JOIN => {
+                visible = true;
+                break;
+            }
+            _ => continue,
+        }
+    }
+
+    Ok(visible)
+}
diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs
index a4ade1a178..0bb6cdb181 100644
--- a/rust/src/events/mod.rs
+++ b/rust/src/events/mod.rs
@@ -22,15 +22,17 @@
 
 use pyo3::{
     types::{PyAnyMethods, PyModule, PyModuleMethods},
-    Bound, PyResult, Python,
+    wrap_pyfunction, Bound, PyResult, Python,
 };
 
+pub mod filter;
 mod internal_metadata;
 
 /// Called when registering modules with python.
 pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
     let child_module = PyModule::new_bound(py, "events")?;
     child_module.add_class::<internal_metadata::EventInternalMetadata>()?;
+    child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?;
 
     m.add_submodule(&child_module)?;
 
diff --git a/rust/src/identifier.rs b/rust/src/identifier.rs
new file mode 100644
index 0000000000..b199c5838e
--- /dev/null
+++ b/rust/src/identifier.rs
@@ -0,0 +1,86 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+//! # Matrix Identifiers
+//!
+//! This module contains definitions and utilities for working with matrix identifiers.
+
+use std::{fmt, ops::Deref};
+
+/// Errors that can occur when parsing a matrix identifier.
+#[derive(Clone, Debug, PartialEq)]
+pub enum IdentifierError {
+    IncorrectSigil,
+    MissingColon,
+}
+
+impl fmt::Display for IdentifierError {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{:?}", self)
+    }
+}
+
+/// A Matrix user_id.
+#[derive(Clone, Debug, PartialEq)]
+pub struct UserID(String);
+
+impl UserID {
+    /// Returns the `localpart` of the user_id.
+    pub fn localpart(&self) -> &str {
+        &self[1..self.colon_pos()]
+    }
+
+    /// Returns the `server_name` / `domain` of the user_id.
+    pub fn server_name(&self) -> &str {
+        &self[self.colon_pos() + 1..]
+    }
+
+    /// Returns the position of the ':' inside of the user_id.
+    /// Used when splitting the user_id into it's respective parts.
+    fn colon_pos(&self) -> usize {
+        self.find(':').unwrap()
+    }
+}
+
+impl TryFrom<&str> for UserID {
+    type Error = IdentifierError;
+
+    /// Will try creating a `UserID` from the provided `&str`.
+    /// Can fail if the user_id is incorrectly formatted.
+    fn try_from(s: &str) -> Result<Self, Self::Error> {
+        if !s.starts_with('@') {
+            return Err(IdentifierError::IncorrectSigil);
+        }
+
+        if s.find(':').is_none() {
+            return Err(IdentifierError::MissingColon);
+        }
+
+        Ok(UserID(s.to_string()))
+    }
+}
+
+impl Deref for UserID {
+    type Target = str;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl fmt::Display for UserID {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
diff --git a/rust/src/lib.rs b/rust/src/lib.rs
index 06477880b9..5de9238326 100644
--- a/rust/src/lib.rs
+++ b/rust/src/lib.rs
@@ -6,6 +6,8 @@ pub mod acl;
 pub mod errors;
 pub mod events;
 pub mod http;
+pub mod identifier;
+pub mod matrix_const;
 pub mod push;
 pub mod rendezvous;
 
diff --git a/rust/src/matrix_const.rs b/rust/src/matrix_const.rs
new file mode 100644
index 0000000000..f75f3bd7c3
--- /dev/null
+++ b/rust/src/matrix_const.rs
@@ -0,0 +1,28 @@
+/*
+ * This file is licensed under the Affero General Public License (AGPL) version 3.
+ *
+ * Copyright (C) 2024 New Vector, Ltd
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * See the GNU Affero General Public License for more details:
+ * <https://www.gnu.org/licenses/agpl-3.0.html>.
+ */
+
+//! # Matrix Constants
+//!
+//! This module contains definitions for constant values described by the matrix specification.
+
+pub const HISTORY_VISIBILITY_WORLD_READABLE: &str = "world_readable";
+pub const HISTORY_VISIBILITY_SHARED: &str = "shared";
+pub const HISTORY_VISIBILITY_INVITED: &str = "invited";
+pub const HISTORY_VISIBILITY_JOINED: &str = "joined";
+
+pub const MEMBERSHIP_BAN: &str = "ban";
+pub const MEMBERSHIP_LEAVE: &str = "leave";
+pub const MEMBERSHIP_KNOCK: &str = "knock";
+pub const MEMBERSHIP_INVITE: &str = "invite";
+pub const MEMBERSHIP_JOIN: &str = "join";
diff --git a/rust/src/push/utils.rs b/rust/src/push/utils.rs
index 28ebed62c8..59536c9954 100644
--- a/rust/src/push/utils.rs
+++ b/rust/src/push/utils.rs
@@ -23,7 +23,6 @@ use anyhow::bail;
 use anyhow::Context;
 use anyhow::Error;
 use lazy_static::lazy_static;
-use regex;
 use regex::Regex;
 use regex::RegexBuilder;
 
diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi
index 1682d0d151..7d3422572d 100644
--- a/synapse/synapse_rust/events.pyi
+++ b/synapse/synapse_rust/events.pyi
@@ -10,7 +10,7 @@
 # See the GNU Affero General Public License for more details:
 # <https://www.gnu.org/licenses/agpl-3.0.html>.
 
-from typing import Optional
+from typing import List, Mapping, Optional, Tuple
 
 from synapse.types import JsonDict
 
@@ -105,3 +105,29 @@ class EventInternalMetadata:
 
     def is_notifiable(self) -> bool:
         """Whether this event can trigger a push notification"""
+
+def event_visible_to_server(
+    sender: str,
+    target_server_name: str,
+    history_visibility: str,
+    erased_senders: Mapping[str, bool],
+    partial_state_invisible: bool,
+    memberships: List[Tuple[str, str]],
+) -> bool:
+    """Determine whether the server is allowed to see the unredacted event.
+
+    Args:
+        sender: The sender of the event.
+        target_server_name: The server we want to send the event to.
+        history_visibility: The history_visibility value at the event.
+        erased_senders: A mapping of users and whether they have requested erasure. If a
+            user is not in the map, it is treated as though they haven't requested erasure.
+        partial_state_invisible: Whether the event should be treated as invisible due to
+            the partial state status of the room.
+        memberships: A list of membership state information at the event for users
+            matching the `target_server_name`. Each list item must contain a tuple of
+            (state_key, membership).
+
+    Returns:
+        Whether the server is allowed to see the unredacted event.
+    """
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 3a2782bade..dc7b6e4065 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -27,7 +27,6 @@ from typing import (
     Final,
     FrozenSet,
     List,
-    Mapping,
     Optional,
     Sequence,
     Set,
@@ -48,6 +47,7 @@ from synapse.events.utils import clone_event, prune_event
 from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
+from synapse.synapse_rust.events import event_visible_to_server
 from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
 from synapse.util import Clock
@@ -628,17 +628,6 @@ async def filter_events_for_server(
     """Filter a list of events based on whether the target server is allowed to
     see them.
 
-    For a fully stated room, the target server is allowed to see an event E if:
-      - the state at E has world readable or shared history vis, OR
-      - the state at E says that the target server is in the room.
-
-    For a partially stated room, the target server is allowed to see E if:
-      - E was created by this homeserver, AND:
-          - the partial state at E has world readable or shared history vis, OR
-          - the partial state at E says that the target server is in the room.
-
-    TODO: state before or state after?
-
     Args:
         storage
         target_server_name
@@ -655,35 +644,6 @@ async def filter_events_for_server(
         The filtered events.
     """
 
-    def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
-        if erased_senders and erased_senders[event.sender]:
-            logger.info("Sender of %s has been erased, redacting", event.event_id)
-            return True
-        return False
-
-    def check_event_is_visible(
-        visibility: str, memberships: StateMap[EventBase]
-    ) -> bool:
-        if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED):
-            return True
-
-        # We now loop through all membership events looking for
-        # membership states for the requesting server to determine
-        # if the server is either in the room or has been invited
-        # into the room.
-        for ev in memberships.values():
-            assert get_domain_from_id(ev.state_key) == target_server_name
-
-            memtype = ev.membership
-            if memtype == Membership.JOIN:
-                return True
-            elif memtype == Membership.INVITE:
-                if visibility == HistoryVisibility.INVITED:
-                    return True
-
-        # server has no users in the room: redact
-        return False
-
     if filter_out_erased_senders:
         erased_senders = await storage.main.are_users_erased(e.sender for e in events)
     else:
@@ -726,20 +686,16 @@ async def filter_events_for_server(
         target_server_name,
     )
 
-    def include_event_in_output(e: EventBase) -> bool:
-        erased = is_sender_erased(e, erased_senders)
-        visible = check_event_is_visible(
-            event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
-        )
-
-        if e.event_id in partial_state_invisible_event_ids:
-            visible = False
-
-        return visible and not erased
-
     to_return = []
     for e in events:
-        if include_event_in_output(e):
+        if event_visible_to_server(
+            sender=e.sender,
+            target_server_name=target_server_name,
+            history_visibility=event_to_history_vis[e.event_id],
+            erased_senders=erased_senders,
+            partial_state_invisible=e.event_id in partial_state_invisible_event_ids,
+            memberships=list(event_to_memberships.get(e.event_id, {}).values()),
+        ):
             to_return.append(e)
         elif redact:
             to_return.append(prune_event(e))
@@ -796,7 +752,7 @@ async def _event_to_history_vis(
 
 async def _event_to_memberships(
     storage: StorageControllers, events: Collection[EventBase], server_name: str
-) -> Dict[str, StateMap[EventBase]]:
+) -> Dict[str, StateMap[Tuple[str, str]]]:
     """Get the remote membership list at each of the given events
 
     Returns a map from event id to state map, which will contain only membership events
@@ -849,7 +805,7 @@ async def _event_to_memberships(
 
     return {
         e_id: {
-            key: event_map[inner_e_id]
+            key: (event_map[inner_e_id].state_key, event_map[inner_e_id].membership)
             for key, inner_e_id in key_to_eid.items()
             if inner_e_id in event_map
         }
-- 
GitLab