From dba6c466674d2e6dd462178a249fe45bde98d5a2 Mon Sep 17 00:00:00 2001
From: timokoesters <timo@koesters.xyz>
Date: Mon, 30 Mar 2020 13:46:18 +0200
Subject: [PATCH] Use sled::Tree::prefix_search for deviceids

---
 rustfmt.toml    |   1 +
 src/data.rs     | 133 +++++++++++++++++++++---------------------------
 src/database.rs | 117 ++++++++++++++++++++++++++++++++++++++++++
 src/main.rs     |  26 ++++++----
 src/utils.rs    |  30 ++++++-----
 5 files changed, 206 insertions(+), 101 deletions(-)
 create mode 100644 rustfmt.toml
 create mode 100644 src/database.rs

diff --git a/rustfmt.toml b/rustfmt.toml
new file mode 100644
index 000000000..7d2cf549d
--- /dev/null
+++ b/rustfmt.toml
@@ -0,0 +1 @@
+merge_imports = true
diff --git a/src/data.rs b/src/data.rs
index 0fa24d458..b7b9845b8 100644
--- a/src/data.rs
+++ b/src/data.rs
@@ -1,134 +1,115 @@
-use crate::utils;
-use directories::ProjectDirs;
-use log::debug;
+use crate::{utils, Database};
 use ruma_events::collections::all::Event;
 use ruma_identifiers::{EventId, RoomId, UserId};
 use std::convert::TryInto;
 
-const USERID_PASSWORD: &str = "userid_password";
-const USERID_DEVICEIDS: &str = "userid_deviceids";
-const DEVICEID_TOKEN: &str = "deviceid_token";
-const TOKEN_USERID: &str = "token_userid";
-
-pub struct Data(sled::Db);
+pub struct Data {
+    hostname: String,
+    db: Database,
+}
 
 impl Data {
     /// Load an existing database or create a new one.
-    pub fn load_or_create() -> Self {
-        Data(
-            sled::open(
-                ProjectDirs::from("xyz", "koesters", "matrixserver")
-                    .unwrap()
-                    .data_dir(),
-            )
-            .unwrap(),
-        )
-    }
-
-    /// Set the hostname of the server. Warning: Hostname changes will likely break things.
-    pub fn set_hostname(&self, hostname: &str) {
-        self.0.insert("hostname", hostname).unwrap();
+    pub fn load_or_create(hostname: &str) -> Self {
+        Self {
+            hostname: hostname.to_owned(),
+            db: Database::load_or_create(hostname),
+        }
     }
 
     /// Get the hostname of the server.
-    pub fn hostname(&self) -> String {
-        utils::bytes_to_string(&self.0.get("hostname").unwrap().unwrap())
+    pub fn hostname(&self) -> &str {
+        &self.hostname
     }
 
     /// Check if a user has an account by looking for an assigned password.
     pub fn user_exists(&self, user_id: &UserId) -> bool {
-        self.0
-            .open_tree(USERID_PASSWORD)
-            .unwrap()
+        self.db
+            .userid_password
             .contains_key(user_id.to_string())
             .unwrap()
     }
 
     /// Create a new user account by assigning them a password.
     pub fn user_add(&self, user_id: &UserId, password: Option<String>) {
-        self.0
-            .open_tree(USERID_PASSWORD)
-            .unwrap()
+        self.db
+            .userid_password
             .insert(user_id.to_string(), &*password.unwrap_or_default())
             .unwrap();
     }
 
     /// Find out which user an access token belongs to.
     pub fn user_from_token(&self, token: &str) -> Option<UserId> {
-        self.0
-            .open_tree(TOKEN_USERID)
-            .unwrap()
+        self.db
+            .token_userid
             .get(token)
             .unwrap()
-            .and_then(|bytes| (*utils::bytes_to_string(&bytes)).try_into().ok())
+            .and_then(|bytes| (*utils::string_from_bytes(&bytes)).try_into().ok())
     }
 
     /// Checks if the given password is equal to the one in the database.
     pub fn password_get(&self, user_id: &UserId) -> Option<String> {
-        self.0
-            .open_tree(USERID_PASSWORD)
-            .unwrap()
+        self.db
+            .userid_password
             .get(user_id.to_string())
             .unwrap()
-            .map(|bytes| utils::bytes_to_string(&bytes))
+            .map(|bytes| utils::string_from_bytes(&bytes))
     }
 
     /// Add a new device to a user.
     pub fn device_add(&self, user_id: &UserId, device_id: &str) {
-        self.0
-            .open_tree(USERID_DEVICEIDS)
-            .unwrap()
-            .insert(user_id.to_string(), device_id)
-            .unwrap();
+        if self
+            .db
+            .userid_deviceids
+            .get_iter(&user_id.to_string().as_bytes())
+            .filter_map(|item| item.ok())
+            .map(|(_key, value)| value)
+            .all(|device| device != device_id)
+        {
+            self.db
+                .userid_deviceids
+                .add(user_id.to_string().as_bytes(), device_id.into());
+        }
     }
 
     /// Replace the access token of one device.
     pub fn token_replace(&self, user_id: &UserId, device_id: &String, token: String) {
         // Make sure the device id belongs to the user
         debug_assert!(self
-            .0
-            .open_tree(USERID_DEVICEIDS)
-            .unwrap()
-            .get(&user_id.to_string()) // Does the user exist?
-            .unwrap()
-            .map(|bytes| utils::bytes_to_vec(&bytes))
-            .filter(|devices| devices.contains(device_id)) // Does the user have that device?
-            .is_some());
+            .db
+            .userid_deviceids
+            .get_iter(&user_id.to_string().as_bytes())
+            .filter_map(|item| item.ok())
+            .map(|(_key, value)| value)
+            .any(|device| device == device_id.as_bytes())); // Does the user have that device?
 
         // Remove old token
-        if let Some(old_token) = self
-            .0
-            .open_tree(DEVICEID_TOKEN)
-            .unwrap()
-            .get(device_id)
-            .unwrap()
-        {
-            self.0
-                .open_tree(TOKEN_USERID)
-                .unwrap()
-                .remove(old_token)
-                .unwrap();
-            // It will be removed from DEVICEID_TOKEN by the insert later
+        if let Some(old_token) = self.db.deviceid_token.get(device_id).unwrap() {
+            self.db.token_userid.remove(old_token).unwrap();
+            // It will be removed from deviceid_token by the insert later
         }
 
         // Assign token to device_id
-        self.0
-            .open_tree(DEVICEID_TOKEN)
-            .unwrap()
-            .insert(device_id, &*token)
-            .unwrap();
+        self.db.deviceid_token.insert(device_id, &*token).unwrap();
 
         // Assign token to user
-        self.0
-            .open_tree(TOKEN_USERID)
-            .unwrap()
+        self.db
+            .token_userid
             .insert(token, &*user_id.to_string())
             .unwrap();
     }
 
     /// Create a new room event.
-    pub fn event_add(&self, event: &Event, room_id: &RoomId, event_id: &EventId) {
-        debug!("{}", serde_json::to_string(event).unwrap());
-        todo!();
+    pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) {
+        let mut key = room_id.to_string().as_bytes().to_vec();
+        key.extend_from_slice(event_id.to_string().as_bytes());
+        self.db
+            .roomid_eventid_event
+            .insert(&key, &*serde_json::to_string(event).unwrap())
+            .unwrap();
+    }
+
+    pub fn debug(&self) {
+        self.db.debug();
     }
 }
diff --git a/src/database.rs b/src/database.rs
new file mode 100644
index 000000000..34ed72baa
--- /dev/null
+++ b/src/database.rs
@@ -0,0 +1,117 @@
+use crate::utils;
+use directories::ProjectDirs;
+use sled::IVec;
+
+pub struct MultiValue(sled::Tree);
+
+impl MultiValue {
+    /// Get an iterator over all values.
+    pub fn iter_all(&self) -> sled::Iter {
+        self.0.iter()
+    }
+
+    /// Get an iterator over all values of this id.
+    pub fn get_iter(&self, id: &[u8]) -> sled::Iter {
+        // Data keys start with d
+        let mut key = vec![b'd'];
+        key.extend_from_slice(id.as_ref());
+        key.push(0xff); // Add delimiter so we don't find usernames starting with the same id
+
+        self.0.scan_prefix(key)
+    }
+
+    /// Add another value to the id.
+    pub fn add(&self, id: &[u8], value: IVec) {
+        // The new value will need a new index. We store the last used index in 'n' + id
+        let mut count_key: Vec<u8> = vec![b'n'];
+        count_key.extend_from_slice(id.as_ref());
+
+        // Increment the last index and use that
+        let index = self
+            .0
+            .update_and_fetch(&count_key, utils::increment)
+            .unwrap()
+            .unwrap();
+
+        // Data keys start with d
+        let mut key = vec![b'd'];
+        key.extend_from_slice(id.as_ref());
+        key.push(0xff);
+        key.extend_from_slice(&index);
+
+        self.0.insert(key, value).unwrap();
+    }
+}
+
+pub struct Database {
+    pub userid_password: sled::Tree,
+    pub userid_deviceids: MultiValue,
+    pub deviceid_token: sled::Tree,
+    pub token_userid: sled::Tree,
+    pub roomid_eventid_event: sled::Tree,
+    _db: sled::Db,
+}
+
+impl Database {
+    /// Load an existing database or create a new one.
+    pub fn load_or_create(hostname: &str) -> Self {
+        let mut path = ProjectDirs::from("xyz", "koesters", "matrixserver")
+            .unwrap()
+            .data_dir()
+            .to_path_buf();
+        path.push(hostname);
+        let db = sled::open(&path).unwrap();
+
+        Self {
+            userid_password: db.open_tree("userid_password").unwrap(),
+            userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()),
+            deviceid_token: db.open_tree("deviceid_token").unwrap(),
+            token_userid: db.open_tree("token_userid").unwrap(),
+            roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(),
+            _db: db,
+        }
+    }
+
+    pub fn debug(&self) {
+        println!("# UserId -> Password:");
+        for (k, v) in self.userid_password.iter().map(|r| r.unwrap()) {
+            println!(
+                "{} -> {}",
+                String::from_utf8_lossy(&k),
+                String::from_utf8_lossy(&v),
+            );
+        }
+        println!("# UserId -> DeviceIds:");
+        for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) {
+            println!(
+                "{} -> {}",
+                String::from_utf8_lossy(&k),
+                String::from_utf8_lossy(&v),
+            );
+        }
+        println!("# DeviceId -> Token:");
+        for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) {
+            println!(
+                "{} -> {}",
+                String::from_utf8_lossy(&k),
+                String::from_utf8_lossy(&v),
+            );
+        }
+        println!("# Token -> UserId:");
+        for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) {
+            println!(
+                "{} -> {}",
+                String::from_utf8_lossy(&k),
+                String::from_utf8_lossy(&v),
+            );
+        }
+        println!("# RoomId + EventId -> Event:");
+        for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) {
+            println!(
+                "{} -> {}",
+                String::from_utf8_lossy(&k),
+                String::from_utf8_lossy(&v),
+            );
+        }
+    }
+}
diff --git a/src/main.rs b/src/main.rs
index 06f7ca3a8..cf1f37f7e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,9 +1,12 @@
 #![feature(proc_macro_hygiene, decl_macro)]
 mod data;
+mod database;
 mod ruma_wrapper;
 mod utils;
 
 pub use data::Data;
+pub use database::Database;
+
 use log::debug;
 use rocket::{get, post, put, routes, State};
 use ruma_client_api::{
@@ -14,13 +17,14 @@
     },
     unversioned::get_supported_versions,
 };
-use ruma_events::collections::all::Event;
-use ruma_events::room::message::MessageEvent;
+use ruma_events::{collections::all::Event, room::message::MessageEvent};
 use ruma_identifiers::{EventId, UserId};
 use ruma_wrapper::{MatrixResult, Ruma};
 use serde_json::map::Map;
-use std::convert::TryFrom;
-use std::{collections::HashMap, convert::TryInto};
+use std::{
+    collections::HashMap,
+    convert::{TryFrom, TryInto},
+};
 
 #[get("/_matrix/client/versions")]
 fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> {
@@ -90,7 +94,7 @@ fn register_route(
 
     MatrixResult(Ok(register::Response {
         access_token: token,
-        home_server: data.hostname(),
+        home_server: data.hostname().to_owned(),
         user_id,
         device_id,
     }))
@@ -153,7 +157,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
         .clone()
         .unwrap_or("TODO:randomdeviceid".to_owned());
 
-    // Add device (TODO: We might not want to call it when using an existing device)
+    // Add device
     data.device_add(&user_id, &device_id);
 
     // Generate a new token for the device
@@ -163,7 +167,7 @@ fn login_route(data: State<Data>, body: Ruma<login::Request>) -> MatrixResult<lo
     return MatrixResult(Ok(login::Response {
         user_id,
         access_token: token,
-        home_server: Some(data.hostname()),
+        home_server: Some(data.hostname().to_owned()),
         device_id,
         well_known: None,
     }));
@@ -217,6 +221,8 @@ fn create_message_event_route(
     // Generate event id
     let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap();
     data.event_add(
+        &body.room_id,
+        &event_id,
         &Event::RoomMessage(MessageEvent {
             content: body.data.clone().into_result().unwrap(),
             event_id: event_id.clone(),
@@ -225,8 +231,6 @@ fn create_message_event_route(
             sender: body.user_id.clone().expect("user is authenticated"),
             unsigned: Map::default(),
         }),
-        &body.room_id,
-        &event_id,
     );
 
     MatrixResult(Ok(create_message_event::Response { event_id }))
@@ -239,8 +243,8 @@ fn main() {
     }
     pretty_env_logger::init();
 
-    let data = Data::load_or_create();
-    data.set_hostname("localhost");
+    let data = Data::load_or_create("localhost");
+    data.debug();
 
     rocket::ignite()
         .mount(
diff --git a/src/utils.rs b/src/utils.rs
index fd7b4cb02..f2ef6c4c0 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -1,4 +1,7 @@
-use std::time::{SystemTime, UNIX_EPOCH};
+use std::{
+    convert::TryInto,
+    time::{SystemTime, UNIX_EPOCH},
+};
 
 pub fn millis_since_unix_epoch() -> js_int::UInt {
     (SystemTime::now()
@@ -8,20 +11,19 @@ pub fn millis_since_unix_epoch() -> js_int::UInt {
         .into()
 }
 
-pub fn bytes_to_string(bytes: &[u8]) -> String {
-    String::from_utf8(bytes.to_vec()).expect("convert bytes to string")
-}
+pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
+    let number = match old {
+        Some(bytes) => {
+            let array: [u8; 8] = bytes.try_into().unwrap();
+            let number = u64::from_be_bytes(array);
+            number + 1
+        }
+        None => 0,
+    };
 
-pub fn vec_to_bytes(vec: Vec<String>) -> Vec<u8> {
-    vec.into_iter()
-        .map(|string| string.into_bytes())
-        .collect::<Vec<Vec<u8>>>()
-        .join(&0)
+    Some(number.to_be_bytes().to_vec())
 }
 
-pub fn bytes_to_vec(bytes: &[u8]) -> Vec<String> {
-    bytes
-        .split(|&b| b == 0)
-        .map(|bytes_string| bytes_to_string(bytes_string))
-        .collect::<Vec<String>>()
+pub fn string_from_bytes(bytes: &[u8]) -> String {
+    String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8")
 }
-- 
GitLab