From 7c6d25dcd165ffa3535ba103ab5ffb4acbe8c558 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Mon, 26 Jun 2023 21:57:59 +0200
Subject: [PATCH] Do state res even if the event soft fails

---
 src/service/rooms/event_handler/mod.rs | 310 ++++++++++---------------
 1 file changed, 129 insertions(+), 181 deletions(-)

diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs
index b01a28283..800d849d9 100644
--- a/src/service/rooms/event_handler/mod.rs
+++ b/src/service/rooms/event_handler/mod.rs
@@ -38,6 +38,8 @@
 
 use crate::{service::*, services, Error, PduEvent, Result};
 
+use super::state_compressor::CompressedStateEvent;
+
 pub struct Service;
 
 impl Service {
@@ -62,9 +64,8 @@ impl Service {
     /// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by
     ///     doing state res where one of the inputs was a previously trusted set of state, don't just
     ///     trust a set of state we got from a remote)
-    /// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail"
-    ///     it
-    /// 14. Use state resolution to find new room state
+    /// 13. Use state resolution to find new room state
+    /// 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it
     // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively
     #[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))]
     pub(crate) async fn handle_incoming_pdu<'a>(
@@ -304,7 +305,7 @@ fn handle_outlier_pdu<'a>(
             ) {
                 Err(e) => {
                     // Drop
-                    warn!("Dropping bad event {}: {}", event_id, e);
+                    warn!("Dropping bad event {}: {}", event_id, e,);
                     return Err(Error::BadRequest(
                         ErrorKind::InvalidParam,
                         "Signature verification failed",
@@ -735,8 +736,9 @@ pub async fn upgrade_outlier_to_timeline_pdu(
         }
         info!("Auth check succeeded");
 
-        // We start looking at current room state now, so lets lock the room
+        // 13. Use state resolution to find new room state
 
+        // We start looking at current room state now, so lets lock the room
         let mutex_state = Arc::clone(
             services()
                 .globals
@@ -782,7 +784,40 @@ pub async fn upgrade_outlier_to_timeline_pdu(
             })
             .collect::<Result<_>>()?;
 
-        // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
+        if incoming_pdu.state_key.is_some() {
+            info!("Preparing for stateres to derive new room state");
+
+            // We also add state after incoming event to the fork states
+            let mut state_after = state_at_incoming_event.clone();
+            if let Some(state_key) = &incoming_pdu.state_key {
+                let shortstatekey = services().rooms.short.get_or_create_shortstatekey(
+                    &incoming_pdu.kind.to_string().into(),
+                    state_key,
+                )?;
+
+                state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
+            }
+
+            let new_room_state = self
+                .resolve_state(room_id, room_version_id, state_after)
+                .await?;
+
+            // Set the new room state to the resolved state
+            info!("Forcing new room state");
+
+            let (sstatehash, new, removed) = services()
+                .rooms
+                .state_compressor
+                .save_state(room_id, new_room_state)?;
+
+            services()
+                .rooms
+                .state
+                .force_state(room_id, sstatehash, new, removed, &state_lock)
+                .await?;
+        }
+
+        // 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it
         info!("Starting soft fail auth check");
 
         let auth_events = services().rooms.state.get_auth_events(
@@ -823,181 +858,6 @@ pub async fn upgrade_outlier_to_timeline_pdu(
             ));
         }
 
-        if incoming_pdu.state_key.is_some() {
-            info!("Loading current room state ids");
-            let current_sstatehash = services()
-                .rooms
-                .state
-                .get_room_shortstatehash(room_id)?
-                .expect("every room has state");
-
-            let current_state_ids = services()
-                .rooms
-                .state_accessor
-                .state_full_ids(current_sstatehash)
-                .await?;
-
-            info!("Preparing for stateres to derive new room state");
-            let mut extremity_sstatehashes = HashMap::new();
-
-            info!(?extremities, "Loading extremities");
-            for id in &extremities {
-                match services().rooms.timeline.get_pdu(id)? {
-                    Some(leaf_pdu) => {
-                        extremity_sstatehashes.insert(
-                            services()
-                                .rooms
-                                .state_accessor
-                                .pdu_shortstatehash(&leaf_pdu.event_id)?
-                                .ok_or_else(|| {
-                                    error!(
-                                        "Found extremity pdu with no statehash in db: {:?}",
-                                        leaf_pdu
-                                    );
-                                    Error::bad_database("Found pdu with no statehash in db.")
-                                })?,
-                            leaf_pdu,
-                        );
-                    }
-                    _ => {
-                        error!("Missing state snapshot for {:?}", id);
-                        return Err(Error::BadDatabase("Missing state snapshot."));
-                    }
-                }
-            }
-
-            let mut fork_states = Vec::new();
-
-            // 12. Ensure that the state is derived from the previous current state (i.e. we calculated
-            //     by doing state res where one of the inputs was a previously trusted set of state,
-            //     don't just trust a set of state we got from a remote).
-
-            // We do this by adding the current state to the list of fork states
-            extremity_sstatehashes.remove(&current_sstatehash);
-            fork_states.push(current_state_ids);
-
-            // We also add state after incoming event to the fork states
-            let mut state_after = state_at_incoming_event.clone();
-            if let Some(state_key) = &incoming_pdu.state_key {
-                let shortstatekey = services().rooms.short.get_or_create_shortstatekey(
-                    &incoming_pdu.kind.to_string().into(),
-                    state_key,
-                )?;
-
-                state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
-            }
-            fork_states.push(state_after);
-
-            let mut update_state = false;
-            // 14. Use state resolution to find new room state
-            let new_room_state = if fork_states.is_empty() {
-                panic!("State is empty");
-            } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) {
-                info!("State resolution trivial");
-                // There was only one state, so it has to be the room's current state (because that is
-                // always included)
-                fork_states[0]
-                    .iter()
-                    .map(|(k, id)| {
-                        services()
-                            .rooms
-                            .state_compressor
-                            .compress_state_event(*k, id)
-                    })
-                    .collect::<Result<_>>()?
-            } else {
-                info!("Loading auth chains");
-                // We do need to force an update to this room's state
-                update_state = true;
-
-                let mut auth_chain_sets = Vec::new();
-                for state in &fork_states {
-                    auth_chain_sets.push(
-                        services()
-                            .rooms
-                            .auth_chain
-                            .get_auth_chain(
-                                room_id,
-                                state.iter().map(|(_, id)| id.clone()).collect(),
-                            )
-                            .await?
-                            .collect(),
-                    );
-                }
-
-                info!("Loading fork states");
-
-                let fork_states: Vec<_> = fork_states
-                    .into_iter()
-                    .map(|map| {
-                        map.into_iter()
-                            .filter_map(|(k, id)| {
-                                services()
-                                    .rooms
-                                    .short
-                                    .get_statekey_from_short(k)
-                                    .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
-                                    .ok()
-                            })
-                            .collect::<StateMap<_>>()
-                    })
-                    .collect();
-
-                info!("Resolving state");
-
-                let lock = services().globals.stateres_mutex.lock();
-                let state = match state_res::resolve(
-                    room_version_id,
-                    &fork_states,
-                    auth_chain_sets,
-                    |id| {
-                        let res = services().rooms.timeline.get_pdu(id);
-                        if let Err(e) = &res {
-                            error!("LOOK AT ME Failed to fetch event: {}", e);
-                        }
-                        res.ok().flatten()
-                    },
-                ) {
-                    Ok(new_state) => new_state,
-                    Err(_) => {
-                        return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization"));
-                    }
-                };
-
-                drop(lock);
-
-                info!("State resolution done. Compressing state");
-
-                state
-                    .into_iter()
-                    .map(|((event_type, state_key), event_id)| {
-                        let shortstatekey = services().rooms.short.get_or_create_shortstatekey(
-                            &event_type.to_string().into(),
-                            &state_key,
-                        )?;
-                        services()
-                            .rooms
-                            .state_compressor
-                            .compress_state_event(shortstatekey, &event_id)
-                    })
-                    .collect::<Result<_>>()?
-            };
-
-            // Set the new room state to the resolved state
-            if update_state {
-                info!("Forcing new room state");
-                let (sstatehash, new, removed) = services()
-                    .rooms
-                    .state_compressor
-                    .save_state(room_id, new_room_state)?;
-                services()
-                    .rooms
-                    .state
-                    .force_state(room_id, sstatehash, new, removed, &state_lock)
-                    .await?;
-            }
-        }
-
         info!("Appending pdu to timeline");
         extremities.insert(incoming_pdu.event_id.clone());
 
@@ -1021,6 +881,94 @@ pub async fn upgrade_outlier_to_timeline_pdu(
         Ok(pdu_id)
     }
 
+    async fn resolve_state(
+        &self,
+        room_id: &RoomId,
+        room_version_id: &RoomVersionId,
+        incoming_state: HashMap<u64, Arc<EventId>>,
+    ) -> Result<HashSet<CompressedStateEvent>> {
+        info!("Loading current room state ids");
+        let current_sstatehash = services()
+            .rooms
+            .state
+            .get_room_shortstatehash(room_id)?
+            .expect("every room has state");
+
+        let current_state_ids = services()
+            .rooms
+            .state_accessor
+            .state_full_ids(current_sstatehash)
+            .await?;
+
+        let fork_states = [current_state_ids, incoming_state];
+
+        let mut auth_chain_sets = Vec::new();
+        for state in &fork_states {
+            auth_chain_sets.push(
+                services()
+                    .rooms
+                    .auth_chain
+                    .get_auth_chain(room_id, state.iter().map(|(_, id)| id.clone()).collect())
+                    .await?
+                    .collect(),
+            );
+        }
+
+        info!("Loading fork states");
+
+        let fork_states: Vec<_> = fork_states
+            .into_iter()
+            .map(|map| {
+                map.into_iter()
+                    .filter_map(|(k, id)| {
+                        services()
+                            .rooms
+                            .short
+                            .get_statekey_from_short(k)
+                            .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
+                            .ok()
+                    })
+                    .collect::<StateMap<_>>()
+            })
+            .collect();
+
+        info!("Resolving state");
+
+        let lock = services().globals.stateres_mutex.lock();
+        let state = match state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
+            let res = services().rooms.timeline.get_pdu(id);
+            if let Err(e) = &res {
+                error!("LOOK AT ME Failed to fetch event: {}", e);
+            }
+            res.ok().flatten()
+        }) {
+            Ok(new_state) => new_state,
+            Err(_) => {
+                return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization"));
+            }
+        };
+
+        drop(lock);
+
+        info!("State resolution done. Compressing state");
+
+        let new_room_state = state
+            .into_iter()
+            .map(|((event_type, state_key), event_id)| {
+                let shortstatekey = services()
+                    .rooms
+                    .short
+                    .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
+                services()
+                    .rooms
+                    .state_compressor
+                    .compress_state_event(shortstatekey, &event_id)
+            })
+            .collect::<Result<_>>()?;
+
+        Ok(new_room_state)
+    }
+
     /// Find the event and auth it. Once the event is validated (steps 1 - 8)
     /// it is appended to the outliers Tree.
     ///
-- 
GitLab