diff --git a/Cargo.lock b/Cargo.lock index 1846c82d87d10174145b4bbc99d174fff27d07a7..7f8e3f77416405a4d60ee0e63bc7acd173124941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,7 +118,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -129,7 +129,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -192,7 +192,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "hyper 1.4.1", "hyper-util", @@ -253,7 +253,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -276,7 +276,7 @@ dependencies = [ "futures-util", "headers", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -296,7 +296,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "hyper 1.4.1", "hyper-util", @@ -377,7 +377,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -457,9 +457,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "bzip2-sys" @@ -474,9 +474,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.0" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" +checksum = "18e2d530f35b40a84124146478cd16f34225306a8441998836466a2e2961c950" dependencies = [ "jobserver", "libc", @@ -562,7 +562,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -589,6 +589,7 @@ dependencies = [ "conduit_router", "conduit_service", "console-subscriber", + "const-str", "hardened_malloc-rs", "log", "opentelemetry", @@ -613,6 +614,8 @@ dependencies = [ "conduit_api", "conduit_core", "conduit_service", + "const-str", + "futures-util", "log", "ruma", "serde_json", @@ -634,6 +637,7 @@ dependencies = [ "conduit_core", "conduit_database", "conduit_service", + "const-str", "futures-util", "hmac", "http 1.1.0", @@ -664,6 +668,7 @@ dependencies = [ "bytes", "checked_ops", "chrono", + "const-str", "either", "figment", "hardened_malloc-rs", @@ -689,6 +694,7 @@ dependencies = [ "tikv-jemalloc-sys", "tikv-jemallocator", "tokio", + "tokio-metrics", "tracing", "tracing-core", "tracing-subscriber", @@ -700,6 +706,7 @@ name = "conduit_database" version = "0.4.5" dependencies = [ "conduit_core", + "const-str", "log", "rust-rocksdb-uwu", "tokio", @@ -719,6 +726,7 @@ dependencies = [ "conduit_api", "conduit_core", "conduit_service", + "const-str", "http 1.1.0", "http-body-util", "hyper 1.4.1", @@ -745,6 +753,7 @@ dependencies = [ "bytes", "conduit_core", "conduit_database", + "const-str", "cyborgtime", "futures-util", "hickory-resolver", @@ -816,6 +825,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "const-str" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3618cccc083bb987a415d85c02ca6c9994ea5b44731ec28b9ecf09658655fba9" + [[package]] name = "const_panic" version = "0.2.8" @@ -1007,7 +1022,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1111,7 +1126,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1247,7 +1262,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1528,7 +1543,7 @@ dependencies = [ "markup5ever", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -1575,9 +1590,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1592,7 +1607,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1649,7 +1664,7 @@ dependencies = [ "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1700,7 +1715,7 @@ dependencies = [ "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "hyper 1.4.1", "pin-project-lite", "socket2", @@ -1937,7 +1952,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -2446,7 +2461,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -2539,7 +2554,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -2627,7 +2642,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", "version_check", "yansi", ] @@ -2652,7 +2667,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -2842,7 +2857,7 @@ dependencies = [ "h2 0.4.5", "hickory-resolver", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "hyper 1.4.1", "hyper-rustls", @@ -3061,7 +3076,7 @@ dependencies = [ "quote", "ruma-identifiers-validation 0.9.5 (git+https://github.com/girlbossceo/ruwuma?rev=c51ccb2c68d2e3557eb12b1a49036531711ec0e5)", "serde", - "syn 2.0.70", + "syn 2.0.71", "toml", ] @@ -3516,7 +3531,7 @@ checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -3809,9 +3824,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.70" +version = "2.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" +checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462" dependencies = [ "proc-macro2", "quote", @@ -3869,22 +3884,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -4033,7 +4048,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -4044,6 +4059,7 @@ checksum = "eace09241d62c98b7eeb1107d4c5c64ca3bd7da92e8c218c153ab3a78f9be112" dependencies = [ "futures-util", "pin-project-lite", + "tokio", "tokio-stream", ] @@ -4208,7 +4224,7 @@ dependencies = [ "futures-core", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tokio", @@ -4249,7 +4265,7 @@ source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390 dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", ] [[package]] @@ -4519,7 +4535,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", "wasm-bindgen-shared", ] @@ -4553,7 +4569,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.70", + "syn 2.0.71", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index fcbe53d566cf5f08f3b55860e05d094f48c6b98f..f9696252bd2cbec71a7f89a781698ab38026e7cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,9 @@ version = "0.4.5" [workspace.metadata.crane] name = "conduit" +[workspace.dependencies.const-str] +version = "0.5.7" + [workspace.dependencies.sanitize-filename] version = "0.5.0" @@ -50,7 +53,7 @@ version = "0.8.5" # Used for the http request / response body type for Ruma endpoints used with reqwest [workspace.dependencies.bytes] -version = "1.6.0" +version = "1.6.1" [workspace.dependencies.http-body-util] version = "0.1.1" @@ -197,6 +200,9 @@ features = [ "io-util", ] +[workspace.dependencies.tokio-metrics] +version = "0.3.1" + [workspace.dependencies.libloading] version = "0.8.3" @@ -245,7 +251,7 @@ default-features = false # Used for conduit::Error type [workspace.dependencies.thiserror] -version = "1.0.61" +version = "1.0.62" # Used when hashing the state [workspace.dependencies.ring] @@ -379,10 +385,6 @@ version = "0.5.4" default-features = false features = ["use_std"] -[workspace.dependencies.tokio-metrics] -version = "0.3.1" -default-features = false - [workspace.dependencies.console-subscriber] version = "0.3" diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 7e5c9710bcec9564e61a5eb95cd092ccf0b8d59f..ea095bfa11451c0b5457cb336a86e2488fb13915 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -57,6 +57,16 @@ # Defaults to 0.15 #sentry_traces_sample_rate = 0.15 +# Whether to attach a stacktrace to Sentry reports. +#sentry_attach_stacktrace = false + +# Send panics to sentry. This is true by default, but sentry has to be enabled. +#sentry_send_panic = true + +# Send errors to sentry. This is true by default, but sentry has to be enabled. This option is +# only effective in release-mode; forced to false in debug-mode. +#sentry_send_error = true + ### Database configuration diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index f84fbb928fff2439c56690378dfca6389bd89545..1e13fb7a97cb2041dbfb8d67913494ddd30c97a4 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -30,6 +30,8 @@ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true conduit-service.workspace = true +const-str.workspace = true +futures-util.workspace = true log.workspace = true ruma.workspace = true serde_json.workspace = true diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 53009566df5ca64696b83749fedc09a68ce2a7eb..46f71622137cb745dc52b3265a11404e356cecce 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -314,6 +314,8 @@ pub(super) async fn force_device_list_updates(_body: Vec<&str>) -> Result<RoomMe pub(super) async fn change_log_level( _body: Vec<&str>, filter: Option<String>, reset: bool, ) -> Result<RoomMessageEventContent> { + let handles = &["console"]; + if reset { let old_filter_layer = match EnvFilter::try_new(&services().globals.config.log) { Ok(s) => s, @@ -324,7 +326,12 @@ pub(super) async fn change_log_level( }, }; - match services().server.log.reload.reload(&old_filter_layer) { + match services() + .server + .log + .reload + .reload(&old_filter_layer, Some(handles)) + { Ok(()) => { return Ok(RoomMessageEventContent::text_plain(format!( "Successfully changed log level back to config value {}", @@ -349,7 +356,12 @@ pub(super) async fn change_log_level( }, }; - match services().server.log.reload.reload(&new_filter_layer) { + match services() + .server + .log + .reload + .reload(&new_filter_layer, Some(handles)) + { Ok(()) => { return Ok(RoomMessageEventContent::text_plain("Successfully changed log level")); }, @@ -570,7 +582,7 @@ pub(super) async fn force_set_room_state_from_server( .state_compressor .save_state(room_id.clone().as_ref(), new_room_state)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; services() .rooms .state @@ -632,12 +644,46 @@ pub(super) async fn resolve_true_destination( pub(super) fn memory_stats() -> RoomMessageEventContent { let html_body = conduit::alloc::memory_stats(); - if html_body.is_empty() { + if html_body.is_none() { return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc."); } RoomMessageEventContent::text_html( "This command's output can only be viewed by clients that render HTML.".to_owned(), - html_body, + html_body.expect("string result"), ) } + +#[cfg(tokio_unstable)] +pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + let out = services().server.metrics.runtime_metrics().map_or_else( + || "Runtime metrics are not available.".to_owned(), + |metrics| format!("```rs\n{metrics:#?}\n```"), + ); + + Ok(RoomMessageEventContent::text_markdown(out)) +} + +#[cfg(not(tokio_unstable))] +pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + Ok(RoomMessageEventContent::text_markdown( + "Runtime metrics require building with `tokio_unstable`.", + )) +} + +#[cfg(tokio_unstable)] +pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + let out = services().server.metrics.runtime_interval().map_or_else( + || "Runtime metrics are not available.".to_owned(), + |metrics| format!("```rs\n{metrics:#?}\n```"), + ); + + Ok(RoomMessageEventContent::text_markdown(out)) +} + +#[cfg(not(tokio_unstable))] +pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + Ok(RoomMessageEventContent::text_markdown( + "Runtime metrics require building with `tokio_unstable`.", + )) +} diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index eed3b633f35018211bbdf7f1a0ed96aa0d914e18..7d6cafa77a61ce0e2e4368d595d863a3bf66b07c 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -160,6 +160,13 @@ pub(super) enum DebugCommand { /// - Print extended memory usage MemoryStats, + /// - Print general tokio runtime metric totals. + RuntimeMetrics, + + /// - Print detailed tokio runtime metrics accumulated since last command + /// invocation. + RuntimeInterval, + /// - Developer test stubs #[command(subcommand)] Tester(TesterCommand), @@ -213,6 +220,8 @@ pub(super) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro no_cache, } => resolve_true_destination(body, server_name, no_cache).await?, DebugCommand::MemoryStats => memory_stats(), + DebugCommand::RuntimeMetrics => runtime_metrics(body).await?, + DebugCommand::RuntimeInterval => runtime_interval(body).await?, DebugCommand::Tester(command) => tester::process(command, body).await?, }) } diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 24f4bc233d5e9e067e463386e1ec3bb0696aa7c8..a97e7582eadca8efdb9a441cf82f4e4fe551fe30 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -16,8 +16,9 @@ pub(super) async fn enable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Resul pub(super) async fn incoming_federation(_body: Vec<&str>) -> Result<RoomMessageEventContent> { let map = services() - .globals - .roomid_federationhandletime + .rooms + .event_handler + .federation_handletime .read() .expect("locked"); let mut msg = format!("Handling {} incoming pdus:\n", map.len()); diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 22ec81c52d5774690d08d787b9f3843d0dcebf42..95c7ed4182037c20eaec437a6de7a442cbea29d9 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -1,10 +1,14 @@ -use std::time::Instant; +use std::{panic::AssertUnwindSafe, time::Instant}; use clap::{CommandFactory, Parser}; -use conduit::trace; -use ruma::events::{ - relation::InReplyTo, - room::message::{Relation::Reply, RoomMessageEventContent}, +use conduit::{error, trace, Error}; +use futures_util::future::FutureExt; +use ruma::{ + events::{ + relation::InReplyTo, + room::message::{Relation::Reply, RoomMessageEventContent}, + }, + OwnedEventId, }; extern crate conduit_service as service; @@ -20,7 +24,6 @@ }; pub(crate) const PAGE_SIZE: usize = 100; -#[cfg_attr(test, derive(Debug))] #[derive(Parser)] #[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] pub(crate) enum AdminCommand { @@ -69,21 +72,39 @@ pub(crate) fn complete(line: &str) -> String { complete_admin_command(AdminComma #[tracing::instrument(skip_all, name = "admin")] async fn handle_command(command: Command) -> CommandResult { - let Some(mut content) = process_admin_message(command.command).await else { - return Ok(None); - }; + AssertUnwindSafe(process_command(&command)) + .catch_unwind() + .await + .map_err(Error::from_panic) + .or_else(|error| handle_panic(&error, command)) +} + +async fn process_command(command: &Command) -> CommandOutput { + process_admin_message(&command.command) + .await + .and_then(|content| reply(content, command.reply_id.clone())) +} + +fn handle_panic(error: &Error, command: Command) -> CommandResult { + let link = "Please submit a [bug report](https://github.com/girlbossceo/conduwuit/issues/new). 🥺"; + let msg = format!("Panic occurred while processing command:\n```\n{error:#?}\n```\n{link}"); + let content = RoomMessageEventContent::notice_markdown(msg); + error!("Panic while processing command: {error:?}"); + Ok(reply(content, command.reply_id)) +} - content.relates_to = command.reply_id.map(|event_id| Reply { +fn reply(mut content: RoomMessageEventContent, reply_id: Option<OwnedEventId>) -> Option<RoomMessageEventContent> { + content.relates_to = reply_id.map(|event_id| Reply { in_reply_to: InReplyTo { event_id, }, }); - Ok(Some(content)) + Some(content) } // Parse and process a message from the admin room -async fn process_admin_message(msg: String) -> CommandOutput { +async fn process_admin_message(msg: &str) -> CommandOutput { let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); let command = lines.next().expect("each string has at least one line"); let body = lines.collect::<Vec<_>>(); @@ -103,7 +124,7 @@ async fn process_admin_message(msg: String) -> CommandOutput { match result { Ok(reply) => Some(reply), Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( - "Encountered an error while handling the command:\n```\n{error}\n```" + "Encountered an error while handling the command:\n```\n{error:#?}\n```" ))), } } diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 6a47bc7450140abb9e236a2478f432e7f19682ed..14856811f086825c856f0e9c90990b4692d89d85 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod query; pub(crate) mod room; pub(crate) mod server; +mod tests; pub(crate) mod user; pub(crate) mod utils; @@ -53,30 +54,3 @@ pub async fn fini() { .expect("locked for writing") .take(); } - -#[cfg(test)] -mod test { - use clap::Parser; - - use crate::handler::AdminCommand; - - #[test] - fn get_help_short() { get_help_inner("-h"); } - - #[test] - fn get_help_long() { get_help_inner("--help"); } - - #[test] - fn get_help_subcommand() { get_help_inner("help"); } - - fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) - .unwrap_err() - .to_string(); - - // Search for a handful of keywords that suggest the help printed properly - assert!(error.contains("Usage:")); - assert!(error.contains("Commands:")); - assert!(error.contains("Options:")); - } -} diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 2e22d6884f06ca4faf333f2f2b339a2ebc08f5d9..9bdd38fca1b05f146193cb31cdc5e2f098db5316 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -26,7 +26,7 @@ pub(super) async fn globals(subcommand: Globals) -> Result<RoomMessageEventConte }, Globals::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.last_check_for_updates_id(); + let results = services().updates.last_check_for_updates_id(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index 06007fb82de497dfc05fa37e3d3506ec757065c6..e45037365375b969ea1ce46161fb9e4f32f11b5d 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -1,4 +1,4 @@ -use conduit::{utils::time, warn, Error, Result}; +use conduit::{utils::time, warn, Err, Result}; use ruma::events::room::message::RoomMessageEventContent; use crate::services; @@ -20,17 +20,12 @@ pub(super) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventCont } pub(super) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let response0 = services().memory_usage().await?; - let response1 = services().db.db.memory_usage()?; - let response2 = conduit::alloc::memory_usage(); + let services_usage = services().memory_usage().await?; + let database_usage = services().db.db.memory_usage()?; + let allocator_usage = conduit::alloc::memory_usage().map_or(String::new(), |s| format!("\nAllocator:\n{s}")); Ok(RoomMessageEventContent::text_plain(format!( - "Services:\n{response0}\nDatabase:\n{response1}\n{}", - if !response2.is_empty() { - format!("Allocator:\n {response2}") - } else { - String::new() - } + "Services:\n{services_usage}\nDatabase:\n{database_usage}{allocator_usage}", ))) } @@ -93,11 +88,10 @@ pub(super) async fn restart(_body: Vec<&str>, force: bool) -> Result<RoomMessage use conduit::utils::sys::current_exe_deleted; if !force && current_exe_deleted() { - return Err(Error::Err( - "The server cannot be restarted because the executable was tampered with. If this is expected use --force \ - to override." - .to_owned(), - )); + return Err!( + "The server cannot be restarted because the executable changed. If this is expected use --force to \ + override." + ); } services().server.restart()?; diff --git a/src/admin/tests.rs b/src/admin/tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..69ccd896c7643fc2aef79c0e969568ee37d1f5fc --- /dev/null +++ b/src/admin/tests.rs @@ -0,0 +1,26 @@ +#![cfg(test)] + +#[test] +fn get_help_short() { get_help_inner("-h"); } + +#[test] +fn get_help_long() { get_help_inner("--help"); } + +#[test] +fn get_help_subcommand() { get_help_inner("help"); } + +fn get_help_inner(input: &str) { + use clap::Parser; + + use crate::handler::AdminCommand; + + let Err(error) = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) else { + panic!("no error!"); + }; + + let error = error.to_string(); + // Search for a handful of keywords that suggest the help printed properly + assert!(error.contains("Usage:")); + assert!(error.contains("Commands:")); + assert!(error.contains("Options:")); +} diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 12aa01703c596101fb2aed7497fcdae626dc71f1..9e4b348b128445faa990f256954bd3d89cb057ae 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -8,7 +8,7 @@ tag::{TagEvent, TagEventContent, TagInfo}, RoomAccountDataEventType, }, - OwnedRoomId, OwnedUserId, RoomId, + OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, }; use tracing::{error, info, warn}; @@ -334,6 +334,35 @@ pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Resu Ok(RoomMessageEventContent::text_html(output_plain, output_html)) } +pub(super) async fn force_join_room( + _body: Vec<&str>, user_id: String, room_id: OwnedRoomOrAliasId, +) -> Result<RoomMessageEventContent> { + let user_id = parse_local_user_id(&user_id)?; + let room_id = services().rooms.alias.resolve(&room_id).await?; + + assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + join_room_by_id_helper(&user_id, &room_id, None, &[], None).await?; + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{user_id} has been joined to {room_id}.", + ))) +} + +pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> { + let user_id = parse_local_user_id(&user_id)?; + let displayname = services() + .users + .displayname(&user_id)? + .unwrap_or_else(|| user_id.to_string()); + + assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + service::admin::make_user_admin(&user_id, displayname).await?; + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{user_id} has been granted admin privileges.", + ))) +} + pub(super) async fn put_room_tag( _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index 088133a5c985ea3f8f2f90e27687328d5bc0b941..cdb5fa5eaf4ab80f9c078dc1237d4bc175d121a8 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -2,7 +2,7 @@ use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomId}; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomOrAliasId, RoomId}; use self::commands::*; @@ -65,6 +65,17 @@ pub(super) enum UserCommand { user_id: String, }, + /// - Manually join a local user to a room. + ForceJoinRoom { + user_id: String, + room_id: OwnedRoomOrAliasId, + }, + + /// - Grant server-admin privileges to a user. + MakeUserAdmin { + user_id: String, + }, + /// - Puts a room tag for the specified user and room ID. /// /// This is primarily useful if you'd like to set your admin room @@ -113,6 +124,13 @@ pub(super) async fn process(command: UserCommand, body: Vec<&str>) -> Result<Roo UserCommand::ListJoinedRooms { user_id, } => list_joined_rooms(body, user_id).await?, + UserCommand::ForceJoinRoom { + user_id, + room_id, + } => force_join_room(body, user_id, room_id).await?, + UserCommand::MakeUserAdmin { + user_id, + } => make_user_admin(body, user_id).await?, UserCommand::PutRoomTag { user_id, room_id, diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 91982fe33ed7cddd89c602dd2d549c7587e38f1c..fda42e6e259969952919a65c94d85b11eba2ab1a 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -1,4 +1,4 @@ -use conduit_core::Error; +use conduit_core::{err, Err}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use service::user_is_local; @@ -33,7 +33,7 @@ pub(crate) fn get_room_info(id: &RoomId) -> (OwnedRoomId, u64, String) { /// Parses user ID pub(crate) fn parse_user_id(user_id: &str) -> Result<OwnedUserId> { UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) - .map_err(|e| Error::Err(format!("The supplied username is not a valid username: {e}"))) + .map_err(|e| err!("The supplied username is not a valid username: {e}")) } /// Parses user ID as our local user @@ -41,7 +41,7 @@ pub(crate) fn parse_local_user_id(user_id: &str) -> Result<OwnedUserId> { let user_id = parse_user_id(user_id)?; if !user_is_local(&user_id) { - return Err(Error::Err(String::from("User does not belong to our server."))); + return Err!("User {user_id:?} does not belong to our server."); } Ok(user_id) @@ -52,11 +52,11 @@ pub(crate) fn parse_active_local_user_id(user_id: &str) -> Result<OwnedUserId> { let user_id = parse_local_user_id(user_id)?; if !services().users.exists(&user_id)? { - return Err(Error::Err(String::from("User does not exist on this server."))); + return Err!("User {user_id:?} does not exist on this server."); } if services().users.is_deactivated(&user_id)? { - return Err(Error::Err(String::from("User is deactivated."))); + return Err!("User {user_id:?} is deactivated."); } Ok(user_id) diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 45cae73d42e720446124f024231ad9205ff12a4f..356adc1f33443c2968edb6ffc9cdecd75400a5dd 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -41,6 +41,7 @@ bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true +const-str.workspace = true futures-util.workspace = true hmac.workspace = true http.workspace = true diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 7d2aff0d697ce01baece24fde42154e87e694d46..68bd0dffcec46ce94e7b8cbfe4b4003da39b12aa 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,4 +1,5 @@ use axum_client_ip::InsecureClientIp; +use conduit::{err, info, warn, Error, Result}; use ruma::{ api::{ client::{ @@ -18,9 +19,8 @@ }, uint, RoomId, ServerName, UInt, UserId, }; -use tracing::{error, info, warn}; -use crate::{service::server_is_ours, services, Error, Result, Ruma}; +use crate::{service::server_is_ours, services, Ruma}; /// # `POST /_matrix/client/v3/publicRooms` /// @@ -271,8 +271,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( _ => None, }) .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") + err!(Database(error!("Invalid room join rule event in database: {e}"))) }) }) .transpose()? diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 6ebcd95bcf8412f549ed90058db654a606c39365..3adee631b61b3e0a6c74df31b8f06f9f9e7e64e0 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -7,9 +7,8 @@ use axum_client_ip::InsecureClientIp; use conduit::{ - debug, error, info, trace, utils, - utils::{math::continue_exponential_backoff_secs, mutex_map}, - warn, Error, PduEvent, Result, + debug, debug_warn, error, info, trace, utils, utils::math::continue_exponential_backoff_secs, warn, Error, + PduEvent, Result, }; use ruma::{ api::{ @@ -36,13 +35,14 @@ OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::sending::convert_to_outgoing_federation_event; use tokio::sync::RwLock; use crate::{ client::{update_avatar_url, update_displayname}, service::{ pdu::{gen_event_id_canonical_json, PduBuilder}, + rooms::state::RoomMutexGuard, + sending::convert_to_outgoing_federation_event, server_is_ours, user_is_local, }, services, Ruma, @@ -366,6 +366,8 @@ pub(crate) async fn invite_user_route( pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Result<kick_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let mut event: RoomMemberEventContent = serde_json::from_str( services() .rooms @@ -383,12 +385,6 @@ pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Resul event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -417,6 +413,8 @@ pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Resul pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let event = services() .rooms .state_accessor @@ -447,12 +445,6 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< }, )?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -481,6 +473,8 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< pub(crate) async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let mut event: RoomMemberEventContent = serde_json::from_str( services() .rooms @@ -496,12 +490,6 @@ pub(crate) async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Res event.reason.clone_from(&body.reason); event.join_authorized_via_users_server = None; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -656,31 +644,33 @@ pub async fn join_room_by_id_helper( sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], third_party_signed: Option<&ThirdPartySigned>, ) -> Result<join_room_by_id::v3::Response> { + let state_lock = services().rooms.state.mutex.lock(room_id).await; + if matches!(services().rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { - info!("{sender_user} is already joined in {room_id}"); + debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), }); } - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; - - // Ask a remote server if we are not participating in this room - if !services() + if services() .rooms .state_cache .server_in_room(services().globals.server_name(), room_id)? + || servers.is_empty() + || (servers.len() == 1 && server_is_ours(&servers[0])) { - join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await - } else { join_room_by_id_helper_local(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + } else { + // Ask a remote server if we are not participating in this room + join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await } } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result<join_room_by_id::v3::Response> { info!("Joining {room_id} over federation."); @@ -790,14 +780,9 @@ async fn join_room_by_id_helper_remote( info!("send_join finished"); if join_authorized_via_users_server.is_some() { + use RoomVersionId::*; match &room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 => { warn!( "Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", room_id, &room_version_id @@ -805,7 +790,7 @@ async fn join_room_by_id_helper_remote( }, // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to // validate and send signatures - RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { + V8 | V9 | V10 | V11 => { if let Some(signed_raw) = &send_join_response.room_state.event { info!( "There is a signed event. This room is probably using restricted joins. Adding signature to \ @@ -1016,9 +1001,9 @@ async fn join_room_by_id_helper_remote( #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result<join_room_by_id::v3::Response> { - info!("We can join locally"); + debug!("We can join locally"); let join_rules_event = services() @@ -1118,7 +1103,7 @@ async fn join_room_by_id_helper_local( .iter() .any(|server_name| !server_is_ours(server_name)) { - info!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); + warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let room_version_id = match make_join_response.room_version { @@ -1283,16 +1268,12 @@ async fn make_join_request( make_join_counter = make_join_counter.saturating_add(1); if let Err(ref e) = make_join_response { - trace!("make_join ErrorKind string: {:?}", e.error_code().to_string()); + trace!("make_join ErrorKind string: {:?}", e.kind().to_string()); // converting to a string is necessary (i think) because ruma is forcing us to // fill in the struct for M_INCOMPATIBLE_ROOM_VERSION - if e.error_code() - .to_string() - .contains("M_INCOMPATIBLE_ROOM_VERSION") - || e.error_code() - .to_string() - .contains("M_UNSUPPORTED_ROOM_VERSION") + if e.kind().to_string().contains("M_INCOMPATIBLE_ROOM_VERSION") + || e.kind().to_string().contains("M_UNSUPPORTED_ROOM_VERSION") { incompatible_room_version_count = incompatible_room_version_count.saturating_add(1); } @@ -1397,7 +1378,7 @@ pub(crate) async fn invite_helper( if !user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: services().users.avatar_url(user_id)?, displayname: None, @@ -1509,7 +1490,7 @@ pub(crate) async fn invite_helper( )); } - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; services() .rooms @@ -1603,7 +1584,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin true, )?; } else { - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let member_event = services() diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 9548f0844e6ea2196e9b4ca641cb534543690011..c376ee52294588a52b790d880e2792c6a5e8faa2 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -29,11 +29,7 @@ pub(crate) async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; // Forbid m.room.encrypted if encryption is disabled if MessageLikeEventType::RoomEncrypted == body.event_type && !services().globals.allow_encryption() { diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 2b23344ee7ef34d574e1b59c18000f288dfb9a47..b030593970386f0a2e3bae50240447de19681c0b 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -353,7 +353,7 @@ pub async fn update_avatar_url( pub async fn update_all_rooms(all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId) { for (pdu_builder, room_id) in all_joined_rooms { - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; if let Err(e) = services() .rooms .timeline diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 4cb24c33606c4918e2efc0a3b40cfb91b4f29454..308d12e5bf3d78a1a23bc2e54c3756fd8257b819 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -15,11 +15,7 @@ pub(crate) async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let event_id = services() .rooms diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 7090fdc8933234fac534923da36f24afe7a0ebd9..adf58b04d60a4e12c39e0fc7f3c26ae422890f7c 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -90,7 +90,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R } let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { Some(room_alias_check(alias, &body.appservice_info).await?) @@ -118,6 +118,8 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R let content = match &body.creation_content { Some(content) => { + use RoomVersionId::*; + let mut content = content .deserialize_as::<CanonicalJsonObject>() .map_err(|e| { @@ -125,16 +127,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R Error::bad_database("Failed to deserialise content as canonical JSON.") })?; match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { content.insert( "creator".into(), json!(&sender_user).try_into().map_err(|e| { @@ -143,7 +136,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R })?, ); }, - RoomVersionId::V11 => {}, // V11 removed the "creator" key + V11 => {}, // V11 removed the "creator" key _ => { warn!("Unexpected or unsupported room version {room_version}"); return Err(Error::BadRequest( @@ -152,7 +145,6 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R )); }, } - content.insert( "room_version".into(), json!(room_version.as_str()) @@ -162,18 +154,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R content }, None => { + use RoomVersionId::*; + let content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(sender_user.clone()), + V11 => RoomCreateEventContent::new_v11(), _ => { warn!("Unexpected or unsupported room version {room_version}"); return Err(Error::BadRequest( @@ -573,11 +558,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> .short .get_or_create_shortroomid(&replacement_room)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; // Send a m.room.tombstone event to the old room to indicate that it is not // intended to be used any further Fail if the sender does not have the required @@ -605,11 +586,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Change lock to replacement room drop(state_lock); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&replacement_room) - .await; + let state_lock = services().rooms.state.mutex.lock(&replacement_room).await; // Get the old room creation event let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( @@ -631,36 +608,30 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Send a m.room.create event containing a predecessor field and the applicable // room_version - match body.new_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - create_event_content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Error forming creation event: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") - })?, - ); - }, - RoomVersionId::V11 => { - // "creator" key no longer exists in V11 rooms - create_event_content.remove("creator"); - }, - _ => { - warn!("Unexpected or unsupported room version {}", body.new_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, + { + use RoomVersionId::*; + match body.new_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + create_event_content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Error forming creation event: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + })?, + ); + }, + V11 => { + // "creator" key no longer exists in V11 rooms + create_event_content.remove("creator"); + }, + _ => { + warn!("Unexpected or unsupported room version {}", body.new_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } } create_event_content.insert( diff --git a/src/api/client/state.rs b/src/api/client/state.rs index abff92181b57248e0f9badbe468ce7792189d5c8..25b77fe3aebae7ae2369aab7ca274a8efe81de93 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -170,7 +170,7 @@ async fn send_state_event_for_key_helper( sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String, ) -> Result<Arc<EventId>> { allowed_to_send_state_event(room_id, event_type, json).await?; - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let event_id = services() .rooms .timeline diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 45a1c75b3eae422cb9ad3588fc34707b9e787d7e..e425616b0bcc722eab16d7646faa8f404f47e7a9 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -8,7 +8,7 @@ use conduit::{ error, utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, - PduCount, + Err, PduCount, }; use ruma::{ api::client::{ @@ -199,7 +199,7 @@ pub(crate) async fn sync_events_route( let (room_id, invite_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); let invite_count = services() @@ -317,7 +317,7 @@ async fn handle_left_room( next_batch_string: &str, full_state: bool, lazy_load_enabled: bool, ) -> Result<()> { // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().globals.roomid_mutex_insert.lock(room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); let left_count = services() @@ -519,7 +519,7 @@ async fn load_joined_room( ) -> Result<JoinedRoom> { // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch - let insert_lock = services().globals.roomid_mutex_insert.lock(room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; @@ -545,8 +545,7 @@ async fn load_joined_room( // Database queries: let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); - return Err(Error::BadDatabase("Room has no state")); + return Err!(Database(error!("Room {room_id} has no state"))); }; let since_shortstatehash = services() diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 08a08e08bff9406149709eeaca646a51aef4e3d9..6c2922b97cb40de8609596f5eb013dd82518712d 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -6,6 +6,7 @@ typed_header::TypedHeaderRejectionReason, TypedHeader, }; +use conduit::Err; use http::uri::PathAndQuery; use ruma::{ api::{client::error::ErrorKind, AuthScheme, Metadata}, @@ -183,7 +184,7 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValue>) -> Result<Auth> { if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); + return Err!(Config("allow_federation", "Federation is disabled.")); } let TypedHeader(Authorization(x_matrix)) = request diff --git a/src/api/routes.rs b/src/api/routes.rs index b22a32cb9bf2294f501125b0e4aae12493e2bffd..3a8b2c7428910c0c759c3274975b64dabab99e37 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -3,7 +3,7 @@ routing::{any, get, post}, Router, }; -use conduit::{Error, Server}; +use conduit::{err, Error, Server}; use http::Uri; use ruma::api::client::error::ErrorKind; @@ -236,4 +236,4 @@ async fn initial_sync(_uri: Uri) -> impl IntoResponse { Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") } -async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") } +async fn federation_disabled() -> impl IntoResponse { err!(Config("allow_federation", "Federation is disabled.")) } diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index ca50dcbeb89e2a39ac0933c1fa09b48d3472c005..b5dadf7fdaa1abd085c3835f3d497013c68226c3 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -7,7 +7,7 @@ }, StateEventType, TimelineEventType, }, - RoomId, RoomVersionId, UserId, + CanonicalJsonObject, RoomId, RoomVersionId, UserId, }; use serde_json::value::to_raw_value; use tracing::warn; @@ -71,11 +71,7 @@ pub(crate) async fn create_join_event_template_route( let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let join_authorized_via_users_server = if (services() .rooms @@ -148,27 +144,7 @@ pub(crate) async fn create_join_event_template_route( drop(state_lock); // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 - | RoomVersionId::V11 => { - pdu_json.remove("event_id"); - }, - _ => { - warn!("Unexpected or unsupported room version {room_version_id}"); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, - }; + maybe_strip_event_id(&mut pdu_json, &room_version_id)?; Ok(prepare_join_event::v1::Response { room_version: Some(room_version_id), @@ -183,6 +159,8 @@ pub(crate) async fn create_join_event_template_route( pub(crate) fn user_can_perform_restricted_join( user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result<bool> { + use RoomVersionId::*; + let join_rules_event = services() .rooms @@ -202,16 +180,7 @@ pub(crate) fn user_can_perform_restricted_join( return Ok(false); }; - if matches!( - room_version_id, - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - ) { + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) { return Ok(false); } @@ -243,3 +212,23 @@ pub(crate) fn user_can_perform_restricted_join( )) } } + +pub(crate) fn maybe_strip_event_id(pdu_json: &mut CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result<()> { + use RoomVersionId::*; + + match room_version_id { + V1 | V2 => {}, + V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 | V11 => { + pdu_json.remove("event_id"); + }, + _ => { + warn!("Unexpected or unsupported room version {room_version_id}"); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + }; + + Ok(()) +} diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 62c0971712ba1ceb075304b9eeff4f5aa1d3e29c..63fc2b2eb28de30077e78281fc7c76ad196d1dee 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -1,14 +1,15 @@ +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, events::{ room::member::{MembershipState, RoomMemberEventContent}, TimelineEventType, }, - RoomVersionId, }; use serde_json::value::to_raw_value; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; +use super::make_join::maybe_strip_event_id; +use crate::{service::pdu::PduBuilder, services, Ruma}; /// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}` /// @@ -35,11 +36,7 @@ pub(crate) async fn create_leave_event_template_route( .acl_check(origin, &body.room_id)?; let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, blurhash: None, @@ -68,26 +65,7 @@ pub(crate) async fn create_leave_event_template_route( drop(state_lock); // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 - | RoomVersionId::V11 => { - pdu_json.remove("event_id"); - }, - _ => { - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, - }; + maybe_strip_event_id(&mut pdu_json, &room_version_id)?; Ok(prepare_leave_event::v1::Response { room_version: Some(room_version_id), diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 90225a1450968eae2a79cb5c9acba17d96f6cbc1..122f564f26c1b95945f47956339234df737039e0 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -84,7 +84,7 @@ pub(crate) async fn send_transaction_message_route( Ok(send_transaction_message::v1::Response { pdus: resolved_map .into_iter() - .map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))) + .map(|(e, r)| (e, r.map_err(|e| e.sanitized_string()))) .collect(), }) } @@ -127,8 +127,9 @@ async fn handle_pdus( for (event_id, value, room_id) in parsed_pdus { let pdu_start_time = Instant::now(); let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(&room_id) .await; resolved_map.insert( diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index ff362f64616acdbfbea06b1d1c6c2177589a3525..577833d55a23a5ae75900f7f6dd66db6a3754507 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -156,8 +156,9 @@ async fn create_join_event( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(room_id) .await; let pdu_id: Vec<u8> = services() diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 4fdde515ac6aa34a79aa6bc15160c9f98829674e..c4e17bbc32f851f3b2a786794f60f16b5a627787 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -152,8 +152,9 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson .await?; let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(room_id) .await; let pdu_id: Vec<u8> = services() diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index e47f673e46dd4a03acf2f7d9200ea137878f23d7..453d7b13ec317d495fe380c9246dad2d7e9a809f 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -55,6 +55,7 @@ axum.workspace = true bytes.workspace = true checked_ops.workspace = true chrono.workspace = true +const-str.workspace = true either.workspace = true figment.workspace = true http-body-util.workspace = true @@ -81,6 +82,7 @@ tikv-jemalloc-ctl.workspace = true tikv-jemalloc-sys.optional = true tikv-jemalloc-sys.workspace = true tokio.workspace = true +tokio-metrics.workspace = true tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true diff --git a/src/core/alloc/default.rs b/src/core/alloc/default.rs index 4e2f8d7ec6cb8dd1df0404ba7818f58da7af6bc4..83bfca7d4f8c224654cf65fee523677439df6990 100644 --- a/src/core/alloc/default.rs +++ b/src/core/alloc/default.rs @@ -1,9 +1,9 @@ //! Default allocator with no special features -/// Always returns the empty string +/// Always returns None #[must_use] -pub fn memory_stats() -> String { String::default() } +pub fn memory_stats() -> Option<String> { None } -/// Always returns the empty string +/// Always returns None #[must_use] -pub fn memory_usage() -> String { String::default() } +pub fn memory_usage() -> Option<String> { None } diff --git a/src/core/alloc/hardened.rs b/src/core/alloc/hardened.rs index 6727407ffadb2f1e48f83ae0de6cf11bec7a362b..335a33078c2ba799efeadf3227f72dff0e31191f 100644 --- a/src/core/alloc/hardened.rs +++ b/src/core/alloc/hardened.rs @@ -4,9 +4,10 @@ static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc; #[must_use] -pub fn memory_usage() -> String { - String::default() //TODO: get usage -} +//TODO: get usage +pub fn memory_usage() -> Option<string> { None } #[must_use] -pub fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() } +pub fn memory_stats() -> Option<String> { + Some("Extended statistics are not available from hardened_malloc.".to_owned()) +} diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index 5e3c361fa2730a8f7567f85786034bb79584b976..08bfc49ada7f29320fc17d5b562ad056ce994453 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -10,7 +10,7 @@ static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; #[must_use] -pub fn memory_usage() -> String { +pub fn memory_usage() -> Option<String> { use mallctl::stats; let mibs = |input: Result<usize, mallctl::Error>| { @@ -27,14 +27,14 @@ pub fn memory_usage() -> String { let metadata = mibs(stats::metadata::read()); let resident = mibs(stats::resident::read()); let retained = mibs(stats::retained::read()); - format!( + Some(format!( "allocated: {allocated:.2} MiB\nactive: {active:.2} MiB\nmapped: {mapped:.2} MiB\nmetadata: {metadata:.2} \ MiB\nresident: {resident:.2} MiB\nretained: {retained:.2} MiB\n" - ) + )) } #[must_use] -pub fn memory_stats() -> String { +pub fn memory_stats() -> Option<String> { const MAX_LENGTH: usize = 65536 - 4096; let opts_s = "d"; @@ -51,7 +51,7 @@ pub fn memory_stats() -> String { unsafe { ffi::malloc_stats_print(Some(malloc_stats_cb), opaque, opts_p) }; str.truncate(MAX_LENGTH); - format!("<pre><code>{str}</code></pre>") + Some(format!("<pre><code>{str}</code></pre>")) } extern "C" fn malloc_stats_cb(opaque: *mut c_void, msg: *const c_char) { diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 95b4df4ea0835cf2ba96b66602a79a89fe783b9c..b36b9c5e5f6170168e9978308a04d7e769869e3e 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -1,50 +1,63 @@ -use crate::{error::Error, warn, Config}; - -pub fn check(config: &Config) -> Result<(), Error> { - #[cfg(all(feature = "rocksdb", not(feature = "sha256_media")))] // prevents catching this in `--all-features` - warn!( - "Note the rocksdb feature was deleted from conduwuit. SQLite support was removed and RocksDB is the only \ - supported backend now. Please update your build script to remove this feature." - ); - #[cfg(all(feature = "sha256_media", not(feature = "rocksdb")))] // prevents catching this in `--all-features` - warn!( - "Note the sha256_media feature was deleted from conduwuit, it is now fully integrated in a \ - forwards-compatible way. Please update your build script to remove this feature." - ); - - config.warn_deprecated(); - config.warn_unknown_key(); +use figment::Figment; + +use super::DEPRECATED_KEYS; +use crate::{debug, debug_info, error, info, warn, Config, Err, Result}; + +#[allow(clippy::cognitive_complexity)] +pub fn check(config: &Config) -> Result<()> { + if cfg!(debug_assertions) { + info!("Note: conduwuit was built without optimisations (i.e. debug build)"); + } + + // prevents catching this in `--all-features` + if cfg!(all(feature = "rocksdb", not(feature = "sha256_media"))) { + warn!( + "Note the rocksdb feature was deleted from conduwuit. SQLite support was removed and RocksDB is the only \ + supported backend now. Please update your build script to remove this feature." + ); + } + + // prevents catching this in `--all-features` + if cfg!(all(feature = "sha256_media", not(feature = "rocksdb"))) { + warn!( + "Note the sha256_media feature was deleted from conduwuit, it is now fully integrated in a \ + forwards-compatible way. Please update your build script to remove this feature." + ); + } + + warn_deprecated(config); + warn_unknown_key(config); if config.sentry && config.sentry_endpoint.is_none() { - return Err(Error::bad_config("Sentry cannot be enabled without an endpoint set")); + return Err!(Config("sentry_endpoint", "Sentry cannot be enabled without an endpoint set")); } - #[cfg(all(feature = "hardened_malloc", feature = "jemalloc"))] - warn!( - "hardened_malloc and jemalloc are both enabled, this causes jemalloc to be used. If using --all-features, \ - this is harmless." - ); + if cfg!(all(feature = "hardened_malloc", feature = "jemalloc")) { + warn!( + "hardened_malloc and jemalloc are both enabled, this causes jemalloc to be used. If using --all-features, \ + this is harmless." + ); + } - #[cfg(not(unix))] - if config.unix_socket_path.is_some() { - return Err(Error::bad_config( - "UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \ - config.", + if cfg!(not(unix)) && config.unix_socket_path.is_some() { + return Err!(Config( + "unix_socket_path", + "UNIX socket support is only available on *nix platforms. Please remove 'unix_socket_path' from your \ + config." )); } - #[cfg(unix)] - if config.unix_socket_path.is_none() { + if cfg!(unix) && config.unix_socket_path.is_none() { config.get_bind_addrs().iter().for_each(|addr| { use std::path::Path; if addr.ip().is_loopback() { - crate::debug_info!("Found loopback listening address {addr}, running checks if we're in a container.",); + debug_info!("Found loopback listening address {addr}, running checks if we're in a container."); if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() /* Host */ { - crate::error!( + error!( "You are detected using OpenVZ with a loopback/localhost listening address of {addr}. If you \ are using OpenVZ for containers and you use NAT-based networking to communicate with the \ host and guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ @@ -53,7 +66,7 @@ pub fn check(config: &Config) -> Result<(), Error> { } if Path::new("/.dockerenv").exists() { - crate::error!( + error!( "You are detected using Docker with a loopback/localhost listening address of {addr}. If you \ are using a reverse proxy on the host and require communication to conduwuit in the Docker \ container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". \ @@ -62,7 +75,7 @@ pub fn check(config: &Config) -> Result<(), Error> { } if Path::new("/run/.containerenv").exists() { - crate::error!( + error!( "You are detected using Podman with a loopback/localhost listening address of {addr}. If you \ are using a reverse proxy on the host and require communication to conduwuit in the Podman \ container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". \ @@ -75,36 +88,39 @@ pub fn check(config: &Config) -> Result<(), Error> { // rocksdb does not allow max_log_files to be 0 if config.rocksdb_max_log_files == 0 { - return Err(Error::bad_config( - "rocksdb_max_log_files cannot be 0. Please set a value at least 1.", + return Err!(Config( + "max_log_files", + "rocksdb_max_log_files cannot be 0. Please set a value at least 1." )); } // yeah, unless the user built a debug build hopefully for local testing only - #[cfg(not(debug_assertions))] - if config.server_name == "your.server.name" { - return Err(Error::bad_config( - "You must specify a valid server name for production usage of conduwuit.", + if cfg!(not(debug_assertions)) && config.server_name == "your.server.name" { + return Err!(Config( + "server_name", + "You must specify a valid server name for production usage of conduwuit." )); } - #[cfg(debug_assertions)] - crate::info!("Note: conduwuit was built without optimisations (i.e. debug build)"); - // check if the user specified a registration token as `""` if config.registration_token == Some(String::new()) { - return Err(Error::bad_config("Registration token was specified but is empty (\"\")")); + return Err!(Config( + "registration_token", + "Registration token was specified but is empty (\"\")" + )); } if config.max_request_size < 5_120_000 { - return Err(Error::bad_config("Max request size is less than 5MB. Please increase it.")); + return Err!(Config( + "max_request_size", + "Max request size is less than 5MB. Please increase it." + )); } // check if user specified valid IP CIDR ranges on startup for cidr in &config.ip_range_denylist { if let Err(e) = ipaddress::IPAddress::parse(cidr) { - crate::error!("Error parsing specified IP CIDR range from string: {e}"); - return Err(Error::bad_config("Error parsing specified IP CIDR ranges from strings")); + return Err!(Config("ip_range_denylist", "Parsing specified IP CIDR range from string: {e}.")); } } @@ -112,13 +128,14 @@ pub fn check(config: &Config) -> Result<(), Error> { && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() { - return Err(Error::bad_config( + return Err!(Config( + "registration_token", "!! You have `allow_registration` enabled without a token configured in your config which means you are \ allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ want, please set the following config option to true: -`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`", +`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`" )); } @@ -135,8 +152,9 @@ pub fn check(config: &Config) -> Result<(), Error> { } if config.allow_outgoing_presence && !config.allow_local_presence { - return Err(Error::bad_config( - "Outgoing presence requires allowing local presence. Please enable \"allow_local_presence\".", + return Err!(Config( + "allow_local_presence", + "Outgoing presence requires allowing local presence. Please enable 'allow_local_presence'." )); } @@ -173,3 +191,52 @@ pub fn check(config: &Config) -> Result<(), Error> { Ok(()) } + +/// Iterates over all the keys in the config file and warns if there is a +/// deprecated key specified +fn warn_deprecated(config: &Config) { + debug!("Checking for deprecated config keys"); + let mut was_deprecated = false; + for key in config + .catchall + .keys() + .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) + { + warn!("Config parameter \"{}\" is deprecated, ignoring.", key); + was_deprecated = true; + } + + if was_deprecated { + warn!( + "Read conduwuit config documentation at https://conduwuit.puppyirl.gay/configuration.html and check your \ + configuration if any new configuration parameters should be adjusted" + ); + } +} + +/// iterates over all the catchall keys (unknown config options) and warns +/// if there are any. +fn warn_unknown_key(config: &Config) { + debug!("Checking for unknown config keys"); + for key in config + .catchall + .keys() + .filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) + { + warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); + } +} + +/// Checks the presence of the `address` and `unix_socket_path` keys in the +/// raw_config, exiting the process if both keys were detected. +pub(super) fn is_dual_listening(raw_config: &Figment) -> Result<()> { + let contains_address = raw_config.contains("address"); + let contains_unix_socket = raw_config.contains("unix_socket_path"); + if contains_address && contains_unix_socket { + return Err!( + "TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option." + ); + } + + Ok(()) +} diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index f9a5ec4c70bbd30aa7b8991211daa4b7e9be91e8..336144db52ec2b8d948f1adfd05cde9b56ef45a3 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1,6 +1,6 @@ use std::{ collections::BTreeMap, - fmt::{self, Write as _}, + fmt, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::PathBuf, }; @@ -19,30 +19,15 @@ api::client::discovery::discover_support::ContactRole, OwnedRoomId, OwnedServerName, OwnedUserId, RoomVersionId, }; use serde::{de::IgnoredAny, Deserialize}; -use tracing::{debug, error, warn}; use url::Url; pub use self::check::check; use self::proxy::ProxyConfig; -use crate::error::Error; +use crate::{error::Error, Err, Result}; pub mod check; pub mod proxy; -#[derive(Deserialize, Clone, Debug)] -#[serde(transparent)] -struct ListeningPort { - #[serde(with = "either::serde_untagged")] - ports: Either<u16, Vec<u16>>, -} - -#[derive(Deserialize, Clone, Debug)] -#[serde(transparent)] -struct ListeningAddr { - #[serde(with = "either::serde_untagged")] - addrs: Either<IpAddr, Vec<IpAddr>>, -} - /// all the config options for conduwuit #[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] @@ -181,16 +166,14 @@ pub struct Config { #[serde(default)] pub well_known: WellKnownConfig, #[serde(default)] - #[cfg(feature = "perf_measurements")] pub allow_jaeger: bool, + #[serde(default = "default_jaeger_filter")] + pub jaeger_filter: String, #[serde(default)] - #[cfg(feature = "perf_measurements")] pub tracing_flame: bool, #[serde(default = "default_tracing_flame_filter")] - #[cfg(feature = "perf_measurements")] pub tracing_flame_filter: String, #[serde(default = "default_tracing_flame_output_path")] - #[cfg(feature = "perf_measurements")] pub tracing_flame_output_path: String, #[serde(default)] pub proxy: ProxyConfig, @@ -356,6 +339,14 @@ pub struct Config { pub sentry_send_server_name: bool, #[serde(default = "default_sentry_traces_sample_rate")] pub sentry_traces_sample_rate: f32, + #[serde(default)] + pub sentry_attach_stacktrace: bool, + #[serde(default = "true_fn")] + pub sentry_send_panic: bool, + #[serde(default = "true_fn")] + pub sentry_send_error: bool, + #[serde(default = "default_sentry_filter")] + pub sentry_filter: String, #[serde(default)] pub tokio_console: bool, @@ -386,6 +377,20 @@ pub struct WellKnownConfig { pub support_mxid: Option<OwnedUserId>, } +#[derive(Deserialize, Clone, Debug)] +#[serde(transparent)] +struct ListeningPort { + #[serde(with = "either::serde_untagged")] + ports: Either<u16, Vec<u16>>, +} + +#[derive(Deserialize, Clone, Debug)] +#[serde(transparent)] +struct ListeningAddr { + #[serde(with = "either::serde_untagged")] + addrs: Either<IpAddr, Vec<IpAddr>>, +} + const DEPRECATED_KEYS: &[&str] = &[ "cache_capacity", "max_concurrent_requests", @@ -399,7 +404,7 @@ pub struct WellKnownConfig { impl Config { /// Initialize config - pub fn new(path: Option<PathBuf>) -> Result<Self, Error> { + pub fn new(path: Option<PathBuf>) -> Result<Self> { let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { Figment::new() .merge(Toml::file(config_file_env).nested()) @@ -422,69 +427,16 @@ pub fn new(path: Option<PathBuf>) -> Result<Self, Error> { }; let config = match raw_config.extract::<Self>() { - Err(e) => return Err(Error::BadConfig(format!("{e}"))), + Err(e) => return Err!("There was a problem with your configuration file: {e}"), Ok(config) => config, }; // don't start if we're listening on both UNIX sockets and TCP at same time - if Self::is_dual_listening(&raw_config) { - return Err(Error::bad_config("dual listening on UNIX and TCP sockets not allowed.")); - }; + check::is_dual_listening(&raw_config)?; Ok(config) } - /// Iterates over all the keys in the config file and warns if there is a - /// deprecated key specified - pub(crate) fn warn_deprecated(&self) { - debug!("Checking for deprecated config keys"); - let mut was_deprecated = false; - for key in self - .catchall - .keys() - .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) - { - warn!("Config parameter \"{}\" is deprecated, ignoring.", key); - was_deprecated = true; - } - - if was_deprecated { - warn!( - "Read conduwuit config documentation at https://conduwuit.puppyirl.gay/configuration.html and check \ - your configuration if any new configuration parameters should be adjusted" - ); - } - } - - /// iterates over all the catchall keys (unknown config options) and warns - /// if there are any. - pub(crate) fn warn_unknown_key(&self) { - debug!("Checking for unknown config keys"); - for key in self - .catchall - .keys() - .filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) - { - warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); - } - } - - /// Checks the presence of the `address` and `unix_socket_path` keys in the - /// raw_config, exiting the process if both keys were detected. - fn is_dual_listening(raw_config: &Figment) -> bool { - let check_address = raw_config.find_value("address"); - let check_unix_socket = raw_config.find_value("unix_socket_path"); - - // are the check_address and check_unix_socket keys both Ok (specified) at the - // same time? - if check_address.is_ok() && check_unix_socket.is_ok() { - error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option."); - return true; - } - - false - } - #[must_use] pub fn get_bind_addrs(&self) -> Vec<SocketAddr> { let mut addrs = Vec::new(); @@ -516,361 +468,358 @@ pub fn check(&self) -> Result<(), Error> { check(self) } impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Prepare a list of config values to show - let lines = [ - ("Server name", self.server_name.host()), - ("Database backend", &self.database_backend), - ("Database path", &self.database_path.to_string_lossy()), - ( - "Database backup path", - self.database_backup_path - .as_ref() - .map_or("", |path| path.to_str().unwrap_or("")), - ), - ("Database backups to keep", &self.database_backups_to_keep.to_string()), - ("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()), - ("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()), - ("PDU cache capacity", &self.pdu_cache_capacity.to_string()), - ("Auth chain cache capacity", &self.auth_chain_cache_capacity.to_string()), - ("Short eventid cache capacity", &self.shorteventid_cache_capacity.to_string()), - ("Eventid short cache capacity", &self.eventidshort_cache_capacity.to_string()), - ("Short statekey cache capacity", &self.shortstatekey_cache_capacity.to_string()), - ("Statekey short cache capacity", &self.statekeyshort_cache_capacity.to_string()), - ( - "Server visibility cache capacity", - &self.server_visibility_cache_capacity.to_string(), - ), - ( - "User visibility cache capacity", - &self.user_visibility_cache_capacity.to_string(), - ), - ("Stateinfo cache capacity", &self.stateinfo_cache_capacity.to_string()), - ( - "Roomid space hierarchy cache capacity", - &self.roomid_spacehierarchy_cache_capacity.to_string(), - ), - ("DNS cache entry limit", &self.dns_cache_entries.to_string()), - ("DNS minimum TTL", &self.dns_min_ttl.to_string()), - ("DNS minimum NXDOMAIN TTL", &self.dns_min_ttl_nxdomain.to_string()), - ("DNS attempts", &self.dns_attempts.to_string()), - ("DNS timeout", &self.dns_timeout.to_string()), - ("DNS fallback to TCP", &self.dns_tcp_fallback.to_string()), - ("DNS query over TCP only", &self.query_over_tcp_only.to_string()), - ("Query all nameservers", &self.query_all_nameservers.to_string()), - ("Maximum request size (bytes)", &self.max_request_size.to_string()), - ("Sender retry backoff limit", &self.sender_retry_backoff_limit.to_string()), - ("Request connect timeout", &self.request_conn_timeout.to_string()), - ("Request timeout", &self.request_timeout.to_string()), - ("Request total timeout", &self.request_total_timeout.to_string()), - ("Idle connections per host", &self.request_idle_per_host.to_string()), - ("Request pool idle timeout", &self.request_idle_timeout.to_string()), - ("Well_known connect timeout", &self.well_known_conn_timeout.to_string()), - ("Well_known timeout", &self.well_known_timeout.to_string()), - ("Federation timeout", &self.federation_timeout.to_string()), - ("Federation pool idle per host", &self.federation_idle_per_host.to_string()), - ("Federation pool idle timeout", &self.federation_idle_timeout.to_string()), - ("Sender timeout", &self.sender_timeout.to_string()), - ("Sender pool idle timeout", &self.sender_idle_timeout.to_string()), - ("Appservice timeout", &self.appservice_timeout.to_string()), - ("Appservice pool idle timeout", &self.appservice_idle_timeout.to_string()), - ("Pusher pool idle timeout", &self.pusher_idle_timeout.to_string()), - ("Allow registration", &self.allow_registration.to_string()), - ( - "Registration token", - if self.registration_token.is_some() { - "set" - } else { - "not set (open registration!)" - }, - ), - ( - "Allow guest registration (inherently false if allow registration is false)", - &self.allow_guest_registration.to_string(), - ), - ( - "Log guest registrations in admin room", - &self.log_guest_registrations.to_string(), - ), - ( - "Allow guests to auto join rooms", - &self.allow_guests_auto_join_rooms.to_string(), - ), - ("New user display name suffix", &self.new_user_displayname_suffix), - ("Allow encryption", &self.allow_encryption.to_string()), - ("Allow federation", &self.allow_federation.to_string()), - ( - "Allow incoming federated presence requests (updates)", - &self.allow_incoming_presence.to_string(), - ), - ( - "Allow outgoing federated presence requests (updates)", - &self.allow_outgoing_presence.to_string(), - ), - ( - "Allow local presence requests (updates)", - &self.allow_local_presence.to_string(), - ), - ( - "Allow incoming remote read receipts", - &self.allow_incoming_read_receipts.to_string(), - ), - ( - "Allow outgoing remote read receipts", - &self.allow_outgoing_read_receipts.to_string(), - ), - ( - "Block non-admin room invites (local and remote, admins can still send and receive invites)", - &self.block_non_admin_invites.to_string(), - ), - ("Enable admin escape commands", &self.admin_escape_commands.to_string()), - ("Allow outgoing federated typing", &self.allow_outgoing_typing.to_string()), - ("Allow incoming federated typing", &self.allow_incoming_typing.to_string()), - ( - "Incoming federated typing timeout", - &self.typing_federation_timeout_s.to_string(), - ), - ("Client typing timeout minimum", &self.typing_client_timeout_min_s.to_string()), - ("Client typing timeout maxmimum", &self.typing_client_timeout_max_s.to_string()), - ("Allow device name federation", &self.allow_device_name_federation.to_string()), - ( - "Allow incoming profile lookup federation requests", - &self.allow_profile_lookup_federation_requests.to_string(), - ), - ( - "Auto deactivate banned room join attempts", - &self.auto_deactivate_banned_room_attempts.to_string(), - ), - ("Notification push path", &self.notification_push_path), - ("Allow room creation", &self.allow_room_creation.to_string()), - ( - "Allow public room directory over federation", - &self.allow_public_room_directory_over_federation.to_string(), - ), - ( - "Allow public room directory without authentication", - &self.allow_public_room_directory_without_auth.to_string(), - ), - ( - "Lockdown public room directory (only allow admins to publish)", - &self.lockdown_public_room_directory.to_string(), - ), - ( - "JWT secret", - match self.jwt_secret { - Some(_) => "set", - None => "not set", - }, - ), - ( - "Trusted key servers", - &self - .trusted_servers - .iter() - .map(|server| server.host()) - .join(", "), - ), - ( - "Query Trusted Key Servers First", - &self.query_trusted_key_servers_first.to_string(), - ), - ("OpenID Token TTL", &self.openid_token_ttl.to_string()), - ( - "TURN username", - if self.turn_username.is_empty() { - "not set" - } else { - &self.turn_username - }, - ), - ("TURN password", { - if self.turn_password.is_empty() { - "not set" - } else { - "set" - } - }), - ("TURN secret", { - if self.turn_secret.is_empty() { - "not set" - } else { - "set" - } - }), - ("Turn TTL", &self.turn_ttl.to_string()), - ("Turn URIs", { - let mut lst = vec![]; - for item in self.turn_uris.iter().cloned().enumerate() { - let (_, uri): (usize, String) = item; - lst.push(uri); - } - &lst.join(", ") - }), - ("Auto Join Rooms", { - let mut lst = vec![]; - for room in &self.auto_join_rooms { - lst.push(room); - } - &lst.into_iter().join(", ") - }), - #[cfg(feature = "zstd_compression")] - ("Zstd HTTP Compression", &self.zstd_compression.to_string()), - #[cfg(feature = "gzip_compression")] - ("Gzip HTTP Compression", &self.gzip_compression.to_string()), - #[cfg(feature = "brotli_compression")] - ("Brotli HTTP Compression", &self.brotli_compression.to_string()), - ("RocksDB database LOG level", &self.rocksdb_log_level), - ("RocksDB database LOG to stderr", &self.rocksdb_log_stderr.to_string()), - ("RocksDB database LOG time-to-roll", &self.rocksdb_log_time_to_roll.to_string()), - ("RocksDB Max LOG Files", &self.rocksdb_max_log_files.to_string()), - ( - "RocksDB database max LOG file size", - &self.rocksdb_max_log_file_size.to_string(), - ), - ( - "RocksDB database optimize for spinning disks", - &self.rocksdb_optimize_for_spinning_disks.to_string(), - ), - ("RocksDB Direct-IO", &self.rocksdb_direct_io.to_string()), - ("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()), - ("RocksDB Compression Algorithm", &self.rocksdb_compression_algo), - ("RocksDB Compression Level", &self.rocksdb_compression_level.to_string()), - ( - "RocksDB Bottommost Compression Level", - &self.rocksdb_bottommost_compression_level.to_string(), - ), - ( - "RocksDB Bottommost Level Compression", - &self.rocksdb_bottommost_compression.to_string(), - ), - ("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()), - ("RocksDB Repair Mode", &self.rocksdb_repair.to_string()), - ("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()), - ( - "RocksDB Compaction Idle Priority", - &self.rocksdb_compaction_prio_idle.to_string(), - ), - ( - "RocksDB Compaction Idle IOPriority", - &self.rocksdb_compaction_ioprio_idle.to_string(), - ), - ("Media integrity checks on startup", &self.media_startup_check.to_string()), - ("Media compatibility filesystem links", &self.media_compat_file_link.to_string()), - ("Prevent Media Downloads From", { - let mut lst = vec![]; - for domain in &self.prevent_media_downloads_from { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Forbidden Remote Server Names (\"Global\" ACLs)", { - let mut lst = vec![]; - for domain in &self.forbidden_remote_server_names { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Forbidden Remote Room Directory Server Names", { - let mut lst = vec![]; - for domain in &self.forbidden_remote_room_directory_server_names { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Outbound Request IP Range Denylist", { - let mut lst = vec![]; - for item in self.ip_range_denylist.iter().cloned().enumerate() { - let (_, ip): (usize, String) = item; - lst.push(ip); - } - &lst.join(", ") - }), - ("Forbidden usernames", { - &self.forbidden_usernames.patterns().iter().join(", ") - }), - ("Forbidden room aliases", { - &self.forbidden_alias_names.patterns().iter().join(", ") - }), - ( - "URL preview domain contains allowlist", - &self.url_preview_domain_contains_allowlist.join(", "), - ), - ( - "URL preview domain explicit allowlist", - &self.url_preview_domain_explicit_allowlist.join(", "), - ), - ( - "URL preview domain explicit denylist", - &self.url_preview_domain_explicit_denylist.join(", "), - ), - ( - "URL preview URL contains allowlist", - &self.url_preview_url_contains_allowlist.join(", "), - ), - ("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()), - ("URL preview check root domain", &self.url_preview_check_root_domain.to_string()), - ( - "Allow check for updates / announcements check", - &self.allow_check_for_updates.to_string(), - ), - ("Enable netburst on startup", &self.startup_netburst.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io reporting and tracing", &self.sentry.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io send server_name in logs", &self.sentry_send_server_name.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io tracing sample rate", &self.sentry_traces_sample_rate.to_string()), - ( - "Well-known server name", - self.well_known - .server - .as_ref() - .map_or("", |server| server.as_str()), - ), - ( - "Well-known client URL", - self.well_known - .client - .as_ref() - .map_or("", |url| url.as_str()), - ), - ( - "Well-known support email", - self.well_known - .support_email - .as_ref() - .map_or("", |str| str.as_ref()), - ), - ( - "Well-known support Matrix ID", - self.well_known - .support_mxid - .as_ref() - .map_or("", |mxid| mxid.as_str()), - ), - ( - "Well-known support role", - self.well_known - .support_role - .as_ref() - .map_or("", |role| role.as_str()), - ), - ( - "Well-known support page/URL", - self.well_known - .support_page - .as_ref() - .map_or("", |url| url.as_str()), - ), - ("Enable the tokio-console", &self.tokio_console.to_string()), - ]; - - let mut msg: String = "Active config values:\n\n".to_owned(); - - for line in lines.into_iter().enumerate() { - writeln!(msg, "{}: {}", line.1 .0, line.1 .1).expect("should be able to write to string buffer"); - } + writeln!(f, "Active config values:\n\n").expect("wrote line to formatter stream"); + let mut line = |key: &str, val: &str| { + writeln!(f, "{key}: {val}").expect("wrote line to formatter stream"); + }; - write!(f, "{msg}") + line("Server name", self.server_name.host()); + line("Database backend", &self.database_backend); + line("Database path", &self.database_path.to_string_lossy()); + line( + "Database backup path", + self.database_backup_path + .as_ref() + .map_or("", |path| path.to_str().unwrap_or("")), + ); + line("Database backups to keep", &self.database_backups_to_keep.to_string()); + line("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()); + line("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()); + line("PDU cache capacity", &self.pdu_cache_capacity.to_string()); + line("Auth chain cache capacity", &self.auth_chain_cache_capacity.to_string()); + line("Short eventid cache capacity", &self.shorteventid_cache_capacity.to_string()); + line("Eventid short cache capacity", &self.eventidshort_cache_capacity.to_string()); + line("Short statekey cache capacity", &self.shortstatekey_cache_capacity.to_string()); + line("Statekey short cache capacity", &self.statekeyshort_cache_capacity.to_string()); + line( + "Server visibility cache capacity", + &self.server_visibility_cache_capacity.to_string(), + ); + line( + "User visibility cache capacity", + &self.user_visibility_cache_capacity.to_string(), + ); + line("Stateinfo cache capacity", &self.stateinfo_cache_capacity.to_string()); + line( + "Roomid space hierarchy cache capacity", + &self.roomid_spacehierarchy_cache_capacity.to_string(), + ); + line("DNS cache entry limit", &self.dns_cache_entries.to_string()); + line("DNS minimum TTL", &self.dns_min_ttl.to_string()); + line("DNS minimum NXDOMAIN TTL", &self.dns_min_ttl_nxdomain.to_string()); + line("DNS attempts", &self.dns_attempts.to_string()); + line("DNS timeout", &self.dns_timeout.to_string()); + line("DNS fallback to TCP", &self.dns_tcp_fallback.to_string()); + line("DNS query over TCP only", &self.query_over_tcp_only.to_string()); + line("Query all nameservers", &self.query_all_nameservers.to_string()); + line("Maximum request size (bytes)", &self.max_request_size.to_string()); + line("Sender retry backoff limit", &self.sender_retry_backoff_limit.to_string()); + line("Request connect timeout", &self.request_conn_timeout.to_string()); + line("Request timeout", &self.request_timeout.to_string()); + line("Request total timeout", &self.request_total_timeout.to_string()); + line("Idle connections per host", &self.request_idle_per_host.to_string()); + line("Request pool idle timeout", &self.request_idle_timeout.to_string()); + line("Well_known connect timeout", &self.well_known_conn_timeout.to_string()); + line("Well_known timeout", &self.well_known_timeout.to_string()); + line("Federation timeout", &self.federation_timeout.to_string()); + line("Federation pool idle per host", &self.federation_idle_per_host.to_string()); + line("Federation pool idle timeout", &self.federation_idle_timeout.to_string()); + line("Sender timeout", &self.sender_timeout.to_string()); + line("Sender pool idle timeout", &self.sender_idle_timeout.to_string()); + line("Appservice timeout", &self.appservice_timeout.to_string()); + line("Appservice pool idle timeout", &self.appservice_idle_timeout.to_string()); + line("Pusher pool idle timeout", &self.pusher_idle_timeout.to_string()); + line("Allow registration", &self.allow_registration.to_string()); + line( + "Registration token", + if self.registration_token.is_some() { + "set" + } else { + "not set (open registration!)" + }, + ); + line( + "Allow guest registration (inherently false if allow registration is false)", + &self.allow_guest_registration.to_string(), + ); + line( + "Log guest registrations in admin room", + &self.log_guest_registrations.to_string(), + ); + line( + "Allow guests to auto join rooms", + &self.allow_guests_auto_join_rooms.to_string(), + ); + line("New user display name suffix", &self.new_user_displayname_suffix); + line("Allow encryption", &self.allow_encryption.to_string()); + line("Allow federation", &self.allow_federation.to_string()); + line( + "Allow incoming federated presence requests (updates)", + &self.allow_incoming_presence.to_string(), + ); + line( + "Allow outgoing federated presence requests (updates)", + &self.allow_outgoing_presence.to_string(), + ); + line( + "Allow local presence requests (updates)", + &self.allow_local_presence.to_string(), + ); + line( + "Allow incoming remote read receipts", + &self.allow_incoming_read_receipts.to_string(), + ); + line( + "Allow outgoing remote read receipts", + &self.allow_outgoing_read_receipts.to_string(), + ); + line( + "Block non-admin room invites (local and remote, admins can still send and receive invites)", + &self.block_non_admin_invites.to_string(), + ); + line("Enable admin escape commands", &self.admin_escape_commands.to_string()); + line("Allow outgoing federated typing", &self.allow_outgoing_typing.to_string()); + line("Allow incoming federated typing", &self.allow_incoming_typing.to_string()); + line( + "Incoming federated typing timeout", + &self.typing_federation_timeout_s.to_string(), + ); + line("Client typing timeout minimum", &self.typing_client_timeout_min_s.to_string()); + line("Client typing timeout maxmimum", &self.typing_client_timeout_max_s.to_string()); + line("Allow device name federation", &self.allow_device_name_federation.to_string()); + line( + "Allow incoming profile lookup federation requests", + &self.allow_profile_lookup_federation_requests.to_string(), + ); + line( + "Auto deactivate banned room join attempts", + &self.auto_deactivate_banned_room_attempts.to_string(), + ); + line("Notification push path", &self.notification_push_path); + line("Allow room creation", &self.allow_room_creation.to_string()); + line( + "Allow public room directory over federation", + &self.allow_public_room_directory_over_federation.to_string(), + ); + line( + "Allow public room directory without authentication", + &self.allow_public_room_directory_without_auth.to_string(), + ); + line( + "Lockdown public room directory (only allow admins to publish)", + &self.lockdown_public_room_directory.to_string(), + ); + line( + "JWT secret", + match self.jwt_secret { + Some(_) => "set", + None => "not set", + }, + ); + line( + "Trusted key servers", + &self + .trusted_servers + .iter() + .map(|server| server.host()) + .join(", "), + ); + line( + "Query Trusted Key Servers First", + &self.query_trusted_key_servers_first.to_string(), + ); + line("OpenID Token TTL", &self.openid_token_ttl.to_string()); + line( + "TURN username", + if self.turn_username.is_empty() { + "not set" + } else { + &self.turn_username + }, + ); + line("TURN password", { + if self.turn_password.is_empty() { + "not set" + } else { + "set" + } + }); + line("TURN secret", { + if self.turn_secret.is_empty() { + "not set" + } else { + "set" + } + }); + line("Turn TTL", &self.turn_ttl.to_string()); + line("Turn URIs", { + let mut lst = vec![]; + for item in self.turn_uris.iter().cloned().enumerate() { + let (_, uri): (usize, String) = item; + lst.push(uri); + } + &lst.join(", ") + }); + line("Auto Join Rooms", { + let mut lst = vec![]; + for room in &self.auto_join_rooms { + lst.push(room); + } + &lst.into_iter().join(", ") + }); + line("Zstd HTTP Compression", &self.zstd_compression.to_string()); + line("Gzip HTTP Compression", &self.gzip_compression.to_string()); + line("Brotli HTTP Compression", &self.brotli_compression.to_string()); + line("RocksDB database LOG level", &self.rocksdb_log_level); + line("RocksDB database LOG to stderr", &self.rocksdb_log_stderr.to_string()); + line("RocksDB database LOG time-to-roll", &self.rocksdb_log_time_to_roll.to_string()); + line("RocksDB Max LOG Files", &self.rocksdb_max_log_files.to_string()); + line( + "RocksDB database max LOG file size", + &self.rocksdb_max_log_file_size.to_string(), + ); + line( + "RocksDB database optimize for spinning disks", + &self.rocksdb_optimize_for_spinning_disks.to_string(), + ); + line("RocksDB Direct-IO", &self.rocksdb_direct_io.to_string()); + line("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()); + line("RocksDB Compression Algorithm", &self.rocksdb_compression_algo); + line("RocksDB Compression Level", &self.rocksdb_compression_level.to_string()); + line( + "RocksDB Bottommost Compression Level", + &self.rocksdb_bottommost_compression_level.to_string(), + ); + line( + "RocksDB Bottommost Level Compression", + &self.rocksdb_bottommost_compression.to_string(), + ); + line("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()); + line("RocksDB Repair Mode", &self.rocksdb_repair.to_string()); + line("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()); + line( + "RocksDB Compaction Idle Priority", + &self.rocksdb_compaction_prio_idle.to_string(), + ); + line( + "RocksDB Compaction Idle IOPriority", + &self.rocksdb_compaction_ioprio_idle.to_string(), + ); + line("Media integrity checks on startup", &self.media_startup_check.to_string()); + line("Media compatibility filesystem links", &self.media_compat_file_link.to_string()); + line("Prevent Media Downloads From", { + let mut lst = vec![]; + for domain in &self.prevent_media_downloads_from { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Forbidden Remote Server Names (\"Global\" ACLs)", { + let mut lst = vec![]; + for domain in &self.forbidden_remote_server_names { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Forbidden Remote Room Directory Server Names", { + let mut lst = vec![]; + for domain in &self.forbidden_remote_room_directory_server_names { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Outbound Request IP Range Denylist", { + let mut lst = vec![]; + for item in self.ip_range_denylist.iter().cloned().enumerate() { + let (_, ip): (usize, String) = item; + lst.push(ip); + } + &lst.join(", ") + }); + line("Forbidden usernames", { + &self.forbidden_usernames.patterns().iter().join(", ") + }); + line("Forbidden room aliases", { + &self.forbidden_alias_names.patterns().iter().join(", ") + }); + line( + "URL preview domain contains allowlist", + &self.url_preview_domain_contains_allowlist.join(", "), + ); + line( + "URL preview domain explicit allowlist", + &self.url_preview_domain_explicit_allowlist.join(", "), + ); + line( + "URL preview domain explicit denylist", + &self.url_preview_domain_explicit_denylist.join(", "), + ); + line( + "URL preview URL contains allowlist", + &self.url_preview_url_contains_allowlist.join(", "), + ); + line("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()); + line("URL preview check root domain", &self.url_preview_check_root_domain.to_string()); + line( + "Allow check for updates / announcements check", + &self.allow_check_for_updates.to_string(), + ); + line("Enable netburst on startup", &self.startup_netburst.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io reporting and tracing", &self.sentry.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io send server_name in logs", &self.sentry_send_server_name.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io tracing sample rate", &self.sentry_traces_sample_rate.to_string()); + line("Sentry.io attach stacktrace", &self.sentry_attach_stacktrace.to_string()); + line("Sentry.io send panics", &self.sentry_send_panic.to_string()); + line("Sentry.io send errors", &self.sentry_send_error.to_string()); + line("Sentry.io tracing filter", &self.sentry_filter); + line( + "Well-known server name", + self.well_known + .server + .as_ref() + .map_or("", |server| server.as_str()), + ); + line( + "Well-known client URL", + self.well_known + .client + .as_ref() + .map_or("", |url| url.as_str()), + ); + line( + "Well-known support email", + self.well_known + .support_email + .as_ref() + .map_or("", |str| str.as_ref()), + ); + line( + "Well-known support Matrix ID", + self.well_known + .support_mxid + .as_ref() + .map_or("", |mxid| mxid.as_str()), + ); + line( + "Well-known support role", + self.well_known + .support_role + .as_ref() + .map_or("", |role| role.as_str()), + ); + line( + "Well-known support page/URL", + self.well_known + .support_page + .as_ref() + .map_or("", |url| url.as_str()), + ); + line("Enable the tokio-console", &self.tokio_console.to_string()); + + Ok(()) } } @@ -968,10 +917,20 @@ fn default_pusher_idle_timeout() -> u64 { 15 } fn default_max_fetch_prev_events() -> u16 { 100_u16 } -#[cfg(feature = "perf_measurements")] -fn default_tracing_flame_filter() -> String { "trace,h2=off".to_owned() } +fn default_tracing_flame_filter() -> String { + cfg!(debug_assertions) + .then_some("trace,h2=off") + .unwrap_or("info") + .to_owned() +} + +fn default_jaeger_filter() -> String { + cfg!(debug_assertions) + .then_some("trace,h2=off") + .unwrap_or("info") + .to_owned() +} -#[cfg(feature = "perf_measurements")] fn default_tracing_flame_output_path() -> String { "./tracing.folded".to_owned() } fn default_trusted_servers() -> Vec<OwnedServerName> { vec![OwnedServerName::try_from("matrix.org").unwrap()] } @@ -1070,4 +1029,6 @@ fn default_sentry_endpoint() -> Option<Url> { fn default_sentry_traces_sample_rate() -> f32 { 0.15 } +fn default_sentry_filter() -> String { "info".to_owned() } + fn default_startup_netburst_keep() -> i64 { 50 } diff --git a/src/core/debug.rs b/src/core/debug.rs index 1f855e520b2262a3b7fb0df49d7789664776bec9..14d0be87a0c33f10320dbaeec33d8cabbbc2b321 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,6 +1,4 @@ -#![allow(dead_code)] // this is a developer's toolbox - -use std::panic; +use std::{any::Any, panic}; /// Export all of the ancillary tools from here as well. pub use crate::utils::debug::*; @@ -79,3 +77,6 @@ pub fn trap() { std::arch::asm!("int3"); } } + +#[must_use] +pub fn panic_str(p: &Box<dyn Any + Send>) -> &'static str { p.downcast_ref::<&str>().copied().unwrap_or_default() } diff --git a/src/core/error.rs b/src/core/error.rs deleted file mode 100644 index 1959081a32aa40339b43b5947b09c2eb35911065..0000000000000000000000000000000000000000 --- a/src/core/error.rs +++ /dev/null @@ -1,247 +0,0 @@ -use std::{convert::Infallible, fmt}; - -use bytes::BytesMut; -use http::StatusCode; -use http_body_util::Full; -use ruma::{ - api::{client::uiaa::UiaaResponse, OutgoingResponse}, - OwnedServerName, -}; - -use crate::{debug_error, error}; - -#[derive(thiserror::Error)] -pub enum Error { - // std - #[error("{0}")] - Fmt(#[from] fmt::Error), - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - #[error("{0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("{0}")] - FromUtf8Error(#[from] std::string::FromUtf8Error), - #[error("{0}")] - TryFromSliceError(#[from] std::array::TryFromSliceError), - #[error("{0}")] - TryFromIntError(#[from] std::num::TryFromIntError), - #[error("{0}")] - ParseIntError(#[from] std::num::ParseIntError), - #[error("{0}")] - ParseFloatError(#[from] std::num::ParseFloatError), - - // third-party - #[error("Regex error: {0}")] - Regex(#[from] regex::Error), - #[error("Tracing filter error: {0}")] - TracingFilter(#[from] tracing_subscriber::filter::ParseError), - #[error("Image error: {0}")] - Image(#[from] image::error::ImageError), - #[error("Request error: {0}")] - Reqwest(#[from] reqwest::Error), - #[error("{0}")] - Extension(#[from] axum::extract::rejection::ExtensionRejection), - #[error("{0}")] - Path(#[from] axum::extract::rejection::PathRejection), - #[error("{0}")] - Http(#[from] http::Error), - - // ruma - #[error("{0}")] - IntoHttpError(#[from] ruma::api::error::IntoHttpError), - #[error("{0}")] - RumaError(#[from] ruma::api::client::error::Error), - #[error("uiaa")] - Uiaa(ruma::api::client::uiaa::UiaaInfo), - #[error("{0}")] - Mxid(#[from] ruma::IdParseError), - #[error("{0}: {1}")] - BadRequest(ruma::api::client::error::ErrorKind, &'static str), - #[error("from {0}: {1}")] - Redaction(OwnedServerName, ruma::canonical_json::RedactionError), - #[error("Remote server {0} responded with: {1}")] - Federation(OwnedServerName, ruma::api::client::error::Error), - #[error("{0} in {1}")] - InconsistentRoomState(&'static str, ruma::OwnedRoomId), - - // conduwuit - #[error("Arithmetic operation failed: {0}")] - Arithmetic(&'static str), - #[error("There was a problem with your configuration: {0}")] - BadConfig(String), - #[error("{0}")] - BadDatabase(&'static str), - #[error("{0}")] - Database(String), - #[error("{0}")] - BadServerResponse(&'static str), - #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists - - // unique / untyped - #[error("{0}")] - Err(String), -} - -impl Error { - pub fn bad_database(message: &'static str) -> Self { - error!("BadDatabase: {}", message); - Self::BadDatabase(message) - } - - pub fn bad_config(message: &str) -> Self { - error!("BadConfig: {}", message); - Self::BadConfig(message.to_owned()) - } - - /// Returns the Matrix error code / error kind - #[inline] - pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { - use ruma::api::client::error::ErrorKind::Unknown; - - match self { - Self::Federation(_, error) => ruma_error_kind(error).clone(), - Self::BadRequest(kind, _) => kind.clone(), - _ => Unknown, - } - } - - /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_error(&self) -> String { - match self { - Self::Database(..) => String::from("Database error occurred."), - Self::Io(..) => String::from("I/O error occurred."), - _ => self.to_string(), - } - } -} - -#[inline] -pub fn log(e: &Error) { - error!(?e); -} - -#[inline] -pub fn debug_log(e: &Error) { - debug_error!(?e); -} - -#[inline] -pub fn into_log(e: Error) { - error!(?e); - drop(e); -} - -#[inline] -pub fn into_debug_log(e: Error) { - debug_error!(?e); - drop(e); -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } -} - -impl From<Infallible> for Error { - fn from(i: Infallible) -> Self { match i {} } -} - -impl axum::response::IntoResponse for Error { - fn into_response(self) -> axum::response::Response { - let response: UiaaResponse = self.into(); - response - .try_into_http_response::<BytesMut>() - .inspect_err(|e| error!(?e)) - .map_or_else( - |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), - |r| r.map(BytesMut::freeze).map(Full::new).into_response(), - ) - } -} - -impl From<Error> for UiaaResponse { - fn from(error: Error) -> Self { - if let Error::Uiaa(uiaainfo) = error { - return Self::AuthResponse(uiaainfo); - } - - let kind = match &error { - Error::Federation(_, ref error) | Error::RumaError(ref error) => ruma_error_kind(error), - Error::BadRequest(kind, _) => kind, - _ => &ruma::api::client::error::ErrorKind::Unknown, - }; - - let status_code = match &error { - Error::Federation(_, ref error) | Error::RumaError(ref error) => error.status_code, - Error::BadRequest(ref kind, _) => bad_request_code(kind), - Error::Conflict(_) => StatusCode::CONFLICT, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - - let message = match &error { - Error::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), - Error::RumaError(ref error) => ruma_error_message(error), - _ => format!("{error}"), - }; - - let body = ruma::api::client::error::ErrorBody::Standard { - kind: kind.clone(), - message, - }; - - Self::MatrixError(ruma::api::client::error::Error { - status_code, - body, - }) - } -} - -fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { - use ruma::api::client::error::ErrorKind::*; - - match kind { - GuestAccessForbidden - | ThreepidAuthFailed - | UserDeactivated - | ThreepidDenied - | WrongRoomKeysVersion { - .. - } - | Forbidden { - .. - } => StatusCode::FORBIDDEN, - - UnknownToken { - .. - } - | MissingToken - | Unauthorized => StatusCode::UNAUTHORIZED, - - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - - NotFound | Unrecognized => StatusCode::NOT_FOUND, - - _ => StatusCode::BAD_REQUEST, - } -} - -fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { - if let ruma::api::client::error::ErrorBody::Standard { - message, - .. - } = &error.body - { - return message.to_string(); - } - - format!("{error}") -} - -fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { - e.error_kind() - .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) -} diff --git a/src/core/error/err.rs b/src/core/error/err.rs new file mode 100644 index 0000000000000000000000000000000000000000..ea596644d7d93c587c4868598fcf186c1d66d5f4 --- /dev/null +++ b/src/core/error/err.rs @@ -0,0 +1,95 @@ +//! Error construction macros +//! +//! These are specialized macros specific to this project's patterns for +//! throwing Errors; they make Error construction succinct and reduce clutter. +//! They are developed from folding existing patterns into the macro while +//! fixing several anti-patterns in the codebase. +//! +//! - The primary macros `Err!` and `err!` are provided. `Err!` simply wraps +//! `err!` in the Result variant to reduce `Err(err!(...))` boilerplate, thus +//! `err!` can be used in any case. +//! +//! 1. The macro makes the general Error construction easy: `return +//! Err!("something went wrong")` replaces the prior `return +//! Err(Error::Err("something went wrong".to_owned()))`. +//! +//! 2. The macro integrates format strings automatically: `return +//! Err!("something bad: {msg}")` replaces the prior `return +//! Err(Error::Err(format!("something bad: {msg}")))`. +//! +//! 3. The macro scopes variants of Error: `return Err!(Database("problem with +//! bad database."))` replaces the prior `return Err(Error::Database("problem +//! with bad database."))`. +//! +//! 4. The macro matches and scopes some special-case sub-variants, for example +//! with ruma ErrorKind: `return Err!(Request(MissingToken("you must provide +//! an access token")))`. +//! +//! 5. The macro fixes the anti-pattern of repeating messages in an error! log +//! and then again in an Error construction, often slightly different due to +//! the Error variant not supporting a format string. Instead `return +//! Err(Database(error!("problem with db: {msg}")))` logs the error at the +//! callsite and then returns the error with the same string. Caller has the +//! option of replacing `error!` with `debug_error!`. +#[macro_export] +macro_rules! Err { + ($($args:tt)*) => { + Err($crate::err!($($args)*)) + }; +} + +#[macro_export] +macro_rules! err { + (Config($item:literal, $($args:expr),*)) => {{ + $crate::error!(config = %$item, $($args),*); + $crate::error::Error::Config($item, $crate::format_maybe!($($args),*)) + }}; + + (Request(Forbidden($level:ident!($($args:expr),*)))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::format_maybe!($($args),*) + ) + }}; + + (Request(Forbidden($($args:expr),*))) => { + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::format_maybe!($($args),*) + ) + }; + + (Request($variant:ident($level:ident!($($args:expr),*)))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::$variant, + $crate::format_maybe!($($args),*) + ) + }}; + + (Request($variant:ident($($args:expr),*))) => { + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::$variant, + $crate::format_maybe!($($args),*) + ) + }; + + ($variant:ident($level:ident!($($args:expr),*))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::$variant($crate::format_maybe!($($args),*)) + }}; + + ($variant:ident($($args:expr),*)) => { + $crate::error::Error::$variant($crate::format_maybe!($($args),*)) + }; + + ($level:ident!($($args:expr),*)) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Err($crate::format_maybe!($($args),*)) + }}; + + ($($args:expr),*) => { + $crate::error::Error::Err($crate::format_maybe!($($args),*)) + }; +} diff --git a/src/core/error/log.rs b/src/core/error/log.rs new file mode 100644 index 0000000000000000000000000000000000000000..c272bf730c42ede2dd18d18e313546bdb18bf241 --- /dev/null +++ b/src/core/error/log.rs @@ -0,0 +1,74 @@ +use std::{convert::Infallible, fmt}; + +use super::Error; +use crate::{debug_error, error}; + +#[inline] +pub fn else_log<T, E>(error: E) -> Result<T, Infallible> +where + T: Default, + Error: From<E>, +{ + Ok(default_log(error)) +} + +#[inline] +pub fn else_debug_log<T, E>(error: E) -> Result<T, Infallible> +where + T: Default, + Error: From<E>, +{ + Ok(default_debug_log(error)) +} + +#[inline] +pub fn default_log<T, E>(error: E) -> T +where + T: Default, + Error: From<E>, +{ + let error = Error::from(error); + inspect_log(&error); + T::default() +} + +#[inline] +pub fn default_debug_log<T, E>(error: E) -> T +where + T: Default, + Error: From<E>, +{ + let error = Error::from(error); + inspect_debug_log(&error); + T::default() +} + +#[inline] +pub fn map_log<E>(error: E) -> Error +where + Error: From<E>, +{ + let error = Error::from(error); + inspect_log(&error); + error +} + +#[inline] +pub fn map_debug_log<E>(error: E) -> Error +where + Error: From<E>, +{ + let error = Error::from(error); + inspect_debug_log(&error); + error +} + +#[inline] +pub fn inspect_log<E: fmt::Display>(error: &E) { + error!("{error}"); +} + +#[inline] +pub fn inspect_debug_log<E: fmt::Debug>(error: &E) { + debug_error!("{error:?}"); +} diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..069fe60e4bacbcdcdbfc9e39ea748731eda4170f --- /dev/null +++ b/src/core/error/mod.rs @@ -0,0 +1,155 @@ +mod err; +mod log; +mod panic; +mod response; + +use std::{any::Any, borrow::Cow, convert::Infallible, fmt}; + +pub use log::*; + +use crate::error; + +#[derive(thiserror::Error)] +pub enum Error { + #[error("PANIC!")] + PanicAny(Box<dyn Any + Send>), + #[error("PANIC! {0}")] + Panic(&'static str, Box<dyn Any + Send + 'static>), + + // std + #[error("{0}")] + Fmt(#[from] fmt::Error), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("{0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), + #[error("{0}")] + TryFromSliceError(#[from] std::array::TryFromSliceError), + #[error("{0}")] + TryFromIntError(#[from] std::num::TryFromIntError), + #[error("{0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("{0}")] + ParseFloatError(#[from] std::num::ParseFloatError), + + // third-party + #[error("Join error: {0}")] + JoinError(#[from] tokio::task::JoinError), + #[error("Regex error: {0}")] + Regex(#[from] regex::Error), + #[error("Tracing filter error: {0}")] + TracingFilter(#[from] tracing_subscriber::filter::ParseError), + #[error("Tracing reload error: {0}")] + TracingReload(#[from] tracing_subscriber::reload::Error), + #[error("Image error: {0}")] + Image(#[from] image::error::ImageError), + #[error("Request error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Extension(#[from] axum::extract::rejection::ExtensionRejection), + #[error("{0}")] + Path(#[from] axum::extract::rejection::PathRejection), + #[error("{0}")] + Http(#[from] http::Error), + #[error("{0}")] + HttpHeader(#[from] http::header::InvalidHeaderValue), + + // ruma + #[error("{0}")] + IntoHttpError(#[from] ruma::api::error::IntoHttpError), + #[error("{0}")] + RumaError(#[from] ruma::api::client::error::Error), + #[error("uiaa")] + Uiaa(ruma::api::client::uiaa::UiaaInfo), + #[error("{0}")] + Mxid(#[from] ruma::IdParseError), + #[error("{0}: {1}")] + BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove + #[error("{0}: {1}")] + Request(ruma::api::client::error::ErrorKind, Cow<'static, str>), + #[error("from {0}: {1}")] + Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError), + #[error("Remote server {0} responded with: {1}")] + Federation(ruma::OwnedServerName, ruma::api::client::error::Error), + #[error("{0} in {1}")] + InconsistentRoomState(&'static str, ruma::OwnedRoomId), + + // conduwuit + #[error("Arithmetic operation failed: {0}")] + Arithmetic(&'static str), + #[error("There was a problem with the '{0}' directive in your configuration: {1}")] + Config(&'static str, Cow<'static, str>), + #[error("{0}")] + Database(Cow<'static, str>), + #[error("{0}")] + BadServerResponse(&'static str), + #[error("{0}")] + Conflict(&'static str), // This is only needed for when a room alias already exists + + // unique / untyped + #[error("{0}")] + Err(Cow<'static, str>), +} + +impl Error { + pub fn bad_database(message: &'static str) -> Self { crate::err!(Database(error!("{message}"))) } + + /// Sanitizes public-facing errors that can leak sensitive information. + pub fn sanitized_string(&self) -> String { + match self { + Self::Database(..) => String::from("Database error occurred."), + Self::Io(..) => String::from("I/O error occurred."), + _ => self.to_string(), + } + } + + pub fn message(&self) -> String { + match self { + Self::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), + Self::RumaError(ref error) => response::ruma_error_message(error), + _ => format!("{self}"), + } + } + + /// Returns the Matrix error code / error kind + #[inline] + pub fn kind(&self) -> ruma::api::client::error::ErrorKind { + use ruma::api::client::error::ErrorKind::Unknown; + + match self { + Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::BadRequest(kind, _) | Self::Request(kind, _) => kind.clone(), + _ => Unknown, + } + } + + pub fn status_code(&self) -> http::StatusCode { + match self { + Self::Federation(_, ref error) | Self::RumaError(ref error) => error.status_code, + Self::BadRequest(ref kind, _) | Self::Request(ref kind, _) => response::bad_request_code(kind), + Self::Conflict(_) => http::StatusCode::CONFLICT, + _ => http::StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } +} + +#[allow(clippy::fallible_impl_from)] +impl From<Infallible> for Error { + #[cold] + #[inline(never)] + fn from(_e: Infallible) -> Self { + panic!("infallible error should never exist"); + } +} + +#[cold] +#[inline(never)] +pub fn infallible(_e: &Infallible) { + panic!("infallible error should never exist"); +} diff --git a/src/core/error/panic.rs b/src/core/error/panic.rs new file mode 100644 index 0000000000000000000000000000000000000000..c070f78669ddbd484a5c8835e2bc345073dfe634 --- /dev/null +++ b/src/core/error/panic.rs @@ -0,0 +1,41 @@ +use std::{ + any::Any, + panic::{panic_any, RefUnwindSafe, UnwindSafe}, +}; + +use super::Error; +use crate::debug; + +impl UnwindSafe for Error {} +impl RefUnwindSafe for Error {} + +impl Error { + pub fn panic(self) -> ! { panic_any(self.into_panic()) } + + #[must_use] + pub fn from_panic(e: Box<dyn Any + Send>) -> Self { Self::Panic(debug::panic_str(&e), e) } + + pub fn into_panic(self) -> Box<dyn Any + Send + 'static> { + match self { + Self::Panic(_, e) | Self::PanicAny(e) => e, + Self::JoinError(e) => e.into_panic(), + _ => Box::new(self), + } + } + + /// Get the panic message string. + pub fn panic_str(self) -> Option<&'static str> { + self.is_panic() + .then_some(debug::panic_str(&self.into_panic())) + } + + /// Check if the Error is trafficking a panic object. + #[inline] + pub fn is_panic(&self) -> bool { + match &self { + Self::Panic(..) | Self::PanicAny(..) => true, + Self::JoinError(e) => e.is_panic(), + _ => false, + } + } +} diff --git a/src/core/error/response.rs b/src/core/error/response.rs new file mode 100644 index 0000000000000000000000000000000000000000..4ea76e264072786255307af24c832afb8165491d --- /dev/null +++ b/src/core/error/response.rs @@ -0,0 +1,88 @@ +use bytes::BytesMut; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; + +use super::Error; +use crate::error; + +impl axum::response::IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let response: UiaaResponse = self.into(); + response + .try_into_http_response::<BytesMut>() + .inspect_err(|e| error!("error response error: {e}")) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} + +impl From<Error> for UiaaResponse { + fn from(error: Error) -> Self { + if let Error::Uiaa(uiaainfo) = error { + return Self::AuthResponse(uiaainfo); + } + + let body = ruma::api::client::error::ErrorBody::Standard { + kind: error.kind(), + message: error.message(), + }; + + Self::MatrixError(ruma::api::client::error::Error { + status_code: error.status_code(), + body, + }) + } +} + +pub(super) fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { + use ruma::api::client::error::ErrorKind::*; + + match kind { + GuestAccessForbidden + | ThreepidAuthFailed + | UserDeactivated + | ThreepidDenied + | WrongRoomKeysVersion { + .. + } + | Forbidden { + .. + } => StatusCode::FORBIDDEN, + + UnknownToken { + .. + } + | MissingToken + | Unauthorized => StatusCode::UNAUTHORIZED, + + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + NotFound | Unrecognized => StatusCode::NOT_FOUND, + + _ => StatusCode::BAD_REQUEST, + } +} + +pub(super) fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { + if let ruma::api::client::error::ErrorBody::Standard { + message, + .. + } = &error.body + { + return message.to_string(); + } + + format!("{error}") +} + +pub(super) fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { + e.error_kind() + .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) +} diff --git a/src/core/log/reload.rs b/src/core/log/reload.rs index 7646254edb996666121e41241067de938517c900..6d6510653b4c61f2803e6de588c35b92c918eee4 100644 --- a/src/core/log/reload.rs +++ b/src/core/log/reload.rs @@ -1,7 +1,12 @@ -use std::sync::Arc; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; use tracing_subscriber::{reload, EnvFilter}; +use crate::{error, Result}; + /// We need to store a reload::Handle value, but can't name it's type explicitly /// because the S type parameter depends on the subscriber's previous layers. In /// our case, this includes unnameable 'impl Trait' types. @@ -17,39 +22,60 @@ /// /// [1]: <https://github.com/tokio-rs/tracing/pull/1035/commits/8a87ea52425098d3ef8f56d92358c2f6c144a28f> pub trait ReloadHandle<L> { + fn current(&self) -> Option<L>; + fn reload(&self, new_value: L) -> Result<(), reload::Error>; } -impl<L, S> ReloadHandle<L> for reload::Handle<L, S> { - fn reload(&self, new_value: L) -> Result<(), reload::Error> { Self::reload(self, new_value) } -} +impl<L: Clone, S> ReloadHandle<L> for reload::Handle<L, S> { + fn current(&self) -> Option<L> { Self::clone_current(self) } -struct LogLevelReloadHandlesInner { - handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>, + fn reload(&self, new_value: L) -> Result<(), reload::Error> { Self::reload(self, new_value) } } -/// Wrapper to allow reloading the filter on several several -/// [`tracing_subscriber::reload::Handle`]s at once, with the same value. #[derive(Clone)] pub struct LogLevelReloadHandles { - inner: Arc<LogLevelReloadHandlesInner>, + handles: Arc<Mutex<HandleMap>>, } +type HandleMap = HashMap<String, Handle>; +type Handle = Box<dyn ReloadHandle<EnvFilter> + Send + Sync>; + impl LogLevelReloadHandles { - #[must_use] - pub fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> Self { - Self { - inner: Arc::new(LogLevelReloadHandlesInner { - handles, - }), - } + pub fn add(&self, name: &str, handle: Handle) { + self.handles + .lock() + .expect("locked") + .insert(name.into(), handle); } - pub fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> { - for handle in &self.inner.handles { - handle.reload(new_value.clone())?; - } + pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { + self.handles + .lock() + .expect("locked") + .iter() + .filter(|(name, _)| names.map_or(false, |names| names.contains(&name.as_str()))) + .for_each(|(_, handle)| { + _ = handle.reload(new_value.clone()).or_else(error::else_log); + }); Ok(()) } + + #[must_use] + pub fn current(&self, name: &str) -> Option<EnvFilter> { + self.handles + .lock() + .expect("locked") + .get(name) + .map(|handle| handle.current())? + } +} + +impl Default for LogLevelReloadHandles { + fn default() -> Self { + Self { + handles: Arc::new(HandleMap::new().into()), + } + } } diff --git a/src/core/log/suppress.rs b/src/core/log/suppress.rs index 6e883086dd2638a440627e8521f7396d93cf2da5..b13ee99ef2c28795e04af3de1f35cb2418472bb9 100644 --- a/src/core/log/suppress.rs +++ b/src/core/log/suppress.rs @@ -10,16 +10,21 @@ pub struct Suppress { impl Suppress { pub fn new(server: &Arc<Server>) -> Self { + let handle = "console"; let config = &server.config.log; - Self::from_filters(server, EnvFilter::try_new(config).unwrap_or_default(), &EnvFilter::default()) - } + let suppress = EnvFilter::default(); + let restore = server + .log + .reload + .current(handle) + .unwrap_or_else(|| EnvFilter::try_new(config).unwrap_or_default()); - fn from_filters(server: &Arc<Server>, restore: EnvFilter, suppress: &EnvFilter) -> Self { server .log .reload - .reload(suppress) + .reload(&suppress, Some(&[handle])) .expect("log filter reloaded"); + Self { server: server.clone(), restore, @@ -32,7 +37,7 @@ fn drop(&mut self) { self.server .log .reload - .reload(&self.restore) + .reload(&self.restore, Some(&["console"])) .expect("log filter reloaded"); } } diff --git a/src/core/metrics/mod.rs b/src/core/metrics/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..3ae139a8694e53bf9e6226d924ee35fe11bbc54b --- /dev/null +++ b/src/core/metrics/mod.rs @@ -0,0 +1,72 @@ +use std::sync::atomic::AtomicU32; + +use tokio::runtime; +use tokio_metrics::TaskMonitor; +#[cfg(tokio_unstable)] +use tokio_metrics::{RuntimeIntervals, RuntimeMonitor}; + +pub struct Metrics { + _runtime: Option<runtime::Handle>, + + runtime_metrics: Option<runtime::RuntimeMetrics>, + + task_monitor: Option<TaskMonitor>, + + #[cfg(tokio_unstable)] + _runtime_monitor: Option<RuntimeMonitor>, + + #[cfg(tokio_unstable)] + runtime_intervals: std::sync::Mutex<Option<RuntimeIntervals>>, + + // TODO: move stats + pub requests_spawn_active: AtomicU32, + pub requests_spawn_finished: AtomicU32, + pub requests_handle_active: AtomicU32, + pub requests_handle_finished: AtomicU32, + pub requests_panic: AtomicU32, +} + +impl Metrics { + #[must_use] + pub fn new(runtime: Option<runtime::Handle>) -> Self { + #[cfg(tokio_unstable)] + let runtime_monitor = runtime.as_ref().map(RuntimeMonitor::new); + + #[cfg(tokio_unstable)] + let runtime_intervals = runtime_monitor.as_ref().map(RuntimeMonitor::intervals); + + Self { + _runtime: runtime.clone(), + + runtime_metrics: runtime.as_ref().map(runtime::Handle::metrics), + + task_monitor: runtime.map(|_| TaskMonitor::new()), + + #[cfg(tokio_unstable)] + _runtime_monitor: runtime_monitor, + + #[cfg(tokio_unstable)] + runtime_intervals: std::sync::Mutex::new(runtime_intervals), + + requests_spawn_active: AtomicU32::new(0), + requests_spawn_finished: AtomicU32::new(0), + requests_handle_active: AtomicU32::new(0), + requests_handle_finished: AtomicU32::new(0), + requests_panic: AtomicU32::new(0), + } + } + + #[cfg(tokio_unstable)] + pub fn runtime_interval(&self) -> Option<tokio_metrics::RuntimeMetrics> { + self.runtime_intervals + .lock() + .expect("locked") + .as_mut() + .map(Iterator::next) + .expect("next interval") + } + + pub fn task_root(&self) -> Option<&TaskMonitor> { self.task_monitor.as_ref() } + + pub fn runtime_metrics(&self) -> Option<&runtime::RuntimeMetrics> { self.runtime_metrics.as_ref() } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index de8057fadcd355aea6ee92cd833ac63a516c0007..9716b46e87463f7457c85feeddcd541ed54d487b 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -3,6 +3,7 @@ pub mod debug; pub mod error; pub mod log; +pub mod metrics; pub mod mods; pub mod pdu; pub mod server; diff --git a/src/core/server.rs b/src/core/server.rs index 575924d3b86984361fc0097027cd2f86f580af45..bf0ab99d12e4885ac2e659148e3f8e6c6a309243 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -1,11 +1,11 @@ use std::{ - sync::atomic::{AtomicBool, AtomicU32, Ordering}, + sync::atomic::{AtomicBool, Ordering}, time::SystemTime, }; use tokio::{runtime, sync::broadcast}; -use crate::{config::Config, log, Error, Result}; +use crate::{config::Config, log::Log, metrics::Metrics, Err, Result}; /// Server runtime state; public portion pub struct Server { @@ -33,47 +33,39 @@ pub struct Server { pub signal: broadcast::Sender<&'static str>, /// Logging subsystem state - pub log: log::Log, - - /// TODO: move stats - pub requests_spawn_active: AtomicU32, - pub requests_spawn_finished: AtomicU32, - pub requests_handle_active: AtomicU32, - pub requests_handle_finished: AtomicU32, - pub requests_panic: AtomicU32, + pub log: Log, + + /// Metrics subsystem state + pub metrics: Metrics, } impl Server { #[must_use] - pub fn new(config: Config, runtime: Option<runtime::Handle>, log: log::Log) -> Self { + pub fn new(config: Config, runtime: Option<runtime::Handle>, log: Log) -> Self { Self { config, started: SystemTime::now(), stopping: AtomicBool::new(false), reloading: AtomicBool::new(false), restarting: AtomicBool::new(false), - runtime, + runtime: runtime.clone(), signal: broadcast::channel::<&'static str>(1).0, log, - requests_spawn_active: AtomicU32::new(0), - requests_spawn_finished: AtomicU32::new(0), - requests_handle_active: AtomicU32::new(0), - requests_handle_finished: AtomicU32::new(0), - requests_panic: AtomicU32::new(0), + metrics: Metrics::new(runtime), } } pub fn reload(&self) -> Result<()> { if cfg!(not(conduit_mods)) { - return Err(Error::Err("Reloading not enabled".into())); + return Err!("Reloading not enabled"); } if self.reloading.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Reloading already in progress".into())); + return Err!("Reloading already in progress"); } if self.stopping.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Shutdown already in progress".into())); + return Err!("Shutdown already in progress"); } self.signal("SIGINT").inspect_err(|_| { @@ -84,7 +76,7 @@ pub fn reload(&self) -> Result<()> { pub fn restart(&self) -> Result<()> { if self.restarting.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Restart already in progress".into())); + return Err!("Restart already in progress"); } self.shutdown() @@ -93,7 +85,7 @@ pub fn restart(&self) -> Result<()> { pub fn shutdown(&self) -> Result<()> { if self.stopping.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Shutdown already in progress".into())); + return Err!("Shutdown already in progress"); } self.signal("SIGTERM") @@ -102,7 +94,7 @@ pub fn shutdown(&self) -> Result<()> { pub fn signal(&self, sig: &'static str) -> Result<()> { if let Err(e) = self.signal.send(sig) { - return Err(Error::Err(format!("Failed to send signal: {e}"))); + return Err!("Failed to send signal: {e}"); } Ok(()) diff --git a/src/core/utils/hash/argon.rs b/src/core/utils/hash/argon.rs index 98cef00e37b3dcb2e3172383ff6b13a7a0e4ace7..0a1e1e14fe44b626fdac0d9733d04626daca8a84 100644 --- a/src/core/utils/hash/argon.rs +++ b/src/core/utils/hash/argon.rs @@ -5,7 +5,7 @@ PasswordVerifier, Version, }; -use crate::{Error, Result}; +use crate::{err, Error, Result}; const M_COST: u32 = Params::DEFAULT_M_COST; // memory size in 1 KiB blocks const T_COST: u32 = Params::DEFAULT_T_COST; // nr of iterations @@ -44,7 +44,7 @@ pub(super) fn verify_password(password: &str, password_hash: &str) -> Result<()> .map_err(map_err) } -fn map_err(e: password_hash::Error) -> Error { Error::Err(e.to_string()) } +fn map_err(e: password_hash::Error) -> Error { err!("{e}") } #[cfg(test)] mod tests { diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 2b79c3c4288c5a9fb4139ebd0201665fffc489ce..bbd528290e4c982b07a03011ea00482473f4283e 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -20,27 +20,14 @@ pub use hash::calculate_hash; pub use html::Escape as HtmlEscape; pub use json::{deserialize_from_str, to_canonical_object}; -pub use mutex_map::MutexMap; +pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; pub use rand::string as random_string; pub use string::{str_from_bytes, string_from_bytes}; pub use sys::available_parallelism; pub use time::now_millis as millis_since_unix_epoch; -use crate::Result; - pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } -/// Boilerplate for wraps which are typed to never error. -/// -/// * <https://doc.rust-lang.org/std/convert/enum.Infallible.html> -#[must_use] -pub fn unwrap_infallible<T>(result: Result<T, std::convert::Infallible>) -> T { - match result { - Ok(val) => val, - Err(err) => match err {}, - } -} - #[must_use] pub fn generate_keypair() -> Vec<u8> { let mut value = rand::string(8).as_bytes().to_vec(); diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index f102487ccc2e51a8dc557f3495955ebd520d5d1e..c3c51798c0e5e3b9bf6c5f072f1f7c089cfb2c4e 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,20 +1,22 @@ -use std::{hash::Hash, sync::Arc}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; -type Value<Val> = tokio::sync::Mutex<Val>; -type ArcMutex<Val> = Arc<Value<Val>>; -type HashMap<Key, Val> = std::collections::HashMap<Key, ArcMutex<Val>>; -type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>; -type Map<Key, Val> = MapMutex<Key, Val>; +use tokio::sync::OwnedMutexGuard as Omg; /// Map of Mutexes pub struct MutexMap<Key, Val> { map: Map<Key, Val>, } -pub struct Guard<Val> { - _guard: tokio::sync::OwnedMutexGuard<Val>, +pub struct Guard<Key, Val> { + map: Map<Key, Val>, + val: Omg<Val>, } +type Map<Key, Val> = Arc<MapMutex<Key, Val>>; +type MapMutex<Key, Val> = std::sync::Mutex<HashMap<Key, Val>>; +type HashMap<Key, Val> = std::collections::HashMap<Key, Value<Val>>; +type Value<Val> = Arc<tokio::sync::Mutex<Val>>; + impl<Key, Val> MutexMap<Key, Val> where Key: Send + Hash + Eq + Clone, @@ -23,28 +25,38 @@ impl<Key, Val> MutexMap<Key, Val> #[must_use] pub fn new() -> Self { Self { - map: Map::<Key, Val>::new(HashMap::<Key, Val>::new()), + map: Map::new(MapMutex::new(HashMap::new())), } } - pub async fn lock<K>(&self, k: &K) -> Guard<Val> + #[tracing::instrument(skip(self), level = "debug")] + pub async fn lock<K>(&self, k: &K) -> Guard<Key, Val> where - K: ?Sized + Send + Sync, + K: ?Sized + Send + Sync + Debug, Key: for<'a> From<&'a K>, { let val = self .map .lock() - .expect("map mutex locked") + .expect("locked") .entry(k.into()) .or_default() .clone(); - let guard = val.lock_owned().await; - Guard::<Val> { - _guard: guard, + Guard::<Key, Val> { + map: Arc::clone(&self.map), + val: val.lock_owned().await, } } + + #[must_use] + pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } + + #[must_use] + pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } + + #[must_use] + pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl<Key, Val> Default for MutexMap<Key, Val> @@ -54,3 +66,14 @@ impl<Key, Val> Default for MutexMap<Key, Val> { fn default() -> Self { Self::new() } } + +impl<Key, Val> Drop for Guard<Key, Val> { + fn drop(&mut self) { + if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { + self.map + .lock() + .expect("locked") + .retain(|_, val| !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2); + } + } +} diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index ec423d53b3027f80451b95d92f7dcb586a220f4b..106d0cb77d387631366b2eb5d9a9db26d109099d 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -2,6 +2,30 @@ pub const EMPTY: &str = ""; +/// Constant expression to bypass format! if the argument is a string literal +/// but not a format string. If the literal is a format string then String is +/// returned otherwise the input (i.e. &'static str) is returned. If multiple +/// arguments are provided the first is assumed to be a format string. +#[macro_export] +macro_rules! format_maybe { + ($s:literal) => { + if $crate::is_format!($s) { std::format!($s).into() } else { $s.into() } + }; + + ($($args:expr),*) => { + std::format!($($args),*).into() + }; +} + +/// Constant expression to decide if a literal is a format string. Note: could +/// use some improvement. +#[macro_export] +macro_rules! is_format { + ($s:literal) => { + ::const_str::contains!($s, "{") && ::const_str::contains!($s, "}") + }; +} + /// Find the common prefix from a collection of strings and return a slice /// ``` /// use conduit_core::utils::string::common_prefix; diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index add15861fdb782635e1d0f369c255a729a97fcb4..4396894704a909fad8e3e02a7891b32147565e3f 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -81,3 +81,56 @@ fn checked_add_overflow() { let res = checked!(a + 1).expect("overflow"); assert_eq!(res, 0); } + +#[tokio::test] +async fn mutex_map_cleanup() { + use crate::utils::MutexMap; + + let map = MutexMap::<String, ()>::new(); + + let lock = map.lock("foo").await; + assert!(!map.is_empty(), "map must not be empty"); + + drop(lock); + assert!(map.is_empty(), "map must be empty"); +} + +#[tokio::test] +async fn mutex_map_contend() { + use std::sync::Arc; + + use tokio::sync::Barrier; + + use crate::utils::MutexMap; + + let map = Arc::new(MutexMap::<String, ()>::new()); + let seq = Arc::new([Barrier::new(2), Barrier::new(2)]); + let str = "foo".to_owned(); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_a = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "A0 must not be empty"); + seq_[0].wait().await; + assert!(map_.contains(&str_), "A1 must contain key"); + }); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_b = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "B0 must not be empty"); + seq_[1].wait().await; + assert!(map_.contains(&str_), "B1 must contain key"); + }); + + seq[0].wait().await; + assert!(map.contains(&str), "Must contain key"); + seq[1].wait().await; + + tokio::try_join!(join_b, join_a).expect("joined"); + assert!(map.is_empty(), "Must be empty"); +} diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 8b0d3fc30a6512a2fcc2aa44d41eefd1fbd220cd..34d98416dad0b00de130647a2a25df7cad3d70bc 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -36,6 +36,7 @@ zstd_compression = [ [dependencies] conduit-core.workspace = true +const-str.workspace = true log.workspace = true rust-rocksdb.workspace = true tokio.workspace = true diff --git a/src/database/util.rs b/src/database/util.rs index 513cedc8f02913534fe8b13afdb64025baa3730d..f0ccbcbee045f793cddfa37ab92333ca9ded3ddd 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,4 @@ -use conduit::Result; +use conduit::{err, Result}; #[inline] pub(crate) fn result<T>(r: std::result::Result<T, rocksdb::Error>) -> Result<T, conduit::Error> { @@ -10,4 +10,7 @@ pub(crate) fn and_then<T>(t: T) -> Result<T, conduit::Error> { Ok(t) } pub(crate) fn or_else<T>(e: rocksdb::Error) -> Result<T, conduit::Error> { Err(map_err(e)) } -pub(crate) fn map_err(e: rocksdb::Error) -> conduit::Error { conduit::Error::Database(e.into_string()) } +pub(crate) fn map_err(e: rocksdb::Error) -> conduit::Error { + let string = e.into_string(); + err!(Database(error!("{string}"))) +} diff --git a/src/main/Cargo.toml b/src/main/Cargo.toml index fa0e5874fa01f97f899424656ac1d3cb2452307b..8dc2a34dc314d066c109c5d0521cf938bb010e22 100644 --- a/src/main/Cargo.toml +++ b/src/main/Cargo.toml @@ -147,6 +147,7 @@ log.workspace = true tracing.workspace = true tracing-subscriber.workspace = true clap.workspace = true +const-str.workspace = true opentelemetry.workspace = true opentelemetry.optional = true diff --git a/src/main/main.rs b/src/main/main.rs index 23f53b4eeee8bb32bec86dff23d5d85b58737d77..959e861001f37dfff68f33bbe5e09000876086ea 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -20,7 +20,7 @@ const WORKER_NAME: &str = "conduwuit:worker"; const WORKER_MIN: usize = 2; -const WORKER_KEEPALIVE_MS: u64 = 2500; +const WORKER_KEEPALIVE: u64 = 36; fn main() -> Result<(), Error> { let args = clap::parse(); @@ -29,7 +29,7 @@ fn main() -> Result<(), Error> { .enable_time() .thread_name(WORKER_NAME) .worker_threads(cmp::max(WORKER_MIN, available_parallelism())) - .thread_keep_alive(Duration::from_millis(WORKER_KEEPALIVE_MS)) + .thread_keep_alive(Duration::from_secs(WORKER_KEEPALIVE)) .build() .expect("built runtime"); diff --git a/src/main/sentry.rs b/src/main/sentry.rs index 6ed4bb8a8a4feb4a703d0f1d09e5114a6010fa7d..04ad8654f982da1e564292421420d8327221d35f 100644 --- a/src/main/sentry.rs +++ b/src/main/sentry.rs @@ -1,18 +1,34 @@ #![cfg(feature = "sentry_telemetry")] -use std::{str::FromStr, sync::Arc}; +use std::{ + str::FromStr, + sync::{Arc, OnceLock}, +}; -use conduit::{config::Config, trace}; +use conduit::{config::Config, debug, trace}; use sentry::{ - types::{protocol::v7::Event, Dsn}, - Breadcrumb, ClientOptions, + types::{ + protocol::v7::{Context, Event}, + Dsn, + }, + Breadcrumb, ClientOptions, Level, }; +static SEND_PANIC: OnceLock<bool> = OnceLock::new(); +static SEND_ERROR: OnceLock<bool> = OnceLock::new(); + pub(crate) fn init(config: &Config) -> Option<sentry::ClientInitGuard> { config.sentry.then(|| sentry::init(options(config))) } fn options(config: &Config) -> ClientOptions { + SEND_PANIC + .set(config.sentry_send_panic) + .expect("SEND_PANIC was not previously set"); + SEND_ERROR + .set(config.sentry_send_error) + .expect("SEND_ERROR was not previously set"); + let dsn = config .sentry_endpoint .as_ref() @@ -28,6 +44,7 @@ fn options(config: &Config) -> ClientOptions { debug: cfg!(debug_assertions), release: sentry::release_name!(), user_agent: conduit::version::user_agent().into(), + attach_stacktrace: config.sentry_attach_stacktrace, before_send: Some(Arc::new(before_send)), before_breadcrumb: Some(Arc::new(before_breadcrumb)), ..Default::default() @@ -35,11 +52,40 @@ fn options(config: &Config) -> ClientOptions { } fn before_send(event: Event<'static>) -> Option<Event<'static>> { - trace!("Sending sentry event: {event:?}"); + if event.exception.iter().any(|e| e.ty == "panic") && !SEND_PANIC.get().unwrap_or(&true) { + return None; + } + + if event.level == Level::Error { + if !SEND_ERROR.get().unwrap_or(&true) { + return None; + } + + if cfg!(debug_assertions) { + return None; + } + + //NOTE: we can enable this to specify error!(sentry = true, ...) + if let Some(Context::Other(context)) = event.contexts.get("Rust Tracing Fields") { + if !context.contains_key("sentry") { + //return None; + } + } + } + + if event.level == Level::Fatal { + trace!("{event:#?}"); + } + + debug!("Sending sentry event: {event:?}"); Some(event) } fn before_breadcrumb(crumb: Breadcrumb) -> Option<Breadcrumb> { - trace!("Adding sentry breadcrumb: {crumb:?}"); + if crumb.ty == "log" && crumb.level == Level::Debug { + return None; + } + + trace!("Sentry breadcrumb: {crumb:?}"); Some(crumb) } diff --git a/src/main/server.rs b/src/main/server.rs index f72b3ef34ba15c1d8c2813e4f6b98c96bcb630de..73c06f0cae721de0a686969d9b7f6ed46aeab541 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -27,7 +27,7 @@ pub(crate) fn build(args: Args, runtime: Option<&runtime::Handle>) -> Result<Arc #[cfg(feature = "sentry_telemetry")] let sentry_guard = crate::sentry::init(&config); - let (tracing_reload_handle, tracing_flame_guard, capture) = crate::tracing::init(&config); + let (tracing_reload_handle, tracing_flame_guard, capture) = crate::tracing::init(&config)?; config.check()?; diff --git a/src/main/tracing.rs b/src/main/tracing.rs index bbfe4dc4245d6a01493590bb0c6ac93e6e214368..0217f38adc78202acbeaf6b213f788f1bf3356bc 100644 --- a/src/main/tracing.rs +++ b/src/main/tracing.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{ - config, config::Config, - debug_warn, - log::{capture, LogLevelReloadHandles, ReloadHandle}, + debug_warn, err, + log::{capture, LogLevelReloadHandles}, + Result, }; use tracing_subscriber::{layer::SubscriberExt, reload, EnvFilter, Layer, Registry}; @@ -14,47 +14,38 @@ pub(crate) type TracingFlameGuard = (); #[allow(clippy::redundant_clone)] -pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard, Arc<capture::State>) { - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new(config::default_log()).expect("failed to set default EnvFilter") - }, - }; - - let mut reload_handles = Vec::<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>::new(); - let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(fmt_reload_handle)); +pub(crate) fn init(config: &Config) -> Result<(LogLevelReloadHandles, TracingFlameGuard, Arc<capture::State>)> { + let reload_handles = LogLevelReloadHandles::default(); - let subscriber = Registry::default().with(fmt_layer.with_filter(fmt_reload_filter)); + let console_filter = EnvFilter::try_new(&config.log).map_err(|e| err!(Config("log", "{e}.")))?; + let console_layer = tracing_subscriber::fmt::Layer::new(); + let (console_reload_filter, console_reload_handle) = reload::Layer::new(console_filter.clone()); + reload_handles.add("console", Box::new(console_reload_handle)); let cap_state = Arc::new(capture::State::new()); let cap_layer = capture::Layer::new(&cap_state); - let subscriber = subscriber.with(cap_layer); + + let subscriber = Registry::default() + .with(console_layer.with_filter(console_reload_filter)) + .with(cap_layer); #[cfg(feature = "sentry_telemetry")] let subscriber = { + let sentry_filter = + EnvFilter::try_new(&config.sentry_filter).map_err(|e| err!(Config("sentry_filter", "{e}.")))?; let sentry_layer = sentry_tracing::layer(); - let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(sentry_reload_handle)); + let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(sentry_filter); + reload_handles.add("sentry", Box::new(sentry_reload_handle)); subscriber.with(sentry_layer.with_filter(sentry_reload_filter)) }; #[cfg(feature = "perf_measurements")] let (subscriber, flame_guard) = { let (flame_layer, flame_guard) = if config.tracing_flame { - let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) { - Ok(flame_filter) => flame_filter, - Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"), - }; - - let (flame_layer, flame_guard) = - match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) { - Ok(ok) => ok, - Err(e) => panic!("failed to initialize tracing-flame: {e}"), - }; + let flame_filter = EnvFilter::try_new(&config.tracing_flame_filter) + .map_err(|e| err!(Config("tracing_flame_filter", "{e}.")))?; + let (flame_layer, flame_guard) = tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) + .map_err(|e| err!(Config("tracing_flame_output_path", "{e}.")))?; let flame_layer = flame_layer .with_empty_samples(false) .with_filter(flame_filter); @@ -63,21 +54,20 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard (None, None) }; - let jaeger_layer = if config.allow_jaeger { + let jaeger_filter = + EnvFilter::try_new(&config.jaeger_filter).map_err(|e| err!(Config("jaeger_filter", "{e}.")))?; + let jaeger_layer = config.allow_jaeger.then(|| { opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_auto_split_batch(true) .with_service_name("conduwuit") .install_batch(opentelemetry_sdk::runtime::Tokio) - .unwrap(); + .expect("jaeger agent pipeline"); let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - - let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(jaeger_reload_handle)); + let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(jaeger_filter.clone()); + reload_handles.add("jaeger", Box::new(jaeger_reload_handle)); Some(telemetry.with_filter(jaeger_reload_filter)) - } else { - None - }; + }); let subscriber = subscriber.with(flame_layer).with(jaeger_layer); (subscriber, flame_guard) @@ -87,7 +77,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard #[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))] let flame_guard = (); - let ret = (LogLevelReloadHandles::new(reload_handles), flame_guard, cap_state); + let ret = (reload_handles, flame_guard, cap_state); // Enable the tokio console. This is slightly kludgy because we're judggling // compile-time and runtime conditions to elide it, each of those changing the @@ -100,7 +90,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard .spawn(); set_global_default(subscriber.with(console_layer)); - return ret; + return Ok(ret); } set_global_default(subscriber); @@ -111,7 +101,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard debug_warn!("{console_disabled_reason}"); } - ret + Ok(ret) } fn tokio_console_enabled(config: &Config) -> (bool, &'static str) { diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index 5312984acaf2d2bc620a981b695a44abd90098c4..38e6adc7c333a552b0b10c49c341e44633880db6 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -55,6 +55,7 @@ conduit-admin.workspace = true conduit-api.workspace = true conduit-core.workspace = true conduit-service.workspace = true +const-str.workspace = true log.workspace = true tokio.workspace = true tower.workspace = true diff --git a/src/router/layers.rs b/src/router/layers.rs index db664b38aa4779e2109c13a8f20f4ca3cd6b14f1..073940f1696558e969c64a3fda76592704d807f2 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -1,11 +1,11 @@ -use std::{any::Any, io, sync::Arc, time::Duration}; +use std::{any::Any, sync::Arc, time::Duration}; use axum::{ extract::{DefaultBodyLimit, MatchedPath}, Router, }; use axum_client_ip::SecureClientIpSource; -use conduit::Server; +use conduit::{Result, Server}; use http::{ header::{self, HeaderName}, HeaderValue, Method, StatusCode, @@ -22,11 +22,19 @@ use crate::{request, router}; -const CONDUWUIT_CSP: &str = "sandbox; default-src 'none'; font-src 'none'; script-src 'none'; frame-ancestors 'none'; \ - form-action 'none'; base-uri 'none';"; -const CONDUWUIT_PERMISSIONS_POLICY: &str = "interest-cohort=(),browsing-topics=()"; +const CONDUWUIT_CSP: &[&str] = &[ + "sandbox", + "default-src 'none'", + "font-src 'none'", + "script-src 'none'", + "frame-ancestors 'none'", + "form-action 'none'", + "base-uri 'none'", +]; -pub(crate) fn build(server: &Arc<Server>) -> io::Result<Router> { +const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; + +pub(crate) fn build(server: &Arc<Server>) -> Result<Router> { let layers = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] @@ -65,11 +73,11 @@ pub(crate) fn build(server: &Arc<Server>) -> io::Result<Router> { )) .layer(SetResponseHeaderLayer::if_not_present( HeaderName::from_static("permissions-policy"), - HeaderValue::from_static(CONDUWUIT_PERMISSIONS_POLICY), + HeaderValue::from_str(&CONDUWUIT_PERMISSIONS_POLICY.join(","))?, )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_static(CONDUWUIT_CSP), + HeaderValue::from_str(&CONDUWUIT_CSP.join("; "))?, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) @@ -145,6 +153,7 @@ fn body_limit_layer(server: &Server) -> DefaultBodyLimit { DefaultBodyLimit::max fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> { conduit_service::services() .server + .metrics .requests_panic .fetch_add(1, std::sync::atomic::Ordering::Release); diff --git a/src/router/request.rs b/src/router/request.rs index 9256fb9c98e7db340a4f8f7a5c740595248192bb..ae7399844ee88b18c67d606c444e8b289f74e521 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -1,6 +1,9 @@ use std::sync::{atomic::Ordering, Arc}; -use axum::{extract::State, response::IntoResponse}; +use axum::{ + extract::State, + response::{IntoResponse, Response}, +}; use conduit::{debug, debug_error, debug_warn, defer, error, trace, Error, Result, Server}; use http::{Method, StatusCode, Uri}; use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind}; @@ -8,17 +11,20 @@ #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn( State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next, -) -> Result<axum::response::Response, StatusCode> { +) -> Result<Response, StatusCode> { if !server.running() { debug_warn!("unavailable pending shutdown"); return Err(StatusCode::SERVICE_UNAVAILABLE); } - let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed); + let active = server + .metrics + .requests_spawn_active + .fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ - let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); - let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); + let active = server.metrics.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.metrics.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; @@ -30,7 +36,7 @@ pub(crate) async fn spawn( #[tracing::instrument(skip_all, name = "handle")] pub(crate) async fn handle( State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next, -) -> Result<axum::response::Response, StatusCode> { +) -> Result<Response, StatusCode> { if !server.running() { debug_warn!( method = %req.method(), @@ -42,12 +48,13 @@ pub(crate) async fn handle( } let active = server + .metrics .requests_handle_active .fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ - let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed); - let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed); + let active = server.metrics.requests_handle_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.metrics.requests_handle_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; @@ -57,9 +64,7 @@ pub(crate) async fn handle( handle_result(&method, &uri, result) } -fn handle_result( - method: &Method, uri: &Uri, result: axum::response::Response, -) -> Result<axum::response::Response, StatusCode> { +fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result<Response, StatusCode> { handle_result_log(method, uri, &result); match result.status() { StatusCode::METHOD_NOT_ALLOWED => handle_result_405(method, uri, &result), @@ -67,9 +72,7 @@ fn handle_result( } } -fn handle_result_405( - _method: &Method, _uri: &Uri, result: &axum::response::Response, -) -> Result<axum::response::Response, StatusCode> { +fn handle_result_405(_method: &Method, _uri: &Uri, result: &Response) -> Result<Response, StatusCode> { let error = Error::RumaError(RumaError { status_code: result.status(), body: ErrorBody::Standard { @@ -81,7 +84,7 @@ fn handle_result_405( Ok(error.into_response()) } -fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { +fn handle_result_log(method: &Method, uri: &Uri, result: &Response) { let status = result.status(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); let code = status.as_u16(); diff --git a/src/router/run.rs b/src/router/run.rs index 3e09823ac56f453cd5a5f430c9a91e4e7aaa43ad..91507772d48edfa2ba57eca2a43e907e142e84f0 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -1,8 +1,10 @@ use std::{sync::Arc, time::Duration}; use axum_server::Handle as ServerHandle; -use tokio::sync::broadcast::{self, Sender}; -use tracing::{debug, error, info}; +use tokio::{ + sync::broadcast::{self, Sender}, + task::JoinHandle, +}; extern crate conduit_admin as admin; extern crate conduit_core as conduit; @@ -10,15 +12,14 @@ use std::sync::atomic::Ordering; -use conduit::{debug_info, trace, Error, Result, Server}; -use service::services; +use conduit::{debug, debug_info, error, info, trace, Error, Result, Server}; -use crate::{layers, serve}; +use crate::serve; /// Main loop base #[tracing::instrument(skip_all)] -pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> { - let app = layers::build(&server)?; +pub(crate) async fn run(server: Arc<Server>) -> Result<()> { + debug!("Start"); // Install the admin room callback here for now admin::init().await; @@ -30,8 +31,16 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> { .runtime() .spawn(signal(server.clone(), tx.clone(), handle.clone())); - // Serve clients - let res = serve::serve(&server, app, handle, tx.subscribe()).await; + let mut listener = server + .runtime() + .spawn(serve::serve(server.clone(), handle.clone(), tx.subscribe())); + + // Focal point + debug!("Running"); + let res = tokio::select! { + res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err), + res = service::services().poll() => handle_services_poll(&server, res, listener).await, + }; // Join the signal handler before we leave. sigs.abort(); @@ -40,17 +49,16 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> { // Remove the admin room callback admin::fini().await; - debug_info!("Finished"); + debug_info!("Finish"); res } /// Async initializations #[tracing::instrument(skip_all)] -pub(crate) async fn start(server: Arc<Server>) -> Result<(), Error> { +pub(crate) async fn start(server: Arc<Server>) -> Result<()> { debug!("Starting..."); - service::init(&server).await?; - services().start().await?; + service::start(&server).await?; #[cfg(feature = "systemd")] sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).expect("failed to notify systemd of ready state"); @@ -61,14 +69,12 @@ pub(crate) async fn start(server: Arc<Server>) -> Result<(), Error> { /// Async destructions #[tracing::instrument(skip_all)] -pub(crate) async fn stop(_server: Arc<Server>) -> Result<(), Error> { +pub(crate) async fn stop(_server: Arc<Server>) -> Result<()> { debug!("Shutting down..."); // Wait for all completions before dropping or we'll lose them to the module // unload and explode. - services().stop().await; - // Deactivate services(). Any further use will panic the caller. - service::fini(); + service::stop().await; debug!("Cleaning up..."); @@ -102,7 +108,7 @@ async fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_se error!("failed sending shutdown transaction to channel: {e}"); } - let pending = server.requests_spawn_active.load(Ordering::Relaxed); + let pending = server.metrics.requests_spawn_active.load(Ordering::Relaxed); if pending > 0 { let timeout = Duration::from_secs(36); trace!(pending, ?timeout, "Notifying for graceful shutdown"); @@ -112,3 +118,21 @@ async fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_se handle.shutdown(); } } + +async fn handle_services_poll( + server: &Arc<Server>, result: Result<()>, listener: JoinHandle<Result<()>>, +) -> Result<()> { + debug!("Service manager finished: {result:?}"); + + if server.running() { + if let Err(e) = server.shutdown() { + error!("Failed to send shutdown signal: {e}"); + } + } + + if let Err(e) = listener.await { + error!("Client listener task finished with error: {e}"); + } + + result +} diff --git a/src/router/serve/mod.rs b/src/router/serve/mod.rs index 47f2fd43f6914cee9644f1f6fd8764a7b491bd7b..4e9234443fa5297d0fb12f04c6273a9e9b68e0f4 100644 --- a/src/router/serve/mod.rs +++ b/src/router/serve/mod.rs @@ -4,23 +4,23 @@ use std::sync::Arc; -use axum::Router; use axum_server::Handle as ServerHandle; -use conduit::{Error, Result, Server}; +use conduit::{Result, Server}; use tokio::sync::broadcast; +use crate::layers; + /// Serve clients -pub(super) async fn serve( - server: &Arc<Server>, app: Router, handle: ServerHandle, shutdown: broadcast::Receiver<()>, -) -> Result<(), Error> { +pub(super) async fn serve(server: Arc<Server>, handle: ServerHandle, shutdown: broadcast::Receiver<()>) -> Result<()> { let config = &server.config; let addrs = config.get_bind_addrs(); + let app = layers::build(&server)?; if cfg!(unix) && config.unix_socket_path.is_some() { - unix::serve(server, app, shutdown).await + unix::serve(&server, app, shutdown).await } else if config.tls.is_some() { - tls::serve(server, app, handle, addrs).await + tls::serve(&server, app, handle, addrs).await } else { - plain::serve(server, app, handle, addrs).await + plain::serve(&server, app, handle, addrs).await } } diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index b79d342d902a2bf19ede9cf7d4969cb7ab363a18..08263353bf0adb36cc16a0abe24b3934f6980284 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -21,12 +21,21 @@ pub(super) async fn serve( info!("Listening on {addrs:?}"); while join_set.join_next().await.is_some() {} - let spawn_active = server.requests_spawn_active.load(Ordering::Relaxed); - let handle_active = server.requests_handle_active.load(Ordering::Relaxed); + let spawn_active = server.metrics.requests_spawn_active.load(Ordering::Relaxed); + let handle_active = server + .metrics + .requests_handle_active + .load(Ordering::Relaxed); debug_info!( - spawn_finished = server.requests_spawn_finished.load(Ordering::Relaxed), - handle_finished = server.requests_handle_finished.load(Ordering::Relaxed), - panics = server.requests_panic.load(Ordering::Relaxed), + spawn_finished = server + .metrics + .requests_spawn_finished + .load(Ordering::Relaxed), + handle_finished = server + .metrics + .requests_handle_finished + .load(Ordering::Relaxed), + panics = server.metrics.requests_panic.load(Ordering::Relaxed), spawn_active, handle_active, "Stopped listening on {addrs:?}", diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index 8373b74943468c9299f0f7d54a54febc470dcdeb..b5938673c5c69b73466ecd34d8324e5dcf025d05 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -3,14 +3,14 @@ use std::{ net::{self, IpAddr, Ipv4Addr}, path::Path, - sync::Arc, + sync::{atomic::Ordering, Arc}, }; use axum::{ extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, Router, }; -use conduit::{debug_error, trace, utils, Error, Result, Server}; +use conduit::{debug, debug_error, error::infallible, info, trace, warn, Err, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -21,14 +21,14 @@ net::{unix::SocketAddr, UnixListener, UnixStream}, sync::broadcast::{self}, task::JoinSet, + time::{sleep, Duration}, }; use tower::{Service, ServiceExt}; -use tracing::{debug, info, warn}; -use utils::unwrap_infallible; type MakeService = IntoMakeServiceWithConnectInfo<Router, net::SocketAddr>; -static NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750); #[tracing::instrument(skip_all)] pub(super) async fn serve(server: &Arc<Server>, app: Router, mut shutdown: broadcast::Receiver<()>) -> Result<()> { @@ -49,7 +49,7 @@ pub(super) async fn serve(server: &Arc<Server>, app: Router, mut shutdown: broad } } - fini(listener, tasks).await; + fini(server, listener, tasks).await; Ok(()) } @@ -62,9 +62,14 @@ async fn accept( let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = unwrap_infallible(app.call(NULL_ADDR).await); - let handler = service_fn(move |req: Request<Incoming>| called.clone().oneshot(req)); + let called = app + .call(NULL_ADDR) + .await + .inspect_err(infallible) + .expect("infallible"); + let service = move |req: Request<Incoming>| called.clone().oneshot(req); + let handler = service_fn(service); let task = async move { // bug on darwin causes all results to be errors. do not unwrap this _ = builder.serve_connection(socket, handler).await; @@ -93,30 +98,41 @@ async fn init(server: &Arc<Server>) -> Result<UnixListener> { let dir = path.parent().unwrap_or_else(|| Path::new("/")); if let Err(e) = fs::create_dir_all(dir).await { - return Err(Error::Err(format!("Failed to create {dir:?} for socket {path:?}: {e}"))); + return Err!("Failed to create {dir:?} for socket {path:?}: {e}"); } let listener = UnixListener::bind(path); if let Err(e) = listener { - return Err(Error::Err(format!("Failed to bind listener {path:?}: {e}"))); + return Err!("Failed to bind listener {path:?}: {e}"); } let socket_perms = config.unix_socket_perms.to_string(); let octal_perms = u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions"); let perms = std::fs::Permissions::from_mode(octal_perms); if let Err(e) = fs::set_permissions(&path, perms).await { - return Err(Error::Err(format!("Failed to set socket {path:?} permissions: {e}"))); + return Err!("Failed to set socket {path:?} permissions: {e}"); } - info!("Listening at {:?}", path); + info!("Listening at {path:?}"); Ok(listener.unwrap()) } -async fn fini(listener: UnixListener, mut tasks: JoinSet<()>) { +async fn fini(server: &Arc<Server>, listener: UnixListener, mut tasks: JoinSet<()>) { let local = listener.local_addr(); + debug!("Closing listener at {local:?} ..."); drop(listener); + + debug!("Waiting for requests to finish..."); + while server.metrics.requests_spawn_active.load(Ordering::Relaxed) > 0 { + tokio::select! { + _ = tasks.join_next() => {} + () = sleep(FINI_POLL_INTERVAL) => {} + } + } + + debug!("Shutting down..."); tasks.shutdown().await; if let Ok(local) = local { diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index f59a6036749080c821bc11b703e3432950fdd355..d2c9785ff714255c7ad3b09aed33adaf1c37467b 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -42,6 +42,7 @@ base64.workspace = true bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true +const-str.workspace = true cyborgtime.workspace = true futures-util.workspace = true hickory-resolver.workspace = true diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 2f66b1d530fa0f23102ae3e5ef486b4ee7628392..c590b928f992ea7c2e68eba882eec38198c29ca8 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -95,7 +95,10 @@ async fn worker(self: Arc<Self>) { ReadlineEvent::Line(string) => self.clone().handle(string).await, ReadlineEvent::Interrupted => continue, ReadlineEvent::Eof => break, - ReadlineEvent::Quit => services().server.shutdown().unwrap_or_else(error::into_log), + ReadlineEvent::Quit => services() + .server + .shutdown() + .unwrap_or_else(error::default_log), }, Err(error) => match error { ReadlineError::Closed => break, diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index ad70fe0c568e0fd54b490dac4c20f302eb84a236..fbb6a078d9353e81a9e7db5cbc915c7b27e8c982 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -34,32 +34,27 @@ pub async fn create_admin_room() -> Result<()> { let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; // Create a user for the server let server_user = &services().globals.server_user; services().users.create(server_user, None)?; let room_version = services().globals.default_room_version(); - let mut content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(server_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), - _ => { - warn!("Unexpected or unsupported room version {}", room_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, + + let mut content = { + use RoomVersionId::*; + match room_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()), + V11 => RoomCreateEventContent::new_v11(), + _ => { + warn!("Unexpected or unsupported room version {}", room_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } }; content.federate = true; diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index ca48ce0dbeb82a42e4821889d662b860b2eb13b7..9a4ef242efdbdf28dd31ef122cac819a0e85a856 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -22,7 +22,7 @@ /// In conduit, this is equivalent to granting admin privileges. pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<()> { if let Some(room_id) = Service::get_admin_room()? { - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; // Use the server user to grant the new admin's power level let server_user = &services().globals.server_user; diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index ca0e551b984b26f7f8b465f9dc81990c1d7768b8..41019cd19153ae1b53319a30ed6401868a69bb0a 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -9,7 +9,7 @@ }; use async_trait::async_trait; -use conduit::{error, utils::mutex_map, Error, Result}; +use conduit::{debug, error, Error, Result}; pub use create::create_admin_room; pub use grant::make_user_admin; use loole::{Receiver, Sender}; @@ -21,25 +21,13 @@ OwnedEventId, OwnedRoomId, RoomId, UserId, }; use serde_json::value::to_raw_value; -use tokio::{ - sync::{Mutex, RwLock}, - task::JoinHandle, -}; - -use crate::{pdu::PduBuilder, services, user_is_local, PduEvent}; - -const COMMAND_QUEUE_LIMIT: usize = 512; +use tokio::sync::{Mutex, RwLock}; -pub type CommandOutput = Option<RoomMessageEventContent>; -pub type CommandResult = Result<CommandOutput, Error>; -pub type HandlerResult = Pin<Box<dyn Future<Output = CommandResult> + Send>>; -pub type Handler = fn(Command) -> HandlerResult; -pub type Completer = fn(&str) -> String; +use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, user_is_local, PduEvent}; pub struct Service { sender: Sender<Command>, receiver: Mutex<Receiver<Command>>, - handler_join: Mutex<Option<JoinHandle<()>>>, pub handle: RwLock<Option<Handler>>, pub complete: StdRwLock<Option<Completer>>, #[cfg(feature = "console")] @@ -52,6 +40,14 @@ pub struct Command { pub reply_id: Option<OwnedEventId>, } +pub type Completer = fn(&str) -> String; +pub type Handler = fn(Command) -> HandlerResult; +pub type HandlerResult = Pin<Box<dyn Future<Output = CommandResult> + Send>>; +pub type CommandResult = Result<CommandOutput, Error>; +pub type CommandOutput = Option<RoomMessageEventContent>; + +const COMMAND_QUEUE_LIMIT: usize = 512; + #[async_trait] impl crate::Service for Service { fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { @@ -59,7 +55,6 @@ fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { sender, receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), handle: RwLock::new(None), complete: StdRwLock::new(None), #[cfg(feature = "console")] @@ -67,16 +62,25 @@ fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { })) } - async fn start(self: Arc<Self>) -> Result<()> { - let self_ = Arc::clone(&self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to initialize admin room handler"); - }); + async fn worker(self: Arc<Self>) -> Result<()> { + let receiver = self.receiver.lock().await; + let mut signals = services().server.signal.subscribe(); + loop { + tokio::select! { + command = receiver.recv_async() => match command { + Ok(command) => self.handle_command(command).await, + Err(_) => break, + }, + sig = signals.recv() => match sig { + Ok(sig) => self.handle_signal(sig).await, + Err(_) => continue, + }, + } + } - _ = self.handler_join.lock().await.insert(handle); + //TODO: not unwind safe + #[cfg(feature = "console")] + self.console.close().await; Ok(()) } @@ -90,19 +94,6 @@ fn interrupt(&self) { } } - async fn stop(&self) { - self.interrupt(); - - #[cfg(feature = "console")] - self.console.close().await; - - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -149,31 +140,16 @@ async fn send(&self, message: Command) { self.sender.send_async(message).await.expect("message sent"); } - async fn handler(self: &Arc<Self>) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut signals = services().server.signal.subscribe(); - loop { - tokio::select! { - command = receiver.recv_async() => match command { - Ok(command) => self.handle_command(command).await, - Err(_) => return Ok(()), - }, - sig = signals.recv() => match sig { - Ok(sig) => self.handle_signal(sig).await, - Err(_) => continue, - }, - } - } - } - async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { #[cfg(feature = "console")] self.console.handle_signal(sig).await; } async fn handle_command(&self, command: Command) { - if let Ok(Some(output)) = self.process_command(command).await { - handle_response(output).await; + match self.process_command(command).await { + Ok(Some(output)) => handle_response(output).await, + Ok(None) => debug!("Command successful with no response"), + Err(e) => error!("Command processing error: {e}"), } } @@ -248,7 +224,7 @@ async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, use "sender is not admin" ); - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -263,14 +239,14 @@ async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, use .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .await { - if let Err(e) = handle_response_error(&e, room_id, user_id, &state_lock).await { + if let Err(e) = handle_response_error(e, room_id, user_id, &state_lock).await { error!("{e}"); } } } async fn handle_response_error( - e: &Error, room_id: &RoomId, user_id: &UserId, state_lock: &mutex_map::Guard<()>, + e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, ) -> Result<()> { error!("Failed to build and append admin room response PDU: \"{e}\""); let error_room_message = RoomMessageEventContent::text_plain(format!( diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index f7cf0b6bb789a622901a615c438307fd082a78c5..24c9b8b076737bd5a40903df45641b2e599673b3 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::Result; +use conduit::{err, Result}; use data::Data; use futures_util::Future; use regex::RegexSet; @@ -171,7 +171,7 @@ pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { .write() .await .remove(service_name) - .ok_or_else(|| crate::Error::Err("Appservice not found".to_owned()))?; + .ok_or(err!("Appservice not found"))?; // remove the appservice from the database self.db.unregister_appservice(service_name)?; diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 254a3d9cd4fba50cac892068d7242b62b30bc82a..281c2a94a3eca600396fc215974af528fa32ead6 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -14,9 +14,6 @@ use crate::services; -const COUNTER: &[u8] = b"c"; -const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; - pub struct Data { global: Arc<Map>, todeviceid_events: Arc<Map>, @@ -35,6 +32,8 @@ pub struct Data { counter: RwLock<u64>, } +const COUNTER: &[u8] = b"c"; + impl Data { pub(super) fn new(db: &Arc<Database>) -> Self { Self { @@ -93,23 +92,6 @@ fn stored_count(global: &Arc<Map>) -> Result<u64> { .map_or(Ok(0_u64), utils::u64_from_bytes) } - pub fn last_check_for_updates_id(&self) -> Result<u64> { - self.global - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) - } - - #[inline] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) - } - #[tracing::instrument(skip(self), level = "debug")] pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 11bfc88c393c4975e0128c6da1dd13bcb803ac75..0a0d0d8eb32aadb95211353168e87ebc211a684c 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,9 +1,8 @@ mod client; mod data; -pub(super) mod emerg_access; +mod emerg_access; pub(super) mod migrations; pub(crate) mod resolver; -pub(super) mod updates; use std::{ collections::{BTreeMap, HashMap}, @@ -12,23 +11,22 @@ time::Instant, }; -use conduit::{error, trace, utils::MutexMap, Config, Result}; +use async_trait::async_trait; +use conduit::{error, trace, Config, Result}; use data::Data; use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ api::{client::discovery::discover_support::ContactRole, federation::discovery::VerifyKey}, serde::Base64, - DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, - RoomAliasId, RoomVersionId, ServerName, UserId, + DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, + RoomVersionId, ServerName, UserId, }; -use tokio::{sync::Mutex, task::JoinHandle}; +use tokio::sync::Mutex; use url::Url; use crate::services; -type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries - pub struct Service { pub db: Data, @@ -43,16 +41,14 @@ pub struct Service { pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>, pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>, pub bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>, - pub roomid_mutex_insert: MutexMap<OwnedRoomId, ()>, - pub roomid_mutex_state: MutexMap<OwnedRoomId, ()>, - pub roomid_mutex_federation: MutexMap<OwnedRoomId, ()>, - pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>, - pub updates_handle: Mutex<Option<JoinHandle<()>>>, pub stateres_mutex: Arc<Mutex<()>>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, } +type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries + +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; @@ -107,11 +103,6 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - roomid_mutex_state: MutexMap::<OwnedRoomId, ()>::new(), - roomid_mutex_insert: MutexMap::<OwnedRoomId, ()>::new(), - roomid_mutex_federation: MutexMap::<OwnedRoomId, ()>::new(), - roomid_federationhandletime: RwLock::new(HashMap::new()), - updates_handle: Mutex::new(None), stateres_mutex: Arc::new(Mutex::new(())), admin_alias: RoomAliasId::parse(format!("#admins:{}", &config.server_name)) .expect("#admins:server_name is valid alias name"), @@ -130,6 +121,12 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(s)) } + async fn worker(self: Arc<Self>) -> Result<()> { + emerg_access::init_emergency_access(); + + Ok(()) + } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { self.resolver.memory_usage(out)?; @@ -189,12 +186,6 @@ pub fn next_count(&self) -> Result<u64> { self.db.next_count() } #[inline] pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) } - #[tracing::instrument(skip(self), level = "debug")] - pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() } - - #[tracing::instrument(skip(self), level = "debug")] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { self.db.watch(user_id, device_id).await } diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index 3082f2fda3123e3d233a4d00fb0d2c161816e539..3002decf4515f02ac05821c13469747a6ea05e4c 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -7,7 +7,7 @@ time::Duration, }; -use conduit::{error, Config, Error, Result}; +use conduit::{error, Config, Result}; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::OwnedServerName; @@ -33,10 +33,7 @@ impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub(super) fn new(config: &Config) -> Self { let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .map_err(|e| { - error!("Failed to set up hickory dns resolver with system config: {e}"); - Error::bad_config("Failed to set up hickory dns resolver with system config.") - }) + .inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}")) .expect("DNS system config must be valid"); let mut conf = hickory_resolver::config::ResolverConfig::new(); diff --git a/src/service/globals/updates.rs b/src/service/globals/updates.rs deleted file mode 100644 index c6ac9fffdd05f0b22eacb07dc165f50fc78789e0..0000000000000000000000000000000000000000 --- a/src/service/globals/updates.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::time::Duration; - -use ruma::events::room::message::RoomMessageEventContent; -use serde::Deserialize; -use tokio::{task::JoinHandle, time::interval}; -use tracing::{error, warn}; - -use crate::{ - conduit::{Error, Result}, - services, -}; - -const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; -const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours - -#[derive(Deserialize)] -struct CheckForUpdatesResponseEntry { - id: u64, - date: String, - message: String, -} -#[derive(Deserialize)] -struct CheckForUpdatesResponse { - updates: Vec<CheckForUpdatesResponseEntry>, -} - -#[tracing::instrument] -pub fn start_check_for_updates_task() -> JoinHandle<()> { - let timer_interval = Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL); - - services().server.runtime().spawn(async move { - let mut i = interval(timer_interval); - - loop { - i.tick().await; - - if let Err(e) = try_handle_updates().await { - warn!(%e, "Failed to check for updates"); - } - } - }) -} - -#[tracing::instrument(skip_all)] -async fn try_handle_updates() -> Result<()> { - let response = services() - .globals - .client - .default - .get(CHECK_FOR_UPDATES_URL) - .send() - .await?; - - let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?) - .map_err(|e| Error::Err(format!("Bad check for updates response: {e}")))?; - - let mut last_update_id = services().globals.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > services().globals.last_check_for_updates_id()? { - error!("{}", update.message); - services() - .admin - .send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))) - .await; - } - } - services() - .globals - .update_check_for_updates_id(last_update_id)?; - - Ok(()) -} diff --git a/src/service/manager.rs b/src/service/manager.rs new file mode 100644 index 0000000000000000000000000000000000000000..af59b4a432afbeb46aae6b050fa08ba7bb462780 --- /dev/null +++ b/src/service/manager.rs @@ -0,0 +1,156 @@ +use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; + +use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; +use futures_util::FutureExt; +use tokio::{ + sync::{Mutex, MutexGuard}, + task::{JoinHandle, JoinSet}, + time::sleep, +}; + +use crate::{service::Service, Services}; + +pub(crate) struct Manager { + manager: Mutex<Option<JoinHandle<Result<()>>>>, + workers: Mutex<Workers>, + server: Arc<Server>, + services: &'static Services, +} + +type Workers = JoinSet<WorkerResult>; +type WorkerResult = (Arc<dyn Service>, Result<()>); +type WorkersLocked<'a> = MutexGuard<'a, Workers>; + +const RESTART_DELAY_MS: u64 = 2500; + +impl Manager { + pub(super) fn new(services: &Services) -> Arc<Self> { + Arc::new(Self { + manager: Mutex::new(None), + workers: Mutex::new(JoinSet::new()), + server: services.server.clone(), + services: crate::services(), + }) + } + + pub(super) async fn poll(&self) -> Result<()> { + if let Some(manager) = &mut *self.manager.lock().await { + trace!("Polling service manager..."); + return manager.await?; + } + + Ok(()) + } + + pub(super) async fn start(self: Arc<Self>) -> Result<()> { + let mut workers = self.workers.lock().await; + + debug!("Starting service manager..."); + let self_ = self.clone(); + _ = self.manager.lock().await.insert( + self.server + .runtime() + .spawn(async move { self_.worker().await }), + ); + + debug!("Starting service workers..."); + for service in self.services.service.values() { + self.start_worker(&mut workers, service).await?; + } + + Ok(()) + } + + pub(super) async fn stop(&self) { + if let Some(manager) = self.manager.lock().await.take() { + debug!("Waiting for service manager..."); + if let Err(e) = manager.await { + error!("Manager shutdown error: {e:?}"); + } + } + } + + async fn worker(&self) -> Result<()> { + loop { + let mut workers = self.workers.lock().await; + tokio::select! { + result = workers.join_next() => match result { + Some(Ok(result)) => self.handle_result(&mut workers, result).await?, + Some(Err(error)) => self.handle_abort(&mut workers, Error::from(error)).await?, + None => break, + } + } + } + + debug!("Worker manager finished"); + Ok(()) + } + + async fn handle_abort(&self, _workers: &mut WorkersLocked<'_>, error: Error) -> Result<()> { + // not supported until service can be associated with abort + unimplemented!("unexpected worker task abort {error:?}"); + } + + async fn handle_result(&self, workers: &mut WorkersLocked<'_>, result: WorkerResult) -> Result<()> { + let (service, result) = result; + match result { + Ok(()) => self.handle_finished(workers, &service).await, + Err(error) => self.handle_error(workers, &service, error).await, + } + } + + async fn handle_finished(&self, _workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> { + debug!("service {:?} worker finished", service.name()); + Ok(()) + } + + async fn handle_error( + &self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>, error: Error, + ) -> Result<()> { + let name = service.name(); + error!("service {name:?} aborted: {error}"); + + if !self.server.running() { + debug_warn!("service {name:?} error ignored on shutdown."); + return Ok(()); + } + + if !error.is_panic() { + return Err(error); + } + + let delay = Duration::from_millis(RESTART_DELAY_MS); + warn!("service {name:?} worker restarting after {} delay", time::pretty(delay)); + sleep(delay).await; + + self.start_worker(workers, service).await + } + + /// Start the worker in a task for the service. + async fn start_worker(&self, workers: &mut WorkersLocked<'_>, service: &Arc<dyn Service>) -> Result<()> { + if !self.server.running() { + return Err!("Service {:?} worker not starting during server shutdown.", service.name()); + } + + debug!("Service {:?} worker starting...", service.name()); + workers.spawn_on(worker(service.clone()), self.server.runtime()); + + Ok(()) + } +} + +/// Base frame for service worker. This runs in a tokio::task. All errors and +/// panics from the worker are caught and returned cleanly. The JoinHandle +/// should never error with a panic, and if so it should propagate, but it may +/// error with an Abort which the manager should handle along with results to +/// determine if the worker should be restarted. +async fn worker(service: Arc<dyn Service>) -> WorkerResult { + let service_ = Arc::clone(&service); + let result = AssertUnwindSafe(service_.worker()) + .catch_unwind() + .await + .map_err(Error::from_panic); + + // flattens JoinError for panic into worker's Error + (service, result.unwrap_or_else(Err)) +} diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 3cb8fda83a33544f6a547522380de7684a1c13b7..9b1cbe066f7baa14ff52c1116048d601f2fdbf9f 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, io::Cursor, num::Saturating as Sat, path::PathBuf, sync::Arc, time::SystemTime}; +use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; use conduit::{checked, debug, debug_error, error, utils, Error, Result, Server}; use data::Data; @@ -47,6 +48,7 @@ pub struct Service { pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>, } +#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { @@ -56,6 +58,12 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { })) } + async fn worker(self: Arc<Self>) -> Result<()> { + self.create_media_dir().await?; + + Ok(()) + } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 4b19073d156c7c011f2596e6bc60e458d8a03aac..81e0be3b55fff502ec8a85ffc892f38c7560b379 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,5 +1,6 @@ #![allow(refining_impl_trait)] +mod manager; mod service; pub mod services; @@ -15,6 +16,7 @@ pub mod sending; pub mod transaction_ids; pub mod uiaa; +pub mod updates; pub mod users; extern crate conduit_core as conduit; @@ -22,7 +24,7 @@ use std::sync::{Arc, RwLock}; -pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, Result, Server}; +pub(crate) use conduit::{config, debug_error, debug_warn, utils, Config, Error, Result, Server}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; use database::Database; pub(crate) use service::{Args, Service}; @@ -37,15 +39,17 @@ static SERVICES: RwLock<Option<&Services>> = RwLock::new(None); -pub async fn init(server: &Arc<Server>) -> Result<()> { +pub async fn start(server: &Arc<Server>) -> Result<()> { let d = Arc::new(Database::open(server).await?); let s = Box::new(Services::build(server.clone(), d)?); _ = SERVICES.write().expect("write locked").insert(Box::leak(s)); - Ok(()) + services().start().await } -pub fn fini() { +pub async fn stop() { + services().stop().await; + // Deactivate services(). Any further use will panic the caller. let s = SERVICES .write() diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index f5400379092b5f51544bde8c7220e68cfa3c48c7..254304bae2dadede7907e39518f76ede3a2bbe04 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -12,7 +12,7 @@ OwnedUserId, UInt, UserId, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinHandle, time::sleep}; +use tokio::{sync::Mutex, time::sleep}; use crate::{services, user_is_local}; @@ -77,7 +77,6 @@ pub struct Service { pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>, - handler_join: Mutex<Option<JoinHandle<()>>>, timeout_remote_users: bool, idle_timeout: u64, offline_timeout: u64, @@ -94,34 +93,26 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { db: Data::new(args.db), timer_sender, timer_receiver: Mutex::new(timer_receiver), - handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, idle_timeout: checked!(idle_timeout_s * 1_000)?, offline_timeout: checked!(offline_timeout_s * 1_000)?, })) } - async fn start(self: Arc<Self>) -> Result<()> { - //TODO: if self.globals.config.allow_local_presence { return; } - - let self_ = Arc::clone(&self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start presence handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); - - Ok(()) - } - - async fn stop(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); + async fn worker(self: Arc<Self>) -> Result<()> { + let mut presence_timers = FuturesUnordered::new(); + let receiver = self.timer_receiver.lock().await; + loop { + debug_assert!(!receiver.is_closed(), "channel error"); + tokio::select! { + Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + event = receiver.recv_async() => match event { + Err(_e) => return Ok(()), + Ok((user_id, timeout)) => { + debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); + presence_timers.push(presence_timer(user_id, timeout)); + }, + }, } } } @@ -219,24 +210,6 @@ pub fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId self.db.presence_since(since) } - async fn handler(&self) -> Result<()> { - let mut presence_timers = FuturesUnordered::new(); - let receiver = self.timer_receiver.lock().await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); - tokio::select! { - Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, - event = receiver.recv_async() => match event { - Err(_e) => return Ok(()), - Ok((user_id, timeout)) => { - debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); - presence_timers.push(presence_timer(user_id, timeout)); - }, - }, - } - } - } - fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 4af1035e58e2b39f226cb943a2688c38531e42cd..792f5c988c3a4b4acc5e42c1b3acdfc3020e1c94 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -3,7 +3,7 @@ use std::sync::Arc; -use conduit::{Error, Result}; +use conduit::{err, Error, Result}; use data::Data; use ruma::{ api::{appservice, client::error::ErrorKind}, @@ -171,7 +171,7 @@ async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Opt .rooms .alias .resolve_local_alias(room_alias)? - .ok_or_else(|| Error::bad_config("Room does not exist."))?, + .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, )); } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 2919ea7ee939e800f1f794de64094024bd3efbb9..4e8c7bb2164f94d9c896a185c82b2b4de7255514 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,9 +5,9 @@ sync::Arc, }; -use conduit::{debug, error, trace, validated, warn, Error, Result}; +use conduit::{debug, error, trace, validated, warn, Err, Result}; use data::Data; -use ruma::{api::client::error::ErrorKind, EventId, RoomId}; +use ruma::{EventId, RoomId}; use crate::services; @@ -143,8 +143,11 @@ fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<H match services().rooms.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { - error!(?event_id, ?pdu, "auth event for incorrect room_id"); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Evil event in db")); + return Err!(Request(Forbidden( + "auth event {event_id:?} for incorrect room {} which is not {}", + pdu.room_id, + room_id + ))); } for auth_event in &pdu.auth_events { let sauthevent = services() diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 53ce6f8d3761eb590f4ac78da8feaa2a522f0c29..6cb23b9fa593ca4b2c8c114357f7be842477b442 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -3,14 +3,16 @@ use std::{ collections::{hash_map, BTreeMap, HashMap, HashSet}, + fmt::Write, pin::Pin, - sync::Arc, + sync::{Arc, RwLock as StdRwLock}, time::Instant, }; use conduit::{ - debug, debug_error, debug_info, error, info, trace, utils::math::continue_exponential_backoff_secs, warn, Error, - Result, + debug, debug_error, debug_info, err, error, info, trace, + utils::{math::continue_exponential_backoff_secs, MutexMap}, + warn, Error, Result, }; use futures_util::Future; pub use parse_incoming_pdu::parse_incoming_pdu; @@ -28,14 +30,21 @@ int, serde::Base64, state_res::{self, RoomVersion, StateMap}, - uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedUserId, RoomId, RoomVersionId, ServerName, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, + RoomVersionId, ServerName, }; use tokio::sync::RwLock; use super::state_compressor::CompressedStateEvent; use crate::{pdu, services, PduEvent}; -pub struct Service; +pub struct Service { + pub federation_handletime: StdRwLock<HandleTimeMap>, + pub mutex_federation: RoomMutexMap, +} + +type RoomMutexMap = MutexMap<OwnedRoomId, ()>; +type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>; // We use some AsyncRecursiveType hacks here so we can call async funtion // recursively. @@ -46,7 +55,26 @@ AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>; impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self {})) } + fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { + Ok(Arc::new(Self { + federation_handletime: HandleTimeMap::new().into(), + mutex_federation: RoomMutexMap::new(), + })) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let mutex_federation = self.mutex_federation.len(); + writeln!(out, "federation_mutex: {mutex_federation}")?; + + let federation_handletime = self + .federation_handletime + .read() + .expect("locked for reading") + .len(); + writeln!(out, "federation_handletime: {federation_handletime}")?; + + Ok(()) + } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -200,9 +228,7 @@ pub async fn handle_incoming_pdu<'a>( // Done with prev events, now handling the incoming event let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() .expect("locked") .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); @@ -211,9 +237,7 @@ pub async fn handle_incoming_pdu<'a>( .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) .await; - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() .expect("locked") .remove(&room_id.to_owned()); @@ -272,9 +296,7 @@ pub async fn handle_prev_pdu<'a>( } let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() .expect("locked") .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); @@ -282,9 +304,7 @@ pub async fn handle_prev_pdu<'a>( self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map) .await?; - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() .expect("locked") .remove(&room_id.to_owned()); @@ -531,55 +551,50 @@ pub async fn upgrade_outlier_to_timeline_pdu( // Soft fail check before doing state res debug!("Performing soft-fail check"); - let soft_fail = !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::<PduEvent>, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? - || incoming_pdu.kind == TimelineEventType::RoomRedaction - && match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - if let Some(redact_id) = &incoming_pdu.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - _ => { - let content = serde_json::from_str::<RoomRedactionEventContent>(incoming_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; - - if let Some(redact_id) = &content.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - }; + let soft_fail = { + use RoomVersionId::*; + + !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::<PduEvent>, |k, s| { + auth_events.get(&(k.clone(), s.to_owned())) + }) + .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? + || incoming_pdu.kind == TimelineEventType::RoomRedaction + && match room_version_id { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + if let Some(redact_id) = &incoming_pdu.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false + } + }, + _ => { + let content = serde_json::from_str::<RoomRedactionEventContent>(incoming_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + + if let Some(redact_id) = &content.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false + } + }, + } + }; // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room trace!("Locking the room"); - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) @@ -1367,11 +1382,8 @@ fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> { } fn get_room_version_id(create_event: &PduEvent) -> Result<RoomVersionId> { - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; + let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()) + .map_err(|e| err!(Database("Invalid create event: {e}")))?; Ok(create_event_content.room_version) } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 4c907e511ab969b5e53d5f80fe9f3c8e95e9a6d7..8fcd854969c3b025b41f9e0582f4787947d75a91 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,4 +1,4 @@ -use conduit::{Error, Result}; +use conduit::{Err, Error, Result}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; use tracing::warn; @@ -17,15 +17,12 @@ pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, Canonical .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { - return Err(Error::Err(format!("Server is not in room {room_id}"))); + return Err!("Server is not in room {room_id}"); }; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); }; Ok((event_id, value, room_id)) diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 03d0d43ff2c1acb3e137587094bc52b762fac0ef..18133fc1f9565583f490f9b64d181ed81e75a48c 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,7 @@ sync::Arc, }; -use conduit::{checked, debug_info, utils::math::usize_from_f64}; +use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; use lru_cache::LruCache; use ruma::{ api::{ @@ -27,9 +27,8 @@ OwnedRoomId, OwnedServerName, RoomId, ServerName, UInt, UserId, }; use tokio::sync::Mutex; -use tracing::{debug, error, warn}; -use crate::{services, Error, Result}; +use crate::services; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -380,10 +379,7 @@ fn get_room_summary( .map(|s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) }) .transpose()? .unwrap_or(JoinRule::Invite); diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index aad3bedec08a92c627722306a512ac2e999052da..3c110afc633c39318b20405cf181ae9acbdf9b57 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -3,7 +3,8 @@ use conduit::{utils, Error, Result}; use database::{Database, Map}; use ruma::{EventId, OwnedEventId, RoomId}; -use utils::mutex_map; + +use super::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc<Map>, @@ -35,7 +36,7 @@ pub(super) fn set_room_state( &self, room_id: &RoomId, new_shortstatehash: u64, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; @@ -68,7 +69,7 @@ pub(super) fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec<OwnedEventId>, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 52ee89d1de9390b5ad78042571c96611bf4f435e..a3a317a584bb8925fab01e1aa83d630a4976e947 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -2,11 +2,12 @@ use std::{ collections::{HashMap, HashSet}, + fmt::Write, sync::Arc, }; use conduit::{ - utils::{calculate_hash, mutex_map}, + utils::{calculate_hash, MutexMap, MutexMapGuard}, warn, Error, Result, }; use data::Data; @@ -18,7 +19,7 @@ }, serde::Raw, state_res::{self, StateMap}, - EventId, OwnedEventId, RoomId, RoomVersionId, UserId, + EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; use super::state_compressor::CompressedStateEvent; @@ -26,15 +27,27 @@ pub struct Service { db: Data, + pub mutex: RoomMutexMap, } +type RoomMutexMap = MutexMap<OwnedRoomId, ()>; +pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { db: Data::new(args.db), + mutex: RoomMutexMap::new(), })) } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let mutex = self.mutex.len(); + writeln!(out, "state_mutex: {mutex}")?; + + Ok(()) + } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } @@ -46,7 +59,7 @@ pub async fn force_state( shortstatehash: u64, statediffnew: Arc<HashSet<CompressedStateEvent>>, _statediffremoved: Arc<HashSet<CompressedStateEvent>>, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { services() @@ -318,7 +331,7 @@ pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, - mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db.set_room_state(room_id, shortstatehash, mutex_lock) } @@ -358,7 +371,7 @@ pub fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec<OwnedEventId>, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db .set_forward_extremities(room_id, event_ids, state_lock) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index a35678573f672ff50ff1a9dbf2011ddc2e1d45cc..7abe5e0fc5ff5ecedc3b5a4bb6ef9e988c85f83c 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,11 +6,7 @@ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{ - error, - utils::{math::usize_from_f64, mutex_map}, - warn, Error, Result, -}; +use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result}; use data::Data; use lru_cache::LruCache; use ruma::{ @@ -37,7 +33,7 @@ }; use serde_json::value::to_raw_value; -use crate::{pdu::PduBuilder, services, PduEvent}; +use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent}; pub struct Service { db: Data, @@ -333,7 +329,7 @@ pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<Ro } pub fn user_can_invite( - &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &mutex_map::Guard<()>, + &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, ) -> Result<bool> { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); @@ -458,10 +454,7 @@ pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec< .map(|c: RoomJoinRulesEventContent| { (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) }) - .map_err(|e| { - error!("Invalid room join rule event in database: {e}"); - Error::BadDatabase("Invalid room join rule event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) }) .transpose()? .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) @@ -487,10 +480,8 @@ pub fn get_room_type(&self, room_id: &RoomId) -> Result<Option<RoomType>> { Ok(self .room_state_get(room_id, &StateEventType::RoomCreate, "")? .map(|s| { - serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(|e| { - error!("Invalid room create event in database: {e}"); - Error::BadDatabase("Invalid room create event in database.") - }) + serde_json::from_str::<RoomCreateEventContent>(s.content.get()) + .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) }) .transpose()? .and_then(|e| e.room_type)) @@ -503,10 +494,7 @@ pub fn get_room_encryption(&self, room_id: &RoomId) -> Result<Option<EventEncryp .map_or(Ok(None), |s| { serde_json::from_str::<RoomEncryptionEventContent>(s.content.get()) .map(|content| Some(content.algorithm)) - .map_err(|e| { - error!("Invalid room encryption event in database: {e}"); - Error::BadDatabase("Invalid room encryption event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) }) } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 30fd7bf4f54e5801d9e2342d67c575ef8e6ac118..48215817e9770b1b8175052e200faf837dbbbb47 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use conduit::{error, warn, Error, Result}; +use conduit::{err, error, warn, Error, Result}; use data::Data; use itertools::Itertools; use ruma::{ @@ -128,10 +128,8 @@ pub fn update_membership( .account_data .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? .map(|event| { - serde_json::from_str(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { services() .account_data @@ -144,10 +142,8 @@ pub fn update_membership( .account_data .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? .map(|event| { - serde_json::from_str::<DirectEvent>(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str::<DirectEvent>(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { let mut direct_event = direct_event?; let mut room_ids_updated = false; @@ -185,10 +181,8 @@ pub fn update_membership( .into(), )? .map(|event| { - serde_json::from_str::<IgnoredUserListEvent>(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str::<IgnoredUserListEvent>(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) .transpose()? .map_or(false, |ignored| { diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index c82098ba3da47de5398d3f682f74d83fd1e880e1..0bc5ade16a951df412c2d64d7ac40234f143de4a 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -6,7 +6,11 @@ sync::Arc, }; -use conduit::{debug, error, info, utils, utils::mutex_map, validated, warn, Error, Result}; +use conduit::{ + debug, error, info, utils, + utils::{MutexMap, MutexMapGuard}, + validated, warn, Error, Result, +}; use data::Data; use itertools::Itertools; use ruma::{ @@ -26,8 +30,8 @@ push::{Action, Ruleset, Tweak}, serde::Base64, state_res::{self, Event, RoomVersion}, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedServerName, RoomId, - RoomVersionId, ServerName, UserId, + uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, + RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; @@ -65,12 +69,17 @@ struct ExtractBody { pub struct Service { db: Data, + pub mutex_insert: RoomMutexMap, } +type RoomMutexMap = MutexMap<OwnedRoomId, ()>; +pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { db: Data::new(args.db), + mutex_insert: RoomMutexMap::new(), })) } @@ -83,6 +92,9 @@ fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { .len(); writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + let mutex_insert = self.mutex_insert.len(); + writeln!(out, "insert_mutex: {mutex_insert}")?; + Ok(()) } @@ -203,7 +215,7 @@ pub async fn append_pdu( pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, leaves: Vec<OwnedEventId>, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Vec<u8>> { // Coalesce database writes for the remainder of this scope. let _cork = services().db.cork_and_flush(); @@ -268,11 +280,7 @@ pub async fn append_pdu( .state .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; - let insert_lock = services() - .globals - .roomid_mutex_insert - .lock(&pdu.room_id) - .await; + let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; let count1 = services().globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if @@ -381,18 +389,11 @@ pub async fn append_pdu( match pdu.kind { TimelineEventType::RoomRedaction => { + use RoomVersionId::*; + let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if services().rooms.state_accessor.user_can_redact( redact_id, @@ -404,7 +405,7 @@ pub async fn append_pdu( } } }, - RoomVersionId::V11 => { + V11 => { let content = serde_json::from_str::<RoomRedactionEventContent>(pdu.content.get()).map_err(|e| { warn!("Invalid content in redaction pdu: {e}"); @@ -593,7 +594,7 @@ pub fn create_hash_and_sign_event( pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<(PduEvent, CanonicalJsonObject)> { let PduBuilder { event_type, @@ -780,7 +781,7 @@ pub async fn build_and_append_pdu( pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Arc<EventId>> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; if let Some(admin_room) = admin::Service::get_admin_room()? { @@ -860,17 +861,9 @@ pub async fn build_and_append_pdu( // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { + use RoomVersionId::*; match services().rooms.state.get_room_version(&pdu.room_id)? { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !services().rooms.state_accessor.user_can_redact( redact_id, @@ -963,7 +956,7 @@ pub async fn append_incoming_pdu( new_room_leaves: Vec<OwnedEventId>, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, soft_fail: bool, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Option<Vec<u8>>> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't @@ -1153,8 +1146,9 @@ pub async fn backfill_pdu( // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(&room_id) .await; @@ -1186,7 +1180,7 @@ pub async fn backfill_pdu( .get_shortroomid(&room_id)? .expect("room exists"); - let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; let count = services().globals.next_count()?; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index d7a9c0fc4a10c2cbca137033de546caeeaef473e..88b8b18971e123c43cacdbff53440b3ea9ecdb80 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -4,29 +4,26 @@ mod send; mod sender; -use std::{fmt::Debug, sync::Arc}; +use std::fmt::Debug; -use async_trait::async_trait; -use conduit::{Error, Result}; -use data::Data; +use conduit::{err, Result}; pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; pub use sender::convert_to_outgoing_federation_event; -use tokio::{sync::Mutex, task::JoinHandle}; -use tracing::{error, warn}; +use tokio::sync::Mutex; +use tracing::warn; use crate::{server_is_ours, services}; pub struct Service { - pub db: Data, + pub db: data::Data, /// The state for a given state hash. sender: loole::Sender<Msg>, receiver: Mutex<loole::Receiver<Msg>>, - handler_join: Mutex<Option<JoinHandle<()>>>, startup_netburst: bool, startup_netburst_keep: i64, } @@ -53,45 +50,6 @@ pub enum SendingEvent { Flush, // none } -#[async_trait] -impl crate::Service for Service { - fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { - let config = &args.server.config; - let (sender, receiver) = loole::unbounded(); - Ok(Arc::new(Self { - db: Data::new(args.db.clone()), - sender, - receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), - startup_netburst: config.startup_netburst, - startup_netburst_keep: config.startup_netburst_keep, - })) - } - - async fn start(self: Arc<Self>) -> Result<()> { - self.start_handler().await; - - Ok(()) - } - - async fn stop(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } - - fn interrupt(&self) { - if !self.sender.is_closed() { - self.sender.close(); - } - } - - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } -} - impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { @@ -266,7 +224,7 @@ pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { fn dispatch(&self, msg: Msg) -> Result<()> { debug_assert!(!self.sender.is_full(), "channel full"); debug_assert!(!self.sender.is_closed(), "channel closed"); - self.sender.send(msg).map_err(|e| Error::Err(e.to_string())) + self.sender.send(msg).map_err(|e| err!("{e}")) } } diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index d38509ba8d0cdbafa1af8a10db2fe4229a5373f4..8b6bfc9599879e52db5b334bc984fc8986ce3961 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -5,12 +5,12 @@ time::SystemTime, }; +use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result}; use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; use ruma::{OwnedServerName, ServerName}; -use tracing::{debug, error, trace}; -use crate::{debug_error, debug_info, debug_warn, services, utils::rand, Error, Result}; +use crate::services; /// Wraps either an literal IP address plus port, or a hostname plus complement /// (colon-plus-port if it was specified). @@ -345,16 +345,13 @@ fn handle_resolve_error(e: &ResolveError) -> Result<()> { debug!("{e}"); Ok(()) }, - _ => { - error!("DNS {e}"); - Err(Error::Err(e.to_string())) - }, + _ => Err!(error!("DNS {e}")), } } fn validate_dest(dest: &ServerName) -> Result<()> { if dest == services().globals.server_name() { - return Err(Error::bad_config("Won't send federation request to ourselves")); + return Err!("Won't send federation request to ourselves"); } if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index f4825ff392e3f3e791df7c2908789ca30057be6c..18a98828f353f2b499e5c5e0bfa02ded61fd3e25 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,5 +1,6 @@ use std::{fmt::Debug, mem}; +use conduit::Err; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; use reqwest::{Client, Method, Request, Response, Url}; @@ -26,7 +27,7 @@ pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::In T: OutgoingRequest + Debug + Send, { if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); + return Err!(Config("allow_federation", "Federation is disabled.")); } let actual = resolve::get_actual_dest(dest).await?; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 0fb0d9dc8576207f63fe0b74758dcb0879c1ded1..2f542dfe477f3f063f0b555769ae49ec18e2d7c7 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -3,11 +3,12 @@ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, sync::Arc, - time::Instant, + time::{Duration, Instant}, }; +use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, error, utils::math::continue_exponential_backoff_secs, warn}; +use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; use federation::transactions::send_transaction_message; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -23,8 +24,9 @@ ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::{sync::Mutex, time::sleep_until}; -use super::{appservice, send, Destination, Msg, SendingEvent, Service}; +use super::{appservice, data::Data, send, Destination, Msg, SendingEvent, Service}; use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; #[derive(Debug)] @@ -42,43 +44,58 @@ enum TransactionStatus { const DEQUEUE_LIMIT: usize = 48; const SELECT_EDU_LIMIT: usize = 16; - -impl Service { - pub async fn start_handler(self: &Arc<Self>) { - let self_ = Arc::clone(self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start sending handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); +const CLEANUP_TIMEOUT_MS: u64 = 3500; + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let config = &args.server.config; + let (sender, receiver) = loole::unbounded(); + Ok(Arc::new(Self { + db: Data::new(args.db.clone()), + sender, + receiver: Mutex::new(receiver), + startup_netburst: config.startup_netburst, + startup_netburst_keep: config.startup_netburst_keep, + })) } #[tracing::instrument(skip_all, name = "sender")] - async fn handler(&self) -> Result<()> { + async fn worker(self: Arc<Self>) -> Result<()> { let receiver = self.receiver.lock().await; let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); - self.initial_transactions(&futures, &mut statuses); + self.initial_requests(&futures, &mut statuses); loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { Ok(request) => self.handle_request(request, &futures, &mut statuses), - Err(_) => return Ok(()), + Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &mut futures, &mut statuses); + self.handle_response(response, &futures, &mut statuses); }, } } + self.finish_responses(&mut futures, &mut statuses).await; + + Ok(()) + } + + fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { fn handle_response( - &self, response: SendingResult, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, + &self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { match response { Ok(dest) => self.handle_response_ok(&dest, futures, statuses), @@ -87,7 +104,7 @@ fn handle_response( } fn handle_response_err( - dest: Destination, _futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, + dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, ) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { @@ -138,7 +155,25 @@ fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut } } - fn initial_transactions(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + let now = Instant::now(); + let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS); + let deadline = now.checked_add(timeout).unwrap_or(now); + loop { + trace!("Waiting for {} requests to complete...", futures.len()); + tokio::select! { + () = sleep_until(deadline.into()) => break, + response = futures.next() => match response { + Some(response) => self.handle_response(response, futures, statuses), + None => return, + } + } + } + + debug_warn!("Leaving with {} unfinished requests...", futures.len()); + } + + fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { diff --git a/src/service/service.rs b/src/service/service.rs index ef60f359f877beff4bb6d0ff7959659ac4303760..3b8f4231d7d84011124100ce6ef8364bdd8f025d 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -15,19 +15,12 @@ fn build(args: Args<'_>) -> Result<Arc<impl Service>> where Self: Sized; - /// Start the service. Implement the spawning of any service workers. This - /// is called after all other services have been constructed. Failure will - /// shutdown the server with an error. - async fn start(self: Arc<Self>) -> Result<()> { Ok(()) } - - /// Stop the service. Implement the joining of any service workers and - /// cleanup of any other state. This function is asynchronous to allow that - /// gracefully, but errors cannot propagate. - async fn stop(&self) {} + /// Implement the service's worker loop. The service manager spawns a + /// task and calls this function after all services have been built. + async fn worker(self: Arc<Self>) -> Result<()> { Ok(()) } - /// Interrupt the service. This may be sent prior to `stop()` as a - /// notification to improve the shutdown sequence. Implementations must be - /// robust to this being called multiple times. + /// Interrupt the service. This is sent to initiate a graceful shutdown. + /// The service worker should return from its work loop. fn interrupt(&self) {} /// Clear any caches or similar runtime state. diff --git a/src/service/services.rs b/src/service/services.rs index aeed82043a113ad3ad5129defe69e80dedf144e2..cc9ec2900e50eb3d6fca8182e39b3bdefa564896 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -2,11 +2,14 @@ use conduit::{debug, debug_info, info, trace, Result, Server}; use database::Database; +use tokio::sync::Mutex; use crate::{ - account_data, admin, appservice, globals, key_backups, media, presence, pusher, rooms, sending, + account_data, admin, appservice, globals, key_backups, + manager::Manager, + media, presence, pusher, rooms, sending, service::{Args, Map, Service}, - transaction_ids, uiaa, users, + transaction_ids, uiaa, updates, users, }; pub struct Services { @@ -19,11 +22,13 @@ pub struct Services { pub account_data: Arc<account_data::Service>, pub presence: Arc<presence::Service>, pub admin: Arc<admin::Service>, - pub globals: Arc<globals::Service>, pub key_backups: Arc<key_backups::Service>, pub media: Arc<media::Service>, pub sending: Arc<sending::Service>, + pub updates: Arc<updates::Service>, + pub globals: Arc<globals::Service>, + manager: Mutex<Option<Arc<Manager>>>, pub(crate) service: Map, pub server: Arc<Server>, pub db: Arc<Database>, @@ -78,30 +83,48 @@ macro_rules! build { key_backups: build!(key_backups::Service), media: build!(media::Service), sending: build!(sending::Service), + updates: build!(updates::Service), globals: build!(globals::Service), + manager: Mutex::new(None), service, server, db, }) } - pub async fn memory_usage(&self) -> Result<String> { - let mut out = String::new(); - for service in self.service.values() { - service.memory_usage(&mut out)?; - } + pub(super) async fn start(&self) -> Result<()> { + debug_info!("Starting services..."); - //TODO - let roomid_spacehierarchy_cache = self - .rooms - .spaces - .roomid_spacehierarchy_cache + globals::migrations::migrations(&self.db, &self.server.config).await?; + self.manager .lock() .await - .len(); - writeln!(out, "roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}")?; + .insert(Manager::new(self)) + .clone() + .start() + .await?; - Ok(out) + debug_info!("Services startup complete."); + Ok(()) + } + + pub(super) async fn stop(&self) { + info!("Shutting down services..."); + + self.interrupt(); + if let Some(manager) = self.manager.lock().await.as_ref() { + manager.stop().await; + } + + debug_info!("Services shutdown complete."); + } + + pub async fn poll(&self) -> Result<()> { + if let Some(manager) = self.manager.lock().await.as_ref() { + return manager.poll().await; + } + + Ok(()) } pub async fn clear_cache(&self) { @@ -118,28 +141,26 @@ pub async fn clear_cache(&self) { .clear(); } - pub async fn start(&self) -> Result<()> { - debug_info!("Starting services"); - - self.media.create_media_dir().await?; - globals::migrations::migrations(&self.db, &self.globals.config).await?; - globals::emerg_access::init_emergency_access(); - - for (name, service) in &self.service { - debug!("Starting {name}"); - service.clone().start().await?; + pub async fn memory_usage(&self) -> Result<String> { + let mut out = String::new(); + for service in self.service.values() { + service.memory_usage(&mut out)?; } - if self.globals.allow_check_for_updates() { - let handle = globals::updates::start_check_for_updates_task(); - _ = self.globals.updates_handle.lock().await.insert(handle); - } + //TODO + let roomid_spacehierarchy_cache = self + .rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .len(); + writeln!(out, "roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}")?; - debug_info!("Services startup complete."); - Ok(()) + Ok(out) } - pub fn interrupt(&self) { + fn interrupt(&self) { debug!("Interrupting services..."); for (name, service) in &self.service { @@ -147,22 +168,4 @@ pub fn interrupt(&self) { service.interrupt(); } } - - pub async fn stop(&self) { - info!("Shutting down services"); - self.interrupt(); - - debug!("Waiting for update worker..."); - if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { - updates_handle.abort(); - _ = updates_handle.await; - } - - for (name, service) in &self.service { - debug!("Waiting for {name} ..."); - service.stop().await; - } - - debug_info!("Services shutdown complete."); - } } diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..3fb680d63179442252f9a258d913a03e803919c8 --- /dev/null +++ b/src/service/updates/mod.rs @@ -0,0 +1,112 @@ +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use conduit::{err, info, utils, warn, Error, Result}; +use database::Map; +use ruma::events::room::message::RoomMessageEventContent; +use serde::Deserialize; +use tokio::{sync::Notify, time::interval}; + +use crate::services; + +pub struct Service { + db: Arc<Map>, + interrupt: Notify, + interval: Duration, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponse { + updates: Vec<CheckForUpdatesResponseEntry>, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, +} + +const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; +const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours +const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + Ok(Arc::new(Self { + db: args.db["global"].clone(), + interrupt: Notify::new(), + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + })) + } + + async fn worker(self: Arc<Self>) -> Result<()> { + let mut i = interval(self.interval); + loop { + tokio::select! { + () = self.interrupt.notified() => return Ok(()), + _ = i.tick() => (), + } + + if let Err(e) = self.handle_updates().await { + warn!(%e, "Failed to check for updates"); + } + } + } + + fn interrupt(&self) { self.interrupt.notify_waiters(); } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip_all)] + async fn handle_updates(&self) -> Result<()> { + let response = services() + .globals + .client + .default + .get(CHECK_FOR_UPDATES_URL) + .send() + .await?; + + let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?) + .map_err(|e| err!("Bad check for updates response: {e}"))?; + + let mut last_update_id = self.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > self.last_check_for_updates_id()? { + info!("{:#}", update.message); + services() + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await; + } + } + self.update_check_for_updates_id(last_update_id)?; + + Ok(()) + } + + #[inline] + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + self.db + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + + Ok(()) + } + + pub fn last_check_for_updates_id(&self) -> Result<u64> { + self.db + .get(LAST_CHECK_FOR_UPDATES_COUNT)? + .map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) + } +} diff --git a/tests/test_results/complement/test_results.jsonl b/tests/test_results/complement/test_results.jsonl index 2f3db95d12cc67f3d73870c7229836814db868d5..873397bad245c6c0e365b3ffcb970470a8a67568 100644 --- a/tests/test_results/complement/test_results.jsonl +++ b/tests/test_results/complement/test_results.jsonl @@ -138,7 +138,8 @@ {"Action":"fail","Test":"TestKnockingInMSC3787Room/Knocking_on_a_room_with_join_rule_'knock'_should_succeed#01"} {"Action":"fail","Test":"TestKnockingInMSC3787Room/Users_in_the_room_see_a_user's_membership_update_when_they_knock"} {"Action":"fail","Test":"TestKnockingInMSC3787Room/Users_in_the_room_see_a_user's_membership_update_when_they_knock#01"} -{"Action":"pass","Test":"TestLocalPngThumbnail"} +{"Action":"fail","Test":"TestLocalPngThumbnail"} +{"Action":"fail","Test":"TestLocalPngThumbnail/test_/_matrix/client/v1/media_endpoint"} {"Action":"fail","Test":"TestMediaFilenames"} {"Action":"fail","Test":"TestMediaFilenames/Parallel"} {"Action":"fail","Test":"TestMediaFilenames/Parallel/ASCII"} @@ -181,7 +182,8 @@ {"Action":"pass","Test":"TestOutboundFederationProfile/Outbound_federation_can_query_profile_data"} {"Action":"pass","Test":"TestOutboundFederationSend"} {"Action":"pass","Test":"TestRemoteAliasRequestsUnderstandUnicode"} -{"Action":"pass","Test":"TestRemotePngThumbnail"} +{"Action":"fail","Test":"TestRemotePngThumbnail"} +{"Action":"fail","Test":"TestRemotePngThumbnail/test_/_matrix/client/v1/media_endpoint"} {"Action":"fail","Test":"TestRemotePresence"} {"Action":"fail","Test":"TestRemotePresence/Presence_changes_are_also_reported_to_remote_room_members"} {"Action":"fail","Test":"TestRemotePresence/Presence_changes_to_UNAVAILABLE_are_reported_to_remote_room_members"}