From 8a035880f097d885baed6e9ee179ccbe3db16881 Mon Sep 17 00:00:00 2001
From: Devin Ragotzy <devin.ragotzy@gmail.com>
Date: Wed, 6 Jan 2021 08:52:30 -0500
Subject: [PATCH] Remove StateStore trait from state-res collect events needed

---
 Cargo.lock                      |  70 ++++++++++----------
 Cargo.toml                      |   2 +-
 src/client_server/membership.rs |   4 --
 src/database/rooms.rs           | 100 +++++++++++++++++++----------
 src/server_server.rs            | 109 +++++++++++++++++++-------------
 5 files changed, 169 insertions(+), 116 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 7ef5efbef..f439e5189 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -216,9 +216,9 @@ dependencies = [
 
 [[package]]
 name = "const_fn"
-version = "0.4.4"
+version = "0.4.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cd51eab21ab4fd6a3bf889e2d0958c0a6e3a61ad04260325e919e652a2a62826"
+checksum = "28b9d6de7f49e22cf97ad17fc4036ece69300032f45f78f30b4a4482cdc3f4a6"
 
 [[package]]
 name = "constant_time_eq"
@@ -543,7 +543,7 @@ dependencies = [
  "futures-sink",
  "futures-task",
  "memchr",
- "pin-project 1.0.2",
+ "pin-project 1.0.3",
  "pin-utils",
  "proc-macro-hack",
  "proc-macro-nested",
@@ -567,18 +567,18 @@ checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
 dependencies = [
  "cfg-if 1.0.0",
  "libc",
- "wasi",
+ "wasi 0.9.0+wasi-snapshot-preview1",
 ]
 
 [[package]]
 name = "getrandom"
-version = "0.2.0"
+version = "0.2.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ee8025cf36f917e6a52cce185b7c7177689b838b7ec138364e50cc2277a56cf4"
+checksum = "4060f4657be78b8e766215b02b18a2e862d83745545de804638e2b545e81aee6"
 dependencies = [
- "cfg-if 0.1.10",
+ "cfg-if 1.0.0",
  "libc",
- "wasi",
+ "wasi 0.10.0+wasi-snapshot-preview1",
 ]
 
 [[package]]
@@ -707,7 +707,7 @@ dependencies = [
  "httparse",
  "httpdate",
  "itoa",
- "pin-project 1.0.2",
+ "pin-project 1.0.3",
  "socket2",
  "tokio",
  "tower-service",
@@ -1221,11 +1221,11 @@ dependencies = [
 
 [[package]]
 name = "pin-project"
-version = "1.0.2"
+version = "1.0.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9ccc2237c2c489783abd8c4c80e5450fc0e98644555b1364da68cc29aa151ca7"
+checksum = "5a83804639aad6ba65345661744708855f9fbcb71176ea8d28d05aeb11d975e7"
 dependencies = [
- "pin-project-internal 1.0.2",
+ "pin-project-internal 1.0.3",
 ]
 
 [[package]]
@@ -1241,9 +1241,9 @@ dependencies = [
 
 [[package]]
 name = "pin-project-internal"
-version = "1.0.2"
+version = "1.0.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8e8d2bf0b23038a4424865103a4df472855692821aab4e4f5c3312d461d9e5f"
+checksum = "b7bcc46b8f73443d15bc1c5fecbb315718491fa9187fa483f0e359323cde8b3a"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -1258,9 +1258,9 @@ checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b"
 
 [[package]]
 name = "pin-project-lite"
-version = "0.2.0"
+version = "0.2.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6b063f57ec186e6140e2b8b6921e5f1bd89c7356dda5b33acc5401203ca6131c"
+checksum = "e36743d754ccdf9954c2e352ce2d4b106e024c814f6499c2dadff80da9a442d8"
 
 [[package]]
 name = "pin-utils"
@@ -1365,13 +1365,13 @@ dependencies = [
 
 [[package]]
 name = "rand"
-version = "0.8.0"
+version = "0.8.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a76330fb486679b4ace3670f117bbc9e16204005c4bde9c4bd372f45bed34f12"
+checksum = "c24fcd450d3fa2b592732565aa4f17a27a61c65ece4726353e000939b0edee34"
 dependencies = [
  "libc",
  "rand_chacha 0.3.0",
- "rand_core 0.6.0",
+ "rand_core 0.6.1",
  "rand_hc 0.3.0",
 ]
 
@@ -1392,7 +1392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "e12735cf05c9e10bf21534da50a147b924d555dc7a547c42e6bb2d5b6017ae0d"
 dependencies = [
  "ppv-lite86",
- "rand_core 0.6.0",
+ "rand_core 0.6.1",
 ]
 
 [[package]]
@@ -1406,11 +1406,11 @@ dependencies = [
 
 [[package]]
 name = "rand_core"
-version = "0.6.0"
+version = "0.6.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a8b34ba8cfb21243bd8df91854c830ff0d785fff2e82ebd4434c2644cb9ada18"
+checksum = "c026d7df8b298d90ccbbc5190bd04d85e159eaf5576caeacf8741da93ccbd2e5"
 dependencies = [
- "getrandom 0.2.0",
+ "getrandom 0.2.1",
 ]
 
 [[package]]
@@ -1428,7 +1428,7 @@ version = "0.3.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "3190ef7066a446f2e7f42e239d161e905420ccab01eb967c9eb27d21b2322a73"
 dependencies = [
- "rand_core 0.6.0",
+ "rand_core 0.6.1",
 ]
 
 [[package]]
@@ -1518,7 +1518,7 @@ dependencies = [
  "mime_guess",
  "native-tls",
  "percent-encoding",
- "pin-project-lite 0.2.0",
+ "pin-project-lite 0.2.1",
  "serde",
  "serde_urlencoded",
  "tokio",
@@ -1758,7 +1758,7 @@ version = "0.17.4"
 source = "git+https://github.com/ruma/ruma?rev=210b6dd823ba89c5a44c3c9d913d377c4b54c896#210b6dd823ba89c5a44c3c9d913d377c4b54c896"
 dependencies = [
  "paste",
- "rand 0.8.0",
+ "rand 0.8.1",
  "ruma-identifiers-macros",
  "ruma-identifiers-validation",
  "ruma-serde",
@@ -1977,9 +1977,9 @@ dependencies = [
 
 [[package]]
 name = "serde_yaml"
-version = "0.8.14"
+version = "0.8.15"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f7baae0a99f1a324984bcdc5f0718384c1f69775f1c7eec8b859b71b443e3fd7"
+checksum = "971be8f6e4d4a47163b405a3df70d14359186f9ab0f3a3ec37df144ca1ce089f"
 dependencies = [
  "dtoa",
  "linked-hash-map",
@@ -2065,7 +2065,7 @@ checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483"
 [[package]]
 name = "state-res"
 version = "0.1.0"
-source = "git+https://github.com/ruma/state-res?branch=event-trait#e5d32e44adb66c5932a81d2c8a8d840abd17c870"
+source = "git+https://github.com/ruma/state-res?branch=no-db#d31c88408e7f69f5b0f18141efeaefff6b83637f"
 dependencies = [
  "itertools",
  "maplit",
@@ -2127,9 +2127,9 @@ checksum = "213701ba3370744dcd1a12960caa4843b3d68b4d1c0a5d575e0d65b2ee9d16c0"
 
 [[package]]
 name = "syn"
-version = "1.0.57"
+version = "1.0.58"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4211ce9909eb971f111059df92c45640aad50a619cf55cd76476be803c4c68e6"
+checksum = "cc60a3d73ea6594cd712d830cc1f0390fd71542d8c8cd24e70cc54cdfd5e05d5"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -2325,7 +2325,7 @@ checksum = "9f47026cdc4080c07e49b37087de021820269d996f581aac150ef9e5583eefe3"
 dependencies = [
  "cfg-if 1.0.0",
  "log",
- "pin-project-lite 0.2.0",
+ "pin-project-lite 0.2.1",
  "tracing-attributes",
  "tracing-core",
 ]
@@ -2509,6 +2509,12 @@ version = "0.9.0+wasi-snapshot-preview1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
 
+[[package]]
+name = "wasi"
+version = "0.10.0+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f"
+
 [[package]]
 name = "wasm-bindgen"
 version = "0.2.69"
diff --git a/Cargo.toml b/Cargo.toml
index 44df2543b..004cbfdf9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,7 +25,7 @@ ruma = { git = "https://github.com/ruma/ruma", features = ["rand", "client-api",
 # Used when doing state resolution
 # state-res = { git = "https://github.com/timokoesters/state-res", branch = "timo-spec-comp", features = ["unstable-pre-spec"] }
 # TODO: remove the gen-eventid feature
-state-res = { git = "https://github.com/ruma/state-res", branch = "event-trait", features = ["unstable-pre-spec", "gen-eventid"] }
+state-res = { git = "https://github.com/ruma/state-res", branch = "no-db", features = ["unstable-pre-spec", "gen-eventid"] }
 # state-res = { path = "../../state-res", features = ["unstable-pre-spec", "gen-eventid"] }
 
 # Used for long polling and federation sender, should be the same as rocket::tokio
diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs
index 4e093c26f..ea14268dc 100644
--- a/src/client_server/membership.rs
+++ b/src/client_server/membership.rs
@@ -618,7 +618,6 @@ async fn join_room_by_id_helper(
             &room_id,
             &control_events,
             &mut event_map,
-            &db.rooms,
             &event_ids,
         );
 
@@ -629,7 +628,6 @@ async fn join_room_by_id_helper(
             &sorted_control_events,
             &BTreeMap::new(), // We have no "clean/resolved" events to add (these extend the `resolved_control_events`)
             &mut event_map,
-            &db.rooms,
         )
         .expect("iterative auth check failed on resolved events");
 
@@ -654,7 +652,6 @@ async fn join_room_by_id_helper(
             &events_to_sort,
             power_level,
             &mut event_map,
-            &db.rooms,
         );
 
         let resolved_events = state_res::StateResolution::iterative_auth_check(
@@ -663,7 +660,6 @@ async fn join_room_by_id_helper(
             &sorted_event_ids,
             &resolved_control_events,
             &mut event_map,
-            &db.rooms,
         )
         .expect("iterative auth check failed on resolved events");
 
diff --git a/src/database/rooms.rs b/src/database/rooms.rs
index 48e7c149f..b84d1f987 100644
--- a/src/database/rooms.rs
+++ b/src/database/rooms.rs
@@ -67,40 +67,6 @@ pub struct Rooms {
     pub(super) stateid_pduid: sled::Tree, // StateId = StateHash + Short, PduId = Count (without roomid)
 }
 
-impl StateStore<PduEvent> for Rooms {
-    fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> state_res::Result<Arc<PduEvent>> {
-        let pid = self
-            .get_pdu_id(event_id)
-            .map_err(StateError::custom)?
-            .ok_or_else(|| {
-                StateError::NotFound(format!(
-                    "PDU via room_id and event_id not found in the db: {}",
-                    event_id.as_str()
-                ))
-            })?;
-
-        serde_json::from_slice(
-            &self
-                .pduid_pdu
-                .get(pid)
-                .map_err(StateError::custom)?
-                .ok_or_else(|| StateError::NotFound("PDU via pduid not found in db.".into()))?,
-        )
-        .map_err(Into::into)
-        .and_then(|pdu: PduEvent| {
-            // conduit's PDU's always contain a room_id but some
-            // of ruma's do not so this must be an Option
-            if pdu.room_id() == room_id {
-                Ok(Arc::new(pdu))
-            } else {
-                Err(StateError::NotFound(
-                    "Found PDU for incorrect room in db.".into(),
-                ))
-            }
-        })
-    }
-}
-
 impl Rooms {
     /// Builds a StateMap by iterating over all keys that start
     /// with state_hash, this gives the full state for the given state_hash.
@@ -222,6 +188,72 @@ pub fn get_auth_events(
         Ok(events)
     }
 
+    /// Returns a Vec of the related auth events to the given `event`.
+    ///
+    /// A recursive list of all the auth_events going back to `RoomCreate` for each event in `event_ids`.
+    pub fn auth_events_full(
+        &self,
+        room_id: &RoomId,
+        event_ids: &[EventId],
+    ) -> Result<Vec<PduEvent>> {
+        let mut result = BTreeMap::new();
+        let mut stack = event_ids.to_vec();
+
+        // DFS for auth event chain
+        while !stack.is_empty() {
+            let ev_id = stack.pop().unwrap();
+            if result.contains_key(&ev_id) {
+                continue;
+            }
+
+            if let Some(ev) = self.get_pdu(&ev_id)? {
+                stack.extend(ev.auth_events());
+                result.insert(ev.event_id().clone(), ev);
+            }
+        }
+
+        Ok(result.into_iter().map(|(_, v)| v).collect())
+    }
+
+    /// Returns a Vec<EventId> representing the difference in auth chains of the given `events`.
+    ///
+    /// Each inner `Vec` of `event_ids` represents a state set (state at each forward extremity).
+    pub fn auth_chain_diff(
+        &self,
+        room_id: &RoomId,
+        event_ids: Vec<Vec<EventId>>,
+    ) -> Result<Vec<EventId>> {
+        use std::collections::BTreeSet;
+
+        let mut chains = vec![];
+        for ids in event_ids {
+            // TODO state store `auth_event_ids` returns self in the event ids list
+            // when an event returns `auth_event_ids` self is not contained
+            let chain = self
+                .auth_events_full(room_id, &ids)?
+                .into_iter()
+                .map(|pdu| pdu.event_id)
+                .collect::<BTreeSet<_>>();
+            chains.push(chain);
+        }
+
+        if let Some(chain) = chains.first() {
+            let rest = chains.iter().skip(1).flatten().cloned().collect();
+            let common = chain.intersection(&rest).collect::<Vec<_>>();
+
+            Ok(chains
+                .iter()
+                .flatten()
+                .filter(|id| !common.contains(&id))
+                .cloned()
+                .collect::<BTreeSet<_>>()
+                .into_iter()
+                .collect())
+        } else {
+            Ok(vec![])
+        }
+    }
+
     /// Generate a new StateHash.
     ///
     /// A unique hash made from hashing all PDU ids of the state joined with 0xff.
diff --git a/src/server_server.rs b/src/server_server.rs
index 3de36364a..f68475cb5 100644
--- a/src/server_server.rs
+++ b/src/server_server.rs
@@ -603,7 +603,7 @@ pub async fn send_transaction_message_route<'a>(
             };
 
         // 4. Passes authorization rules based on the event's auth events, otherwise it is rejected.
-        // TODO: To me this sounds more like the auth_events should be get the pdu.auth_events not
+        // TODO: To me this sounds more like the auth_events should be "get the pdu.auth_events" not
         // the auth events that would be correct for this pdu. Put another way we should use the auth events
         // the pdu claims are its auth events
         let auth_events = db.rooms.get_auth_events(
@@ -637,50 +637,56 @@ pub async fn send_transaction_message_route<'a>(
             );
             continue;
         }
+        // End of step 4.
 
-        let (state_at_event, incoming_auth_events): (StateMap<Arc<PduEvent>>, _) = match db
-            .sending
-            .send_federation_request(
-                &db.globals,
-                server_name.clone(),
-                get_room_state_ids::v1::Request {
-                    room_id: pdu.room_id(),
-                    event_id: pdu.event_id(),
-                },
-            )
-            .await
-        {
-            Ok(res) => {
-                let state =
-                    fetch_events(&db, server_name.clone(), &pub_key_map, &res.pdu_ids).await?;
-                // Sanity check: there are no conflicting events in the state we received
-                let mut seen = BTreeSet::new();
-                for ev in &state {
-                    // If the key is already present
-                    if !seen.insert((&ev.kind, &ev.state_key)) {
-                        todo!("Server sent us an invalid state")
+        // Step 5. event passes auth based on state at the event
+        let (state_at_event, incoming_auth_events): (StateMap<Arc<PduEvent>>, Vec<Arc<PduEvent>>) =
+            match db
+                .sending
+                .send_federation_request(
+                    &db.globals,
+                    server_name.clone(),
+                    get_room_state_ids::v1::Request {
+                        room_id: pdu.room_id(),
+                        event_id: pdu.event_id(),
+                    },
+                )
+                .await
+            {
+                Ok(res) => {
+                    let state =
+                        fetch_events(&db, server_name.clone(), &pub_key_map, &res.pdu_ids).await?;
+                    // Sanity check: there are no conflicting events in the state we received
+                    let mut seen = BTreeSet::new();
+                    for ev in &state {
+                        // If the key is already present
+                        if !seen.insert((&ev.kind, &ev.state_key)) {
+                            todo!("Server sent us an invalid state")
+                        }
                     }
-                }
-
-                let state = state
-                    .into_iter()
-                    .map(|pdu| ((pdu.kind.clone(), pdu.state_key.clone()), Arc::new(pdu)))
-                    .collect();
 
-                (
-                    state,
-                    fetch_events(&db, server_name.clone(), &pub_key_map, &res.auth_chain_ids)
-                        .await?,
-                )
-            }
-            Err(_) => {
-                resolved_map.insert(
-                    event.event_id().clone(),
-                    Err("Fetching state for event failed".into()),
-                );
-                continue;
-            }
-        };
+                    let state = state
+                        .into_iter()
+                        .map(|pdu| ((pdu.kind.clone(), pdu.state_key.clone()), Arc::new(pdu)))
+                        .collect();
+
+                    (
+                        state,
+                        fetch_events(&db, server_name.clone(), &pub_key_map, &res.auth_chain_ids)
+                            .await?
+                            .into_iter()
+                            .map(Arc::new)
+                            .collect(),
+                    )
+                }
+                Err(_) => {
+                    resolved_map.insert(
+                        event.event_id().clone(),
+                        Err("Fetching state for event failed".into()),
+                    );
+                    continue;
+                }
+            };
 
         if !state_res::event_auth::auth_check(
             &RoomVersionId::Version6,
@@ -698,6 +704,7 @@ pub async fn send_transaction_message_route<'a>(
             );
             continue;
         }
+        // End of step 5.
 
         // The event could still be soft failed
         append_state_soft(&db, &pdu)?;
@@ -724,18 +731,30 @@ pub async fn send_transaction_message_route<'a>(
             }
         }
 
-        // 6.
+        // Step 6. event passes auth based on state of all forks and current room state
         let state_at_forks = if fork_states.is_empty() {
             // State is empty
             Default::default()
         } else if fork_states.len() == 1 {
             fork_states[0].clone()
         } else {
+            let auth_events = fork_states
+                .iter()
+                .map(|map| {
+                    db.rooms.auth_events_full(
+                        pdu.room_id(),
+                        &map.values()
+                            .map(|pdu| pdu.event_id().clone())
+                            .collect::<Vec<_>>(),
+                    )
+                })
+                .collect();
+
             // Add as much as we can to the `event_map` (less DB hits)
             event_map.extend(
                 incoming_auth_events
                     .into_iter()
-                    .map(|pdu| (pdu.event_id().clone(), Arc::new(pdu))),
+                    .map(|pdu| (pdu.event_id().clone(), pdu)),
             );
             event_map.extend(
                 state_at_event
@@ -754,8 +773,8 @@ pub async fn send_transaction_message_route<'a>(
                             .collect::<StateMap<_>>()
                     })
                     .collect::<Vec<_>>(),
+                &auth_events,
                 &mut event_map,
-                &db.rooms,
             ) {
                 Ok(res) => res
                     .into_iter()
-- 
GitLab