diff --git a/Cargo.lock b/Cargo.lock index 3415cac09da72b6c920c1ce66a254588686dda1d..dd0c0436af6565acb32401535113093cba61400d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -504,6 +504,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "checked_ops" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b491d76efc1d99d74de3c8529bee64c62312c275c7eb124f9185291de45801d5" +dependencies = [ + "num-traits", +] + [[package]] name = "chrono" version = "0.4.38" @@ -603,14 +612,9 @@ dependencies = [ "clap", "conduit_api", "conduit_core", - "conduit_database", "conduit_service", - "futures-util", "log", - "loole", - "regex", "ruma", - "serde", "serde_json", "serde_yaml", "tokio", @@ -633,6 +637,7 @@ dependencies = [ "futures-util", "hmac", "http 1.1.0", + "http-body-util", "hyper 1.4.0", "image", "ipaddress", @@ -645,7 +650,6 @@ dependencies = [ "serde_html_form", "serde_json", "sha-1", - "thiserror", "tokio", "tracing", "webpage", @@ -658,6 +662,7 @@ dependencies = [ "argon2", "axum 0.7.5", "bytes", + "checked_ops", "chrono", "either", "figment", @@ -696,7 +701,6 @@ version = "0.4.5" dependencies = [ "conduit_core", "log", - "ruma", "rust-rocksdb-uwu", "tokio", "tracing", @@ -714,14 +718,12 @@ dependencies = [ "conduit_admin", "conduit_api", "conduit_core", - "conduit_database", "conduit_service", "http 1.1.0", "http-body-util", "hyper 1.4.0", "hyper-util", "log", - "regex", "ruma", "sd-notify", "sentry", diff --git a/Cargo.toml b/Cargo.toml index 29454e606ba53fd81522ac2afdca28036dd5329e..0d3d59fdb9dce683ab601ee199120f02b245dc6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -411,6 +411,9 @@ rev = "de26100b0db03e419a3d8e1dd26895d170d1fe50" version = "0.29.4" default-features = false +[workspace.dependencies.checked_ops] +version = "0.1" + # # Patches @@ -727,7 +730,6 @@ nursery = "warn" ## some sadness missing_const_for_fn = { level = "allow", priority = 1 } # TODO -needless_collect = { level = "allow", priority = 1 } # TODO option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO @@ -737,21 +739,14 @@ significant_drop_tightening = { level = "allow", priority = 1 } # TODO pedantic = "warn" ## some sadness -cast_possible_truncation = { level = "allow", priority = 1 } -cast_precision_loss = { level = "allow", priority = 1 } -cast_sign_loss = { level = "allow", priority = 1 } doc_markdown = { level = "allow", priority = 1 } -error_impl_error = { level = "allow", priority = 1 } -expect_used = { level = "allow", priority = 1 } +enum_glob_use = { level = "allow", priority = 1 } if_not_else = { level = "allow", priority = 1 } if_then_some_else_none = { level = "allow", priority = 1 } -implicit_return = { level = "allow", priority = 1 } inline_always = { level = "allow", priority = 1 } -map_err_ignore = { level = "allow", priority = 1 } missing_docs_in_private_items = { level = "allow", priority = 1 } missing_errors_doc = { level = "allow", priority = 1 } missing_panics_doc = { level = "allow", priority = 1 } -mod_module_files = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } @@ -765,8 +760,10 @@ perf = "warn" ################### #restriction = "warn" -#arithmetic_side_effects = "warn" # TODO -#as_conversions = "warn" # TODO +allow_attributes = "warn" +arithmetic_side_effects = "warn" +as_conversions = "warn" +as_underscore = "warn" assertions_on_result_states = "warn" dbg_macro = "warn" default_union_representation = "warn" @@ -780,7 +777,6 @@ fn_to_numeric_cast_any = "warn" format_push_string = "warn" get_unwrap = "warn" impl_trait_in_params = "warn" -let_underscore_must_use = "warn" let_underscore_untyped = "warn" lossy_float_literal = "warn" mem_forget = "warn" @@ -794,6 +790,7 @@ rest_pat_in_fully_bound_structs = "warn" semicolon_outside_block = "warn" str_to_string = "warn" string_lit_chars_any = "warn" +string_slice = "warn" string_to_string = "warn" suspicious_xor_used_as_pow = "warn" tests_outside_test_module = "warn" @@ -804,6 +801,7 @@ unnecessary_safety_doc = "warn" unnecessary_self_imports = "warn" unneeded_field_pattern = "warn" unseparated_literal_suffix = "warn" +#unwrap_used = "warn" # TODO verbose_file_reads = "warn" ################### diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index 97de55ff415c97a6e63a73b6156d991f389151d1..f84fbb928fff2439c56690378dfca6389bd89545 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,15 +29,10 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true -conduit-database.workspace = true conduit-service.workspace = true -futures-util.workspace = true log.workspace = true -loole.workspace = true -regex.workspace = true ruma.workspace = true serde_json.workspace = true -serde.workspace = true serde_yaml.workspace = true tokio.workspace = true tracing-subscriber.workspace = true diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 3efe283bf2d2a00d2624d0c2fef53a0bfdc3fff7..53009566df5ca64696b83749fedc09a68ce2a7eb 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -58,7 +58,7 @@ pub(super) async fn parse_pdu(body: Vec<&str>) -> Result<RoomMessageEventContent )); } - let string = body[1..body.len() - 1].join("\n"); + let string = body[1..body.len().saturating_sub(1)].join("\n"); match serde_json::from_str(&string) { Ok(value) => match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { Ok(hash) => { diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 1466c04040329821900671a8c37d5c850742a6b6..22ec81c52d5774690d08d787b9f3843d0dcebf42 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -132,11 +132,13 @@ fn parse_admin_command(command_line: &str) -> Result<AdminCommand, String> { } fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { - let mut ret = Vec::<String>::new(); let argv = parse_command_line(line); + let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); + 'token: for token in argv.into_iter().skip(1) { - let mut choice = Vec::new(); let cmd_ = cmd.clone(); + let mut choice = Vec::new(); + for sub in cmd_.get_subcommands() { let name = sub.get_name(); if *name == token { @@ -144,20 +146,20 @@ fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { ret.push(token); cmd.clone_from(sub); continue 'token; - } - if name.starts_with(&token) { + } else if name.starts_with(&token) { // partial match; add to choices choice.push(name); } } - if choice.is_empty() { - // Nothing found, return original string - ret.push(token); - } else if choice.len() == 1 { + if choice.len() == 1 { // One choice. Add extra space because it's complete - ret.push((*choice.first().expect("only choice")).to_owned()); + let choice = *choice.first().expect("only choice"); + ret.push(choice.to_owned()); ret.push(String::new()); + } else if choice.is_empty() { + // Nothing found, return original string + ret.push(token); } else { // Find the common prefix ret.push(common_prefix(&choice).into()); @@ -168,9 +170,8 @@ fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { } // Return from no completion. Needs a space though. - let mut ret = ret.join(" "); - ret.push(' '); - ret + ret.push(String::new()); + ret.join(" ") } // Parse chat messages from the admin room into an AdminCommand object diff --git a/src/admin/mod.rs b/src/admin/mod.rs index f2e35d80f9e497b997cd46abb72c39023667b375..6a47bc7450140abb9e236a2478f432e7f19682ed 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -28,7 +28,6 @@ mod_dtor! {} /// Install the admin command handler -#[allow(clippy::let_underscore_must_use)] pub async fn init() { _ = services() .admin @@ -45,7 +44,6 @@ pub async fn init() { } /// Uninstall the admin command handler -#[allow(clippy::let_underscore_must_use)] pub async fn fini() { _ = services().admin.handle.write().await.take(); _ = services() diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index 39eb7c47574111e21f65d3f439c2e548b538553c..30c30c6e62fcec89f9f9addccc8bcb0d8032ecef 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -1,7 +1,5 @@ use api::client::leave_room; -use ruma::{ - events::room::message::RoomMessageEventContent, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, -}; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use tracing::{debug, error, info, warn}; use super::{super::Service, RoomModerationCommand}; @@ -124,9 +122,7 @@ async fn ban_room( .is_admin(local_user) .unwrap_or(true)) }) - }) - .collect::<Vec<OwnedUserId>>() - { + }) { debug!( "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, &room_id @@ -153,9 +149,7 @@ async fn ban_room( .is_admin(local_user) .unwrap_or(false)) }) - }) - .collect::<Vec<OwnedUserId>>() - { + }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); if let Err(e) = leave_room(&local_user, &room_id, None).await { error!( @@ -191,7 +185,10 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo )); } - let rooms_s = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); + let rooms_s = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::<Vec<_>>(); let admin_room_alias = &services().globals.admin_alias; @@ -332,9 +329,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo .is_admin(local_user) .unwrap_or(true)) }) - }) - .collect::<Vec<OwnedUserId>>() - { + }) { debug!( "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, room_id @@ -361,9 +356,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo .is_admin(local_user) .unwrap_or(false)) }) - }) - .collect::<Vec<OwnedUserId>>() - { + }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); if let Err(e) = leave_room(&local_user, &room_id, None).await { error!( diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index 2ab710339768d254e6c7d596eef718b87e3ee739..06007fb82de497dfc05fa37e3d3506ec757065c6 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -1,24 +1,17 @@ -use conduit::{warn, Error, Result}; +use conduit::{utils::time, warn, Error, Result}; use ruma::events::room::message::RoomMessageEventContent; use crate::services; pub(super) async fn uptime(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let seconds = services() + let elapsed = services() .server .started .elapsed() - .expect("standard duration") - .as_secs(); - let result = format!( - "up {} days, {} hours, {} minutes, {} seconds.", - seconds / 86400, - (seconds % 86400) / 60 / 60, - (seconds % 3600) / 60, - seconds % 60, - ); - - Ok(RoomMessageEventContent::notice_plain(result)) + .expect("standard duration"); + + let result = time::pretty(elapsed); + Ok(RoomMessageEventContent::notice_plain(format!("{result}."))) } pub(super) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventContent> { diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 6dc60713a6f521d5935459db2ae602110f0b4153..884e1d29cacbfda9c37f8fd65e527b19ffc57d4c 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -23,7 +23,7 @@ pub(super) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> { match services().users.list_local_users() { Ok(users) => { let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += &users.join("\n"); + plain_msg += users.join("\n").as_str(); plain_msg += "\n```"; Ok(RoomMessageEventContent::notice_markdown(plain_msg)) @@ -195,7 +195,10 @@ pub(super) async fn deactivate_all( )); } - let usernames = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); + let usernames = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::<Vec<_>>(); let mut user_ids: Vec<OwnedUserId> = Vec::with_capacity(usernames.len()); let mut admins = Vec::new(); diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index a27924c634fc2c13356e30e52dbc2a8d39be008e..45cae73d42e720446124f024231ad9205ff12a4f 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -44,6 +44,7 @@ conduit-service.workspace = true futures-util.workspace = true hmac.workspace = true http.workspace = true +http-body-util.workspace = true hyper.workspace = true image.workspace = true ipaddress.workspace = true @@ -56,7 +57,6 @@ serde_html_form.workspace = true serde_json.workspace = true serde.workspace = true sha-1.workspace = true -thiserror.workspace = true tokio.workspace = true tracing.workspace = true webpage.workspace = true diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 6f08987524d4f7ae459ae95fc10ad1e5aed8e66e..7bb02a606500daa54a4a35993064784e363bd9fd 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -1,9 +1,9 @@ use std::{ - cmp, collections::{hash_map, BTreeMap, HashMap, HashSet}, - time::{Duration, Instant}, + time::Instant, }; +use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -18,15 +18,11 @@ DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; +use service::user_is_local; use tracing::debug; use super::SESSION_ID_LENGTH; -use crate::{ - service::user_is_local, - services, - utils::{self}, - Error, Result, Ruma, -}; +use crate::{services, Ruma}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -357,11 +353,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( .get(server) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off query from {:?}", server); + const MIN: u64 = 5 * 60; + const MAX: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { + debug!("Backing off query from {server:?}"); return (server, Err(Error::BadServerResponse("bad query, still backing off"))); } } diff --git a/src/api/client/media.rs b/src/api/client/media.rs index cfb10c3dd611f1c69393749b2768a15707f09b07..44978caf2a934de12cc533093d66ad961cd1a8af 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -2,6 +2,7 @@ use std::{io::Cursor, sync::Arc, time::Duration}; +use conduit::{debug, error, utils::math::ruma_from_usize, warn}; use image::io::Reader as ImgReader; use ipaddress::IPAddress; use reqwest::Url; @@ -12,7 +13,6 @@ get_media_preview, }, }; -use tracing::{debug, error, warn}; use webpage::HTML; use crate::{ @@ -44,7 +44,7 @@ pub(crate) async fn get_media_config_route( _body: Ruma<get_media_config::v3::Request>, ) -> Result<get_media_config::v3::Response> { Ok(get_media_config::v3::Response { - upload_size: services().globals.max_request_size().into(), + upload_size: ruma_from_usize(services().globals.config.max_request_size), }) } diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 07a585fd29c48919c74270e0e70c6edc9658b185..0792ed89f5c9f9ed49bcf16bb13eb315d7c2b308 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1,13 +1,16 @@ use std::{ - cmp, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, net::IpAddr, sync::Arc, - time::{Duration, Instant}, + time::Instant, }; use axum_client_ip::InsecureClientIp; -use conduit::utils::mutex_map; +use conduit::{ + debug, error, info, trace, utils, + utils::{math::continue_exponential_backoff_secs, mutex_map}, + warn, Error, PduEvent, Result, +}; use ruma::{ api::{ client::{ @@ -35,7 +38,6 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::sending::convert_to_outgoing_federation_event; use tokio::sync::RwLock; -use tracing::{debug, error, info, trace, warn}; use crate::{ client::{update_avatar_url, update_displayname}, @@ -43,7 +45,7 @@ pdu::{gen_event_id_canonical_json, PduBuilder}, server_is_ours, user_is_local, }, - services, utils, Error, PduEvent, Result, Ruma, + services, Ruma, }; /// Checks if the room is banned in any way possible and the sender user is not @@ -1363,11 +1365,10 @@ pub async fn validate_and_add_event_id( .get(&event_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); + const MIN: u64 = 60 * 5; + const MAX: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { + debug!("Backing off from {event_id}"); return Err(Error::BadServerResponse("bad event, still backing off")); } } @@ -1681,8 +1682,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { .filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::<HashSet<OwnedServerName>>(), + .map(|user| user.server_name().to_owned()), ); debug!("servers in remote_leave_room: {servers:?}"); diff --git a/src/api/client/space.rs b/src/api/client/space.rs index e00171c33bc50dcc30c867c4da574bed6be970fd..0cf1b1073de1dff25ee9f656e290ff6caf003210 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -47,7 +47,7 @@ pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) &body.room_id, limit.try_into().unwrap_or(10), key.map_or(vec![], |token| token.short_room_ids), - max_depth.try_into().unwrap_or(3), + max_depth.into(), body.suggested_only, ) .await diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 2ea766a4a2d50980b830422a3f81ccd5a96ff528..45a1c75b3eae422cb9ad3588fc34707b9e787d7e 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -1,10 +1,15 @@ use std::{ + cmp, cmp::Ordering, collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, time::Duration, }; -use conduit::PduCount; +use conduit::{ + error, + utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + PduCount, +}; use ruma::{ api::client::{ filter::{FilterDefinition, LazyLoadOptions}, @@ -27,7 +32,7 @@ serde::Raw, uint, DeviceId, EventId, OwnedUserId, RoomId, UInt, UserId, }; -use tracing::{error, Instrument as _, Span}; +use tracing::{Instrument as _, Span}; use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse}; @@ -298,15 +303,9 @@ pub(crate) async fn sync_events_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or_default(); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - - #[allow(clippy::let_underscore_must_use)] - { - _ = tokio::time::timeout(duration, watcher).await; - } + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; } Ok(response) @@ -975,8 +974,8 @@ async fn load_joined_room( }, summary: RoomSummary { heroes, - joined_member_count: joined_member_count.map(|n| (n as u32).into()), - invited_member_count: invited_member_count.map(|n| (n as u32).into()), + joined_member_count: joined_member_count.map(ruma_from_u64), + invited_member_count: invited_member_count.map(ruma_from_u64), }, unread_notifications: UnreadNotificationsCount { highlight_count, @@ -1026,7 +1025,7 @@ fn load_timeline( // Take the last events for the timeline timeline_pdus = non_timeline_pdus .by_ref() - .take(limit as usize) + .take(usize_from_u64_truncated(limit)) .collect::<Vec<_>>() .into_iter() .rev() @@ -1300,7 +1299,7 @@ pub(crate) async fn sync_events_v4_route( r.0, UInt::try_from(all_joined_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), ); - let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec(); + let room_ids = all_joined_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec(); new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { let todo_room = todo_rooms @@ -1333,7 +1332,7 @@ pub(crate) async fn sync_events_v4_route( } }) .collect(), - count: UInt::from(all_joined_rooms.len() as u32), + count: ruma_from_usize(all_joined_rooms.len()), }, ); @@ -1529,20 +1528,22 @@ pub(crate) async fn sync_events_v4_route( prev_batch, limited, joined_count: Some( - (services() + services() .rooms .state_cache .room_joined_count(room_id)? - .unwrap_or(0) as u32) - .into(), + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), ), invited_count: Some( - (services() + services() .rooms .state_cache .room_invited_count(room_id)? - .unwrap_or(0) as u32) - .into(), + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), ), num_live: None, // Count events in timeline greater than global sync counter timestamp: None, @@ -1557,14 +1558,9 @@ pub(crate) async fn sync_events_v4_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - #[allow(clippy::let_underscore_must_use)] - { - _ = tokio::time::timeout(duration, watcher).await; - } + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; } Ok(sync_events::v4::Response { diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index 77cac0fa58e9fe261acea4b30094acd57958946c..e39db94e864865a7391119ff3c5aa2063bac7cb1 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -1,12 +1,12 @@ use axum_client_ip::InsecureClientIp; -use conduit::{warn, RumaResponse}; +use conduit::warn; use ruma::{ api::client::{error::ErrorKind, membership::mutual_rooms, room::get_summary}, events::room::member::MembershipState, OwnedRoomId, }; -use crate::{services, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms` /// diff --git a/src/api/mod.rs b/src/api/mod.rs index 8e30a518468247375e0e01b56d396108a92f61a7..6adf2d393cc3949803e8f694f5855977b07382fe 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,7 +9,7 @@ pub(crate) use conduit::{debug_info, debug_warn, utils, Error, Result}; pub(crate) use service::{pdu::PduEvent, services, user_is_local}; -pub(crate) use crate::router::{Ruma, RumaResponse}; +pub(crate) use self::router::{Ruma, RumaResponse}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/api/router/mod.rs b/src/api/router/mod.rs index 2c439d655297400b1e0e6b3a6299e7642a79ba84..c3e08c5bc13a5de75d7dfbed444fe59ffe2cb606 100644 --- a/src/api/router/mod.rs +++ b/src/api/router/mod.rs @@ -1,20 +1,20 @@ mod auth; mod handler; mod request; +mod response; use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; -pub(super) use conduit::error::RumaResponse; use conduit::{debug, debug_warn, trace, warn}; use ruma::{ api::{client::error::ErrorKind, IncomingRequest}, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; -pub(super) use self::handler::RouterExt; use self::{auth::Auth, request::Request}; +pub(super) use self::{handler::RouterExt, response::RumaResponse}; use crate::{service::appservice::RegistrationInfo, services, Error, Result}; /// Extractor for Ruma request structs diff --git a/src/api/router/request.rs b/src/api/router/request.rs index 59639eaa87b803b4ecafa0dcb842301fe504c417..56c766192521a251560394daf4dd7d66908a2fb8 100644 --- a/src/api/router/request.rs +++ b/src/api/router/request.rs @@ -29,12 +29,7 @@ pub(super) async fn from(request: hyper::Request<axum::body::Body>) -> Result<Re let query = serde_html_form::from_str(parts.uri.query().unwrap_or_default()) .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"))?; - let max_body_size = services() - .globals - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"); + let max_body_size = services().globals.config.max_request_size; let body = axum::body::to_bytes(body, max_body_size) .await diff --git a/src/api/router/response.rs b/src/api/router/response.rs new file mode 100644 index 0000000000000000000000000000000000000000..2aaa79faa86d6238dbbcece5a580e85334fd79da --- /dev/null +++ b/src/api/router/response.rs @@ -0,0 +1,24 @@ +use axum::response::{IntoResponse, Response}; +use bytes::BytesMut; +use conduit::{error, Error}; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; + +pub(crate) struct RumaResponse<T>(pub(crate) T); + +impl From<Error> for RumaResponse<UiaaResponse> { + fn from(t: Error) -> Self { Self(t.into()) } +} + +impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { + fn into_response(self) -> Response { + self.0 + .try_into_http_response::<BytesMut>() + .inspect_err(|e| error!("response error: {e}")) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 2c0fc47d0ebc36917438c22844fec3fcba6bce73..e47f673e46dd4a03acf2f7d9200ea137878f23d7 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -53,6 +53,7 @@ sha256_media = [] argon2.workspace = true axum.workspace = true bytes.workspace = true +checked_ops.workspace = true chrono.workspace = true either.workspace = true figment.workspace = true diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index 966cbde3f28fbf7eefef841024002b7c8c127685..5e3c361fa2730a8f7567f85786034bb79584b976 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -12,15 +12,24 @@ #[must_use] pub fn memory_usage() -> String { use mallctl::stats; - let allocated = stats::allocated::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let active = stats::active::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let mapped = stats::mapped::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let metadata = stats::metadata::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let resident = stats::resident::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let retained = stats::retained::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; + + let mibs = |input: Result<usize, mallctl::Error>| { + let input = input.unwrap_or_default(); + let kibs = input / 1024; + let kibs = u32::try_from(kibs).unwrap_or_default(); + let kibs = f64::from(kibs); + kibs / 1024.0 + }; + + let allocated = mibs(stats::allocated::read()); + let active = mibs(stats::active::read()); + let mapped = mibs(stats::mapped::read()); + let metadata = mibs(stats::metadata::read()); + let resident = mibs(stats::resident::read()); + let retained = mibs(stats::retained::read()); format!( - "allocated: {allocated:.2} MiB\n active: {active:.2} MiB\n mapped: {mapped:.2} MiB\n metadata: {metadata:.2} \ - MiB\n resident: {resident:.2} MiB\n retained: {retained:.2} MiB\n " + "allocated: {allocated:.2} MiB\nactive: {active:.2} MiB\nmapped: {mapped:.2} MiB\nmetadata: {metadata:.2} \ + MiB\nresident: {resident:.2} MiB\nretained: {retained:.2} MiB\n" ) } diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index b808f196484940d956111c5420d646c3f6b6f71e..f9a5ec4c70bbd30aa7b8991211daa4b7e9be91e8 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -114,7 +114,7 @@ pub struct Config { pub ip_lookup_strategy: u8, #[serde(default = "default_max_request_size")] - pub max_request_size: u32, + pub max_request_size: usize, #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -930,7 +930,7 @@ fn default_dns_timeout() -> u64 { 10 } fn default_ip_lookup_strategy() -> u8 { 5 } -fn default_max_request_size() -> u32 { +fn default_max_request_size() -> usize { 20 * 1024 * 1024 // Default to 20 MB } diff --git a/src/core/config/proxy.rs b/src/core/config/proxy.rs index d823e5e48f574ba10a73dde5c63c250af34a40cd..48f883c6a8a94c3c1e5923aa11b0c9bef1d636bf 100644 --- a/src/core/config/proxy.rs +++ b/src/core/config/proxy.rs @@ -127,6 +127,7 @@ fn more_specific_than(&self, other: &Self) -> bool { impl std::str::FromStr for WildCardedDomain { type Err = std::convert::Infallible; + #[allow(clippy::string_slice)] fn from_str(s: &str) -> Result<Self, Self::Err> { // maybe do some domain validation? Ok(if s.starts_with("*.") { diff --git a/src/core/debug.rs b/src/core/debug.rs index 7522aa7755d28f50050fa70a134ed01d6f0d9542..1f855e520b2262a3b7fb0df49d7789664776bec9 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -14,9 +14,9 @@ macro_rules! debug_event { ( $level:expr, $($x:tt)+ ) => { if cfg!(debug_assertions) && cfg!(not(feature = "dev_release_log_level")) { - ::tracing::event!( $level, $($x)+ ); + ::tracing::event!( $level, $($x)+ ) } else { - ::tracing::debug!( $($x)+ ); + ::tracing::debug!( $($x)+ ) } } } @@ -27,7 +27,7 @@ macro_rules! debug_event { #[macro_export] macro_rules! debug_error { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::ERROR, $($x)+ ); + $crate::debug_event!(::tracing::Level::ERROR, $($x)+ ) } } @@ -37,7 +37,7 @@ macro_rules! debug_error { #[macro_export] macro_rules! debug_warn { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::WARN, $($x)+ ); + $crate::debug_event!(::tracing::Level::WARN, $($x)+ ) } } @@ -47,7 +47,7 @@ macro_rules! debug_warn { #[macro_export] macro_rules! debug_info { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::INFO, $($x)+ ); + $crate::debug_event!(::tracing::Level::INFO, $($x)+ ) } } diff --git a/src/core/error.rs b/src/core/error.rs index ea7acda5c6b228e09571f8105d92aac12d5add76..1959081a32aa40339b43b5947b09c2eb35911065 100644 --- a/src/core/error.rs +++ b/src/core/error.rs @@ -1,28 +1,16 @@ use std::{convert::Infallible, fmt}; -use axum::response::{IntoResponse, Response}; use bytes::BytesMut; use http::StatusCode; use http_body_util::Full; use ruma::{ - api::{ - client::{ - error::ErrorKind::{ - Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, ThreepidAuthFailed, - ThreepidDenied, TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, - WrongRoomKeysVersion, - }, - uiaa::{UiaaInfo, UiaaResponse}, - }, - OutgoingResponse, - }, + api::{client::uiaa::UiaaResponse, OutgoingResponse}, OwnedServerName, }; -use thiserror::Error; use crate::{debug_error, error}; -#[derive(Error)] +#[derive(thiserror::Error)] pub enum Error { // std #[error("{0}")] @@ -35,6 +23,12 @@ pub enum Error { FromUtf8Error(#[from] std::string::FromUtf8Error), #[error("{0}")] TryFromSliceError(#[from] std::array::TryFromSliceError), + #[error("{0}")] + TryFromIntError(#[from] std::num::TryFromIntError), + #[error("{0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("{0}")] + ParseFloatError(#[from] std::num::ParseFloatError), // third-party #[error("Regex error: {0}")] @@ -49,9 +43,17 @@ pub enum Error { Extension(#[from] axum::extract::rejection::ExtensionRejection), #[error("{0}")] Path(#[from] axum::extract::rejection::PathRejection), + #[error("{0}")] + Http(#[from] http::Error), // ruma #[error("{0}")] + IntoHttpError(#[from] ruma::api::error::IntoHttpError), + #[error("{0}")] + RumaError(#[from] ruma::api::client::error::Error), + #[error("uiaa")] + Uiaa(ruma::api::client::uiaa::UiaaInfo), + #[error("{0}")] Mxid(#[from] ruma::IdParseError), #[error("{0}: {1}")] BadRequest(ruma::api::client::error::ErrorKind, &'static str), @@ -63,6 +65,8 @@ pub enum Error { InconsistentRoomState(&'static str, ruma::OwnedRoomId), // conduwuit + #[error("Arithmetic operation failed: {0}")] + Arithmetic(&'static str), #[error("There was a problem with your configuration: {0}")] BadConfig(String), #[error("{0}")] @@ -73,8 +77,6 @@ pub enum Error { BadServerResponse(&'static str), #[error("{0}")] Conflict(&'static str), // This is only needed for when a room alias already exists - #[error("uiaa")] - Uiaa(UiaaInfo), // unique / untyped #[error("{0}")] @@ -95,11 +97,10 @@ pub fn bad_config(message: &str) -> Self { /// Returns the Matrix error code / error kind #[inline] pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { - if let Self::Federation(_, error) = self { - return error.error_kind().unwrap_or_else(|| &Unknown).clone(); - } + use ruma::api::client::error::ErrorKind::Unknown; match self { + Self::Federation(_, error) => ruma_error_kind(error).clone(), Self::BadRequest(kind, _) => kind.clone(), _ => Unknown, } @@ -108,116 +109,139 @@ pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { /// Sanitizes public-facing errors that can leak sensitive information. pub fn sanitized_error(&self) -> String { match self { - Self::Database { - .. - } => String::from("Database error occurred."), - Self::Io { - .. - } => String::from("I/O error occurred."), + Self::Database(..) => String::from("Database error occurred."), + Self::Io(..) => String::from("I/O error occurred."), _ => self.to_string(), } } } -impl From<Infallible> for Error { - fn from(i: Infallible) -> Self { match i {} } +#[inline] +pub fn log(e: &Error) { + error!(?e); } -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } +#[inline] +pub fn debug_log(e: &Error) { + debug_error!(?e); } #[inline] -pub fn log(e: Error) { - error!("{e}"); +pub fn into_log(e: Error) { + error!(?e); drop(e); } #[inline] -pub fn debug_log(e: Error) { - debug_error!("{e}"); +pub fn into_debug_log(e: Error) { + debug_error!(?e); drop(e); } -#[derive(Clone)] -pub struct RumaResponse<T>(pub T); - -impl<T> From<T> for RumaResponse<T> { - fn from(t: T) -> Self { Self(t) } +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } } -impl From<Error> for RumaResponse<UiaaResponse> { - fn from(t: Error) -> Self { t.to_response() } +impl From<Infallible> for Error { + fn from(i: Infallible) -> Self { match i {} } } -impl Error { - pub fn to_response(&self) -> RumaResponse<UiaaResponse> { - use ruma::api::client::error::{Error as RumaError, ErrorBody}; +impl axum::response::IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let response: UiaaResponse = self.into(); + response + .try_into_http_response::<BytesMut>() + .inspect_err(|e| error!(?e)) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} - if let Self::Uiaa(uiaainfo) = self { - return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); +impl From<Error> for UiaaResponse { + fn from(error: Error) -> Self { + if let Error::Uiaa(uiaainfo) = error { + return Self::AuthResponse(uiaainfo); } - if let Self::Federation(origin, error) = self { - let mut error = error.clone(); - error.body = ErrorBody::Standard { - kind: error.error_kind().unwrap_or_else(|| &Unknown).clone(), - message: format!("Answer from {origin}: {error}"), - }; - return RumaResponse(UiaaResponse::MatrixError(error)); - } + let kind = match &error { + Error::Federation(_, ref error) | Error::RumaError(ref error) => ruma_error_kind(error), + Error::BadRequest(kind, _) => kind, + _ => &ruma::api::client::error::ErrorKind::Unknown, + }; + + let status_code = match &error { + Error::Federation(_, ref error) | Error::RumaError(ref error) => error.status_code, + Error::BadRequest(ref kind, _) => bad_request_code(kind), + Error::Conflict(_) => StatusCode::CONFLICT, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; - let message = format!("{self}"); - let (kind, status_code) = match self { - Self::BadRequest(kind, _) => ( - kind.clone(), - match kind { - WrongRoomKeysVersion { - .. - } - | Forbidden { - .. - } - | GuestAccessForbidden - | ThreepidAuthFailed - | UserDeactivated - | ThreepidDenied => StatusCode::FORBIDDEN, - Unauthorized - | UnknownToken { - .. - } - | MissingToken => StatusCode::UNAUTHORIZED, - NotFound | Unrecognized => StatusCode::NOT_FOUND, - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - _ => StatusCode::BAD_REQUEST, - }, - ), - Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), - _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), + let message = match &error { + Error::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), + Error::RumaError(ref error) => ruma_error_message(error), + _ => format!("{error}"), }; - RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind, - message, - }, + let body = ruma::api::client::error::ErrorBody::Standard { + kind: kind.clone(), + message, + }; + + Self::MatrixError(ruma::api::client::error::Error { status_code, - })) + body, + }) } } -impl ::axum::response::IntoResponse for Error { - fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() } -} +fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { + use ruma::api::client::error::ErrorKind::*; + + match kind { + GuestAccessForbidden + | ThreepidAuthFailed + | UserDeactivated + | ThreepidDenied + | WrongRoomKeysVersion { + .. + } + | Forbidden { + .. + } => StatusCode::FORBIDDEN, -impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { - fn into_response(self) -> Response { - match self.0.try_into_http_response::<BytesMut>() { - Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + UnknownToken { + .. } + | MissingToken + | Unauthorized => StatusCode::UNAUTHORIZED, + + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + NotFound | Unrecognized => StatusCode::NOT_FOUND, + + _ => StatusCode::BAD_REQUEST, + } +} + +fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { + if let ruma::api::client::error::ErrorBody::Standard { + message, + .. + } = &error.body + { + return message.to_string(); } + + format!("{error}") +} + +fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { + e.error_kind() + .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) } diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index daa6b8e8ebd1913abccfe234ff0513c1e32cffd2..04d250a6d9701e42652e87749afde8f3c1f99817 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -29,25 +29,25 @@ pub struct Log { #[macro_export] macro_rules! error { - ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ) } } #[macro_export] macro_rules! warn { - ( $($x:tt)+ ) => { ::tracing::warn!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::warn!( $($x)+ ) } } #[macro_export] macro_rules! info { - ( $($x:tt)+ ) => { ::tracing::info!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::info!( $($x)+ ) } } #[macro_export] macro_rules! debug { - ( $($x:tt)+ ) => { ::tracing::debug!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::debug!( $($x)+ ) } } #[macro_export] macro_rules! trace { - ( $($x:tt)+ ) => { ::tracing::trace!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::trace!( $($x)+ ) } } diff --git a/src/core/mod.rs b/src/core/mod.rs index ec536ee2695cd0cf7416d9545927497f40d8c9f7..de8057fadcd355aea6ee92cd833ac63a516c0007 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -10,7 +10,7 @@ pub mod version; pub use config::Config; -pub use error::{Error, RumaResponse}; +pub use error::Error; pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; pub use version::version; diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index c7d9360817a9d499f6a9edd5d74a314534a071bb..0c54812e87032a516adc0ed74e202108527b525b 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -64,7 +64,7 @@ pub struct PduEvent { } impl PduEvent { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { self.unsigned = None; @@ -158,7 +158,7 @@ pub fn copy_redacts(&self) -> (Option<Arc<EventId>>, Box<RawJsonValue>) { (self.redacts.clone(), self.content.clone()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -183,7 +183,7 @@ pub fn to_sync_room_event(&self) -> Raw<AnySyncTimelineEvent> { } /// This only works for events that are also AnyRoomEvents. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -208,7 +208,7 @@ pub fn to_any_event(&self) -> Raw<AnyEphemeralRoomEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -233,7 +233,7 @@ pub fn to_room_event(&self) -> Raw<AnyTimelineEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -258,7 +258,7 @@ pub fn to_message_like_event(&self) -> Raw<AnyMessageLikeEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_state_event(&self) -> Raw<AnyStateEvent> { let mut json = json!({ "content": self.content, @@ -277,7 +277,7 @@ pub fn to_state_event(&self) -> Raw<AnyStateEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { let mut json = json!({ "content": self.content, @@ -295,7 +295,7 @@ pub fn to_sync_state_event(&self) -> Raw<AnySyncStateEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> { let json = json!({ "content": self.content, @@ -307,7 +307,7 @@ pub fn to_stripped_state_event(&self) -> Raw<AnyStrippedStateEvent> { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent> { let json = json!({ "content": self.content, @@ -320,7 +320,7 @@ pub fn to_stripped_spacechild_state_event(&self) -> Raw<HierarchySpaceChildEvent serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_member_event(&self) -> Raw<StateEvent<RoomMemberEventContent>> { let mut json = json!({ "content": self.content, diff --git a/src/core/server.rs b/src/core/server.rs index e76e4d57e9d58143d6ee7f5369ef0b01570fe751..575924d3b86984361fc0097027cd2f86f580af45 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -76,7 +76,10 @@ pub fn reload(&self) -> Result<()> { return Err(Error::Err("Shutdown already in progress".into())); } - self.signal("SIGINT") + self.signal("SIGINT").inspect_err(|_| { + self.stopping.store(false, Ordering::Release); + self.reloading.store(false, Ordering::Release); + }) } pub fn restart(&self) -> Result<()> { @@ -85,6 +88,7 @@ pub fn restart(&self) -> Result<()> { } self.shutdown() + .inspect_err(|_| self.restarting.store(false, Ordering::Release)) } pub fn shutdown(&self) -> Result<()> { @@ -93,6 +97,7 @@ pub fn shutdown(&self) -> Result<()> { } self.signal("SIGTERM") + .inspect_err(|_| self.stopping.store(false, Ordering::Release)) } pub fn signal(&self, sig: &'static str) -> Result<()> { diff --git a/src/core/utils/content_disposition.rs b/src/core/utils/content_disposition.rs index 1c2b066dd10c1e166e6cd033c46b9dbb136dd6d3..be17a731cea8c9b85c672e8a441151d017a87fbb 100644 --- a/src/core/utils/content_disposition.rs +++ b/src/core/utils/content_disposition.rs @@ -66,7 +66,7 @@ pub fn content_disposition_type(content_type: &Option<String>) -> &'static str { /// sanitises the file name for the Content-Disposition using /// `sanitize_filename` crate -#[tracing::instrument] +#[tracing::instrument(level = "debug")] pub fn sanitise_filename(filename: String) -> String { let options = sanitize_filename::Options { truncate: false, diff --git a/src/core/utils/defer.rs b/src/core/utils/defer.rs index 2762d4facdb3652c12bffaed156f8e5fe0636150..9d42e6795bf68b2a263318b8475bfbbc59c88498 100644 --- a/src/core/utils/defer.rs +++ b/src/core/utils/defer.rs @@ -1,17 +1,11 @@ #[macro_export] macro_rules! defer { ($body:block) => { - struct _Defer_<F> - where - F: FnMut(), - { + struct _Defer_<F: FnMut()> { closure: F, } - impl<F> Drop for _Defer_<F> - where - F: FnMut(), - { + impl<F: FnMut()> Drop for _Defer_<F> { fn drop(&mut self) { (self.closure)(); } } diff --git a/src/core/utils/hash/sha256.rs b/src/core/utils/hash/sha256.rs index 6a1f18793de545a4b832ef7e225ad29ac7fe328b..b2e5a94c28222c259597daaef63ac565b0c5bee7 100644 --- a/src/core/utils/hash/sha256.rs +++ b/src/core/utils/hash/sha256.rs @@ -1,6 +1,6 @@ use ring::{digest, digest::SHA256}; -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, level = "debug")] pub(super) fn hash(keys: &[&[u8]]) -> Vec<u8> { // We only hash the pdu's event ids, not the whole pdu let bytes = keys.join(&0xFF); diff --git a/src/core/utils/html.rs b/src/core/utils/html.rs index 3b44a31b8baf090a816d8b447f26b0c6978d2343..fe07b2dd768ab37d913a90e955af223c6b4fe32c 100644 --- a/src/core/utils/html.rs +++ b/src/core/utils/html.rs @@ -6,6 +6,7 @@ /// Copied from librustdoc: /// * <https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs> +#[allow(clippy::string_slice)] impl fmt::Display for Escape<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { // Because the internet is always right, turns out there's not that many @@ -26,7 +27,7 @@ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.write_str(s)?; // NOTE: we only expect single byte characters here - which is fine as long as // we only match single byte characters - last = i + 1; + last = i.saturating_add(1); } if last < s.len() { diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs new file mode 100644 index 0000000000000000000000000000000000000000..155721e7884d7892d6e373fefd994a813c7381f3 --- /dev/null +++ b/src/core/utils/math.rs @@ -0,0 +1,88 @@ +use std::{cmp, time::Duration}; + +pub use checked_ops::checked_ops; + +use crate::{Error, Result}; + +/// Checked arithmetic expression. Returns a Result<R, Error::Arithmetic> +#[macro_export] +macro_rules! checked { + ($($input:tt)*) => { + $crate::utils::math::checked_ops!($($input)*) + .ok_or_else(|| $crate::Error::Arithmetic("operation overflowed or result invalid")) + } +} + +/// in release-mode. Use for performance when the expression is obviously safe. +/// The check remains in debug-mode for regression analysis. +#[cfg(not(debug_assertions))] +#[macro_export] +macro_rules! validated { + ($($input:tt)*) => { + //#[allow(clippy::arithmetic_side_effects)] { + //Some($($input)*) + // .ok_or_else(|| $crate::Error::Arithmetic("this error should never been seen")) + //} + + //NOTE: remove me when stmt_expr_attributes is stable + $crate::checked!($($input)*) + } +} + +#[cfg(debug_assertions)] +#[macro_export] +macro_rules! validated { + ($($input:tt)*) => { $crate::checked!($($input)*) } +} + +/// Returns false if the exponential backoff has expired based on the inputs +#[inline] +#[must_use] +pub fn continue_exponential_backoff_secs(min: u64, max: u64, elapsed: Duration, tries: u32) -> bool { + let min = Duration::from_secs(min); + let max = Duration::from_secs(max); + continue_exponential_backoff(min, max, elapsed, tries) +} + +/// Returns false if the exponential backoff has expired based on the inputs +#[inline] +#[must_use] +pub fn continue_exponential_backoff(min: Duration, max: Duration, elapsed: Duration, tries: u32) -> bool { + let min = min.saturating_mul(tries).saturating_mul(tries); + let min = cmp::min(min, max); + elapsed < min +} + +#[inline] +#[allow(clippy::as_conversions)] +pub fn usize_from_f64(val: f64) -> Result<usize, Error> { + if val < 0.0 { + return Err(Error::Arithmetic("Converting negative float to unsigned integer")); + } + + //SAFETY: <https://doc.rust-lang.org/std/primitive.f64.html#method.to_int_unchecked> + Ok(unsafe { val.to_int_unchecked::<usize>() }) +} + +#[inline] +#[must_use] +pub fn usize_from_ruma(val: ruma::UInt) -> usize { + usize::try_from(val).expect("failed conversion from ruma::UInt to usize") +} + +#[inline] +#[must_use] +pub fn ruma_from_u64(val: u64) -> ruma::UInt { + ruma::UInt::try_from(val).expect("failed conversion from u64 to ruma::UInt") +} + +#[inline] +#[must_use] +pub fn ruma_from_usize(val: usize) -> ruma::UInt { + ruma::UInt::try_from(val).expect("failed conversion from usize to ruma::UInt") +} + +#[inline] +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_possible_truncation)] +pub fn usize_from_u64_truncated(val: u64) -> usize { val as usize } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index f9f1b87ee1898f3a375b59f593761bdab220566f..2b79c3c4288c5a9fb4139ebd0201665fffc489ce 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -5,6 +5,7 @@ pub mod hash; pub mod html; pub mod json; +pub mod math; pub mod mutex_map; pub mod rand; pub mod string; diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs index 1ded8a6d88d99162257694ea548509877584a378..b80671eb90c00f35019a247e3610ab32aad3c020 100644 --- a/src/core/utils/rand.rs +++ b/src/core/utils/rand.rs @@ -15,7 +15,11 @@ pub fn string(length: usize) -> String { #[inline] #[must_use] -pub fn timepoint_secs(range: Range<u64>) -> SystemTime { SystemTime::now() + secs(range) } +pub fn timepoint_secs(range: Range<u64>) -> SystemTime { + SystemTime::now() + .checked_add(secs(range)) + .expect("range does not overflow SystemTime") +} #[must_use] pub fn secs(range: Range<u64>) -> Duration { diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 1f2a65727f5b320a88a287b6177f675ba09f310a..ec423d53b3027f80451b95d92f7dcb586a220f4b 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -9,12 +9,13 @@ /// common_prefix(&input) == "con"; /// ``` #[must_use] +#[allow(clippy::string_slice)] pub fn common_prefix<'a>(choice: &'a [&str]) -> &'a str { choice.first().map_or(EMPTY, move |best| { choice.iter().skip(1).fold(*best, |best, choice| { &best[0..choice - .chars() - .zip(best.chars()) + .char_indices() + .zip(best.char_indices()) .take_while(|&(a, b)| a == b) .count()] }) diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index f5cd0a077ec8853584c3908d46a65e9fd8d26896..add15861fdb782635e1d0f369c255a729a97fcb4 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -62,3 +62,22 @@ fn common_prefix_none() { let output = string::common_prefix(&input); assert_eq!(output, ""); } + +#[test] +fn checked_add() { + use crate::checked; + + let a = 1234; + let res = checked!(a + 1).unwrap(); + assert_eq!(res, 1235); +} + +#[test] +#[should_panic(expected = "overflow")] +fn checked_add_overflow() { + use crate::checked; + + let a = u64::MAX; + let res = checked!(a + 1).expect("overflow"); + assert_eq!(res, 0); +} diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index 7de00e9e307e71f85bee370225e7e5dce3b8833c..9a31632e252e621ea4ed0882daac44fb1d7be372 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -1,8 +1,8 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[inline] #[must_use] -#[allow(clippy::as_conversions)] +#[allow(clippy::as_conversions, clippy::cast_possible_truncation)] pub fn now_millis() -> u64 { UNIX_EPOCH .elapsed() @@ -26,3 +26,80 @@ pub fn format(ts: SystemTime, str: &str) -> String { let dt: DateTime<Utc> = ts.into(); dt.format(str).to_string() } + +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn pretty(d: Duration) -> String { + use Unit::*; + + let fmt = |w, f, u| format!("{w}.{f} {u}"); + let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u); + let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u); + match whole_and_frac(d) { + (Days(whole), frac) => gen64(whole, frac, "days"), + (Hours(whole), frac) => gen64(whole, frac, "hours"), + (Mins(whole), frac) => gen64(whole, frac, "minutes"), + (Secs(whole), frac) => gen64(whole, frac, "seconds"), + (Millis(whole), frac) => gen128(whole, frac, "milliseconds"), + (Micros(whole), frac) => gen128(whole, frac, "microseconds"), + (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"), + } +} + +/// Return a pair of (whole part, frac part) from a duration where. The whole +/// part is the largest Unit containing a non-zero value, the frac part is a +/// rational remainder left over. +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_precision_loss)] +pub fn whole_and_frac(d: Duration) -> (Unit, f64) { + use Unit::*; + + let whole = whole_unit(d); + ( + whole, + match whole { + Days(_) => (d.as_secs() % 86_400) as f64 / 86_400.0, + Hours(_) => (d.as_secs() % 3_600) as f64 / 3_600.0, + Mins(_) => (d.as_secs() % 60) as f64 / 60.0, + Secs(_) => f64::from(d.subsec_millis()) / 1000.0, + Millis(_) => f64::from(d.subsec_micros()) / 1000.0, + Micros(_) => f64::from(d.subsec_nanos()) / 1000.0, + Nanos(_) => 0.0, + }, + ) +} + +/// Return the largest Unit which represents the duration. The value is +/// rounded-down, but never zero. +#[must_use] +pub fn whole_unit(d: Duration) -> Unit { + use Unit::*; + + match d.as_secs() { + 86_400.. => Days(d.as_secs() / 86_400), + 3_600..=86_399 => Hours(d.as_secs() / 3_600), + 60..=3_599 => Mins(d.as_secs() / 60), + + _ => match d.as_micros() { + 1_000_000.. => Secs(d.as_secs()), + 1_000..=999_999 => Millis(d.subsec_millis().into()), + + _ => match d.as_nanos() { + 1_000.. => Micros(d.subsec_micros().into()), + + _ => Nanos(d.subsec_nanos().into()), + }, + }, + } +} + +#[derive(Eq, PartialEq, Clone, Copy, Debug)] +pub enum Unit { + Days(u64), + Hours(u64), + Mins(u64), + Secs(u64), + Millis(u128), + Micros(u128), + Nanos(u128), +} diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 6e95236b0a29ed468fa7a0a42694a8c9e42ae8aa..8b0d3fc30a6512a2fcc2aa44d41eefd1fbd220cd 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -37,7 +37,6 @@ zstd_compression = [ [dependencies] conduit-core.workspace = true log.workspace = true -ruma.workspace = true rust-rocksdb.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/database.rs b/src/database/database.rs index 06a2d88fa4f3d70ffc768ce24aaf17ea67e936f7..44bb655ccaa9fb38f90bf2fadac9dbeb7a0402f4 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -6,7 +6,7 @@ pub struct Database { pub db: Arc<Engine>, - pub map: Maps, + map: Maps, } impl Database { diff --git a/src/database/engine.rs b/src/database/engine.rs index 7c9522e1eb9721f5fb8b8950bc1a086e9b680658..d02ecf58d0247465be65d96136dc6c06db2f7e3d 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -30,6 +30,7 @@ pub struct Engine { pub(crate) type Db = DBWithThreadMode<MultiThreaded>; impl Engine { + #[tracing::instrument(skip_all)] pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { let config = &server.config; let cache_capacity_bytes = config.db_cache_capacity_mb * 1024.0 * 1024.0; @@ -51,7 +52,7 @@ pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { if config.rocksdb_repair { warn!("Starting database repair. This may take a long time..."); if let Err(e) = Db::repair(&db_opts, &config.database_path) { - error!("Repair failed: {:?}", e); + error!("Repair failed: {e:?}"); } } @@ -76,9 +77,9 @@ pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { let db = res.or_else(or_else)?; info!( - "Opened database at sequence number {} in {:?}", - db.latest_sequence_number(), - load_time.elapsed() + sequence = %db.latest_sequence_number(), + time = ?load_time.elapsed(), + "Opened database." ); Ok(Arc::new(Self { @@ -93,15 +94,16 @@ pub(crate) fn open(server: &Arc<Server>) -> Result<Arc<Self>> { })) } + #[tracing::instrument(skip(self))] pub(crate) fn open_cf(&self, name: &str) -> Result<Arc<BoundColumnFamily<'_>>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { - debug!("Creating new column family in database: {}", name); + debug!("Creating new column family in database: {name}"); let mut col_cache = self.col_cache.write().expect("locked"); let opts = cf_options(&self.server.config, name, self.opts.clone(), &mut col_cache); if let Err(e) = self.db.create_cf(name, &opts) { - error!("Failed to create new column family: {e}"); + error!(?name, "Failed to create new column family: {e}"); return or_else(e); } @@ -134,34 +136,34 @@ pub(crate) fn uncork(&self) { .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub fn memory_usage(&self) -> Result<String> { let mut res = String::new(); let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&self.row_cache])).or_else(or_else)?; + let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; writeln!( res, "Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow cache: {:.2} MiB", - stats.mem_table_total as f64 / 1024.0 / 1024.0, - stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, - stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, - self.row_cache.get_usage() as f64 / 1024.0 / 1024.0, - ) - .expect("should be able to write to string buffer"); + mibs(stats.mem_table_total), + mibs(stats.mem_table_unflushed), + mibs(stats.mem_table_readers_total), + mibs(u64::try_from(self.row_cache.get_usage())?), + )?; for (name, cache) in &*self.col_cache.read().expect("locked") { - writeln!(res, "{} cache: {:.2} MiB", name, cache.get_usage() as f64 / 1024.0 / 1024.0,) - .expect("should be able to write to string buffer"); + writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; } Ok(res) } + #[tracing::instrument(skip(self), level = "debug")] pub fn cleanup(&self) -> Result<()> { debug!("Running flush_opt"); let flushoptions = rocksdb::FlushOptions::default(); result(DBCommon::flush_opt(&self.db, &flushoptions)) } + #[tracing::instrument(skip(self))] pub fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { let config = &self.server.config; let path = config.database_backup_path.as_ref(); @@ -214,8 +216,7 @@ pub fn backup_list(&self) -> Result<String> { rfc2822_from_seconds(info.timestamp), info.size, info.num_files, - ) - .expect("should be able to write to string buffer"); + )?; } Ok(res) @@ -226,16 +227,16 @@ pub fn file_list(&self) -> Result<String> { Err(e) => Ok(String::from(e)), Ok(files) => { let mut res = String::new(); - writeln!(res, "| lev | sst | keys | dels | size | column |").expect("written to string buffer"); - writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |").expect("written to string buffer"); + writeln!(res, "| lev | sst | keys | dels | size | column |")?; + writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |")?; for file in files { writeln!( res, "| {} | {:<13} | {:7}+ | {:4}- | {:9} | {} |", file.level, file.name, file.num_entries, file.num_deletions, file.size, file.column_family_name, - ) - .expect("should be able to writeln to string buffer"); + )?; } + Ok(res) }, } diff --git a/src/database/map.rs b/src/database/map.rs index 0b00730723b2802d7c899c5ae166a5b93b2f3237..1b35a72aa04c99e19de3a21dd7952e6327dd1faa 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -233,7 +233,7 @@ fn open(db: &Arc<Engine>, name: &str) -> Result<Arc<ColumnFamily>> { // closing the database (dropping `Engine`). Since `Arc<Engine>` is a sibling // member along with this handle in `Map`, that is prevented. Ok(unsafe { - Arc::decrement_strong_count(cf_ptr); + Arc::increment_strong_count(cf_ptr); Arc::from_raw(cf_ptr) }) } diff --git a/src/database/maps.rs b/src/database/maps.rs index 1e09041c06a8f957e60df376f3f5c852889753d7..de78eaed53d929eda91671113d02224920d34d64 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -8,6 +8,7 @@ pub(crate) fn open(db: &Arc<Engine>) -> Result<Maps> { open_list(db, MAPS) } +#[tracing::instrument(skip_all, level = "debug")] pub(crate) fn open_list(db: &Arc<Engine>, maps: &[&str]) -> Result<Maps> { Ok(maps .iter() diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index ecb09dd718edf9dd6c296e65b5302ed9dd080dd7..5312984acaf2d2bc620a981b695a44abd90098c4 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -54,7 +54,6 @@ axum.workspace = true conduit-admin.workspace = true conduit-api.workspace = true conduit-core.workspace = true -conduit-database.workspace = true conduit-service.workspace = true log.workspace = true tokio.workspace = true @@ -65,7 +64,6 @@ http-body-util.workspace = true http.workspace = true hyper.workspace = true hyper-util.workspace = true -regex.workspace = true ruma.workspace = true sentry.optional = true sentry-tower.optional = true diff --git a/src/router/layers.rs b/src/router/layers.rs index 4fe3516449b489f996938295b6497bf38a7ee582..db664b38aa4779e2109c13a8f20f4ca3cd6b14f1 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -138,15 +138,7 @@ fn cors_layer(_server: &Server) -> CorsLayer { .max_age(Duration::from_secs(86400)) } -fn body_limit_layer(server: &Server) -> DefaultBodyLimit { - DefaultBodyLimit::max( - server - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - ) -} +fn body_limit_layer(server: &Server) -> DefaultBodyLimit { DefaultBodyLimit::max(server.config.max_request_size) } #[allow(clippy::needless_pass_by_value)] #[tracing::instrument(skip_all)] diff --git a/src/router/request.rs b/src/router/request.rs index 2bdd06f8eb6a626c11c4d7aaec570ae4bc632957..9256fb9c98e7db340a4f8f7a5c740595248192bb 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -1,15 +1,11 @@ use std::sync::{atomic::Ordering, Arc}; use axum::{extract::State, response::IntoResponse}; -use conduit::{debug_error, debug_warn, defer, Result, RumaResponse, Server}; +use conduit::{debug, debug_error, debug_warn, defer, error, trace, Error, Result, Server}; use http::{Method, StatusCode, Uri}; -use ruma::api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, -}; -use tracing::{debug, error, trace}; +use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind}; -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn( State(server): State<Arc<Server>>, req: http::Request<axum::body::Body>, next: axum::middleware::Next, ) -> Result<axum::response::Response, StatusCode> { @@ -66,16 +62,15 @@ fn handle_result( ) -> Result<axum::response::Response, StatusCode> { handle_result_log(method, uri, &result); match result.status() { - StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result), + StatusCode::METHOD_NOT_ALLOWED => handle_result_405(method, uri, &result), _ => Ok(result), } } -#[allow(clippy::unnecessary_wraps)] -fn handle_result_403( +fn handle_result_405( _method: &Method, _uri: &Uri, result: &axum::response::Response, ) -> Result<axum::response::Response, StatusCode> { - let error = UiaaResponse::MatrixError(RumaError { + let error = Error::RumaError(RumaError { status_code: result.status(), body: ErrorBody::Standard { kind: ErrorKind::Unrecognized, @@ -83,7 +78,7 @@ fn handle_result_403( }, }); - Ok(RumaResponse(error).into_response()) + Ok(error.into_response()) } fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { diff --git a/src/router/run.rs b/src/router/run.rs index 4f7853d850cd4c63a310f264a9ff1aa895a7c724..3e09823ac56f453cd5a5f430c9a91e4e7aaa43ad 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -17,7 +17,6 @@ /// Main loop base #[tracing::instrument(skip_all)] -#[allow(clippy::let_underscore_must_use)] // various of these are intended pub(crate) async fn run(server: Arc<Server>) -> Result<(), Error> { let app = layers::build(&server)?; diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index 6c406d282fa7feea1b7649ff2bae0452bc53b2d6..8373b74943468c9299f0f7d54a54febc470dcdeb 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -54,7 +54,6 @@ pub(super) async fn serve(server: &Arc<Server>, app: Router, mut shutdown: broad Ok(()) } -#[allow(clippy::let_underscore_must_use)] async fn accept( server: &Arc<Server>, listener: &UnixListener, tasks: &mut JoinSet<()>, mut app: MakeService, builder: server::conn::auto::Builder<TokioExecutor>, conn: (UnixStream, SocketAddr), diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 664743703629b15ea4166d3789fbf96d6dc4ce0b..69d2f799ad57147db8e80f563720ba9f390729d3 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -44,7 +44,7 @@ pub fn get( } /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since")] + #[tracing::instrument(skip_all, name = "since", level = "debug")] pub fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 0a200caebe4c76f4943f2d11056aff768589e734..2f66b1d530fa0f23102ae3e5ef486b4ee7628392 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -45,7 +45,6 @@ pub(super) async fn handle_signal(self: &Arc<Self>, sig: &'static str) { } } - #[allow(clippy::let_underscore_must_use)] pub async fn start(self: &Arc<Self>) { let mut worker_join = self.worker_join.lock().expect("locked"); if worker_join.is_none() { @@ -54,7 +53,6 @@ pub async fn start(self: &Arc<Self>) { } } - #[allow(clippy::let_underscore_must_use)] pub async fn close(self: &Arc<Self>) { self.interrupt(); let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { @@ -97,7 +95,7 @@ async fn worker(self: Arc<Self>) { ReadlineEvent::Line(string) => self.clone().handle(string).await, ReadlineEvent::Interrupted => continue, ReadlineEvent::Eof => break, - ReadlineEvent::Quit => services().server.shutdown().unwrap_or_else(error::log), + ReadlineEvent::Quit => services().server.shutdown().unwrap_or_else(error::into_log), }, Err(error) => match error { ReadlineError::Closed => break, @@ -113,7 +111,6 @@ async fn worker(self: Arc<Self>) { self.worker_join.lock().expect("locked").take(); } - #[allow(clippy::let_underscore_must_use)] async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> { let _suppression = log::Suppress::new(&services().server); @@ -138,7 +135,6 @@ async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> { result } - #[allow(clippy::let_underscore_must_use)] async fn handle(self: Arc<Self>, line: String) { if line.trim().is_empty() { return; diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index ecb827fe9a5ae0ae05c2983d2bc4835db7451d19..254a3d9cd4fba50cac892068d7242b62b30bc82a 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -65,7 +65,10 @@ pub fn next_count(&self) -> Result<u64> { "counter mismatch" ); - *counter = counter.wrapping_add(1); + *counter = counter + .checked_add(1) + .expect("counter must not overflow u64"); + self.global.insert(COUNTER, &counter.to_be_bytes())?; Ok(*counter) @@ -107,7 +110,7 @@ pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { Ok(()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index e171cb9d2689bbc1258ba0de4f0d35b811dc2498..3948d1f50a96252ad09e588fc25919d846c9ffb9 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -836,11 +836,14 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &C for (mut key, value) in roomuserid_joined.iter() { iter_count = iter_count.saturating_add(1); debug_info!(%iter_count); - let first_sep_index = key.iter().position(|&i| i == 0xFF).unwrap(); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); if key .iter() - .get(first_sep_index..=first_sep_index + 1) + .get(first_sep_index..=first_sep_index.saturating_add(1)) .copied() .collect_vec() == vec![0xFF, 0xFF] diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 3c0e66c0ad1b2b07524c145f0dda2c851717f351..11bfc88c393c4975e0128c6da1dd13bcb803ac75 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -189,10 +189,10 @@ pub fn next_count(&self) -> Result<u64> { self.db.next_count() } #[inline] pub fn current_count(&self) -> Result<u64> { Ok(self.db.current_count()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { @@ -201,8 +201,6 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } - pub fn max_request_size(&self) -> u32 { self.config.max_request_size } - pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } pub fn allow_registration(&self) -> bool { self.config.allow_registration } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index caa75d4e346ba8a2cb41902ccd106046c15f1d1b..3cb8fda83a33544f6a547522380de7684a1c13b7 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,10 +1,10 @@ mod data; mod tests; -use std::{collections::HashMap, io::Cursor, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{collections::HashMap, io::Cursor, num::Saturating as Sat, path::PathBuf, sync::Arc, time::SystemTime}; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_error, error, utils, Error, Result, Server}; +use conduit::{checked, debug, debug_error, error, utils, Error, Result, Server}; use data::Data; use image::imageops::FilterType; use ruma::{OwnedMxcUri, OwnedUserId}; @@ -305,36 +305,20 @@ pub async fn get_thumbnail(&self, mxc: &str, width: u32, height: u32) -> Result< image.resize_to_fill(width, height, FilterType::CatmullRom) } else { let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - // - // https://github.com/image-rs/image/blob/6edf8ae492c4bb1dacb41da88681ea74dab1bab3/src/math/utils.rs#L5-L11 - // Calculates the width and height an image should be - // resized to. This preserves aspect ratio, and based - // on the `fill` parameter will either fill the - // dimensions to fit inside the smaller constraint - // (will overflow the specified bounds on one axis to - // preserve aspect ratio), or will shrink so that both - // dimensions are completely contained within the given - // `width` and `height`, with empty space on one axis. - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); + let ratio = Sat(original_width) * Sat(height); + let nratio = Sat(width) * Sat(original_height); let use_width = nratio <= ratio; let intermediate = if use_width { - u64::from(original_height) * u64::from(width) / u64::from(original_width) + Sat(original_height) * Sat(checked!(width / original_width)?) } else { - u64::from(original_width) * u64::from(height) / u64::from(original_height) + Sat(original_width) * Sat(checked!(height / original_height)?) }; + if use_width { - if u32::try_from(intermediate).is_ok() { - (width, intermediate as u32) - } else { - ((u64::from(width) * u64::from(u32::MAX) / intermediate) as u32, u32::MAX) - } - } else if u32::try_from(intermediate).is_ok() { - (intermediate as u32, height) + (width, intermediate.0) } else { - (u32::MAX, (u64::from(height) * u64::from(u32::MAX) / intermediate) as u32) + (intermediate.0, height) } }; diff --git a/src/service/mod.rs b/src/service/mod.rs index 9c83b25c1f3833069a22c8e4398488d4331f895b..4b19073d156c7c011f2596e6bc60e458d8a03aac 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -37,7 +37,6 @@ static SERVICES: RwLock<Option<&Services>> = RwLock::new(None); -#[allow(clippy::let_underscore_must_use)] pub async fn init(server: &Arc<Server>) -> Result<()> { let d = Arc::new(Database::open(server).await?); let s = Box::new(Services::build(server.clone(), d)?); diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 584f1a6d8db82081ef3a49fb052e206aa5c54ca9..f5400379092b5f51544bde8c7220e68cfa3c48c7 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{debug, error, utils, Error, Result}; +use conduit::{checked, debug, error, utils, Error, Result}; use data::Data; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -79,12 +79,16 @@ pub struct Service { timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>, handler_join: Mutex<Option<JoinHandle<()>>>, timeout_remote_users: bool, + idle_timeout: u64, + offline_timeout: u64, } #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; + let idle_timeout_s = config.presence_idle_timeout_s; + let offline_timeout_s = config.presence_offline_timeout_s; let (timer_sender, timer_receiver) = loole::unbounded(); Ok(Arc::new(Self { db: Data::new(args.db), @@ -92,6 +96,8 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { timer_receiver: Mutex::new(timer_receiver), handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, + idle_timeout: checked!(idle_timeout_s * 1_000)?, + offline_timeout: checked!(offline_timeout_s * 1_000)?, })) } @@ -219,7 +225,7 @@ async fn handler(&self) -> Result<()> { loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { - Some(user_id) = presence_timers.next() => process_presence_timer(&user_id)?, + Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, event = receiver.recv_async() => match event { Err(_e) => return Ok(()), Ok((user_id, timeout)) => { @@ -230,43 +236,40 @@ async fn handler(&self) -> Result<()> { } } } -} -async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId { - sleep(timeout).await; + fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + let mut presence_state = PresenceState::Offline; + let mut last_active_ago = None; + let mut status_msg = None; - user_id -} + let presence_event = self.get_presence(user_id)?; -fn process_presence_timer(user_id: &OwnedUserId) -> Result<()> { - let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000; - let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000; - - let mut presence_state = PresenceState::Offline; - let mut last_active_ago = None; - let mut status_msg = None; - - let presence_event = services().presence.get_presence(user_id)?; + if let Some(presence_event) = presence_event { + presence_state = presence_event.content.presence; + last_active_ago = presence_event.content.last_active_ago; + status_msg = presence_event.content.status_msg; + } - if let Some(presence_event) = presence_event { - presence_state = presence_event.content.presence; - last_active_ago = presence_event.content.last_active_ago; - status_msg = presence_event.content.status_msg; - } + let new_state = match (&presence_state, last_active_ago.map(u64::from)) { + (PresenceState::Online, Some(ago)) if ago >= self.idle_timeout => Some(PresenceState::Unavailable), + (PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout => Some(PresenceState::Offline), + _ => None, + }; - let new_state = match (&presence_state, last_active_ago.map(u64::from)) { - (PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable), - (PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline), - _ => None, - }; + debug!( + "Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}" + ); - debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"); + if let Some(new_state) = new_state { + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + } - if let Some(new_state) = new_state { - services() - .presence - .set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + Ok(()) } +} - Ok(()) +async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId { + sleep(timeout).await; + + user_id } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 27b49c281920f1824d1b9a00bac2d969f4378020..ea48ea7c98468528bbb96c605d8b070aebb8f0ce 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -186,7 +186,7 @@ pub async fn send_push_notice( Ok(()) } - #[tracing::instrument(skip(self, user, ruleset, pdu))] + #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] pub fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId, diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index d375561ec8f338f5c3b7c41f6ccc48fc03d020a9..706bf2f8f94c671373ab989107bf1193e88530cb 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -89,19 +89,19 @@ pub async fn resolve_alias( ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { self.db.resolve_local_alias(alias) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> { self.db.local_aliases_for_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { self.db.all_local_aliases() } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index a5771f4abdad18e1a15775ba92a4d63f802ca2a2..5efb36c2a3ecb70b5382f077b8bb4c53fed4f406 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,7 +3,7 @@ sync::{Arc, Mutex}, }; -use conduit::{utils, Result, Server}; +use conduit::{utils, utils::math::usize_from_f64, Result, Server}; use database::{Database, Map}; use lru_cache::LruCache; @@ -16,7 +16,7 @@ impl Data { pub(super) fn new(server: &Arc<Server>, db: &Arc<Database>) -> Self { let config = &server.config; let cache_size = f64::from(config.auth_chain_cache_capacity); - let cache_size = (cache_size * config.conduit_cache_capacity_modifier) as usize; + let cache_size = usize_from_f64(cache_size * config.conduit_cache_capacity_modifier).expect("valid cache size"); Self { shorteventid_authchain: db["shorteventid_authchain"].clone(), auth_chain_cache: Mutex::new(LruCache::new(cache_size)), diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index ca04b1e5d58cdf469e47b812e0a2c7e3f26cdc2e..2919ea7ee939e800f1f794de64094024bd3efbb9 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,7 +5,7 @@ sync::Arc, }; -use conduit::{debug, error, trace, warn, Error, Result}; +use conduit::{debug, error, trace, validated, warn, Error, Result}; use data::Data; use ruma::{api::client::error::ErrorKind, EventId, RoomId}; @@ -48,15 +48,16 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId let started = std::time::Instant::now(); let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, short) in services() + for (i, &short) in services() .rooms .short .multi_get_or_create_shorteventid(starting_events)? .iter() .enumerate() { - let bucket = short % NUM_BUCKETS as u64; - buckets[bucket as usize].insert((*short, starting_events[i])); + let bucket: usize = short.try_into()?; + let bucket: usize = validated!(bucket % NUM_BUCKETS)?; + buckets[bucket].insert((short, starting_events[i])); } debug!( @@ -173,13 +174,13 @@ pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u6 self.db.get_cached_eventid_authchain(key) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>()) diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 3e60831c578d8cd2078980415082d83b7d73bc91..23ec6b6b5b925e4024031ed46c6729f0c8b55738 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -22,15 +22,15 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 9f50ef5802bb84e234143615fb2b69b3a19732c6..0f7919dd5b5a803c48414a03bb8d3e083a768982 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -2,14 +2,16 @@ mod signing_keys; use std::{ - cmp, collections::{hash_map, BTreeMap, HashMap, HashSet}, pin::Pin, sync::Arc, - time::{Duration, Instant}, + time::Instant, }; -use conduit::{debug_error, debug_info, Error, Result}; +use conduit::{ + debug, debug_error, debug_info, error, info, trace, utils::math::continue_exponential_backoff_secs, warn, Error, + Result, +}; use futures_util::Future; pub use parse_incoming_pdu::parse_incoming_pdu; use ruma::{ @@ -29,7 +31,6 @@ uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedUserId, RoomId, RoomVersionId, ServerName, }; use tokio::sync::RwLock; -use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; use crate::{pdu, services, PduEvent}; @@ -190,7 +191,7 @@ pub async fn handle_incoming_pdu<'a>( e.insert((Instant::now(), 1)); }, hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); + *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); }, }; }, @@ -252,14 +253,12 @@ pub async fn handle_prev_pdu<'a>( .get(prev_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - let duration = time.elapsed(); - - if duration < min_duration { + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { debug!( - duration = ?duration, - min_duration = ?min_duration, + ?tries, + duration = ?time.elapsed(), "Backing off from prev_event" ); return Ok(()); @@ -1073,7 +1072,7 @@ pub fn fetch_and_handle_outliers<'a>( let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i = 0; + let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { if let Some((time, tries)) = services() .globals @@ -1083,12 +1082,10 @@ pub fn fetch_and_handle_outliers<'a>( .get(&*next_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = - cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); continue; } } @@ -1097,7 +1094,7 @@ pub fn fetch_and_handle_outliers<'a>( continue; } - i += 1; + i = i.saturating_add(1); if i % 100 == 0 { tokio::task::yield_now().await; } @@ -1191,12 +1188,10 @@ pub fn fetch_and_handle_outliers<'a>( .get(&**next_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = - cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", next_id); + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); continue; } } diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index fdb2be0e7abdcb019c01fe378363ff327db3f5db..2fa5b0df06a6805017ce199130e91ba499e8ba3a 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -3,6 +3,7 @@ time::{Duration, SystemTime}, }; +use conduit::{debug, error, info, trace, warn}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ @@ -19,7 +20,6 @@ }; use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard}; -use tracing::{debug, error, info, trace, warn}; use crate::{services, Error, Result}; diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 185cfd8cb2db8a5b5498155bd020962254e4e1ab..96f623f2e0e7a7ed9f37019a8e12afdfe4f18898 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -39,7 +39,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result<bool> { @@ -47,7 +47,7 @@ pub fn lazy_load_was_sent_before( .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn lazy_load_mark_sent( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount, @@ -58,7 +58,7 @@ pub async fn lazy_load_mark_sent( .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, ) -> Result<()> { @@ -77,7 +77,7 @@ pub async fn lazy_load_confirm_delivery( Ok(()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { self.db.lazy_load_reset(user_id, device_id, room_id) } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index a7326f424acadb0ca73104e8e37d79dab9f11ebb..22bd2092a2a2afe91a89fadc7f138f0808651ae9 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -35,7 +35,7 @@ pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<Canonica pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { self.db.get_outlier_pdu(event_id) } /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu))] + #[tracing::instrument(skip(self, pdu), level = "debug")] pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.db.add_pdu_outlier(event_id, pdu) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index e35c969ddea78ebd264995d2019fed2c38e4df43..05067aa8e596f9963a0c721c2ee17fcebdb22ee2 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -38,7 +38,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self, from, to))] + #[tracing::instrument(skip(self, from, to), level = "debug")] pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), @@ -205,7 +205,7 @@ pub fn relations_until<'a>( if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { for relation in relations.flatten() { if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1 + 1)); + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } pdus.push(relation); @@ -218,19 +218,19 @@ pub fn relations_until<'a>( }) } - #[tracing::instrument(skip(self, room_id, event_ids))] + #[tracing::instrument(skip_all, level = "debug")] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { self.db.mark_as_referenced(room_id, event_ids) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> { self.db.is_event_referenced(room_id, event_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { self.db.is_event_soft_failed(event_id) } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 17acb0b3120acb02bf28cb384121cbb03f21f81c..06eaf65558bd1b6fa27dd028c80d460a99866c8f 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -76,10 +76,12 @@ pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> .iter_from(&first_possible_edu, false) .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { - let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + size_of::<u64>()]) + let count_offset = prefix.len().saturating_add(size_of::<u64>()); + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id_offset = count_offset.saturating_add(1); let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + size_of::<u64>() + 1..]) + utils::string_from_bytes(&k[user_id_offset..]) .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, ) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index e46027b752e1302d6510820d447048cdf70f0d04..9375276eeac85119e349950bc4844554090693e3 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -33,7 +33,7 @@ pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &Rec /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( &'a self, room_id: &RoomId, since: u64, ) -> impl Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a { @@ -41,13 +41,13 @@ pub fn readreceipts_since<'a>( } /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { self.db.private_read_set(room_id, user_id, count) } /// Returns the private read marker. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { self.db.private_read_get(room_id, user_id) } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 7573d218e60660583143431c843ebdfb6b02dfff..082dd432fc4c85a5ff1d1213bce531b68426064f 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -21,17 +21,17 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.index_pdu(shortroomid, pdu_id, message_body) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.deindex_pdu(shortroomid, pdu_id, message_body) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, ) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> { diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index bf6cc87335f64aa0bd1bb90f0dc1b4de0187977d..03d0d43ff2c1acb3e137587094bc52b762fac0ef 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,7 +7,7 @@ sync::Arc, }; -use conduit::debug_info; +use conduit::{checked, debug_info, utils::math::usize_from_f64}; use lru_cache::LruCache; use ruma::{ api::{ @@ -161,11 +161,10 @@ fn from(value: CachedSpaceHierarchySummary) -> Self { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; + let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity); + let cache_size = cache_size * config.conduit_cache_capacity_modifier; Ok(Arc::new(Self { - roomid_spacehierarchy_cache: Mutex::new(LruCache::new( - (f64::from(config.roomid_spacehierarchy_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), + roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)), })) } @@ -447,7 +446,7 @@ fn get_room_summary( } pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<u64>, max_depth: usize, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<u64>, max_depth: u64, suggested_only: bool, ) -> Result<client::space::get_hierarchy::v1::Response> { let mut parents = VecDeque::new(); @@ -508,12 +507,14 @@ pub async fn get_client_hierarchy( } // We have reached the room after where we last left off - if parents.len() + 1 == short_room_ids.len() { + let parents_len = parents.len(); + if checked!(parents_len + 1)? == short_room_ids.len() { populate_results = true; } } - if !children.is_empty() && parents.len() < max_depth { + let parents_len: u64 = parents.len().try_into()?; + if !children.is_empty() && parents_len < max_depth { parents.push_back(current_room.clone()); stack.push(children); } @@ -548,9 +549,8 @@ pub async fn get_client_hierarchy( Some( PaginationToken { short_room_ids, - limit: UInt::new(max_depth as u64).expect("When sent in request it must have been valid UInt"), - max_depth: UInt::new(max_depth as u64) - .expect("When sent in request it must have been valid UInt"), + limit: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"), + max_depth: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"), suggested_only, } .to_string(), diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 7964a5fb8b7df399f543f763f9709cfbae0fa863..52ee89d1de9390b5ad78042571c96611bf4f435e 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -116,7 +116,7 @@ pub async fn force_state( /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed))] + #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, ) -> Result<u64> { @@ -184,7 +184,7 @@ pub fn set_event_state( /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu))] + #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { let shorteventid = services() .rooms @@ -257,7 +257,7 @@ pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { } } - #[tracing::instrument(skip(self, invite_event))] + #[tracing::instrument(skip(self, invite_event), level = "debug")] pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> { let mut state = Vec::new(); // Add recommended events @@ -313,7 +313,7 @@ pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw< } /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self, mutex_lock))] + #[tracing::instrument(skip(self, mutex_lock), level = "debug")] pub fn set_room_state( &self, room_id: &RoomId, @@ -324,7 +324,7 @@ pub fn set_room_state( } /// Returns the room's version. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { let create_event = services() .rooms @@ -365,7 +365,7 @@ pub fn set_forward_extremities( } /// This fetches auth events from the current state. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 7fe61134a86564e5bf1263937546175c3b095e5a..a35678573f672ff50ff1a9dbf2011ddc2e1d45cc 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,7 +6,11 @@ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{error, utils::mutex_map, warn, Error, Result}; +use conduit::{ + error, + utils::{math::usize_from_f64, mutex_map}, + warn, Error, Result, +}; use data::Data; use lru_cache::LruCache; use ruma::{ @@ -44,14 +48,15 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; + let server_visibility_cache_capacity = + f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier; + let user_visibility_cache_capacity = + f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier; + Ok(Arc::new(Self { db: Data::new(args.db), - server_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - user_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), + server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)), + user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)), })) } @@ -76,7 +81,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { self.db.state_full_ids(shortstatehash).await } @@ -87,7 +92,7 @@ pub async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEven /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<EventId>>> { @@ -281,14 +286,14 @@ pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> R pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { self.db.pdu_shortstatehash(event_id) } /// Returns the full room state. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { self.db.room_state_full(room_id).await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<EventId>>> { @@ -297,7 +302,7 @@ pub fn room_state_get_id( /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<PduEvent>>> { diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index f79ee6785cb6129d78896c29f47fe8f1afb24c8d..2b9fbe941602429f87fda683ef4a3f28ad29aee8 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -213,7 +213,7 @@ pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { Ok(()) } - #[tracing::instrument(skip(self, room_id, appservice))] + #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { let maybe = self .appservice_in_room_cache @@ -249,7 +249,7 @@ pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrat } /// Makes a user forget a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -266,7 +266,7 @@ pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { } /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_servers<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { @@ -286,7 +286,7 @@ pub(super) fn room_servers<'a>( })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { let mut key = server.as_bytes().to_vec(); key.push(0xFF); @@ -297,7 +297,7 @@ pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Re /// Returns an iterator of all rooms a server participates in (as far as we /// know). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn server_rooms<'a>( &'a self, server: &ServerName, ) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { @@ -318,8 +318,10 @@ pub(super) fn server_rooms<'a>( } /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self))] - pub(super) fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { + #[tracing::instrument(skip(self), level = "debug")] + pub(super) fn room_members<'a>( + &'a self, room_id: &RoomId, + ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + Send + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -348,7 +350,7 @@ pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Ite /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn active_local_users_in_room<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { @@ -359,7 +361,7 @@ pub(super) fn active_local_users_in_room<'a>( } /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.roomid_joinedcount .get(room_id.as_bytes())? @@ -368,7 +370,7 @@ pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> } /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.roomid_invitedcount .get(room_id.as_bytes())? @@ -377,7 +379,7 @@ pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> } /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_useroncejoined<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { @@ -402,7 +404,7 @@ pub(super) fn room_useroncejoined<'a>( } /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_members_invited<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> { @@ -426,7 +428,7 @@ pub(super) fn room_members_invited<'a>( ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); @@ -441,7 +443,7 @@ pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Res }) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); @@ -454,7 +456,7 @@ pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Resul } /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> { Box::new( self.userroomid_joined @@ -474,7 +476,7 @@ pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = R } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -501,7 +503,7 @@ pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEven ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn invite_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { @@ -520,7 +522,7 @@ pub(super) fn invite_state( .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn left_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { @@ -540,7 +542,7 @@ pub(super) fn left_state( } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -567,7 +569,7 @@ pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIte ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -576,7 +578,7 @@ pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<b Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -585,7 +587,7 @@ pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<boo Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -594,7 +596,7 @@ pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bo Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -603,7 +605,7 @@ pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn servers_invite_via<'a>( &'a self, room_id: &RoomId, ) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { @@ -629,7 +631,7 @@ pub(super) fn servers_invite_via<'a>( ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { let mut prev_servers = self .servers_invite_via(room_id) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 45e05c295ccd301401f06553bc66d0a0814fc02f..30fd7bf4f54e5801d9e2342d67c575ef8e6ac118 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -219,10 +219,10 @@ pub fn update_membership( Ok(()) } - #[tracing::instrument(skip(self, room_id))] + #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice))] + #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { self.db.appservice_in_room(room_id, appservice) } @@ -230,7 +230,7 @@ pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.mark_as_left(user_id, room_id) } @@ -238,35 +238,35 @@ pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.mark_as_joined(user_id, room_id) } /// Makes a user forget a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { self.db.room_servers(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> { self.db.server_in_room(server, room_id) } /// Returns an iterator of all rooms a server participates in (as far as we /// know). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.server_rooms(server) } /// Returns true if server can see user by sharing at least one room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result<bool> { Ok(self .server_rooms(server) @@ -275,7 +275,7 @@ pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result< } /// Returns true if user_a and user_b share at least one room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> { // Minimize number of point-queries by iterating user with least nr rooms let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { @@ -291,23 +291,23 @@ pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> { } /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + Send + '_ { self.db.room_members(room_id) } /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_joined_count(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a { self.db.local_users_in_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a { @@ -315,80 +315,80 @@ pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterat } /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_invited_count(room_id) } /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.room_useroncejoined(room_id) } /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.room_members_invited(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { self.db.get_invite_count(room_id, user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { self.db.get_left_count(room_id, user_id) } /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.rooms_joined(user_id) } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_invited( &self, user_id: &UserId, ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + '_ { self.db.rooms_invited(user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { self.db.invite_state(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> { self.db.left_state(user_id, room_id) } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_left( &self, user_id: &UserId, ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + '_ { self.db.rooms_left(user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.once_joined(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_joined(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_invited(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { self.db.servers_invite_via(room_id) } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 61c7d6e617e14e825a0c29075ca4ae24ee99f230..337730019ac34b6a5fdc196bb1c5a35c053e38ab 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -38,11 +38,12 @@ pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { let mut added = HashSet::new(); let mut removed = HashSet::new(); - let mut i = size_of::<u64>(); - while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) { + let stride = size_of::<u64>(); + let mut i = stride; + while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i += size_of::<u64>(); + i = checked!(i + stride)?; continue; } if add_mode { @@ -50,7 +51,7 @@ pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i += 2 * size_of::<u64>(); + i = checked!(i + 2 * stride)?; } Ok(StateDiff { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 97f3fb8002fa1ad542b2ce8f2851bab405f64f6f..4b4ea7d4e9f01d5f67a3830a23ddb7a3e9d924eb 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -7,7 +7,7 @@ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{utils, Result}; +use conduit::{checked, utils, utils::math::usize_from_f64, Result}; use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; @@ -55,11 +55,10 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; + let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier; Ok(Arc::new(Self { db: Data::new(args.db), - stateinfo_cache: StdMutex::new(LruCache::new( - (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), + stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } @@ -78,7 +77,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { if let Some(r) = self .stateinfo_cache @@ -163,18 +162,20 @@ pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEve /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states))] + #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states), level = "debug")] pub fn save_state_from_diff( &self, shortstatehash: u64, statediffnew: Arc<HashSet<CompressedStateEvent>>, statediffremoved: Arc<HashSet<CompressedStateEvent>>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); + let statediffnew_len = statediffnew.len(); + let statediffremoved_len = statediffremoved.len(); + let diffsum = checked!(statediffnew_len + statediffremoved_len)?; if parent_states.len() > 3 { // Number of layers // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); + let parent = parent_states.pop().expect("parent must have a state"); let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); @@ -226,10 +227,12 @@ pub fn save_state_from_diff( // 1. We add the current diff on top of the parent layer. // 2. We replace a layer above - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); + let parent = parent_states.pop().expect("parent must have a state"); + let parent_2_len = parent.2.len(); + let parent_3_len = parent.3.len(); + let parent_diff = checked!(parent_2_len + parent_3_len)?; - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 29539847235db1c13dcf13f63cfd450c11742c5f..c4a1a294534ff8854cbcfcaadb02783b1ec26d8e 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,6 +1,6 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; @@ -31,7 +31,7 @@ pub(super) fn threads_until<'a>( .to_vec(); let mut current = prefix.clone(); - current.extend_from_slice(&(until - 1).to_be_bytes()); + current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); Ok(Box::new( self.threadid_userids diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index f3cefe21a25abb7f0644d099d05bdc277321fe1f..dd2686b06a2391ffeea1dc7419ef5dd99c4baa90 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -64,7 +64,7 @@ pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<( .and_then(|relations| serde_json::from_value::<BundledThread>(relations.clone().into()).ok()) { // Thread already existed - relations.count += uint!(1); + relations.count = relations.count.saturating_add(uint!(1)); relations.latest_event = pdu.to_message_like_event(); let content = serde_json::to_value(relations).expect("to_value always works"); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 0d4d945e40284558afc41142fdfca8a6ddf4d2d0..ec975b99d2dca5e9a7787ba9fe1474e1b298ee68 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,7 +4,7 @@ sync::{Arc, Mutex}, }; -use conduit::{error, utils, Error, Result}; +use conduit::{checked, error, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; @@ -281,10 +281,12 @@ pub(super) fn increment_notification_counts( /// Returns the `count` of this pdu's id. pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) + let stride = size_of::<u64>(); + let pdu_id_len = pdu_id.len(); + let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; let second_last_u64 = - utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]); + utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); if matches!(second_last_u64, Ok(0)) { Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index ba987dbdd712bce6936a57dc0fed1671a5d51996..db7af00420f6d944d2c5dcd9f7415e8faa374c75 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -6,7 +6,7 @@ sync::Arc, }; -use conduit::{debug, error, info, utils, utils::mutex_map, warn, Error, Result}; +use conduit::{debug, error, info, utils, utils::mutex_map, validated, warn, Error, Result}; use data::Data; use itertools::Itertools; use rand::prelude::SliceRandom; @@ -99,7 +99,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? .next() @@ -107,7 +107,7 @@ pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent> .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? .last() @@ -115,7 +115,7 @@ pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { self.db.last_timeline_count(sender_user, room_id) } @@ -213,7 +213,7 @@ pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJson } /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { self.db.replace_pdu(pdu_id, pdu_json, pdu) } @@ -670,7 +670,7 @@ pub fn create_hash_and_sign_event( .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) .max() .unwrap_or_else(|| uint!(0)) - + uint!(1); + .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); @@ -1030,7 +1030,7 @@ pub fn all_pdus<'a>( /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { @@ -1039,7 +1039,7 @@ pub fn pdus_until<'a>( /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, ) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { @@ -1240,10 +1240,11 @@ pub async fn backfill_pdu( let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let max = u64::MAX; let count = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); + pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index ed57981710b9cff9dcc32a3910980a848e70607c..6572561857d4c358e9fb681a493c3527cdc9933d 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -136,7 +136,7 @@ pub fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { } } -#[tracing::instrument(skip(key))] +#[tracing::instrument(skip(key), level = "debug")] fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 0dcebcecfa82fc02e92d8289ebf24d03523e9ad2..d7a9c0fc4a10c2cbca137033de546caeeaef473e 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -93,7 +93,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); @@ -106,7 +106,7 @@ pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Re }) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); @@ -119,7 +119,7 @@ pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Res }) } - #[tracing::instrument(skip(self, room_id, pdu_id))] + #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = services() .rooms @@ -131,7 +131,7 @@ pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { self.send_pdu_servers(servers, pdu_id) } - #[tracing::instrument(skip(self, servers, pdu_id))] + #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { let requests = servers .into_iter() @@ -155,7 +155,7 @@ pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, Ok(()) } - #[tracing::instrument(skip(self, server, serialized))] + #[tracing::instrument(skip(self, server, serialized), level = "debug")] pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); @@ -168,7 +168,7 @@ pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Resul }) } - #[tracing::instrument(skip(self, room_id, serialized))] + #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> { let servers = services() .rooms @@ -180,7 +180,7 @@ pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> self.send_edu_servers(servers, serialized) } - #[tracing::instrument(skip(self, servers, serialized))] + #[tracing::instrument(skip(self, servers, serialized), level = "debug")] pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, serialized: Vec<u8>) -> Result<()> { let requests = servers .into_iter() @@ -205,7 +205,7 @@ pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, Ok(()) } - #[tracing::instrument(skip(self, room_id))] + #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = services() .rooms @@ -217,7 +217,7 @@ pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { self.flush_servers(servers) } - #[tracing::instrument(skip(self, servers))] + #[tracing::instrument(skip(self, servers), level = "debug")] pub fn flush_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I) -> Result<()> { let requests = servers.into_iter().map(Destination::Normal); for dest in requests { @@ -255,7 +255,7 @@ pub async fn send_appservice_request<T>( /// Cleanup event data /// Used for instance after we remove an appservice registration - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { self.db .delete_all_requests_for(&Destination::Appservice(appservice_id))?; @@ -271,29 +271,47 @@ fn dispatch(&self, msg: Msg) -> Result<()> { } impl Destination { - #[tracing::instrument(skip(self))] + #[must_use] pub fn get_prefix(&self) -> Vec<u8> { - let mut prefix = match self { + match self { + Self::Normal(server) => { + let len = server.as_bytes().len().saturating_add(1); + + let mut p = Vec::with_capacity(len); + p.extend_from_slice(server.as_bytes()); + p.push(0xFF); + p + }, Self::Appservice(server) => { - let mut p = b"+".to_vec(); + let sigil = b"+"; + let len = sigil + .len() + .saturating_add(server.as_bytes().len()) + .saturating_add(1); + + let mut p = Vec::with_capacity(len); + p.extend_from_slice(sigil); p.extend_from_slice(server.as_bytes()); + p.push(0xFF); p }, Self::Push(user, pushkey) => { - let mut p = b"$".to_vec(); + let sigil = b"$"; + let len = sigil + .len() + .saturating_add(user.as_bytes().len()) + .saturating_add(1) + .saturating_add(pushkey.as_bytes().len()) + .saturating_add(1); + + let mut p = Vec::with_capacity(len); + p.extend_from_slice(sigil); p.extend_from_slice(user.as_bytes()); p.push(0xFF); p.extend_from_slice(pushkey.as_bytes()); + p.push(0xFF); p }, - Self::Normal(server) => { - let mut p = Vec::new(); - p.extend_from_slice(server.as_bytes()); - p - }, - }; - prefix.push(0xFF); - - prefix + } } } diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 773110060ae1f5045d15d0878d4569c3248f6e38..d38509ba8d0cdbafa1af8a10db2fe4229a5373f4 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -484,6 +484,7 @@ fn hostname(&self) -> String { } #[inline] + #[allow(clippy::string_slice)] fn port(&self) -> Option<u16> { match &self { Self::Literal(addr) => Some(addr.port()), diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 54302fd5b5299a272d7b0145f6a2dcba5b8bb018..0fb0d9dc8576207f63fe0b74758dcb0879c1ded1 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -3,10 +3,11 @@ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, sync::Arc, - time::{Duration, Instant}, + time::Instant, }; use base64::{engine::general_purpose, Engine as _}; +use conduit::{debug, error, utils::math::continue_exponential_backoff_secs, warn}; use federation::transactions::send_transaction_message; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -22,7 +23,6 @@ ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tracing::{debug, error, warn}; use super::{appservice, send, Destination, Msg, SendingEvent, Service}; use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; @@ -93,7 +93,7 @@ fn handle_response_err( statuses.entry(dest).and_modify(|e| { *e = match e { TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n + 1, Instant::now()), + TransactionStatus::Retrying(ref n) => TransactionStatus::Failed(n.saturating_add(1), Instant::now()), TransactionStatus::Failed(..) => panic!("Request that was not even running failed?!"), } }); @@ -216,11 +216,9 @@ fn select_events_current(&self, dest: Destination, statuses: &mut CurTransaction .and_modify(|e| match e { TransactionStatus::Failed(tries, time) => { // Fail if a request has failed recently (exponential backoff) - let max_duration = Duration::from_secs(services().globals.config.sender_retry_backoff_limit); - let min_duration = Duration::from_secs(services().globals.config.sender_timeout); - let min_elapsed_duration = min_duration * (*tries) * (*tries); - let min_elapsed_duration = cmp::min(min_elapsed_duration, max_duration); - if time.elapsed() < min_elapsed_duration { + let min = services().globals.config.sender_timeout; + let max = services().globals.config.sender_retry_backoff_limit; + if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) { allow = false; } else { retry = true; diff --git a/src/service/services.rs b/src/service/services.rs index 0ba866936bf7eadbca9add45a871599be0a1aa1a..aeed82043a113ad3ad5129defe69e80dedf144e2 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -132,11 +132,7 @@ pub async fn start(&self) -> Result<()> { if self.globals.allow_check_for_updates() { let handle = globals::updates::start_check_for_updates_task(); - - #[allow(clippy::let_underscore_must_use)] // needed for shutdown - { - _ = self.globals.updates_handle.lock().await.insert(handle); - } + _ = self.globals.updates_handle.lock().await.insert(handle); } debug_info!("Services startup complete."); @@ -159,11 +155,7 @@ pub async fn stop(&self) { debug!("Waiting for update worker..."); if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { updates_handle.abort(); - - #[allow(clippy::let_underscore_must_use)] - { - _ = updates_handle.await; - } + _ = updates_handle.await; } for (name, service) in &self.service { diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 3ba54a935e2f3b9323f63bd04573b13eddb476ab..302eae9db1a6cc86f7e1fce0afa89dd96fbb6e4d 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -463,7 +463,8 @@ pub(super) fn count_one_time_keys( .algorithm(), ) }) { - *counts.entry(algorithm?).or_default() += uint!(1); + let count: &mut UInt = counts.entry(algorithm?).or_default(); + *count = count.saturating_add(uint!(1)); } Ok(counts) @@ -814,7 +815,7 @@ pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &Devic .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()]) + utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::<u64>())..key.len()]) .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, )) }) @@ -928,10 +929,12 @@ pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Opt /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> { + use std::num::Saturating as Sat; + let expires_in = services().globals.config.openid_token_ttl; - let expires_at = utils::millis_since_unix_epoch().saturating_add(expires_in * 1000); + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); - let mut value = expires_at.to_be_bytes().to_vec(); + let mut value = expires_at.0.to_be_bytes().to_vec(); value.extend_from_slice(user_id.as_bytes()); self.openidtoken_expiresatuserid