diff --git a/Cargo.lock b/Cargo.lock index 36a66590bf09c1d4f5377fadcc64d7d8ef5e8e2c..630d41423269657ee579f4463a4042d199bc6a7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,6 +103,25 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72" +[[package]] +name = "bindgen" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd4865004a46a0aafb2a0a5eb19d3c9fc46ee5f063a6cfc605c69ac9ecf5263d" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", +] + [[package]] name = "bitflags" version = "1.2.1" @@ -162,6 +181,15 @@ dependencies = [ "jobserver", ] +[[package]] +name = "cexpr" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4aedb84272dbe89af497cf81375129abda4fc0a9e7c5d317498c15cc30c0d27" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -187,6 +215,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "clang-sys" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "853eda514c284c2287f4bf20ae614f8781f40a81d32ecda6e91449304dfe077c" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -212,6 +251,7 @@ dependencies = [ "reqwest", "ring", "rocket", + "rocksdb", "ruma", "rust-argon2", "rustls", @@ -1008,12 +1048,40 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "789da6d93f1b866ffe175afc5322a4d76c038605a1c3319bb57b06967ca98a36" +[[package]] +name = "libloading" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f84d96438c15fcd6c3f244c8fce01d1e2b9c6b5623e9c711dc9286d8fc92d6a" +dependencies = [ + "cfg-if 1.0.0", + "winapi", +] + +[[package]] +name = "librocksdb-sys" +version = "6.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da125e1c0f22c7cae785982115523a0738728498547f415c9054cb17c7e89f9" +dependencies = [ + "bindgen", + "cc", + "glob", + "libc", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -1158,6 +1226,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "nom" +version = "5.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af" +dependencies = [ + "memchr", + "version_check", +] + [[package]] name = "ntapi" version = "0.3.6" @@ -1339,6 +1417,12 @@ dependencies = [ "syn", ] +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem" version = "0.8.3" @@ -1777,6 +1861,16 @@ dependencies = [ "uncased", ] +[[package]] +name = "rocksdb" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c749134fda8bfc90d0de643d59bfc841dcb3ac8a1062e12b6754bd60235c48b3" +dependencies = [ + "libc", + "librocksdb-sys", +] + [[package]] name = "ruma" version = "0.1.2" @@ -2046,6 +2140,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.2.3" @@ -2245,6 +2345,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2" + [[package]] name = "signal-hook-registry" version = "1.3.0" diff --git a/Cargo.toml b/Cargo.toml index f36d8382bb6ce1c95dc9161aa0146ed7391adb71..e7ebadf20788995285443cc0a357f52711fb33c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,8 @@ ruma = { git = "https://github.com/ruma/ruma", rev = "b39537812c12caafcbf8b7bd74 # Used for long polling and federation sender, should be the same as rocket::tokio tokio = "1.2.0" # Used for storing data permanently -sled = { version = "0.34.6", features = ["compression", "no_metrics"] } +sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } +rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } # Used for the http request / response body type for Ruma endpoints used with reqwest @@ -74,7 +75,9 @@ opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" [features] -default = ["conduit_bin"] +default = ["conduit_bin", "backend_sled"] +backend_sled = ["sled"] +backend_rocksdb = ["rocksdb"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 0cf30a079d8152bdf8016a90217f370b537665e2..56de5fc90b4d70f9aadf817f18e7b45568dacadf 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, convert::TryInto}; +use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; @@ -42,7 +42,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn get_register_available_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_username_availability::Request<'_>>, ) -> ConduitResult<get_username_availability::Response> { // Validate user id @@ -85,7 +85,7 @@ pub async fn get_register_available_route( )] #[tracing::instrument(skip(db, body))] pub async fn register_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<register::Request<'_>>, ) -> ConduitResult<register::Response> { if !db.globals.allow_registration() { @@ -227,7 +227,7 @@ pub async fn register_route( )?; // If this is the first user on this server, create the admins room - if db.users.count() == 1 { + if db.users.count()? == 1 { // Create a user for the server let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) .expect("@conduit:server_name is valid"); @@ -506,7 +506,7 @@ pub async fn register_route( )] #[tracing::instrument(skip(db, body))] pub async fn change_password_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<change_password::Request<'_>>, ) -> ConduitResult<change_password::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -598,7 +598,7 @@ pub async fn whoami_route(body: Ruma<whoami::Request>) -> ConduitResult<whoami:: )] #[tracing::instrument(skip(db, body))] pub async fn deactivate_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<deactivate::Request<'_>>, ) -> ConduitResult<deactivate::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index 07b49773aec7d4270c548c4728d05e9191a56659..40252af2088567e43b514a270dd111d4aa104e4a 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use regex::Regex; @@ -22,7 +24,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn create_alias_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_alias::Request<'_>>, ) -> ConduitResult<create_alias::Response> { if db.rooms.id_from_alias(&body.room_alias)?.is_some() { @@ -43,7 +45,7 @@ pub async fn create_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_alias_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_alias::Request<'_>>, ) -> ConduitResult<delete_alias::Response> { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; @@ -59,7 +61,7 @@ pub async fn delete_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_alias_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_alias::Request<'_>>, ) -> ConduitResult<get_alias::Response> { get_alias_helper(&db, &body.room_alias).await @@ -86,7 +88,8 @@ pub async fn get_alias_helper( match db.rooms.id_from_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in db.appservice.iter_all().filter_map(|r| r.ok()) { + let iter = db.appservice.iter_all()?; + for (_id, registration) in iter.filter_map(|r| r.ok()) { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index 12f3bfd36be2facc4fdf33605effcbaa62ff6c6a..fcca676f299bb8e1c18e9859cccce583c341bf00 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ @@ -19,7 +21,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn create_backup_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_backup::Request>, ) -> ConduitResult<create_backup::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -38,7 +40,7 @@ pub async fn create_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_backup_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<update_backup::Request<'_>>, ) -> ConduitResult<update_backup::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -56,7 +58,7 @@ pub async fn update_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_latest_backup_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_latest_backup::Request>, ) -> ConduitResult<get_latest_backup::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -84,7 +86,7 @@ pub async fn get_latest_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_backup::Request<'_>>, ) -> ConduitResult<get_backup::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -111,7 +113,7 @@ pub async fn get_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_backup::Request<'_>>, ) -> ConduitResult<delete_backup::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -130,7 +132,7 @@ pub async fn delete_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<add_backup_keys::Request<'_>>, ) -> ConduitResult<add_backup_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -164,7 +166,7 @@ pub async fn add_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_sessions_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<add_backup_key_sessions::Request<'_>>, ) -> ConduitResult<add_backup_key_sessions::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -196,7 +198,7 @@ pub async fn add_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_session_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<add_backup_key_session::Request<'_>>, ) -> ConduitResult<add_backup_key_session::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -225,7 +227,7 @@ pub async fn add_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_backup_keys::Request<'_>>, ) -> ConduitResult<get_backup_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -241,14 +243,14 @@ pub async fn get_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_sessions_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_backup_key_sessions::Request<'_>>, ) -> ConduitResult<get_backup_key_sessions::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sessions = db .key_backups - .get_room(&sender_user, &body.version, &body.room_id); + .get_room(&sender_user, &body.version, &body.room_id)?; Ok(get_backup_key_sessions::Response { sessions }.into()) } @@ -259,7 +261,7 @@ pub async fn get_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_session_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_backup_key_session::Request<'_>>, ) -> ConduitResult<get_backup_key_session::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -281,7 +283,7 @@ pub async fn get_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_backup_keys::Request<'_>>, ) -> ConduitResult<delete_backup_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -304,7 +306,7 @@ pub async fn delete_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_sessions_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_backup_key_sessions::Request<'_>>, ) -> ConduitResult<delete_backup_key_sessions::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -327,7 +329,7 @@ pub async fn delete_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_session_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_backup_key_session::Request<'_>>, ) -> ConduitResult<delete_backup_key_session::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/config.rs b/src/client_server/config.rs index ce437efdc7b8f93317a150b346cdbd576132e39e..829bf94a89be567cb9e325e17661e6969a6a958e 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::{ @@ -23,7 +25,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn set_global_account_data_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_global_account_data::Request<'_>>, ) -> ConduitResult<set_global_account_data::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -58,7 +60,7 @@ pub async fn set_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_account_data_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_room_account_data::Request<'_>>, ) -> ConduitResult<set_room_account_data::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -90,7 +92,7 @@ pub async fn set_room_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_global_account_data_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_global_account_data::Request<'_>>, ) -> ConduitResult<get_global_account_data::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -117,7 +119,7 @@ pub async fn get_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_account_data_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_room_account_data::Request<'_>>, ) -> ConduitResult<get_room_account_data::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 1fee2f2605bd31cb08850361454f25b1461c9379..b86fd0bfed754e0f4e930ef6d0be8976bca9e8e4 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,7 +1,7 @@ use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; #[cfg(feature = "conduit_bin")] use rocket::get; @@ -12,7 +12,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn get_context_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_context::Request<'_>>, ) -> ConduitResult<get_context::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 961ba97a73b0380cbe4e67144355e5961220bc58..2441524df4e7b77988aa1226c25feea2ec18edfe 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{utils, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ @@ -18,7 +20,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn get_devices_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_devices::Request>, ) -> ConduitResult<get_devices::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -38,7 +40,7 @@ pub async fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_device_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_device::Request<'_>>, ) -> ConduitResult<get_device::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -57,7 +59,7 @@ pub async fn get_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_device_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<update_device::Request<'_>>, ) -> ConduitResult<update_device::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -83,7 +85,7 @@ pub async fn update_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_device_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_device::Request<'_>>, ) -> ConduitResult<delete_device::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -137,7 +139,7 @@ pub async fn delete_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_devices_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_devices::Request<'_>>, ) -> ConduitResult<delete_devices::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 9864a5e36770e048ac4fdb766bdfdcab69e07b97..ad609cd5e0961e4de9732d1b2bf647f7c0939a1e 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Result, Ruma}; use log::info; @@ -33,7 +35,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_public_rooms_filtered::Request<'_>>, ) -> ConduitResult<get_public_rooms_filtered::Response> { get_public_rooms_filtered_helper( @@ -53,7 +55,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_public_rooms::Request<'_>>, ) -> ConduitResult<get_public_rooms::Response> { let response = get_public_rooms_filtered_helper( @@ -82,7 +84,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_visibility_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_room_visibility::Request<'_>>, ) -> ConduitResult<set_room_visibility::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -112,7 +114,7 @@ pub async fn set_room_visibility_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_visibility_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_room_visibility::Request<'_>>, ) -> ConduitResult<get_room_visibility::Response> { Ok(get_room_visibility::Response { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index d856bf315a60489ddf9f8d1b855d67e50fddc2e1..f80a3294e2a29100b4a852006da626865e246d56 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -14,7 +14,10 @@ encryption::UnsignedDeviceInfo, DeviceId, DeviceKeyAlgorithm, UserId, }; -use std::collections::{BTreeMap, HashSet}; +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, +}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -25,7 +28,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn upload_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<upload_keys::Request>, ) -> ConduitResult<upload_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -74,7 +77,7 @@ pub async fn upload_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_keys::Request<'_>>, ) -> ConduitResult<get_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -95,7 +98,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<claim_keys::Request>, ) -> ConduitResult<claim_keys::Response> { let response = claim_keys_helper(&body.one_time_keys, &db)?; @@ -111,7 +114,7 @@ pub async fn claim_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signing_keys_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<upload_signing_keys::Request<'_>>, ) -> ConduitResult<upload_signing_keys::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -174,7 +177,7 @@ pub async fn upload_signing_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signatures_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<upload_signatures::Request>, ) -> ConduitResult<upload_signatures::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -235,7 +238,7 @@ pub async fn upload_signatures_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_key_changes_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_key_changes::Request<'_>>, ) -> ConduitResult<get_key_changes::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 74ca6c842ce168dba04409a3c5cc6e01b30f1592..0b1fbd7bcf6cbaec264552d29a49217223bb5be9 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -7,14 +7,14 @@ #[cfg(feature = "conduit_bin")] use rocket::{get, post}; -use std::convert::TryInto; +use std::{convert::TryInto, sync::Arc}; const MXC_LENGTH: usize = 32; #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[tracing::instrument(skip(db))] pub async fn get_media_config_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, ) -> ConduitResult<get_media_config::Response> { Ok(get_media_config::Response { upload_size: db.globals.max_request_size().into(), @@ -28,7 +28,7 @@ pub async fn get_media_config_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_content_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_content::Request<'_>>, ) -> ConduitResult<create_content::Response> { let mxc = format!( @@ -36,16 +36,20 @@ pub async fn create_content_route( db.globals.server_name(), utils::random_string(MXC_LENGTH) ); - db.media.create( - mxc.clone(), - &body - .filename - .as_ref() - .map(|filename| "inline; filename=".to_owned() + filename) - .as_deref(), - &body.content_type.as_deref(), - &body.file, - )?; + + db.media + .create( + mxc.clone(), + &db.globals, + &body + .filename + .as_ref() + .map(|filename| "inline; filename=".to_owned() + filename) + .as_deref(), + &body.content_type.as_deref(), + &body.file, + ) + .await?; db.flush().await?; @@ -62,7 +66,7 @@ pub async fn create_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_content::Request<'_>>, ) -> ConduitResult<get_content::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -71,7 +75,7 @@ pub async fn get_content_route( content_disposition, content_type, file, - }) = db.media.get(&mxc)? + }) = db.media.get(&db.globals, &mxc).await? { Ok(get_content::Response { file, @@ -93,12 +97,15 @@ pub async fn get_content_route( ) .await?; - db.media.create( - mxc, - &get_content_response.content_disposition.as_deref(), - &get_content_response.content_type.as_deref(), - &get_content_response.file, - )?; + db.media + .create( + mxc, + &db.globals, + &get_content_response.content_disposition.as_deref(), + &get_content_response.content_type.as_deref(), + &get_content_response.file, + ) + .await?; Ok(get_content_response.into()) } else { @@ -112,22 +119,27 @@ pub async fn get_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_thumbnail_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_content_thumbnail::Request<'_>>, ) -> ConduitResult<get_content_thumbnail::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { content_type, file, .. - }) = db.media.get_thumbnail( - mxc.clone(), - body.width - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - body.height - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - )? { + }) = db + .media + .get_thumbnail( + mxc.clone(), + &db.globals, + body.width + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + body.height + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + ) + .await? + { Ok(get_content_thumbnail::Response { file, content_type }.into()) } else if &*body.server_name != db.globals.server_name() && body.allow_remote { let get_thumbnail_response = db @@ -146,14 +158,17 @@ pub async fn get_content_thumbnail_route( ) .await?; - db.media.upload_thumbnail( - mxc, - &None, - &get_thumbnail_response.content_type, - body.width.try_into().expect("all UInts are valid u32s"), - body.height.try_into().expect("all UInts are valid u32s"), - &get_thumbnail_response.file, - )?; + db.media + .upload_thumbnail( + mxc, + &db.globals, + &None, + &get_thumbnail_response.content_type, + body.width.try_into().expect("all UInts are valid u32s"), + body.height.try_into().expect("all UInts are valid u32s"), + &get_thumbnail_response.file, + ) + .await?; Ok(get_thumbnail_response.into()) } else { diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 96fe80013880e295a14acc7657c5119b683b4253..a3f1389a9d258651fc12ed96cc65ab6dc79a52a2 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -44,7 +44,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<join_room_by_id::Request<'_>>, ) -> ConduitResult<join_room_by_id::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -81,7 +81,7 @@ pub async fn join_room_by_id_route( )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_or_alias_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<join_room_by_id_or_alias::Request<'_>>, ) -> ConduitResult<join_room_by_id_or_alias::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -135,7 +135,7 @@ pub async fn join_room_by_id_or_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn leave_room_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<leave_room::Request<'_>>, ) -> ConduitResult<leave_room::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -153,7 +153,7 @@ pub async fn leave_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn invite_user_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<invite_user::Request<'_>>, ) -> ConduitResult<invite_user::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -173,7 +173,7 @@ pub async fn invite_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn kick_user_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<kick_user::Request<'_>>, ) -> ConduitResult<kick_user::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -222,7 +222,7 @@ pub async fn kick_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn ban_user_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<ban_user::Request<'_>>, ) -> ConduitResult<ban_user::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -279,7 +279,7 @@ pub async fn ban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn unban_user_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<unban_user::Request<'_>>, ) -> ConduitResult<unban_user::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -327,7 +327,7 @@ pub async fn unban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn forget_room_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<forget_room::Request<'_>>, ) -> ConduitResult<forget_room::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -345,7 +345,7 @@ pub async fn forget_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_rooms_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<joined_rooms::Request>, ) -> ConduitResult<joined_rooms::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -366,7 +366,7 @@ pub async fn joined_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_member_events_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_member_events::Request<'_>>, ) -> ConduitResult<get_member_events::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -396,7 +396,7 @@ pub async fn get_member_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_members_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<joined_members::Request<'_>>, ) -> ConduitResult<joined_members::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -621,7 +621,7 @@ async fn join_room_by_id_helper( &pdu, utils::to_canonical_object(&pdu).expect("Pdu is valid canonical object"), count, - pdu_id.into(), + &pdu_id, &[pdu.event_id.clone()], db, )?; diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 96de93dd9dd8fb213df7b1fec2bf483d3fababf3..0d19f347c248f32dc4ad59bbf58294e682f9b4cc 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -11,6 +11,7 @@ use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, + sync::Arc, }; #[cfg(feature = "conduit_bin")] @@ -22,7 +23,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn send_message_event_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<send_message_event::Request<'_>>, ) -> ConduitResult<send_message_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +86,7 @@ pub async fn send_message_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_message_events_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_message_events::Request<'_>>, ) -> ConduitResult<get_message_events::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index 9f4f7a39af4fd7489a44c6135a561c6dc7744611..ce80dfd7a88a28568a676a4ac533a8d86b2ce216 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,7 +1,7 @@ use super::State; use crate::{utils, ConduitResult, Database, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; -use std::{convert::TryInto, time::Duration}; +use std::{convert::TryInto, sync::Arc, time::Duration}; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; @@ -12,7 +12,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn set_presence_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_presence::Request<'_>>, ) -> ConduitResult<set_presence::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +53,7 @@ pub async fn set_presence_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_presence_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_presence::Request<'_>>, ) -> ConduitResult<get_presence::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -62,7 +62,7 @@ pub async fn get_presence_route( for room_id in db .rooms - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()]) + .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { let room_id = room_id?; diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 882b02e7e4d9c0e2c51adf4a8c9754de7efdc695..32bb6083d3b2eb4ccbbadbf4dfcc4cb1d09837cc 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -13,7 +13,7 @@ #[cfg(feature = "conduit_bin")] use rocket::{get, put}; -use std::convert::TryInto; +use std::{convert::TryInto, sync::Arc}; #[cfg_attr( feature = "conduit_bin", @@ -21,7 +21,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn set_displayname_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_display_name::Request<'_>>, ) -> ConduitResult<set_display_name::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -107,7 +107,7 @@ pub async fn set_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_displayname_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_display_name::Request<'_>>, ) -> ConduitResult<get_display_name::Response> { Ok(get_display_name::Response { @@ -122,7 +122,7 @@ pub async fn get_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_avatar_url_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_avatar_url::Request<'_>>, ) -> ConduitResult<set_avatar_url::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -208,7 +208,7 @@ pub async fn set_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_avatar_url_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_avatar_url::Request<'_>>, ) -> ConduitResult<get_avatar_url::Response> { Ok(get_avatar_url::Response { @@ -223,7 +223,7 @@ pub async fn get_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_profile_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_profile::Request<'_>>, ) -> ConduitResult<get_profile::Response> { if !db.users.exists(&body.user_id)? { diff --git a/src/client_server/push.rs b/src/client_server/push.rs index e37e660d4f0f0b5ac0d1ecd4d7e33e480df067ea..d6f62126d0f6a62d91f2837ca2305578862ccbb4 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::{ @@ -22,7 +24,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrules_all_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_pushrules_all::Request>, ) -> ConduitResult<get_pushrules_all::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -47,7 +49,7 @@ pub async fn get_pushrules_all_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_pushrule::Request<'_>>, ) -> ConduitResult<get_pushrule::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -101,7 +103,7 @@ pub async fn get_pushrule_route( )] #[tracing::instrument(skip(db, req))] pub async fn set_pushrule_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, req: Ruma<set_pushrule::Request<'_>>, ) -> ConduitResult<set_pushrule::Response> { let sender_user = req.sender_user.as_ref().expect("user is authenticated"); @@ -204,7 +206,7 @@ pub async fn set_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_actions_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_pushrule_actions::Request<'_>>, ) -> ConduitResult<get_pushrule_actions::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -263,7 +265,7 @@ pub async fn get_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_actions_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_pushrule_actions::Request<'_>>, ) -> ConduitResult<set_pushrule_actions::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -337,7 +339,7 @@ pub async fn set_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_enabled_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_pushrule_enabled::Request<'_>>, ) -> ConduitResult<get_pushrule_enabled::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -398,7 +400,7 @@ pub async fn get_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_enabled_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_pushrule_enabled::Request<'_>>, ) -> ConduitResult<set_pushrule_enabled::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -477,7 +479,7 @@ pub async fn set_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_pushrule_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_pushrule::Request<'_>>, ) -> ConduitResult<delete_pushrule::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -546,7 +548,7 @@ pub async fn delete_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushers_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_pushers::Request>, ) -> ConduitResult<get_pushers::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -563,7 +565,7 @@ pub async fn get_pushers_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushers_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_pusher::Request>, ) -> ConduitResult<set_pusher::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 1b7ea0b57e41bf9507e3ada52e818b9fdf95b051..837170ff27353c1b37ce47034633bb2832f146e1 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -12,7 +12,7 @@ #[cfg(feature = "conduit_bin")] use rocket::post; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[cfg_attr( feature = "conduit_bin", @@ -20,7 +20,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn set_read_marker_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<set_read_marker::Request<'_>>, ) -> ConduitResult<set_read_marker::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -87,7 +87,7 @@ pub async fn set_read_marker_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_receipt_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_receipt::Request<'_>>, ) -> ConduitResult<create_receipt::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index be5d3b1179d1634986a3ac99648f296ffb706e90..e19308239e50a3dba8e4f24e0c8c4ec0a2facb1f 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -4,6 +4,7 @@ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, }; +use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::put; @@ -14,7 +15,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn redact_event_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<redact_event::Request<'_>>, ) -> ConduitResult<redact_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 0bc67d4d9e11c15428266f515b534b40fbb77889..7c507750c2e8e22739277c0ec28453149a4f192b 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -13,7 +13,7 @@ serde::Raw, RoomAliasId, RoomId, RoomVersionId, }; -use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; +use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; @@ -24,7 +24,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn create_room_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_room::Request<'_>>, ) -> ConduitResult<create_room::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -304,7 +304,7 @@ pub async fn create_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_event_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_room_event::Request<'_>>, ) -> ConduitResult<get_room_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -332,7 +332,7 @@ pub async fn get_room_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<upgrade_room::Request<'_>>, _room_id: String, ) -> ConduitResult<upgrade_room::Response> { diff --git a/src/client_server/search.rs b/src/client_server/search.rs index a668a0d0b667f990d85ebf751e51012d892ff276..ef5ddc2ffecf41ac45cd4a84c3a071c9d8644986 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,6 +1,7 @@ use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; +use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::post; @@ -13,7 +14,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn search_events_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<search_events::Request<'_>>, ) -> ConduitResult<search_events::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/session.rs b/src/client_server/session.rs index 3718003af9b96e62c2adf802d7052515b253a1ff..9a75ae287647f39c35e433020f02b37f8e3fd083 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; use crate::{utils, ConduitResult, Database, Error, Ruma}; use log::info; @@ -50,7 +52,7 @@ pub async fn get_login_types_route() -> ConduitResult<get_login_types::Response> )] #[tracing::instrument(skip(db, body))] pub async fn login_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<login::Request<'_>>, ) -> ConduitResult<login::Response> { // Validate login method @@ -167,7 +169,7 @@ pub async fn login_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<logout::Request>, ) -> ConduitResult<logout::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -195,7 +197,7 @@ pub async fn logout_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_all_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<logout_all::Request>, ) -> ConduitResult<logout_all::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 68e0c7f1971431723380f4634a5ac3cbb96444ae..c431ac0d18cabc8573d478543bde13003f844966 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ @@ -25,7 +27,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_key_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<send_state_event::Request<'_>>, ) -> ConduitResult<send_state_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -51,7 +53,7 @@ pub async fn send_state_event_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_empty_key_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<send_state_event::Request<'_>>, ) -> ConduitResult<send_state_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -77,7 +79,7 @@ pub async fn send_state_event_for_empty_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_state_events::Request<'_>>, ) -> ConduitResult<get_state_events::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -124,7 +126,7 @@ pub async fn get_state_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_key_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_state_events_for_key::Request<'_>>, ) -> ConduitResult<get_state_events_for_key::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -175,7 +177,7 @@ pub async fn get_state_events_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_empty_key_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_state_events_for_key::Request<'_>>, ) -> ConduitResult<get_state_events_for_key::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 63ad590dc09d7f2f24d403365182d5920a54f097..2f28706a7eb4c8d1bb01ee39a97f56277e8aa049 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{ConduitResult, Database, Error, Result, Ruma}; use log::error; use ruma::{ api::client::r0::sync::sync_events, @@ -13,6 +13,7 @@ use std::{ collections::{hash_map, BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, + sync::Arc, time::Duration, }; @@ -33,7 +34,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn sync_events_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<sync_events::Request<'_>>, ) -> ConduitResult<sync_events::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -71,18 +72,23 @@ pub async fn sync_events_route( let mut non_timeline_pdus = db .rooms - .pdus_since(&sender_user, &room_id, since)? + .pdus_until(&sender_user, &room_id, u64::MAX) .filter_map(|r| { + // Filter out buggy events if r.is_err() { error!("Bad pdu in pdus_since: {:?}", r); } r.ok() - }); // Filter out buggy events + }) + .take_while(|(pduid, _)| { + db.rooms + .pdu_count(pduid) + .map_or(false, |count| count > since) + }); // Take the last 10 events for the timeline let timeline_pdus = non_timeline_pdus .by_ref() - .rev() .take(10) .collect::<Vec<_>>() .into_iter() @@ -226,7 +232,7 @@ pub async fn sync_events_route( match (since_membership, current_membership) { (MembershipState::Leave, MembershipState::Join) => { // A new user joined an encrypted room - if !share_encrypted_room(&db, &sender_user, &user_id, &room_id) { + if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { device_list_updates.insert(user_id); } } @@ -257,6 +263,7 @@ pub async fn sync_events_route( .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already !share_encrypted_room(&db, sender_user, user_id, &room_id) + .unwrap_or(false) }), ); } @@ -274,7 +281,7 @@ pub async fn sync_events_route( for hero in db .rooms - .all_pdus(&sender_user, &room_id)? + .all_pdus(&sender_user, &room_id) .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .map(|(_, pdu)| { @@ -411,7 +418,7 @@ pub async fn sync_events_route( let mut edus = db .rooms .edus - .readreceipts_since(&room_id, since)? + .readreceipts_since(&room_id, since) .filter_map(|r| r.ok()) // Filter out buggy events .map(|(_, _, v)| v) .collect::<Vec<_>>(); @@ -549,7 +556,7 @@ pub async fn sync_events_route( for user_id in left_encrypted_users { let still_share_encrypted_room = db .rooms - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) + .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(|r| r.ok()) .filter_map(|other_room_id| { Some( @@ -639,9 +646,10 @@ fn share_encrypted_room( sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, -) -> bool { - db.rooms - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) +) -> Result<bool> { + Ok(db + .rooms + .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(|r| r.ok()) .filter(|room_id| room_id != ignore_room) .filter_map(|other_room_id| { @@ -652,5 +660,5 @@ fn share_encrypted_room( .is_some(), ) }) - .any(|encrypted| encrypted) + .any(|encrypted| encrypted)) } diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 63e70ffbbacf3cac1f2ca0b672a50e9974090e3d..2382fe0a05312489102c9b341a9ce1e0bb99439f 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -4,7 +4,7 @@ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, }; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[cfg(feature = "conduit_bin")] use rocket::{delete, get, put}; @@ -15,7 +15,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn update_tag_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_tag::Request<'_>>, ) -> ConduitResult<create_tag::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -52,7 +52,7 @@ pub async fn update_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_tag_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<delete_tag::Request<'_>>, ) -> ConduitResult<delete_tag::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn delete_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_tags_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_tags::Request<'_>>, ) -> ConduitResult<get_tags::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index 460bd05720b69f1dcd3475c176a7779a7b68b209..f2a97abdbbe9a80898548db7b07ddc243ae57607 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ @@ -14,7 +16,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn send_event_to_device_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<send_event_to_device::Request<'_>>, ) -> ConduitResult<send_event_to_device::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index 4b7feb7f01af70ed5272beb633afbee4be513d44..a0a5d430db8e71f41a3f18e384dad53852281e39 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{utils, ConduitResult, Database, Ruma}; use create_typing_event::Typing; @@ -12,7 +14,7 @@ )] #[tracing::instrument(skip(db, body))] pub fn create_typing_event_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<create_typing_event::Request<'_>>, ) -> ConduitResult<create_typing_event::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index b358274626ad26d79e02d6cd12b311b556d25b88..0ddc7e85bb7a7e90507b187d3009ffdf2ca98aa5 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::State; use crate::{ConduitResult, Database, Ruma}; use ruma::api::client::r0::user_directory::search_users; @@ -11,7 +13,7 @@ )] #[tracing::instrument(skip(db, body))] pub async fn search_users_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<search_users::Request<'_>>, ) -> ConduitResult<search_users::Response> { let limit = u64::from(body.limit) as usize; diff --git a/src/database.rs b/src/database.rs index 7a55b030e412513d767424887319aa500f872824..e00bdcd521b1e7a36362057fa21cb9eb4260f4fc 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,3 +1,5 @@ +pub mod abstraction; + pub mod account_data; pub mod admin; pub mod appservice; @@ -12,15 +14,16 @@ pub mod users; use crate::{utils, Error, Result}; +use abstraction::DatabaseEngine; use directories::ProjectDirs; -use futures::StreamExt; -use log::{error, info}; -use rocket::futures::{self, channel::mpsc}; +use log::error; +use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; use std::{ collections::HashMap, - fs::remove_dir_all, + fs::{self, remove_dir_all}, + io::Write, sync::{Arc, RwLock}, }; use tokio::sync::Semaphore; @@ -74,7 +77,12 @@ fn default_log() -> String { "info,state_res=warn,rocket=off,_=off,sled=off".to_owned() } -#[derive(Clone)] +#[cfg(feature = "sled")] +pub type Engine = abstraction::SledEngine; + +#[cfg(feature = "rocksdb")] +pub type Engine = abstraction::RocksDbEngine; + pub struct Database { pub globals: globals::Globals, pub users: users::Users, @@ -88,7 +96,6 @@ pub struct Database { pub admin: admin::Admin, pub appservice: appservice::Appservice, pub pusher: pusher::PushData, - pub _db: sled::Db, } impl Database { @@ -105,126 +112,126 @@ pub fn try_remove(server_name: &str) -> Result<()> { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result<Self> { - let db = sled::Config::default() - .path(&config.database_path) - .cache_capacity(config.cache_capacity as u64) - .use_compression(true) - .open()?; + pub async fn load_or_create(config: Config) -> Result<Arc<Self>> { + let builder = Engine::open(&config)?; if config.max_request_size < 1024 { eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); } let (admin_sender, admin_receiver) = mpsc::unbounded(); + let (sending_sender, sending_receiver) = mpsc::unbounded(); - let db = Self { + let db = Arc::new(Self { users: users::Users { - userid_password: db.open_tree("userid_password")?, - userid_displayname: db.open_tree("userid_displayname")?, - userid_avatarurl: db.open_tree("userid_avatarurl")?, - userdeviceid_token: db.open_tree("userdeviceid_token")?, - userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: db.open_tree("userid_devicelistversion")?, - token_userdeviceid: db.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: db.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: db.open_tree("keychangeid_userid")?, - keyid_key: db.open_tree("keyid_key")?, - userid_masterkeyid: db.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: db.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: db.open_tree("userid_usersigningkeyid")?, - todeviceid_events: db.open_tree("todeviceid_events")?, + userid_password: builder.open_tree("userid_password")?, + userid_displayname: builder.open_tree("userid_displayname")?, + userid_avatarurl: builder.open_tree("userid_avatarurl")?, + userdeviceid_token: builder.open_tree("userdeviceid_token")?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, + token_userdeviceid: builder.open_tree("token_userdeviceid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, + keychangeid_userid: builder.open_tree("keychangeid_userid")?, + keyid_key: builder.open_tree("keyid_key")?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + todeviceid_events: builder.open_tree("todeviceid_events")?, }, uiaa: uiaa::Uiaa { - userdevicesessionid_uiaainfo: db.open_tree("userdevicesessionid_uiaainfo")?, - userdevicesessionid_uiaarequest: db.open_tree("userdevicesessionid_uiaarequest")?, + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaarequest: builder + .open_tree("userdevicesessionid_uiaarequest")?, }, rooms: rooms::Rooms { edus: rooms::RoomEdus { - readreceiptid_readreceipt: db.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: db.open_tree("roomuserid_privateread")?, // "Private" read receipt - roomuserid_lastprivatereadupdate: db + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + roomuserid_lastprivatereadupdate: builder .open_tree("roomuserid_lastprivatereadupdate")?, - typingid_userid: db.open_tree("typingid_userid")?, - roomid_lasttypingupdate: db.open_tree("roomid_lasttypingupdate")?, - presenceid_presence: db.open_tree("presenceid_presence")?, - userid_lastpresenceupdate: db.open_tree("userid_lastpresenceupdate")?, + typingid_userid: builder.open_tree("typingid_userid")?, + roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, + presenceid_presence: builder.open_tree("presenceid_presence")?, + userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, }, - pduid_pdu: db.open_tree("pduid_pdu")?, - eventid_pduid: db.open_tree("eventid_pduid")?, - roomid_pduleaves: db.open_tree("roomid_pduleaves")?, - - alias_roomid: db.open_tree("alias_roomid")?, - aliasid_alias: db.open_tree("aliasid_alias")?, - publicroomids: db.open_tree("publicroomids")?, - - tokenids: db.open_tree("tokenids")?, - - roomserverids: db.open_tree("roomserverids")?, - serverroomids: db.open_tree("serverroomids")?, - userroomid_joined: db.open_tree("userroomid_joined")?, - roomuserid_joined: db.open_tree("roomuserid_joined")?, - roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: db.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: db.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: db.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: db.open_tree("roomuserid_leftcount")?, - - userroomid_notificationcount: db.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: db.open_tree("userroomid_highlightcount")?, - - statekey_shortstatekey: db.open_tree("statekey_shortstatekey")?, - stateid_shorteventid: db.open_tree("stateid_shorteventid")?, - eventid_shorteventid: db.open_tree("eventid_shorteventid")?, - shorteventid_eventid: db.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: db.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: db.open_tree("roomid_shortstatehash")?, - statehash_shortstatehash: db.open_tree("statehash_shortstatehash")?, - - eventid_outlierpdu: db.open_tree("eventid_outlierpdu")?, - prevevent_parent: db.open_tree("prevevent_parent")?, + pduid_pdu: builder.open_tree("pduid_pdu")?, + eventid_pduid: builder.open_tree("eventid_pduid")?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + + alias_roomid: builder.open_tree("alias_roomid")?, + aliasid_alias: builder.open_tree("aliasid_alias")?, + publicroomids: builder.open_tree("publicroomids")?, + + tokenids: builder.open_tree("tokenids")?, + + roomserverids: builder.open_tree("roomserverids")?, + serverroomids: builder.open_tree("serverroomids")?, + userroomid_joined: builder.open_tree("userroomid_joined")?, + roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, + + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, + stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, + prevevent_parent: builder.open_tree("prevevent_parent")?, }, account_data: account_data::AccountData { - roomuserdataid_accountdata: db.open_tree("roomuserdataid_accountdata")?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, }, media: media::Media { - mediaid_file: db.open_tree("mediaid_file")?, + mediaid_file: builder.open_tree("mediaid_file")?, }, key_backups: key_backups::KeyBackups { - backupid_algorithm: db.open_tree("backupid_algorithm")?, - backupid_etag: db.open_tree("backupid_etag")?, - backupkeyid_backup: db.open_tree("backupkeyid_backup")?, + backupid_algorithm: builder.open_tree("backupid_algorithm")?, + backupid_etag: builder.open_tree("backupid_etag")?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, }, transaction_ids: transaction_ids::TransactionIds { - userdevicetxnid_response: db.open_tree("userdevicetxnid_response")?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, }, sending: sending::Sending { - servername_educount: db.open_tree("servername_educount")?, - servernamepduids: db.open_tree("servernamepduids")?, - servercurrentevents: db.open_tree("servercurrentevents")?, + servername_educount: builder.open_tree("servername_educount")?, + servernamepduids: builder.open_tree("servernamepduids")?, + servercurrentevents: builder.open_tree("servercurrentevents")?, maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), + sender: sending_sender, }, admin: admin::Admin { sender: admin_sender, }, appservice: appservice::Appservice { cached_registrations: Arc::new(RwLock::new(HashMap::new())), - id_appserviceregistrations: db.open_tree("id_appserviceregistrations")?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, + }, + pusher: pusher::PushData { + senderkey_pusher: builder.open_tree("senderkey_pusher")?, }, - pusher: pusher::PushData::new(&db)?, globals: globals::Globals::load( - db.open_tree("global")?, - db.open_tree("server_signingkeys")?, + builder.open_tree("global")?, + builder.open_tree("server_signingkeys")?, config, )?, - _db: db, - }; + }); // MIGRATIONS + // TODO: database versions of new dbs should probably not be 0 if db.globals.database_version()? < 1 { - for roomserverid in db.rooms.roomserverids.iter().keys() { - let roomserverid = roomserverid?; + for (roomserverid, _) in db.rooms.roomserverids.iter() { let mut parts = roomserverid.split(|&b| b == 0xff); let room_id = parts.next().expect("split always returns one element"); let servername = match parts.next() { @@ -238,37 +245,55 @@ pub async fn load_or_create(config: Config) -> Result<Self> { serverroomid.push(0xff); serverroomid.extend_from_slice(room_id); - db.rooms.serverroomids.insert(serverroomid, &[])?; + db.rooms.serverroomids.insert(&serverroomid, &[])?; } db.globals.bump_database_version(1)?; - info!("Migration: 0 -> 1 finished"); + println!("Migration: 0 -> 1 finished"); } if db.globals.database_version()? < 2 { // We accidentally inserted hashed versions of "" into the db instead of just "" - for userid_password in db.users.userid_password.iter() { - let (userid, password) = userid_password?; - + for (userid, password) in db.users.userid_password.iter() { let password = utils::string_from_bytes(&password); - if password.map_or(false, |password| { + let empty_hashed_password = password.map_or(false, |password| { argon2::verify_encoded(&password, b"").unwrap_or(false) - }) { - db.users.userid_password.insert(userid, b"")?; + }); + + if empty_hashed_password { + db.users.userid_password.insert(&userid, b"")?; } } db.globals.bump_database_version(2)?; - info!("Migration: 1 -> 2 finished"); + println!("Migration: 1 -> 2 finished"); } + if db.globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.media.mediaid_file.iter() { + if content.len() == 0 { + continue; + } + + let path = db.globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.media.mediaid_file.insert(&key, &[])?; + } + + db.globals.bump_database_version(3)?; + + println!("Migration: 2 -> 3 finished"); + } // This data is probably outdated db.rooms.edus.presenceid_presence.clear()?; - db.admin.start_handler(db.clone(), admin_receiver); + db.admin.start_handler(Arc::clone(&db), admin_receiver); + db.sending.start_handler(Arc::clone(&db), sending_receiver); Ok(db) } @@ -282,7 +307,7 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); userdeviceid_prefix.push(0xff); - let mut futures = futures::stream::FuturesUnordered::new(); + let mut futures = FuturesUnordered::new(); // Return when *any* user changed his key // TODO: only send for user they share a room with diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs new file mode 100644 index 0000000000000000000000000000000000000000..f81c9def2c306aeaf530182f8a45e7588b5e5052 --- /dev/null +++ b/src/database/abstraction.rs @@ -0,0 +1,329 @@ +use super::Config; +use crate::{utils, Result}; +use log::warn; +use std::{future::Future, pin::Pin, sync::Arc}; + +#[cfg(feature = "rocksdb")] +use std::{collections::BTreeMap, sync::RwLock}; + +#[cfg(feature = "sled")] +pub struct SledEngine(sled::Db); +#[cfg(feature = "sled")] +pub struct SledEngineTree(sled::Tree); + +#[cfg(feature = "rocksdb")] +pub struct RocksDbEngine(rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>); +#[cfg(feature = "rocksdb")] +pub struct RocksDbEngineTree<'a> { + db: Arc<RocksDbEngine>, + name: &'a str, + watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>, +} + +pub trait DatabaseEngine: Sized { + fn open(config: &Config) -> Result<Arc<Self>>; + fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>>; +} + +pub trait Tree: Send + Sync { + fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; + + fn remove(&self, key: &[u8]) -> Result<()>; + + fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a>; + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>; + + fn increment(&self, key: &[u8]) -> Result<Vec<u8>>; + + fn scan_prefix<'a>( + &'a self, + prefix: Vec<u8>, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a>; + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>; + + fn clear(&self) -> Result<()> { + for (key, _) in self.iter() { + self.remove(&key)?; + } + + Ok(()) + } +} + +#[cfg(feature = "sled")] +impl DatabaseEngine for SledEngine { + fn open(config: &Config) -> Result<Arc<Self>> { + Ok(Arc::new(SledEngine( + sled::Config::default() + .path(&config.database_path) + .cache_capacity(config.cache_capacity as u64) + .use_compression(true) + .open()?, + ))) + } + + fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> { + Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) + } +} + +#[cfg(feature = "sled")] +impl Tree for SledEngineTree { + fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { + Ok(self.0.get(key)?.map(|v| v.to_vec())) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + self.0.insert(key, value)?; + Ok(()) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + self.0.remove(key)?; + Ok(()) + } + + fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> { + Box::new( + self.0 + .iter() + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), + ) + } + + fn iter_from( + &self, + from: &[u8], + backwards: bool, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)>> { + let iter = if backwards { + self.0.range(..from) + } else { + self.0.range(from..) + }; + + let iter = iter + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + if backwards { + Box::new(iter.rev()) + } else { + Box::new(iter) + } + } + + fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { + Ok(self + .0 + .update_and_fetch(key, utils::increment) + .map(|o| o.expect("increment always sets a value").to_vec())?) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec<u8>, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> { + let iter = self + .0 + .scan_prefix(prefix) + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + Box::new(iter) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { + let prefix = prefix.to_vec(); + Box::pin(async move { + self.0.watch_prefix(prefix).await; + }) + } +} + +#[cfg(feature = "rocksdb")] +impl DatabaseEngine for RocksDbEngine { + fn open(config: &Config) -> Result<Arc<Self>> { + let mut db_opts = rocksdb::Options::default(); + db_opts.create_if_missing(true); + db_opts.set_max_open_files(16); + db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); + db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); + db_opts.set_target_file_size_base(256 << 20); + db_opts.set_write_buffer_size(256 << 20); + + let mut block_based_options = rocksdb::BlockBasedOptions::default(); + block_based_options.set_block_size(512 << 10); + db_opts.set_block_based_table_factory(&block_based_options); + + let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf( + &db_opts, + &config.database_path, + ) + .unwrap_or_default(); + + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors( + &db_opts, + &config.database_path, + cfs.iter() + .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), + )?; + + Ok(Arc::new(RocksDbEngine(db))) + } + + fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> { + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + // Create if it doesn't exist + let _ = self.0.create_cf(name, &options); + + Ok(Arc::new(RocksDbEngineTree { + name, + db: Arc::clone(self), + watchers: RwLock::new(BTreeMap::new()), + })) + } +} + +#[cfg(feature = "rocksdb")] +impl RocksDbEngineTree<'_> { + fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { + self.db.0.cf_handle(self.name).unwrap() + } +} + +#[cfg(feature = "rocksdb")] +impl Tree for RocksDbEngineTree<'_> { + fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { + Ok(self.db.0.get_cf(self.cf(), key)?) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let watchers = self.watchers.read().unwrap(); + let mut triggered = Vec::new(); + + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } + + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write().unwrap(); + for prefix in triggered { + if let Some(txs) = watchers.remove(prefix) { + for tx in txs { + let _ = tx.send(()); + } + } + } + } + + Ok(self.db.0.put_cf(self.cf(), key, value)?) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + Ok(self.db.0.delete_cf(self.cf(), key)?) + } + + fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> { + Box::new( + self.db + .0 + .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), + ) + } + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> { + Box::new(self.db.0.iterator_cf( + self.cf(), + rocksdb::IteratorMode::From( + from, + if backwards { + rocksdb::Direction::Reverse + } else { + rocksdb::Direction::Forward + }, + ), + )) + } + + fn increment(&self, key: &[u8]) -> Result<Vec<u8>> { + let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); + dbg!(stats.mem_table_total); + dbg!(stats.mem_table_unflushed); + dbg!(stats.mem_table_readers_total); + dbg!(stats.cache_total); + // TODO: atomic? + let old = self.get(key)?; + let new = utils::increment(old.as_deref()).unwrap(); + self.insert(key, &new)?; + Ok(new) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec<u8>, + ) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + 'a> { + Box::new( + self.db + .0 + .iterator_cf( + self.cf(), + rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), + ) + .take_while(move |(k, _)| k.starts_with(&prefix)), + ) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.watchers + .write() + .unwrap() + .entry(prefix.to_vec()) + .or_default() + .push(tx); + + Box::pin(async move { + // Tx is never destroyed + rx.await.unwrap(); + }) + } +} diff --git a/src/database/account_data.rs b/src/database/account_data.rs index bb970c3459e2ec859ff0120272e6e239934a8a2b..2ba7bc3dc6e4f415741549a7bed5026e6b0ee087 100644 --- a/src/database/account_data.rs +++ b/src/database/account_data.rs @@ -6,12 +6,12 @@ RoomId, UserId, }; use serde::{de::DeserializeOwned, Serialize}; -use sled::IVec; -use std::{collections::HashMap, convert::TryFrom}; +use std::{collections::HashMap, convert::TryFrom, sync::Arc}; + +use super::abstraction::Tree; -#[derive(Clone)] pub struct AccountData { - pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type + pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type } impl AccountData { @@ -34,9 +34,8 @@ pub fn update<T: Serialize>( prefix.push(0xff); // Remove old entry - if let Some(previous) = self.find_event(room_id, user_id, &event_type) { - let (old_key, _) = previous?; - self.roomuserdataid_accountdata.remove(old_key)?; + if let Some((old_key, _)) = self.find_event(room_id, user_id, &event_type)? { + self.roomuserdataid_accountdata.remove(&old_key)?; } let mut key = prefix; @@ -52,8 +51,10 @@ pub fn update<T: Serialize>( )); } - self.roomuserdataid_accountdata - .insert(key, &*json.to_string())?; + self.roomuserdataid_accountdata.insert( + &key, + &serde_json::to_vec(&json).expect("to_vec always works on json values"), + )?; Ok(()) } @@ -65,9 +66,8 @@ pub fn get<T: DeserializeOwned>( user_id: &UserId, kind: EventType, ) -> Result<Option<T>> { - self.find_event(room_id, user_id, &kind) - .map(|r| { - let (_, v) = r?; + self.find_event(room_id, user_id, &kind)? + .map(|(_, v)| { serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize")) }) .transpose() @@ -98,8 +98,7 @@ pub fn changes_since( for r in self .roomuserdataid_accountdata - .range(&*first_possible..) - .filter_map(|r| r.ok()) + .iter_from(&first_possible, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(|(k, v)| { Ok::<_, Error>(( @@ -128,7 +127,7 @@ fn find_event( room_id: Option<&RoomId>, user_id: &UserId, kind: &EventType, - ) -> Option<Result<(IVec, IVec)>> { + ) -> Result<Option<(Box<[u8]>, Box<[u8]>)>> { let mut prefix = room_id .map(|r| r.to_string()) .unwrap_or_default() @@ -137,23 +136,21 @@ fn find_event( prefix.push(0xff); prefix.extend_from_slice(&user_id.as_bytes()); prefix.push(0xff); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + let kind = kind.clone(); - self.roomuserdataid_accountdata - .scan_prefix(prefix) - .rev() - .find(move |r| { - r.as_ref() - .map(|(k, _)| { - k.rsplit(|&b| b == 0xff) - .next() - .map(|current_event_type| { - current_event_type == kind.as_ref().as_bytes() - }) - .unwrap_or(false) - }) + Ok(self + .roomuserdataid_accountdata + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .find(move |(k, _)| { + k.rsplit(|&b| b == 0xff) + .next() + .map(|current_event_type| current_event_type == kind.as_ref().as_bytes()) .unwrap_or(false) - }) - .map(|r| Ok(r?)) + })) } } diff --git a/src/database/admin.rs b/src/database/admin.rs index 30143859398d816bc28dc6c4e3168d3b007312eb..7826cfeacec86139969a851f7595b07ebc4705ff 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -1,6 +1,9 @@ -use std::convert::{TryFrom, TryInto}; +use std::{ + convert::{TryFrom, TryInto}, + sync::Arc, +}; -use crate::pdu::PduBuilder; +use crate::{pdu::PduBuilder, Database}; use log::warn; use rocket::futures::{channel::mpsc, stream::StreamExt}; use ruma::{ @@ -22,7 +25,7 @@ pub struct Admin { impl Admin { pub fn start_handler( &self, - db: super::Database, + db: Arc<Database>, mut receiver: mpsc::UnboundedReceiver<AdminCommand>, ) { tokio::spawn(async move { @@ -73,14 +76,17 @@ pub fn start_handler( db.appservice.register_appservice(yaml).unwrap(); // TODO handle error } AdminCommand::ListAppservices => { - let appservices = db.appservice.iter_ids().collect::<Vec<_>>(); - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ") - ); - send_message(message::MessageEventContent::text_plain(output)); + if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) { + let count = appservices.len(); + let output = format!( + "Appservices ({}): {}", + count, + appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ") + ); + send_message(message::MessageEventContent::text_plain(output)); + } else { + send_message(message::MessageEventContent::text_plain("Failed to get appservices.")); + } } AdminCommand::SendMessage(message) => { send_message(message); @@ -93,6 +99,6 @@ pub fn start_handler( } pub fn send(&self, command: AdminCommand) { - self.sender.unbounded_send(command).unwrap() + self.sender.unbounded_send(command).unwrap(); } } diff --git a/src/database/appservice.rs b/src/database/appservice.rs index 222eb182831a9d2b5c250d4dd8a67e4533ad5254..21b18a765dc817b0998c021a1a8a6dddfda317cd 100644 --- a/src/database/appservice.rs +++ b/src/database/appservice.rs @@ -4,18 +4,21 @@ sync::{Arc, RwLock}, }; -#[derive(Clone)] +use super::abstraction::Tree; + pub struct Appservice { pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>, - pub(super) id_appserviceregistrations: sled::Tree, + pub(super) id_appserviceregistrations: Arc<dyn Tree>, } impl Appservice { pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> { // TODO: Rumaify let id = yaml.get("id").unwrap().as_str().unwrap(); - self.id_appserviceregistrations - .insert(id, serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + self.id_appserviceregistrations.insert( + id.as_bytes(), + serde_yaml::to_string(&yaml).unwrap().as_bytes(), + )?; self.cached_registrations .write() .unwrap() @@ -33,7 +36,7 @@ pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> { || { Ok(self .id_appserviceregistrations - .get(id)? + .get(id.as_bytes())? .map(|bytes| { Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| { Error::bad_database( @@ -47,21 +50,25 @@ pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> { ) } - pub fn iter_ids(&self) -> impl Iterator<Item = Result<String>> { - self.id_appserviceregistrations.iter().keys().map(|id| { - Ok(utils::string_from_bytes(&id?).map_err(|_| { + pub fn iter_ids<'a>( + &'a self, + ) -> Result<impl Iterator<Item = Result<String>> + Send + Sync + 'a> { + Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { + Ok(utils::string_from_bytes(&id).map_err(|_| { Error::bad_database("Invalid id bytes in id_appserviceregistrations.") })?) - }) + })) } - pub fn iter_all(&self) -> impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ { - self.iter_ids().filter_map(|id| id.ok()).map(move |id| { + pub fn iter_all( + &self, + ) -> Result<impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ + Send + Sync> { + Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { Ok(( id.clone(), self.get_registration(&id)? .expect("iter_ids only returns appservices that exist"), )) - }) + })) } } diff --git a/src/database/globals.rs b/src/database/globals.rs index 5d91d37413bc6f36f90c909a764f8b4e59285fbe..552564437fac79311f3e7a1492bfeea8572127f8 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -7,28 +7,31 @@ use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ collections::{BTreeMap, HashMap}, + fs, + path::PathBuf, sync::{Arc, RwLock}, time::{Duration, Instant}, }; use tokio::sync::Semaphore; use trust_dns_resolver::TokioAsyncResolver; -pub const COUNTER: &str = "c"; +use super::abstraction::Tree; + +pub const COUNTER: &[u8] = b"c"; type WellKnownMap = HashMap<Box<ServerName>, (String, String)>; type TlsNameMap = HashMap<String, webpki::DNSName>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries -#[derive(Clone)] pub struct Globals { pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host pub tls_name_override: Arc<RwLock<TlsNameMap>>, - pub(super) globals: sled::Tree, + pub(super) globals: Arc<dyn Tree>, config: Config, keypair: Arc<ruma::signatures::Ed25519KeyPair>, reqwest_client: reqwest::Client, dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>, - pub(super) server_signingkeys: sled::Tree, + pub(super) server_signingkeys: Arc<dyn Tree>, pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>, pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>, pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>, @@ -69,15 +72,20 @@ fn verify_server_cert( impl Globals { pub fn load( - globals: sled::Tree, - server_signingkeys: sled::Tree, + globals: Arc<dyn Tree>, + server_signingkeys: Arc<dyn Tree>, config: Config, ) -> Result<Self> { - let bytes = &*globals - .update_and_fetch("keypair", utils::generate_keypair)? - .expect("utils::generate_keypair always returns Some"); + let keypair_bytes = globals.get(b"keypair")?.map_or_else( + || { + let keypair = utils::generate_keypair(); + globals.insert(b"keypair", &keypair)?; + Ok::<_, Error>(keypair) + }, + |s| Ok(s.to_vec()), + )?; - let mut parts = bytes.splitn(2, |&b| b == 0xff); + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); let keypair = utils::string_from_bytes( // 1. version @@ -102,7 +110,7 @@ pub fn load( Ok(k) => k, Err(e) => { error!("Keypair invalid. Deleting..."); - globals.remove("keypair")?; + globals.remove(b"keypair")?; return Err(e); } }; @@ -130,7 +138,7 @@ pub fn load( .as_ref() .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static()); - Ok(Self { + let s = Self { globals, config, keypair: Arc::new(keypair), @@ -145,7 +153,11 @@ pub fn load( bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), - }) + }; + + fs::create_dir_all(s.get_media_folder())?; + + Ok(s) } /// Returns this server's keypair. @@ -159,13 +171,8 @@ pub fn reqwest_client(&self) -> &reqwest::Client { } pub fn next_count(&self) -> Result<u64> { - Ok(utils::u64_from_bytes( - &self - .globals - .update_and_fetch(COUNTER, utils::increment)? - .expect("utils::increment will always put in a value"), - ) - .map_err(|_| Error::bad_database("Count has invalid bytes."))?) + Ok(utils::u64_from_bytes(&self.globals.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes."))?) } pub fn current_count(&self) -> Result<u64> { @@ -211,21 +218,30 @@ pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey<'_>> { /// Remove the outdated keys and insert the new ones. /// /// This doesn't actually check that the keys provided are newer than the old set. - pub fn add_signing_key(&self, origin: &ServerName, new_keys: &ServerSigningKeys) -> Result<()> { - self.server_signingkeys - .update_and_fetch(origin.as_bytes(), |signingkeys| { - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - keys.verify_keys - .extend(new_keys.verify_keys.clone().into_iter()); - keys.old_verify_keys - .extend(new_keys.old_verify_keys.clone().into_iter()); - Some(serde_json::to_vec(&keys).expect("serversigningkeys can be serialized")) - })?; + pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; + + keys.verify_keys.extend(verify_keys.into_iter()); + keys.old_verify_keys.extend(old_verify_keys.into_iter()); + + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; Ok(()) } @@ -254,14 +270,30 @@ pub fn signing_keys_for( } pub fn database_version(&self) -> Result<u64> { - self.globals.get("version")?.map_or(Ok(0), |version| { + self.globals.get(b"version")?.map_or(Ok(0), |version| { utils::u64_from_bytes(&version) .map_err(|_| Error::bad_database("Database version id is invalid.")) }) } pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.globals.insert("version", &new_version.to_be_bytes())?; + self.globals + .insert(b"version", &new_version.to_be_bytes())?; Ok(()) } + + pub fn get_media_folder(&self) -> PathBuf { + let mut r = PathBuf::new(); + r.push(self.config.database_path.clone()); + r.push("media"); + r + } + + pub fn get_media_file(&self, key: &[u8]) -> PathBuf { + let mut r = PathBuf::new(); + r.push(self.config.database_path.clone()); + r.push("media"); + r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD)); + r + } } diff --git a/src/database/key_backups.rs b/src/database/key_backups.rs index 0f9af2eb9b01f45da4ae0771ef61a5f6c6216df3..0685c4820a7710e3de805530ae3eef6d0ffba906 100644 --- a/src/database/key_backups.rs +++ b/src/database/key_backups.rs @@ -6,13 +6,14 @@ }, RoomId, UserId, }; -use std::{collections::BTreeMap, convert::TryFrom}; +use std::{collections::BTreeMap, convert::TryFrom, sync::Arc}; + +use super::abstraction::Tree; -#[derive(Clone)] pub struct KeyBackups { - pub(super) backupid_algorithm: sled::Tree, // BackupId = UserId + Version(Count) - pub(super) backupid_etag: sled::Tree, // BackupId = UserId + Version(Count) - pub(super) backupkeyid_backup: sled::Tree, // BackupKeyId = UserId + Version + RoomId + SessionId + pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count) + pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count) + pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId } impl KeyBackups { @@ -30,8 +31,7 @@ pub fn create_backup( self.backupid_algorithm.insert( &key, - &*serde_json::to_string(backup_metadata) - .expect("BackupAlgorithm::to_string always works"), + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag .insert(&key, &globals.next_count()?.to_be_bytes())?; @@ -48,13 +48,8 @@ pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { key.push(0xff); - for outdated_key in self - .backupkeyid_backup - .scan_prefix(&key) - .keys() - .filter_map(|r| r.ok()) - { - self.backupkeyid_backup.remove(outdated_key)?; + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; } Ok(()) @@ -80,8 +75,9 @@ pub fn update_backup( self.backupid_algorithm.insert( &key, - &*serde_json::to_string(backup_metadata) - .expect("BackupAlgorithm::to_string always works"), + &serde_json::to_string(backup_metadata) + .expect("BackupAlgorithm::to_string always works") + .as_bytes(), )?; self.backupid_etag .insert(&key, &globals.next_count()?.to_be_bytes())?; @@ -91,11 +87,14 @@ pub fn update_backup( pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + self.backupid_algorithm - .scan_prefix(&prefix) - .last() - .map_or(Ok(None), |r| { - let (key, value) = r?; + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map_or(Ok(None), |(key, value)| { let version = utils::string_from_bytes( key.rsplit(|&b| b == 0xff) .next() @@ -117,10 +116,13 @@ pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Backu key.push(0xff); key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm.get(key)?.map_or(Ok(None), |bytes| { - Ok(serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?) - }) + self.backupid_algorithm + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Algorithm in backupid_algorithm is invalid.") + })?) + }) } pub fn add_key( @@ -153,7 +155,7 @@ pub fn add_key( self.backupkeyid_backup.insert( &key, - &*serde_json::to_string(&key_data).expect("KeyBackupData::to_string always works"), + &serde_json::to_vec(&key_data).expect("KeyBackupData::to_vec always works"), )?; Ok(()) @@ -164,7 +166,7 @@ pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { prefix.push(0xff); prefix.extend_from_slice(version.as_bytes()); - Ok(self.backupkeyid_backup.scan_prefix(&prefix).count()) + Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) } pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { @@ -194,33 +196,37 @@ pub fn get_all( let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new(); - for result in self.backupkeyid_backup.scan_prefix(&prefix).map(|r| { - let (key, value) = r?; - let mut parts = key.rsplit(|&b| b == 0xff); + for result in self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xff); - let session_id = utils::string_from_bytes( - &parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; + let session_id = + utils::string_from_bytes(&parts.next().ok_or_else(|| { + Error::bad_database("backupkeyid_backup key is invalid.") + })?) + .map_err(|_| { + Error::bad_database("backupkeyid_backup session_id is invalid.") + })?; - let room_id = RoomId::try_from( - utils::string_from_bytes( - &parts - .next() - .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + let room_id = RoomId::try_from( + utils::string_from_bytes(&parts.next().ok_or_else(|| { + Error::bad_database("backupkeyid_backup key is invalid.") + })?) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; + .map_err(|_| { + Error::bad_database("backupkeyid_backup room_id is invalid room id.") + })?; - let key_data = serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") - })?; + let key_data = serde_json::from_slice(&value).map_err(|_| { + Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") + })?; - Ok::<_, Error>((room_id, session_id, key_data)) - }) { + Ok::<_, Error>((room_id, session_id, key_data)) + }) + { let (room_id, session_id, key_data) = result?; rooms .entry(room_id) @@ -239,7 +245,7 @@ pub fn get_room( user_id: &UserId, version: &str, room_id: &RoomId, - ) -> BTreeMap<String, KeyBackupData> { + ) -> Result<BTreeMap<String, KeyBackupData>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(version.as_bytes()); @@ -247,10 +253,10 @@ pub fn get_room( prefix.extend_from_slice(room_id.as_bytes()); prefix.push(0xff); - self.backupkeyid_backup - .scan_prefix(&prefix) - .map(|r| { - let (key, value) = r?; + Ok(self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { let mut parts = key.rsplit(|&b| b == 0xff); let session_id = @@ -268,7 +274,7 @@ pub fn get_room( Ok::<_, Error>((session_id, key_data)) }) .filter_map(|r| r.ok()) - .collect() + .collect()) } pub fn get_session( @@ -302,13 +308,8 @@ pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { key.extend_from_slice(&version.as_bytes()); key.push(0xff); - for outdated_key in self - .backupkeyid_backup - .scan_prefix(&key) - .keys() - .filter_map(|r| r.ok()) - { - self.backupkeyid_backup.remove(outdated_key)?; + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; } Ok(()) @@ -327,13 +328,8 @@ pub fn delete_room_keys( key.extend_from_slice(&room_id.as_bytes()); key.push(0xff); - for outdated_key in self - .backupkeyid_backup - .scan_prefix(&key) - .keys() - .filter_map(|r| r.ok()) - { - self.backupkeyid_backup.remove(outdated_key)?; + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; } Ok(()) @@ -354,13 +350,8 @@ pub fn delete_room_key( key.push(0xff); key.extend_from_slice(&session_id.as_bytes()); - for outdated_key in self - .backupkeyid_backup - .scan_prefix(&key) - .keys() - .filter_map(|r| r.ok()) - { - self.backupkeyid_backup.remove(outdated_key)?; + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; } Ok(()) diff --git a/src/database/media.rs b/src/database/media.rs index 28ef88a2ef2fac6b2f7760265aaf60da776b634e..944c5bdce1124e1220e478a997880f43fa4a3a1e 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -1,7 +1,10 @@ +use crate::database::globals::Globals; use image::{imageops::FilterType, GenericImageView}; +use super::abstraction::Tree; use crate::{utils, Error, Result}; -use std::mem; +use std::{mem, sync::Arc}; +use tokio::{fs::File, io::AsyncReadExt, io::AsyncWriteExt}; pub struct FileMeta { pub content_disposition: Option<String>, @@ -9,16 +12,16 @@ pub struct FileMeta { pub file: Vec<u8>, } -#[derive(Clone)] pub struct Media { - pub(super) mediaid_file: sled::Tree, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType + pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType } impl Media { - /// Uploads or replaces a file. - pub fn create( + /// Uploads a file. + pub async fn create( &self, mxc: String, + globals: &Globals, content_disposition: &Option<&str>, content_type: &Option<&str>, file: &[u8], @@ -42,15 +45,19 @@ pub fn create( .unwrap_or_default(), ); - self.mediaid_file.insert(key, file)?; + let path = globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + self.mediaid_file.insert(&key, &[])?; Ok(()) } /// Uploads or replaces a file thumbnail. - pub fn upload_thumbnail( + pub async fn upload_thumbnail( &self, mxc: String, + globals: &Globals, content_disposition: &Option<String>, content_type: &Option<String>, width: u32, @@ -76,21 +83,28 @@ pub fn upload_thumbnail( .unwrap_or_default(), ); - self.mediaid_file.insert(key, file)?; + let path = globals.get_media_file(&key); + let mut f = File::create(path).await?; + f.write_all(file).await?; + + self.mediaid_file.insert(&key, &[])?; Ok(()) } /// Downloads a file. - pub fn get(&self, mxc: &str) -> Result<Option<FileMeta>> { + pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> { let mut prefix = mxc.as_bytes().to_vec(); prefix.push(0xff); prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail prefix.push(0xff); - if let Some(r) = self.mediaid_file.scan_prefix(&prefix).next() { - let (key, file) = r?; + let mut iter = self.mediaid_file.scan_prefix(prefix); + if let Some((key, _)) = iter.next() { + let path = globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; let mut parts = key.rsplit(|&b| b == 0xff); let content_type = parts @@ -121,7 +135,7 @@ pub fn get(&self, mxc: &str) -> Result<Option<FileMeta>> { Ok(Some(FileMeta { content_disposition, content_type, - file: file.to_vec(), + file, })) } else { Ok(None) @@ -151,7 +165,13 @@ pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, /// - Server creates the thumbnail and sends it to the user /// /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. - pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> { + pub async fn get_thumbnail( + &self, + mxc: String, + globals: &Globals, + width: u32, + height: u32, + ) -> Result<Option<FileMeta>> { let (width, height, crop) = self .thumbnail_properties(width, height) .unwrap_or((0, 0, false)); // 0, 0 because that's the original file @@ -169,9 +189,11 @@ pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Opti original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.push(0xff); - if let Some(r) = self.mediaid_file.scan_prefix(&thumbnail_prefix).next() { + if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { // Using saved thumbnail - let (key, file) = r?; + let path = globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; let mut parts = key.rsplit(|&b| b == 0xff); let content_type = parts @@ -202,10 +224,12 @@ pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Opti content_type, file: file.to_vec(), })) - } else if let Some(r) = self.mediaid_file.scan_prefix(&original_prefix).next() { + } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { // Generate a thumbnail + let path = globals.get_media_file(&key); + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; - let (key, file) = r?; let mut parts = key.rsplit(|&b| b == 0xff); let content_type = parts @@ -302,7 +326,11 @@ pub fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Opti widthheight, ); - self.mediaid_file.insert(thumbnail_key, &*thumbnail_bytes)?; + let path = globals.get_media_file(&thumbnail_key); + let mut f = File::create(path).await?; + f.write_all(&thumbnail_bytes).await?; + + self.mediaid_file.insert(&thumbnail_key, &[])?; Ok(Some(FileMeta { content_disposition, diff --git a/src/database/pusher.rs b/src/database/pusher.rs index 51f55a17c3dad73c9ab4f928a002ab5d1bf1a792..39b631dc6d13773325255cf42b993dc9561d5794 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -14,23 +14,17 @@ push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, uint, UInt, UserId, }; -use sled::IVec; -use std::{convert::TryFrom, fmt::Debug, mem}; +use std::{convert::TryFrom, fmt::Debug, mem, sync::Arc}; + +use super::abstraction::Tree; -#[derive(Debug, Clone)] pub struct PushData { /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: sled::Tree, + pub(super) senderkey_pusher: Arc<dyn Tree>, } impl PushData { - pub fn new(db: &sled::Db) -> Result<Self> { - Ok(Self { - senderkey_pusher: db.open_tree("senderkey_pusher")?, - }) - } - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); @@ -40,14 +34,14 @@ pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result< if pusher.kind.is_none() { return self .senderkey_pusher - .remove(key) + .remove(&key) .map(|_| ()) .map_err(Into::into); } self.senderkey_pusher.insert( - key, - &*serde_json::to_string(&pusher).expect("Pusher is valid JSON string"), + &key, + &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), )?; Ok(()) @@ -69,23 +63,21 @@ pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::Pusher>> { self.senderkey_pusher .scan_prefix(prefix) - .values() - .map(|push| { - let push = push.map_err(|_| Error::bad_database("Invalid push bytes in db."))?; + .map(|(_, push)| { Ok(serde_json::from_slice(&*push) .map_err(|_| Error::bad_database("Invalid Pusher in db."))?) }) .collect() } - pub fn get_pusher_senderkeys(&self, sender: &UserId) -> impl Iterator<Item = Result<IVec>> { + pub fn get_pusher_senderkeys<'a>( + &'a self, + sender: &UserId, + ) -> impl Iterator<Item = Box<[u8]>> + 'a { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); - self.senderkey_pusher - .scan_prefix(prefix) - .keys() - .map(|r| Ok(r?)) + self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) } } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 703314e68e2220434563f6dccc119f7fd6f5257c..736ff4d85e45eaf46dcb26f480f0038fc974263c 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -19,8 +19,6 @@ state_res::{self, Event, RoomVersion, StateMap}, uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; -use sled::IVec; - use std::{ collections::{BTreeMap, HashMap, HashSet}, convert::{TryFrom, TryInto}, @@ -28,62 +26,61 @@ sync::Arc, }; -use super::{admin::AdminCommand, pusher}; +use super::{abstraction::Tree, admin::AdminCommand, pusher}; /// The unique identifier of each state group. /// /// This is created when a state group is added to the database by /// hashing the entire state. -pub type StateHashId = IVec; +pub type StateHashId = Vec<u8>; -#[derive(Clone)] pub struct Rooms { pub edus: edus::RoomEdus, - pub(super) pduid_pdu: sled::Tree, // PduId = RoomId + Count - pub(super) eventid_pduid: sled::Tree, - pub(super) roomid_pduleaves: sled::Tree, - pub(super) alias_roomid: sled::Tree, - pub(super) aliasid_alias: sled::Tree, // AliasId = RoomId + Count - pub(super) publicroomids: sled::Tree, + pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = RoomId + Count + pub(super) eventid_pduid: Arc<dyn Tree>, + pub(super) roomid_pduleaves: Arc<dyn Tree>, + pub(super) alias_roomid: Arc<dyn Tree>, + pub(super) aliasid_alias: Arc<dyn Tree>, // AliasId = RoomId + Count + pub(super) publicroomids: Arc<dyn Tree>, - pub(super) tokenids: sled::Tree, // TokenId = RoomId + Token + PduId + pub(super) tokenids: Arc<dyn Tree>, // TokenId = RoomId + Token + PduId /// Participating servers in a room. - pub(super) roomserverids: sled::Tree, // RoomServerId = RoomId + ServerName - pub(super) serverroomids: sled::Tree, // ServerRoomId = ServerName + RoomId + pub(super) roomserverids: Arc<dyn Tree>, // RoomServerId = RoomId + ServerName + pub(super) serverroomids: Arc<dyn Tree>, // ServerRoomId = ServerName + RoomId - pub(super) userroomid_joined: sled::Tree, - pub(super) roomuserid_joined: sled::Tree, - pub(super) roomuseroncejoinedids: sled::Tree, - pub(super) userroomid_invitestate: sled::Tree, // InviteState = Vec<Raw<Pdu>> - pub(super) roomuserid_invitecount: sled::Tree, // InviteCount = Count - pub(super) userroomid_leftstate: sled::Tree, - pub(super) roomuserid_leftcount: sled::Tree, + pub(super) userroomid_joined: Arc<dyn Tree>, + pub(super) roomuserid_joined: Arc<dyn Tree>, + pub(super) roomuseroncejoinedids: Arc<dyn Tree>, + pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>> + pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count + pub(super) userroomid_leftstate: Arc<dyn Tree>, + pub(super) roomuserid_leftcount: Arc<dyn Tree>, - pub(super) userroomid_notificationcount: sled::Tree, // NotifyCount = u64 - pub(super) userroomid_highlightcount: sled::Tree, // HightlightCount = u64 + pub(super) userroomid_notificationcount: Arc<dyn Tree>, // NotifyCount = u64 + pub(super) userroomid_highlightcount: Arc<dyn Tree>, // HightlightCount = u64 /// Remember the current state hash of a room. - pub(super) roomid_shortstatehash: sled::Tree, + pub(super) roomid_shortstatehash: Arc<dyn Tree>, /// Remember the state hash at events in the past. - pub(super) shorteventid_shortstatehash: sled::Tree, + pub(super) shorteventid_shortstatehash: Arc<dyn Tree>, /// StateKey = EventType + StateKey, ShortStateKey = Count - pub(super) statekey_shortstatekey: sled::Tree, - pub(super) shorteventid_eventid: sled::Tree, + pub(super) statekey_shortstatekey: Arc<dyn Tree>, + pub(super) shorteventid_eventid: Arc<dyn Tree>, /// ShortEventId = Count - pub(super) eventid_shorteventid: sled::Tree, + pub(super) eventid_shorteventid: Arc<dyn Tree>, /// ShortEventId = Count - pub(super) statehash_shortstatehash: sled::Tree, + pub(super) statehash_shortstatehash: Arc<dyn Tree>, /// ShortStateHash = Count /// StateId = ShortStateHash + ShortStateKey - pub(super) stateid_shorteventid: sled::Tree, + pub(super) stateid_shorteventid: Arc<dyn Tree>, /// RoomId + EventId -> outlier PDU. /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. - pub(super) eventid_outlierpdu: sled::Tree, + pub(super) eventid_outlierpdu: Arc<dyn Tree>, /// RoomId + EventId -> Parent PDU EventId. - pub(super) prevevent_parent: sled::Tree, + pub(super) prevevent_parent: Arc<dyn Tree>, } impl Rooms { @@ -92,10 +89,8 @@ impl Rooms { pub fn state_full_ids(&self, shortstatehash: u64) -> Result<Vec<EventId>> { Ok(self .stateid_shorteventid - .scan_prefix(&shortstatehash.to_be_bytes()) - .values() - .filter_map(|r| r.ok()) - .map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten()) + .scan_prefix(shortstatehash.to_be_bytes().to_vec()) + .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) .flatten() .map(|bytes| { Ok::<_, Error>( @@ -117,10 +112,8 @@ pub fn state_full( ) -> Result<BTreeMap<(EventType, String), PduEvent>> { Ok(self .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes()) - .values() - .filter_map(|r| r.ok()) - .map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten()) + .scan_prefix(shortstatehash.to_be_bytes().to_vec()) + .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) .flatten() .map(|bytes| { Ok::<_, Error>( @@ -211,16 +204,16 @@ pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { self.eventid_shorteventid .get(event_id.as_bytes())? .map_or(Ok(None), |shorteventid| { - Ok(self.shorteventid_shortstatehash.get(shorteventid)?.map_or( - Ok::<_, Error>(None), - |bytes| { + Ok(self + .shorteventid_shortstatehash + .get(&shorteventid)? + .map_or(Ok::<_, Error>(None), |bytes| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { Error::bad_database( "Invalid shortstatehash bytes in shorteventid_shortstatehash", ) })?)) - }, - )?) + })?) }) } @@ -285,7 +278,8 @@ pub fn exists(&self, room_id: &RoomId) -> Result<bool> { // Look for PDUs in that room. Ok(self .pduid_pdu - .get_gt(&prefix)? + .iter_from(&prefix, false) + .next() .filter(|(k, _)| k.starts_with(&prefix)) .is_some()) } @@ -471,10 +465,17 @@ pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> { } pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xff); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + self.pduid_pdu - .scan_prefix(room_id.as_bytes()) - .last() - .map(|b| self.pdu_count(&b?.0)) + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|b| self.pdu_count(&b.0)) .transpose() .map(|op| op.unwrap_or_default()) } @@ -499,7 +500,7 @@ pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObj } /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<IVec>> { + pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| Ok(Some(pdu_id))) @@ -570,11 +571,11 @@ pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJson } /// Removes a pdu and creates a new one with the same id. - fn replace_pdu(&self, pdu_id: &IVec, pdu: &PduEvent) -> Result<()> { + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(&pdu_id)?.is_some() { self.pduid_pdu.insert( &pdu_id, - &*serde_json::to_string(pdu).expect("PduEvent::to_string always works"), + &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), )?; Ok(()) } else { @@ -591,11 +592,11 @@ pub fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<EventId>> { prefix.push(0xff); self.roomid_pduleaves - .scan_prefix(prefix) - .values() - .map(|bytes| { + .scan_prefix(dbg!(prefix)) + .map(|(key, bytes)| { + dbg!(key); Ok::<_, Error>( - EventId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { + EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") })?) .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?, @@ -612,8 +613,8 @@ pub fn replace_pdu_leaves(&self, room_id: &RoomId, event_ids: &[EventId]) -> Res let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - for key in self.roomid_pduleaves.scan_prefix(&prefix).keys() { - self.roomid_pduleaves.remove(key?)?; + for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { + self.roomid_pduleaves.remove(&key)?; } for event_id in event_ids { @@ -628,7 +629,7 @@ pub fn replace_pdu_leaves(&self, room_id: &RoomId, event_ids: &[EventId]) -> Res pub fn is_pdu_referenced(&self, pdu: &PduEvent) -> Result<bool> { let mut key = pdu.room_id().as_bytes().to_vec(); key.extend_from_slice(pdu.event_id().as_bytes()); - self.prevevent_parent.contains_key(key).map_err(Into::into) + Ok(self.prevevent_parent.get(&key)?.is_some()) } /// Returns the pdu from the outlier tree. @@ -646,7 +647,7 @@ pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.eventid_outlierpdu.insert( &event_id.as_bytes(), - &*serde_json::to_string(&pdu).expect("CanonicalJsonObject is valid string"), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), )?; Ok(()) @@ -662,7 +663,7 @@ pub fn append_pdu( pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, count: u64, - pdu_id: IVec, + pdu_id: &[u8], leaves: &[EventId], db: &Database, ) -> Result<()> { @@ -698,7 +699,7 @@ pub fn append_pdu( let mut key = pdu.room_id().as_bytes().to_vec(); key.extend_from_slice(leaf.as_bytes()); self.prevevent_parent - .insert(key, pdu.event_id().as_bytes())?; + .insert(&key, pdu.event_id().as_bytes())?; } self.replace_pdu_leaves(&pdu.room_id, leaves)?; @@ -710,15 +711,13 @@ pub fn append_pdu( self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; self.pduid_pdu.insert( - &pdu_id, - &*serde_json::to_string(&pdu_json) - .expect("CanonicalJsonObject is always a valid String"), + pdu_id, + &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), )?; // This also replaces the eventid of any outliers with the correct // pduid, removing the place holder. - self.eventid_pduid - .insert(pdu.event_id.as_bytes(), &*pdu_id)?; + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; // See if the event matches any known pushers for user in db @@ -760,22 +759,14 @@ pub fn append_pdu( userroom_id.extend_from_slice(pdu.room_id.as_bytes()); if notify { - self.userroomid_notificationcount - .update_and_fetch(&userroom_id, utils::increment)? - .expect("utils::increment will always put in a value"); + self.userroomid_notificationcount.increment(&userroom_id)?; } if highlight { - self.userroomid_highlightcount - .update_and_fetch(&userroom_id, utils::increment)? - .expect("utils::increment will always put in a value"); + self.userroomid_highlightcount.increment(&userroom_id)?; } - for senderkey in db - .pusher - .get_pusher_senderkeys(&user) - .filter_map(|r| r.ok()) - { + for senderkey in db.pusher.get_pusher_senderkeys(&user) { db.sending.send_push_pdu(&*pdu_id, senderkey)?; } } @@ -840,7 +831,7 @@ pub fn append_pdu( key.extend_from_slice(word.as_bytes()); key.push(0xff); key.extend_from_slice(&pdu_id); - self.tokenids.insert(key, &[])?; + self.tokenids.insert(&key, &[])?; } if body.starts_with(&format!("@conduit:{}: ", db.globals.server_name())) @@ -991,7 +982,7 @@ pub fn set_event_state( Some(shortstatehash) => { // State already existed in db self.shorteventid_shortstatehash - .insert(shorteventid, &*shortstatehash)?; + .insert(&shorteventid, &*shortstatehash)?; return Ok(()); } None => { @@ -1037,7 +1028,7 @@ pub fn set_event_state( } self.shorteventid_shortstatehash - .insert(shorteventid, &*shortstatehash)?; + .insert(&shorteventid, &*shortstatehash)?; Ok(()) } @@ -1070,7 +1061,7 @@ pub fn append_to_state( }; self.shorteventid_shortstatehash - .insert(shorteventid, &old_shortstatehash)?; + .insert(&shorteventid, &old_shortstatehash)?; if new_pdu.state_key.is_none() { return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") @@ -1078,17 +1069,16 @@ pub fn append_to_state( } self.stateid_shorteventid - .scan_prefix(&old_shortstatehash) - .filter_map(|pdu| pdu.map_err(|e| error!("{}", e)).ok()) + .scan_prefix(old_shortstatehash.clone()) // Chop the old_shortstatehash out leaving behind the short state key .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::<HashMap<Vec<u8>, IVec>>() + .collect::<HashMap<Vec<u8>, Box<[u8]>>>() } else { HashMap::new() }; if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap<Vec<u8>, IVec> = old_state; + let mut new_state: HashMap<Vec<u8>, Box<[u8]>> = old_state; let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); new_state_key.push(0xff); @@ -1205,6 +1195,7 @@ pub fn build_and_append_pdu( room_id: &RoomId, db: &Database, ) -> Result<EventId> { + dbg!(&pdu_builder); let PduBuilder { event_type, content, @@ -1366,7 +1357,7 @@ pub fn build_and_append_pdu( &pdu, pdu_json, count, - pdu_id.clone().into(), + &pdu_id, // Since this PDU references all pdu_leaves we can update the leaves // of the room &[pdu.event_id.clone()], @@ -1385,7 +1376,7 @@ pub fn build_and_append_pdu( db.sending.send_pdu(&server, &pdu_id)?; } - for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { + for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") @@ -1464,23 +1455,23 @@ pub fn build_and_append_pdu( /// Returns an iterator over all PDUs in a room. #[tracing::instrument(skip(self))] - pub fn all_pdus( - &self, + pub fn all_pdus<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, - ) -> Result<impl Iterator<Item = Result<(IVec, PduEvent)>>> { + ) -> impl Iterator<Item = Result<(Box<[u8]>, PduEvent)>> + 'a { self.pdus_since(user_id, room_id, 0) } - /// Returns a double-ended iterator over all events in a room that happened after the event with id `since` + /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. #[tracing::instrument(skip(self))] - pub fn pdus_since( - &self, + pub fn pdus_since<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result<impl DoubleEndedIterator<Item = Result<(IVec, PduEvent)>>> { + ) -> impl Iterator<Item = Result<(Box<[u8]>, PduEvent)>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1488,19 +1479,10 @@ pub fn pdus_since( let mut first_pdu_id = prefix.clone(); first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - let mut last_pdu_id = prefix; - last_pdu_id.extend_from_slice(&u64::MAX.to_be_bytes()); - let user_id = user_id.clone(); - Ok(self - .pduid_pdu - .range(first_pdu_id..last_pdu_id) - .filter_map(|r| { - if r.is_err() { - error!("Bad pdu in pduid_pdu: {:?}", r); - } - r.ok() - }) + self.pduid_pdu + .iter_from(&first_pdu_id, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::<PduEvent>(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; @@ -1508,17 +1490,17 @@ pub fn pdus_since( pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - })) + }) } /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. - pub fn pdus_until( - &self, + pub fn pdus_until<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator<Item = Result<(IVec, PduEvent)>> { + ) -> impl Iterator<Item = Result<(Box<[u8]>, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1530,9 +1512,7 @@ pub fn pdus_until( let user_id = user_id.clone(); self.pduid_pdu - .range(..current) - .rev() - .filter_map(|r| r.ok()) + .iter_from(current, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::<PduEvent>(&v) @@ -1547,12 +1527,12 @@ pub fn pdus_until( /// Returns an iterator over all events and their token in a room that happened after the event /// with id `from` in chronological order. #[tracing::instrument(skip(self))] - pub fn pdus_after( - &self, + pub fn pdus_after<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator<Item = Result<(IVec, PduEvent)>> { + ) -> impl Iterator<Item = Result<(Box<[u8]>, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1564,8 +1544,7 @@ pub fn pdus_after( let user_id = user_id.clone(); self.pduid_pdu - .range(current..) - .filter_map(|r| r.ok()) + .iter_from(current, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::<PduEvent>(&v) @@ -1744,7 +1723,7 @@ pub fn update_membership( self.serverroomids.insert(&serverroom_id, &[])?; self.userroomid_invitestate.insert( &userroom_id, - serde_json::to_vec(&last_state.unwrap_or_default()) + &serde_json::to_vec(&last_state.unwrap_or_default()) .expect("state to bytes always works"), )?; self.roomuserid_invitecount @@ -1766,7 +1745,7 @@ pub fn update_membership( } self.userroomid_leftstate.insert( &userroom_id, - serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), + &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), )?; // TODO self.roomuserid_leftcount .insert(&roomuser_id, &db.globals.next_count()?.to_be_bytes())?; @@ -1966,8 +1945,8 @@ pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_leftstate.remove(userroom_id)?; - self.roomuserid_leftcount.remove(roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; Ok(()) } @@ -1981,26 +1960,26 @@ pub fn set_alias( if let Some(room_id) = room_id { // New alias self.alias_roomid - .insert(alias.alias(), room_id.as_bytes())?; + .insert(&alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xff); aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(aliasid, &*alias.as_bytes())?; + self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; } else { // room_id=None means remove alias - let room_id = self - .alias_roomid - .remove(alias.alias())? - .ok_or(Error::BadRequest( + if let Some(room_id) = self.alias_roomid.get(&alias.alias().as_bytes())? { + let mut prefix = room_id.to_vec(); + prefix.push(0xff); + + for (key, _) in self.aliasid_alias.scan_prefix(prefix) { + self.aliasid_alias.remove(&key)?; + } + self.alias_roomid.remove(&alias.alias().as_bytes())?; + } else { + return Err(Error::BadRequest( ErrorKind::NotFound, "Alias does not exist.", - ))?; - - let mut prefix = room_id.to_vec(); - prefix.push(0xff); - - for key in self.aliasid_alias.scan_prefix(prefix).keys() { - self.aliasid_alias.remove(key?)?; + )); } } @@ -2009,7 +1988,7 @@ pub fn set_alias( pub fn id_from_alias(&self, alias: &RoomAliasId) -> Result<Option<RoomId>> { self.alias_roomid - .get(alias.alias())? + .get(alias.alias().as_bytes())? .map_or(Ok(None), |bytes| { Ok(Some( RoomId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { @@ -2020,19 +1999,19 @@ pub fn id_from_alias(&self, alias: &RoomAliasId) -> Result<Option<RoomId>> { }) } - pub fn room_aliases(&self, room_id: &RoomId) -> impl Iterator<Item = Result<RoomAliasId>> { + pub fn room_aliases<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<RoomAliasId>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - self.aliasid_alias - .scan_prefix(prefix) - .values() - .map(|bytes| { - Ok(utils::string_from_bytes(&bytes?) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))?) - }) + self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { + Ok(utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))?) + }) } pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { @@ -2046,13 +2025,13 @@ pub fn set_public(&self, room_id: &RoomId, public: bool) -> Result<()> { } pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { - Ok(self.publicroomids.contains_key(room_id.as_bytes())?) + Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - pub fn public_rooms(&self) -> impl Iterator<Item = Result<RoomId>> { - self.publicroomids.iter().keys().map(|bytes| { + pub fn public_rooms<'a>(&'a self) -> impl Iterator<Item = Result<RoomId>> + 'a { + self.publicroomids.iter().map(|(bytes, _)| { Ok( - RoomId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { + RoomId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("Room ID in publicroomids is invalid unicode.") })?) .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))?, @@ -2073,31 +2052,39 @@ pub fn search_pdus<'a>( .map(str::to_lowercase) .collect::<Vec<_>>(); - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xff); - self.tokenids - .scan_prefix(&prefix2) - .keys() - .rev() // Newest pdus first - .filter_map(|r| r.ok()) - .map(|key| { - let pduid_index = key - .iter() - .enumerate() - .filter(|(_, &b)| b == 0xff) - .nth(1) - .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? - .0 - + 1; // +1 because the pdu id starts AFTER the separator - - let pdu_id = key[pduid_index..].to_vec(); - - Ok::<_, Error>(pdu_id) - }) - .filter_map(|r| r.ok()) - }); + let iterators = words + .clone() + .into_iter() + .map(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xff); + + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + + Ok::<_, Error>( + self.tokenids + .iter_from(&last_possible_id, true) // Newest pdus first + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(|(key, _)| { + let pduid_index = key + .iter() + .enumerate() + .filter(|(_, &b)| b == 0xff) + .nth(1) + .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? + .0 + + 1; // +1 because the pdu id starts AFTER the separator + + let pdu_id = key[pduid_index..].to_vec(); + + Ok::<_, Error>(pdu_id) + }) + .filter_map(|r| r.ok()), + ) + }) + .filter_map(|r| r.ok()); Ok(( utils::common_elements(iterators, |a, b| { @@ -2113,52 +2100,59 @@ pub fn search_pdus<'a>( pub fn get_shared_rooms<'a>( &'a self, users: Vec<UserId>, - ) -> impl Iterator<Item = Result<RoomId>> + 'a { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - self.userroomid_joined - .scan_prefix(&prefix) - .keys() - .filter_map(|r| r.ok()) - .map(|key| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator - - let room_id = key[roomid_index..].to_vec(); - - Ok::<_, Error>(room_id) - }) - .filter_map(|r| r.ok()) - }); + ) -> Result<impl Iterator<Item = Result<RoomId>> + 'a> { + let iterators = users + .into_iter() + .map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + + Ok::<_, Error>( + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xff) + .ok_or_else(|| { + Error::bad_database("Invalid userroomid_joined in db.") + })? + .0 + + 1; // +1 because the room id starts AFTER the separator + + let room_id = key[roomid_index..].to_vec(); + + Ok::<_, Error>(room_id) + }) + .filter_map(|r| r.ok()), + ) + }) + .filter_map(|r| r.ok()); // We use the default compare function because keys are sorted correctly (not reversed) - utils::common_elements(iterators, Ord::cmp) + Ok(utils::common_elements(iterators, Ord::cmp) .expect("users is not empty") .map(|bytes| { RoomId::try_from(utils::string_from_bytes(&*bytes).map_err(|_| { Error::bad_database("Invalid RoomId bytes in userroomid_joined") })?) .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }) + })) } /// Returns an iterator of all servers participating in this room. - pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<Box<ServerName>>> { + pub fn room_servers<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<Box<ServerName>>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - self.roomserverids.scan_prefix(prefix).keys().map(|key| { + self.roomserverids.scan_prefix(prefix).map(|(key, _)| { Ok(Box::<ServerName>::try_from( utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) + &key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element"), ) @@ -2171,15 +2165,17 @@ pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<Box< } /// Returns an iterator of all rooms a server participates in (as far as we know). - pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<RoomId>> { + pub fn server_rooms<'a>( + &'a self, + server: &ServerName, + ) -> impl Iterator<Item = Result<RoomId>> + 'a { let mut prefix = server.as_bytes().to_vec(); prefix.push(0xff); - self.serverroomids.scan_prefix(prefix).keys().map(|key| { + self.serverroomids.scan_prefix(prefix).map(|(key, _)| { Ok(RoomId::try_from( utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) + &key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element"), ) @@ -2191,42 +2187,42 @@ pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<R /// Returns an iterator over all joined members of a room. #[tracing::instrument(skip(self))] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { + pub fn room_members<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<UserId>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - self.roomuserid_joined - .scan_prefix(prefix) - .keys() - .map(|key| { - Ok(UserId::try_from( - utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_joined is invalid unicode.") - })?, + self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { + Ok(UserId::try_from( + utils::string_from_bytes( + &key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))?) - }) + .map_err(|_| { + Error::bad_database("User ID in roomuserid_joined is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))?) + }) } /// Returns an iterator over all User IDs who ever joined a room. - pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { + pub fn room_useroncejoined<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<UserId>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); self.roomuseroncejoinedids .scan_prefix(prefix) - .keys() - .map(|key| { + .map(|(key, _)| { Ok(UserId::try_from( utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) + &key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element"), ) @@ -2240,18 +2236,19 @@ pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Resu /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self))] - pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<UserId>> { + pub fn room_members_invited<'a>( + &'a self, + room_id: &RoomId, + ) -> impl Iterator<Item = Result<UserId>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); self.roomuserid_invitecount .scan_prefix(prefix) - .keys() - .map(|key| { + .map(|(key, _)| { Ok(UserId::try_from( utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) + &key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element"), ) @@ -2270,7 +2267,7 @@ pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Opt key.extend_from_slice(user_id.as_bytes()); self.roomuserid_invitecount - .get(key)? + .get(&key)? .map_or(Ok(None), |bytes| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { Error::bad_database("Invalid invitecount in db.") @@ -2285,7 +2282,7 @@ pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Optio key.extend_from_slice(user_id.as_bytes()); self.roomuserid_leftcount - .get(key)? + .get(&key)? .map_or(Ok(None), |bytes| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { Error::bad_database("Invalid leftcount in db.") @@ -2295,15 +2292,16 @@ pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Optio /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] - pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<RoomId>> { + pub fn rooms_joined<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<RoomId>> + 'a { self.userroomid_joined - .scan_prefix(user_id.as_bytes()) - .keys() - .map(|key| { + .scan_prefix(user_id.as_bytes().to_vec()) + .map(|(key, _)| { Ok(RoomId::try_from( utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) + &key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element"), ) @@ -2317,32 +2315,33 @@ pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<Room /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self))] - pub fn rooms_invited( - &self, + pub fn rooms_invited<'a>( + &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnyStrippedStateEvent>>)>> { + ) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); - self.userroomid_invitestate.scan_prefix(prefix).map(|r| { - let (key, state) = r?; - let room_id = RoomId::try_from( - utils::string_from_bytes( - &key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), + self.userroomid_invitestate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::try_from( + utils::string_from_bytes( + &key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid unicode.") + })?, ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - Ok((room_id, state)) - }) + Ok((room_id, state)) + }) } #[tracing::instrument(skip(self))] @@ -2356,7 +2355,7 @@ pub fn invite_state( key.extend_from_slice(&room_id.as_bytes()); self.userroomid_invitestate - .get(key)? + .get(&key)? .map(|state| { let state = serde_json::from_slice(&state) .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; @@ -2377,7 +2376,7 @@ pub fn left_state( key.extend_from_slice(&room_id.as_bytes()); self.userroomid_leftstate - .get(key)? + .get(&key)? .map(|state| { let state = serde_json::from_slice(&state) .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; @@ -2389,32 +2388,33 @@ pub fn left_state( /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self))] - pub fn rooms_left( - &self, + pub fn rooms_left<'a>( + &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnySyncStateEvent>>)>> { + ) -> impl Iterator<Item = Result<(RoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); - self.userroomid_leftstate.scan_prefix(prefix).map(|r| { - let (key, state) = r?; - let room_id = RoomId::try_from( - utils::string_from_bytes( - &key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), + self.userroomid_leftstate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::try_from( + utils::string_from_bytes( + &key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid unicode.") + })?, ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - Ok((room_id, state)) - }) + Ok((room_id, state)) + }) } pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { @@ -2422,7 +2422,7 @@ pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - Ok(self.roomuseroncejoinedids.get(userroom_id)?.is_some()) + Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) } pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { @@ -2430,7 +2430,7 @@ pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - Ok(self.userroomid_joined.get(userroom_id)?.is_some()) + Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) } pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { @@ -2438,7 +2438,7 @@ pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - Ok(self.userroomid_invitestate.get(userroom_id)?.is_some()) + Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) } pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { @@ -2446,6 +2446,6 @@ pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - Ok(self.userroomid_leftstate.get(userroom_id)?.is_some()) + Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } } diff --git a/src/database/rooms/edus.rs b/src/database/rooms/edus.rs index f4c7075e49c8a899fc18f12f88f3cd1649ee3c9f..677d26eb056308c4928889d8d93fa833e7a646ab 100644 --- a/src/database/rooms/edus.rs +++ b/src/database/rooms/edus.rs @@ -1,4 +1,4 @@ -use crate::{utils, Error, Result}; +use crate::{database::abstraction::Tree, utils, Error, Result}; use ruma::{ events::{ presence::{PresenceEvent, PresenceEventContent}, @@ -13,17 +13,17 @@ collections::{HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, + sync::Arc, }; -#[derive(Clone)] pub struct RoomEdus { - pub(in super::super) readreceiptid_readreceipt: sled::Tree, // ReadReceiptId = RoomId + Count + UserId - pub(in super::super) roomuserid_privateread: sled::Tree, // RoomUserId = Room + User, PrivateRead = Count - pub(in super::super) roomuserid_lastprivatereadupdate: sled::Tree, // LastPrivateReadUpdate = Count - pub(in super::super) typingid_userid: sled::Tree, // TypingId = RoomId + TimeoutTime + Count - pub(in super::super) roomid_lasttypingupdate: sled::Tree, // LastRoomTypingUpdate = Count - pub(in super::super) presenceid_presence: sled::Tree, // PresenceId = RoomId + Count + UserId - pub(in super::super) userid_lastpresenceupdate: sled::Tree, // LastPresenceUpdate = Count + pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId + pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count + pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count + pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count + pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count + pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId + pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count } impl RoomEdus { @@ -38,15 +38,15 @@ pub fn readreceipt_update( let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + // Remove old entry - if let Some(old) = self + if let Some((old, _)) = self .readreceiptid_readreceipt - .scan_prefix(&prefix) - .keys() - .rev() - .filter_map(|r| r.ok()) - .take_while(|key| key.starts_with(&prefix)) - .find(|key| { + .iter_from(&last_possible_key, true) + .take_while(|(key, _)| key.starts_with(&prefix)) + .find(|(key, _)| { key.rsplit(|&b| b == 0xff) .next() .expect("rsplit always returns an element") @@ -54,7 +54,7 @@ pub fn readreceipt_update( }) { // This is the old room_latest - self.readreceiptid_readreceipt.remove(old)?; + self.readreceiptid_readreceipt.remove(&old)?; } let mut room_latest_id = prefix; @@ -63,8 +63,8 @@ pub fn readreceipt_update( room_latest_id.extend_from_slice(&user_id.as_bytes()); self.readreceiptid_readreceipt.insert( - room_latest_id, - &*serde_json::to_string(&event).expect("EduEvent::to_string always works"), + &room_latest_id, + &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), )?; Ok(()) @@ -72,13 +72,12 @@ pub fn readreceipt_update( /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. #[tracing::instrument(skip(self))] - pub fn readreceipts_since( - &self, + pub fn readreceipts_since<'a>( + &'a self, room_id: &RoomId, since: u64, - ) -> Result< - impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>>, - > { + ) -> impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a + { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); let prefix2 = prefix.clone(); @@ -86,10 +85,8 @@ impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoom let mut first_possible_edu = prefix.clone(); first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - Ok(self - .readreceiptid_readreceipt - .range(&*first_possible_edu..) - .filter_map(|r| r.ok()) + self.readreceiptid_readreceipt + .iter_from(&first_possible_edu, false) .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { let count = @@ -115,7 +112,7 @@ impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoom serde_json::value::to_raw_value(&json).expect("json is valid raw value"), ), )) - })) + }) } /// Sets a private read marker at `count`. @@ -146,11 +143,13 @@ pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Opt key.push(0xff); key.extend_from_slice(&user_id.as_bytes()); - self.roomuserid_privateread.get(key)?.map_or(Ok(None), |v| { - Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { - Error::bad_database("Invalid private read marker bytes") - })?)) - }) + self.roomuserid_privateread + .get(&key)? + .map_or(Ok(None), |v| { + Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { + Error::bad_database("Invalid private read marker bytes") + })?)) + }) } /// Returns the count of the last typing update in this room. @@ -215,11 +214,10 @@ pub fn typing_remove( // Maybe there are multiple ones from calling roomtyping_add multiple times for outdated_edu in self .typingid_userid - .scan_prefix(&prefix) - .filter_map(|r| r.ok()) - .filter(|(_, v)| v == user_id.as_bytes()) + .scan_prefix(prefix) + .filter(|(_, v)| &**v == user_id.as_bytes()) { - self.typingid_userid.remove(outdated_edu.0)?; + self.typingid_userid.remove(&outdated_edu.0)?; found_outdated = true; } @@ -247,10 +245,8 @@ fn typings_maintain( // Find all outdated edus before inserting a new one for outdated_edu in self .typingid_userid - .scan_prefix(&prefix) - .keys() - .map(|key| { - let key = key?; + .scan_prefix(prefix) + .map(|(key, _)| { Ok::<_, Error>(( key.clone(), utils::u64_from_bytes( @@ -265,7 +261,7 @@ fn typings_maintain( .take_while(|&(_, timestamp)| timestamp < current_timestamp) { // This is an outdated edu (time > timestamp) - self.typingid_userid.remove(outdated_edu.0)?; + self.typingid_userid.remove(&outdated_edu.0)?; found_outdated = true; } @@ -309,10 +305,9 @@ pub fn typings_all( for user_id in self .typingid_userid .scan_prefix(prefix) - .values() - .map(|user_id| { + .map(|(_, user_id)| { Ok::<_, Error>( - UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| { + UserId::try_from(utils::string_from_bytes(&user_id).map_err(|_| { Error::bad_database("User ID in typingid_userid is invalid unicode.") })?) .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?, @@ -351,12 +346,12 @@ pub fn update_presence( presence_id.extend_from_slice(&presence.sender.as_bytes()); self.presenceid_presence.insert( - presence_id, - &*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"), + &presence_id, + &serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"), )?; self.userid_lastpresenceupdate.insert( - &user_id.as_bytes(), + user_id.as_bytes(), &utils::millis_since_unix_epoch().to_be_bytes(), )?; @@ -403,7 +398,7 @@ pub fn get_last_presence_event( presence_id.extend_from_slice(&user_id.as_bytes()); self.presenceid_presence - .get(presence_id)? + .get(&presence_id)? .map(|value| { let mut presence = serde_json::from_slice::<PresenceEvent>(&value) .map_err(|_| Error::bad_database("Invalid presence event in db."))?; @@ -438,7 +433,6 @@ pub fn presence_maintain( for (user_id_bytes, last_timestamp) in self .userid_lastpresenceupdate .iter() - .filter_map(|r| r.ok()) .filter_map(|(k, bytes)| { Some(( k, @@ -468,8 +462,8 @@ pub fn presence_maintain( presence_id.extend_from_slice(&user_id_bytes); self.presenceid_presence.insert( - presence_id, - &*serde_json::to_string(&PresenceEvent { + &presence_id, + &serde_json::to_vec(&PresenceEvent { content: PresenceEventContent { avatar_url: None, currently_active: None, @@ -515,8 +509,7 @@ pub fn presence_since( for (key, value) in self .presenceid_presence - .range(&*first_possible_edu..) - .filter_map(|r| r.ok()) + .iter_from(&*first_possible_edu, false) .take_while(|(key, _)| key.starts_with(&prefix)) { let user_id = UserId::try_from( diff --git a/src/database/sending.rs b/src/database/sending.rs index ed5b5ef872af9abbef4e17c817aefac95e3fb0be..ecf0761828ad938b14d3a2aed2d27ef0900af650 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -12,7 +12,10 @@ use federation::transactions::send_transaction_message; use log::{error, warn}; use ring::digest; -use rocket::futures::stream::{FuturesUnordered, StreamExt}; +use rocket::futures::{ + channel::mpsc, + stream::{FuturesUnordered, StreamExt}, +}; use ruma::{ api::{ appservice, @@ -27,9 +30,10 @@ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, }; -use sled::IVec; use tokio::{select, sync::Semaphore}; +use super::abstraction::Tree; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum OutgoingKind { Appservice(Box<ServerName>), @@ -70,13 +74,13 @@ pub enum SendingEventType { Edu(Vec<u8>), } -#[derive(Clone)] pub struct Sending { /// The state for a given state hash. - pub(super) servername_educount: sled::Tree, // EduCount: Count of last EDU sync - pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId - pub(super) servercurrentevents: sled::Tree, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent + pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync + pub(super) servernamepduids: Arc<dyn Tree>, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId + pub(super) servercurrentevents: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent pub(super) maximum_requests: Arc<Semaphore>, + pub sender: mpsc::UnboundedSender<Vec<u8>>, } enum TransactionStatus { @@ -86,28 +90,23 @@ enum TransactionStatus { } impl Sending { - pub fn start_handler(&self, db: &Database) { - let servernamepduids = self.servernamepduids.clone(); - let servercurrentevents = self.servercurrentevents.clone(); - - let db = db.clone(); - + pub fn start_handler(&self, db: Arc<Database>, mut receiver: mpsc::UnboundedReceiver<Vec<u8>>) { tokio::spawn(async move { let mut futures = FuturesUnordered::new(); - // Retry requests we could not finish yet - let mut subscriber = servernamepduids.watch_prefix(b""); let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new(); + // Retry requests we could not finish yet let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); - for (key, outgoing_kind, event) in servercurrentevents - .iter() - .filter_map(|r| r.ok()) - .filter_map(|(key, _)| { - Self::parse_servercurrentevent(&key) - .ok() - .map(|(k, e)| (key, k, e)) - }) + for (key, outgoing_kind, event) in + db.sending + .servercurrentevents + .iter() + .filter_map(|(key, _)| { + Self::parse_servercurrentevent(&key) + .ok() + .map(|(k, e)| (key, k, e)) + }) { let entry = initial_transactions .entry(outgoing_kind.clone()) @@ -118,7 +117,7 @@ pub fn start_handler(&self, db: &Database) { "Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event ); - servercurrentevents.remove(key).unwrap(); + db.sending.servercurrentevents.remove(&key).unwrap(); continue; } @@ -137,20 +136,16 @@ pub fn start_handler(&self, db: &Database) { match response { Ok(outgoing_kind) => { let prefix = outgoing_kind.get_prefix(); - for key in servercurrentevents - .scan_prefix(&prefix) - .keys() - .filter_map(|r| r.ok()) + for (key, _) in db.sending.servercurrentevents + .scan_prefix(prefix.clone()) { - servercurrentevents.remove(key).unwrap(); + db.sending.servercurrentevents.remove(&key).unwrap(); } // Find events that have been added since starting the last request - let new_events = servernamepduids - .scan_prefix(&prefix) - .keys() - .filter_map(|r| r.ok()) - .map(|k| { + let new_events = db.sending.servernamepduids + .scan_prefix(prefix.clone()) + .map(|(k, _)| { SendingEventType::Pdu(k[prefix.len()..].to_vec()) }) .take(30) @@ -166,8 +161,8 @@ pub fn start_handler(&self, db: &Database) { SendingEventType::Pdu(b) | SendingEventType::Edu(b) => { current_key.extend_from_slice(&b); - servercurrentevents.insert(¤t_key, &[]).unwrap(); - servernamepduids.remove(¤t_key).unwrap(); + db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); + db.sending.servernamepduids.remove(¤t_key).unwrap(); } } } @@ -195,18 +190,15 @@ pub fn start_handler(&self, db: &Database) { } }; }, - Some(event) = &mut subscriber => { - // New sled version: - //for (_tree, key, value_opt) in &event { - // if value_opt.is_none() { - // continue; - // } - - if let sled::Event::Insert { key, .. } = event { - if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { - if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids, &db) { - futures.push(Self::handle_events(outgoing_kind, events, &db)); - } + Some(key) = receiver.next() => { + if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { + if let Ok(Some(events)) = Self::select_events( + &outgoing_kind, + vec![(event, key)], + &mut current_transaction_status, + &db + ) { + futures.push(Self::handle_events(outgoing_kind, events, &db)); } } } @@ -217,12 +209,10 @@ pub fn start_handler(&self, db: &Database) { fn select_events( outgoing_kind: &OutgoingKind, - new_events: Vec<(SendingEventType, IVec)>, // Events we want to send: event and full key + new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>, - servercurrentevents: &sled::Tree, - servernamepduids: &sled::Tree, db: &Database, - ) -> Option<Vec<SendingEventType>> { + ) -> Result<Option<Vec<SendingEventType>>> { let mut retry = false; let mut allow = true; @@ -252,29 +242,25 @@ fn select_events( .or_insert(TransactionStatus::Running); if !allow { - return None; + return Ok(None); } let mut events = Vec::new(); if retry { // We retry the previous transaction - for key in servercurrentevents - .scan_prefix(&prefix) - .keys() - .filter_map(|r| r.ok()) - { + for (key, _) in db.sending.servercurrentevents.scan_prefix(prefix) { if let Ok((_, e)) = Self::parse_servercurrentevent(&key) { events.push(e); } } } else { for (e, full_key) in new_events { - servercurrentevents.insert(&full_key, &[]).unwrap(); + db.sending.servercurrentevents.insert(&full_key, &[])?; // If it was a PDU we have to unqueue it // TODO: don't try to unqueue EDUs - servernamepduids.remove(&full_key).unwrap(); + db.sending.servernamepduids.remove(&full_key)?; events.push(e); } @@ -284,13 +270,12 @@ fn select_events( events.extend_from_slice(&select_edus); db.sending .servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) - .unwrap(); + .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; } } } - Some(events) + Ok(Some(events)) } pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> { @@ -307,7 +292,7 @@ pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEve let mut max_edu_count = since; 'outer: for room_id in db.rooms.server_rooms(server) { let room_id = room_id?; - for r in db.rooms.edus.readreceipts_since(&room_id, since)? { + for r in db.rooms.edus.readreceipts_since(&room_id, since) { let (user_id, count, read_receipt) = r?; if count > max_edu_count { @@ -372,12 +357,13 @@ pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEve } #[tracing::instrument(skip(self))] - pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: IVec) -> Result<()> { + pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> { let mut key = b"$".to_vec(); key.extend_from_slice(&senderkey); key.push(0xff); key.extend_from_slice(pdu_id); - self.servernamepduids.insert(key, b"")?; + self.servernamepduids.insert(&key, b"")?; + self.sender.unbounded_send(key).unwrap(); Ok(()) } @@ -387,7 +373,8 @@ pub fn send_pdu(&self, server: &ServerName, pdu_id: &[u8]) -> Result<()> { let mut key = server.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(pdu_id); - self.servernamepduids.insert(key, b"")?; + self.servernamepduids.insert(&key, b"")?; + self.sender.unbounded_send(key).unwrap(); Ok(()) } @@ -398,7 +385,8 @@ pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result< key.extend_from_slice(appservice_id.as_bytes()); key.push(0xff); key.extend_from_slice(pdu_id); - self.servernamepduids.insert(key, b"")?; + self.servernamepduids.insert(&key, b"")?; + self.sender.unbounded_send(key).unwrap(); Ok(()) } @@ -641,7 +629,7 @@ async fn handle_events( } } - fn parse_servercurrentevent(key: &IVec) -> Result<(OutgoingKind, SendingEventType)> { + fn parse_servercurrentevent(key: &[u8]) -> Result<(OutgoingKind, SendingEventType)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { let mut parts = key[1..].splitn(2, |&b| b == 0xff); diff --git a/src/database/transaction_ids.rs b/src/database/transaction_ids.rs index 1f8ba7de128c2e74044f254562819688f28aab5a..3e3777985e844d322b9d32c016cc43450a6b2a54 100644 --- a/src/database/transaction_ids.rs +++ b/src/database/transaction_ids.rs @@ -1,10 +1,12 @@ +use std::sync::Arc; + use crate::Result; use ruma::{DeviceId, UserId}; -use sled::IVec; -#[derive(Clone)] +use super::abstraction::Tree; + pub struct TransactionIds { - pub(super) userdevicetxnid_response: sled::Tree, // Response can be empty (/sendToDevice) or the event id (/send) + pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send) } impl TransactionIds { @@ -21,7 +23,7 @@ pub fn add_txnid( key.push(0xff); key.extend_from_slice(txn_id.as_bytes()); - self.userdevicetxnid_response.insert(key, data)?; + self.userdevicetxnid_response.insert(&key, data)?; Ok(()) } @@ -31,7 +33,7 @@ pub fn existing_txnid( user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &str, - ) -> Result<Option<IVec>> { + ) -> Result<Option<Vec<u8>>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); @@ -39,6 +41,6 @@ pub fn existing_txnid( key.extend_from_slice(txn_id.as_bytes()); // If there's no entry, this is a new transaction - Ok(self.userdevicetxnid_response.get(key)?) + Ok(self.userdevicetxnid_response.get(&key)?) } } diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs index 3b778402ce90479e1e70ece7173d1177f5e6e218..f7f3d1f85cd4825fb5da5a5620b11c8aa1f26972 100644 --- a/src/database/uiaa.rs +++ b/src/database/uiaa.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use ruma::{ api::client::{ @@ -8,10 +10,11 @@ DeviceId, UserId, }; -#[derive(Clone)] +use super::abstraction::Tree; + pub struct Uiaa { - pub(super) userdevicesessionid_uiaainfo: sled::Tree, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: sled::Tree, // UiaaRequest = canonical json value + pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: Arc<dyn Tree>, // UiaaRequest = canonical json value } impl Uiaa { @@ -185,7 +188,7 @@ fn set_uiaa_request( self.userdevicesessionid_uiaarequest.insert( &userdevicesessionid, - &*serde_json::to_string(request).expect("json value to string always works"), + &serde_json::to_vec(request).expect("json value to vec always works"), )?; Ok(()) @@ -233,7 +236,7 @@ fn update_uiaa_session( if let Some(uiaainfo) = uiaainfo { self.userdevicesessionid_uiaainfo.insert( &userdevicesessionid, - &*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"), + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), )?; } else { self.userdevicesessionid_uiaainfo diff --git a/src/database/users.rs b/src/database/users.rs index 52e6e33bdd86fb6e4e319a019ea9cb7dda943d13..b6d3b3c4af1bdd130e8184291319fba650aaa821 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -7,40 +7,41 @@ serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId, }; -use std::{collections::BTreeMap, convert::TryFrom, mem}; +use std::{collections::BTreeMap, convert::TryFrom, mem, sync::Arc}; + +use super::abstraction::Tree; -#[derive(Clone)] pub struct Users { - pub(super) userid_password: sled::Tree, - pub(super) userid_displayname: sled::Tree, - pub(super) userid_avatarurl: sled::Tree, - pub(super) userdeviceid_token: sled::Tree, - pub(super) userdeviceid_metadata: sled::Tree, // This is also used to check if a device exists - pub(super) userid_devicelistversion: sled::Tree, // DevicelistVersion = u64 - pub(super) token_userdeviceid: sled::Tree, - - pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: sled::Tree, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: sled::Tree, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: sled::Tree, - pub(super) userid_selfsigningkeyid: sled::Tree, - pub(super) userid_usersigningkeyid: sled::Tree, - - pub(super) todeviceid_events: sled::Tree, // ToDeviceId = UserId + DeviceId + Count + pub(super) userid_password: Arc<dyn Tree>, + pub(super) userid_displayname: Arc<dyn Tree>, + pub(super) userid_avatarurl: Arc<dyn Tree>, + pub(super) userdeviceid_token: Arc<dyn Tree>, + pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc<dyn Tree>, + + pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc<dyn Tree>, + pub(super) userid_selfsigningkeyid: Arc<dyn Tree>, + pub(super) userid_usersigningkeyid: Arc<dyn Tree>, + + pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count } impl Users { /// Check if a user has an account on this homeserver. pub fn exists(&self, user_id: &UserId) -> Result<bool> { - Ok(self.userid_password.contains_key(user_id.to_string())?) + Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } /// Check if account is deactivated pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { Ok(self .userid_password - .get(user_id.to_string())? + .get(user_id.as_bytes())? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, "User does not exist.", @@ -55,14 +56,14 @@ pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { } /// Returns the number of users registered on this server. - pub fn count(&self) -> usize { - self.userid_password.iter().count() + pub fn count(&self) -> Result<usize> { + Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> { self.token_userdeviceid - .get(token)? + .get(token.as_bytes())? .map_or(Ok(None), |bytes| { let mut parts = bytes.split(|&b| b == 0xff); let user_bytes = parts.next().ok_or_else(|| { @@ -87,10 +88,10 @@ pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> { } /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { - self.userid_password.iter().keys().map(|bytes| { + pub fn iter<'a>(&'a self) -> impl Iterator<Item = Result<UserId>> + 'a { + self.userid_password.iter().map(|(bytes, _)| { Ok( - UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| { + UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") })?) .map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?, @@ -101,7 +102,7 @@ pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> { /// Returns the password hash for the given user. pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { self.userid_password - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("Password hash in db is not valid string.") @@ -113,7 +114,8 @@ pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_hash(&password) { - self.userid_password.insert(user_id.to_string(), &*hash)?; + self.userid_password + .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) } else { Err(Error::BadRequest( @@ -122,7 +124,7 @@ pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<( )) } } else { - self.userid_password.insert(user_id.to_string(), "")?; + self.userid_password.insert(user_id.as_bytes(), b"")?; Ok(()) } } @@ -130,7 +132,7 @@ pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<( /// Returns the displayname of a user on this homeserver. pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { self.userid_displayname - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("Displayname in db is invalid.") @@ -142,9 +144,9 @@ pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname - .insert(user_id.to_string(), &*displayname)?; + .insert(user_id.as_bytes(), displayname.as_bytes())?; } else { - self.userid_displayname.remove(user_id.to_string())?; + self.userid_displayname.remove(user_id.as_bytes())?; } Ok(()) @@ -153,7 +155,7 @@ pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> /// Get a the avatar_url of a user. pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> { self.userid_avatarurl - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map(|bytes| { let s = utils::string_from_bytes(&bytes) .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; @@ -166,9 +168,9 @@ pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> { pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl - .insert(user_id.to_string(), avatar_url.to_string().as_str())?; + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; } else { - self.userid_avatarurl.remove(user_id.to_string())?; + self.userid_avatarurl.remove(user_id.as_bytes())?; } Ok(()) @@ -190,19 +192,17 @@ pub fn create_device( userdeviceid.extend_from_slice(device_id.as_bytes()); self.userid_devicelistversion - .update_and_fetch(&user_id.as_bytes(), utils::increment)? - .expect("utils::increment will always put in a value"); + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( - userdeviceid, - serde_json::to_string(&Device { + &userdeviceid, + &serde_json::to_vec(&Device { device_id: device_id.into(), display_name: initial_device_display_name, last_seen_ip: None, // TODO last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), }) - .expect("Device::to_string never fails.") - .as_bytes(), + .expect("Device::to_string never fails."), )?; self.set_token(user_id, &device_id, token)?; @@ -217,7 +217,8 @@ pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<() userdeviceid.extend_from_slice(device_id.as_bytes()); // Remove tokens - if let Some(old_token) = self.userdeviceid_token.remove(&userdeviceid)? { + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.userdeviceid_token.remove(&userdeviceid)?; self.token_userdeviceid.remove(&old_token)?; } @@ -225,15 +226,14 @@ pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<() let mut prefix = userdeviceid.clone(); prefix.push(0xff); - for key in self.todeviceid_events.scan_prefix(&prefix).keys() { - self.todeviceid_events.remove(key?)?; + for (key, _) in self.todeviceid_events.scan_prefix(prefix) { + self.todeviceid_events.remove(&key)?; } // TODO: Remove onetimekeys self.userid_devicelistversion - .update_and_fetch(&user_id.as_bytes(), utils::increment)? - .expect("utils::increment will always put in a value"); + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.remove(&userdeviceid)?; @@ -241,16 +241,18 @@ pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<() } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator<Item = Result<Box<DeviceId>>> { + pub fn all_device_ids<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata self.userdeviceid_metadata .scan_prefix(prefix) - .keys() - .map(|bytes| { + .map(|(bytes, _)| { Ok(utils::string_from_bytes( - &*bytes? + &bytes .rsplit(|&b| b == 0xff) .next() .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, @@ -271,13 +273,15 @@ pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> // Remove old token if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(old_token)?; + self.token_userdeviceid.remove(&old_token)?; // It will be removed from userdeviceid_token by the insert later } // Assign token to user device combination - self.userdeviceid_token.insert(&userdeviceid, &*token)?; - self.token_userdeviceid.insert(token, userdeviceid)?; + self.userdeviceid_token + .insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid + .insert(token.as_bytes(), &userdeviceid)?; Ok(()) } @@ -309,8 +313,7 @@ pub fn add_one_time_key( self.onetimekeyid_onetimekeys.insert( &key, - &*serde_json::to_string(&one_time_key_value) - .expect("OneTimeKey::to_string always works"), + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), )?; self.userid_lastonetimekeyupdate @@ -350,10 +353,9 @@ pub fn take_one_time_key( .insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; self.onetimekeyid_onetimekeys - .scan_prefix(&prefix) + .scan_prefix(prefix) .next() - .map(|r| { - let (key, value) = r?; + .map(|(key, value)| { self.onetimekeyid_onetimekeys.remove(&key)?; Ok(( @@ -383,21 +385,20 @@ pub fn count_one_time_keys( let mut counts = BTreeMap::new(); - for algorithm in self - .onetimekeyid_onetimekeys - .scan_prefix(&userdeviceid) - .keys() - .map(|bytes| { - Ok::<_, Error>( - serde_json::from_slice::<DeviceKeyId>( - &*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("OneTimeKey ID in db is invalid.") - })?, + for algorithm in + self.onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::<DeviceKeyId>( + &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { + Error::bad_database("OneTimeKey ID in db is invalid.") + })?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) + }) { *counts.entry(algorithm?).or_default() += UInt::from(1_u32); } @@ -419,7 +420,7 @@ pub fn add_device_keys( self.keyid_key.insert( &userdeviceid, - &*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"), + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), )?; self.mark_device_key_update(user_id, rooms, globals)?; @@ -460,11 +461,11 @@ pub fn add_cross_signing_keys( self.keyid_key.insert( &master_key_key, - &*serde_json::to_string(&master_key).expect("CrossSigningKey::to_string always works"), + &serde_json::to_vec(&master_key).expect("CrossSigningKey::to_vec always works"), )?; self.userid_masterkeyid - .insert(&*user_id.to_string(), master_key_key)?; + .insert(user_id.as_bytes(), &master_key_key)?; // Self-signing key if let Some(self_signing_key) = self_signing_key { @@ -486,12 +487,12 @@ pub fn add_cross_signing_keys( self.keyid_key.insert( &self_signing_key_key, - &*serde_json::to_string(&self_signing_key) - .expect("CrossSigningKey::to_string always works"), + &serde_json::to_vec(&self_signing_key) + .expect("CrossSigningKey::to_vec always works"), )?; self.userid_selfsigningkeyid - .insert(&*user_id.to_string(), self_signing_key_key)?; + .insert(user_id.as_bytes(), &self_signing_key_key)?; } // User-signing key @@ -514,12 +515,12 @@ pub fn add_cross_signing_keys( self.keyid_key.insert( &user_signing_key_key, - &*serde_json::to_string(&user_signing_key) - .expect("CrossSigningKey::to_string always works"), + &serde_json::to_vec(&user_signing_key) + .expect("CrossSigningKey::to_vec always works"), )?; self.userid_usersigningkeyid - .insert(&*user_id.to_string(), user_signing_key_key)?; + .insert(user_id.as_bytes(), &user_signing_key_key)?; } self.mark_device_key_update(user_id, rooms, globals)?; @@ -561,8 +562,7 @@ pub fn sign_key( self.keyid_key.insert( &key, - &*serde_json::to_string(&cross_signing_key) - .expect("CrossSigningKey::to_string always works"), + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), )?; // TODO: Should we notify about this change? @@ -572,24 +572,20 @@ pub fn sign_key( } #[tracing::instrument(skip(self))] - pub fn keys_changed( - &self, + pub fn keys_changed<'a>( + &'a self, user_or_room_id: &str, from: u64, to: Option<u64>, - ) -> impl Iterator<Item = Result<UserId>> { + ) -> impl Iterator<Item = Result<UserId>> + 'a { let mut prefix = user_or_room_id.as_bytes().to_vec(); prefix.push(0xff); let mut start = prefix.clone(); start.extend_from_slice(&(from + 1).to_be_bytes()); - let mut end = prefix.clone(); - end.extend_from_slice(&to.unwrap_or(u64::MAX).to_be_bytes()); - self.keychangeid_userid - .range(start..end) - .filter_map(|r| r.ok()) + .iter_from(&start, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(|(_, bytes)| { Ok( @@ -625,13 +621,13 @@ fn mark_device_key_update( key.push(0xff); key.extend_from_slice(&count); - self.keychangeid_userid.insert(key, &*user_id.to_string())?; + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; } let mut key = user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&count); - self.keychangeid_userid.insert(key, &*user_id.to_string())?; + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; Ok(()) } @@ -645,7 +641,7 @@ pub fn get_device_keys( key.push(0xff); key.extend_from_slice(device_id.as_bytes()); - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Error::bad_database("DeviceKeys in db are invalid.") })?)) @@ -658,9 +654,9 @@ pub fn get_master_key<F: Fn(&UserId) -> bool>( allowed_signatures: F, ) -> Result<Option<CrossSigningKey>> { self.userid_masterkeyid - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map_or(Ok(None), |key| { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) .map_err(|_| { Error::bad_database("CrossSigningKey in db is invalid.") @@ -685,9 +681,9 @@ pub fn get_self_signing_key<F: Fn(&UserId) -> bool>( allowed_signatures: F, ) -> Result<Option<CrossSigningKey>> { self.userid_selfsigningkeyid - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map_or(Ok(None), |key| { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes) .map_err(|_| { Error::bad_database("CrossSigningKey in db is invalid.") @@ -708,9 +704,9 @@ pub fn get_self_signing_key<F: Fn(&UserId) -> bool>( pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> { self.userid_usersigningkeyid - .get(user_id.to_string())? + .get(user_id.as_bytes())? .map_or(Ok(None), |key| { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Error::bad_database("CrossSigningKey in db is invalid.") })?)) @@ -740,7 +736,7 @@ pub fn add_to_device_event( self.todeviceid_events.insert( &key, - &*serde_json::to_string(&json).expect("Map::to_string always works"), + &serde_json::to_vec(&json).expect("Map::to_vec always works"), )?; Ok(()) @@ -759,9 +755,9 @@ pub fn get_to_device_events( prefix.extend_from_slice(device_id.as_bytes()); prefix.push(0xff); - for value in self.todeviceid_events.scan_prefix(&prefix).values() { + for (_, value) in self.todeviceid_events.scan_prefix(prefix) { events.push( - serde_json::from_slice(&*value?) + serde_json::from_slice(&value) .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, ); } @@ -786,10 +782,9 @@ pub fn remove_to_device_events( for (key, _) in self .todeviceid_events - .range(&*prefix..=&*last) - .keys() - .map(|key| { - let key = key?; + .iter_from(&last, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(key, _)| { Ok::<_, Error>(( key.clone(), utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) @@ -799,7 +794,7 @@ pub fn remove_to_device_events( .filter_map(|r| r.ok()) .take_while(|&(_, count)| count <= until) { - self.todeviceid_events.remove(key)?; + self.todeviceid_events.remove(&key)?; } Ok(()) @@ -819,14 +814,11 @@ pub fn update_device_metadata( assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); self.userid_devicelistversion - .update_and_fetch(&user_id.as_bytes(), utils::increment)? - .expect("utils::increment will always put in a value"); + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( - userdeviceid, - serde_json::to_string(device) - .expect("Device::to_string always works") - .as_bytes(), + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), )?; Ok(()) @@ -861,15 +853,17 @@ pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { }) } - pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> { + pub fn all_devices_metadata<'a>( + &'a self, + user_id: &UserId, + ) -> impl Iterator<Item = Result<Device>> + 'a { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); self.userdeviceid_metadata .scan_prefix(key) - .values() - .map(|bytes| { - Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| { + .map(|(_, bytes)| { + Ok(serde_json::from_slice::<Device>(&bytes).map_err(|_| { Error::bad_database("Device in userdeviceid_metadata is invalid.") })?) }) @@ -885,7 +879,7 @@ pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Set the password to "" to indicate a deactivated account. Hashes will never result in an // empty string, so the user will not be able to log in again. Systems like changing the // password without logging in should check if the account is deactivated. - self.userid_password.insert(user_id.to_string(), "")?; + self.userid_password.insert(user_id.as_bytes(), &[])?; // TODO: Unhook 3PID Ok(()) diff --git a/src/error.rs b/src/error.rs index e2664e21619bcaeae6e675a9642138616e953868..4f363fff654075e284f57a32414ecd71c57b7dff 100644 --- a/src/error.rs +++ b/src/error.rs @@ -23,11 +23,18 @@ #[derive(Error, Debug)] pub enum Error { - #[error("There was a problem with the connection to the database.")] + #[cfg(feature = "sled")] + #[error("There was a problem with the connection to the sled database.")] SledError { #[from] source: sled::Error, }, + #[cfg(feature = "rocksdb")] + #[error("There was a problem with the connection to the rocksdb database: {source}")] + RocksDbError { + #[from] + source: rocksdb::Error, + }, #[error("Could not generate an image.")] ImageError { #[from] @@ -40,6 +47,11 @@ pub enum Error { }, #[error("{0}")] FederationError(Box<ServerName>, RumaError), + #[error("Could not do this io: {source}")] + IoError { + #[from] + source: std::io::Error, + }, #[error("{0}")] BadServerResponse(&'static str), #[error("{0}")] diff --git a/src/main.rs b/src/main.rs index e76cea4e159e5cccdf976e15b5be3c571be11f25..8b63d1d6066c4edad8c156bd184902ffdf5bbeb7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,8 @@ mod ruma_wrapper; mod utils; +use std::sync::Arc; + use database::Config; pub use database::Database; pub use error::{Error, Result}; @@ -31,7 +33,7 @@ use tracing::span; use tracing_subscriber::{prelude::*, Registry}; -fn setup_rocket(config: Figment, data: Database) -> rocket::Rocket<rocket::Build> { +fn setup_rocket(config: Figment, data: Arc<Database>) -> rocket::Rocket<rocket::Build> { rocket::custom(config) .manage(data) .mount( @@ -197,8 +199,6 @@ async fn main() { .await .expect("config is valid"); - db.sending.start_handler(&db); - if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() .with_service_name("conduit") diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 147df3c3596f862f5af39e95e2c765043c56ef83..2912a578ac5447ffd8b3fc003d375a0c7e032be1 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -2,13 +2,14 @@ use ruma::{ api::OutgoingResponse, identifiers::{DeviceId, UserId}, - Outgoing, + signatures::CanonicalJsonValue, + Outgoing, ServerName, }; use std::ops::Deref; #[cfg(feature = "conduit_bin")] use { - crate::server_server, + crate::{server_server, Database}, log::{debug, warn}, rocket::{ data::{self, ByteUnit, Data, FromData}, @@ -18,14 +19,11 @@ tokio::io::AsyncReadExt, Request, State, }, - ruma::{ - api::{AuthScheme, IncomingRequest}, - signatures::CanonicalJsonValue, - ServerName, - }, + ruma::api::{AuthScheme, IncomingRequest}, std::collections::BTreeMap, std::convert::TryFrom, std::io::Cursor, + std::sync::Arc, }; /// This struct converts rocket requests into ruma structs by converting them into http requests @@ -51,7 +49,7 @@ impl<'a, T: Outgoing> FromData<'a> for Ruma<T> async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, Self::Error> { let metadata = T::Incoming::METADATA; let db = request - .guard::<State<'_, crate::Database>>() + .guard::<State<'_, Arc<Database>>>() .await .expect("database was loaded"); @@ -75,6 +73,7 @@ async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome<Self, )) = db .appservice .iter_all() + .unwrap() .filter_map(|r| r.ok()) .find(|(_id, registration)| { registration diff --git a/src/server_server.rs b/src/server_server.rs index b405c1abd40eebae15b82c61e9b7829f3288678b..2a445c2ba219b5957eb348ed075aab9b35336cf1 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -433,7 +433,7 @@ pub async fn request_well_known( #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[tracing::instrument(skip(db))] pub fn get_server_version_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, ) -> ConduitResult<get_server_version::v1::Response> { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); @@ -451,7 +451,7 @@ pub fn get_server_version_route( // Response type for this endpoint is Json because we need to calculate a signature for the response #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> { +pub fn get_server_keys_route(db: State<'_, Arc<Database>>) -> Json<String> { if !db.globals.allow_federation() { // TODO: Use proper types return Json("Federation is disabled.".to_owned()); @@ -498,7 +498,7 @@ pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> { #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> { +pub fn get_server_keys_deprecated_route(db: State<'_, Arc<Database>>) -> Json<String> { get_server_keys_route(db) } @@ -508,7 +508,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_public_rooms_filtered::v1::Request<'_>>, ) -> ConduitResult<get_public_rooms_filtered::v1::Response> { if !db.globals.allow_federation() { @@ -556,7 +556,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Database>, + db: State<'_, Arc<Database>>, body: Ruma<get_public_rooms::v1::Request<'_>>, ) -> ConduitResult<get_public_rooms::v1::Response> { if !db.globals.allow_federation() { @@ -603,8 +603,8 @@ pub async fn get_public_rooms_route( put("/_matrix/federation/v1/send/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub async fn send_transaction_message_route<'a>( - db: State<'a, Database>, +pub async fn send_transaction_message_route( + db: State<'_, Arc<Database>>, body: Ruma<send_transaction_message::v1::Request<'_>>, ) -> ConduitResult<send_transaction_message::v1::Response> { if !db.globals.allow_federation() { @@ -1585,7 +1585,7 @@ pub(crate) async fn fetch_signing_keys( .await { db.globals - .add_signing_key(origin, &get_keys_response.server_key)?; + .add_signing_key(origin, get_keys_response.server_key.clone())?; result.extend( get_keys_response @@ -1628,7 +1628,7 @@ pub(crate) async fn fetch_signing_keys( { trace!("Got signing keys: {:?}", keys); for k in keys.server_keys { - db.globals.add_signing_key(origin, &k)?; + db.globals.add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -1681,12 +1681,12 @@ pub(crate) fn append_incoming_pdu( pdu, pdu_json, count, - pdu_id.clone().into(), + &pdu_id, &new_room_leaves.into_iter().collect::<Vec<_>>(), &db, )?; - for appservice in db.appservice.iter_all().filter_map(|r| r.ok()) { + for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") @@ -1758,8 +1758,8 @@ pub(crate) fn append_incoming_pdu( get("/_matrix/federation/v1/event/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_event_route<'a>( - db: State<'a, Database>, +pub fn get_event_route( + db: State<'_, Arc<Database>>, body: Ruma<get_event::v1::Request<'_>>, ) -> ConduitResult<get_event::v1::Response> { if !db.globals.allow_federation() { @@ -1783,8 +1783,8 @@ pub fn get_event_route<'a>( post("/_matrix/federation/v1/get_missing_events/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_missing_events_route<'a>( - db: State<'a, Database>, +pub fn get_missing_events_route( + db: State<'_, Arc<Database>>, body: Ruma<get_missing_events::v1::Request<'_>>, ) -> ConduitResult<get_missing_events::v1::Response> { if !db.globals.allow_federation() { @@ -1832,8 +1832,8 @@ pub fn get_missing_events_route<'a>( get("/_matrix/federation/v1/state_ids/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_room_state_ids_route<'a>( - db: State<'a, Database>, +pub fn get_room_state_ids_route( + db: State<'_, Arc<Database>>, body: Ruma<get_room_state_ids::v1::Request<'_>>, ) -> ConduitResult<get_room_state_ids::v1::Response> { if !db.globals.allow_federation() { @@ -1884,8 +1884,8 @@ pub fn get_room_state_ids_route<'a>( get("/_matrix/federation/v1/make_join/<_>/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn create_join_event_template_route<'a>( - db: State<'a, Database>, +pub fn create_join_event_template_route( + db: State<'_, Arc<Database>>, body: Ruma<create_join_event_template::v1::Request<'_>>, ) -> ConduitResult<create_join_event_template::v1::Response> { if !db.globals.allow_federation() { @@ -2055,8 +2055,8 @@ pub fn create_join_event_template_route<'a>( put("/_matrix/federation/v2/send_join/<_>/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub async fn create_join_event_route<'a>( - db: State<'a, Database>, +pub async fn create_join_event_route( + db: State<'_, Arc<Database>>, body: Ruma<create_join_event::v2::Request<'_>>, ) -> ConduitResult<create_join_event::v2::Response> { if !db.globals.allow_federation() { @@ -2171,8 +2171,8 @@ pub async fn create_join_event_route<'a>( put("/_matrix/federation/v2/invite/<_>/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub async fn create_invite_route<'a>( - db: State<'a, Database>, +pub async fn create_invite_route( + db: State<'_, Arc<Database>>, body: Ruma<create_invite::v2::Request>, ) -> ConduitResult<create_invite::v2::Response> { if !db.globals.allow_federation() { @@ -2276,8 +2276,8 @@ pub async fn create_invite_route<'a>( get("/_matrix/federation/v1/user/devices/<_>", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_devices_route<'a>( - db: State<'a, Database>, +pub fn get_devices_route( + db: State<'_, Arc<Database>>, body: Ruma<get_devices::v1::Request<'_>>, ) -> ConduitResult<get_devices::v1::Response> { if !db.globals.allow_federation() { @@ -2316,8 +2316,8 @@ pub fn get_devices_route<'a>( get("/_matrix/federation/v1/query/directory", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_room_information_route<'a>( - db: State<'a, Database>, +pub fn get_room_information_route( + db: State<'_, Arc<Database>>, body: Ruma<get_room_information::v1::Request<'_>>, ) -> ConduitResult<get_room_information::v1::Response> { if !db.globals.allow_federation() { @@ -2344,8 +2344,8 @@ pub fn get_room_information_route<'a>( get("/_matrix/federation/v1/query/profile", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_profile_information_route<'a>( - db: State<'a, Database>, +pub fn get_profile_information_route( + db: State<'_, Arc<Database>>, body: Ruma<get_profile_information::v1::Request<'_>>, ) -> ConduitResult<get_profile_information::v1::Response> { if !db.globals.allow_federation() { @@ -2378,8 +2378,8 @@ pub fn get_profile_information_route<'a>( post("/_matrix/federation/v1/user/keys/query", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub fn get_keys_route<'a>( - db: State<'a, Database>, +pub fn get_keys_route( + db: State<'_, Arc<Database>>, body: Ruma<get_keys::v1::Request>, ) -> ConduitResult<get_keys::v1::Response> { if !db.globals.allow_federation() { @@ -2406,8 +2406,8 @@ pub fn get_keys_route<'a>( post("/_matrix/federation/v1/user/keys/claim", data = "<body>") )] #[tracing::instrument(skip(db, body))] -pub async fn claim_keys_route<'a>( - db: State<'a, Database>, +pub async fn claim_keys_route( + db: State<'_, Arc<Database>>, body: Ruma<claim_keys::v1::Request>, ) -> ConduitResult<claim_keys::v1::Response> { if !db.globals.allow_federation() { diff --git a/src/utils.rs b/src/utils.rs index 106baffdaa5c5b6fac8ade2e68501df3f9fcf3f5..0c8fb5ca3f54ea9debc2ed69abfb80ccd86546a4 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,6 +15,15 @@ pub fn millis_since_unix_epoch() -> u64 { .as_millis() as u64 } +#[cfg(feature = "rocksdb")] +pub fn increment_rocksdb( + _new_key: &[u8], + old: Option<&[u8]>, + _operands: &mut rocksdb::MergeOperands, +) -> Option<Vec<u8>> { + increment(old) +} + pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { let number = match old.map(|bytes| bytes.try_into()) { Some(Ok(bytes)) => { @@ -27,16 +36,14 @@ pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> { Some(number.to_be_bytes().to_vec()) } -pub fn generate_keypair(old: Option<&[u8]>) -> Option<Vec<u8>> { - Some(old.map(|s| s.to_vec()).unwrap_or_else(|| { - let mut value = random_string(8).as_bytes().to_vec(); - value.push(0xff); - value.extend_from_slice( - &ruma::signatures::Ed25519KeyPair::generate() - .expect("Ed25519KeyPair generation always works (?)"), - ); - value - })) +pub fn generate_keypair() -> Vec<u8> { + let mut value = random_string(8).as_bytes().to_vec(); + value.push(0xff); + value.extend_from_slice( + &ruma::signatures::Ed25519KeyPair::generate() + .expect("Ed25519KeyPair generation always works (?)"), + ); + value } /// Parses the bytes into an u64.