From a8df1acdfdef93fcae3e1fea5c0057cdb5fa1bb9 Mon Sep 17 00:00:00 2001
From: timokoesters <timo@koesters.xyz>
Date: Thu, 4 Jun 2020 13:58:55 +0200
Subject: [PATCH] feat: load replies, forward pagination

---
 Cargo.lock            |  40 ++++++--
 Cargo.toml            |   7 +-
 src/client_server.rs  | 207 ++++++++++++++++++++++++++++++++++--------
 src/database/rooms.rs |  23 +++++
 src/main.rs           |   1 +
 5 files changed, 229 insertions(+), 49 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 9f69ddd6d..365781df5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -163,7 +163,7 @@ dependencies = [
  "ruma-api",
  "ruma-client-api",
  "ruma-common",
- "ruma-events",
+ "ruma-events 0.21.3 (git+https://github.com/ruma/ruma-events?rev=7395f94)",
  "ruma-federation-api",
  "ruma-identifiers",
  "ruma-signatures",
@@ -1297,13 +1297,13 @@ dependencies = [
 [[package]]
 name = "ruma-client-api"
 version = "0.9.0"
-source = "git+https://github.com/ruma/ruma-client-api.git?rev=c725288cd099690c1d13f1a9b9e57228bc860a62#c725288cd099690c1d13f1a9b9e57228bc860a62"
+source = "git+https://github.com/ruma/ruma-client-api.git?rev=c2c5a3cea01b0544e5adb40f7ddae828627afd2c#c2c5a3cea01b0544e5adb40f7ddae828627afd2c"
 dependencies = [
  "http",
  "js_int",
  "ruma-api",
  "ruma-common",
- "ruma-events",
+ "ruma-events 0.21.3 (git+https://github.com/ruma/ruma-events?rev=7395f94)",
  "ruma-identifiers",
  "ruma-serde",
  "serde",
@@ -1326,11 +1326,11 @@ dependencies = [
 [[package]]
 name = "ruma-events"
 version = "0.21.3"
-source = "git+https://github.com/ruma/ruma-events.git?rev=4d09416cd1663d63c22153705c9e1fd77910797f#4d09416cd1663d63c22153705c9e1fd77910797f"
+source = "git+https://github.com/ruma/ruma-events?rev=7395f94#7395f940a7cf70c1598223570fb2b731a6a41707"
 dependencies = [
  "js_int",
  "ruma-common",
- "ruma-events-macros",
+ "ruma-events-macros 0.21.3 (git+https://github.com/ruma/ruma-events?rev=7395f94)",
  "ruma-identifiers",
  "ruma-serde",
  "serde",
@@ -1338,10 +1338,36 @@ dependencies = [
  "strum",
 ]
 
+[[package]]
+name = "ruma-events"
+version = "0.21.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ddf82c2231e4c53443424df34e868e4b09c20de7a76780d47a133a3b3f8ad9c"
+dependencies = [
+ "js_int",
+ "ruma-common",
+ "ruma-events-macros 0.21.3 (registry+https://github.com/rust-lang/crates.io-index)",
+ "ruma-identifiers",
+ "ruma-serde",
+ "serde",
+ "serde_json",
+]
+
 [[package]]
 name = "ruma-events-macros"
 version = "0.21.3"
-source = "git+https://github.com/ruma/ruma-events.git?rev=4d09416cd1663d63c22153705c9e1fd77910797f#4d09416cd1663d63c22153705c9e1fd77910797f"
+source = "git+https://github.com/ruma/ruma-events?rev=7395f94#7395f940a7cf70c1598223570fb2b731a6a41707"
+dependencies = [
+ "proc-macro2 1.0.18",
+ "quote 1.0.6",
+ "syn 1.0.30",
+]
+
+[[package]]
+name = "ruma-events-macros"
+version = "0.21.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "88e5c5b242fe4ee0cc56879057353621196d0988dd359579cad8f43471e483b7"
 dependencies = [
  "proc-macro2 1.0.18",
  "quote 1.0.6",
@@ -1356,7 +1382,7 @@ dependencies = [
  "js_int",
  "matches",
  "ruma-api",
- "ruma-events",
+ "ruma-events 0.21.3 (registry+https://github.com/rust-lang/crates.io-index)",
  "ruma-identifiers",
  "ruma-serde",
  "serde",
diff --git a/Cargo.toml b/Cargo.toml
index 53aacc792..7f7ba5bf8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,10 +14,10 @@ edition = "2018"
 [dependencies]
 rocket = { git = "https://github.com/SergioBenitez/Rocket.git", branch = "async", features = ["tls"] }
 http = "0.2.1"
-ruma-client-api = { git = "https://github.com/ruma/ruma-client-api.git", rev = "c725288cd099690c1d13f1a9b9e57228bc860a62" }
+ruma-client-api = { git = "https://github.com/ruma/ruma-client-api.git", rev = "c2c5a3cea01b0544e5adb40f7ddae828627afd2c" }
 ruma-identifiers = { version = "0.16.2", features = ["rand"] }
 ruma-api = "0.16.1"
-ruma-events = { git = "https://github.com/ruma/ruma-events.git", rev = "4d09416cd1663d63c22153705c9e1fd77910797f" }
+ruma-events = { git = "https://github.com/ruma/ruma-events.git", rev = "7395f94" }
 ruma-signatures = { git = "https://github.com/ruma/ruma-signatures.git", rev = "1ca545cba8dfd43e0fc8e3c18e1311fb73390a97" }
 ruma-federation-api = { git = "https://github.com/ruma/ruma-federation-api.git", rev = "4cf4aa6ef74b25ad8c14d99d7774129f023df163" }
 log = "0.4.8"
@@ -34,6 +34,3 @@ base64 = "0.12.1"
 thiserror = "1.0.19"
 ruma-common = "0.1.2"
 image = { version = "0.23.4", default-features = false, features = ["jpeg", "png", "gif"] }
-
-[patch.crates-io]
-ruma-events = { git = "https://github.com/ruma/ruma-events.git", rev = "4d09416cd1663d63c22153705c9e1fd77910797f" }
diff --git a/src/client_server.rs b/src/client_server.rs
index 6d99a829a..2ef6d0204 100644
--- a/src/client_server.rs
+++ b/src/client_server.rs
@@ -14,6 +14,7 @@
         alias::{create_alias, delete_alias, get_alias},
         capabilities::get_capabilities,
         config::{get_global_account_data, set_global_account_data},
+        context::get_context,
         device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
         directory::{
             self, get_public_rooms, get_public_rooms_filtered, get_room_visibility,
@@ -200,7 +201,7 @@ pub fn register_route(
                 content: ruma_events::push_rules::PushRulesEventContent {
                     global: ruma_events::push_rules::Ruleset {
                         content: vec![],
-                        override_rules: vec![ruma_events::push_rules::ConditionalPushRule {
+                        override_: vec![ruma_events::push_rules::ConditionalPushRule {
                             actions: vec![ruma_events::push_rules::Action::DontNotify],
                             default: true,
                             enabled: false,
@@ -219,12 +220,10 @@ pub fn register_route(
                             default: true,
                             enabled: true,
                             rule_id: ".m.rule.message".to_owned(),
-                            conditions: vec![ruma_events::push_rules::PushCondition::EventMatch(
-                                ruma_events::push_rules::EventMatchCondition {
-                                    key: "type".to_owned(),
-                                    pattern: "m.room.message".to_owned(),
-                                },
-                            )],
+                            conditions: vec![ruma_events::push_rules::PushCondition::EventMatch {
+                                key: "type".to_owned(),
+                                pattern: "m.room.message".to_owned(),
+                            }],
                         }],
                     },
                 },
@@ -348,11 +347,11 @@ pub fn logout_route(
 pub fn get_capabilities_route() -> MatrixResult<get_capabilities::Response> {
     let mut available = BTreeMap::new();
     available.insert(
-        "5".to_owned(),
+        RoomVersionId::version_5(),
         get_capabilities::RoomVersionStability::Stable,
     );
     available.insert(
-        "6".to_owned(),
+        RoomVersionId::version_6(),
         get_capabilities::RoomVersionStability::Stable,
     );
 
@@ -374,7 +373,6 @@ pub fn get_pushrules_all_route(
     body: Ruma<get_pushrules_all::Request>,
 ) -> MatrixResult<get_pushrules_all::Response> {
     let user_id = body.user_id.as_ref().expect("user is authenticated");
-    warn!("TODO: get_pushrules_all_route");
 
     if let Some(EduEvent::PushRules(pushrules)) = db
         .account_data
@@ -383,7 +381,7 @@ pub fn get_pushrules_all_route(
         .map(|edu| edu.deserialize().expect("PushRules event in db is valid"))
     {
         MatrixResult(Ok(get_pushrules_all::Response {
-            global: BTreeMap::new(),
+            global: pushrules.content.global
         }))
     } else {
         MatrixResult(Err(Error {
@@ -2092,9 +2090,15 @@ pub fn sync_route(
             .unwrap()
             > since
         {
-            edus.push(serde_json::from_str(&serde_json::to_string(
-                &EduEvent::Typing(db.rooms.edus.roomactives_all(&room_id).unwrap()),
-            ).unwrap()).unwrap());
+            edus.push(
+                serde_json::from_str(
+                    &serde_json::to_string(&EduEvent::Typing(
+                        db.rooms.edus.roomactives_all(&room_id).unwrap(),
+                    ))
+                    .unwrap(),
+                )
+                .unwrap(),
+            );
         }
 
         joined_rooms.insert(
@@ -2170,9 +2174,15 @@ pub fn sync_route(
             .unwrap()
             > since
         {
-            edus.push(serde_json::from_str(&serde_json::to_string(
-                &EduEvent::Typing(db.rooms.edus.roomactives_all(&room_id).unwrap()),
-            ).unwrap()).unwrap());
+            edus.push(
+                serde_json::from_str(
+                    &serde_json::to_string(&EduEvent::Typing(
+                        db.rooms.edus.roomactives_all(&room_id).unwrap(),
+                    ))
+                    .unwrap(),
+                )
+                .unwrap(),
+            );
         }
 
         left_rooms.insert(
@@ -2271,12 +2281,16 @@ pub fn sync_route(
     }))
 }
 
-#[get("/_matrix/client/r0/rooms/<_room_id>/messages", data = "<body>")]
-pub fn get_message_events_route(
+#[get(
+    "/_matrix/client/r0/rooms/<_room_id>/context/<_event_id>",
+    data = "<body>"
+)]
+pub fn get_context_route(
     db: State<'_, Database>,
-    body: Ruma<get_message_events::Request>,
+    body: Ruma<get_context::Request>,
     _room_id: String,
-) -> MatrixResult<get_message_events::Response> {
+    _event_id: String,
+) -> MatrixResult<get_context::Response> {
     let user_id = body.user_id.as_ref().expect("user is authenticated");
 
     if !db.rooms.is_joined(user_id, &body.room_id).unwrap() {
@@ -2287,41 +2301,160 @@ pub fn get_message_events_route(
         }));
     }
 
-    if let get_message_events::Direction::Forward = body.dir {
-        todo!();
-    }
+    if let Some(base_event) = db.rooms.get_pdu(&body.event_id).unwrap() {
+        let base_event = base_event
+            .to_room_event();
+
+        let base_token = db
+            .rooms
+            .get_pdu_count(&body.event_id)
+            .unwrap()
+            .expect("event exists, so count should exist too");
 
-    if let Ok(from) = body.from.clone().parse() {
-        let pdus = db
+        let events_before = db
             .rooms
-            .pdus_until(&body.room_id, from)
-            .take(body.limit.map(|l| l.try_into().unwrap()).unwrap_or(10_u32) as usize)
+            .pdus_until(&body.room_id, base_token)
+            .take(u32::try_from(body.limit).unwrap() as usize / 2)
             .map(|r| r.unwrap())
             .collect::<Vec<_>>();
-        let prev_batch = pdus
+
+        let start_token = events_before
             .last()
             .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap())
             .map(|c| c.to_string());
-        let room_events = pdus
+
+        let events_before = events_before
             .into_iter()
             .map(|pdu| pdu.to_room_event())
             .collect::<Vec<_>>();
 
-        MatrixResult(Ok(get_message_events::Response {
-            start: Some(body.from.clone()),
-            end: prev_batch,
-            chunk: room_events,
-            state: Vec::new(),
-        }))
+        let events_after = db
+            .rooms
+            .pdus_after(&body.room_id, base_token)
+            .take(u32::try_from(body.limit).unwrap() as usize / 2)
+            .map(|r| r.unwrap())
+            .collect::<Vec<_>>();
+
+        let end_token = events_after
+            .last()
+            .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap())
+            .map(|c| c.to_string());
+
+        let events_after = events_after
+            .into_iter()
+            .map(|pdu| pdu.to_room_event())
+            .collect::<Vec<_>>();
+
+        MatrixResult(Ok(get_context::Response {
+            start: start_token,
+            end: end_token,
+            events_before,
+            event: Some(base_event),
+            events_after,
+            state: db // TODO: State at event
+                .rooms
+                .room_state(&body.room_id)
+                .unwrap()
+                .values()
+                .map(|pdu| pdu.to_state_event())
+                .collect(),
+            }))
     } else {
         MatrixResult(Err(Error {
             kind: ErrorKind::Unknown,
-            message: "Invalid from.".to_owned(),
+            message: "Invalid base event.".to_owned(),
             status_code: http::StatusCode::BAD_REQUEST,
         }))
     }
 }
 
+#[get("/_matrix/client/r0/rooms/<_room_id>/messages", data = "<body>")]
+pub fn get_message_events_route(
+    db: State<'_, Database>,
+    body: Ruma<get_message_events::Request>,
+    _room_id: String,
+) -> MatrixResult<get_message_events::Response> {
+    let user_id = body.user_id.as_ref().expect("user is authenticated");
+
+    if !db.rooms.is_joined(user_id, &body.room_id).unwrap() {
+        return MatrixResult(Err(Error {
+            kind: ErrorKind::Forbidden,
+            message: "You don't have permission to view this room.".to_owned(),
+            status_code: http::StatusCode::BAD_REQUEST,
+        }));
+    }
+
+    match body.dir {
+        get_message_events::Direction::Forward => {
+            if let Ok(from) = body.from.clone().parse() {
+                let events_after = db
+                    .rooms
+                    .pdus_after(&body.room_id, from)
+                    .take(body.limit.map(|l| l.try_into().unwrap()).unwrap_or(10_u32) as usize)
+                    .map(|r| r.unwrap())
+                    .collect::<Vec<_>>();
+
+                let end_token = events_after
+                    .last()
+                    .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap())
+                    .map(|c| c.to_string());
+
+                let events_after = events_after
+                    .into_iter()
+                    .map(|pdu| pdu.to_room_event())
+                    .collect::<Vec<_>>();
+
+                MatrixResult(Ok(get_message_events::Response {
+                    start: Some(body.from.clone()),
+                    end: end_token,
+                    chunk: events_after,
+                    state: Vec::new(),
+                }))
+            } else {
+                MatrixResult(Err(Error {
+                    kind: ErrorKind::Unknown,
+                    message: "Invalid from.".to_owned(),
+                    status_code: http::StatusCode::BAD_REQUEST,
+                }))
+            }
+        }
+        get_message_events::Direction::Backward => {
+            if let Ok(from) = body.from.clone().parse() {
+                let events_before = db
+                    .rooms
+                    .pdus_until(&body.room_id, from)
+                    .take(body.limit.map(|l| l.try_into().unwrap()).unwrap_or(10_u32) as usize)
+                    .map(|r| r.unwrap())
+                    .collect::<Vec<_>>();
+
+                let start_token = events_before
+                    .last()
+                    .and_then(|e| db.rooms.get_pdu_count(&e.event_id).unwrap())
+                    .map(|c| c.to_string());
+
+                let events_before = events_before
+                    .into_iter()
+                    .map(|pdu| pdu.to_room_event())
+                    .collect::<Vec<_>>();
+
+                MatrixResult(Ok(get_message_events::Response {
+                    start: Some(body.from.clone()),
+                    end: start_token,
+                    chunk: events_before,
+                    state: Vec::new(),
+                }))
+            } else {
+                MatrixResult(Err(Error {
+                    kind: ErrorKind::Unknown,
+                    message: "Invalid from.".to_owned(),
+                    status_code: http::StatusCode::BAD_REQUEST,
+                }))
+            }
+        }
+    }
+
+}
+
 #[get("/_matrix/client/r0/voip/turnServer")]
 pub fn turn_server_route() -> MatrixResult<create_message_event::Response> {
     MatrixResult(Err(Error {
diff --git a/src/database/rooms.rs b/src/database/rooms.rs
index a9a93067b..44cd20270 100644
--- a/src/database/rooms.rs
+++ b/src/database/rooms.rs
@@ -553,6 +553,29 @@ pub fn pdus_until(
             .map(|(_, v)| Ok(serde_json::from_slice(&v)?))
     }
 
+    /// Returns an iterator over all events in a room that happened after the event with id
+    /// `from` in chronological order.
+    pub fn pdus_after(
+        &self,
+        room_id: &RoomId,
+        from: u64,
+    ) -> impl Iterator<Item = Result<PduEvent>> {
+        // Create the first part of the full pdu id
+        let mut prefix = room_id.to_string().as_bytes().to_vec();
+        prefix.push(0xff);
+
+        let mut current = prefix.clone();
+        current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
+
+        let current: &[u8] = &current;
+
+        self.pduid_pdu
+            .range(current..)
+            .filter_map(|r| r.ok())
+            .take_while(move |(k, _)| k.starts_with(&prefix))
+            .map(|(_, v)| Ok(serde_json::from_slice(&v)?))
+    }
+
     /// Replace a PDU with the redacted form.
     pub fn redact_pdu(&self, event_id: &EventId) -> Result<()> {
         if let Some(pdu_id) = self.get_pdu_id(event_id)? {
diff --git a/src/main.rs b/src/main.rs
index 27493d1e4..12a5195a4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -70,6 +70,7 @@ fn setup_rocket() -> rocket::Rocket {
                 client_server::get_state_events_for_key_route,
                 client_server::get_state_events_for_empty_key_route,
                 client_server::sync_route,
+                client_server::get_context_route,
                 client_server::get_message_events_route,
                 client_server::turn_server_route,
                 client_server::publicised_groups_route,
-- 
GitLab