From 463f1a12875d909afd901de9cabb038793e403e7 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 06:49:47 +0000 Subject: [PATCH 01/47] only use graceful shutdown w/ axum-server fixed Signed-off-by: Jason Volk <jason@zemos.net> --- src/router/run.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/router/run.rs b/src/router/run.rs index 91507772d..cb5d2abf3 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -12,7 +12,7 @@ use std::sync::atomic::Ordering; -use conduit::{debug, debug_info, error, info, trace, Error, Result, Server}; +use conduit::{debug, debug_info, error, info, Error, Result, Server}; use crate::serve; @@ -103,20 +103,20 @@ async fn signal(server: Arc<Server>, tx: Sender<()>, handle: axum_server::Handle } async fn handle_shutdown(server: &Arc<Server>, tx: &Sender<()>, handle: &axum_server::Handle, sig: &str) { - debug!("Received signal {}", sig); + debug!("Received signal {sig}"); if let Err(e) = tx.send(()) { error!("failed sending shutdown transaction to channel: {e}"); } - let pending = server.metrics.requests_spawn_active.load(Ordering::Relaxed); - if pending > 0 { - let timeout = Duration::from_secs(36); - trace!(pending, ?timeout, "Notifying for graceful shutdown"); - handle.graceful_shutdown(Some(timeout)); - } else { - debug!(pending, "Notifying for immediate shutdown"); - handle.shutdown(); - } + let timeout = Duration::from_secs(36); + debug!( + ?timeout, + spawn_active = ?server.metrics.requests_spawn_active.load(Ordering::Relaxed), + handle_active = ?server.metrics.requests_handle_active.load(Ordering::Relaxed), + "Notifying for graceful shutdown" + ); + + handle.graceful_shutdown(Some(timeout)); } async fn handle_services_poll( -- GitLab From 8b6018d77da5a0caf5800410354b07314a836a3f Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 08:05:25 +0000 Subject: [PATCH 02/47] de-global services() from api Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 4 +- src/admin/room/room_moderation_commands.rs | 8 +- src/admin/user/commands.rs | 15 +- src/api/client/account.rs | 129 +++---- src/api/client/alias.rs | 38 +- src/api/client/backup.rs | 111 +++--- src/api/client/capabilities.rs | 11 +- src/api/client/config.rs | 29 +- src/api/client/context.rs | 39 +- src/api/client/device.rs | 43 ++- src/api/client/directory.rs | 75 ++-- src/api/client/filter.rs | 15 +- src/api/client/keys.rs | 86 +++-- src/api/client/media.rs | 143 ++++---- src/api/client/membership.rs | 398 +++++++++++---------- src/api/client/message.rs | 57 +-- src/api/client/openid.rs | 9 +- src/api/client/presence.rs | 22 +- src/api/client/profile.rs | 113 +++--- src/api/client/push.rs | 61 ++-- src/api/client/read_marker.rs | 27 +- src/api/client/redact.rs | 11 +- src/api/client/relations.rs | 90 +++-- src/api/client/report.rs | 30 +- src/api/client/room.rs | 139 +++---- src/api/client/search.rs | 31 +- src/api/client/session.rs | 45 ++- src/api/client/space.rs | 9 +- src/api/client/state.rs | 47 +-- src/api/client/sync.rs | 281 ++++++++------- src/api/client/tag.rs | 25 +- src/api/client/threads.rs | 11 +- src/api/client/to_device.rs | 19 +- src/api/client/typing.rs | 15 +- src/api/client/unstable.rs | 52 +-- src/api/client/unversioned.rs | 30 +- src/api/client/user_directory.rs | 19 +- src/api/client/voip.rs | 19 +- src/api/router/args.rs | 23 +- src/api/router/auth.rs | 36 +- src/api/router/request.rs | 9 +- src/api/server/backfill.rs | 23 +- src/api/server/event.rs | 20 +- src/api/server/event_auth.rs | 17 +- src/api/server/get_missing_events.rs | 15 +- src/api/server/hierarchy.rs | 11 +- src/api/server/invite.rs | 30 +- src/api/server/key.rs | 20 +- src/api/server/make_join.rs | 47 +-- src/api/server/make_leave.rs | 15 +- src/api/server/openid.rs | 9 +- src/api/server/publicrooms.rs | 15 +- src/api/server/query.rs | 32 +- src/api/server/send.rs | 2 +- src/api/server/send_join.rs | 53 +-- src/api/server/send_leave.rs | 35 +- src/api/server/state.rs | 30 +- src/api/server/state_ids.rs | 17 +- src/api/server/user.rs | 32 +- src/api/server/well_known.rs | 7 +- src/service/globals/mod.rs | 1 + 61 files changed, 1485 insertions(+), 1320 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 46f716221..7c7f93311 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -530,7 +530,7 @@ pub(super) async fn force_set_room_state_from_server( for result in remote_state_response .pdus .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; @@ -558,7 +558,7 @@ pub(super) async fn force_set_room_state_from_server( for result in remote_state_response .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index 30c30c6e6..c9be44737 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -128,7 +128,7 @@ async fn ban_room( &local_user, &room_id ); - if let Err(e) = leave_room(&local_user, &room_id, None).await { + if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } @@ -151,7 +151,7 @@ async fn ban_room( }) }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(&local_user, &room_id, None).await { + if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -334,7 +334,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, room_id ); - if let Err(e) = leave_room(&local_user, &room_id, None).await { + if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } @@ -358,7 +358,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo }) }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(&local_user, &room_id, None).await { + if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during bulk room banning: {}", &local_user, &room_id, e diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 709e1f579..f6387ee82 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -101,6 +101,7 @@ pub(super) async fn create( if let Some(room_id_server_name) = room.server_name() { match join_room_by_id_helper( + services(), &user_id, room, Some("Automatically joining this room upon registration".to_owned()), @@ -158,9 +159,9 @@ pub(super) async fn deactivate( .rooms_joined(&user_id) .filter_map(Result::ok) .collect(); - update_displayname(user_id.clone(), None, all_joined_rooms.clone()).await?; - update_avatar_url(user_id.clone(), None, None, all_joined_rooms).await?; - leave_all_rooms(&user_id).await; + update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?; + update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?; + leave_all_rooms(services(), &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -262,9 +263,9 @@ pub(super) async fn deactivate_all( .rooms_joined(&user_id) .filter_map(Result::ok) .collect(); - update_displayname(user_id.clone(), None, all_joined_rooms.clone()).await?; - update_avatar_url(user_id.clone(), None, None, all_joined_rooms).await?; - leave_all_rooms(&user_id).await; + update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?; + update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?; + leave_all_rooms(services(), &user_id).await; } }, Err(e) => { @@ -347,7 +348,7 @@ pub(super) async fn force_join_room( let room_id = services().rooms.alias.resolve(&room_id).await?; assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); - join_room_by_id_helper(&user_id, &room_id, None, &[], None).await?; + join_room_by_id_helper(services(), &user_id, &room_id, None, &[], None).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "{user_id} has been joined to {room_id}.", diff --git a/src/api/client/account.rs b/src/api/client/account.rs index d34211bf1..19ac89a08 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -1,5 +1,6 @@ use std::fmt::Write; +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::debug_info; use register::RegistrationKind; @@ -22,7 +23,6 @@ use super::{join_room_by_id_helper, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ service::user_is_local, - services, utils::{self}, Error, Result, Ruma, }; @@ -42,20 +42,21 @@ /// invalid when trying to register #[tracing::instrument(skip_all, fields(%client), name = "register_available")] pub(crate) async fn get_register_available_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_username_availability::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_username_availability::v3::Request>, ) -> Result<get_username_availability::v3::Response> { // Validate user id - let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) + let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services.globals.server_name()) .ok() .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough - if services().users.exists(&user_id)? { + if services.users.exists(&user_id)? { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } - if services() + if services .globals .forbidden_usernames() .is_match(user_id.localpart()) @@ -91,9 +92,9 @@ pub(crate) async fn get_register_available_route( #[allow(clippy::doc_markdown)] #[tracing::instrument(skip_all, fields(%client), name = "register")] pub(crate) async fn register_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<register::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<register::v3::Request>, ) -> Result<register::v3::Response> { - if !services().globals.allow_registration() && body.appservice_info.is_none() { + if !services.globals.allow_registration() && body.appservice_info.is_none() { info!( "Registration disabled and request not from known appservice, rejecting registration attempt for username \ {:?}", @@ -105,8 +106,8 @@ pub(crate) async fn register_route( let is_guest = body.kind == RegistrationKind::Guest; if is_guest - && (!services().globals.allow_guest_registration() - || (services().globals.allow_registration() && services().globals.config.registration_token.is_some())) + && (!services.globals.allow_guest_registration() + || (services.globals.allow_registration() && services.globals.config.registration_token.is_some())) { info!( "Guest registration disabled / registration enabled with token configured, rejecting guest registration \ @@ -121,7 +122,7 @@ pub(crate) async fn register_route( // forbid guests from registering if there is not a real admin user yet. give // generic user error. - if is_guest && services().users.count()? < 2 { + if is_guest && services.users.count()? < 2 { warn!( "Guest account attempted to register before a real admin user has been registered, rejecting \ registration. Guest's initial device name: {:?}", @@ -133,16 +134,16 @@ pub(crate) async fn register_route( let user_id = match (&body.username, is_guest) { (Some(username), false) => { let proposed_user_id = - UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) + UserId::parse_with_server_name(username.to_lowercase(), services.globals.server_name()) .ok() .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services().users.exists(&proposed_user_id)? { + if services.users.exists(&proposed_user_id)? { return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } - if services() + if services .globals .forbidden_usernames() .is_match(proposed_user_id.localpart()) @@ -155,10 +156,10 @@ pub(crate) async fn register_route( _ => loop { let proposed_user_id = UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - services().globals.server_name(), + services.globals.server_name(), ) .unwrap(); - if !services().users.exists(&proposed_user_id)? { + if !services.users.exists(&proposed_user_id)? { break proposed_user_id; } }, @@ -172,13 +173,13 @@ pub(crate) async fn register_route( } else { return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing appservice token.")); } - } else if services().appservice.is_exclusive_user_id(&user_id).await { + } else if services.appservice.is_exclusive_user_id(&user_id).await { return Err(Error::BadRequest(ErrorKind::Exclusive, "User ID reserved by appservice.")); } // UIAA let mut uiaainfo; - let skip_auth = if services().globals.config.registration_token.is_some() { + let skip_auth = if services.globals.config.registration_token.is_some() { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -206,8 +207,8 @@ pub(crate) async fn register_route( if !skip_auth { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services().uiaa.try_auth( - &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), + let (worked, uiaainfo) = services.uiaa.try_auth( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), "".into(), auth, &uiaainfo, @@ -218,8 +219,8 @@ pub(crate) async fn register_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services().uiaa.create( - &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), + services.uiaa.create( + &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"), "".into(), &uiaainfo, &json, @@ -237,25 +238,25 @@ pub(crate) async fn register_route( }; // Create user - services().users.create(&user_id, password)?; + services.users.create(&user_id, password)?; // Default to pretty displayname let mut displayname = user_id.localpart().to_owned(); // If `new_user_displayname_suffix` is set, registration will push whatever // content is set to the user's display name with a space before it - if !services().globals.new_user_displayname_suffix().is_empty() { - write!(displayname, " {}", services().globals.config.new_user_displayname_suffix) + if !services.globals.new_user_displayname_suffix().is_empty() { + write!(displayname, " {}", services.globals.config.new_user_displayname_suffix) .expect("should be able to write to string buffer"); } - services() + services .users .set_displayname(&user_id, Some(displayname.clone())) .await?; // Initial account data - services().account_data.update( + services.account_data.update( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -290,7 +291,7 @@ pub(crate) async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - services() + services .users .create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; @@ -299,7 +300,7 @@ pub(crate) async fn register_route( // log in conduit admin channel if a non-guest user registered if body.appservice_info.is_none() && !is_guest { info!("New user \"{user_id}\" registered on this server."); - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "New user \"{user_id}\" registered on this server from IP {client}." @@ -308,7 +309,7 @@ pub(crate) async fn register_route( } // log in conduit admin channel if a guest registered - if body.appservice_info.is_none() && is_guest && services().globals.log_guest_registrations() { + if body.appservice_info.is_none() && is_guest && services.globals.log_guest_registrations() { info!("New guest user \"{user_id}\" registered on this server."); if let Some(device_display_name) = &body.initial_device_display_name { @@ -317,7 +318,7 @@ pub(crate) async fn register_route( .as_ref() .is_some_and(|device_display_name| !device_display_name.is_empty()) { - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "Guest user \"{user_id}\" with device display name `{device_display_name}` registered on this \ @@ -325,7 +326,7 @@ pub(crate) async fn register_route( ))) .await; } else { - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "Guest user \"{user_id}\" with no device display name registered on this server from IP \ @@ -334,7 +335,7 @@ pub(crate) async fn register_route( .await; } } else { - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "Guest user \"{user_id}\" with no device display name registered on this server from IP {client}.", @@ -347,12 +348,7 @@ pub(crate) async fn register_route( // users Note: the server user, @conduit:servername, is generated first if !is_guest { if let Some(admin_room) = service::admin::Service::get_admin_room()? { - if services() - .rooms - .state_cache - .room_joined_count(&admin_room)? - == Some(1) - { + if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { service::admin::make_user_admin(&user_id, displayname).await?; warn!("Granting {user_id} admin privileges as the first user"); @@ -361,14 +357,14 @@ pub(crate) async fn register_route( } if body.appservice_info.is_none() - && !services().globals.config.auto_join_rooms.is_empty() - && (services().globals.allow_guests_auto_join_rooms() || !is_guest) + && !services.globals.config.auto_join_rooms.is_empty() + && (services.globals.allow_guests_auto_join_rooms() || !is_guest) { - for room in &services().globals.config.auto_join_rooms { - if !services() + for room in &services.globals.config.auto_join_rooms { + if !services .rooms .state_cache - .server_in_room(services().globals.server_name(), room)? + .server_in_room(services.globals.server_name(), room)? { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; @@ -376,10 +372,11 @@ pub(crate) async fn register_route( if let Some(room_id_server_name) = room.server_name() { if let Err(e) = join_room_by_id_helper( + services, &user_id, room, Some("Automatically joining this room upon registration".to_owned()), - &[room_id_server_name.to_owned(), services().globals.server_name().to_owned()], + &[room_id_server_name.to_owned(), services.globals.server_name().to_owned()], None, ) .await @@ -421,7 +418,8 @@ pub(crate) async fn register_route( /// - Triggers device list updates #[tracing::instrument(skip_all, fields(%client), name = "change_password")] pub(crate) async fn change_password_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<change_password::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<change_password::v3::Request>, ) -> Result<change_password::v3::Response> { // Authentication for this endpoint was made optional, but we need // authentication currently @@ -442,7 +440,7 @@ pub(crate) async fn change_password_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services() + let (worked, uiaainfo) = services .uiaa .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { @@ -451,7 +449,7 @@ pub(crate) async fn change_password_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() + services .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); @@ -459,24 +457,24 @@ pub(crate) async fn change_password_route( return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services() + services .users .set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one - for id in services() + for id in services .users .all_device_ids(sender_user) .filter_map(Result::ok) .filter(|id| id != sender_device) { - services().users.remove_device(sender_user, &id)?; + services.users.remove_device(sender_user, &id)?; } } info!("User {sender_user} changed their password."); - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} changed their password." @@ -491,14 +489,16 @@ pub(crate) async fn change_password_route( /// Get `user_id` of the sender user. /// /// Note: Also works for Application Services -pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> { +pub(crate) async fn whoami_route( + State(services): State<crate::State>, body: Ruma<whoami::v3::Request>, +) -> Result<whoami::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let device_id = body.sender_device.clone(); Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: services().users.is_deactivated(sender_user)? && body.appservice_info.is_none(), + is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(), }) } @@ -515,7 +515,8 @@ pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoa /// - Removes ability to log in again #[tracing::instrument(skip_all, fields(%client), name = "deactivate")] pub(crate) async fn deactivate_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<deactivate::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<deactivate::v3::Request>, ) -> Result<deactivate::v3::Response> { // Authentication for this endpoint was made optional, but we need // authentication currently @@ -536,7 +537,7 @@ pub(crate) async fn deactivate_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services() + let (worked, uiaainfo) = services .uiaa .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { @@ -545,7 +546,7 @@ pub(crate) async fn deactivate_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() + services .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); @@ -554,23 +555,23 @@ pub(crate) async fn deactivate_route( } // Remove devices and mark account as deactivated - services().users.deactivate_account(sender_user)?; + services.users.deactivate_account(sender_user)?; // Remove profile pictures and display name - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(sender_user) .filter_map(Result::ok) .collect(); - super::update_displayname(sender_user.clone(), None, all_joined_rooms.clone()).await?; - super::update_avatar_url(sender_user.clone(), None, None, all_joined_rooms).await?; + super::update_displayname(services, sender_user.clone(), None, all_joined_rooms.clone()).await?; + super::update_avatar_url(services, sender_user.clone(), None, None, all_joined_rooms).await?; // Make the user leave all rooms before deactivation - super::leave_all_rooms(sender_user).await; + super::leave_all_rooms(services, sender_user).await; info!("User {sender_user} deactivated their account."); - services() + services .admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {sender_user} deactivated their account." @@ -632,9 +633,9 @@ pub(crate) async fn request_3pid_management_token_via_msisdn_route( /// Currently does not have any ratelimiting, and this isn't very practical as /// there is only one registration token allowed. pub(crate) async fn check_registration_token_validity( - body: Ruma<check_registration_token_validity::v1::Request>, + State(services): State<crate::State>, body: Ruma<check_registration_token_validity::v1::Request>, ) -> Result<check_registration_token_validity::v1::Response> { - let Some(reg_token) = services().globals.config.registration_token.clone() else { + let Some(reg_token) = services.globals.config.registration_token.clone() else { return Err(Error::BadRequest( ErrorKind::forbidden(), "Server does not allow token registration.", diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 5230c25b9..88d1a4e6b 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use rand::seq::SliceRandom; use ruma::{ api::client::{ @@ -8,19 +9,24 @@ }; use tracing::debug; -use crate::{service::server_is_ours, services, Error, Result, Ruma}; +use crate::{ + service::{server_is_ours, Services}, + Error, Result, Ruma, +}; /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// /// Creates a new room alias on this server. -pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> { +pub(crate) async fn create_alias_route( + State(services): State<crate::State>, body: Ruma<create_alias::v3::Request>, +) -> Result<create_alias::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; // this isn't apart of alias_checks or delete alias route because we should // allow removing forbidden room aliases - if services() + if services .globals .forbidden_alias_names() .is_match(body.room_alias.alias()) @@ -28,7 +34,7 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); } - if services() + if services .rooms .alias .resolve_local_alias(&body.room_alias)? @@ -37,7 +43,7 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> return Err(Error::Conflict("Alias already exists.")); } - services() + services .rooms .alias .set_alias(&body.room_alias, &body.room_id, sender_user)?; @@ -50,12 +56,14 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> /// Deletes a room alias from this server. /// /// - TODO: Update canonical alias event -pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> { +pub(crate) async fn delete_alias_route( + State(services): State<crate::State>, body: Ruma<delete_alias::v3::Request>, +) -> Result<delete_alias::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; - if services() + if services .rooms .alias .resolve_local_alias(&body.room_alias)? @@ -64,7 +72,7 @@ pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); } - services() + services .rooms .alias .remove_alias(&body.room_alias, sender_user) @@ -78,11 +86,13 @@ pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> /// # `GET /_matrix/client/v3/directory/room/{roomAlias}` /// /// Resolve an alias locally or over federation. -pub(crate) async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> { +pub(crate) async fn get_alias_route( + State(services): State<crate::State>, body: Ruma<get_alias::v3::Request>, +) -> Result<get_alias::v3::Response> { let room_alias = body.body.room_alias; let servers = None; - let Ok((room_id, pre_servers)) = services() + let Ok((room_id, pre_servers)) = services .rooms .alias .resolve_alias(&room_alias, servers.as_ref()) @@ -91,17 +101,17 @@ pub(crate) async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Resul return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); }; - let servers = room_available_servers(&room_id, &room_alias, &pre_servers); + let servers = room_available_servers(services, &room_id, &room_alias, &pre_servers); debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) } fn room_available_servers( - room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>, + services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>, ) -> Vec<OwnedServerName> { // find active servers in room state cache to suggest - let mut servers: Vec<OwnedServerName> = services() + let mut servers: Vec<OwnedServerName> = services .rooms .state_cache .room_servers(room_id) @@ -127,7 +137,7 @@ fn room_available_servers( .position(|server_name| server_is_ours(server_name)) { servers.swap_remove(server_index); - servers.insert(0, services().globals.server_name().to_owned()); + servers.insert(0, services.globals.server_name().to_owned()); } else if let Some(alias_server_index) = servers .iter() .position(|server| server == room_alias.server_name()) diff --git a/src/api/client/backup.rs b/src/api/client/backup.rs index fb7e7f31c..4ead87776 100644 --- a/src/api/client/backup.rs +++ b/src/api/client/backup.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::client::{ backup::{ @@ -11,16 +12,16 @@ UInt, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `POST /_matrix/client/r0/room_keys/version` /// /// Creates a new backup. pub(crate) async fn create_backup_version_route( - body: Ruma<create_backup_version::v3::Request>, + State(services): State<crate::State>, body: Ruma<create_backup_version::v3::Request>, ) -> Result<create_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = services() + let version = services .key_backups .create_backup(sender_user, &body.algorithm)?; @@ -34,10 +35,10 @@ pub(crate) async fn create_backup_version_route( /// Update information about an existing backup. Only `auth_data` can be /// modified. pub(crate) async fn update_backup_version_route( - body: Ruma<update_backup_version::v3::Request>, + State(services): State<crate::State>, body: Ruma<update_backup_version::v3::Request>, ) -> Result<update_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() + services .key_backups .update_backup(sender_user, &body.version, &body.algorithm)?; @@ -48,20 +49,20 @@ pub(crate) async fn update_backup_version_route( /// /// Get information about the latest backup version. pub(crate) async fn get_latest_backup_info_route( - body: Ruma<get_latest_backup_info::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_latest_backup_info::v3::Request>, ) -> Result<get_latest_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let (version, algorithm) = services() + let (version, algorithm) = services .key_backups .get_latest_backup(sender_user)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; Ok(get_latest_backup_info::v3::Response { algorithm, - count: (UInt::try_from(services().key_backups.count_keys(sender_user, &version)?) + count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?) .expect("user backup keys count should not be that high")), - etag: services().key_backups.get_etag(sender_user, &version)?, + etag: services.key_backups.get_etag(sender_user, &version)?, version, }) } @@ -70,10 +71,10 @@ pub(crate) async fn get_latest_backup_info_route( /// /// Get information about an existing backup. pub(crate) async fn get_backup_info_route( - body: Ruma<get_backup_info::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_backup_info::v3::Request>, ) -> Result<get_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let algorithm = services() + let algorithm = services .key_backups .get_backup(sender_user, &body.version)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; @@ -81,14 +82,12 @@ pub(crate) async fn get_backup_info_route( Ok(get_backup_info::v3::Response { algorithm, count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, version: body.version.clone(), }) } @@ -100,11 +99,11 @@ pub(crate) async fn get_backup_info_route( /// - Deletes both information about the backup, as well as all key data related /// to the backup pub(crate) async fn delete_backup_version_route( - body: Ruma<delete_backup_version::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_backup_version::v3::Request>, ) -> Result<delete_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() + services .key_backups .delete_backup(sender_user, &body.version)?; @@ -120,12 +119,12 @@ pub(crate) async fn delete_backup_version_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_route( - body: Ruma<add_backup_keys::v3::Request>, + State(services): State<crate::State>, body: Ruma<add_backup_keys::v3::Request>, ) -> Result<add_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != services() + != services .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -138,7 +137,7 @@ pub(crate) async fn add_backup_keys_route( for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { - services() + services .key_backups .add_key(sender_user, &body.version, room_id, session_id, key_data)?; } @@ -146,14 +145,12 @@ pub(crate) async fn add_backup_keys_route( Ok(add_backup_keys::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } @@ -166,12 +163,12 @@ pub(crate) async fn add_backup_keys_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_for_room_route( - body: Ruma<add_backup_keys_for_room::v3::Request>, + State(services): State<crate::State>, body: Ruma<add_backup_keys_for_room::v3::Request>, ) -> Result<add_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != services() + != services .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -183,21 +180,19 @@ pub(crate) async fn add_backup_keys_for_room_route( } for (session_id, key_data) in &body.sessions { - services() + services .key_backups .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; } Ok(add_backup_keys_for_room::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } @@ -210,12 +205,12 @@ pub(crate) async fn add_backup_keys_for_room_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub(crate) async fn add_backup_keys_for_session_route( - body: Ruma<add_backup_keys_for_session::v3::Request>, + State(services): State<crate::State>, body: Ruma<add_backup_keys_for_session::v3::Request>, ) -> Result<add_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != services() + != services .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -226,20 +221,18 @@ pub(crate) async fn add_backup_keys_for_session_route( )); } - services() + services .key_backups .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; Ok(add_backup_keys_for_session::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } @@ -247,11 +240,11 @@ pub(crate) async fn add_backup_keys_for_session_route( /// /// Retrieves all keys from the backup. pub(crate) async fn get_backup_keys_route( - body: Ruma<get_backup_keys::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_backup_keys::v3::Request>, ) -> Result<get_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services().key_backups.get_all(sender_user, &body.version)?; + let rooms = services.key_backups.get_all(sender_user, &body.version)?; Ok(get_backup_keys::v3::Response { rooms, @@ -262,11 +255,11 @@ pub(crate) async fn get_backup_keys_route( /// /// Retrieves all keys from the backup for a given room. pub(crate) async fn get_backup_keys_for_room_route( - body: Ruma<get_backup_keys_for_room::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_backup_keys_for_room::v3::Request>, ) -> Result<get_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services() + let sessions = services .key_backups .get_room(sender_user, &body.version, &body.room_id)?; @@ -279,11 +272,11 @@ pub(crate) async fn get_backup_keys_for_room_route( /// /// Retrieves a key from the backup. pub(crate) async fn get_backup_keys_for_session_route( - body: Ruma<get_backup_keys_for_session::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_backup_keys_for_session::v3::Request>, ) -> Result<get_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = services() + let key_data = services .key_backups .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; @@ -297,24 +290,22 @@ pub(crate) async fn get_backup_keys_for_session_route( /// /// Delete the keys from the backup. pub(crate) async fn delete_backup_keys_route( - body: Ruma<delete_backup_keys::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_backup_keys::v3::Request>, ) -> Result<delete_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() + services .key_backups .delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } @@ -322,24 +313,22 @@ pub(crate) async fn delete_backup_keys_route( /// /// Delete the keys from the backup for a given room. pub(crate) async fn delete_backup_keys_for_room_route( - body: Ruma<delete_backup_keys_for_room::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_room::v3::Request>, ) -> Result<delete_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() + services .key_backups .delete_room_keys(sender_user, &body.version, &body.room_id)?; Ok(delete_backup_keys_for_room::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } @@ -347,23 +336,21 @@ pub(crate) async fn delete_backup_keys_for_room_route( /// /// Delete a key from the backup. pub(crate) async fn delete_backup_keys_for_session_route( - body: Ruma<delete_backup_keys_for_session::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_session::v3::Request>, ) -> Result<delete_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() + services .key_backups .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; Ok(delete_backup_keys_for_session::v3::Response { count: (UInt::try_from( - services() + services .key_backups .count_keys(sender_user, &body.version)?, ) .expect("user backup keys count should not be that high")), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, + etag: services.key_backups.get_etag(sender_user, &body.version)?, }) } diff --git a/src/api/client/capabilities.rs b/src/api/client/capabilities.rs index f8572fe82..c04347f80 100644 --- a/src/api/client/capabilities.rs +++ b/src/api/client/capabilities.rs @@ -1,29 +1,30 @@ use std::collections::BTreeMap; +use axum::extract::State; use ruma::api::client::discovery::get_capabilities::{ self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability, }; -use crate::{services, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/v3/capabilities` /// /// Get information on the supported feature set and other relevent capabilities /// of this server. pub(crate) async fn get_capabilities_route( - _body: Ruma<get_capabilities::v3::Request>, + State(services): State<crate::State>, _body: Ruma<get_capabilities::v3::Request>, ) -> Result<get_capabilities::v3::Response> { let mut available = BTreeMap::new(); - for room_version in &services().globals.unstable_room_versions { + for room_version in &services.globals.unstable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Unstable); } - for room_version in &services().globals.stable_room_versions { + for room_version in &services.globals.stable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } let mut capabilities = Capabilities::default(); capabilities.room_versions = RoomVersionsCapability { - default: services().globals.default_room_version(), + default: services.globals.default_room_version(), available, }; diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 603c192a2..56d33ba7e 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::client::{ config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, @@ -10,15 +11,21 @@ use serde::Deserialize; use serde_json::{json, value::RawValue as RawJsonValue}; -use crate::{services, Error, Result, Ruma}; +use crate::{service::Services, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// /// Sets some account data for the sender user. pub(crate) async fn set_global_account_data_route( - body: Ruma<set_global_account_data::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_global_account_data::v3::Request>, ) -> Result<set_global_account_data::v3::Response> { - set_account_data(None, &body.sender_user, &body.event_type.to_string(), body.data.json())?; + set_account_data( + services, + None, + &body.sender_user, + &body.event_type.to_string(), + body.data.json(), + )?; Ok(set_global_account_data::v3::Response {}) } @@ -27,9 +34,10 @@ pub(crate) async fn set_global_account_data_route( /// /// Sets some room account data for the sender user. pub(crate) async fn set_room_account_data_route( - body: Ruma<set_room_account_data::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_room_account_data::v3::Request>, ) -> Result<set_room_account_data::v3::Response> { set_account_data( + services, Some(&body.room_id), &body.sender_user, &body.event_type.to_string(), @@ -43,11 +51,11 @@ pub(crate) async fn set_room_account_data_route( /// /// Gets some account data for the sender user. pub(crate) async fn get_global_account_data_route( - body: Ruma<get_global_account_data::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_global_account_data::v3::Request>, ) -> Result<get_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = services() + let event: Box<RawJsonValue> = services .account_data .get(None, sender_user, body.event_type.to_string().into())? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; @@ -65,11 +73,11 @@ pub(crate) async fn get_global_account_data_route( /// /// Gets some room account data for the sender user. pub(crate) async fn get_room_account_data_route( - body: Ruma<get_room_account_data::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_room_account_data::v3::Request>, ) -> Result<get_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = services() + let event: Box<RawJsonValue> = services .account_data .get(Some(&body.room_id), sender_user, body.event_type.clone())? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; @@ -84,14 +92,15 @@ pub(crate) async fn get_room_account_data_route( } fn set_account_data( - room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str, data: &RawJsonValue, + services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str, + data: &RawJsonValue, ) -> Result<()> { let sender_user = sender_user.as_ref().expect("user is authenticated"); let data: serde_json::Value = serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - services().account_data.update( + services.account_data.update( room_id, sender_user, event_type.into(), diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 2ca161c91..89a08153e 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,12 +1,13 @@ use std::collections::HashSet; +use axum::extract::State; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, }; use tracing::error; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// @@ -14,7 +15,9 @@ /// /// - Only works if the user is joined (TODO: always allow, but only show events /// if the user was joined, depending on history_visibility) -pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> { +pub(crate) async fn get_context_route( + State(services): State<crate::State>, body: Ruma<get_context::v3::Request>, +) -> Result<get_context::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -27,13 +30,13 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R let mut lazy_loaded = HashSet::new(); - let base_token = services() + let base_token = services .rooms .timeline .get_pdu_count(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; - let base_event = services() + let base_event = services .rooms .timeline .get_pdu(&body.event_id)? @@ -41,7 +44,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R let room_id = base_event.room_id.clone(); - if !services() + if !services .rooms .state_accessor .user_can_see_event(sender_user, &room_id, &body.event_id)? @@ -52,7 +55,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R )); } - if !services().rooms.lazy_loading.lazy_load_was_sent_before( + if !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -67,14 +70,14 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R let base_event = base_event.to_room_event(); - let events_before: Vec<_> = services() + let events_before: Vec<_> = services .rooms .timeline .pdus_until(sender_user, &room_id, base_token)? .take(limit / 2) .filter_map(Result::ok) // Remove buggy events .filter(|(_, pdu)| { - services() + services .rooms .state_accessor .user_can_see_event(sender_user, &room_id, &pdu.event_id) @@ -83,7 +86,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R .collect(); for (_, event) in &events_before { - if !services().rooms.lazy_loading.lazy_load_was_sent_before( + if !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -103,14 +106,14 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R .map(|(_, pdu)| pdu.to_room_event()) .collect(); - let events_after: Vec<_> = services() + let events_after: Vec<_> = services .rooms .timeline .pdus_after(sender_user, &room_id, base_token)? .take(limit / 2) .filter_map(Result::ok) // Remove buggy events .filter(|(_, pdu)| { - services() + services .rooms .state_accessor .user_can_see_event(sender_user, &room_id, &pdu.event_id) @@ -119,7 +122,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R .collect(); for (_, event) in &events_after { - if !services().rooms.lazy_loading.lazy_load_was_sent_before( + if !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -130,7 +133,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R } } - let shortstatehash = services() + let shortstatehash = services .rooms .state_accessor .pdu_shortstatehash( @@ -139,7 +142,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R .map_or(&*body.event_id, |(_, e)| &*e.event_id), )? .map_or( - services() + services .rooms .state .get_room_shortstatehash(&room_id)? @@ -147,7 +150,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R |hash| hash, ); - let state_ids = services() + let state_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) @@ -165,20 +168,20 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R let mut state = Vec::with_capacity(state_ids.len()); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services() + let (event_type, state_key) = services .rooms .short .get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; diff --git a/src/api/client/device.rs b/src/api/client/device.rs index c917ba773..bad7f2844 100644 --- a/src/api/client/device.rs +++ b/src/api/client/device.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -5,15 +6,17 @@ }; use super::SESSION_ID_LENGTH; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma}; /// # `GET /_matrix/client/r0/devices` /// /// Get metadata on all devices of the sender user. -pub(crate) async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> { +pub(crate) async fn get_devices_route( + State(services): State<crate::State>, body: Ruma<get_devices::v3::Request>, +) -> Result<get_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let devices: Vec<device::Device> = services() + let devices: Vec<device::Device> = services .users .all_devices_metadata(sender_user) .filter_map(Result::ok) // Filter out buggy devices @@ -27,10 +30,12 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> R /// # `GET /_matrix/client/r0/devices/{deviceId}` /// /// Get metadata on a single device of the sender user. -pub(crate) async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> { +pub(crate) async fn get_device_route( + State(services): State<crate::State>, body: Ruma<get_device::v3::Request>, +) -> Result<get_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device = services() + let device = services .users .get_device_metadata(sender_user, &body.body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; @@ -43,17 +48,19 @@ pub(crate) async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Res /// # `PUT /_matrix/client/r0/devices/{deviceId}` /// /// Updates the metadata on a given device of the sender user. -pub(crate) async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> { +pub(crate) async fn update_device_route( + State(services): State<crate::State>, body: Ruma<update_device::v3::Request>, +) -> Result<update_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device = services() + let mut device = services .users .get_device_metadata(sender_user, &body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; device.display_name.clone_from(&body.display_name); - services() + services .users .update_device_metadata(sender_user, &body.device_id, &device)?; @@ -70,7 +77,9 @@ pub(crate) async fn update_device_route(body: Ruma<update_device::v3::Request>) /// last seen ts) /// - Forgets to-device events /// - Triggers device list updates -pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> { +pub(crate) async fn delete_device_route( + State(services): State<crate::State>, body: Ruma<delete_device::v3::Request>, +) -> Result<delete_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -86,7 +95,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services() + let (worked, uiaainfo) = services .uiaa .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { @@ -95,7 +104,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() + services .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); @@ -103,9 +112,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services() - .users - .remove_device(sender_user, &body.device_id)?; + services.users.remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -123,7 +130,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) /// - Forgets to-device events /// - Triggers device list updates pub(crate) async fn delete_devices_route( - body: Ruma<delete_devices::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_devices::v3::Request>, ) -> Result<delete_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -140,7 +147,7 @@ pub(crate) async fn delete_devices_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services() + let (worked, uiaainfo) = services .uiaa .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { @@ -149,7 +156,7 @@ pub(crate) async fn delete_devices_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() + services .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); @@ -158,7 +165,7 @@ pub(crate) async fn delete_devices_route( } for device_id in &body.devices { - services().users.remove_device(sender_user, device_id)?; + services.users.remove_device(sender_user, device_id)?; } Ok(delete_devices::v3::Response {}) diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 8e12c0343..deebd250e 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{err, info, warn, Error, Result}; use ruma::{ @@ -20,7 +21,10 @@ uint, RoomId, ServerName, UInt, UserId, }; -use crate::{service::server_is_ours, services, Ruma}; +use crate::{ + service::{server_is_ours, Services}, + Ruma, +}; /// # `POST /_matrix/client/v3/publicRooms` /// @@ -29,10 +33,11 @@ /// - Rooms are ordered by the number of joined members #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] pub(crate) async fn get_public_rooms_filtered_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms_filtered::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_public_rooms_filtered::v3::Request>, ) -> Result<get_public_rooms_filtered::v3::Response> { if let Some(server) = &body.server { - if services() + if services .globals .forbidden_remote_room_directory_server_names() .contains(server) @@ -45,6 +50,7 @@ pub(crate) async fn get_public_rooms_filtered_route( } let response = get_public_rooms_filtered_helper( + services, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -67,10 +73,11 @@ pub(crate) async fn get_public_rooms_filtered_route( /// - Rooms are ordered by the number of joined members #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] pub(crate) async fn get_public_rooms_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_public_rooms::v3::Request>, ) -> Result<get_public_rooms::v3::Response> { if let Some(server) = &body.server { - if services() + if services .globals .forbidden_remote_room_directory_server_names() .contains(server) @@ -83,6 +90,7 @@ pub(crate) async fn get_public_rooms_route( } let response = get_public_rooms_filtered_helper( + services, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -108,16 +116,17 @@ pub(crate) async fn get_public_rooms_route( /// Sets the visibility of a given room in the room directory. #[tracing::instrument(skip_all, fields(%client), name = "room_directory")] pub(crate) async fn set_room_visibility_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<set_room_visibility::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<set_room_visibility::v3::Request>, ) -> Result<set_room_visibility::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id)? { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if !user_can_publish_room(sender_user, &body.room_id)? { + if !user_can_publish_room(services, sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", @@ -126,7 +135,7 @@ pub(crate) async fn set_room_visibility_route( match &body.visibility { room::Visibility::Public => { - if services().globals.config.lockdown_public_room_directory && !services().users.is_admin(sender_user)? { + if services.globals.config.lockdown_public_room_directory && !services.users.is_admin(sender_user)? { info!( "Non-admin user {sender_user} tried to publish {0} to the room directory while \ \"lockdown_public_room_directory\" is enabled", @@ -139,10 +148,10 @@ pub(crate) async fn set_room_visibility_route( )); } - services().rooms.directory.set_public(&body.room_id)?; + services.rooms.directory.set_public(&body.room_id)?; info!("{sender_user} made {0} public", body.room_id); }, - room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, + room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?, _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -158,15 +167,15 @@ pub(crate) async fn set_room_visibility_route( /// /// Gets the visibility of a given room in the room directory. pub(crate) async fn get_room_visibility_route( - body: Ruma<get_room_visibility::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>, ) -> Result<get_room_visibility::v3::Response> { - if !services().rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id)? { // Return 404 if the room doesn't exist return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } Ok(get_room_visibility::v3::Response { - visibility: if services().rooms.directory.is_public_room(&body.room_id)? { + visibility: if services.rooms.directory.is_public_room(&body.room_id)? { room::Visibility::Public } else { room::Visibility::Private @@ -175,10 +184,11 @@ pub(crate) async fn get_room_visibility_route( } pub(crate) async fn get_public_rooms_filtered_helper( - server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork, + services: &Services, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, + _network: &RoomNetwork, ) -> Result<get_public_rooms_filtered::v3::Response> { if let Some(other_server) = server.filter(|server_name| !server_is_ours(server_name)) { - let response = services() + let response = services .sending .send_federation_request( other_server, @@ -224,7 +234,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = services() + let mut all_rooms: Vec<_> = services .rooms .directory .public_rooms() @@ -232,12 +242,12 @@ pub(crate) async fn get_public_rooms_filtered_helper( let room_id = room_id?; let chunk = PublicRoomsChunk { - canonical_alias: services() + canonical_alias: services .rooms .state_accessor .get_canonical_alias(&room_id)?, - name: services().rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services() + name: services.rooms.state_accessor.get_name(&room_id)?, + num_joined_members: services .rooms .state_cache .room_joined_count(&room_id)? @@ -247,24 +257,24 @@ pub(crate) async fn get_public_rooms_filtered_helper( }) .try_into() .expect("user count should not be that big"), - topic: services() + topic: services .rooms .state_accessor .get_room_topic(&room_id) .unwrap_or(None), - world_readable: services().rooms.state_accessor.is_world_readable(&room_id)?, - guest_can_join: services() + world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?, + guest_can_join: services .rooms .state_accessor .guest_can_join(&room_id)?, - avatar_url: services() + avatar_url: services .rooms .state_accessor .get_avatar(&room_id)? .into_option() .unwrap_or_default() .url, - join_rule: services() + join_rule: services .rooms .state_accessor .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? @@ -282,7 +292,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( .transpose()? .flatten() .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services() + room_type: services .rooms .state_accessor .get_room_type(&room_id)?, @@ -361,12 +371,11 @@ pub(crate) async fn get_public_rooms_filtered_helper( /// Check whether the user can publish to the room directory via power levels of /// room history visibility event or room creator -fn user_can_publish_room(user_id: &UserId, room_id: &RoomId) -> Result<bool> { - if let Some(event) = - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? +fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> { + if let Some(event) = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? { serde_json::from_str(event.content.get()) .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) @@ -374,7 +383,7 @@ fn user_can_publish_room(user_id: &UserId, room_id: &RoomId) -> Result<bool> { RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) }) } else if let Some(event) = - services() + services .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? diff --git a/src/api/client/filter.rs b/src/api/client/filter.rs index 5a1614001..8b2690c69 100644 --- a/src/api/client/filter.rs +++ b/src/api/client/filter.rs @@ -1,18 +1,21 @@ +use axum::extract::State; use ruma::api::client::{ error::ErrorKind, filter::{create_filter, get_filter}, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// /// Loads a filter that was previously created. /// /// - A user can only access their own filters -pub(crate) async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> { +pub(crate) async fn get_filter_route( + State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>, +) -> Result<get_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let Some(filter) = services().users.get_filter(sender_user, &body.filter_id)? else { + let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else { return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); }; @@ -22,9 +25,11 @@ pub(crate) async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Res /// # `PUT /_matrix/client/r0/user/{userId}/filter` /// /// Creates a new filter to be used by other endpoints. -pub(crate) async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> { +pub(crate) async fn create_filter_route( + State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>, +) -> Result<create_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(create_filter::v3::Response::new( - services().users.create_filter(sender_user, &body.filter)?, + services.users.create_filter(sender_user, &body.filter)?, )) } diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 7bb02a606..728ea7a93 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -3,6 +3,7 @@ time::Instant, }; +use axum::extract::State; use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -22,7 +23,7 @@ use tracing::debug; use super::SESSION_ID_LENGTH; -use crate::{services, Ruma}; +use crate::{service::Services, Ruma}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -31,12 +32,14 @@ /// - Adds one time keys /// - If there are no device keys yet: Adds device keys (TODO: merge with /// existing keys?) -pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> { +pub(crate) async fn upload_keys_route( + State(services): State<crate::State>, body: Ruma<upload_keys::v3::Request>, +) -> Result<upload_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); for (key_key, key_value) in &body.one_time_keys { - services() + services .users .add_one_time_key(sender_user, sender_device, key_key, key_value)?; } @@ -44,19 +47,19 @@ pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> R if let Some(device_keys) = &body.device_keys { // TODO: merge this and the existing event? // This check is needed to assure that signatures are kept - if services() + if services .users .get_device_keys(sender_user, sender_device)? .is_none() { - services() + services .users .add_device_keys(sender_user, sender_device, device_keys)?; } } Ok(upload_keys::v3::Response { - one_time_key_counts: services() + one_time_key_counts: services .users .count_one_time_keys(sender_user, sender_device)?, }) @@ -70,10 +73,13 @@ pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> R /// - Gets master keys, self-signing keys, user signing keys and device keys. /// - The master and self-signing keys contain signatures that the user is /// allowed to see -pub(crate) async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> { +pub(crate) async fn get_keys_route( + State(services): State<crate::State>, body: Ruma<get_keys::v3::Request>, +) -> Result<get_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); get_keys_helper( + services, Some(sender_user), &body.device_keys, |u| u == sender_user, @@ -85,8 +91,10 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result< /// # `POST /_matrix/client/r0/keys/claim` /// /// Claims one-time keys -pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> { - claim_keys_helper(&body.one_time_keys).await +pub(crate) async fn claim_keys_route( + State(services): State<crate::State>, body: Ruma<claim_keys::v3::Request>, +) -> Result<claim_keys::v3::Response> { + claim_keys_helper(services, &body.one_time_keys).await } /// # `POST /_matrix/client/r0/keys/device_signing/upload` @@ -95,7 +103,7 @@ pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Res /// /// - Requires UIAA to verify password pub(crate) async fn upload_signing_keys_route( - body: Ruma<upload_signing_keys::v3::Request>, + State(services): State<crate::State>, body: Ruma<upload_signing_keys::v3::Request>, ) -> Result<upload_signing_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -112,7 +120,7 @@ pub(crate) async fn upload_signing_keys_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services() + let (worked, uiaainfo) = services .uiaa .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { @@ -121,7 +129,7 @@ pub(crate) async fn upload_signing_keys_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() + services .uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); @@ -130,7 +138,7 @@ pub(crate) async fn upload_signing_keys_route( } if let Some(master_key) = &body.master_key { - services().users.add_cross_signing_keys( + services.users.add_cross_signing_keys( sender_user, master_key, &body.self_signing_key, @@ -146,7 +154,7 @@ pub(crate) async fn upload_signing_keys_route( /// /// Uploads end-to-end key signatures from the sender user. pub(crate) async fn upload_signatures_route( - body: Ruma<upload_signatures::v3::Request>, + State(services): State<crate::State>, body: Ruma<upload_signatures::v3::Request>, ) -> Result<upload_signatures::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -173,7 +181,7 @@ pub(crate) async fn upload_signatures_route( .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .to_owned(), ); - services() + services .users .sign_key(user_id, key_id, signature, sender_user)?; } @@ -192,14 +200,14 @@ pub(crate) async fn upload_signatures_route( /// /// - TODO: left users pub(crate) async fn get_key_changes_route( - body: Ruma<get_key_changes::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_key_changes::v3::Request>, ) -> Result<get_key_changes::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut device_list_updates = HashSet::new(); device_list_updates.extend( - services() + services .users .keys_changed( sender_user.as_str(), @@ -215,14 +223,14 @@ pub(crate) async fn get_key_changes_route( .filter_map(Result::ok), ); - for room_id in services() + for room_id in services .rooms .state_cache .rooms_joined(sender_user) .filter_map(Result::ok) { device_list_updates.extend( - services() + services .users .keys_changed( room_id.as_ref(), @@ -245,8 +253,8 @@ pub(crate) async fn get_key_changes_route( } pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( - sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F, - include_display_names: bool, + services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, + allowed_signatures: F, include_display_names: bool, ) -> Result<get_keys::v3::Response> { let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); @@ -268,10 +276,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in services().users.all_device_ids(user_id) { + for device_id in services.users.all_device_ids(user_id) { let device_id = device_id?; - if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { - let metadata = services() + if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? { + let metadata = services .users .get_device_metadata(user_id, &device_id)? .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; @@ -286,8 +294,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { - let metadata = services() + if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? { + let metadata = services .users .get_device_metadata(user_id, device_id)? .ok_or(Error::BadRequest( @@ -303,21 +311,21 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( } } - if let Some(master_key) = services() + if let Some(master_key) = services .users .get_master_key(sender_user, user_id, &allowed_signatures)? { master_keys.insert(user_id.to_owned(), master_key); } if let Some(self_signing_key) = - services() + services .users .get_self_signing_key(sender_user, user_id, &allowed_signatures)? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { + if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -326,7 +334,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( let mut failures = BTreeMap::new(); let back_off = |id| async { - match services() + match services .globals .bad_query_ratelimiter .write() @@ -345,7 +353,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( let mut futures: FuturesUnordered<_> = get_over_federation .into_iter() .map(|(server, vec)| async move { - if let Some((time, tries)) = services() + if let Some((time, tries)) = services .globals .bad_query_ratelimiter .read() @@ -369,7 +377,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( let request = federation::keys::get_keys::v1::Request { device_keys: device_keys_input_fed, }; - let response = services() + let response = services .sending .send_federation_request(server, request) .await; @@ -381,19 +389,19 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( while let Some((server, response)) = futures.next().await { if let Ok(Ok(response)) = response { for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?; + let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?; if let Some(our_master_key) = - services() + services .users .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? { - let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?; + let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); } let json = serde_json::to_value(master_key).expect("to_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services().users.add_cross_signing_keys( + services.users.add_cross_signing_keys( &user, &raw, &None, &None, false, /* Dont notify. A notification would trigger another key request resulting in an * endless loop */ @@ -444,7 +452,7 @@ fn add_unsigned_device_display_name( } pub(crate) async fn claim_keys_helper( - one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, + services: &Services, one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, ) -> Result<claim_keys::v3::Response> { let mut one_time_keys = BTreeMap::new(); @@ -460,7 +468,7 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services() + if let Some(one_time_keys) = services .users .take_one_time_key(user_id, device_id, key_algorithm)? { @@ -483,7 +491,7 @@ pub(crate) async fn claim_keys_helper( } ( server, - services() + services .sending .send_federation_request( server, diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 39640b232..1adcefdd8 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -2,6 +2,7 @@ use std::{io::Cursor, sync::Arc, time::Duration}; +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{debug, error, utils::math::ruma_from_usize, warn}; use image::io::Reader as ImgReader; @@ -20,9 +21,8 @@ debug_warn, service::{ media::{FileMeta, UrlPreviewData}, - server_is_ours, + server_is_ours, Services, }, - services, utils::{ self, content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename}, @@ -42,10 +42,10 @@ /// /// Returns max upload size. pub(crate) async fn get_media_config_route( - _body: Ruma<get_media_config::v3::Request>, + State(services): State<crate::State>, _body: Ruma<get_media_config::v3::Request>, ) -> Result<get_media_config::v3::Response> { Ok(get_media_config::v3::Response { - upload_size: ruma_from_usize(services().globals.config.max_request_size), + upload_size: ruma_from_usize(services.globals.config.max_request_size), }) } @@ -57,9 +57,11 @@ pub(crate) async fn get_media_config_route( /// /// Returns max upload size. pub(crate) async fn get_media_config_v1_route( - body: Ruma<get_media_config::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_media_config::v3::Request>, ) -> Result<RumaResponse<get_media_config::v3::Response>> { - get_media_config_route(body).await.map(RumaResponse) + get_media_config_route(State(services), body) + .await + .map(RumaResponse) } /// # `GET /_matrix/media/v3/preview_url` @@ -67,17 +69,18 @@ pub(crate) async fn get_media_config_v1_route( /// Returns URL preview. #[tracing::instrument(skip_all, fields(%client), name = "url_preview")] pub(crate) async fn get_media_preview_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_media_preview::v3::Request>, ) -> Result<get_media_preview::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !url_preview_allowed(url) { + if !url_preview_allowed(services, url) { warn!(%sender_user, "URL is not allowed to be previewed: {url}"); return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed")); } - match get_url_preview(url).await { + match get_url_preview(services, url).await { Ok(preview) => { let res = serde_json::value::to_raw_value(&preview).map_err(|e| { error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}"); @@ -115,9 +118,10 @@ pub(crate) async fn get_media_preview_route( /// Returns URL preview. #[tracing::instrument(skip_all, fields(%client), name = "url_preview")] pub(crate) async fn get_media_preview_v1_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_media_preview::v3::Request>, ) -> Result<RumaResponse<get_media_preview::v3::Response>> { - get_media_preview_route(InsecureClientIp(client), body) + get_media_preview_route(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } @@ -130,17 +134,14 @@ pub(crate) async fn get_media_preview_v1_route( /// - Media will be saved in the media/ directory #[tracing::instrument(skip_all, fields(%client), name = "media_upload")] pub(crate) async fn create_content_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<create_content::v3::Request>, ) -> Result<create_content::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(MXC_LENGTH) - ); + let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH)); - services() + services .media .create( Some(sender_user.clone()), @@ -178,9 +179,10 @@ pub(crate) async fn create_content_route( /// - Media will be saved in the media/ directory #[tracing::instrument(skip_all, fields(%client), name = "media_upload")] pub(crate) async fn create_content_v1_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<create_content::v3::Request>, ) -> Result<RumaResponse<create_content::v3::Response>> { - create_content_route(InsecureClientIp(client), body) + create_content_route(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } @@ -195,7 +197,8 @@ pub(crate) async fn create_content_v1_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content::v3::Request>, ) -> Result<get_content::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -203,7 +206,7 @@ pub(crate) async fn get_content_route( content, content_type, content_disposition, - }) = services().media.get(&mxc).await? + }) = services.media.get(&mxc).await? { let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None)); let file = content.expect("content"); @@ -217,6 +220,7 @@ pub(crate) async fn get_content_route( }) } else if !server_is_ours(&body.server_name) && body.allow_remote { let response = get_remote_content( + services, &mxc, &body.server_name, body.media_id.clone(), @@ -261,9 +265,10 @@ pub(crate) async fn get_content_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_v1_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content::v3::Request>, ) -> Result<RumaResponse<get_content::v3::Response>> { - get_content_route(InsecureClientIp(client), body) + get_content_route(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } @@ -278,7 +283,8 @@ pub(crate) async fn get_content_v1_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_as_filename_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content_as_filename::v3::Request>, ) -> Result<get_content_as_filename::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -286,7 +292,7 @@ pub(crate) async fn get_content_as_filename_route( content, content_type, content_disposition, - }) = services().media.get(&mxc).await? + }) = services.media.get(&mxc).await? { let content_disposition = Some(make_content_disposition( &content_type, @@ -304,6 +310,7 @@ pub(crate) async fn get_content_as_filename_route( }) } else if !server_is_ours(&body.server_name) && body.allow_remote { match get_remote_content( + services, &mxc, &body.server_name, body.media_id.clone(), @@ -351,9 +358,10 @@ pub(crate) async fn get_content_as_filename_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_as_filename_v1_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content_as_filename::v3::Request>, ) -> Result<RumaResponse<get_content_as_filename::v3::Response>> { - get_content_as_filename_route(InsecureClientIp(client), body) + get_content_as_filename_route(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } @@ -368,7 +376,8 @@ pub(crate) async fn get_content_as_filename_v1_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] pub(crate) async fn get_content_thumbnail_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content_thumbnail::v3::Request>, ) -> Result<get_content_thumbnail::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -376,7 +385,7 @@ pub(crate) async fn get_content_thumbnail_route( content, content_type, content_disposition, - }) = services() + }) = services .media .get_thumbnail( &mxc, @@ -400,7 +409,7 @@ pub(crate) async fn get_content_thumbnail_route( content_disposition, }) } else if !server_is_ours(&body.server_name) && body.allow_remote { - if services() + if services .globals .prevent_media_downloads_from() .contains(&body.server_name) @@ -411,7 +420,7 @@ pub(crate) async fn get_content_thumbnail_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); } - match services() + match services .sending .send_federation_request( &body.server_name, @@ -430,7 +439,7 @@ pub(crate) async fn get_content_thumbnail_route( .await { Ok(get_thumbnail_response) => { - services() + services .media .upload_thumbnail( None, @@ -481,17 +490,19 @@ pub(crate) async fn get_content_thumbnail_route( /// seconds #[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] pub(crate) async fn get_content_thumbnail_v1_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_content_thumbnail::v3::Request>, ) -> Result<RumaResponse<get_content_thumbnail::v3::Response>> { - get_content_thumbnail_route(InsecureClientIp(client), body) + get_content_thumbnail_route(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } async fn get_remote_content( - mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration, + services: &Services, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, + timeout_ms: Duration, ) -> Result<get_content::v3::Response, Error> { - if services() + if services .globals .prevent_media_downloads_from() .contains(&server_name.to_owned()) @@ -502,7 +513,7 @@ async fn get_remote_content( return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); } - let content_response = services() + let content_response = services .sending .send_federation_request( server_name, @@ -522,7 +533,7 @@ async fn get_remote_content( None, )); - services() + services .media .create( None, @@ -542,15 +553,11 @@ async fn get_remote_content( }) } -async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { +async fn download_image(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { let image = client.get(url).send().await?.bytes().await?; - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(MXC_LENGTH) - ); + let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH)); - services() + services .media .create(None, &mxc, None, None, &image) .await?; @@ -572,18 +579,18 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie }) } -async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { +async fn download_html(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { let mut response = client.get(url).send().await?; let mut bytes: Vec<u8> = Vec::new(); while let Some(chunk) = response.chunk().await? { bytes.extend_from_slice(&chunk); - if bytes.len() > services().globals.url_preview_max_spider_size() { + if bytes.len() > services.globals.url_preview_max_spider_size() { debug!( "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \ response body and assuming our necessary data is in this range.", url, - services().globals.url_preview_max_spider_size() + services.globals.url_preview_max_spider_size() ); break; } @@ -595,7 +602,7 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview let mut data = match html.opengraph.images.first() { None => UrlPreviewData::default(), - Some(obj) => download_image(client, &obj.url).await?, + Some(obj) => download_image(services, client, &obj.url).await?, }; let props = html.opengraph.properties; @@ -607,19 +614,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview Ok(data) } -async fn request_url_preview(url: &str) -> Result<UrlPreviewData> { +async fn request_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> { if let Ok(ip) = IPAddress::parse(url) { - if !services().globals.valid_cidr_range(&ip) { + if !services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Requesting from this address is forbidden")); } } - let client = &services().globals.client.url_preview; + let client = &services.globals.client.url_preview; let response = client.head(url).send().await?; if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - if !services().globals.valid_cidr_range(&ip) { + if !services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Requesting from this address is forbidden")); } } @@ -633,24 +640,24 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> { return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")); }; let data = match content_type { - html if html.starts_with("text/html") => download_html(client, url).await?, - img if img.starts_with("image/") => download_image(client, url).await?, + html if html.starts_with("text/html") => download_html(services, client, url).await?, + img if img.starts_with("image/") => download_image(services, client, url).await?, _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), }; - services().media.set_url_preview(url, &data).await?; + services.media.set_url_preview(url, &data).await?; Ok(data) } -async fn get_url_preview(url: &str) -> Result<UrlPreviewData> { - if let Some(preview) = services().media.get_url_preview(url).await { +async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> { + if let Some(preview) = services.media.get_url_preview(url).await { return Ok(preview); } // ensure that only one request is made per URL let mutex_request = Arc::clone( - services() + services .media .url_preview_mutex .write() @@ -660,13 +667,13 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> { ); let _request_lock = mutex_request.lock().await; - match services().media.get_url_preview(url).await { + match services.media.get_url_preview(url).await { Some(preview) => Ok(preview), - None => request_url_preview(url).await, + None => request_url_preview(services, url).await, } } -fn url_preview_allowed(url_str: &str) -> bool { +fn url_preview_allowed(services: &Services, url_str: &str) -> bool { let url: Url = match Url::parse(url_str) { Ok(u) => u, Err(e) => { @@ -691,10 +698,10 @@ fn url_preview_allowed(url_str: &str) -> bool { Some(h) => h.to_owned(), }; - let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist(); - let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist(); - let denylist_domain_explicit = services().globals.url_preview_domain_explicit_denylist(); - let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist(); + let allowlist_domain_contains = services.globals.url_preview_domain_contains_allowlist(); + let allowlist_domain_explicit = services.globals.url_preview_domain_explicit_allowlist(); + let denylist_domain_explicit = services.globals.url_preview_domain_explicit_denylist(); + let allowlist_url_contains = services.globals.url_preview_url_contains_allowlist(); if allowlist_domain_contains.contains(&"*".to_owned()) || allowlist_domain_explicit.contains(&"*".to_owned()) @@ -735,7 +742,7 @@ fn url_preview_allowed(url_str: &str) -> bool { } // check root domain if available and if user has root domain checks - if services().globals.url_preview_check_root_domain() { + if services.globals.url_preview_check_root_domain() { debug!("Checking root domain"); match host.split_once('.') { None => return false, diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 3adee631b..9fde99a4a 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -5,6 +5,7 @@ time::Instant, }; +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{ debug, debug_warn, error, info, trace, utils, utils::math::continue_exponential_backoff_secs, warn, Error, @@ -43,9 +44,9 @@ pdu::{gen_event_id_canonical_json, PduBuilder}, rooms::state::RoomMutexGuard, sending::convert_to_outgoing_federation_event, - server_is_ours, user_is_local, + server_is_ours, user_is_local, Services, }, - services, Ruma, + Ruma, }; /// Checks if the room is banned in any way possible and the sender user is not @@ -53,14 +54,15 @@ /// /// Performs automatic deactivation if `auto_deactivate_banned_room_attempts` is /// enabled -#[tracing::instrument] +#[tracing::instrument(skip(services))] async fn banned_room_check( - user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, client_ip: IpAddr, + services: &Services, user_id: &UserId, room_id: Option<&RoomId>, server_name: Option<&ServerName>, + client_ip: IpAddr, ) -> Result<()> { - if !services().users.is_admin(user_id)? { + if !services.users.is_admin(user_id)? { if let Some(room_id) = room_id { - if services().rooms.metadata.is_banned(room_id)? - || services() + if services.rooms.metadata.is_banned(room_id)? + || services .globals .config .forbidden_remote_server_names @@ -71,13 +73,9 @@ async fn banned_room_check( room or banned room server name: {room_id}" ); - if services() - .globals - .config - .auto_deactivate_banned_room_attempts - { + if services.globals.config.auto_deactivate_banned_room_attempts { warn!("Automatically deactivating user {user_id} due to attempted banned room join"); - services() + services .admin .send_message(RoomMessageEventContent::text_plain(format!( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ @@ -85,20 +83,20 @@ async fn banned_room_check( ))) .await; - if let Err(e) = services().users.deactivate_account(user_id) { + if let Err(e) = services.users.deactivate_account(user_id) { warn!(%user_id, %e, "Failed to deactivate account"); } - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(user_id) .filter_map(Result::ok) .collect(); - update_displayname(user_id.into(), None, all_joined_rooms.clone()).await?; - update_avatar_url(user_id.into(), None, None, all_joined_rooms).await?; - leave_all_rooms(user_id).await; + update_displayname(services, user_id.into(), None, all_joined_rooms.clone()).await?; + update_avatar_url(services, user_id.into(), None, None, all_joined_rooms).await?; + leave_all_rooms(services, user_id).await; } return Err(Error::BadRequest( @@ -107,7 +105,7 @@ async fn banned_room_check( )); } } else if let Some(server_name) = server_name { - if services() + if services .globals .config .forbidden_remote_server_names @@ -118,13 +116,9 @@ async fn banned_room_check( that is globally forbidden. Rejecting.", ); - if services() - .globals - .config - .auto_deactivate_banned_room_attempts - { + if services.globals.config.auto_deactivate_banned_room_attempts { warn!("Automatically deactivating user {user_id} due to attempted banned room join"); - services() + services .admin .send_message(RoomMessageEventContent::text_plain(format!( "Automatically deactivating user {user_id} due to attempted banned room join from IP \ @@ -132,20 +126,20 @@ async fn banned_room_check( ))) .await; - if let Err(e) = services().users.deactivate_account(user_id) { + if let Err(e) = services.users.deactivate_account(user_id) { warn!(%user_id, %e, "Failed to deactivate account"); } - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(user_id) .filter_map(Result::ok) .collect(); - update_displayname(user_id.into(), None, all_joined_rooms.clone()).await?; - update_avatar_url(user_id.into(), None, None, all_joined_rooms).await?; - leave_all_rooms(user_id).await; + update_displayname(services, user_id.into(), None, all_joined_rooms.clone()).await?; + update_avatar_url(services, user_id.into(), None, None, all_joined_rooms).await?; + leave_all_rooms(services, user_id).await; } return Err(Error::BadRequest( @@ -169,14 +163,22 @@ async fn banned_room_check( /// federation #[tracing::instrument(skip_all, fields(%client_ip), name = "join")] pub(crate) async fn join_room_by_id_route( - InsecureClientIp(client_ip): InsecureClientIp, body: Ruma<join_room_by_id::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client_ip): InsecureClientIp, + body: Ruma<join_room_by_id::v3::Request>, ) -> Result<join_room_by_id::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - banned_room_check(sender_user, Some(&body.room_id), body.room_id.server_name(), client_ip).await?; + banned_room_check( + services, + sender_user, + Some(&body.room_id), + body.room_id.server_name(), + client_ip, + ) + .await?; // There is no body.server_name for /roomId/join - let mut servers = services() + let mut servers = services .rooms .state_cache .servers_invite_via(&body.room_id) @@ -184,7 +186,7 @@ pub(crate) async fn join_room_by_id_route( .collect::<Vec<_>>(); servers.extend( - services() + services .rooms .state_cache .invite_state(sender_user, &body.room_id)? @@ -202,6 +204,7 @@ pub(crate) async fn join_room_by_id_route( } join_room_by_id_helper( + services, sender_user, &body.room_id, body.reason.clone(), @@ -222,18 +225,19 @@ pub(crate) async fn join_room_by_id_route( /// via room alias server name and room ID server name #[tracing::instrument(skip_all, fields(%client), name = "join")] pub(crate) async fn join_room_by_id_or_alias_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<join_room_by_id_or_alias::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<join_room_by_id_or_alias::v3::Request>, ) -> Result<join_room_by_id_or_alias::v3::Response> { let sender_user = body.sender_user.as_deref().expect("user is authenticated"); let body = body.body; let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { Ok(room_id) => { - banned_room_check(sender_user, Some(&room_id), room_id.server_name(), client).await?; + banned_room_check(services, sender_user, Some(&room_id), room_id.server_name(), client).await?; let mut servers = body.server_name.clone(); servers.extend( - services() + services .rooms .state_cache .servers_invite_via(&room_id) @@ -241,7 +245,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( ); servers.extend( - services() + services .rooms .state_cache .invite_state(sender_user, &room_id)? @@ -261,21 +265,21 @@ pub(crate) async fn join_room_by_id_or_alias_route( (servers, room_id) }, Err(room_alias) => { - let response = services() + let response = services .rooms .alias .resolve_alias(&room_alias, Some(&body.server_name.clone())) .await?; let (room_id, mut pre_servers) = response; - banned_room_check(sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; + banned_room_check(services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; let mut servers = body.server_name; if let Some(pre_servers) = &mut pre_servers { servers.append(pre_servers); } servers.extend( - services() + services .rooms .state_cache .servers_invite_via(&room_id) @@ -283,7 +287,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( ); servers.extend( - services() + services .rooms .state_cache .invite_state(sender_user, &room_id)? @@ -301,6 +305,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( }; let join_room_response = join_room_by_id_helper( + services, sender_user, &room_id, body.reason.clone(), @@ -319,10 +324,12 @@ pub(crate) async fn join_room_by_id_or_alias_route( /// Tries to leave the sender user from a room. /// /// - This should always work if the user is currently joined. -pub(crate) async fn leave_room_route(body: Ruma<leave_room::v3::Request>) -> Result<leave_room::v3::Response> { +pub(crate) async fn leave_room_route( + State(services): State<crate::State>, body: Ruma<leave_room::v3::Request>, +) -> Result<leave_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - leave_room(sender_user, &body.room_id, body.reason.clone()).await?; + leave_room(services, sender_user, &body.room_id, body.reason.clone()).await?; Ok(leave_room::v3::Response::new()) } @@ -332,11 +339,12 @@ pub(crate) async fn leave_room_route(body: Ruma<leave_room::v3::Request>) -> Res /// Tries to send an invite event into the room. #[tracing::instrument(skip_all, fields(%client), name = "invite")] pub(crate) async fn invite_user_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<invite_user::v3::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<invite_user::v3::Request>, ) -> Result<invite_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().users.is_admin(sender_user)? && services().globals.block_non_admin_invites() { + if !services.users.is_admin(sender_user)? && services.globals.block_non_admin_invites() { info!( "User {sender_user} is not an admin and attempted to send an invite to room {}", &body.room_id @@ -347,13 +355,13 @@ pub(crate) async fn invite_user_route( )); } - banned_room_check(sender_user, Some(&body.room_id), body.room_id.server_name(), client).await?; + banned_room_check(services, sender_user, Some(&body.room_id), body.room_id.server_name(), client).await?; if let invite_user::v3::InvitationRecipient::UserId { user_id, } = &body.recipient { - invite_helper(sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; + invite_helper(services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -363,13 +371,15 @@ pub(crate) async fn invite_user_route( /// # `POST /_matrix/client/r0/rooms/{roomId}/kick` /// /// Tries to send a kick event into the room. -pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Result<kick_user::v3::Response> { +pub(crate) async fn kick_user_route( + State(services): State<crate::State>, body: Ruma<kick_user::v3::Request>, +) -> Result<kick_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let mut event: RoomMemberEventContent = serde_json::from_str( - services() + services .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? @@ -385,7 +395,7 @@ pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Resul event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); - services() + services .rooms .timeline .build_and_append_pdu( @@ -410,12 +420,14 @@ pub(crate) async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Resul /// # `POST /_matrix/client/r0/rooms/{roomId}/ban` /// /// Tries to send a ban event into the room. -pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_user::v3::Response> { +pub(crate) async fn ban_user_route( + State(services): State<crate::State>, body: Ruma<ban_user::v3::Request>, +) -> Result<ban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let event = services() + let event = services .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? @@ -426,7 +438,7 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< avatar_url: None, is_direct: None, third_party_invite: None, - blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), reason: body.reason.clone(), join_authorized_via_users_server: None, }), @@ -436,7 +448,7 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< membership: MembershipState::Ban, displayname: None, avatar_url: None, - blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(), + blurhash: services.users.blurhash(&body.user_id).unwrap_or_default(), reason: body.reason.clone(), join_authorized_via_users_server: None, ..event @@ -445,7 +457,7 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< }, )?; - services() + services .rooms .timeline .build_and_append_pdu( @@ -470,13 +482,15 @@ pub(crate) async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result< /// # `POST /_matrix/client/r0/rooms/{roomId}/unban` /// /// Tries to send an unban event into the room. -pub(crate) async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unban_user::v3::Response> { +pub(crate) async fn unban_user_route( + State(services): State<crate::State>, body: Ruma<unban_user::v3::Request>, +) -> Result<unban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let mut event: RoomMemberEventContent = serde_json::from_str( - services() + services .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? @@ -490,7 +504,7 @@ pub(crate) async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Res event.reason.clone_from(&body.reason); event.join_authorized_via_users_server = None; - services() + services .rooms .timeline .build_and_append_pdu( @@ -521,10 +535,12 @@ pub(crate) async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Res /// /// Note: Other devices of the user have no way of knowing the room was /// forgotten, so this has to be called from every device -pub(crate) async fn forget_room_route(body: Ruma<forget_room::v3::Request>) -> Result<forget_room::v3::Response> { +pub(crate) async fn forget_room_route( + State(services): State<crate::State>, body: Ruma<forget_room::v3::Request>, +) -> Result<forget_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services() + if services .rooms .state_cache .is_joined(sender_user, &body.room_id)? @@ -535,7 +551,7 @@ pub(crate) async fn forget_room_route(body: Ruma<forget_room::v3::Request>) -> R )); } - services() + services .rooms .state_cache .forget(&body.room_id, sender_user)?; @@ -546,11 +562,13 @@ pub(crate) async fn forget_room_route(body: Ruma<forget_room::v3::Request>) -> R /// # `POST /_matrix/client/r0/joined_rooms` /// /// Lists all rooms the user has joined. -pub(crate) async fn joined_rooms_route(body: Ruma<joined_rooms::v3::Request>) -> Result<joined_rooms::v3::Response> { +pub(crate) async fn joined_rooms_route( + State(services): State<crate::State>, body: Ruma<joined_rooms::v3::Request>, +) -> Result<joined_rooms::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(joined_rooms::v3::Response { - joined_rooms: services() + joined_rooms: services .rooms .state_cache .rooms_joined(sender_user) @@ -566,11 +584,11 @@ pub(crate) async fn joined_rooms_route(body: Ruma<joined_rooms::v3::Request>) -> /// /// - Only works if the user is currently joined pub(crate) async fn get_member_events_route( - body: Ruma<get_member_events::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_member_events::v3::Request>, ) -> Result<get_member_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_accessor .user_can_see_state_events(sender_user, &body.room_id)? @@ -582,7 +600,7 @@ pub(crate) async fn get_member_events_route( } Ok(get_member_events::v3::Response { - chunk: services() + chunk: services .rooms .state_accessor .room_state_full(&body.room_id) @@ -601,11 +619,11 @@ pub(crate) async fn get_member_events_route( /// - The sender user must be in the room /// - TODO: An appservice just needs a puppet joined pub(crate) async fn joined_members_route( - body: Ruma<joined_members::v3::Request>, + State(services): State<crate::State>, body: Ruma<joined_members::v3::Request>, ) -> Result<joined_members::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_accessor .user_can_see_state_events(sender_user, &body.room_id)? @@ -617,14 +635,14 @@ pub(crate) async fn joined_members_route( } let mut joined = BTreeMap::new(); - for user_id in services() + for user_id in services .rooms .state_cache .room_members(&body.room_id) .filter_map(Result::ok) { - let display_name = services().users.displayname(&user_id)?; - let avatar_url = services().users.avatar_url(&user_id)?; + let display_name = services.users.displayname(&user_id)?; + let avatar_url = services.users.avatar_url(&user_id)?; joined.insert( user_id, @@ -641,46 +659,48 @@ pub(crate) async fn joined_members_route( } pub async fn join_room_by_id_helper( - sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], third_party_signed: Option<&ThirdPartySigned>, ) -> Result<join_room_by_id::v3::Response> { - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; - if matches!(services().rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { + if matches!(services.rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), }); } - if services() + if services .rooms .state_cache - .server_in_room(services().globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id)? || servers.is_empty() || (servers.len() == 1 && server_is_ours(&servers[0])) { - join_room_by_id_helper_local(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .await } else { // Ask a remote server if we are not participating in this room - join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + join_room_by_id_helper_remote(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) + .await } } #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( - sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result<join_room_by_id::v3::Response> { info!("Joining {room_id} over federation."); - let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; + let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; info!("make_join finished"); let room_version_id = match make_join_response.room_version { Some(room_version) - if services() + if services .globals .supported_room_versions() .contains(&room_version) => @@ -705,7 +725,7 @@ async fn join_room_by_id_helper_remote( // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -719,11 +739,11 @@ async fn join_room_by_id_helper_remote( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user)?, + avatar_url: services.users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user)?, reason, join_authorized_via_users_server: join_authorized_via_users_server.clone(), }) @@ -742,8 +762,8 @@ async fn join_room_by_id_helper_remote( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut join_event_stub, &room_version_id, ) @@ -764,7 +784,7 @@ async fn join_room_by_id_helper_remote( let mut join_event = join_event_stub; info!("Asking {remote_server} for send_join in room {room_id}"); - let send_join_response = services() + let send_join_response = services .sending .send_federation_request( &remote_server, @@ -852,7 +872,7 @@ async fn join_room_by_id_helper_remote( } } - services().rooms.short.get_or_create_shortroomid(room_id)?; + services.rooms.short.get_or_create_shortroomid(room_id)?; info!("Parsing join event"); let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) @@ -862,7 +882,7 @@ async fn join_room_by_id_helper_remote( let pub_key_map = RwLock::new(BTreeMap::new()); info!("Fetching join signing keys"); - services() + services .rooms .event_handler .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) @@ -873,7 +893,7 @@ async fn join_room_by_id_helper_remote( .room_state .state .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; @@ -884,12 +904,9 @@ async fn join_room_by_id_helper_remote( Error::BadServerResponse("Invalid PDU in send_join response.") })?; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() + let shortstatekey = services .rooms .short .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; @@ -902,16 +919,13 @@ async fn join_room_by_id_helper_remote( .room_state .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(services, pdu, &room_version_id, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; }; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; + services.rooms.outlier.add_pdu_outlier(&event_id, &value)?; } debug!("Running send_join auth check"); @@ -921,12 +935,12 @@ async fn join_room_by_id_helper_remote( &parsed_join_pdu, None::<PduEvent>, // TODO: third party invite |k, s| { - services() + services .rooms .timeline .get_pdu( state.get( - &services() + &services .rooms .short .get_or_create_shortstatekey(&k.to_string().into(), s) @@ -946,37 +960,32 @@ async fn join_room_by_id_helper_remote( } info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( + let (statehash_before_join, new, removed) = services.rooms.state_compressor.save_state( room_id, Arc::new( state .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) + .map(|(k, id)| services.rooms.state_compressor.compress_state_event(k, &id)) .collect::<Result<_>>()?, ), )?; - services() + services .rooms .state .force_state(room_id, statehash_before_join, new, removed, &state_lock) .await?; info!("Updating joined counts for new room"); - services().rooms.state_cache.update_joined_count(room_id)?; + services.rooms.state_cache.update_joined_count(room_id)?; // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + let statehash_after_join = services.rooms.state.append_to_state(&parsed_join_pdu)?; info!("Appending new room join event"); - services() + services .rooms .timeline .append_pdu( @@ -990,7 +999,7 @@ async fn join_room_by_id_helper_remote( info!("Setting final room state for new room"); // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist - services() + services .rooms .state .set_room_state(room_id, statehash_after_join, &state_lock)?; @@ -1000,16 +1009,15 @@ async fn join_room_by_id_helper_remote( #[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( - sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], + services: &Services, sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result<join_room_by_id::v3::Response> { debug!("We can join locally"); - let join_rules_event = - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + let join_rules_event = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event .as_ref() @@ -1035,7 +1043,7 @@ async fn join_room_by_id_helper_local( _ => Vec::new(), }; - let local_members = services() + let local_members = services .rooms .state_cache .room_members(room_id) @@ -1046,14 +1054,14 @@ async fn join_room_by_id_helper_local( let mut join_authorized_via_users_server: Option<OwnedUserId> = None; if restriction_rooms.iter().any(|restriction_room_id| { - services() + services .rooms .state_cache .is_joined(sender_user, restriction_room_id) .unwrap_or(false) }) { for user in local_members { - if services() + if services .rooms .state_accessor .user_can_invite(room_id, &user, sender_user, &state_lock) @@ -1067,17 +1075,17 @@ async fn join_room_by_id_helper_local( let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user)?, + avatar_url: services.users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user)?, reason: reason.clone(), join_authorized_via_users_server, }; // Try normal join first - let error = match services() + let error = match services .rooms .timeline .build_and_append_pdu( @@ -1104,11 +1112,11 @@ async fn join_room_by_id_helper_local( .any(|server_name| !server_is_ours(server_name)) { warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); - let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; + let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; let room_version_id = match make_join_response.room_version { Some(room_version_id) - if services() + if services .globals .supported_room_versions() .contains(&room_version_id) => @@ -1130,7 +1138,7 @@ async fn join_room_by_id_helper_local( // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -1144,11 +1152,11 @@ async fn join_room_by_id_helper_local( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user)?, + avatar_url: services.users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user)?, reason, join_authorized_via_users_server, }) @@ -1167,8 +1175,8 @@ async fn join_room_by_id_helper_local( // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut join_event_stub, &room_version_id, ) @@ -1188,7 +1196,7 @@ async fn join_room_by_id_helper_local( // It has enough fields to be called a proper event now let join_event = join_event_stub; - let send_join_response = services() + let send_join_response = services .sending .send_federation_request( &remote_server, @@ -1219,12 +1227,12 @@ async fn join_room_by_id_helper_local( drop(state_lock); let pub_key_map = RwLock::new(BTreeMap::new()); - services() + services .rooms .event_handler .fetch_required_signing_keys([&signed_value], &pub_key_map) .await?; - services() + services .rooms .event_handler .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true, &pub_key_map) @@ -1240,7 +1248,7 @@ async fn join_room_by_id_helper_local( } async fn make_join_request( - sender_user: &UserId, room_id: &RoomId, servers: &[OwnedServerName], + services: &Services, sender_user: &UserId, room_id: &RoomId, servers: &[OwnedServerName], ) -> Result<(federation::membership::prepare_join_event::v1::Response, OwnedServerName)> { let mut make_join_response_and_server = Err(Error::BadServerResponse("No server available to assist in joining.")); @@ -1252,14 +1260,14 @@ async fn make_join_request( continue; } info!("Asking {remote_server} for make_join ({make_join_counter})"); - let make_join_response = services() + let make_join_response = services .sending .send_federation_request( remote_server, federation::membership::prepare_join_event::v1::Request { room_id: room_id.to_owned(), user_id: sender_user.to_owned(), - ver: services().globals.supported_room_versions(), + ver: services.globals.supported_room_versions(), }, ) .await; @@ -1309,7 +1317,8 @@ async fn make_join_request( } pub async fn validate_and_add_event_id( - pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, + services: &Services, pdu: &RawJsonValue, room_version: &RoomVersionId, + pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<(OwnedEventId, CanonicalJsonObject)> { let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -1322,7 +1331,7 @@ pub async fn validate_and_add_event_id( .expect("ruma's reference hashes are valid event ids"); let back_off = |id| async { - match services() + match services .globals .bad_event_ratelimiter .write() @@ -1338,7 +1347,7 @@ pub async fn validate_and_add_event_id( } }; - if let Some((time, tries)) = services() + if let Some((time, tries)) = services .globals .bad_event_ratelimiter .read() @@ -1366,9 +1375,10 @@ pub async fn validate_and_add_event_id( } pub(crate) async fn invite_helper( - sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>, is_direct: bool, + services: &Services, sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option<String>, + is_direct: bool, ) -> Result<()> { - if !services().users.is_admin(user_id)? && services().globals.block_non_admin_invites() { + if !services.users.is_admin(user_id)? && services.globals.block_non_admin_invites() { info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); return Err(Error::BadRequest( ErrorKind::forbidden(), @@ -1378,9 +1388,9 @@ pub(crate) async fn invite_helper( if !user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services().users.avatar_url(user_id)?, + avatar_url: services.users.avatar_url(user_id)?, displayname: None, is_direct: Some(is_direct), membership: MembershipState::Invite, @@ -1391,7 +1401,7 @@ pub(crate) async fn invite_helper( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + let (pdu, pdu_json) = services.rooms.timeline.create_hash_and_sign_event( PduBuilder { event_type: TimelineEventType::RoomMember, content, @@ -1404,16 +1414,16 @@ pub(crate) async fn invite_helper( &state_lock, )?; - let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services.rooms.state.calculate_invite_state(&pdu)?; drop(state_lock); (pdu, pdu_json, invite_room_state) }; - let room_version_id = services().rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id)?; - let response = services() + let response = services .sending .send_federation_request( user_id.server_name(), @@ -1423,7 +1433,7 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, - via: services().rooms.state_cache.servers_route_via(room_id).ok(), + via: services.rooms.state_cache.servers_route_via(room_id).ok(), }, ) .await?; @@ -1459,13 +1469,13 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services() + services .rooms .event_handler .fetch_required_signing_keys([&value], &pub_key_map) .await?; - let pdu_id: Vec<u8> = services() + let pdu_id: Vec<u8> = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) @@ -1475,24 +1485,20 @@ pub(crate) async fn invite_helper( "Could not accept incoming PDU as timeline event.", ))?; - services().sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id)?; return Ok(()); } - if !services() - .rooms - .state_cache - .is_joined(sender_user, room_id)? - { + if !services.rooms.state_cache.is_joined(sender_user, room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", )); } - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; - services() + services .rooms .timeline .build_and_append_pdu( @@ -1500,11 +1506,11 @@ pub(crate) async fn invite_helper( event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: services().users.displayname(user_id)?, - avatar_url: services().users.avatar_url(user_id)?, + displayname: services.users.displayname(user_id)?, + avatar_url: services.users.avatar_url(user_id)?, is_direct: Some(is_direct), third_party_invite: None, - blurhash: services().users.blurhash(user_id)?, + blurhash: services.users.blurhash(user_id)?, reason, join_authorized_via_users_server: None, }) @@ -1526,13 +1532,13 @@ pub(crate) async fn invite_helper( // Make a user leave all their joined rooms, forgets all rooms, and ignores // errors -pub async fn leave_all_rooms(user_id: &UserId) { - let all_rooms = services() +pub async fn leave_all_rooms(services: &Services, user_id: &UserId) { + let all_rooms = services .rooms .state_cache .rooms_joined(user_id) .chain( - services() + services .rooms .state_cache .rooms_invited(user_id) @@ -1546,35 +1552,35 @@ pub async fn leave_all_rooms(user_id: &UserId) { }; // ignore errors - if let Err(e) = leave_room(user_id, &room_id, None).await { + if let Err(e) = leave_room(services, user_id, &room_id, None).await { warn!(%room_id, %user_id, %e, "Failed to leave room"); } - if let Err(e) = services().rooms.state_cache.forget(&room_id, user_id) { + if let Err(e) = services.rooms.state_cache.forget(&room_id, user_id) { warn!(%room_id, %user_id, %e, "Failed to forget room"); } } } -pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> { +pub async fn leave_room(services: &Services, user_id: &UserId, room_id: &RoomId, reason: Option<String>) -> Result<()> { // Ask a remote server if we don't have this room - if !services() + if !services .rooms .state_cache - .server_in_room(services().globals.server_name(), room_id)? + .server_in_room(services.globals.server_name(), room_id)? { - if let Err(e) = remote_leave_room(user_id, room_id).await { + if let Err(e) = remote_leave_room(services, user_id, room_id).await { warn!("Failed to leave room {} remotely: {}", user_id, e); // Don't tell the client about this error } - let last_state = services() + let last_state = services .rooms .state_cache .invite_state(user_id, room_id)? - .map_or_else(|| services().rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; + .map_or_else(|| services.rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; // We always drop the invite, we can't rely on other servers - services().rooms.state_cache.update_membership( + services.rooms.state_cache.update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1584,10 +1590,10 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin true, )?; } else { - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = services.rooms.state.mutex.lock(room_id).await; let member_event = - services() + services .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; @@ -1597,7 +1603,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin None => { error!("Trying to leave a room you are not a member of."); - services().rooms.state_cache.update_membership( + services.rooms.state_cache.update_membership( room_id, user_id, RoomMemberEventContent::new(MembershipState::Leave), @@ -1619,7 +1625,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin event.membership = MembershipState::Leave; event.reason = reason; - services() + services .rooms .timeline .build_and_append_pdu( @@ -1640,16 +1646,16 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin Ok(()) } -async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { +async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut make_leave_response_and_server = Err(Error::BadServerResponse("No server available to assist in leaving.")); - let invite_state = services() + let invite_state = services .rooms .state_cache .invite_state(user_id, room_id)? .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; - let mut servers: HashSet<OwnedServerName> = services() + let mut servers: HashSet<OwnedServerName> = services .rooms .state_cache .servers_invite_via(room_id) @@ -1669,7 +1675,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { debug!("servers in remote_leave_room: {servers:?}"); for remote_server in servers { - let make_leave_response = services() + let make_leave_response = services .sending .send_federation_request( &remote_server, @@ -1691,7 +1697,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { let room_version_id = match make_leave_response.room_version { Some(version) - if services() + if services .globals .supported_room_versions() .contains(&version) => @@ -1707,7 +1713,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { // TODO: Is origin needed? leave_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String(services.globals.server_name().as_str().to_owned()), ); leave_event_stub.insert( "origin_server_ts".to_owned(), @@ -1729,8 +1735,8 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { // In order to create a compatible ref hash (EventID) the `hashes` field needs // to be present ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut leave_event_stub, &room_version_id, ) @@ -1750,7 +1756,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { // It has enough fields to be called a proper event now let leave_event = leave_event_stub; - services() + services .sending .send_federation_request( &remote_server, diff --git a/src/api/client/message.rs b/src/api/client/message.rs index c376ee522..c0b5cf0c3 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -1,5 +1,6 @@ use std::collections::{BTreeMap, HashSet}; +use axum::extract::State; use conduit::PduCount; use ruma::{ api::client::{ @@ -12,7 +13,10 @@ }; use serde_json::{from_str, Value}; -use crate::{service::pdu::PduBuilder, services, utils, Error, PduEvent, Result, Ruma}; +use crate::{ + service::{pdu::PduBuilder, Services}, + utils, Error, PduEvent, Result, Ruma, +}; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// @@ -24,21 +28,19 @@ /// - Tries to send the event into the room, auth rules will determine if it is /// allowed pub(crate) async fn send_message_event_route( - body: Ruma<send_message_event::v3::Request>, + State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>, ) -> Result<send_message_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; // Forbid m.room.encrypted if encryption is disabled - if MessageLikeEventType::RoomEncrypted == body.event_type && !services().globals.allow_encryption() { + if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() { return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); } - if body.event_type == MessageLikeEventType::CallInvite - && services().rooms.directory.is_public_room(&body.room_id)? - { + if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "Room call invites are not allowed in public rooms", @@ -46,7 +48,7 @@ pub(crate) async fn send_message_event_route( } // Check if this is a new transaction id - if let Some(response) = services() + if let Some(response) = services .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? { @@ -71,7 +73,7 @@ pub(crate) async fn send_message_event_route( let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let event_id = services() + let event_id = services .rooms .timeline .build_and_append_pdu( @@ -89,7 +91,7 @@ pub(crate) async fn send_message_event_route( ) .await?; - services() + services .transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; @@ -105,7 +107,7 @@ pub(crate) async fn send_message_event_route( /// - Only works if the user is joined (TODO: always allow, but only show events /// where the user was joined, depending on `history_visibility`) pub(crate) async fn get_message_events_route( - body: Ruma<get_message_events::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_message_events::v3::Request>, ) -> Result<get_message_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); @@ -123,7 +125,7 @@ pub(crate) async fn get_message_events_route( .as_ref() .and_then(|t| PduCount::try_from_string(t).ok()); - services() + services .rooms .lazy_loading .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) @@ -139,12 +141,12 @@ pub(crate) async fn get_message_events_route( match body.dir { ruma::api::Direction::Forward => { - let events_after: Vec<_> = services() + let events_after: Vec<_> = services .rooms .timeline .pdus_after(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(pdu, sender_user, &body.room_id) + .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id) }) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` @@ -157,7 +159,7 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services().rooms.lazy_loading.lazy_load_was_sent_before( + && !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -181,17 +183,17 @@ pub(crate) async fn get_message_events_route( resp.chunk = events_after; }, ruma::api::Direction::Backward => { - services() + services .rooms .timeline .backfill_if_required(&body.room_id, from) .await?; - let events_before: Vec<_> = services() + let events_before: Vec<_> = services .rooms .timeline .pdus_until(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(pdu, sender_user, &body.room_id)}) + .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id)}) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take(limit) .collect(); @@ -202,7 +204,7 @@ pub(crate) async fn get_message_events_route( * https://github.com/vector-im/element-web/issues/21034 */ if !cfg!(feature = "element_hacks") - && !services().rooms.lazy_loading.lazy_load_was_sent_before( + && !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -229,11 +231,12 @@ pub(crate) async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( - &body.room_id, - &StateEventType::RoomMember, - ll_id.as_str(), - )? { + if let Some(member_event) = + services + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? + { resp.state.push(member_event.to_state_event()); } } @@ -241,7 +244,7 @@ pub(crate) async fn get_message_events_route( // remove the feature check when we are sure clients like element can handle it if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { - services() + services .rooms .lazy_loading .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) @@ -252,8 +255,8 @@ pub(crate) async fn get_message_events_route( Ok(resp) } -fn visibility_filter(pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { - services() +fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { + services .rooms .state_accessor .user_can_see_event(user_id, room_id, &pdu.event_id) diff --git a/src/api/client/openid.rs b/src/api/client/openid.rs index a19320526..3e4c6ca8c 100644 --- a/src/api/client/openid.rs +++ b/src/api/client/openid.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use axum::extract::State; use conduit::utils; use ruma::{ api::client::{account, error::ErrorKind}, @@ -7,7 +8,7 @@ }; use super::TOKEN_LENGTH; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `POST /_matrix/client/v3/user/{userId}/openid/request_token` /// @@ -15,7 +16,7 @@ /// /// - The token generated is only valid for the OpenID API pub(crate) async fn create_openid_token_route( - body: Ruma<account::request_openid_token::v3::Request>, + State(services): State<crate::State>, body: Ruma<account::request_openid_token::v3::Request>, ) -> Result<account::request_openid_token::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -28,14 +29,14 @@ pub(crate) async fn create_openid_token_route( let access_token = utils::random_string(TOKEN_LENGTH); - let expires_in = services() + let expires_in = services .users .create_openid_token(&body.user_id, &access_token)?; Ok(account::request_openid_token::v3::Response { access_token, token_type: TokenType::Bearer, - matrix_server_name: services().globals.config.server_name.clone(), + matrix_server_name: services.globals.config.server_name.clone(), expires_in: Duration::from_secs(expires_in), }) } diff --git a/src/api/client/presence.rs b/src/api/client/presence.rs index 775f82ecb..8384d5aca 100644 --- a/src/api/client/presence.rs +++ b/src/api/client/presence.rs @@ -1,22 +1,24 @@ use std::time::Duration; +use axum::extract::State; use ruma::api::client::{ error::ErrorKind, presence::{get_presence, set_presence}, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/presence/{userId}/status` /// /// Sets the presence state of the sender user. -pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> { - if !services().globals.allow_local_presence() { +pub(crate) async fn set_presence_route( + State(services): State<crate::State>, body: Ruma<set_presence::v3::Request>, +) -> Result<set_presence::v3::Response> { + if !services.globals.allow_local_presence() { return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server")); } let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if sender_user != &body.user_id && body.appservice_info.is_none() { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -24,7 +26,7 @@ pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> )); } - services() + services .presence .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; @@ -36,8 +38,10 @@ pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> /// Gets the presence state of the given user. /// /// - Only works if you share a room with the user -pub(crate) async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> { - if !services().globals.allow_local_presence() { +pub(crate) async fn get_presence_route( + State(services): State<crate::State>, body: Ruma<get_presence::v3::Request>, +) -> Result<get_presence::v3::Response> { + if !services.globals.allow_local_presence() { return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server")); } @@ -45,12 +49,12 @@ pub(crate) async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> let mut presence_event = None; - for _room_id in services() + for _room_id in services .rooms .user .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { - if let Some(presence) = services().presence.get_presence(&body.user_id)? { + if let Some(presence) = services.presence.get_presence(&body.user_id)? { presence_event = Some(presence); break; } diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index b03059397..3b2c32ecf 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::{ client::{ @@ -14,8 +15,8 @@ use tracing::warn; use crate::{ - service::{pdu::PduBuilder, user_is_local}, - services, Error, Result, Ruma, + service::{pdu::PduBuilder, user_is_local, Services}, + Error, Result, Ruma, }; /// # `PUT /_matrix/client/r0/profile/{userId}/displayname` @@ -24,21 +25,21 @@ /// /// - Also makes sure other users receive the update using presence EDUs pub(crate) async fn set_displayname_route( - body: Ruma<set_display_name::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_display_name::v3::Request>, ) -> Result<set_display_name::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(sender_user) .filter_map(Result::ok) .collect(); - update_displayname(sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?; + update_displayname(services, sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?; - if services().globals.allow_local_presence() { + if services.globals.allow_local_presence() { // Presence update - services() + services .presence .ping_presence(sender_user, &PresenceState::Online)?; } @@ -53,11 +54,11 @@ pub(crate) async fn set_displayname_route( /// - If user is on another server and we do not have a local copy already fetch /// displayname over federation pub(crate) async fn get_displayname_route( - body: Ruma<get_display_name::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_display_name::v3::Request>, ) -> Result<get_display_name::v3::Response> { if !user_is_local(&body.user_id) { // Create and update our local copy of the user - if let Ok(response) = services() + if let Ok(response) = services .sending .send_federation_request( body.user_id.server_name(), @@ -68,19 +69,19 @@ pub(crate) async fn get_displayname_route( ) .await { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; + if !services.users.exists(&body.user_id)? { + services.users.create(&body.user_id, None)?; } - services() + services .users .set_displayname(&body.user_id, response.displayname.clone()) .await?; - services() + services .users .set_avatar_url(&body.user_id, response.avatar_url.clone()) .await?; - services() + services .users .set_blurhash(&body.user_id, response.blurhash.clone()) .await?; @@ -91,14 +92,14 @@ pub(crate) async fn get_displayname_route( } } - if !services().users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id)? { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_display_name::v3::Response { - displayname: services().users.displayname(&body.user_id)?, + displayname: services.users.displayname(&body.user_id)?, }) } @@ -108,10 +109,10 @@ pub(crate) async fn get_displayname_route( /// /// - Also makes sure other users receive the update using presence EDUs pub(crate) async fn set_avatar_url_route( - body: Ruma<set_avatar_url::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_avatar_url::v3::Request>, ) -> Result<set_avatar_url::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = services .rooms .state_cache .rooms_joined(sender_user) @@ -119,6 +120,7 @@ pub(crate) async fn set_avatar_url_route( .collect(); update_avatar_url( + services, sender_user.clone(), body.avatar_url.clone(), body.blurhash.clone(), @@ -126,9 +128,9 @@ pub(crate) async fn set_avatar_url_route( ) .await?; - if services().globals.allow_local_presence() { + if services.globals.allow_local_presence() { // Presence update - services() + services .presence .ping_presence(sender_user, &PresenceState::Online)?; } @@ -143,11 +145,11 @@ pub(crate) async fn set_avatar_url_route( /// - If user is on another server and we do not have a local copy already fetch /// `avatar_url` and blurhash over federation pub(crate) async fn get_avatar_url_route( - body: Ruma<get_avatar_url::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_avatar_url::v3::Request>, ) -> Result<get_avatar_url::v3::Response> { if !user_is_local(&body.user_id) { // Create and update our local copy of the user - if let Ok(response) = services() + if let Ok(response) = services .sending .send_federation_request( body.user_id.server_name(), @@ -158,19 +160,19 @@ pub(crate) async fn get_avatar_url_route( ) .await { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; + if !services.users.exists(&body.user_id)? { + services.users.create(&body.user_id, None)?; } - services() + services .users .set_displayname(&body.user_id, response.displayname.clone()) .await?; - services() + services .users .set_avatar_url(&body.user_id, response.avatar_url.clone()) .await?; - services() + services .users .set_blurhash(&body.user_id, response.blurhash.clone()) .await?; @@ -182,15 +184,15 @@ pub(crate) async fn get_avatar_url_route( } } - if !services().users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id)? { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_avatar_url::v3::Response { - avatar_url: services().users.avatar_url(&body.user_id)?, - blurhash: services().users.blurhash(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id)?, + blurhash: services.users.blurhash(&body.user_id)?, }) } @@ -200,10 +202,12 @@ pub(crate) async fn get_avatar_url_route( /// /// - If user is on another server and we do not have a local copy already, /// fetch profile over federation. -pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> { +pub(crate) async fn get_profile_route( + State(services): State<crate::State>, body: Ruma<get_profile::v3::Request>, +) -> Result<get_profile::v3::Response> { if !user_is_local(&body.user_id) { // Create and update our local copy of the user - if let Ok(response) = services() + if let Ok(response) = services .sending .send_federation_request( body.user_id.server_name(), @@ -214,19 +218,19 @@ pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> R ) .await { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; + if !services.users.exists(&body.user_id)? { + services.users.create(&body.user_id, None)?; } - services() + services .users .set_displayname(&body.user_id, response.displayname.clone()) .await?; - services() + services .users .set_avatar_url(&body.user_id, response.avatar_url.clone()) .await?; - services() + services .users .set_blurhash(&body.user_id, response.blurhash.clone()) .await?; @@ -239,23 +243,23 @@ pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> R } } - if !services().users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id)? { // Return 404 if this user doesn't exist and we couldn't fetch it over // federation return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); } Ok(get_profile::v3::Response { - avatar_url: services().users.avatar_url(&body.user_id)?, - blurhash: services().users.blurhash(&body.user_id)?, - displayname: services().users.displayname(&body.user_id)?, + avatar_url: services.users.avatar_url(&body.user_id)?, + blurhash: services.users.blurhash(&body.user_id)?, + displayname: services.users.displayname(&body.user_id)?, }) } pub async fn update_displayname( - user_id: OwnedUserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, + services: &Services, user_id: OwnedUserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, ) -> Result<()> { - services() + services .users .set_displayname(&user_id, displayname.clone()) .await?; @@ -271,7 +275,7 @@ pub async fn update_displayname( displayname: displayname.clone(), join_authorized_via_users_server: None, ..serde_json::from_str( - services() + services .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? @@ -294,19 +298,20 @@ pub async fn update_displayname( .filter_map(Result::ok) .collect(); - update_all_rooms(all_joined_rooms, user_id).await; + update_all_rooms(services, all_joined_rooms, user_id).await; Ok(()) } pub async fn update_avatar_url( - user_id: OwnedUserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, + services: &Services, user_id: OwnedUserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>, + all_joined_rooms: Vec<OwnedRoomId>, ) -> Result<()> { - services() + services .users .set_avatar_url(&user_id, avatar_url.clone()) .await?; - services() + services .users .set_blurhash(&user_id, blurhash.clone()) .await?; @@ -323,7 +328,7 @@ pub async fn update_avatar_url( blurhash: blurhash.clone(), join_authorized_via_users_server: None, ..serde_json::from_str( - services() + services .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? @@ -346,15 +351,17 @@ pub async fn update_avatar_url( .filter_map(Result::ok) .collect(); - update_all_rooms(all_joined_rooms, user_id).await; + update_all_rooms(services, all_joined_rooms, user_id).await; Ok(()) } -pub async fn update_all_rooms(all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId) { +pub async fn update_all_rooms( + services: &Services, all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId, +) { for (pdu_builder, room_id) in all_joined_rooms { - let state_lock = services().rooms.state.mutex.lock(room_id).await; - if let Err(e) = services() + let state_lock = services.rooms.state.mutex.lock(room_id).await; + if let Err(e) = services .rooms .timeline .build_and_append_pdu(pdu_builder, &user_id, room_id, &state_lock) diff --git a/src/api/client/push.rs b/src/api/client/push.rs index 5d7fd6306..26462e790 100644 --- a/src/api/client/push.rs +++ b/src/api/client/push.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::client::{ error::ErrorKind, @@ -10,18 +11,18 @@ push::{InsertPushRuleError, RemovePushRuleError, Ruleset}, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/client/r0/pushrules/` /// /// Retrieves the push rules event for this user. pub(crate) async fn get_pushrules_all_route( - body: Ruma<get_pushrules_all::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_pushrules_all::v3::Request>, ) -> Result<get_pushrules_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let event = - services() + services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?; @@ -34,7 +35,7 @@ pub(crate) async fn get_pushrules_all_route( global: account_data.global, }) } else { - services().account_data.update( + services.account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -55,10 +56,12 @@ pub(crate) async fn get_pushrules_all_route( /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Retrieves a single specified push rule for this user. -pub(crate) async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> { +pub(crate) async fn get_pushrule_route( + State(services): State<crate::State>, body: Ruma<get_pushrule::v3::Request>, +) -> Result<get_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -84,7 +87,9 @@ pub(crate) async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Creates a single specified push rule for this user. -pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> { +pub(crate) async fn set_pushrule_route( + State(services): State<crate::State>, body: Ruma<set_pushrule::v3::Request>, +) -> Result<set_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; @@ -95,7 +100,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -134,7 +139,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> return Err(err); } - services().account_data.update( + services.account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -148,7 +153,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> /// /// Gets the actions of a single specified push rule for this user. pub(crate) async fn get_pushrule_actions_route( - body: Ruma<get_pushrule_actions::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_pushrule_actions::v3::Request>, ) -> Result<get_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -159,7 +164,7 @@ pub(crate) async fn get_pushrule_actions_route( )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -183,7 +188,7 @@ pub(crate) async fn get_pushrule_actions_route( /// /// Sets the actions of a single specified push rule for this user. pub(crate) async fn set_pushrule_actions_route( - body: Ruma<set_pushrule_actions::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_pushrule_actions::v3::Request>, ) -> Result<set_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -194,7 +199,7 @@ pub(crate) async fn set_pushrule_actions_route( )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -211,7 +216,7 @@ pub(crate) async fn set_pushrule_actions_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services().account_data.update( + services.account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -225,7 +230,7 @@ pub(crate) async fn set_pushrule_actions_route( /// /// Gets the enabled status of a single specified push rule for this user. pub(crate) async fn get_pushrule_enabled_route( - body: Ruma<get_pushrule_enabled::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_pushrule_enabled::v3::Request>, ) -> Result<get_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -236,7 +241,7 @@ pub(crate) async fn get_pushrule_enabled_route( )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -259,7 +264,7 @@ pub(crate) async fn get_pushrule_enabled_route( /// /// Sets the enabled status of a single specified push rule for this user. pub(crate) async fn set_pushrule_enabled_route( - body: Ruma<set_pushrule_enabled::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_pushrule_enabled::v3::Request>, ) -> Result<set_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -270,7 +275,7 @@ pub(crate) async fn set_pushrule_enabled_route( )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -287,7 +292,7 @@ pub(crate) async fn set_pushrule_enabled_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } - services().account_data.update( + services.account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -301,7 +306,7 @@ pub(crate) async fn set_pushrule_enabled_route( /// /// Deletes a single specified push rule for this user. pub(crate) async fn delete_pushrule_route( - body: Ruma<delete_pushrule::v3::Request>, + State(services): State<crate::State>, body: Ruma<delete_pushrule::v3::Request>, ) -> Result<delete_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -312,7 +317,7 @@ pub(crate) async fn delete_pushrule_route( )); } - let event = services() + let event = services .account_data .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; @@ -336,7 +341,7 @@ pub(crate) async fn delete_pushrule_route( return Err(err); } - services().account_data.update( + services.account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -349,11 +354,13 @@ pub(crate) async fn delete_pushrule_route( /// # `GET /_matrix/client/r0/pushers` /// /// Gets all currently active pushers for the sender user. -pub(crate) async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> { +pub(crate) async fn get_pushers_route( + State(services): State<crate::State>, body: Ruma<get_pushers::v3::Request>, +) -> Result<get_pushers::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: services().pusher.get_pushers(sender_user)?, + pushers: services.pusher.get_pushers(sender_user)?, }) } @@ -362,10 +369,12 @@ pub(crate) async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> R /// Adds a pusher for the sender user. /// /// - TODO: Handle `append` -pub(crate) async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> { +pub(crate) async fn set_pushers_route( + State(services): State<crate::State>, body: Ruma<set_pusher::v3::Request>, +) -> Result<set_pusher::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().pusher.set_pusher(sender_user, &body.action)?; + services.pusher.set_pusher(sender_user, &body.action)?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index cbb36b813..f40f24932 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use axum::extract::State; use conduit::PduCount; use ruma::{ api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, @@ -10,7 +11,7 @@ MilliSecondsSinceUnixEpoch, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// @@ -20,7 +21,7 @@ /// - If `read_receipt` is set: Update private marker and public read receipt /// EDU pub(crate) async fn set_read_marker_route( - body: Ruma<set_read_marker::v3::Request>, + State(services): State<crate::State>, body: Ruma<set_read_marker::v3::Request>, ) -> Result<set_read_marker::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -30,7 +31,7 @@ pub(crate) async fn set_read_marker_route( event_id: fully_read.clone(), }, }; - services().account_data.update( + services.account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, @@ -39,14 +40,14 @@ pub(crate) async fn set_read_marker_route( } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { - services() + services .rooms .user .reset_notification_counts(sender_user, &body.room_id)?; } if let Some(event) = &body.private_read_receipt { - let count = services() + let count = services .rooms .timeline .get_pdu_count(event)? @@ -60,7 +61,7 @@ pub(crate) async fn set_read_marker_route( }, PduCount::Normal(c) => c, }; - services() + services .rooms .read_receipt .private_read_set(&body.room_id, sender_user, count)?; @@ -82,7 +83,7 @@ pub(crate) async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - services().rooms.read_receipt.readreceipt_update( + services.rooms.read_receipt.readreceipt_update( sender_user, &body.room_id, &ruma::events::receipt::ReceiptEvent { @@ -99,7 +100,7 @@ pub(crate) async fn set_read_marker_route( /// /// Sets private read marker and public read receipt EDU. pub(crate) async fn create_receipt_route( - body: Ruma<create_receipt::v3::Request>, + State(services): State<crate::State>, body: Ruma<create_receipt::v3::Request>, ) -> Result<create_receipt::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -107,7 +108,7 @@ pub(crate) async fn create_receipt_route( &body.receipt_type, create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate ) { - services() + services .rooms .user .reset_notification_counts(sender_user, &body.room_id)?; @@ -120,7 +121,7 @@ pub(crate) async fn create_receipt_route( event_id: body.event_id.clone(), }, }; - services().account_data.update( + services.account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, @@ -142,7 +143,7 @@ pub(crate) async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.clone(), receipts); - services().rooms.read_receipt.readreceipt_update( + services.rooms.read_receipt.readreceipt_update( sender_user, &body.room_id, &ruma::events::receipt::ReceiptEvent { @@ -152,7 +153,7 @@ pub(crate) async fn create_receipt_route( )?; }, create_receipt::v3::ReceiptType::ReadPrivate => { - let count = services() + let count = services .rooms .timeline .get_pdu_count(&body.event_id)? @@ -166,7 +167,7 @@ pub(crate) async fn create_receipt_route( }, PduCount::Normal(c) => c, }; - services() + services .rooms .read_receipt .private_read_set(&body.room_id, sender_user, count)?; diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 308d12e5b..89446754c 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -1,23 +1,26 @@ +use axum::extract::State; use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, }; use serde_json::value::to_raw_value; -use crate::{service::pdu::PduBuilder, services, Result, Ruma}; +use crate::{service::pdu::PduBuilder, Result, Ruma}; /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// /// Tries to send a redaction event into the room. /// /// - TODO: Handle txn id -pub(crate) async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> { +pub(crate) async fn redact_event_route( + State(services): State<crate::State>, body: Ruma<redact_event::v3::Request>, +) -> Result<redact_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let event_id = services() + let event_id = services .rooms .timeline .build_and_append_pdu( diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index fb366bdc3..ae6459400 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,30 +1,28 @@ +use axum::extract::State; use ruma::api::client::relations::{ get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, }; -use crate::{services, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( - body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services.rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + &Some(body.event_type.clone()), + &Some(body.rel_type.clone()), + &body.from, + &body.to, + &body.limit, + body.recurse, + body.dir, + )?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, @@ -36,25 +34,22 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` pub(crate) async fn get_relating_events_with_rel_type_route( - body: Ruma<get_relating_events_with_rel_type::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>, ) -> Result<get_relating_events_with_rel_type::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &Some(body.rel_type.clone()), - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - )?; + let res = services.rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + &None, + &Some(body.rel_type.clone()), + &body.from, + &body.to, + &body.limit, + body.recurse, + body.dir, + )?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -66,23 +61,20 @@ pub(crate) async fn get_relating_events_with_rel_type_route( /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}` pub(crate) async fn get_relating_events_route( - body: Ruma<get_relating_events::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>, ) -> Result<get_relating_events::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &None, - &body.from, - &body.to, - &body.limit, - body.recurse, - body.dir, - ) + services.rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + &None, + &None, + &body.from, + &body.to, + &body.limit, + body.recurse, + body.dir, + ) } diff --git a/src/api/client/report.rs b/src/api/client/report.rs index dae6086f4..a16df4448 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use axum::extract::State; use rand::Rng; use ruma::{ api::client::{error::ErrorKind, room::report_content}, @@ -9,13 +10,18 @@ use tokio::time::sleep; use tracing::info; -use crate::{debug_info, service::pdu::PduEvent, services, utils::HtmlEscape, Error, Result, Ruma}; +use crate::{ + debug_info, + service::{pdu::PduEvent, Services}, + utils::HtmlEscape, + Error, Result, Ruma, +}; /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// /// Reports an inappropriate event to homeserver admins pub(crate) async fn report_event_route( - body: Ruma<report_content::v3::Request>, + State(services): State<crate::State>, body: Ruma<report_content::v3::Request>, ) -> Result<report_content::v3::Response> { // user authentication let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -26,18 +32,26 @@ pub(crate) async fn report_event_route( ); // check if we know about the reported event ID or if it's invalid - let Some(pdu) = services().rooms.timeline.get_pdu(&body.event_id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else { return Err(Error::BadRequest( ErrorKind::NotFound, "Event ID is not known to us or Event ID is invalid", )); }; - is_report_valid(&pdu.event_id, &body.room_id, sender_user, &body.reason, body.score, &pdu)?; + is_report_valid( + services, + &pdu.event_id, + &body.room_id, + sender_user, + &body.reason, + body.score, + &pdu, + )?; // send admin room message that we received the report with an @room ping for // urgency - services() + services .admin .send_message(message::RoomMessageEventContent::text_html( format!( @@ -79,8 +93,8 @@ pub(crate) async fn report_event_route( /// check if score is in valid range /// check if report reasoning is less than or equal to 750 characters fn is_report_valid( - event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>, score: Option<ruma::Int>, - pdu: &std::sync::Arc<PduEvent>, + services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>, + score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>, ) -> Result<bool> { debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid"); @@ -91,7 +105,7 @@ fn is_report_valid( )); } - if !services() + if !services .rooms .state_cache .room_members(&pdu.room_id) diff --git a/src/api/client/room.rs b/src/api/client/room.rs index adf58b04d..c4d828220 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -1,5 +1,6 @@ use std::{cmp::max, collections::BTreeMap}; +use axum::extract::State; use conduit::{debug_info, debug_warn}; use ruma::{ api::client::{ @@ -30,8 +31,8 @@ use super::invite_helper; use crate::{ - service::{appservice::RegistrationInfo, pdu::PduBuilder}, - services, Error, Result, Ruma, + service::{appservice::RegistrationInfo, pdu::PduBuilder, Services}, + Error, Result, Ruma, }; /// Recommended transferable state events list from the spec @@ -63,44 +64,46 @@ /// - Send events listed in initial state /// - Send events implied by `name` and `topic` /// - Send invite events -pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<create_room::v3::Response> { +pub(crate) async fn create_room_route( + State(services): State<crate::State>, body: Ruma<create_room::v3::Request>, +) -> Result<create_room::v3::Response> { use create_room::v3::RoomPreset; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().globals.allow_room_creation() + if !services.globals.allow_room_creation() && body.appservice_info.is_none() - && !services().users.is_admin(sender_user)? + && !services.users.is_admin(sender_user)? { return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); } let room_id: OwnedRoomId = if let Some(custom_room_id) = &body.room_id { - custom_room_id_check(custom_room_id)? + custom_room_id_check(services, custom_room_id)? } else { - RoomId::new(&services().globals.config.server_name) + RoomId::new(&services.globals.config.server_name) }; // check if room ID doesn't already exist instead of erroring on auth check - if services().rooms.short.get_shortroomid(&room_id)?.is_some() { + if services.rooms.short.get_shortroomid(&room_id)?.is_some() { return Err(Error::BadRequest( ErrorKind::RoomInUse, "Room with that custom room ID already exists", )); } - let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().rooms.state.mutex.lock(&room_id).await; + let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; + let state_lock = services.rooms.state.mutex.lock(&room_id).await; let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { - Some(room_alias_check(alias, &body.appservice_info).await?) + Some(room_alias_check(services, alias, &body.appservice_info).await?) } else { None }; let room_version = match body.room_version.clone() { Some(room_version) => { - if services() + if services .globals .supported_room_versions() .contains(&room_version) @@ -113,7 +116,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R )); } }, - None => services().globals.default_room_version(), + None => services.globals.default_room_version(), }; let content = match &body.creation_content { @@ -184,7 +187,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R }; // 1. The room create event - services() + services .rooms .timeline .build_and_append_pdu( @@ -202,7 +205,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R .await?; // 2. Let the room creator join - services() + services .rooms .timeline .build_and_append_pdu( @@ -210,11 +213,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user)?, + avatar_url: services.users.avatar_url(sender_user)?, is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -249,7 +252,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R let power_levels_content = default_power_levels_content(&body.power_level_content_override, &body.visibility, users)?; - services() + services .rooms .timeline .build_and_append_pdu( @@ -268,7 +271,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R // 4. Canonical room alias if let Some(room_alias_id) = &alias { - services() + services .rooms .timeline .build_and_append_pdu( @@ -293,7 +296,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R // 5. Events set by preset // 5.1 Join Rules - services() + services .rooms .timeline .build_and_append_pdu( @@ -316,7 +319,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R .await?; // 5.2 History Visibility - services() + services .rooms .timeline .build_and_append_pdu( @@ -335,7 +338,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R .await?; // 5.3 Guest Access - services() + services .rooms .timeline .build_and_append_pdu( @@ -378,11 +381,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R pdu_builder.state_key.get_or_insert_with(String::new); // Silently skip encryption events if they are not allowed - if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services().globals.allow_encryption() { + if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services.globals.allow_encryption() { continue; } - services() + services .rooms .timeline .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) @@ -391,7 +394,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R // 7. Events implied by name and topic if let Some(name) = &body.name { - services() + services .rooms .timeline .build_and_append_pdu( @@ -411,7 +414,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R } if let Some(topic) = &body.topic { - services() + services .rooms .timeline .build_and_append_pdu( @@ -435,21 +438,21 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await { + if let Err(e) = invite_helper(services, sender_user, user_id, &room_id, None, body.is_direct).await { warn!(%e, "Failed to send invite"); } } // Homeserver specific stuff if let Some(alias) = alias { - services() + services .rooms .alias .set_alias(&alias, &room_id, sender_user)?; } if body.visibility == room::Visibility::Public { - services().rooms.directory.set_public(&room_id)?; + services.rooms.directory.set_public(&room_id)?; } info!("{sender_user} created a room with room ID {room_id}"); @@ -464,11 +467,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R /// - You have to currently be joined to the room (TODO: Respect history /// visibility) pub(crate) async fn get_room_event_route( - body: Ruma<get_room_event::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_room_event::v3::Request>, ) -> Result<get_room_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() + let event = services .rooms .timeline .get_pdu(&body.event_id)? @@ -477,7 +480,7 @@ pub(crate) async fn get_room_event_route( Error::BadRequest(ErrorKind::NotFound, "Event not found.") })?; - if !services() + if !services .rooms .state_accessor .user_can_see_event(sender_user, &event.room_id, &body.event_id)? @@ -502,10 +505,12 @@ pub(crate) async fn get_room_event_route( /// /// - Only users joined to the room are allowed to call this, or if /// `history_visibility` is world readable in the room -pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> { +pub(crate) async fn get_room_aliases_route( + State(services): State<crate::State>, body: Ruma<aliases::v3::Request>, +) -> Result<aliases::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_accessor .user_can_see_state_events(sender_user, &body.room_id)? @@ -517,7 +522,7 @@ pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> } Ok(aliases::v3::Response { - aliases: services() + aliases: services .rooms .alias .local_aliases_for_room(&body.room_id) @@ -536,10 +541,12 @@ pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> /// - Transfers some state events /// - Moves local aliases /// - Modifies old room power levels to prevent users from speaking -pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> { +pub(crate) async fn upgrade_room_route( + State(services): State<crate::State>, body: Ruma<upgrade_room::v3::Request>, +) -> Result<upgrade_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .globals .supported_room_versions() .contains(&body.new_version) @@ -551,19 +558,19 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> } // Create a replacement room - let replacement_room = RoomId::new(services().globals.server_name()); + let replacement_room = RoomId::new(services.globals.server_name()); - let _short_id = services() + let _short_id = services .rooms .short .get_or_create_shortroomid(&replacement_room)?; - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; // Send a m.room.tombstone event to the old room to indicate that it is not // intended to be used any further Fail if the sender does not have the required // permissions - let tombstone_event_id = services() + let tombstone_event_id = services .rooms .timeline .build_and_append_pdu( @@ -586,11 +593,11 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Change lock to replacement room drop(state_lock); - let state_lock = services().rooms.state.mutex.lock(&replacement_room).await; + let state_lock = services.rooms.state.mutex.lock(&replacement_room).await; // Get the old room creation event let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( - services() + services .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? @@ -658,7 +665,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); } - services() + services .rooms .timeline .build_and_append_pdu( @@ -676,7 +683,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> .await?; // Join the new room - services() + services .rooms .timeline .build_and_append_pdu( @@ -684,11 +691,11 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> event_type: TimelineEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, + displayname: services.users.displayname(sender_user)?, + avatar_url: services.users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, + blurhash: services.users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -705,7 +712,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Replicate transferable state events to the new room for event_type in TRANSFERABLE_STATE_EVENTS { - let event_content = match services() + let event_content = match services .rooms .state_accessor .room_state_get(&body.room_id, event_type, "")? @@ -714,7 +721,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> None => continue, // Skipping missing events. }; - services() + services .rooms .timeline .build_and_append_pdu( @@ -733,13 +740,13 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> } // Moves any local aliases to the new room - for alias in services() + for alias in services .rooms .alias .local_aliases_for_room(&body.room_id) .filter_map(Result::ok) { - services() + services .rooms .alias .set_alias(&alias, &replacement_room, sender_user)?; @@ -747,7 +754,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Get the old room power levels let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - services() + services .rooms .state_accessor .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? @@ -772,7 +779,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> // Modify the power levels in the old room to prevent sending of events and // inviting new users - services() + services .rooms .timeline .build_and_append_pdu( @@ -841,7 +848,7 @@ fn default_power_levels_content( /// if a room is being created with a room alias, run our checks async fn room_alias_check( - room_alias_name: &str, appservice_info: &Option<RegistrationInfo>, + services: &Services, room_alias_name: &str, appservice_info: &Option<RegistrationInfo>, ) -> Result<OwnedRoomAliasId> { // Basic checks on the room alias validity if room_alias_name.contains(':') { @@ -858,7 +865,7 @@ async fn room_alias_check( } // check if room alias is forbidden - if services() + if services .globals .forbidden_alias_names() .is_match(room_alias_name) @@ -866,13 +873,13 @@ async fn room_alias_check( return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); } - let full_room_alias = RoomAliasId::parse(format!("#{}:{}", room_alias_name, services().globals.config.server_name)) + let full_room_alias = RoomAliasId::parse(format!("#{}:{}", room_alias_name, services.globals.config.server_name)) .map_err(|e| { - info!("Failed to parse room alias {room_alias_name}: {e}"); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") - })?; + info!("Failed to parse room alias {room_alias_name}: {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") + })?; - if services() + if services .rooms .alias .resolve_local_alias(&full_room_alias)? @@ -885,7 +892,7 @@ async fn room_alias_check( if !info.aliases.is_match(full_room_alias.as_str()) { return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); } - } else if services() + } else if services .appservice .is_exclusive_alias(&full_room_alias) .await @@ -899,9 +906,9 @@ async fn room_alias_check( } /// if a room is being created with a custom room ID, run our checks against it -fn custom_room_id_check(custom_room_id: &str) -> Result<OwnedRoomId> { +fn custom_room_id_check(services: &Services, custom_room_id: &str) -> Result<OwnedRoomId> { // apply forbidden room alias checks to custom room IDs too - if services() + if services .globals .forbidden_alias_names() .is_match(custom_room_id) @@ -922,7 +929,7 @@ fn custom_room_id_check(custom_room_id: &str) -> Result<OwnedRoomId> { )); } - let full_room_id = format!("!{}:{}", custom_room_id, services().globals.config.server_name); + let full_room_id = format!("!{}:{}", custom_room_id, services.globals.config.server_name); debug_info!("Full custom room ID: {full_room_id}"); diff --git a/src/api/client/search.rs b/src/api/client/search.rs index d2a1668f6..b143bd2c7 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use axum::extract::State; use ruma::{ api::client::{ error::ErrorKind, @@ -14,7 +15,7 @@ }; use tracing::debug; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `POST /_matrix/client/r0/search` /// @@ -22,7 +23,9 @@ /// /// - Only works if the user is currently joined to the room (TODO: Respect /// history visibility) -pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> { +pub(crate) async fn search_events_route( + State(services): State<crate::State>, body: Ruma<search_events::v3::Request>, +) -> Result<search_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let search_criteria = body.search_categories.room_events.as_ref().unwrap(); @@ -30,7 +33,7 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) let include_state = &search_criteria.include_state; let room_ids = filter.rooms.clone().unwrap_or_else(|| { - services() + services .rooms .state_cache .rooms_joined(sender_user) @@ -50,11 +53,7 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) if include_state.is_some_and(|include_state| include_state) { for room_id in &room_ids { - if !services() - .rooms - .state_cache - .is_joined(sender_user, room_id)? - { + if !services.rooms.state_cache.is_joined(sender_user, room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", @@ -62,12 +61,12 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) } // check if sender_user can see state events - if services() + if services .rooms .state_accessor .user_can_see_state_events(sender_user, room_id)? { - let room_state = services() + let room_state = services .rooms .state_accessor .room_state_full(room_id) @@ -91,18 +90,14 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) let mut searches = Vec::new(); for room_id in &room_ids { - if !services() - .rooms - .state_cache - .is_joined(sender_user, room_id)? - { + if !services.rooms.state_cache.is_joined(sender_user, room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "You don't have permission to view this room.", )); } - if let Some(search) = services() + if let Some(search) = services .rooms .search .search_pdus(room_id, &search_criteria.search_term)? @@ -135,14 +130,14 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) .iter() .skip(skip) .filter_map(|result| { - services() + services .rooms .timeline .get_pdu_from_id(result) .ok()? .filter(|pdu| { !pdu.is_redacted() - && services() + && services .rooms .state_accessor .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) diff --git a/src/api/client/session.rs b/src/api/client/session.rs index 6df46549d..32f3ed294 100644 --- a/src/api/client/session.rs +++ b/src/api/client/session.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::client::{ error::ErrorKind, @@ -20,7 +21,7 @@ use tracing::{debug, info, warn}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, utils::hash, Error, Result, Ruma}; +use crate::{utils, utils::hash, Error, Result, Ruma}; #[derive(Debug, Deserialize)] struct Claims { @@ -55,7 +56,9 @@ pub(crate) async fn get_login_types_route( /// Note: You can use [`GET /// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// supported login types. -pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> { +pub(crate) async fn login_route( + State(services): State<crate::State>, body: Ruma<login::v3::Request>, +) -> Result<login::v3::Response> { // Validate login method // TODO: Other login methods let user_id = match &body.login_info { @@ -68,7 +71,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: }) => { debug!("Got password login type"); let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) + UserId::parse_with_server_name(user_id.to_lowercase(), services.globals.server_name()) } else if let Some(user) = user { UserId::parse(user) } else { @@ -77,7 +80,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: } .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - let hash = services() + let hash = services .users .password_hash(&user_id)? .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; @@ -96,7 +99,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: token, }) => { debug!("Got token login type"); - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { + if let Some(jwt_decoding_key) = services.globals.jwt_decoding_key() { let token = jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default()) .map_err(|e| { @@ -106,7 +109,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| { + UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| { warn!("Failed to parse username from user logging in: {e}"); Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })? @@ -124,7 +127,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: }) => { debug!("Got appservice login type"); let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) + UserId::parse_with_server_name(user_id.to_lowercase(), services.globals.server_name()) } else if let Some(user) = user { UserId::parse(user) } else { @@ -164,22 +167,22 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: // Determine if device_id was provided and exists in the db for this user let device_exists = body.device_id.as_ref().map_or(false, |device_id| { - services() + services .users .all_device_ids(&user_id) .any(|x| x.as_ref().map_or(false, |v| v == device_id)) }); if device_exists { - services().users.set_token(&user_id, &device_id, &token)?; + services.users.set_token(&user_id, &device_id, &token)?; } else { - services() + services .users .create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; } // send client well-known if specified so the client knows to reconfigure itself - let client_discovery_info: Option<DiscoveryInfo> = services() + let client_discovery_info: Option<DiscoveryInfo> = services .globals .well_known_client() .as_ref() @@ -197,7 +200,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: device_id, well_known: client_discovery_info, expires_in: None, - home_server: Some(services().globals.server_name().to_owned()), + home_server: Some(services.globals.server_name().to_owned()), refresh_token: None, }) } @@ -211,14 +214,16 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login: /// last seen ts) /// - Forgets to-device events /// - Triggers device list updates -pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> { +pub(crate) async fn logout_route( + State(services): State<crate::State>, body: Ruma<logout::v3::Request>, +) -> Result<logout::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services().users.remove_device(sender_user, sender_device)?; + services.users.remove_device(sender_user, sender_device)?; // send device list update for user after logout - services().users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user)?; Ok(logout::v3::Response::new()) } @@ -236,15 +241,17 @@ pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logo /// Note: This is equivalent to calling [`GET /// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this /// user. -pub(crate) async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> { +pub(crate) async fn logout_all_route( + State(services): State<crate::State>, body: Ruma<logout_all::v3::Request>, +) -> Result<logout_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services().users.all_device_ids(sender_user).flatten() { - services().users.remove_device(sender_user, &device_id)?; + for device_id in services.users.all_device_ids(sender_user).flatten() { + services.users.remove_device(sender_user, &device_id)?; } // send device list update for user after logout - services().users.mark_device_key_update(sender_user)?; + services.users.mark_device_key_update(sender_user)?; Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client/space.rs b/src/api/client/space.rs index 0cf1b1073..a3031a3af 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -1,17 +1,20 @@ use std::str::FromStr; +use axum::extract::State; use ruma::{ api::client::{error::ErrorKind, space::get_hierarchy}, UInt, }; -use crate::{service::rooms::spaces::PaginationToken, services, Error, Result, Ruma}; +use crate::{service::rooms::spaces::PaginationToken, Error, Result, Ruma}; /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// /// Paginates over the space tree in a depth-first manner to locate child rooms /// of a given space. -pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { +pub(crate) async fn get_hierarchy_route( + State(services): State<crate::State>, body: Ruma<get_hierarchy::v1::Request>, +) -> Result<get_hierarchy::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = body @@ -39,7 +42,7 @@ pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) } } - services() + services .rooms .spaces .get_client_hierarchy( diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 25b77fe3a..51217d001 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use conduit::{debug_info, error}; use ruma::{ api::client::{ @@ -19,8 +20,8 @@ }; use crate::{ - service::{pdu::PduBuilder, server_is_ours}, - services, Error, Result, Ruma, RumaResponse, + service::{pdu::PduBuilder, server_is_ours, Services}, + Error, Result, Ruma, RumaResponse, }; /// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}` @@ -32,12 +33,13 @@ /// allowed /// - If event is new `canonical_alias`: Rejects if alias is incorrect pub(crate) async fn send_state_event_for_key_route( - body: Ruma<send_state_event::v3::Request>, + State(services): State<crate::State>, body: Ruma<send_state_event::v3::Request>, ) -> Result<send_state_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(send_state_event::v3::Response { event_id: send_state_event_for_key_helper( + services, sender_user, &body.room_id, &body.event_type, @@ -58,9 +60,11 @@ pub(crate) async fn send_state_event_for_key_route( /// allowed /// - If event is new `canonical_alias`: Rejects if alias is incorrect pub(crate) async fn send_state_event_for_empty_key_route( - body: Ruma<send_state_event::v3::Request>, + State(services): State<crate::State>, body: Ruma<send_state_event::v3::Request>, ) -> Result<RumaResponse<send_state_event::v3::Response>> { - send_state_event_for_key_route(body).await.map(RumaResponse) + send_state_event_for_key_route(State(services), body) + .await + .map(RumaResponse) } /// # `GET /_matrix/client/v3/rooms/{roomid}/state` @@ -70,11 +74,11 @@ pub(crate) async fn send_state_event_for_empty_key_route( /// - If not joined: Only works if current room history visibility is world /// readable pub(crate) async fn get_state_events_route( - body: Ruma<get_state_events::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_state_events::v3::Request>, ) -> Result<get_state_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_accessor .user_can_see_state_events(sender_user, &body.room_id)? @@ -86,7 +90,7 @@ pub(crate) async fn get_state_events_route( } Ok(get_state_events::v3::Response { - room_state: services() + room_state: services .rooms .state_accessor .room_state_full(&body.room_id) @@ -106,11 +110,11 @@ pub(crate) async fn get_state_events_route( /// - If not joined: Only works if current room history visibility is world /// readable pub(crate) async fn get_state_events_for_key_route( - body: Ruma<get_state_events_for_key::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_state_events_for_key::v3::Request>, ) -> Result<get_state_events_for_key::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_accessor .user_can_see_state_events(sender_user, &body.room_id)? @@ -121,7 +125,7 @@ pub(crate) async fn get_state_events_for_key_route( )); } - let event = services() + let event = services .rooms .state_accessor .room_state_get(&body.room_id, &body.event_type, &body.state_key)? @@ -161,17 +165,20 @@ pub(crate) async fn get_state_events_for_key_route( /// - If not joined: Only works if current room history visibility is world /// readable pub(crate) async fn get_state_events_for_empty_key_route( - body: Ruma<get_state_events_for_key::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_state_events_for_key::v3::Request>, ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { - get_state_events_for_key_route(body).await.map(RumaResponse) + get_state_events_for_key_route(State(services), body) + .await + .map(RumaResponse) } async fn send_state_event_for_key_helper( - sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String, + services: &Services, sender: &UserId, room_id: &RoomId, event_type: &StateEventType, + json: &Raw<AnyStateEventContent>, state_key: String, ) -> Result<Arc<EventId>> { - allowed_to_send_state_event(room_id, event_type, json).await?; - let state_lock = services().rooms.state.mutex.lock(room_id).await; - let event_id = services() + allowed_to_send_state_event(services, room_id, event_type, json).await?; + let state_lock = services.rooms.state.mutex.lock(room_id).await; + let event_id = services .rooms .timeline .build_and_append_pdu( @@ -192,12 +199,12 @@ async fn send_state_event_for_key_helper( } async fn allowed_to_send_state_event( - room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, + services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, ) -> Result<()> { match event_type { // Forbid m.room.encryption if encryption is disabled StateEventType::RoomEncryption => { - if !services().globals.allow_encryption() { + if !services.globals.allow_encryption() { return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); } }, @@ -244,7 +251,7 @@ async fn allowed_to_send_state_event( for alias in aliases { if !server_is_ours(alias.server_name()) - || services() + || services .rooms .alias .resolve_local_alias(&alias)? diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index e425616b0..5739052c3 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -5,6 +5,7 @@ time::Duration, }; +use axum::extract::State; use conduit::{ error, utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, @@ -17,7 +18,7 @@ self, v3::{ Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, - RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, + RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice, }, v4::SlidingOp, DeviceLists, UnreadNotificationsCount, @@ -34,7 +35,10 @@ }; use tracing::{Instrument as _, Span}; -use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse}; +use crate::{ + service::{pdu::EventHash, Services}, + utils, Error, PduEvent, Result, Ruma, RumaResponse, +}; /// # `GET /_matrix/client/r0/sync` /// @@ -72,23 +76,23 @@ /// - If the user left after `since`: `prev_batch` token, empty state (TODO: /// subset of the state at the point of the leave) pub(crate) async fn sync_events_route( - body: Ruma<sync_events::v3::Request>, + State(services): State<crate::State>, body: Ruma<sync_events::v3::Request>, ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { let sender_user = body.sender_user.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let body = body.body; // Presence update - if services().globals.allow_local_presence() { - services() + if services.globals.allow_local_presence() { + services .presence .ping_presence(&sender_user, &body.set_presence)?; } // Setup watchers, so if there's no response, we can wait for them - let watcher = services().globals.watch(&sender_user, &sender_device); + let watcher = services.globals.watch(&sender_user, &sender_device); - let next_batch = services().globals.current_count()?; + let next_batch = services.globals.current_count()?; let next_batchcount = PduCount::Normal(next_batch); let next_batch_string = next_batch.to_string(); @@ -96,7 +100,7 @@ pub(crate) async fn sync_events_route( let filter = match body.filter { None => FilterDefinition::default(), Some(Filter::FilterDefinition(filter)) => filter, - Some(Filter::FilterId(filter_id)) => services() + Some(Filter::FilterId(filter_id)) => services .users .get_filter(&sender_user, &filter_id)? .unwrap_or_default(), @@ -126,28 +130,29 @@ pub(crate) async fn sync_events_route( // Look for device list updates of this account device_list_updates.extend( - services() + services .users .keys_changed(sender_user.as_ref(), since, None) .filter_map(Result::ok), ); - if services().globals.allow_local_presence() { - process_presence_updates(&mut presence_updates, since, &sender_user).await?; + if services.globals.allow_local_presence() { + process_presence_updates(services, &mut presence_updates, since, &sender_user).await?; } - let all_joined_rooms = services() + let all_joined_rooms = services .rooms .state_cache .rooms_joined(&sender_user) .collect::<Vec<_>>(); // Coalesce database writes for the remainder of this scope. - let _cork = services().db.cork_and_flush(); + let _cork = services.db.cork_and_flush(); for room_id in all_joined_rooms { let room_id = room_id?; if let Ok(joined_room) = load_joined_room( + services, &sender_user, &sender_device, &room_id, @@ -170,13 +175,14 @@ pub(crate) async fn sync_events_route( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services() + let all_left_rooms: Vec<_> = services .rooms .state_cache .rooms_left(&sender_user) .collect(); for result in all_left_rooms { handle_left_room( + services, since, &result?.0, &sender_user, @@ -190,7 +196,7 @@ pub(crate) async fn sync_events_route( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = services() + let all_invited_rooms: Vec<_> = services .rooms .state_cache .rooms_invited(&sender_user) @@ -199,10 +205,10 @@ pub(crate) async fn sync_events_route( let (room_id, invite_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().rooms.timeline.mutex_insert.lock(&room_id).await; + let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); - let invite_count = services() + let invite_count = services .rooms .state_cache .get_invite_count(&room_id, &sender_user)?; @@ -223,14 +229,14 @@ pub(crate) async fn sync_events_route( } for user_id in left_encrypted_users { - let dont_share_encrypted_room = services() + let dont_share_encrypted_room = services .rooms .user .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(Result::ok) .filter_map(|other_room_id| { Some( - services() + services .rooms .state_accessor .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") @@ -247,7 +253,7 @@ pub(crate) async fn sync_events_route( } // Remove all to-device events the device received *last time* - services() + services .users .remove_to_device_events(&sender_user, &sender_device, since)?; @@ -266,7 +272,7 @@ pub(crate) async fn sync_events_route( .collect(), }, account_data: GlobalAccountData { - events: services() + events: services .account_data .changes_since(None, &sender_user, since)? .into_iter() @@ -281,11 +287,11 @@ pub(crate) async fn sync_events_route( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: services() + device_one_time_keys_count: services .users .count_one_time_keys(&sender_user, &sender_device)?, to_device: ToDevice { - events: services() + events: services .users .get_to_device_events(&sender_user, &sender_device)?, }, @@ -311,16 +317,18 @@ pub(crate) async fn sync_events_route( Ok(response) } +#[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(user_id = %sender_user, room_id = %room_id), name = "left_room")] async fn handle_left_room( - since: u64, room_id: &RoomId, sender_user: &UserId, left_rooms: &mut BTreeMap<ruma::OwnedRoomId, LeftRoom>, - next_batch_string: &str, full_state: bool, lazy_load_enabled: bool, + services: &Services, since: u64, room_id: &RoomId, sender_user: &UserId, + left_rooms: &mut BTreeMap<ruma::OwnedRoomId, LeftRoom>, next_batch_string: &str, full_state: bool, + lazy_load_enabled: bool, ) -> Result<()> { // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; + let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let left_count = services() + let left_count = services .rooms .state_cache .get_left_count(room_id, sender_user)?; @@ -330,11 +338,11 @@ async fn handle_left_room( return Ok(()); } - if !services().rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id)? { // This is just a rejected invite, not a room we know // Insert a leave event anyways let event = PduEvent { - event_id: EventId::new(services().globals.server_name()).into(), + event_id: EventId::new(services.globals.server_name()).into(), sender: sender_user.to_owned(), origin: None, origin_server_ts: utils::millis_since_unix_epoch() @@ -367,7 +375,7 @@ async fn handle_left_room( prev_batch: Some(next_batch_string.to_owned()), events: Vec::new(), }, - state: State { + state: RoomState { events: vec![event.to_sync_state_event()], }, }, @@ -377,27 +385,27 @@ async fn handle_left_room( let mut left_state_events = Vec::new(); - let since_shortstatehash = services() + let since_shortstatehash = services .rooms .user .get_token_shortstatehash(room_id, since)?; let since_state_ids = match since_shortstatehash { - Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, + Some(s) => services.rooms.state_accessor.state_full_ids(s).await?, None => HashMap::new(), }; - let Some(left_event_id) = services().rooms.state_accessor.room_state_get_id( - room_id, - &StateEventType::RoomMember, - sender_user.as_str(), - )? + let Some(left_event_id) = + services + .rooms + .state_accessor + .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())? else { error!("Left room but no left state event"); return Ok(()); }; - let Some(left_shortstatehash) = services() + let Some(left_shortstatehash) = services .rooms .state_accessor .pdu_shortstatehash(&left_event_id)? @@ -406,13 +414,13 @@ async fn handle_left_room( return Ok(()); }; - let mut left_state_ids = services() + let mut left_state_ids = services .rooms .state_accessor .state_full_ids(left_shortstatehash) .await?; - let leave_shortstatekey = services() + let leave_shortstatekey = services .rooms .short .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; @@ -422,7 +430,7 @@ async fn handle_left_room( let mut i: u8 = 0; for (key, id) in left_state_ids { if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; + let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?; if !lazy_load_enabled || event_type != StateEventType::RoomMember @@ -430,7 +438,7 @@ async fn handle_left_room( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; @@ -456,7 +464,7 @@ async fn handle_left_room( prev_batch: Some(next_batch_string.to_owned()), events: Vec::new(), }, - state: State { + state: RoomState { events: left_state_events, }, }, @@ -465,13 +473,13 @@ async fn handle_left_room( } async fn process_presence_updates( - presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, + services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, ) -> Result<()> { use crate::service::presence::Presence; // Take presence updates - for (user_id, _, presence_bytes) in services().presence.presence_since(since) { - if !services() + for (user_id, _, presence_bytes) in services.presence.presence_since(since) { + if !services .rooms .state_cache .user_sees_user(syncing_user, &user_id)? @@ -513,19 +521,20 @@ async fn process_presence_updates( #[allow(clippy::too_many_arguments)] async fn load_joined_room( - sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, sincecount: PduCount, - next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, lazy_load_send_redundant: bool, - full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>, left_encrypted_users: &mut HashSet<OwnedUserId>, + services: &Services, sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, + sincecount: PduCount, next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, + lazy_load_send_redundant: bool, full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>, + left_encrypted_users: &mut HashSet<OwnedUserId>, ) -> Result<JoinedRoom> { // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch - let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; + let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?; let send_notification_counts = !timeline_pdus.is_empty() - || services() + || services .rooms .user .last_notification_read(sender_user, room_id)? @@ -536,7 +545,7 @@ async fn load_joined_room( timeline_users.insert(event.sender.as_str().to_owned()); } - services() + services .rooms .lazy_loading .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) @@ -544,11 +553,11 @@ async fn load_joined_room( // Database queries: - let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { + let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { return Err!(Database(error!("Room {room_id} has no state"))); }; - let since_shortstatehash = services() + let since_shortstatehash = services .rooms .user .get_token_shortstatehash(room_id, since)?; @@ -560,12 +569,12 @@ async fn load_joined_room( } else { // Calculates joined_member_count, invited_member_count and heroes let calculate_counts = || { - let joined_member_count = services() + let joined_member_count = services .rooms .state_cache .room_joined_count(room_id)? .unwrap_or(0); - let invited_member_count = services() + let invited_member_count = services .rooms .state_cache .room_invited_count(room_id)? @@ -578,7 +587,7 @@ async fn load_joined_room( // Go through all PDUs and for each member event, check if the user is still // joined or invited until we have 5 or we reach the end - for hero in services() + for hero in services .rooms .timeline .all_pdus(sender_user, room_id)? @@ -594,8 +603,8 @@ async fn load_joined_room( // The membership was and still is invite or join if matches!(content.membership, MembershipState::Join | MembershipState::Invite) - && (services().rooms.state_cache.is_joined(&user_id, room_id)? - || services().rooms.state_cache.is_invited(&user_id, room_id)?) + && (services.rooms.state_cache.is_joined(&user_id, room_id)? + || services.rooms.state_cache.is_invited(&user_id, room_id)?) { Ok::<_, Error>(Some(user_id)) } else { @@ -622,7 +631,7 @@ async fn load_joined_room( let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash .and_then(|shortstatehash| { - services() + services .rooms .state_accessor .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) @@ -643,7 +652,7 @@ async fn load_joined_room( let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = services() + let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) @@ -654,13 +663,13 @@ async fn load_joined_room( let mut i: u8 = 0; for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services() + let (event_type, state_key) = services .rooms .short .get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; @@ -676,7 +685,7 @@ async fn load_joined_room( // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; @@ -695,14 +704,14 @@ async fn load_joined_room( } // Reset lazy loading because this is an initial sync - services() + services .rooms .lazy_loading .lazy_load_reset(sender_user, sender_device, room_id)?; // The state_events above should contain all timeline_users, let's mark them as // lazy loaded. - services() + services .rooms .lazy_loading .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) @@ -716,12 +725,12 @@ async fn load_joined_room( let mut delta_state_events = Vec::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = services() + let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; - let since_state_ids = services() + let since_state_ids = services .rooms .state_accessor .state_full_ids(since_shortstatehash) @@ -729,7 +738,7 @@ async fn load_joined_room( for (key, id) in current_state_ids { if full_state || since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; @@ -740,13 +749,13 @@ async fn load_joined_room( } } - let encrypted_room = services() + let encrypted_room = services .rooms .state_accessor .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .is_some(); - let since_encryption = services().rooms.state_accessor.state_get( + let since_encryption = services.rooms.state_accessor.state_get( since_shortstatehash, &StateEventType::RoomEncryption, "", @@ -781,7 +790,7 @@ async fn load_joined_room( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(sender_user, &user_id, room_id)? { + if !share_encrypted_room(services, sender_user, &user_id, room_id)? { device_list_updates.insert(user_id); } }, @@ -798,7 +807,7 @@ async fn load_joined_room( if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( - services() + services .rooms .state_cache .room_members(room_id) @@ -810,7 +819,7 @@ async fn load_joined_room( .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target // already - !share_encrypted_room(sender_user, user_id, room_id).unwrap_or(false) + !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false) }), ); } @@ -848,14 +857,14 @@ async fn load_joined_room( continue; } - if !services().rooms.lazy_loading.lazy_load_was_sent_before( + if !services.rooms.lazy_loading.lazy_load_was_sent_before( sender_user, sender_device, room_id, &event.sender, )? || lazy_load_send_redundant { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( + if let Some(member_event) = services.rooms.state_accessor.room_state_get( room_id, &StateEventType::RoomMember, event.sender.as_str(), @@ -866,7 +875,7 @@ async fn load_joined_room( } } - services() + services .rooms .lazy_loading .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) @@ -884,7 +893,7 @@ async fn load_joined_room( // Look for device list updates in this room device_list_updates.extend( - services() + services .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok), @@ -892,7 +901,7 @@ async fn load_joined_room( let notification_count = if send_notification_counts { Some( - services() + services .rooms .user .notification_count(sender_user, room_id)? @@ -905,7 +914,7 @@ async fn load_joined_room( let highlight_count = if send_notification_counts { Some( - services() + services .rooms .user .highlight_count(sender_user, room_id)? @@ -933,7 +942,7 @@ async fn load_joined_room( .map(|(_, pdu)| pdu.to_sync_room_event()) .collect(); - let mut edus: Vec<_> = services() + let mut edus: Vec<_> = services .rooms .read_receipt .readreceipts_since(room_id, since) @@ -941,10 +950,10 @@ async fn load_joined_room( .map(|(_, _, v)| v) .collect(); - if services().rooms.typing.last_typing_update(room_id).await? > since { + if services.rooms.typing.last_typing_update(room_id).await? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&services().rooms.typing.typings_all(room_id).await?) + &serde_json::to_string(&services.rooms.typing.typings_all(room_id).await?) .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), @@ -953,14 +962,14 @@ async fn load_joined_room( // Save the state after this sync so we can send the correct state diff next // sync - services() + services .rooms .user .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; Ok(JoinedRoom { account_data: RoomAccountData { - events: services() + events: services .account_data .changes_since(Some(room_id), sender_user, since)? .into_iter() @@ -985,7 +994,7 @@ async fn load_joined_room( prev_batch, events: room_events, }, - state: State { + state: RoomState { events: state_events .iter() .map(|pdu| pdu.to_sync_state_event()) @@ -999,16 +1008,16 @@ async fn load_joined_room( } fn load_timeline( - sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, + services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; - let limited = if services() + let limited = if services .rooms .timeline .last_timeline_count(sender_user, room_id)? > roomsincecount { - let mut non_timeline_pdus = services() + let mut non_timeline_pdus = services .rooms .timeline .pdus_until(sender_user, room_id, PduCount::max())? @@ -1040,8 +1049,10 @@ fn load_timeline( Ok((timeline_pdus, limited)) } -fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId) -> Result<bool> { - Ok(services() +fn share_encrypted_room( + services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, +) -> Result<bool> { + Ok(services .rooms .user .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? @@ -1049,7 +1060,7 @@ fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &Ro .filter(|room_id| room_id != ignore_room) .filter_map(|other_room_id| { Some( - services() + services .rooms .state_accessor .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") @@ -1064,15 +1075,15 @@ fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &Ro /// /// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) pub(crate) async fn sync_events_v4_route( - body: Ruma<sync_events::v4::Request>, + State(services): State<crate::State>, body: Ruma<sync_events::v4::Request>, ) -> Result<sync_events::v4::Response, RumaResponse<UiaaResponse>> { let sender_user = body.sender_user.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let mut body = body.body; // Setup watchers, so if there's no response, we can wait for them - let watcher = services().globals.watch(&sender_user, &sender_device); + let watcher = services.globals.watch(&sender_user, &sender_device); - let next_batch = services().globals.next_count()?; + let next_batch = services.globals.next_count()?; let globalsince = body .pos @@ -1082,21 +1093,19 @@ pub(crate) async fn sync_events_v4_route( if globalsince == 0 { if let Some(conn_id) = &body.conn_id { - services().users.forget_sync_request_connection( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - ); + services + .users + .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone()); } } // Get sticky parameters from cache let known_rooms = - services() + services .users .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - let all_joined_rooms = services() + let all_joined_rooms = services .rooms .state_cache .rooms_joined(&sender_user) @@ -1104,7 +1113,7 @@ pub(crate) async fn sync_events_v4_route( .collect::<Vec<_>>(); if body.extensions.to_device.enabled.unwrap_or(false) { - services() + services .users .remove_to_device_events(&sender_user, &sender_device, globalsince)?; } @@ -1116,26 +1125,26 @@ pub(crate) async fn sync_events_v4_route( if body.extensions.e2ee.enabled.unwrap_or(false) { // Look for device list updates of this account device_list_changes.extend( - services() + services .users .keys_changed(sender_user.as_ref(), globalsince, None) .filter_map(Result::ok), ); for room_id in &all_joined_rooms { - let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { + let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else { error!("Room {} has no state", room_id); continue; }; - let since_shortstatehash = services() + let since_shortstatehash = services .rooms .user .get_token_shortstatehash(room_id, globalsince)?; let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash .and_then(|shortstatehash| { - services() + services .rooms .state_accessor .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) @@ -1148,7 +1157,7 @@ pub(crate) async fn sync_events_v4_route( .ok() }); - let encrypted_room = services() + let encrypted_room = services .rooms .state_accessor .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? @@ -1160,7 +1169,7 @@ pub(crate) async fn sync_events_v4_route( continue; } - let since_encryption = services().rooms.state_accessor.state_get( + let since_encryption = services.rooms.state_accessor.state_get( since_shortstatehash, &StateEventType::RoomEncryption, "", @@ -1171,12 +1180,12 @@ pub(crate) async fn sync_events_v4_route( let new_encrypted_room = encrypted_room && since_encryption.is_none(); if encrypted_room { - let current_state_ids = services() + let current_state_ids = services .rooms .state_accessor .state_full_ids(current_shortstatehash) .await?; - let since_state_ids = services() + let since_state_ids = services .rooms .state_accessor .state_full_ids(since_shortstatehash) @@ -1184,7 +1193,7 @@ pub(crate) async fn sync_events_v4_route( for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { - let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { + let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else { error!("Pdu in state not found: {}", id); continue; }; @@ -1205,7 +1214,7 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&sender_user, &user_id, room_id)? { + if !share_encrypted_room(services, &sender_user, &user_id, room_id)? { device_list_changes.insert(user_id); } }, @@ -1222,7 +1231,7 @@ pub(crate) async fn sync_events_v4_route( if joined_since_last_sync || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_changes.extend( - services() + services .rooms .state_cache .room_members(room_id) @@ -1234,7 +1243,7 @@ pub(crate) async fn sync_events_v4_route( .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target // already - !share_encrypted_room(&sender_user, user_id, room_id).unwrap_or(false) + !share_encrypted_room(services, &sender_user, user_id, room_id).unwrap_or(false) }), ); } @@ -1242,21 +1251,21 @@ pub(crate) async fn sync_events_v4_route( } // Look for device list updates in this room device_list_changes.extend( - services() + services .users .keys_changed(room_id.as_ref(), globalsince, None) .filter_map(Result::ok), ); } for user_id in left_encrypted_users { - let dont_share_encrypted_room = services() + let dont_share_encrypted_room = services .rooms .user .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(Result::ok) .filter_map(|other_room_id| { Some( - services() + services .rooms .state_accessor .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") @@ -1336,7 +1345,7 @@ pub(crate) async fn sync_events_v4_route( ); if let Some(conn_id) = &body.conn_id { - services().users.update_sync_known_rooms( + services.users.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1349,7 +1358,7 @@ pub(crate) async fn sync_events_v4_route( let mut known_subscription_rooms = BTreeSet::new(); for (room_id, room) in &body.room_subscriptions { - if !services().rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id)? { continue; } let todo_room = todo_rooms @@ -1375,7 +1384,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services().users.update_sync_known_rooms( + services.users.update_sync_known_rooms( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1386,7 +1395,7 @@ pub(crate) async fn sync_events_v4_route( } if let Some(conn_id) = &body.conn_id { - services().users.update_sync_subscriptions( + services.users.update_sync_subscriptions( sender_user.clone(), sender_device.clone(), conn_id.clone(), @@ -1398,7 +1407,7 @@ pub(crate) async fn sync_events_v4_route( for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { let roomsincecount = PduCount::Normal(*roomsince); - let (timeline_pdus, limited) = load_timeline(&sender_user, room_id, roomsincecount, *timeline_limit)?; + let (timeline_pdus, limited) = load_timeline(services, &sender_user, room_id, roomsincecount, *timeline_limit)?; if roomsince != &0 && timeline_pdus.is_empty() { continue; @@ -1431,7 +1440,7 @@ pub(crate) async fn sync_events_v4_route( let required_state = required_state_request .iter() .map(|state| { - services() + services .rooms .state_accessor .room_state_get(room_id, &state.0, &state.1) @@ -1442,7 +1451,7 @@ pub(crate) async fn sync_events_v4_route( .collect(); // Heroes - let heroes = services() + let heroes = services .rooms .state_cache .room_members(room_id) @@ -1450,7 +1459,7 @@ pub(crate) async fn sync_events_v4_route( .filter(|member| member != &sender_user) .map(|member| { Ok::<_, Error>( - services() + services .rooms .state_accessor .get_member(room_id, &member)? @@ -1491,11 +1500,11 @@ pub(crate) async fn sync_events_v4_route( rooms.insert( room_id.clone(), sync_events::v4::SlidingSyncRoom { - name: services().rooms.state_accessor.get_name(room_id)?.or(name), + name: services.rooms.state_accessor.get_name(room_id)?.or(name), avatar: if let Some(heroes_avatar) = heroes_avatar { ruma::JsOption::Some(heroes_avatar) } else { - match services().rooms.state_accessor.get_avatar(room_id)? { + match services.rooms.state_accessor.get_avatar(room_id)? { ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), ruma::JsOption::Null => ruma::JsOption::Null, ruma::JsOption::Undefined => ruma::JsOption::Undefined, @@ -1506,7 +1515,7 @@ pub(crate) async fn sync_events_v4_route( invite_state: None, unread_notifications: UnreadNotificationsCount { highlight_count: Some( - services() + services .rooms .user .highlight_count(&sender_user, room_id)? @@ -1514,7 +1523,7 @@ pub(crate) async fn sync_events_v4_route( .expect("notification count can't go that high"), ), notification_count: Some( - services() + services .rooms .user .notification_count(&sender_user, room_id)? @@ -1527,7 +1536,7 @@ pub(crate) async fn sync_events_v4_route( prev_batch, limited, joined_count: Some( - services() + services .rooms .state_cache .room_joined_count(room_id)? @@ -1536,7 +1545,7 @@ pub(crate) async fn sync_events_v4_route( .unwrap_or_else(|_| uint!(0)), ), invited_count: Some( - services() + services .rooms .state_cache .room_invited_count(room_id)? @@ -1571,7 +1580,7 @@ pub(crate) async fn sync_events_v4_route( extensions: sync_events::v4::Extensions { to_device: if body.extensions.to_device.enabled.unwrap_or(false) { Some(sync_events::v4::ToDevice { - events: services() + events: services .users .get_to_device_events(&sender_user, &sender_device)?, next_batch: next_batch.to_string(), @@ -1584,7 +1593,7 @@ pub(crate) async fn sync_events_v4_route( changed: device_list_changes.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: services() + device_one_time_keys_count: services .users .count_one_time_keys(&sender_user, &sender_device)?, // Fallback keys are not yet supported @@ -1592,7 +1601,7 @@ pub(crate) async fn sync_events_v4_route( }, account_data: sync_events::v4::AccountData { global: if body.extensions.account_data.enabled.unwrap_or(false) { - services() + services .account_data .changes_since(None, &sender_user, globalsince)? .into_iter() diff --git a/src/api/client/tag.rs b/src/api/client/tag.rs index 9fb60bc6e..301568e50 100644 --- a/src/api/client/tag.rs +++ b/src/api/client/tag.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use axum::extract::State; use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ @@ -8,17 +9,19 @@ }, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// /// Adds a tag to the room. /// /// - Inserts the tag into the tag event of the room account data. -pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> { +pub(crate) async fn update_tag_route( + State(services): State<crate::State>, body: Ruma<create_tag::v3::Request>, +) -> Result<create_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() + let event = services .account_data .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; @@ -38,7 +41,7 @@ pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Res .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - services().account_data.update( + services.account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, @@ -53,10 +56,12 @@ pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Res /// Deletes a tag from the room. /// /// - Removes the tag from the tag event of the room account data. -pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> { +pub(crate) async fn delete_tag_route( + State(services): State<crate::State>, body: Ruma<delete_tag::v3::Request>, +) -> Result<delete_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() + let event = services .account_data .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; @@ -73,7 +78,7 @@ pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Res tags_event.content.tags.remove(&body.tag.clone().into()); - services().account_data.update( + services.account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, @@ -88,10 +93,12 @@ pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Res /// Returns tags on the room. /// /// - Gets the tag event of the room account data. -pub(crate) async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> { +pub(crate) async fn get_tags_route( + State(services): State<crate::State>, body: Ruma<get_tags::v3::Request>, +) -> Result<get_tags::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() + let event = services .account_data .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 2895573a9..8100f0e67 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,12 +1,15 @@ +use axum::extract::State; use ruma::{ api::client::{error::ErrorKind, threads::get_threads}, uint, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/threads` -pub(crate) async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> { +pub(crate) async fn get_threads_route( + State(services): State<crate::State>, body: Ruma<get_threads::v1::Request>, +) -> Result<get_threads::v1::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Use limit or else 10, with maximum 100 @@ -24,14 +27,14 @@ pub(crate) async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> R u64::MAX }; - let threads = services() + let threads = services .rooms .threads .threads_until(sender_user, &body.room_id, from, &body.include)? .take(limit) .filter_map(Result::ok) .filter(|(_, pdu)| { - services() + services .rooms .state_accessor .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 011e08f7d..8476ff41d 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use axum::extract::State; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -8,19 +9,19 @@ to_device::DeviceIdOrAllDevices, }; -use crate::{services, user_is_local, Error, Result, Ruma}; +use crate::{user_is_local, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// /// Send a to-device event to a set of client devices. pub(crate) async fn send_event_to_device_route( - body: Ruma<send_event_to_device::v3::Request>, + State(services): State<crate::State>, body: Ruma<send_event_to_device::v3::Request>, ) -> Result<send_event_to_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); // Check if this is a new transaction id - if services() + if services .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? .is_some() @@ -35,9 +36,9 @@ pub(crate) async fn send_event_to_device_route( map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); - let count = services().globals.next_count()?; + let count = services.globals.next_count()?; - services().sending.send_edu_server( + services.sending.send_edu_server( target_user_id.server_name(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { sender: sender_user.clone(), @@ -53,7 +54,7 @@ pub(crate) async fn send_event_to_device_route( match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( + services.users.add_to_device_event( sender_user, target_user_id, target_device_id, @@ -65,8 +66,8 @@ pub(crate) async fn send_event_to_device_route( }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( + for target_device_id in services.users.all_device_ids(target_user_id) { + services.users.add_to_device_event( sender_user, target_user_id, &target_device_id?, @@ -82,7 +83,7 @@ pub(crate) async fn send_event_to_device_route( } // Save transaction id with empty data - services() + services .transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; diff --git a/src/api/client/typing.rs b/src/api/client/typing.rs index 52c8b3530..a06648e05 100644 --- a/src/api/client/typing.rs +++ b/src/api/client/typing.rs @@ -1,18 +1,19 @@ +use axum::extract::State; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. pub(crate) async fn create_typing_event_route( - body: Ruma<create_typing_event::v3::Request>, + State(services): State<crate::State>, body: Ruma<create_typing_event::v3::Request>, ) -> Result<create_typing_event::v3::Response> { use create_typing_event::v3::Typing; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() + if !services .rooms .state_cache .is_joined(sender_user, &body.room_id)? @@ -23,20 +24,20 @@ pub(crate) async fn create_typing_event_route( if let Typing::Yes(duration) = body.state { let duration = utils::clamp( duration.as_millis().try_into().unwrap_or(u64::MAX), - services() + services .globals .config .typing_client_timeout_min_s .checked_mul(1000) .unwrap(), - services() + services .globals .config .typing_client_timeout_max_s .checked_mul(1000) .unwrap(), ); - services() + services .rooms .typing .typing_add( @@ -48,7 +49,7 @@ pub(crate) async fn create_typing_event_route( ) .await?; } else { - services() + services .rooms .typing .typing_remove(sender_user, &body.room_id) diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index e39db94e8..896509513 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::warn; use ruma::{ @@ -6,7 +7,7 @@ OwnedRoomId, }; -use crate::{services, Error, Result, Ruma, RumaResponse}; +use crate::{Error, Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms` /// @@ -17,7 +18,8 @@ /// An implementation of [MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666) #[tracing::instrument(skip_all, fields(%client), name = "mutual_rooms")] pub(crate) async fn get_mutual_rooms_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<mutual_rooms::unstable::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<mutual_rooms::unstable::Request>, ) -> Result<mutual_rooms::unstable::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -28,14 +30,14 @@ pub(crate) async fn get_mutual_rooms_route( )); } - if !services().users.exists(&body.user_id)? { + if !services.users.exists(&body.user_id)? { return Ok(mutual_rooms::unstable::Response { joined: vec![], next_batch_token: None, }); } - let mutual_rooms: Vec<OwnedRoomId> = services() + let mutual_rooms: Vec<OwnedRoomId> = services .rooms .user .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? @@ -58,9 +60,10 @@ pub(crate) async fn get_mutual_rooms_route( /// /// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266) pub(crate) async fn get_room_summary_legacy( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_summary::msc3266::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_summary::msc3266::Request>, ) -> Result<RumaResponse<get_summary::msc3266::Response>> { - get_room_summary(InsecureClientIp(client), body) + get_room_summary(State(services), InsecureClientIp(client), body) .await .map(RumaResponse) } @@ -74,22 +77,19 @@ pub(crate) async fn get_room_summary_legacy( /// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266) #[tracing::instrument(skip_all, fields(%client), name = "room_summary")] pub(crate) async fn get_room_summary( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_summary::msc3266::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_summary::msc3266::Request>, ) -> Result<get_summary::msc3266::Response> { let sender_user = body.sender_user.as_ref(); - let room_id = services() - .rooms - .alias - .resolve(&body.room_id_or_alias) - .await?; + let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?; - if !services().rooms.metadata.exists(&room_id)? { + if !services.rooms.metadata.exists(&room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } if sender_user.is_none() - && !services() + && !services .rooms .state_accessor .is_world_readable(&room_id) @@ -103,25 +103,25 @@ pub(crate) async fn get_room_summary( Ok(get_summary::msc3266::Response { room_id: room_id.clone(), - canonical_alias: services() + canonical_alias: services .rooms .state_accessor .get_canonical_alias(&room_id) .unwrap_or(None), - avatar_url: services() + avatar_url: services .rooms .state_accessor .get_avatar(&room_id)? .into_option() .unwrap_or_default() .url, - guest_can_join: services().rooms.state_accessor.guest_can_join(&room_id)?, - name: services() + guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?, + name: services .rooms .state_accessor .get_name(&room_id) .unwrap_or(None), - num_joined_members: services() + num_joined_members: services .rooms .state_cache .room_joined_count(&room_id) @@ -132,21 +132,21 @@ pub(crate) async fn get_room_summary( }) .try_into() .expect("user count should not be that big"), - topic: services() + topic: services .rooms .state_accessor .get_room_topic(&room_id) .unwrap_or(None), - world_readable: services() + world_readable: services .rooms .state_accessor .is_world_readable(&room_id) .unwrap_or(false), - join_rule: services().rooms.state_accessor.get_join_rule(&room_id)?.0, - room_type: services().rooms.state_accessor.get_room_type(&room_id)?, - room_version: Some(services().rooms.state.get_room_version(&room_id)?), + join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0, + room_type: services.rooms.state_accessor.get_room_type(&room_id)?, + room_version: Some(services.rooms.state.get_room_version(&room_id)?), membership: if let Some(sender_user) = sender_user { - services() + services .rooms .state_accessor .get_member(&room_id, sender_user)? @@ -154,7 +154,7 @@ pub(crate) async fn get_room_summary( } else { None }, - encryption: services() + encryption: services .rooms .state_accessor .get_room_encryption(&room_id) diff --git a/src/api/client/unversioned.rs b/src/api/client/unversioned.rs index 2b703f947..9a8f3220b 100644 --- a/src/api/client/unversioned.rs +++ b/src/api/client/unversioned.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use axum::{response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use ruma::api::client::{ discovery::{ discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, @@ -10,7 +10,7 @@ error::ErrorKind, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/client/versions` /// @@ -62,9 +62,9 @@ pub(crate) async fn get_supported_versions_route( /// /// Returns the .well-known URL if it is configured, otherwise returns 404. pub(crate) async fn well_known_client( - _body: Ruma<discover_homeserver::Request>, + State(services): State<crate::State>, _body: Ruma<discover_homeserver::Request>, ) -> Result<discover_homeserver::Response> { - let client_url = match services().globals.well_known_client() { + let client_url = match services.globals.well_known_client() { Some(url) => url.to_string(), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), }; @@ -84,22 +84,24 @@ pub(crate) async fn well_known_client( /// # `GET /.well-known/matrix/support` /// /// Server support contact and support page of a homeserver's domain. -pub(crate) async fn well_known_support(_body: Ruma<discover_support::Request>) -> Result<discover_support::Response> { - let support_page = services() +pub(crate) async fn well_known_support( + State(services): State<crate::State>, _body: Ruma<discover_support::Request>, +) -> Result<discover_support::Response> { + let support_page = services .globals .well_known_support_page() .as_ref() .map(ToString::to_string); - let role = services().globals.well_known_support_role().clone(); + let role = services.globals.well_known_support_role().clone(); // support page or role must be either defined for this to be valid if support_page.is_none() && role.is_none() { return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")); } - let email_address = services().globals.well_known_support_email().clone(); - let matrix_id = services().globals.well_known_support_mxid().clone(); + let email_address = services.globals.well_known_support_email().clone(); + let matrix_id = services.globals.well_known_support_mxid().clone(); // if a role is specified, an email address or matrix id is required if role.is_some() && (email_address.is_none() && matrix_id.is_none()) { @@ -134,10 +136,10 @@ pub(crate) async fn well_known_support(_body: Ruma<discover_support::Request>) - /// /// Endpoint provided by sliding sync proxy used by some clients such as Element /// Web as a non-standard health check. -pub(crate) async fn syncv3_client_server_json() -> Result<impl IntoResponse> { - let server_url = match services().globals.well_known_client() { +pub(crate) async fn syncv3_client_server_json(State(services): State<crate::State>) -> Result<impl IntoResponse> { + let server_url = match services.globals.well_known_client() { Some(url) => url.to_string(), - None => match services().globals.well_known_server() { + None => match services.globals.well_known_server() { Some(url) => url.to_string(), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), }, @@ -165,8 +167,8 @@ pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> { /// conduwuit-specific API to return the amount of users registered on this /// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// only includes active users (not deactivated, no guests, etc) -pub(crate) async fn conduwuit_local_user_count() -> Result<impl IntoResponse> { - let user_count = services().users.list_local_users()?.len(); +pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> { + let user_count = services.users.list_local_users()?.len(); Ok(Json(serde_json::json!({ "count": user_count diff --git a/src/api/client/user_directory.rs b/src/api/client/user_directory.rs index ca5f834a9..87d4062cd 100644 --- a/src/api/client/user_directory.rs +++ b/src/api/client/user_directory.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::client::user_directory::search_users, events::{ @@ -6,7 +7,7 @@ }, }; -use crate::{services, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/user_directory/search` /// @@ -14,18 +15,20 @@ /// /// - Hides any local users that aren't in any public rooms (i.e. those that /// have the join rule set to public) and don't share a room with the sender -pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> { +pub(crate) async fn search_users_route( + State(services): State<crate::State>, body: Ruma<search_users::v3::Request>, +) -> Result<search_users::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 - let mut users = services().users.iter().filter_map(|user_id| { + let mut users = services.users.iter().filter_map(|user_id| { // Filter out buggy users (they should not exist, but you never know...) let user_id = user_id.ok()?; let user = search_users::v3::User { user_id: user_id.clone(), - display_name: services().users.displayname(&user_id).ok()?, - avatar_url: services().users.avatar_url(&user_id).ok()?, + display_name: services.users.displayname(&user_id).ok()?, + avatar_url: services.users.avatar_url(&user_id).ok()?, }; let user_id_matches = user @@ -50,13 +53,13 @@ pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) -> // It's a matching user, but is the sender allowed to see them? let mut user_visible = false; - let user_is_in_public_rooms = services() + let user_is_in_public_rooms = services .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) .any(|room| { - services() + services .rooms .state_accessor .room_state_get(&room, &StateEventType::RoomJoinRules, "") @@ -71,7 +74,7 @@ pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) -> if user_is_in_public_rooms { user_visible = true; } else { - let user_is_in_shared_rooms = services() + let user_is_in_shared_rooms = services .rooms .user .get_shared_rooms(vec![sender_user.clone(), user_id]) diff --git a/src/api/client/voip.rs b/src/api/client/voip.rs index 9608dc88f..ed6971ee9 100644 --- a/src/api/client/voip.rs +++ b/src/api/client/voip.rs @@ -1,12 +1,13 @@ use std::time::{Duration, SystemTime}; +use axum::extract::State; use base64::{engine::general_purpose, Engine as _}; use conduit::utils; use hmac::{Hmac, Mac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch, UserId}; use sha1::Sha1; -use crate::{services, Result, Ruma}; +use crate::{Result, Ruma}; const RANDOM_USER_ID_LENGTH: usize = 10; @@ -16,14 +17,14 @@ /// /// TODO: Returns information about the recommended turn server. pub(crate) async fn turn_server_route( - body: Ruma<get_turn_server_info::v3::Request>, + State(services): State<crate::State>, body: Ruma<get_turn_server_info::v3::Request>, ) -> Result<get_turn_server_info::v3::Response> { - let turn_secret = services().globals.turn_secret().clone(); + let turn_secret = services.globals.turn_secret().clone(); let (username, password) = if !turn_secret.is_empty() { let expiry = SecondsSinceUnixEpoch::from_system_time( SystemTime::now() - .checked_add(Duration::from_secs(services().globals.turn_ttl())) + .checked_add(Duration::from_secs(services.globals.turn_ttl())) .expect("TURN TTL should not get this high"), ) .expect("time is valid"); @@ -31,7 +32,7 @@ pub(crate) async fn turn_server_route( let user = body.sender_user.unwrap_or_else(|| { UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - &services().globals.config.server_name, + &services.globals.config.server_name, ) .unwrap() }); @@ -46,15 +47,15 @@ pub(crate) async fn turn_server_route( (username, password) } else { ( - services().globals.turn_username().clone(), - services().globals.turn_password().clone(), + services.globals.turn_username().clone(), + services.globals.turn_password().clone(), ) }; Ok(get_turn_server_info::v3::Response { username, password, - uris: services().globals.turn_uris().to_vec(), - ttl: Duration::from_secs(services().globals.turn_ttl()), + uris: services.globals.turn_uris().to_vec(), + ttl: Duration::from_secs(services.globals.turn_ttl()), }) } diff --git a/src/api/router/args.rs b/src/api/router/args.rs index 776ce4f46..fa5b1e439 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -2,11 +2,12 @@ use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; -use conduit::{debug, err, trace, Error, Result}; +use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result}; use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId}; +use service::Services; use super::{auth, auth::Auth, request, request::Request}; -use crate::{service::appservice::RegistrationInfo, services}; +use crate::service::appservice::RegistrationInfo; /// Extractor for Ruma request structs pub(crate) struct Args<T> { @@ -42,11 +43,12 @@ impl<T, S> FromRequest<S, Body> for Args<T> type Rejection = Error; async fn from_request(request: hyper::Request<Body>, _: &S) -> Result<Self, Self::Rejection> { - let mut request = request::from(request).await?; + let services = service::services(); // ??? + let mut request = request::from(services, request).await?; let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok(); - let auth = auth::auth(&mut request, &json_body, &T::METADATA).await?; + let auth = auth::auth(services, &mut request, &json_body, &T::METADATA).await?; Ok(Self { - body: make_body::<T>(&mut request, &mut json_body, &auth)?, + body: make_body::<T>(services, &mut request, &mut json_body, &auth)?, origin: auth.origin, sender_user: auth.sender_user, sender_device: auth.sender_device, @@ -62,13 +64,16 @@ impl<T> Deref for Args<T> { fn deref(&self) -> &Self::Target { &self.body } } -fn make_body<T>(request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth) -> Result<T> +fn make_body<T>( + services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth, +) -> Result<T> where T: IncomingRequest, { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let user_id = auth.sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") + let server_name = services.globals.server_name(); + UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id") }); let uiaa_request = json_body @@ -77,9 +82,9 @@ fn make_body<T>(request: &mut Request, json_body: &mut Option<CanonicalJsonValue .and_then(|auth| auth.get("session")) .and_then(|session| session.as_str()) .and_then(|session| { - services().uiaa.get_uiaa_request( + services.uiaa.get_uiaa_request( &user_id, - &auth.sender_device.clone().unwrap_or_else(|| "".into()), + &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()), session, ) }); diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 6c2922b97..3f08ddba8 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -6,17 +6,17 @@ typed_header::TypedHeaderRejectionReason, TypedHeader, }; -use conduit::Err; +use conduit::{warn, Err, Error, Result}; use http::uri::PathAndQuery; use ruma::{ api::{client::error::ErrorKind, AuthScheme, Metadata}, server_util::authorization::XMatrix, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; -use tracing::warn; +use service::Services; use super::request::Request; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::service::appservice::RegistrationInfo; enum Token { Appservice(Box<RegistrationInfo>), @@ -33,7 +33,7 @@ pub(super) struct Auth { } pub(super) async fn auth( - request: &mut Request, json_body: &Option<CanonicalJsonValue>, metadata: &Metadata, + services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>, metadata: &Metadata, ) -> Result<Auth> { let bearer: Option<TypedHeader<Authorization<Bearer>>> = request.parts.extract().await?; let token = match &bearer { @@ -42,9 +42,9 @@ pub(super) async fn auth( }; let token = if let Some(token) = token { - if let Some(reg_info) = services().appservice.find_from_token(token).await { + if let Some(reg_info) = services.appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info)) - } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { + } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? { Token::User((user_id, OwnedDeviceId::from(device_id))) } else { Token::Invalid @@ -57,7 +57,7 @@ pub(super) async fn auth( match request.parts.uri.path() { // TODO: can we check this better? "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { - if !services() + if !services .globals .config .allow_public_room_directory_without_auth @@ -98,7 +98,7 @@ pub(super) async fn auth( )) } }, - (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(request, info)?), + (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { origin: None, @@ -110,7 +110,7 @@ pub(super) async fn auth( (AuthScheme::AccessToken, Token::None) => match request.parts.uri.path() { // TODO: can we check this better? "/_matrix/client/v3/voip/turnServer" | "/_matrix/client/r0/voip/turnServer" => { - if services().globals.config.turn_allow_guests { + if services.globals.config.turn_allow_guests { Ok(Auth { origin: None, sender_user: None, @@ -132,7 +132,7 @@ pub(super) async fn auth( sender_device: Some(device_id), appservice_info: None, }), - (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(request, json_body).await?), + (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(services, request, json_body).await?), (AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => Ok(Auth { sender_user: None, sender_device: None, @@ -150,7 +150,7 @@ pub(super) async fn auth( } } -fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { +fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { let user_id = request .query .user_id @@ -159,7 +159,7 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut || { UserId::parse_with_server_name( info.registration.sender_localpart.as_str(), - services().globals.server_name(), + services.globals.server_name(), ) }, UserId::parse, @@ -170,7 +170,7 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); } - if !services().users.exists(&user_id)? { + if !services.users.exists(&user_id)? { return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); } @@ -182,8 +182,10 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut }) } -async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValue>) -> Result<Auth> { - if !services().globals.allow_federation() { +async fn auth_server( + services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>, +) -> Result<Auth> { + if !services.globals.allow_federation() { return Err!(Config("allow_federation", "Federation is disabled.")); } @@ -216,7 +218,7 @@ async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValu ), )]); - let server_destination = services().globals.server_name().as_str().to_owned(); + let server_destination = services.globals.server_name().as_str().to_owned(); if let Some(destination) = x_matrix.destination.as_ref() { if destination != &server_destination { return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); @@ -247,7 +249,7 @@ async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValu request_map.insert("content".to_owned(), json_body.clone()); }; - let keys_result = services() + let keys_result = services .rooms .event_handler .fetch_signing_keys_for_server(origin, vec![x_matrix.key.to_string()]) diff --git a/src/api/router/request.rs b/src/api/router/request.rs index bed8d057b..16c83e52e 100644 --- a/src/api/router/request.rs +++ b/src/api/router/request.rs @@ -2,11 +2,10 @@ use axum::{extract::Path, RequestExt, RequestPartsExt}; use bytes::Bytes; -use conduit::err; +use conduit::{err, Result}; use http::request::Parts; use serde::Deserialize; - -use crate::{services, Result}; +use service::Services; #[derive(Deserialize)] pub(super) struct QueryParams { @@ -21,7 +20,7 @@ pub(super) struct Request { pub(super) parts: Parts, } -pub(super) async fn from(request: hyper::Request<axum::body::Body>) -> Result<Request> { +pub(super) async fn from(services: &Services, request: hyper::Request<axum::body::Body>) -> Result<Request> { let limited = request.with_limited_body(); let (mut parts, body) = limited.into_parts(); @@ -30,7 +29,7 @@ pub(super) async fn from(request: hyper::Request<axum::body::Body>) -> Result<Re let query = serde_html_form::from_str(query).map_err(|e| err!(Request(Unknown("Failed to read query parameters: {e}"))))?; - let max_body_size = services().globals.config.max_request_size; + let max_body_size = services.globals.config.max_request_size; let body = axum::body::to_bytes(body, max_body_size) .await diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index b432ae203..8dd38cad0 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,9 +1,10 @@ +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::backfill::get_backfill}, uint, user_id, MilliSecondsSinceUnixEpoch, }; -use service::{sending::convert_to_outgoing_federation_event, services}; +use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -11,19 +12,21 @@ /// /// Retrieves events from before the sender joined the room, if the room's /// history visibility allows. -pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result<get_backfill::v1::Response> { +pub(crate) async fn get_backfill_route( + State(services): State<crate::State>, body: Ruma<get_backfill::v1::Request>, +) -> Result<get_backfill::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .rooms .state_accessor .is_world_readable(&body.room_id)? - && !services() + && !services .rooms .state_cache .server_in_room(origin, &body.room_id)? @@ -34,7 +37,7 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> let until = body .v .iter() - .map(|event_id| services().rooms.timeline.get_pdu_count(event_id)) + .map(|event_id| services.rooms.timeline.get_pdu_count(event_id)) .filter_map(|r| r.ok().flatten()) .max() .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; @@ -45,7 +48,7 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> .try_into() .expect("UInt could not be converted to usize"); - let all_events = services() + let all_events = services .rooms .timeline .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? @@ -55,20 +58,20 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> .filter_map(Result::ok) .filter(|(_, e)| { matches!( - services() + services .rooms .state_accessor .server_can_see_event(origin, &e.room_id, &e.event_id,), Ok(true), ) }) - .map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) + .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) .filter_map(|r| r.ok().flatten()) .map(convert_to_outgoing_federation_event) .collect(); Ok(get_backfill::v1::Response { - origin: services().globals.server_name().to_owned(), + origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), pdus: events, }) diff --git a/src/api/server/event.rs b/src/api/server/event.rs index f4c9d145f..e8e08c817 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,9 +1,10 @@ +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_event}, MilliSecondsSinceUnixEpoch, RoomId, }; -use service::{sending::convert_to_outgoing_federation_event, services}; +use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -13,10 +14,12 @@ /// /// - Only works if a user of this server is currently invited or joined the /// room -pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_event::v1::Response> { +pub(crate) async fn get_event_route( + State(services): State<crate::State>, body: Ruma<get_event::v1::Request>, +) -> Result<get_event::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - let event = services() + let event = services .rooms .timeline .get_pdu_json(&body.event_id)? @@ -30,16 +33,13 @@ pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Resul let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - if !services().rooms.state_accessor.is_world_readable(room_id)? - && !services() - .rooms - .state_cache - .server_in_room(origin, room_id)? + if !services.rooms.state_accessor.is_world_readable(room_id)? + && !services.rooms.state_cache.server_in_room(origin, room_id)? { return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } - if !services() + if !services .rooms .state_accessor .server_can_see_event(origin, room_id, &body.event_id)? @@ -48,7 +48,7 @@ pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Resul } Ok(get_event::v1::Response { - origin: services().globals.server_name().to_owned(), + origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), pdu: convert_to_outgoing_federation_event(event), }) diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index bef5116b9..8d26b73a4 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,11 +1,12 @@ use std::sync::Arc; +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, }; -use service::{sending::convert_to_outgoing_federation_event, services}; +use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -15,20 +16,20 @@ /// /// - This does not include the event itself pub(crate) async fn get_event_authorization_route( - body: Ruma<get_event_authorization::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_event_authorization::v1::Request>, ) -> Result<get_event_authorization::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .rooms .state_accessor .is_world_readable(&body.room_id)? - && !services() + && !services .rooms .state_cache .server_in_room(origin, &body.room_id)? @@ -36,7 +37,7 @@ pub(crate) async fn get_event_authorization_route( return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } - let event = services() + let event = services .rooms .timeline .get_pdu_json(&body.event_id)? @@ -50,7 +51,7 @@ pub(crate) async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; - let auth_chain_ids = services() + let auth_chain_ids = services .rooms .auth_chain .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) @@ -58,7 +59,7 @@ pub(crate) async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) + .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) .map(convert_to_outgoing_federation_event) .collect(), }) diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 5ab9abf84..378cd4fe3 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -1,9 +1,10 @@ +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_missing_events}, OwnedEventId, RoomId, }; -use service::{sending::convert_to_outgoing_federation_event, services}; +use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -11,20 +12,20 @@ /// /// Retrieves events that the sender is missing. pub(crate) async fn get_missing_events_route( - body: Ruma<get_missing_events::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_missing_events::v1::Request>, ) -> Result<get_missing_events::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .rooms .state_accessor .is_world_readable(&body.room_id)? - && !services() + && !services .rooms .state_cache .server_in_room(origin, &body.room_id)? @@ -43,7 +44,7 @@ pub(crate) async fn get_missing_events_route( let mut i: usize = 0; while i < queued_events.len() && events.len() < limit { - if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { + if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -61,7 +62,7 @@ pub(crate) async fn get_missing_events_route( continue; } - if !services() + if !services .rooms .state_accessor .server_can_see_event(origin, &body.room_id, &queued_events[i])? diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index 84b30e80f..530ed1456 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -1,16 +1,19 @@ +use axum::extract::State; use ruma::api::{client::error::ErrorKind, federation::space::get_hierarchy}; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/federation/v1/hierarchy/{roomId}` /// /// Gets the space tree in a depth-first manner to locate child rooms of a given /// space. -pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { +pub(crate) async fn get_hierarchy_route( + State(services): State<crate::State>, body: Ruma<get_hierarchy::v1::Request>, +) -> Result<get_hierarchy::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - if services().rooms.metadata.exists(&body.room_id)? { - services() + if services.rooms.metadata.exists(&body.room_id)? { + services .rooms .spaces .get_federation_hierarchy(&body.room_id, origin, body.suggested_only) diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 89b900589..982a2a01e 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{utils, warn, Error, PduEvent, Result}; use ruma::{ @@ -6,7 +7,7 @@ serde::JsonObject, CanonicalJsonValue, EventId, OwnedUserId, }; -use service::{sending::convert_to_outgoing_federation_event, server_is_ours, services}; +use service::{sending::convert_to_outgoing_federation_event, server_is_ours}; use crate::Ruma; @@ -15,17 +16,18 @@ /// Invites a remote user to a room. #[tracing::instrument(skip_all, fields(%client), name = "invite")] pub(crate) async fn create_invite_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<create_invite::v2::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<create_invite::v2::Request>, ) -> Result<create_invite::v2::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); // ACL check origin - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .globals .supported_room_versions() .contains(&body.room_version) @@ -39,7 +41,7 @@ pub(crate) async fn create_invite_route( } if let Some(server) = body.room_id.server_name() { - if services() + if services .globals .config .forbidden_remote_server_names @@ -52,7 +54,7 @@ pub(crate) async fn create_invite_route( } } - if services() + if services .globals .config .forbidden_remote_server_names @@ -94,14 +96,14 @@ pub(crate) async fn create_invite_route( } // Make sure we're not ACL'ed from their room. - services() + services .rooms .event_handler .acl_check(invited_user.server_name(), &body.room_id)?; ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut signed_event, &body.room_version, ) @@ -127,14 +129,14 @@ pub(crate) async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; - if services().rooms.metadata.is_banned(&body.room_id)? && !services().users.is_admin(&invited_user)? { + if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "This room is banned on this homeserver.", )); } - if services().globals.block_non_admin_invites() && !services().users.is_admin(&invited_user)? { + if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "This server does not allow room invites.", @@ -155,12 +157,12 @@ pub(crate) async fn create_invite_route( // If we are active in the room, the remote server will notify us about the join // via /send - if !services() + if !services .rooms .state_cache - .server_in_room(services().globals.server_name(), &body.room_id)? + .server_in_room(services.globals.server_name(), &body.room_id)? { - services().rooms.state_cache.update_membership( + services.rooms.state_cache.update_membership( &body.room_id, &invited_user, RoomMemberEventContent::new(MembershipState::Invite), diff --git a/src/api/server/key.rs b/src/api/server/key.rs index 27d866397..686e44242 100644 --- a/src/api/server/key.rs +++ b/src/api/server/key.rs @@ -3,7 +3,7 @@ time::{Duration, SystemTime}, }; -use axum::{response::IntoResponse, Json}; +use axum::{extract::State, response::IntoResponse, Json}; use ruma::{ api::{ federation::discovery::{get_server_keys, ServerSigningKeys, VerifyKey}, @@ -13,7 +13,7 @@ MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, }; -use crate::{services, Result}; +use crate::Result; /// # `GET /_matrix/key/v2/server` /// @@ -23,20 +23,20 @@ /// this will be valid forever. // Response type for this endpoint is Json because we need to calculate a // signature for the response -pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> { +pub(crate) async fn get_server_keys_route(State(services): State<crate::State>) -> Result<impl IntoResponse> { let verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> = BTreeMap::from([( - format!("ed25519:{}", services().globals.keypair().version()) + format!("ed25519:{}", services.globals.keypair().version()) .try_into() .expect("found invalid server signing keys in DB"), VerifyKey { - key: Base64::new(services().globals.keypair().public_key().to_vec()), + key: Base64::new(services.globals.keypair().public_key().to_vec()), }, )]); let mut response = serde_json::from_slice( get_server_keys::v2::Response { server_key: Raw::new(&ServerSigningKeys { - server_name: services().globals.server_name().to_owned(), + server_name: services.globals.server_name().to_owned(), verify_keys, old_verify_keys: BTreeMap::new(), signatures: BTreeMap::new(), @@ -56,8 +56,8 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> { .unwrap(); ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut response, ) .unwrap(); @@ -71,4 +71,6 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> { /// /// - Matrix does not support invalidating public keys, so the key returned by /// this will be valid forever. -pub(crate) async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_server_keys_route().await } +pub(crate) async fn get_server_keys_deprecated_route(State(services): State<crate::State>) -> impl IntoResponse { + get_server_keys_route(State(services)).await +} diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index b5dadf7fd..e1beaa33c 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_join_event}, events::{ @@ -12,15 +13,18 @@ use serde_json::value::to_raw_value; use tracing::warn; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; +use crate::{ + service::{pdu::PduBuilder, Services}, + Error, Result, Ruma, +}; /// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` /// /// Creates a join template. pub(crate) async fn create_join_event_template_route( - body: Ruma<prepare_join_event::v1::Request>, + State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>, ) -> Result<prepare_join_event::v1::Response> { - if !services().rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -33,12 +37,12 @@ pub(crate) async fn create_join_event_template_route( } // ACL check origin server - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if services() + if services .globals .config .forbidden_remote_server_names @@ -56,7 +60,7 @@ pub(crate) async fn create_join_event_template_route( } if let Some(server) = body.room_id.server_name() { - if services() + if services .globals .config .forbidden_remote_server_names @@ -69,25 +73,25 @@ pub(crate) async fn create_join_event_template_route( } } - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; - let join_authorized_via_users_server = if (services() + let join_authorized_via_users_server = if (services .rooms .state_cache .is_left(&body.user_id, &body.room_id) .unwrap_or(true)) - && user_can_perform_restricted_join(&body.user_id, &body.room_id, &room_version_id)? + && user_can_perform_restricted_join(services, &body.user_id, &body.room_id, &room_version_id)? { - let auth_user = services() + let auth_user = services .rooms .state_cache .room_members(&body.room_id) .filter_map(Result::ok) - .filter(|user| user.server_name() == services().globals.server_name()) + .filter(|user| user.server_name() == services.globals.server_name()) .find(|user| { - services() + services .rooms .state_accessor .user_can_invite(&body.room_id, user, &body.user_id, &state_lock) @@ -106,7 +110,7 @@ pub(crate) async fn create_join_event_template_route( None }; - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { @@ -128,7 +132,7 @@ pub(crate) async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( PduBuilder { event_type: TimelineEventType::RoomMember, content, @@ -157,15 +161,14 @@ pub(crate) async fn create_join_event_template_route( /// externally, either by using the state cache or attempting to authorize the /// event. pub(crate) fn user_can_perform_restricted_join( - user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, + services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result<bool> { use RoomVersionId::*; - let join_rules_event = - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + let join_rules_event = services + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; let Some(join_rules_event_content) = join_rules_event .as_ref() @@ -198,7 +201,7 @@ pub(crate) fn user_can_perform_restricted_join( } }) .any(|m| { - services() + services .rooms .state_cache .is_joined(user_id, &m.room_id) diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 63fc2b2eb..ae7237ad3 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, @@ -9,15 +10,15 @@ use serde_json::value::to_raw_value; use super::make_join::maybe_strip_event_id; -use crate::{service::pdu::PduBuilder, services, Ruma}; +use crate::{service::pdu::PduBuilder, Ruma}; /// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}` /// /// Creates a leave template. pub(crate) async fn create_leave_event_template_route( - body: Ruma<prepare_leave_event::v1::Request>, + State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>, ) -> Result<prepare_leave_event::v1::Response> { - if !services().rooms.metadata.exists(&body.room_id)? { + if !services.rooms.metadata.exists(&body.room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } @@ -30,13 +31,13 @@ pub(crate) async fn create_leave_event_template_route( } // ACL check origin - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let room_version_id = services.rooms.state.get_room_version(&body.room_id)?; + let state_lock = services.rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, blurhash: None, @@ -49,7 +50,7 @@ pub(crate) async fn create_leave_event_template_route( }) .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event( PduBuilder { event_type: TimelineEventType::RoomMember, content, diff --git a/src/api/server/openid.rs b/src/api/server/openid.rs index 322c7343e..6a1b99b75 100644 --- a/src/api/server/openid.rs +++ b/src/api/server/openid.rs @@ -1,16 +1,15 @@ +use axum::extract::State; use ruma::api::federation::openid::get_openid_userinfo; -use crate::{services, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/federation/v1/openid/userinfo` /// /// Get information about the user that generated the OpenID token. pub(crate) async fn get_openid_userinfo_route( - body: Ruma<get_openid_userinfo::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>, ) -> Result<get_openid_userinfo::v1::Response> { Ok(get_openid_userinfo::v1::Response::new( - services() - .users - .find_from_openid_token(&body.access_token)?, + services.users.find_from_openid_token(&body.access_token)?, )) } diff --git a/src/api/server/publicrooms.rs b/src/api/server/publicrooms.rs index 13edd8d6b..1876dde17 100644 --- a/src/api/server/publicrooms.rs +++ b/src/api/server/publicrooms.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use axum_client_ip::InsecureClientIp; use ruma::{ api::{ @@ -7,16 +8,17 @@ directory::Filter, }; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `POST /_matrix/federation/v1/publicRooms` /// /// Lists the public rooms on this server. #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] pub(crate) async fn get_public_rooms_filtered_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms_filtered::v1::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_public_rooms_filtered::v1::Request>, ) -> Result<get_public_rooms_filtered::v1::Response> { - if !services() + if !services .globals .allow_public_room_directory_over_federation() { @@ -24,6 +26,7 @@ pub(crate) async fn get_public_rooms_filtered_route( } let response = crate::client::get_public_rooms_filtered_helper( + services, None, body.limit, body.since.as_deref(), @@ -46,9 +49,10 @@ pub(crate) async fn get_public_rooms_filtered_route( /// Lists the public rooms on this server. #[tracing::instrument(skip_all, fields(%client), "publicrooms")] pub(crate) async fn get_public_rooms_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms::v1::Request>, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, + body: Ruma<get_public_rooms::v1::Request>, ) -> Result<get_public_rooms::v1::Response> { - if !services() + if !services .globals .allow_public_room_directory_over_federation() { @@ -56,6 +60,7 @@ pub(crate) async fn get_public_rooms_route( } let response = crate::client::get_public_rooms_filtered_helper( + services, None, body.limit, body.since.as_deref(), diff --git a/src/api/server/query.rs b/src/api/server/query.rs index 89baedb33..dddf23e71 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -8,21 +9,21 @@ OwnedServerName, }; -use crate::{service::server_is_ours, services, Error, Result, Ruma}; +use crate::{service::server_is_ours, Error, Result, Ruma}; /// # `GET /_matrix/federation/v1/query/directory` /// /// Resolve a room alias to a room id. pub(crate) async fn get_room_information_route( - body: Ruma<get_room_information::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_room_information::v1::Request>, ) -> Result<get_room_information::v1::Response> { - let room_id = services() + let room_id = services .rooms .alias .resolve_local_alias(&body.room_alias)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; - let mut servers: Vec<OwnedServerName> = services() + let mut servers: Vec<OwnedServerName> = services .rooms .state_cache .room_servers(&room_id) @@ -37,10 +38,10 @@ pub(crate) async fn get_room_information_route( // insert our server as the very first choice if in list if let Some(server_index) = servers .iter() - .position(|server| server == services().globals.server_name()) + .position(|server| server == services.globals.server_name()) { servers.swap_remove(server_index); - servers.insert(0, services().globals.server_name().to_owned()); + servers.insert(0, services.globals.server_name().to_owned()); } Ok(get_room_information::v1::Response { @@ -54,12 +55,9 @@ pub(crate) async fn get_room_information_route( /// /// Gets information on a profile. pub(crate) async fn get_profile_information_route( - body: Ruma<get_profile_information::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_profile_information::v1::Request>, ) -> Result<get_profile_information::v1::Response> { - if !services() - .globals - .allow_profile_lookup_federation_requests() - { + if !services.globals.allow_profile_lookup_federation_requests() { return Err(Error::BadRequest( ErrorKind::forbidden(), "Profile lookup over federation is not allowed on this homeserver.", @@ -79,18 +77,18 @@ pub(crate) async fn get_profile_information_route( match &body.field { Some(ProfileField::DisplayName) => { - displayname = services().users.displayname(&body.user_id)?; + displayname = services.users.displayname(&body.user_id)?; }, Some(ProfileField::AvatarUrl) => { - avatar_url = services().users.avatar_url(&body.user_id)?; - blurhash = services().users.blurhash(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id)?; + blurhash = services.users.blurhash(&body.user_id)?; }, // TODO: what to do with custom Some(_) => {}, None => { - displayname = services().users.displayname(&body.user_id)?; - avatar_url = services().users.avatar_url(&body.user_id)?; - blurhash = services().users.blurhash(&body.user_id)?; + displayname = services.users.displayname(&body.user_id)?; + avatar_url = services.users.avatar_url(&body.user_id)?; + blurhash = services.users.blurhash(&body.user_id)?; }, } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 7c699e95c..f29344802 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -34,7 +34,7 @@ /// Push EDUs and PDUs to this server. #[tracing::instrument(skip_all, fields(%client), name = "send")] pub(crate) async fn send_transaction_message_route( - State(services): State<&Services>, InsecureClientIp(client): InsecureClientIp, + State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<send_transaction_message::v1::Request>, ) -> Result<send_transaction_message::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 577833d55..b72bfa035 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; +use axum::extract::State; use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, @@ -13,7 +14,7 @@ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::{ - pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, services, user_is_local, + pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services, }; use tokio::sync::RwLock; use tracing::warn; @@ -22,18 +23,18 @@ /// helper method for /send_join v1 and v2 async fn create_join_event( - origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, + services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<create_join_event::v1::RoomState> { - if !services().rooms.metadata.exists(room_id)? { + if !services.rooms.metadata.exists(room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin server - services().rooms.event_handler.acl_check(origin, room_id)?; + services.rooms.event_handler.acl_check(origin, room_id)?; // We need to return the state prior to joining, let's keep a reference to that // here - let shortstatehash = services() + let shortstatehash = services .rooms .state .get_room_shortstatehash(room_id)? @@ -44,7 +45,7 @@ async fn create_join_event( // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services().rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id)?; let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json @@ -96,7 +97,7 @@ async fn create_join_event( ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?; - services() + services .rooms .event_handler .acl_check(sender.server_name(), room_id)?; @@ -128,18 +129,18 @@ async fn create_join_event( if content .join_authorized_via_users_server .is_some_and(|user| user_is_local(&user)) - && super::user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default() + && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() { ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + services.globals.server_name().as_str(), + services.globals.keypair(), &mut value, &room_version_id, ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; } - services() + services .rooms .event_handler .fetch_required_signing_keys([&value], &pub_key_map) @@ -155,13 +156,13 @@ async fn create_join_event( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; - let mutex_lock = services() + let mutex_lock = services .rooms .event_handler .mutex_federation .lock(room_id) .await; - let pdu_id: Vec<u8> = services() + let pdu_id: Vec<u8> = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) @@ -169,27 +170,27 @@ async fn create_join_event( .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; drop(mutex_lock); - let state_ids = services() + let state_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) .await?; - let auth_chain_ids = services() + let auth_chain_ids = services .rooms .auth_chain .event_ids_iter(room_id, state_ids.values().cloned().collect()) .await?; - services().sending.send_pdu_room(room_id, &pdu_id)?; + services.sending.send_pdu_room(room_id, &pdu_id)?; Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) .map(convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() - .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) .map(convert_to_outgoing_federation_event) .collect(), // Event field is required if the room version supports restricted join rules. @@ -204,11 +205,11 @@ async fn create_join_event( /// /// Submits a signed join event. pub(crate) async fn create_join_event_v1_route( - body: Ruma<create_join_event::v1::Request>, + State(services): State<crate::State>, body: Ruma<create_join_event::v1::Request>, ) -> Result<create_join_event::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - if services() + if services .globals .config .forbidden_remote_server_names @@ -225,7 +226,7 @@ pub(crate) async fn create_join_event_v1_route( } if let Some(server) = body.room_id.server_name() { - if services() + if services .globals .config .forbidden_remote_server_names @@ -243,7 +244,7 @@ pub(crate) async fn create_join_event_v1_route( } } - let room_state = create_join_event(origin, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(services, origin, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state, @@ -254,11 +255,11 @@ pub(crate) async fn create_join_event_v1_route( /// /// Submits a signed join event. pub(crate) async fn create_join_event_v2_route( - body: Ruma<create_join_event::v2::Request>, + State(services): State<crate::State>, body: Ruma<create_join_event::v2::Request>, ) -> Result<create_join_event::v2::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - if services() + if services .globals .config .forbidden_remote_server_names @@ -271,7 +272,7 @@ pub(crate) async fn create_join_event_v2_route( } if let Some(server) = body.room_id.server_name() { - if services() + if services .globals .config .forbidden_remote_server_names @@ -288,7 +289,7 @@ pub(crate) async fn create_join_event_v2_route( auth_chain, state, event, - } = create_join_event(origin, &body.room_id, &body.pdu).await?; + } = create_join_event(services, origin, &body.room_id, &body.pdu).await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index c4e17bbc3..f7484a330 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; +use axum::extract::State; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -14,19 +15,19 @@ use tokio::sync::RwLock; use crate::{ - service::{pdu::gen_event_id_canonical_json, server_is_ours}, - services, Error, Result, Ruma, + service::{pdu::gen_event_id_canonical_json, server_is_ours, Services}, + Error, Result, Ruma, }; /// # `PUT /_matrix/federation/v1/send_leave/{roomId}/{eventId}` /// /// Submits a signed leave event. pub(crate) async fn create_leave_event_v1_route( - body: Ruma<create_leave_event::v1::Request>, + State(services): State<crate::State>, body: Ruma<create_leave_event::v1::Request>, ) -> Result<create_leave_event::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - create_leave_event(origin, &body.room_id, &body.pdu).await?; + create_leave_event(services, origin, &body.room_id, &body.pdu).await?; Ok(create_leave_event::v1::Response::new()) } @@ -35,28 +36,30 @@ pub(crate) async fn create_leave_event_v1_route( /// /// Submits a signed leave event. pub(crate) async fn create_leave_event_v2_route( - body: Ruma<create_leave_event::v2::Request>, + State(services): State<crate::State>, body: Ruma<create_leave_event::v2::Request>, ) -> Result<create_leave_event::v2::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - create_leave_event(origin, &body.room_id, &body.pdu).await?; + create_leave_event(services, origin, &body.room_id, &body.pdu).await?; Ok(create_leave_event::v2::Response::new()) } -async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue) -> Result<()> { - if !services().rooms.metadata.exists(room_id)? { +async fn create_leave_event( + services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, +) -> Result<()> { + if !services.rooms.metadata.exists(room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } // ACL check origin - services().rooms.event_handler.acl_check(origin, room_id)?; + services.rooms.event_handler.acl_check(origin, room_id)?; let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and // hashes checks - let room_version_id = services().rooms.state.get_room_version(room_id)?; + let room_version_id = services.rooms.state.get_room_version(room_id)?; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json return Err(Error::BadRequest( @@ -107,7 +110,7 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson ) .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid."))?; - services() + services .rooms .event_handler .acl_check(sender.server_name(), room_id)?; @@ -145,19 +148,19 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; - services() + services .rooms .event_handler .fetch_required_signing_keys([&value], &pub_key_map) .await?; - let mutex_lock = services() + let mutex_lock = services .rooms .event_handler .mutex_federation .lock(room_id) .await; - let pdu_id: Vec<u8> = services() + let pdu_id: Vec<u8> = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) @@ -166,14 +169,14 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson drop(mutex_lock); - let servers = services() + let servers = services .rooms .state_cache .room_servers(room_id) .filter_map(Result::ok) .filter(|server| !server_is_ours(server)); - services().sending.send_pdu_servers(servers, &pdu_id)?; + services.sending.send_pdu_servers(servers, &pdu_id)?; Ok(()) } diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 220448401..24a11ccab 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,8 +1,9 @@ use std::sync::Arc; +use axum::extract::State; use conduit::{Error, Result}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; -use service::{sending::convert_to_outgoing_federation_event, services}; +use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -10,20 +11,20 @@ /// /// Retrieves a snapshot of a room's state at a given event. pub(crate) async fn get_room_state_route( - body: Ruma<get_room_state::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_room_state::v1::Request>, ) -> Result<get_room_state::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .rooms .state_accessor .is_world_readable(&body.room_id)? - && !services() + && !services .rooms .state_cache .server_in_room(origin, &body.room_id)? @@ -31,31 +32,22 @@ pub(crate) async fn get_room_state_route( return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } - let shortstatehash = services() + let shortstatehash = services .rooms .state_accessor .pdu_shortstatehash(&body.event_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; - let pdus = services() + let pdus = services .rooms .state_accessor .state_full_ids(shortstatehash) .await? .into_values() - .map(|id| { - convert_to_outgoing_federation_event( - services() - .rooms - .timeline - .get_pdu_json(&id) - .unwrap() - .unwrap(), - ) - }) + .map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())) .collect(); - let auth_chain_ids = services() + let auth_chain_ids = services .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) @@ -64,7 +56,7 @@ pub(crate) async fn get_room_state_route( Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids .filter_map(|id| { - services() + services .rooms .timeline .get_pdu_json(&id) diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 9fe695dc4..d22f2df4a 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,28 +1,29 @@ use std::sync::Arc; +use axum::extract::State; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// /// Retrieves a snapshot of a room's state at a given event, in the form of /// event IDs. pub(crate) async fn get_room_state_ids_route( - body: Ruma<get_room_state_ids::v1::Request>, + State(services): State<crate::State>, body: Ruma<get_room_state_ids::v1::Request>, ) -> Result<get_room_state_ids::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - services() + services .rooms .event_handler .acl_check(origin, &body.room_id)?; - if !services() + if !services .rooms .state_accessor .is_world_readable(&body.room_id)? - && !services() + && !services .rooms .state_cache .server_in_room(origin, &body.room_id)? @@ -30,13 +31,13 @@ pub(crate) async fn get_room_state_ids_route( return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); } - let shortstatehash = services() + let shortstatehash = services .rooms .state_accessor .pdu_shortstatehash(&body.event_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; - let pdu_ids = services() + let pdu_ids = services .rooms .state_accessor .state_full_ids(shortstatehash) @@ -45,7 +46,7 @@ pub(crate) async fn get_room_state_ids_route( .map(|id| (*id).to_owned()) .collect(); - let auth_chain_ids = services() + let auth_chain_ids = services .rooms .auth_chain .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) diff --git a/src/api/server/user.rs b/src/api/server/user.rs index 5bfc070ce..949e5f380 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,3 +1,4 @@ +use axum::extract::State; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -9,13 +10,15 @@ use crate::{ client::{claim_keys_helper, get_keys_helper}, service::user_is_local, - services, Error, Result, Ruma, + Error, Result, Ruma, }; /// # `GET /_matrix/federation/v1/user/devices/{userId}` /// /// Gets information on all devices of the user. -pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> Result<get_devices::v1::Response> { +pub(crate) async fn get_devices_route( + State(services): State<crate::State>, body: Ruma<get_devices::v1::Request>, +) -> Result<get_devices::v1::Response> { if !user_is_local(&body.user_id) { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -27,25 +30,25 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R Ok(get_devices::v1::Response { user_id: body.user_id.clone(), - stream_id: services() + stream_id: services .users .get_devicelist_version(&body.user_id)? .unwrap_or(0) .try_into() .expect("version will not grow that large"), - devices: services() + devices: services .users .all_devices_metadata(&body.user_id) .filter_map(Result::ok) .filter_map(|metadata| { let device_id_string = metadata.device_id.as_str().to_owned(); - let device_display_name = if services().globals.allow_device_name_federation() { + let device_display_name = if services.globals.allow_device_name_federation() { metadata.display_name } else { Some(device_id_string) }; Some(UserDevice { - keys: services() + keys: services .users .get_device_keys(&body.user_id, &metadata.device_id) .ok()??, @@ -54,10 +57,10 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R }) }) .collect(), - master_key: services() + master_key: services .users .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, - self_signing_key: services() + self_signing_key: services .users .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, }) @@ -66,7 +69,9 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R /// # `POST /_matrix/federation/v1/user/keys/query` /// /// Gets devices and identity keys for the given users. -pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_keys::v1::Response> { +pub(crate) async fn get_keys_route( + State(services): State<crate::State>, body: Ruma<get_keys::v1::Request>, +) -> Result<get_keys::v1::Response> { if body.device_keys.iter().any(|(u, _)| !user_is_local(u)) { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -75,10 +80,11 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result< } let result = get_keys_helper( + services, None, &body.device_keys, |u| Some(u.server_name()) == body.origin.as_deref(), - services().globals.allow_device_name_federation(), + services.globals.allow_device_name_federation(), ) .await?; @@ -92,7 +98,9 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result< /// # `POST /_matrix/federation/v1/user/keys/claim` /// /// Claims one-time keys. -pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Result<claim_keys::v1::Response> { +pub(crate) async fn claim_keys_route( + State(services): State<crate::State>, body: Ruma<claim_keys::v1::Request>, +) -> Result<claim_keys::v1::Response> { if body.one_time_keys.iter().any(|(u, _)| !user_is_local(u)) { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -100,7 +108,7 @@ pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Res )); } - let result = claim_keys_helper(&body.one_time_keys).await?; + let result = claim_keys_helper(services, &body.one_time_keys).await?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys, diff --git a/src/api/server/well_known.rs b/src/api/server/well_known.rs index 358626621..2cc8f2386 100644 --- a/src/api/server/well_known.rs +++ b/src/api/server/well_known.rs @@ -1,15 +1,16 @@ +use axum::extract::State; use ruma::api::{client::error::ErrorKind, federation::discovery::discover_homeserver}; -use crate::{services, Error, Result, Ruma}; +use crate::{Error, Result, Ruma}; /// # `GET /.well-known/matrix/server` /// /// Returns the .well-known URL if it is configured, otherwise returns 404. pub(crate) async fn well_known_server( - _body: Ruma<discover_homeserver::Request>, + State(services): State<crate::State>, _body: Ruma<discover_homeserver::Request>, ) -> Result<discover_homeserver::Response> { Ok(discover_homeserver::Response { - server: match services().globals.well_known_server() { + server: match services.globals.well_known_server() { Some(server_name) => server_name.to_owned(), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), }, diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 6f26161ee..75cae822d 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -176,6 +176,7 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> self.db.watch(user_id, device_id).await } + #[inline] pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } -- GitLab From 2fd6f6b0ff1301544cdb3886069799ca103e3361 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 21:33:28 +0000 Subject: [PATCH 03/47] add polymorphism to Services Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/manager.rs | 2 +- src/service/service.rs | 15 ++++++++++++--- src/service/services.rs | 14 ++++++++------ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/service/manager.rs b/src/service/manager.rs index af59b4a43..447cd6fe7 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -54,7 +54,7 @@ pub(super) async fn start(self: Arc<Self>) -> Result<()> { ); debug!("Starting service workers..."); - for service in self.services.service.values() { + for (service, ..) in self.services.service.values() { self.start_worker(&mut workers, service).await?; } diff --git a/src/service/service.rs b/src/service/service.rs index 3b8f4231d..ab41753a8 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -1,11 +1,11 @@ -use std::{collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; use async_trait::async_trait; use conduit::{utils::string::split_once_infallible, Result, Server}; use database::Database; #[async_trait] -pub(crate) trait Service: Send + Sync { +pub(crate) trait Service: Any + Send + Sync { /// Implement the construction of the service instance. Services are /// generally singletons so expect this to only be called once for a /// service type. Note that it may be called again after a server reload, @@ -40,7 +40,16 @@ pub(crate) struct Args<'a> { pub(crate) _service: &'a Map, } -pub(crate) type Map = BTreeMap<String, Arc<dyn Service>>; +pub(crate) type Map = BTreeMap<String, MapVal>; +pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); + +pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> { + map.get(name).map(|(_, s)| { + s.clone() + .downcast::<T>() + .expect("Service must be correctly downcast.") + }) +} #[inline] pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index cc9ec2900..0aac10dc0 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; use conduit::{debug, debug_info, info, trace, Result, Server}; use database::Database; @@ -7,7 +7,7 @@ use crate::{ account_data, admin, appservice, globals, key_backups, manager::Manager, - media, presence, pusher, rooms, sending, + media, presence, pusher, rooms, sending, service, service::{Args, Map, Service}, transaction_ids, uiaa, updates, users, }; @@ -44,7 +44,7 @@ macro_rules! build { db: &db, _service: &service, })?; - service.insert(built.name().to_owned(), built.clone()); + service.insert(built.name().to_owned(), (built.clone(), built.clone())); built }}; } @@ -128,7 +128,7 @@ pub async fn poll(&self) -> Result<()> { } pub async fn clear_cache(&self) { - for service in self.service.values() { + for (service, ..) in self.service.values() { service.clear_cache(); } @@ -143,7 +143,7 @@ pub async fn clear_cache(&self) { pub async fn memory_usage(&self) -> Result<String> { let mut out = String::new(); - for service in self.service.values() { + for (service, ..) in self.service.values() { service.memory_usage(&mut out)?; } @@ -163,9 +163,11 @@ pub async fn memory_usage(&self) -> Result<String> { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, service) in &self.service { + for (name, (service, ..)) in &self.service { trace!("Interrupting {name}"); service.interrupt(); } } + + pub fn get<T: Any + Send + Sync>(&self, name: &str) -> Option<Arc<T>> { service::get::<T>(&self.service, name) } } -- GitLab From f465d77ad39d3d095346e30daf9d51270e3976be Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 22:00:54 +0000 Subject: [PATCH 04/47] convert Resolver into a Service. Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/query/resolver.rs | 18 +- src/service/globals/client.rs | 4 +- src/service/globals/mod.rs | 16 +- src/service/globals/resolver.rs | 156 -------------- src/service/mod.rs | 1 + src/service/resolver/mod.rs | 349 ++++++++++++++++++++++++++++++++ src/service/sending/mod.rs | 2 +- src/service/sending/resolve.rs | 225 +------------------- src/service/sending/send.rs | 9 +- src/service/service.rs | 2 +- src/service/services.rs | 10 +- 11 files changed, 382 insertions(+), 410 deletions(-) delete mode 100644 src/service/globals/resolver.rs create mode 100644 src/service/resolver/mod.rs diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 2a2554b5f..06cd8ba95 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result<RoomMessageEventCon } async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> { - use service::sending::CachedDest; + use service::resolver::CachedDest; let mut out = String::new(); writeln!(out, "| Server Name | Destination | Hostname | Expires |")?; @@ -36,12 +36,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line"); }; - let map = services() - .globals - .resolver - .destinations - .read() - .expect("locked"); + let map = services().resolver.destinations.read().expect("locked"); if let Some(server_name) = server_name.as_ref() { map.get_key_value(server_name).map(row); @@ -53,7 +48,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room } async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEventContent> { - use service::sending::CachedOverride; + use service::resolver::CachedOverride; let mut out = String::new(); writeln!(out, "| Server Name | IP | Port | Expires |")?; @@ -70,12 +65,7 @@ async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEvent writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line"); }; - let map = services() - .globals - .resolver - .overrides - .read() - .expect("locked"); + let map = services().resolver.overrides.read().expect("locked"); if let Some(server_name) = server_name.as_ref() { map.get_key_value(server_name).map(row); diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs index 1986d7d6e..d8b84dede 100644 --- a/src/service/globals/client.rs +++ b/src/service/globals/client.rs @@ -2,7 +2,7 @@ use reqwest::redirect; -use crate::{globals::resolver, Config, Result}; +use crate::{resolver, Config, Result}; pub struct Client { pub default: reqwest::Client, @@ -15,7 +15,7 @@ pub struct Client { } impl Client { - pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Self { + pub fn new(config: &Config, resolver: &Arc<resolver::Service>) -> Self { Self { default: Self::base(config) .unwrap() diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 75cae822d..a2fba4d95 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -2,7 +2,6 @@ mod data; mod emerg_access; pub(super) mod migrations; -pub(crate) mod resolver; use std::{ collections::{BTreeMap, HashMap}, @@ -25,7 +24,7 @@ use tokio::sync::Mutex; use url::Url; -use crate::services; +use crate::{resolver, service, services}; pub struct Service { pub db: Data, @@ -34,7 +33,6 @@ pub struct Service { pub cidr_range_denylist: Vec<IPAddress>, keypair: Arc<ruma::signatures::Ed25519KeyPair>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey>, - pub resolver: Arc<resolver::Resolver>, pub client: client::Client, pub stable_room_versions: Vec<RoomVersionId>, pub unstable_room_versions: Vec<RoomVersionId>, @@ -68,8 +66,6 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { .as_ref() .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); - let resolver = Arc::new(resolver::Resolver::new(config)); - // Supported and stable room versions let stable_room_versions = vec![ RoomVersionId::V6, @@ -89,12 +85,14 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { cidr_range_denylist.push(cidr); } + let resolver = service::get::<resolver::Service>(args.service, "resolver") + .expect("resolver must be built prior to globals"); + let mut s = Self { db, config: config.clone(), cidr_range_denylist, keypair: Arc::new(keypair), - resolver: resolver.clone(), client: client::Client::new(config, &resolver), jwt_decoding_key, stable_room_versions, @@ -126,8 +124,6 @@ async fn worker(self: Arc<Self>) -> Result<()> { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - self.resolver.memory_usage(out)?; - let bad_event_ratelimiter = self .bad_event_ratelimiter .read() @@ -146,8 +142,6 @@ fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { } fn clear_cache(&self) { - self.resolver.clear_cache(); - self.bad_event_ratelimiter .write() .expect("locked for writing") @@ -159,7 +153,7 @@ fn clear_cache(&self) { .clear(); } - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } + fn name(&self) -> &str { service::make_name(std::module_path!()) } } impl Service { diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs deleted file mode 100644 index 3002decf4..000000000 --- a/src/service/globals/resolver.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::{ - collections::HashMap, - fmt::Write, - future, iter, - net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, - time::Duration, -}; - -use conduit::{error, Config, Result}; -use hickory_resolver::TokioAsyncResolver; -use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use ruma::OwnedServerName; - -use crate::sending::{CachedDest, CachedOverride}; - -type WellKnownMap = HashMap<OwnedServerName, CachedDest>; -type TlsNameMap = HashMap<String, CachedOverride>; - -pub struct Resolver { - pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host - pub overrides: Arc<RwLock<TlsNameMap>>, - pub(crate) resolver: Arc<TokioAsyncResolver>, - pub(crate) hooked: Arc<Hooked>, -} - -pub(crate) struct Hooked { - overrides: Arc<RwLock<TlsNameMap>>, - resolver: Arc<TokioAsyncResolver>, -} - -impl Resolver { - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - pub(super) fn new(config: &Config) -> Self { - let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}")) - .expect("DNS system config must be valid"); - - let mut conf = hickory_resolver::config::ResolverConfig::new(); - - if let Some(domain) = sys_conf.domain() { - conf.set_domain(domain.clone()); - } - - for sys_conf in sys_conf.search() { - conf.add_search(sys_conf.clone()); - } - - for sys_conf in sys_conf.name_servers() { - let mut ns = sys_conf.clone(); - - if config.query_over_tcp_only { - ns.protocol = hickory_resolver::config::Protocol::Tcp; - } - - ns.trust_negative_responses = !config.query_all_nameservers; - - conf.add_name_server(ns); - } - - opts.cache_size = config.dns_cache_entries as usize; - opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); - opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); - opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); - opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); - opts.timeout = Duration::from_secs(config.dns_timeout); - opts.attempts = config.dns_attempts as usize; - opts.try_tcp_on_error = config.dns_tcp_fallback; - opts.num_concurrent_reqs = 1; - opts.shuffle_dns_servers = true; - opts.rotate = true; - opts.ip_strategy = match config.ip_lookup_strategy { - 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, - 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, - 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, - 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, - _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, - }; - opts.authentic_data = false; - - let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); - let overrides = Arc::new(RwLock::new(TlsNameMap::new())); - Self { - destinations: Arc::new(RwLock::new(WellKnownMap::new())), - overrides: overrides.clone(), - resolver: resolver.clone(), - hooked: Arc::new(Hooked { - overrides, - resolver, - }), - } - } - - pub(super) fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); - writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; - - let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); - writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; - - Ok(()) - } - - pub(super) fn clear_cache(&self) { - self.overrides.write().expect("write locked").clear(); - self.destinations.write().expect("write locked").clear(); - self.resolver.clear_cache(); - } -} - -impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } -} - -impl Resolve for Hooked { - fn resolve(&self, name: Name) -> Resolving { - let cached = self - .overrides - .read() - .expect("locked for reading") - .get(name.as_str()) - .filter(|cached| cached.valid()) - .cloned(); - - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } - } -} - -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name - .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} - -fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); - - let results: Addrs = Box::new(results); - - Ok(results) - }) -} diff --git a/src/service/mod.rs b/src/service/mod.rs index 81e0be3b5..870d865b6 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -12,6 +12,7 @@ pub mod media; pub mod presence; pub mod pusher; +pub mod resolver; pub mod rooms; pub mod sending; pub mod transaction_ids; diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs new file mode 100644 index 000000000..62fd1625d --- /dev/null +++ b/src/service/resolver/mod.rs @@ -0,0 +1,349 @@ +use std::{ + collections::HashMap, + fmt, + fmt::Write, + future, iter, + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; + +use conduit::{err, trace, Result}; +use hickory_resolver::TokioAsyncResolver; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use ruma::{OwnedServerName, ServerName}; + +use crate::utils::rand; + +pub struct Service { + pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host + pub overrides: Arc<RwLock<TlsNameMap>>, + pub(crate) resolver: Arc<TokioAsyncResolver>, + pub(crate) hooked: Arc<Hooked>, +} + +pub(crate) struct Hooked { + overrides: Arc<RwLock<TlsNameMap>>, + resolver: Arc<TokioAsyncResolver>, +} + +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, + pub expire: SystemTime, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec<IpAddr>, + pub port: u16, + pub expire: SystemTime, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +type WellKnownMap = HashMap<OwnedServerName, CachedDest>; +type TlsNameMap = HashMap<String, CachedOverride>; + +impl crate::Service for Service { + #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let config = &args.server.config; + let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() + .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; + + let mut conf = hickory_resolver::config::ResolverConfig::new(); + + if let Some(domain) = sys_conf.domain() { + conf.set_domain(domain.clone()); + } + + for sys_conf in sys_conf.search() { + conf.add_search(sys_conf.clone()); + } + + for sys_conf in sys_conf.name_servers() { + let mut ns = sys_conf.clone(); + + if config.query_over_tcp_only { + ns.protocol = hickory_resolver::config::Protocol::Tcp; + } + + ns.trust_negative_responses = !config.query_all_nameservers; + + conf.add_name_server(ns); + } + + opts.cache_size = config.dns_cache_entries as usize; + opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); + opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); + opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); + opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); + opts.timeout = Duration::from_secs(config.dns_timeout); + opts.attempts = config.dns_attempts as usize; + opts.try_tcp_on_error = config.dns_tcp_fallback; + opts.num_concurrent_reqs = 1; + opts.shuffle_dns_servers = true; + opts.rotate = true; + opts.ip_strategy = match config.ip_lookup_strategy { + 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, + 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, + 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, + 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, + _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, + }; + opts.authentic_data = false; + + let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); + let overrides = Arc::new(RwLock::new(TlsNameMap::new())); + Ok(Arc::new(Self { + destinations: Arc::new(RwLock::new(WellKnownMap::new())), + overrides: overrides.clone(), + resolver: resolver.clone(), + hooked: Arc::new(Hooked { + overrides, + resolver, + }), + })) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); + writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; + + let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); + writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; + + Ok(()) + } + + fn clear_cache(&self) { + self.overrides.write().expect("write locked").clear(); + self.destinations.write().expect("write locked").clear(); + self.resolver.clear_cache(); + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> { + trace!(?name, ?dest, "set cached destination"); + self.destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + #[must_use] + pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> { + self.destinations + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> { + trace!(?name, ?over, "set cached override"); + self.overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + #[must_use] + pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> { + self.overrides + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + #[must_use] + pub fn has_cached_override(&self, name: &str) -> bool { + self.overrides + .read() + .expect("locked for reading") + .contains_key(name) + } +} + +impl Resolve for Service { + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } +} + +impl Resolve for Hooked { + fn resolve(&self, name: Name) -> Resolving { + let cached = self + .overrides + .read() + .expect("locked for reading") + .get(name.as_str()) + .cloned(); + + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) + } else { + resolve_to_reqwest(self.resolver.clone(), name) + } + } +} + +fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { + override_name + .first() + .map(|first_name| -> Resolving { + let saddr = SocketAddr::new(*first_name, port); + let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); + Box::pin(future::ready(Ok(result))) + }) + .expect("must provide at least one override name") +} + +fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { + Box::pin(async move { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); + + let results: Addrs = Box::new(results); + + Ok(results) + }) +} + +impl CachedDest { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } +} + +impl CachedOverride { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } +} + +pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { + if let Ok(dest) = dest_str.parse::<SocketAddr>() { + Some(FedDest::Literal(dest)) + } else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { + let (host, port) = match dest_str.find(':') { + None => (dest_str, ":8448"), + Some(pos) => dest_str.split_at(pos), + }; + + FedDest::Named(host.to_owned(), port.to_owned()) +} + +impl FedDest { + pub(crate) fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + pub(crate) fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => format!("{host}{port}"), + } + } + + pub(crate) fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + #[inline] + #[allow(clippy::string_slice)] + pub(crate) fn port(&self) -> Option<u16> { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +impl fmt::Display for FedDest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Named(host, port) => write!(f, "{host}{port}"), + Self::Literal(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); + } +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 88b8b1897..fd9acd3db 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,7 +7,7 @@ use std::fmt::Debug; use conduit::{err, Result}; -pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; +pub use resolve::resolve_actual_dest; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 48dddf352..4c1e47583 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -1,40 +1,17 @@ use std::{ - fmt, fmt::Debug, net::{IpAddr, SocketAddr}, - time::SystemTime, }; -use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result}; +use conduit::{debug, debug_error, debug_info, debug_warn, trace, Err, Error, Result}; use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; -use ruma::{OwnedServerName, ServerName}; - -use crate::services; - -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A `FedDest::Named` might contain an IP address in string form if there -/// was no port specified to construct a `SocketAddr` with. -/// -/// # Examples: -/// ```rust -/// # use conduit_service::sending::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), String::new()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), String::new()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} +use ruma::ServerName; + +use crate::{ + resolver::{add_port_to_hostname, get_ip_with_port, CachedDest, CachedOverride, FedDest}, + services, +}; #[derive(Clone, Debug)] pub(crate) struct ActualDest { @@ -44,27 +21,10 @@ pub(crate) struct ActualDest { pub(crate) cached: bool, } -#[derive(Clone, Debug)] -pub struct CachedDest { - pub dest: FedDest, - pub host: String, - pub expire: SystemTime, -} - -#[derive(Clone, Debug)] -pub struct CachedOverride { - pub ips: Vec<IpAddr>, - pub port: u16, - pub expire: SystemTime, -} - #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> { let cached; - let cached_result = services() - .globals - .resolver - .get_cached_destination(server_name); + let cached_result = services().resolver.get_cached_destination(server_name); let CachedDest { dest, @@ -213,7 +173,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(dest: &str) -> Result<Option<String>> { trace!("Requesting well known for {dest}"); - if !services().globals.resolver.has_cached_override(dest) { + if !services().resolver.has_cached_override(dest) { query_and_cache_override(dest, dest, 8448).await?; } @@ -273,7 +233,6 @@ async fn conditional_query_and_cache_override(overname: &str, hostname: &str, po #[tracing::instrument(skip_all, name = "ip")] async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { match services() - .globals .resolver .resolver .lookup_ip(hostname.to_owned()) @@ -285,7 +244,7 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 debug_info!("{overname:?} overriden by {hostname:?}"); } - services().globals.resolver.set_cached_override( + services().resolver.set_cached_override( overname.to_owned(), CachedOverride { ips: override_ip.iter().collect(), @@ -314,7 +273,6 @@ async fn lookup_srv(hostname: &str) -> Result<SrvLookup, ResolveError> { debug!("querying SRV for {:?}", hostname); let hostname = hostname.trim_end_matches('.'); services() - .globals .resolver .resolver .srv_lookup(hostname.to_owned()) @@ -384,166 +342,3 @@ pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { Ok(()) } - -fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { - if let Ok(dest) = dest_str.parse::<SocketAddr>() { - Some(FedDest::Literal(dest)) - } else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), - }; - - FedDest::Named(host.to_owned(), port.to_owned()) -} - -impl crate::globals::resolver::Resolver { - pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> { - trace!(?name, ?dest, "set cached destination"); - self.destinations - .write() - .expect("locked for writing") - .insert(name, dest) - } - - pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> { - self.destinations - .read() - .expect("locked for reading") - .get(name) - .filter(|cached| cached.valid()) - .cloned() - } - - pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> { - trace!(?name, ?over, "set cached override"); - self.overrides - .write() - .expect("locked for writing") - .insert(name, over) - } - - pub(crate) fn has_cached_override(&self, name: &str) -> bool { - self.overrides - .read() - .expect("locked for reading") - .get(name) - .filter(|cached| cached.valid()) - .is_some() - } -} - -impl CachedDest { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } -} - -impl CachedOverride { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } -} - -impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - #[inline] - #[allow(clippy::string_slice)] - fn port(&self) -> Option<u16> { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } -} diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 18a98828f..dc4541d6e 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -15,11 +15,8 @@ }; use tracing::{debug, trace}; -use super::{ - resolve, - resolve::{ActualDest, CachedDest}, -}; -use crate::{debug_error, debug_warn, services, Error, Result}; +use super::{resolve, resolve::ActualDest}; +use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result}; #[tracing::instrument(skip_all, name = "send")] pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> @@ -109,7 +106,7 @@ async fn handle_response<T>( let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services().globals.resolver.set_cached_destination( + services().resolver.set_cached_destination( dest.to_owned(), CachedDest { dest: actual.dest.clone(), diff --git a/src/service/service.rs b/src/service/service.rs index ab41753a8..8b49d455a 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -37,7 +37,7 @@ fn memory_usage(&self, _out: &mut dyn Write) -> Result<()> { Ok(()) } pub(crate) struct Args<'a> { pub(crate) server: &'a Arc<Server>, pub(crate) db: &'a Arc<Database>, - pub(crate) _service: &'a Map, + pub(crate) service: &'a Map, } pub(crate) type Map = BTreeMap<String, MapVal>; diff --git a/src/service/services.rs b/src/service/services.rs index 0aac10dc0..fb13f24a3 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -7,12 +7,14 @@ use crate::{ account_data, admin, appservice, globals, key_backups, manager::Manager, - media, presence, pusher, rooms, sending, service, + media, presence, pusher, resolver, rooms, sending, service, service::{Args, Map, Service}, transaction_ids, uiaa, updates, users, }; pub struct Services { + pub resolver: Arc<resolver::Service>, + pub globals: Arc<globals::Service>, pub rooms: rooms::Service, pub appservice: Arc<appservice::Service>, pub pusher: Arc<pusher::Service>, @@ -26,7 +28,6 @@ pub struct Services { pub media: Arc<media::Service>, pub sending: Arc<sending::Service>, pub updates: Arc<updates::Service>, - pub globals: Arc<globals::Service>, manager: Mutex<Option<Arc<Manager>>>, pub(crate) service: Map, @@ -42,7 +43,7 @@ macro_rules! build { let built = <$tyname>::build(Args { server: &server, db: &db, - _service: &service, + service: &service, })?; service.insert(built.name().to_owned(), (built.clone(), built.clone())); built @@ -50,6 +51,8 @@ macro_rules! build { } Ok(Self { + resolver: build!(resolver::Service), + globals: build!(globals::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), auth_chain: build!(rooms::auth_chain::Service), @@ -84,7 +87,6 @@ macro_rules! build { media: build!(media::Service), sending: build!(sending::Service), updates: build!(updates::Service), - globals: build!(globals::Service), manager: Mutex::new(None), service, server, -- GitLab From 8a2ae401df8a34b9a2ae823aedee8857d5737b48 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 22:29:42 +0000 Subject: [PATCH 05/47] convert Client into Service Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/federation/commands.rs | 1 - src/api/client/media.rs | 2 +- src/service/client/mod.rs | 143 ++++++++++++++++++++++++++++++ src/service/globals/mod.rs | 8 +- src/service/mod.rs | 3 +- src/service/pusher/mod.rs | 7 +- src/service/sending/appservice.rs | 1 - src/service/sending/mod.rs | 2 +- src/service/sending/resolve.rs | 1 - src/service/sending/sender.rs | 3 +- src/service/service.rs | 6 ++ src/service/services.rs | 4 +- src/service/updates/mod.rs | 1 - 13 files changed, 159 insertions(+), 23 deletions(-) create mode 100644 src/service/client/mod.rs diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index a97e7582e..331231aed 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -35,7 +35,6 @@ pub(super) async fn fetch_support_well_known( _body: Vec<&str>, server_name: Box<ServerName>, ) -> Result<RoomMessageEventContent> { let response = services() - .globals .client .default .get(format!("https://{server_name}/.well-known/matrix/support")) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 1adcefdd8..46f8152b0 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -621,7 +621,7 @@ async fn request_url_preview(services: &Services, url: &str) -> Result<UrlPrevie } } - let client = &services.globals.client.url_preview; + let client = &services.client.url_preview; let response = client.head(url).send().await?; if let Some(remote_addr) = response.remote_addr() { diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs new file mode 100644 index 000000000..eae7214ea --- /dev/null +++ b/src/service/client/mod.rs @@ -0,0 +1,143 @@ +use std::{sync::Arc, time::Duration}; + +use conduit::{Config, Result}; +use reqwest::redirect; + +use crate::{resolver, service}; + +pub struct Service { + pub default: reqwest::Client, + pub url_preview: reqwest::Client, + pub well_known: reqwest::Client, + pub federation: reqwest::Client, + pub sender: reqwest::Client, + pub appservice: reqwest::Client, + pub pusher: reqwest::Client, +} + +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let config = &args.server.config; + let resolver = args + .get_service::<resolver::Service>("resolver") + .expect("resolver must be built prior to client"); + + Ok(Arc::new(Self { + default: base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .build() + .unwrap(), + + url_preview: base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .redirect(redirect::Policy::limited(3)) + .build() + .unwrap(), + + well_known: base(config) + .unwrap() + .dns_resolver(resolver.hooked.clone()) + .connect_timeout(Duration::from_secs(config.well_known_conn_timeout)) + .read_timeout(Duration::from_secs(config.well_known_timeout)) + .timeout(Duration::from_secs(config.well_known_timeout)) + .pool_max_idle_per_host(0) + .redirect(redirect::Policy::limited(4)) + .build() + .unwrap(), + + federation: base(config) + .unwrap() + .dns_resolver(resolver.hooked.clone()) + .read_timeout(Duration::from_secs(config.federation_timeout)) + .timeout(Duration::from_secs(config.federation_timeout)) + .pool_max_idle_per_host(config.federation_idle_per_host.into()) + .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) + .redirect(redirect::Policy::limited(3)) + .build() + .unwrap(), + + sender: base(config) + .unwrap() + .dns_resolver(resolver.hooked.clone()) + .read_timeout(Duration::from_secs(config.sender_timeout)) + .timeout(Duration::from_secs(config.sender_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.sender_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + appservice: base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .connect_timeout(Duration::from_secs(5)) + .read_timeout(Duration::from_secs(config.appservice_timeout)) + .timeout(Duration::from_secs(config.appservice_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.appservice_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + pusher: base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + })) + } + + fn name(&self) -> &str { service::make_name(std::module_path!()) } +} + +fn base(config: &Config) -> Result<reqwest::ClientBuilder> { + let mut builder = reqwest::Client::builder() + .hickory_dns(true) + .connect_timeout(Duration::from_secs(config.request_conn_timeout)) + .read_timeout(Duration::from_secs(config.request_timeout)) + .timeout(Duration::from_secs(config.request_total_timeout)) + .pool_idle_timeout(Duration::from_secs(config.request_idle_timeout)) + .pool_max_idle_per_host(config.request_idle_per_host.into()) + .user_agent(conduit::version::user_agent()) + .redirect(redirect::Policy::limited(6)) + .connection_verbose(true); + + #[cfg(feature = "gzip_compression")] + { + builder = if config.gzip_compression { + builder.gzip(true) + } else { + builder.gzip(false).no_gzip() + }; + }; + + #[cfg(feature = "brotli_compression")] + { + builder = if config.brotli_compression { + builder.brotli(true) + } else { + builder.brotli(false).no_brotli() + }; + }; + + #[cfg(not(feature = "gzip_compression"))] + { + builder = builder.no_gzip(); + }; + + #[cfg(not(feature = "brotli_compression"))] + { + builder = builder.no_brotli(); + }; + + if let Some(proxy) = config.proxy.to_proxy()? { + Ok(builder.proxy(proxy)) + } else { + Ok(builder) + } +} diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index a2fba4d95..811aff3ad 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,4 +1,3 @@ -mod client; mod data; mod emerg_access; pub(super) mod migrations; @@ -24,7 +23,7 @@ use tokio::sync::Mutex; use url::Url; -use crate::{resolver, service, services}; +use crate::{service, services}; pub struct Service { pub db: Data, @@ -33,7 +32,6 @@ pub struct Service { pub cidr_range_denylist: Vec<IPAddress>, keypair: Arc<ruma::signatures::Ed25519KeyPair>, jwt_decoding_key: Option<jsonwebtoken::DecodingKey>, - pub client: client::Client, pub stable_room_versions: Vec<RoomVersionId>, pub unstable_room_versions: Vec<RoomVersionId>, pub bad_event_ratelimiter: Arc<RwLock<HashMap<OwnedEventId, RateLimitState>>>, @@ -85,15 +83,11 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { cidr_range_denylist.push(cidr); } - let resolver = service::get::<resolver::Service>(args.service, "resolver") - .expect("resolver must be built prior to globals"); - let mut s = Self { db, config: config.clone(), cidr_range_denylist, keypair: Arc::new(keypair), - client: client::Client::new(config, &resolver), jwt_decoding_key, stable_room_versions, unstable_room_versions, diff --git a/src/service/mod.rs b/src/service/mod.rs index 870d865b6..46adb072d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -7,6 +7,7 @@ pub mod account_data; pub mod admin; pub mod appservice; +pub mod client; pub mod globals; pub mod key_backups; pub mod media; @@ -25,7 +26,7 @@ use std::sync::{Arc, RwLock}; -pub(crate) use conduit::{config, debug_error, debug_warn, utils, Config, Error, Result, Server}; +pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; use database::Database; pub(crate) use service::{Args, Service}; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index ea48ea7c9..38ea5b9af 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -84,12 +84,7 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin } } - let response = services() - .globals - .client - .pusher - .execute(reqwest_request) - .await; + let response = services().client.pusher.execute(reqwest_request).await; match response { Ok(mut response) => { diff --git a/src/service/sending/appservice.rs b/src/service/sending/appservice.rs index 721424e17..9e060e811 100644 --- a/src/service/sending/appservice.rs +++ b/src/service/sending/appservice.rs @@ -49,7 +49,6 @@ pub(crate) async fn send_request<T>(registration: Registration, request: T) -> R let reqwest_request = reqwest::Request::try_from(http_request)?; let mut response = services() - .globals .client .appservice .execute(reqwest_request) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index fd9acd3db..be10184db 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -194,7 +194,7 @@ pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> where T: OutgoingRequest + Debug + Send, { - let client = &services().globals.client.federation; + let client = &services().client.federation; send::send(client, dest, request).await } diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 4c1e47583..f8d1c51cf 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -178,7 +178,6 @@ async fn request_well_known(dest: &str) -> Result<Option<String>> { } let response = services() - .globals .client .well_known .get(&format!("https://{dest}/.well-known/matrix/server")) diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 2f542dfe4..e6b68e9ec 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -611,11 +611,10 @@ async fn send_events_dest_normal( } } - let client = &services().globals.client.sender; //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty // transaction"); send::send( - client, + &services().client.sender, server, send_transaction_message::v1::Request { origin: services().globals.server_name().to_owned(), diff --git a/src/service/service.rs b/src/service/service.rs index 8b49d455a..99b8723a0 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -43,6 +43,12 @@ pub(crate) struct Args<'a> { pub(crate) type Map = BTreeMap<String, MapVal>; pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); +impl Args<'_> { + pub(crate) fn get_service<T: Any + Send + Sync>(&self, name: &str) -> Option<Arc<T>> { + get::<T>(self.service, name) + } +} + pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> { map.get(name).map(|(_, s)| { s.clone() diff --git a/src/service/services.rs b/src/service/services.rs index fb13f24a3..03f3a9baa 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -5,7 +5,7 @@ use tokio::sync::Mutex; use crate::{ - account_data, admin, appservice, globals, key_backups, + account_data, admin, appservice, client, globals, key_backups, manager::Manager, media, presence, pusher, resolver, rooms, sending, service, service::{Args, Map, Service}, @@ -14,6 +14,7 @@ pub struct Services { pub resolver: Arc<resolver::Service>, + pub client: Arc<client::Service>, pub globals: Arc<globals::Service>, pub rooms: rooms::Service, pub appservice: Arc<appservice::Service>, @@ -52,6 +53,7 @@ macro_rules! build { Ok(Self { resolver: build!(resolver::Service), + client: build!(client::Service), globals: build!(globals::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index 3fb680d63..db69d9b0f 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -64,7 +64,6 @@ impl Service { #[tracing::instrument(skip_all)] async fn handle_updates(&self) -> Result<()> { let response = services() - .globals .client .default .get(CHECK_FOR_UPDATES_URL) -- GitLab From 3ccd9ea326cefaf80946b4c334f314d3c87b0598 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 16 Jul 2024 23:38:48 +0000 Subject: [PATCH 06/47] consolidate all resolution in resolver; split units Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 7 +- src/admin/query/resolver.rs | 13 +- src/service/client/mod.rs | 14 +- src/service/resolver/actual.rs | 356 +++++++++++++++++++++++++++++++++ src/service/resolver/cache.rs | 114 +++++++++++ src/service/resolver/dns.rs | 129 ++++++++++++ src/service/resolver/fed.rs | 70 +++++++ src/service/resolver/mod.rs | 351 ++++---------------------------- src/service/resolver/tests.rs | 43 ++++ src/service/sending/mod.rs | 2 - src/service/sending/send.rs | 11 +- 11 files changed, 774 insertions(+), 336 deletions(-) create mode 100644 src/service/resolver/actual.rs create mode 100644 src/service/resolver/cache.rs create mode 100644 src/service/resolver/dns.rs create mode 100644 src/service/resolver/fed.rs create mode 100644 src/service/resolver/tests.rs diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 7c7f93311..f319f5a55 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -15,7 +15,7 @@ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, sending::resolve_actual_dest, services, PduEvent}; +use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent}; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -628,7 +628,10 @@ pub(super) async fn resolve_true_destination( let capture = Capture::new(state, Some(filter), capture::fmt_markdown(logs.clone())); let capture_scope = capture.start(); - let actual = resolve_actual_dest(&server_name, !no_cache).await?; + let actual = services() + .resolver + .resolve_actual_dest(&server_name, !no_cache) + .await?; drop(capture_scope); let msg = format!( diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 06cd8ba95..37d179609 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result<RoomMessageEventCon } async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> { - use service::resolver::CachedDest; + use service::resolver::cache::CachedDest; let mut out = String::new(); writeln!(out, "| Server Name | Destination | Hostname | Expires |")?; @@ -36,7 +36,12 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line"); }; - let map = services().resolver.destinations.read().expect("locked"); + let map = services() + .resolver + .cache + .destinations + .read() + .expect("locked"); if let Some(server_name) = server_name.as_ref() { map.get_key_value(server_name).map(row); @@ -48,7 +53,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room } async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEventContent> { - use service::resolver::CachedOverride; + use service::resolver::cache::CachedOverride; let mut out = String::new(); writeln!(out, "| Server Name | IP | Port | Expires |")?; @@ -65,7 +70,7 @@ async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEvent writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line"); }; - let map = services().resolver.overrides.read().expect("locked"); + let map = services().resolver.cache.overrides.read().expect("locked"); if let Some(server_name) = server_name.as_ref() { map.get_key_value(server_name).map(row); diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index eae7214ea..cc8d52d1a 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -25,20 +25,20 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { default: base(config) .unwrap() - .dns_resolver(resolver.clone()) + .dns_resolver(resolver.resolver.clone()) .build() .unwrap(), url_preview: base(config) .unwrap() - .dns_resolver(resolver.clone()) + .dns_resolver(resolver.resolver.clone()) .redirect(redirect::Policy::limited(3)) .build() .unwrap(), well_known: base(config) .unwrap() - .dns_resolver(resolver.hooked.clone()) + .dns_resolver(resolver.resolver.hooked.clone()) .connect_timeout(Duration::from_secs(config.well_known_conn_timeout)) .read_timeout(Duration::from_secs(config.well_known_timeout)) .timeout(Duration::from_secs(config.well_known_timeout)) @@ -49,7 +49,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { federation: base(config) .unwrap() - .dns_resolver(resolver.hooked.clone()) + .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.federation_timeout)) .timeout(Duration::from_secs(config.federation_timeout)) .pool_max_idle_per_host(config.federation_idle_per_host.into()) @@ -60,7 +60,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { sender: base(config) .unwrap() - .dns_resolver(resolver.hooked.clone()) + .dns_resolver(resolver.resolver.hooked.clone()) .read_timeout(Duration::from_secs(config.sender_timeout)) .timeout(Duration::from_secs(config.sender_timeout)) .pool_max_idle_per_host(1) @@ -71,7 +71,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { appservice: base(config) .unwrap() - .dns_resolver(resolver.clone()) + .dns_resolver(resolver.resolver.clone()) .connect_timeout(Duration::from_secs(5)) .read_timeout(Duration::from_secs(config.appservice_timeout)) .timeout(Duration::from_secs(config.appservice_timeout)) @@ -83,7 +83,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { pusher: base(config) .unwrap() - .dns_resolver(resolver.clone()) + .dns_resolver(resolver.resolver.clone()) .pool_max_idle_per_host(1) .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) .redirect(redirect::Policy::limited(2)) diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs new file mode 100644 index 000000000..b2a00023f --- /dev/null +++ b/src/service/resolver/actual.rs @@ -0,0 +1,356 @@ +use std::{ + fmt::Debug, + net::{IpAddr, SocketAddr}, + sync::Arc, +}; + +use conduit::{debug, debug_error, debug_info, debug_warn, trace, Err, Error, Result}; +use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; +use ipaddress::IPAddress; +use ruma::ServerName; + +use crate::{ + resolver::{ + cache::{CachedDest, CachedOverride}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest}, + }, + services, +}; + +#[derive(Clone, Debug)] +pub(crate) struct ActualDest { + pub(crate) dest: FedDest, + pub(crate) host: String, + pub(crate) string: String, + pub(crate) cached: bool, +} + +impl super::Service { + #[tracing::instrument(skip_all, name = "resolve")] + pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<ActualDest> { + let cached; + let cached_result = self.get_cached_destination(server_name); + + let CachedDest { + dest, + host, + .. + } = if let Some(result) = cached_result { + cached = true; + result + } else { + cached = false; + validate_dest(server_name)?; + self.resolve_actual_dest(server_name, true).await? + }; + + let string = dest.clone().into_https_string(); + Ok(ActualDest { + dest, + host, + string, + cached, + }) + } + + /// Returns: `actual_destination`, host header + /// Implemented according to the specification at <https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names> + /// Numbers in comments below refer to bullet points in linked section of + /// specification + #[tracing::instrument(skip_all, name = "actual")] + pub async fn resolve_actual_dest(&self, dest: &ServerName, cache: bool) -> Result<CachedDest> { + trace!("Finding actual destination for {dest}"); + let mut host = dest.as_str().to_owned(); + let actual_dest = match get_ip_with_port(dest.as_str()) { + Some(host_port) => Self::actual_dest_1(host_port)?, + None => { + if let Some(pos) = dest.as_str().find(':') { + self.actual_dest_2(dest, cache, pos).await? + } else if let Some(delegated) = self.request_well_known(dest.as_str()).await? { + self.actual_dest_3(&mut host, cache, delegated).await? + } else if let Some(overrider) = self.query_srv_record(dest.as_str()).await? { + self.actual_dest_4(&host, cache, overrider).await? + } else { + self.actual_dest_5(dest, cache).await? + } + }, + }; + + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let host = if let Ok(addr) = host.parse::<SocketAddr>() { + FedDest::Literal(addr) + } else if let Ok(addr) = host.parse::<IpAddr>() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = host.find(':') { + let (host, port) = host.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(host, ":8448".to_owned()) + }; + + debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); + Ok(CachedDest { + dest: actual_dest, + host: host.into_uri_string(), + expire: CachedDest::default_expire(), + }) + } + + fn actual_dest_1(host_port: FedDest) -> Result<FedDest> { + debug!("1: IP literal with provided or default port"); + Ok(host_port) + } + + async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> { + debug!("2: Hostname with included port"); + let (host, port) = dest.as_str().split_at(pos); + self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache) + .await?; + Ok(FedDest::Named(host.to_owned(), port.to_owned())) + } + + async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result<FedDest> { + debug!("3: A .well-known file is available"); + *host = add_port_to_hostname(&delegated).into_uri_string(); + match get_ip_with_port(&delegated) { + Some(host_and_port) => Self::actual_dest_3_1(host_and_port), + None => { + if let Some(pos) = delegated.find(':') { + self.actual_dest_3_2(cache, delegated, pos).await + } else { + trace!("Delegated hostname has no port in this branch"); + if let Some(overrider) = self.query_srv_record(&delegated).await? { + self.actual_dest_3_3(cache, delegated, overrider).await + } else { + self.actual_dest_3_4(cache, delegated).await + } + } + }, + } + } + + fn actual_dest_3_1(host_and_port: FedDest) -> Result<FedDest> { + debug!("3.1: IP literal in .well-known file"); + Ok(host_and_port) + } + + async fn actual_dest_3_2(&self, cache: bool, delegated: String, pos: usize) -> Result<FedDest> { + debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated.split_at(pos); + self.conditional_query_and_cache_override(host, host, port.parse::<u16>().unwrap_or(8448), cache) + .await?; + Ok(FedDest::Named(host.to_owned(), port.to_owned())) + } + + async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result<FedDest> { + debug!("3.3: SRV lookup successful"); + let force_port = overrider.port(); + self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache) + .await?; + if let Some(port) = force_port { + Ok(FedDest::Named(delegated, format!(":{port}"))) + } else { + Ok(add_port_to_hostname(&delegated)) + } + } + + async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result<FedDest> { + debug!("3.4: No SRV records, just use the hostname from .well-known"); + self.conditional_query_and_cache_override(&delegated, &delegated, 8448, cache) + .await?; + Ok(add_port_to_hostname(&delegated)) + } + + async fn actual_dest_4(&self, host: &str, cache: bool, overrider: FedDest) -> Result<FedDest> { + debug!("4: No .well-known; SRV record found"); + let force_port = overrider.port(); + self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache) + .await?; + if let Some(port) = force_port { + Ok(FedDest::Named(host.to_owned(), format!(":{port}"))) + } else { + Ok(add_port_to_hostname(host)) + } + } + + async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> { + debug!("5: No SRV record found"); + self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache) + .await?; + Ok(add_port_to_hostname(dest.as_str())) + } + + #[tracing::instrument(skip_all, name = "well-known")] + async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { + trace!("Requesting well known for {dest}"); + if !self.has_cached_override(dest) { + self.query_and_cache_override(dest, dest, 8448).await?; + } + + let response = services() + .client + .well_known + .get(&format!("https://{dest}/.well-known/matrix/server")) + .send() + .await; + + trace!("response: {:?}", response); + if let Err(e) = &response { + debug!("error: {e:?}"); + return Ok(None); + } + + let response = response?; + if !response.status().is_success() { + debug!("response not 2XX"); + return Ok(None); + } + + let text = response.text().await?; + trace!("response text: {:?}", text); + if text.len() >= 12288 { + debug_warn!("response contains junk"); + return Ok(None); + } + + let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); + + let m_server = body + .get("m.server") + .unwrap_or(&serde_json::Value::Null) + .as_str() + .unwrap_or_default(); + + if ruma::identifiers_validation::server_name::validate(m_server).is_err() { + debug_error!("response content missing or invalid"); + return Ok(None); + } + + debug_info!("{:?} found at {:?}", dest, m_server); + Ok(Some(m_server.to_owned())) + } + + #[inline] + async fn conditional_query_and_cache_override( + &self, overname: &str, hostname: &str, port: u16, cache: bool, + ) -> Result<()> { + if cache { + self.query_and_cache_override(overname, hostname, port) + .await + } else { + Ok(()) + } + } + + #[tracing::instrument(skip_all, name = "ip")] + async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { + match services() + .resolver + .raw() + .lookup_ip(hostname.to_owned()) + .await + { + Err(e) => handle_resolve_error(&e), + Ok(override_ip) => { + if hostname != overname { + debug_info!("{overname:?} overriden by {hostname:?}"); + } + + services().resolver.set_cached_override( + overname.to_owned(), + CachedOverride { + ips: override_ip.iter().collect(), + port, + expire: CachedOverride::default_expire(), + }, + ); + + Ok(()) + }, + } + } + + #[tracing::instrument(skip_all, name = "srv")] + async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> { + fn handle_successful_srv(srv: &SrvLookup) -> Option<FedDest> { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + } + + async fn lookup_srv( + resolver: Arc<super::TokioAsyncResolver>, hostname: &str, + ) -> Result<SrvLookup, ResolveError> { + debug!("querying SRV for {hostname:?}"); + let hostname = hostname.trim_end_matches('.'); + resolver.srv_lookup(hostname.to_owned()).await + } + + let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; + + for hostname in hostnames { + match lookup_srv(self.raw(), &hostname).await { + Ok(result) => return Ok(handle_successful_srv(&result)), + Err(e) => handle_resolve_error(&e)?, + } + } + + Ok(None) + } +} + +#[allow(clippy::single_match_else)] +fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; + + match *e.kind() { + ResolveErrorKind::NoRecordsFound { + .. + } => { + // Raise to debug_warn if we can find out the result wasn't from cache + debug!("{e}"); + Ok(()) + }, + _ => Err!(error!("DNS {e}")), + } +} + +fn validate_dest(dest: &ServerName) -> Result<()> { + if dest == services().globals.server_name() { + return Err!("Won't send federation request to ourselves"); + } + + if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { + validate_dest_ip_literal(dest)?; + } + + Ok(()) +} + +fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { + trace!("Destination is an IP literal, checking against IP range denylist.",); + debug_assert!( + dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), + "Destination is not an IP literal." + ); + let ip = IPAddress::parse(dest.host()).map_err(|e| { + debug_error!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + validate_ip(&ip)?; + + Ok(()) +} + +pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { + if !services().globals.valid_cidr_range(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + + Ok(()) +} diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs new file mode 100644 index 000000000..0fba24006 --- /dev/null +++ b/src/service/resolver/cache.rs @@ -0,0 +1,114 @@ +use std::{ + collections::HashMap, + net::IpAddr, + sync::{Arc, RwLock}, + time::SystemTime, +}; + +use conduit::trace; +use ruma::{OwnedServerName, ServerName}; + +use super::fed::FedDest; +use crate::utils::rand; + +pub struct Cache { + pub destinations: RwLock<WellKnownMap>, // actual_destination, host + pub overrides: RwLock<TlsNameMap>, +} + +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, + pub expire: SystemTime, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec<IpAddr>, + pub port: u16, + pub expire: SystemTime, +} + +pub type WellKnownMap = HashMap<OwnedServerName, CachedDest>; +pub type TlsNameMap = HashMap<String, CachedOverride>; + +impl Cache { + pub(super) fn new() -> Arc<Self> { + Arc::new(Self { + destinations: RwLock::new(WellKnownMap::new()), + overrides: RwLock::new(TlsNameMap::new()), + }) + } +} + +impl super::Service { + pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> { + trace!(?name, ?dest, "set cached destination"); + self.cache + .destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + #[must_use] + pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> { + self.cache + .destinations + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> { + trace!(?name, ?over, "set cached override"); + self.cache + .overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + #[must_use] + pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> { + self.cache + .overrides + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + #[must_use] + pub fn has_cached_override(&self, name: &str) -> bool { + self.cache + .overrides + .read() + .expect("locked for reading") + .contains_key(name) + } +} + +impl CachedDest { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } +} + +impl CachedOverride { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } +} diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs new file mode 100644 index 000000000..b77bbb84f --- /dev/null +++ b/src/service/resolver/dns.rs @@ -0,0 +1,129 @@ +use std::{ + future, iter, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use conduit::{err, Result, Server}; +use hickory_resolver::TokioAsyncResolver; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; + +use super::cache::Cache; + +pub struct Resolver { + pub(crate) resolver: Arc<TokioAsyncResolver>, + pub(crate) hooked: Arc<Hooked>, +} + +pub(crate) struct Hooked { + resolver: Arc<TokioAsyncResolver>, + cache: Arc<Cache>, +} + +impl Resolver { + #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] + pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> { + let config = &server.config; + let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() + .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; + + let mut conf = hickory_resolver::config::ResolverConfig::new(); + + if let Some(domain) = sys_conf.domain() { + conf.set_domain(domain.clone()); + } + + for sys_conf in sys_conf.search() { + conf.add_search(sys_conf.clone()); + } + + for sys_conf in sys_conf.name_servers() { + let mut ns = sys_conf.clone(); + + if config.query_over_tcp_only { + ns.protocol = hickory_resolver::config::Protocol::Tcp; + } + + ns.trust_negative_responses = !config.query_all_nameservers; + + conf.add_name_server(ns); + } + + opts.cache_size = config.dns_cache_entries as usize; + opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); + opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); + opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); + opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); + opts.timeout = Duration::from_secs(config.dns_timeout); + opts.attempts = config.dns_attempts as usize; + opts.try_tcp_on_error = config.dns_tcp_fallback; + opts.num_concurrent_reqs = 1; + opts.shuffle_dns_servers = true; + opts.rotate = true; + opts.ip_strategy = match config.ip_lookup_strategy { + 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, + 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, + 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, + 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, + _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, + }; + opts.authentic_data = false; + + let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); + Ok(Arc::new(Self { + resolver: resolver.clone(), + hooked: Arc::new(Hooked { + resolver, + cache, + }), + })) + } +} + +impl Resolve for Resolver { + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } +} + +impl Resolve for Hooked { + fn resolve(&self, name: Name) -> Resolving { + let cached = self + .cache + .overrides + .read() + .expect("locked for reading") + .get(name.as_str()) + .cloned(); + + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) + } else { + resolve_to_reqwest(self.resolver.clone(), name) + } + } +} + +fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { + override_name + .first() + .map(|first_name| -> Resolving { + let saddr = SocketAddr::new(*first_name, port); + let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); + Box::pin(future::ready(Ok(result))) + }) + .expect("must provide at least one override name") +} + +fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { + Box::pin(async move { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); + + let results: Addrs = Box::new(results); + + Ok(results) + }) +} diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs new file mode 100644 index 000000000..10cbbbdd0 --- /dev/null +++ b/src/service/resolver/fed.rs @@ -0,0 +1,70 @@ +use std::{ + fmt, + net::{IpAddr, SocketAddr}, +}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { + if let Ok(dest) = dest_str.parse::<SocketAddr>() { + Some(FedDest::Literal(dest)) + } else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { + let (host, port) = match dest_str.find(':') { + None => (dest_str, ":8448"), + Some(pos) => dest_str.split_at(pos), + }; + + FedDest::Named(host.to_owned(), port.to_owned()) +} + +impl FedDest { + pub(crate) fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + pub(crate) fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => format!("{host}{port}"), + } + } + + pub(crate) fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + #[inline] + #[allow(clippy::string_slice)] + pub(crate) fn port(&self) -> Option<u16> { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +impl fmt::Display for FedDest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Named(host, port) => write!(f, "{host}{port}"), + Self::Literal(addr) => write!(f, "{addr}"), + } + } +} diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 62fd1625d..48ff8813d 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -1,349 +1,66 @@ -use std::{ - collections::HashMap, - fmt, - fmt::Write, - future, iter, - net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, - time::{Duration, SystemTime}, -}; +pub mod actual; +pub mod cache; +mod dns; +pub mod fed; +mod tests; -use conduit::{err, trace, Result}; +use std::{fmt::Write, sync::Arc}; + +use conduit::Result; use hickory_resolver::TokioAsyncResolver; -use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use ruma::{OwnedServerName, ServerName}; -use crate::utils::rand; +use self::{cache::Cache, dns::Resolver}; pub struct Service { - pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host - pub overrides: Arc<RwLock<TlsNameMap>>, - pub(crate) resolver: Arc<TokioAsyncResolver>, - pub(crate) hooked: Arc<Hooked>, -} - -pub(crate) struct Hooked { - overrides: Arc<RwLock<TlsNameMap>>, - resolver: Arc<TokioAsyncResolver>, -} - -#[derive(Clone, Debug)] -pub struct CachedDest { - pub dest: FedDest, - pub host: String, - pub expire: SystemTime, -} - -#[derive(Clone, Debug)] -pub struct CachedOverride { - pub ips: Vec<IpAddr>, - pub port: u16, - pub expire: SystemTime, + pub cache: Arc<Cache>, + pub resolver: Arc<Resolver>, } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -type WellKnownMap = HashMap<OwnedServerName, CachedDest>; -type TlsNameMap = HashMap<String, CachedOverride>; - impl crate::Service for Service { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { - let config = &args.server.config; - let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; - - let mut conf = hickory_resolver::config::ResolverConfig::new(); - - if let Some(domain) = sys_conf.domain() { - conf.set_domain(domain.clone()); - } - - for sys_conf in sys_conf.search() { - conf.add_search(sys_conf.clone()); - } - - for sys_conf in sys_conf.name_servers() { - let mut ns = sys_conf.clone(); - - if config.query_over_tcp_only { - ns.protocol = hickory_resolver::config::Protocol::Tcp; - } - - ns.trust_negative_responses = !config.query_all_nameservers; - - conf.add_name_server(ns); - } - - opts.cache_size = config.dns_cache_entries as usize; - opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); - opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); - opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); - opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); - opts.timeout = Duration::from_secs(config.dns_timeout); - opts.attempts = config.dns_attempts as usize; - opts.try_tcp_on_error = config.dns_tcp_fallback; - opts.num_concurrent_reqs = 1; - opts.shuffle_dns_servers = true; - opts.rotate = true; - opts.ip_strategy = match config.ip_lookup_strategy { - 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, - 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, - 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, - 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, - _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, - }; - opts.authentic_data = false; - - let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); - let overrides = Arc::new(RwLock::new(TlsNameMap::new())); + let cache = Cache::new(); Ok(Arc::new(Self { - destinations: Arc::new(RwLock::new(WellKnownMap::new())), - overrides: overrides.clone(), - resolver: resolver.clone(), - hooked: Arc::new(Hooked { - overrides, - resolver, - }), + cache: cache.clone(), + resolver: Resolver::build(args.server, cache)?, })) } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); + let resolver_overrides_cache = self + .cache + .overrides + .read() + .expect("locked for reading") + .len(); writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; - let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); + let resolver_destinations_cache = self + .cache + .destinations + .read() + .expect("locked for reading") + .len(); writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; Ok(()) } fn clear_cache(&self) { - self.overrides.write().expect("write locked").clear(); - self.destinations.write().expect("write locked").clear(); - self.resolver.clear_cache(); + self.cache.overrides.write().expect("write locked").clear(); + self.cache + .destinations + .write() + .expect("write locked") + .clear(); + self.resolver.resolver.clear_cache(); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> { - trace!(?name, ?dest, "set cached destination"); - self.destinations - .write() - .expect("locked for writing") - .insert(name, dest) - } - - #[must_use] - pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> { - self.destinations - .read() - .expect("locked for reading") - .get(name) - .cloned() - } - - pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> { - trace!(?name, ?over, "set cached override"); - self.overrides - .write() - .expect("locked for writing") - .insert(name, over) - } - - #[must_use] - pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> { - self.overrides - .read() - .expect("locked for reading") - .get(name) - .cloned() - } - - #[must_use] - pub fn has_cached_override(&self, name: &str) -> bool { - self.overrides - .read() - .expect("locked for reading") - .contains_key(name) - } -} - -impl Resolve for Service { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } -} - -impl Resolve for Hooked { - fn resolve(&self, name: Name) -> Resolving { - let cached = self - .overrides - .read() - .expect("locked for reading") - .get(name.as_str()) - .cloned(); - - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } - } -} - -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name - .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} - -fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); - - let results: Addrs = Box::new(results); - - Ok(results) - }) -} - -impl CachedDest { #[inline] #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } -} - -impl CachedOverride { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } -} - -pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> { - if let Ok(dest) = dest_str.parse::<SocketAddr>() { - Some(FedDest::Literal(dest)) - } else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), - }; - - FedDest::Named(host.to_owned(), port.to_owned()) -} - -impl FedDest { - pub(crate) fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - pub(crate) fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - pub(crate) fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - #[inline] - #[allow(clippy::string_slice)] - pub(crate) fn port(&self) -> Option<u16> { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } + pub fn raw(&self) -> Arc<TokioAsyncResolver> { self.resolver.resolver.clone() } } diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs new file mode 100644 index 000000000..55cf0345d --- /dev/null +++ b/src/service/resolver/tests.rs @@ -0,0 +1,43 @@ +#![cfg(test)] + +use super::fed::{add_port_to_hostname, get_ip_with_port, FedDest}; + +#[test] +fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); +} + +#[test] +fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); +} + +#[test] +fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); +} + +#[test] +fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index be10184db..1eacca778 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,13 +1,11 @@ mod appservice; mod data; -mod resolve; mod send; mod sender; use std::fmt::Debug; use conduit::{err, Result}; -pub use resolve::resolve_actual_dest; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index dc4541d6e..df3139c31 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -15,8 +15,11 @@ }; use tracing::{debug, trace}; -use super::{resolve, resolve::ActualDest}; -use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result}; +use crate::{ + debug_error, debug_warn, resolver, + resolver::{actual::ActualDest, cache::CachedDest}, + services, Error, Result, +}; #[tracing::instrument(skip_all, name = "send")] pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> @@ -27,7 +30,7 @@ pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::In return Err!(Config("allow_federation", "Federation is disabled.")); } - let actual = resolve::get_actual_dest(dest).await?; + let actual = services().resolver.get_actual_dest(dest).await?; let request = prepare::<T>(dest, &actual, req).await?; execute::<T>(client, dest, &actual, request).await } @@ -219,7 +222,7 @@ fn validate_url(url: &Url) -> Result<()> { if let Some(url_host) = url.host_str() { if let Ok(ip) = IPAddress::parse(url_host) { trace!("Checking request URL IP {ip:?}"); - resolve::validate_ip(&ip)?; + resolver::actual::validate_ip(&ip)?; } } -- GitLab From 29fc5b9b524db556a5b45d899ff25fb9883bf097 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 17 Jul 2024 01:00:57 +0000 Subject: [PATCH 07/47] de-global some services in services Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/sending/data.rs | 4 ++-- src/service/sending/mod.rs | 17 ++++++++--------- src/service/sending/sender.rs | 26 ++++++++++++-------------- src/service/services.rs | 2 +- 4 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 657256185..9cb1c2670 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -14,7 +14,7 @@ pub struct Data { servercurrentevent_data: Arc<Map>, servernameevent_data: Arc<Map>, servername_educount: Arc<Map>, - _db: Arc<Database>, + pub(super) db: Arc<Database>, } impl Data { @@ -23,7 +23,7 @@ pub(super) fn new(db: Arc<Database>) -> Self { servercurrentevent_data: db["servercurrentevent_data"].clone(), servernameevent_data: db["servernameevent_data"].clone(), servername_educount: db["servername_educount"].clone(), - _db: db, + db, } } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 1eacca778..a6c3411f1 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -3,9 +3,9 @@ mod send; mod sender; -use std::fmt::Debug; +use std::{fmt::Debug, sync::Arc}; -use conduit::{err, Result}; +use conduit::{err, Result, Server}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -18,12 +18,11 @@ pub struct Service { pub db: data::Data, + server: Arc<Server>, /// The state for a given state hash. sender: loole::Sender<Msg>, receiver: Mutex<loole::Receiver<Msg>>, - startup_netburst: bool, - startup_netburst_keep: i64, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -53,7 +52,7 @@ impl Service { pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -66,7 +65,7 @@ pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Re pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -93,7 +92,7 @@ pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Pdu(pdu_id.to_owned()))) .collect::<Vec<_>>(); - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let keys = self.db.queue_requests( &requests .iter() @@ -115,7 +114,7 @@ pub fn send_pdu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let keys = self.db.queue_requests(&[(&dest, event.clone())])?; self.dispatch(Msg { dest, @@ -142,7 +141,7 @@ pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, .into_iter() .map(|server| (Destination::Normal(server), SendingEvent::Edu(serialized.clone()))) .collect::<Vec<_>>(); - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let keys = self.db.queue_requests( &requests .iter() diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index e6b68e9ec..a924ce556 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -49,14 +49,12 @@ enum TransactionStatus { #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { - let config = &args.server.config; let (sender, receiver) = loole::unbounded(); Ok(Arc::new(Self { db: Data::new(args.db.clone()), + server: args.server.clone(), sender, receiver: Mutex::new(receiver), - startup_netburst: config.startup_netburst, - startup_netburst_keep: config.startup_netburst_keep, })) } @@ -119,7 +117,7 @@ fn handle_response_err( fn handle_response_ok( &self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); self.db .delete_all_active_requests_for(dest) .expect("all active requests deleted"); @@ -174,11 +172,11 @@ async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mu } fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { - let keep = usize::try_from(self.startup_netburst_keep).unwrap_or(usize::MAX); + let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { let entry = txns.entry(dest.clone()).or_default(); - if self.startup_netburst_keep >= 0 && entry.len() >= keep { + if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); self.db .delete_active_request(&key) @@ -189,7 +187,7 @@ fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTrans } for (dest, events) in txns { - if self.startup_netburst && !events.is_empty() { + if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); futures.push(Box::pin(send_events(dest.clone(), events))); } @@ -210,7 +208,7 @@ fn select_events( return Ok(None); } - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); let mut events = Vec::new(); // Must retry any previous transaction for this remote. @@ -224,7 +222,7 @@ fn select_events( } // Compose the next transaction - let _cork = services().db.cork(); + let _cork = self.db.db.cork(); if !new_events.is_empty() { self.db.mark_as_active(&new_events)?; for (e, _) in new_events { @@ -251,8 +249,8 @@ fn select_events_current(&self, dest: Destination, statuses: &mut CurTransaction .and_modify(|e| match e { TransactionStatus::Failed(tries, time) => { // Fail if a request has failed recently (exponential backoff) - let min = services().globals.config.sender_timeout; - let max = services().globals.config.sender_retry_backoff_limit; + let min = self.server.config.sender_timeout; + let max = self.server.config.sender_retry_backoff_limit; if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) { allow = false; } else { @@ -288,7 +286,7 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { .filter(|user_id| user_is_local(user_id)), ); - if services().globals.allow_outgoing_read_receipts() + if self.server.config.allow_outgoing_read_receipts && !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? { break; @@ -311,7 +309,7 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { events.push(serde_json::to_vec(&edu).expect("json can be serialized")); } - if services().globals.allow_outgoing_presence() { + if self.server.config.allow_outgoing_presence { select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; } @@ -617,7 +615,7 @@ async fn send_events_dest_normal( &services().client.sender, server, send_transaction_message::v1::Request { - origin: services().globals.server_name().to_owned(), + origin: services().server.config.server_name.clone(), pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), diff --git a/src/service/services.rs b/src/service/services.rs index 03f3a9baa..d0f74e13a 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -52,9 +52,9 @@ macro_rules! build { } Ok(Self { + globals: build!(globals::Service), resolver: build!(resolver::Service), client: build!(client::Service), - globals: build!(globals::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), auth_chain: build!(rooms::auth_chain::Service), -- GitLab From b0ac5255c8a6d211df86dde8a22cccfd9189b8c0 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 17 Jul 2024 02:19:03 +0000 Subject: [PATCH 08/47] move sending service impl properly back to mod root Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/sending/mod.rs | 27 +++++++++++++++++++++++++++ src/service/sending/sender.rs | 31 ++++--------------------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index a6c3411f1..26f43fd32 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -5,6 +5,7 @@ use std::{fmt::Debug, sync::Arc}; +use async_trait::async_trait; use conduit::{err, Result, Server}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, @@ -47,6 +48,32 @@ pub enum SendingEvent { Flush, // none } +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let (sender, receiver) = loole::unbounded(); + Ok(Arc::new(Self { + db: data::Data::new(args.db.clone()), + server: args.server.clone(), + sender, + receiver: Mutex::new(receiver), + })) + } + + async fn worker(self: Arc<Self>) -> Result<()> { + // trait impl can't be split between files so this just glues to mod sender + self.sender().await + } + + fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index a924ce556..df41db28b 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -2,11 +2,9 @@ cmp, collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, - sync::Arc, time::{Duration, Instant}, }; -use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; use federation::transactions::send_transaction_message; @@ -24,9 +22,9 @@ ServerName, UInt, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::{sync::Mutex, time::sleep_until}; +use tokio::time::sleep_until; -use super::{appservice, data::Data, send, Destination, Msg, SendingEvent, Service}; +use super::{appservice, send, Destination, Msg, SendingEvent, Service}; use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; #[derive(Debug)] @@ -46,20 +44,9 @@ enum TransactionStatus { const SELECT_EDU_LIMIT: usize = 16; const CLEANUP_TIMEOUT_MS: u64 = 3500; -#[async_trait] -impl crate::Service for Service { - fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { - let (sender, receiver) = loole::unbounded(); - Ok(Arc::new(Self { - db: Data::new(args.db.clone()), - server: args.server.clone(), - sender, - receiver: Mutex::new(receiver), - })) - } - +impl Service { #[tracing::instrument(skip_all, name = "sender")] - async fn worker(self: Arc<Self>) -> Result<()> { + pub(super) async fn sender(&self) -> Result<()> { let receiver = self.receiver.lock().await; let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); @@ -82,16 +69,6 @@ async fn worker(self: Arc<Self>) -> Result<()> { Ok(()) } - fn interrupt(&self) { - if !self.sender.is_closed() { - self.sender.close(); - } - } - - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } -} - -impl Service { fn handle_response( &self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { -- GitLab From b116984e46aa267e306da500db1b06ad7bc39ce1 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 17 Jul 2024 02:48:43 +0000 Subject: [PATCH 09/47] use mutex_map for url preview lock Signed-off-by: Jason Volk <jason@zemos.net> --- src/api/client/media.rs | 13 ++----------- src/service/media/mod.rs | 13 ++++++++----- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 46f8152b0..e219e4575 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::{io::Cursor, sync::Arc, time::Duration}; +use std::{io::Cursor, time::Duration}; use axum::extract::State; use axum_client_ip::InsecureClientIp; @@ -656,16 +656,7 @@ async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewDat } // ensure that only one request is made per URL - let mutex_request = Arc::clone( - services - .media - .url_preview_mutex - .write() - .await - .entry(url.to_owned()) - .or_default(), - ); - let _request_lock = mutex_request.lock().await; + let _request_lock = services.media.url_preview_mutex.lock(url).await; match services.media.get_url_preview(url).await { Some(preview) => Ok(preview), diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 1638235bd..4a7d38a48 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -2,18 +2,17 @@ mod tests; mod thumbnail; -use std::{collections::HashMap, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{path::PathBuf, sync::Arc, time::SystemTime}; use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_error, err, error, utils, Err, Result, Server}; +use conduit::{debug, debug_error, err, error, utils, utils::MutexMap, Err, Result, Server}; use data::{Data, Metadata}; use ruma::{OwnedMxcUri, OwnedUserId}; use serde::Serialize; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, - sync::{Mutex, RwLock}, }; use crate::services; @@ -44,7 +43,7 @@ pub struct UrlPreviewData { pub struct Service { server: Arc<Server>, pub(crate) db: Data, - pub url_preview_mutex: RwLock<HashMap<String, Arc<Mutex<()>>>>, + pub url_preview_mutex: MutexMap<String, ()>, } #[async_trait] @@ -53,7 +52,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { server: args.server.clone(), db: Data::new(args.db), - url_preview_mutex: RwLock::new(HashMap::new()), + url_preview_mutex: MutexMap::new(), })) } @@ -274,10 +273,12 @@ async fn create_media_file(&self, key: &[u8]) -> Result<fs::File> { } #[inline] + #[must_use] pub fn get_media_file(&self, key: &[u8]) -> PathBuf { self.get_media_file_sha256(key) } /// new SHA256 file name media function. requires database migrated. uses /// SHA256 hash of the base64 key as the file name + #[must_use] pub fn get_media_file_sha256(&self, key: &[u8]) -> PathBuf { let mut r = self.get_media_dir(); // Using the hash of the base64 key as the filename @@ -292,6 +293,7 @@ pub fn get_media_file_sha256(&self, key: &[u8]) -> PathBuf { /// old base64 file name media function /// This is the old version of `get_media_file` that uses the full base64 /// key as the filename. + #[must_use] pub fn get_media_file_b64(&self, key: &[u8]) -> PathBuf { let mut r = self.get_media_dir(); let encoded = encode_key(key); @@ -299,6 +301,7 @@ pub fn get_media_file_b64(&self, key: &[u8]) -> PathBuf { r } + #[must_use] pub fn get_media_dir(&self) -> PathBuf { let mut r = PathBuf::new(); r.push(self.server.config.database_path.clone()); -- GitLab From 43432189578bc6327566748c24866c838358c6cc Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 17 Jul 2024 07:39:14 +0000 Subject: [PATCH 10/47] initialize some containers with_capacity Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/sending/send.rs | 2 +- src/service/sending/sender.rs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index df3139c31..7901de48b 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -154,7 +154,7 @@ fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) where T: OutgoingRequest + Debug + Send, { - let mut req_map = serde_json::Map::new(); + let mut req_map = serde_json::Map::with_capacity(8); if !http_request.body().is_empty() { req_map.insert( "content".to_owned(), diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index df41db28b..774c3d694 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -555,8 +555,18 @@ async fn send_events_dest_push( async fn send_events_dest_normal( dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>, ) -> SendingResult { - let mut edu_jsons = Vec::new(); - let mut pdu_jsons = Vec::new(); + let mut pdu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Pdu(_))) + .count(), + ); + let mut edu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Edu(_))) + .count(), + ); for event in &events { match event { -- GitLab From a88f913a178ab278c81b65b8d918c5cbdac92149 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 17 Jul 2024 22:31:21 +0000 Subject: [PATCH 11/47] enrich interface for inter-service referencing Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/client/mod.rs | 4 +--- src/service/service.rs | 23 ++++++++++++++++++++--- src/service/services.rs | 14 +++++++++++++- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index cc8d52d1a..03b0a1425 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -18,9 +18,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; - let resolver = args - .get_service::<resolver::Service>("resolver") - .expect("resolver must be built prior to client"); + let resolver = args.require_service::<resolver::Service>("resolver"); Ok(Arc::new(Self { default: base(config) diff --git a/src/service/service.rs b/src/service/service.rs index 99b8723a0..863b955b5 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -1,7 +1,7 @@ use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; use async_trait::async_trait; -use conduit::{utils::string::split_once_infallible, Result, Server}; +use conduit::{err, error::inspect_log, utils::string::split_once_infallible, Err, Result, Server}; use database::Database; #[async_trait] @@ -44,9 +44,26 @@ pub(crate) struct Args<'a> { pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); impl Args<'_> { - pub(crate) fn get_service<T: Any + Send + Sync>(&self, name: &str) -> Option<Arc<T>> { - get::<T>(self.service, name) + pub(crate) fn require_service<T: Any + Send + Sync>(&self, name: &str) -> Arc<T> { + self.try_get_service::<T>(name) + .inspect_err(inspect_log) + .expect("Failure to reference service required by another service.") } + + pub(crate) fn try_get_service<T: Any + Send + Sync>(&self, name: &str) -> Result<Arc<T>> { + try_get::<T>(self.service, name) + } +} + +pub(crate) fn try_get<T: Any + Send + Sync>(map: &Map, name: &str) -> Result<Arc<T>> { + map.get(name).map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.clone() + .downcast::<T>() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) } pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> { diff --git a/src/service/services.rs b/src/service/services.rs index d0f74e13a..136059cdc 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -173,5 +173,17 @@ fn interrupt(&self) { } } - pub fn get<T: Any + Send + Sync>(&self, name: &str) -> Option<Arc<T>> { service::get::<T>(&self.service, name) } + pub fn try_get<T>(&self, name: &str) -> Result<Arc<T>> + where + T: Any + Send + Sync, + { + service::try_get::<T>(&self.service, name) + } + + pub fn get<T>(&self, name: &str) -> Option<Arc<T>> + where + T: Any + Send + Sync, + { + service::get::<T>(&self.service, name) + } } -- GitLab From 3dc91525cec658cd7288658b756a15228450bbee Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 21 Jul 2024 01:08:03 +0000 Subject: [PATCH 12/47] fix over-tabulation Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/rooms/spaces/mod.rs | 283 ++++++++++++++++---------------- 1 file changed, 143 insertions(+), 140 deletions(-) diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 02db7fab1..19a3ebbbd 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -269,7 +269,7 @@ async fn get_summary_and_children_federation( ) -> Result<Option<SummaryAccessibility>> { for server in via { debug_info!("Asking {server} for /hierarchy"); - if let Ok(response) = services() + let Ok(response) = services() .sending .send_federation_request( server, @@ -279,68 +279,70 @@ async fn get_summary_and_children_federation( }, ) .await - { - debug_info!("Got response from {server} for /hierarchy\n{response:?}"); - let summary = response.room.clone(); - - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: summary.clone(), - }), - ); - - for child in response.children { - let mut guard = self.roomid_spacehierarchy_cache.lock().await; - if !guard.contains_key(current_room) { - guard.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: { - let SpaceHierarchyChildSummary { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - allowed_room_ids, - } = child; - - SpaceHierarchyParentSummary { - canonical_alias, - name, - num_joined_members, - room_id: room_id.clone(), - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), - allowed_room_ids, - } - }, - }), - ); - } - } - if is_accessable_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) { - return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); + else { + continue; + }; + + debug_info!("Got response from {server} for /hierarchy\n{response:?}"); + let summary = response.room.clone(); + + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), + }), + ); + + for child in response.children { + let mut guard = self.roomid_spacehierarchy_cache.lock().await; + if !guard.contains_key(current_room) { + guard.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: { + let SpaceHierarchyChildSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + allowed_room_ids, + } = child; + + SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id: room_id.clone(), + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), + allowed_room_ids, + } + }, + }), + ); } - - return Ok(Some(SummaryAccessibility::Inaccessible)); } + if is_accessable_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + ) { + return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); + } + + return Ok(Some(SummaryAccessibility::Inaccessible)); } self.roomid_spacehierarchy_cache @@ -461,24 +463,27 @@ pub async fn get_client_hierarchy( let mut results = Vec::new(); while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { - if limit > results.len() { - match ( - self.get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) - .await?, - current_room == room_id, - ) { - (Some(SummaryAccessibility::Accessible(summary)), _) => { - let mut children: Vec<(OwnedRoomId, Vec<OwnedServerName>)> = - get_parent_children_via(&summary, suggested_only) - .into_iter() - .filter(|(room, _)| parents.iter().all(|parent| parent != room)) - .rev() - .collect(); - - if populate_results { - results.push(summary_to_chunk(*summary.clone())); - } else { - children = children + if results.len() >= limit { + break; + } + + match ( + self.get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) + .await?, + current_room == room_id, + ) { + (Some(SummaryAccessibility::Accessible(summary)), _) => { + let mut children: Vec<(OwnedRoomId, Vec<OwnedServerName>)> = + get_parent_children_via(&summary, suggested_only) + .into_iter() + .filter(|(room, _)| parents.iter().all(|parent| parent != room)) + .rev() + .collect(); + + if populate_results { + results.push(summary_to_chunk(*summary.clone())); + } else { + children = children .into_iter() .rev() .skip_while(|(room, _)| { @@ -495,39 +500,36 @@ pub async fn get_client_hierarchy( .rev() .collect(); - if children.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room IDs in token were not found.", - )); - } - - // We have reached the room after where we last left off - let parents_len = parents.len(); - if checked!(parents_len + 1)? == short_room_ids.len() { - populate_results = true; - } + if children.is_empty() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room IDs in token were not found.", + )); } - let parents_len: u64 = parents.len().try_into()?; - if !children.is_empty() && parents_len < max_depth { - parents.push_back(current_room.clone()); - stack.push(children); + // We have reached the room after where we last left off + let parents_len = parents.len(); + if checked!(parents_len + 1)? == short_room_ids.len() { + populate_results = true; } - // Root room in the space hierarchy, we return an error - // if this one fails. - }, - (Some(SummaryAccessibility::Inaccessible), true) => { - return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible")); - }, - (None, true) => { - return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found")); - }, - // Just ignore other unavailable rooms - (None | Some(SummaryAccessibility::Inaccessible), false) => (), - } - } else { - break; + } + + let parents_len: u64 = parents.len().try_into()?; + if !children.is_empty() && parents_len < max_depth { + parents.push_back(current_room.clone()); + stack.push(children); + } + // Root room in the space hierarchy, we return an error + // if this one fails. + }, + (Some(SummaryAccessibility::Inaccessible), true) => { + return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible")); + }, + (None, true) => { + return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found")); + }, + // Just ignore other unavailable rooms + (None | Some(SummaryAccessibility::Inaccessible), false) => (), } } @@ -574,41 +576,42 @@ fn next_room_to_traverse( async fn get_stripped_space_child_events( room_id: &RoomId, ) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { - let state = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let mut children_pdus = Vec::new(); - for (key, id) in state { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; - if event_type != StateEventType::SpaceChild { - continue; - } + let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { + return Ok(None); + }; + + let state = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - - if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; - } + let pdu = services() + .rooms + .timeline + .get_pdu(&id)? + .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + + if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get()) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) + { + continue; + } - if OwnedRoomId::try_from(state_key).is_ok() { - children_pdus.push(pdu.to_stripped_spacechild_state_event()); - } + if OwnedRoomId::try_from(state_key).is_ok() { + children_pdus.push(pdu.to_stripped_spacechild_state_event()); } - Ok(Some(children_pdus)) - } else { - Ok(None) } + + Ok(Some(children_pdus)) } /// With the given identifier, checks if a room is accessable -- GitLab From 9b20c6918f86c14b541356d2ae40a2c4a8558cf6 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 21 Jul 2024 00:21:34 +0000 Subject: [PATCH 13/47] add indirection for circular-dependencies between services Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/mod.rs | 2 +- src/service/service.rs | 49 +++++++++++++++++++++++++++++++++-------- src/service/services.rs | 36 ++++++++++++++++++++---------- 3 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/service/mod.rs b/src/service/mod.rs index 46adb072d..21d1f5946 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -29,7 +29,7 @@ pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; use database::Database; -pub(crate) use service::{Args, Service}; +pub(crate) use service::{Args, Dep, Service}; pub use crate::{ globals::{server_is_ours, user_is_local}, diff --git a/src/service/service.rs b/src/service/service.rs index 863b955b5..ce4f15b2a 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -1,9 +1,16 @@ -use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{ + any::Any, + collections::BTreeMap, + fmt::Write, + ops::Deref, + sync::{Arc, OnceLock}, +}; use async_trait::async_trait; use conduit::{err, error::inspect_log, utils::string::split_once_infallible, Err, Result, Server}; use database::Database; +/// Abstract interface for a Service #[async_trait] pub(crate) trait Service: Any + Send + Sync { /// Implement the construction of the service instance. Services are @@ -34,27 +41,51 @@ fn memory_usage(&self, _out: &mut dyn Write) -> Result<()> { Ok(()) } fn name(&self) -> &str; } +/// Args are passed to `Service::build` when a service is constructed. This +/// allows for arguments to change with limited impact to the many services. pub(crate) struct Args<'a> { pub(crate) server: &'a Arc<Server>, pub(crate) db: &'a Arc<Database>, - pub(crate) service: &'a Map, + pub(crate) service: &'a Arc<Map>, +} + +/// Dep is a reference to a service used within another service. +/// Circular-dependencies between services require this indirection to allow the +/// referenced service construction after the referencing service. +pub(crate) struct Dep<T> { + dep: OnceLock<Arc<T>>, + service: Arc<Map>, + name: &'static str, } pub(crate) type Map = BTreeMap<String, MapVal>; pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); -impl Args<'_> { - pub(crate) fn require_service<T: Any + Send + Sync>(&self, name: &str) -> Arc<T> { - self.try_get_service::<T>(name) - .inspect_err(inspect_log) - .expect("Failure to reference service required by another service.") +impl<T: Any + Send + Sync> Deref for Dep<T> { + type Target = Arc<T>; + + fn deref(&self) -> &Self::Target { + self.dep + .get_or_init(|| require::<T>(&self.service, self.name)) } +} - pub(crate) fn try_get_service<T: Any + Send + Sync>(&self, name: &str) -> Result<Arc<T>> { - try_get::<T>(self.service, name) +impl Args<'_> { + pub(crate) fn depend_service<T: Any + Send + Sync>(&self, name: &'static str) -> Dep<T> { + Dep::<T> { + dep: OnceLock::new(), + service: self.service.clone(), + name, + } } } +pub(crate) fn require<T: Any + Send + Sync>(map: &Map, name: &str) -> Arc<T> { + try_get::<T>(map, name) + .inspect_err(inspect_log) + .expect("Failure to reference service required by another service.") +} + pub(crate) fn try_get<T: Any + Send + Sync>(map: &Map, name: &str) -> Result<Arc<T>> { map.get(name).map_or_else( || Err!("Service {name:?} does not exist or has not been built yet."), diff --git a/src/service/services.rs b/src/service/services.rs index 136059cdc..68205323f 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -31,24 +31,36 @@ pub struct Services { pub updates: Arc<updates::Service>, manager: Mutex<Option<Arc<Manager>>>, - pub(crate) service: Map, + pub(crate) service: Arc<Map>, pub server: Arc<Server>, pub db: Arc<Database>, } +macro_rules! build_service { + ($map:ident, $server:ident, $db:ident, $tyname:ty) => {{ + let built = <$tyname>::build(Args { + server: &$server, + db: &$db, + service: &$map, + })?; + + Arc::get_mut(&mut $map) + .expect("must have mutable reference to services collection") + .insert(built.name().to_owned(), (built.clone(), built.clone())); + + trace!("built service #{}: {:?}", $map.len(), built.name()); + built + }}; +} + impl Services { + #[allow(clippy::cognitive_complexity)] pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> { - let mut service: Map = BTreeMap::new(); + let mut service: Arc<Map> = Arc::new(BTreeMap::new()); macro_rules! build { - ($tyname:ty) => {{ - let built = <$tyname>::build(Args { - server: &server, - db: &db, - service: &service, - })?; - service.insert(built.name().to_owned(), (built.clone(), built.clone())); - built - }}; + ($srv:ty) => { + build_service!(service, server, db, $srv) + }; } Ok(Self { @@ -167,7 +179,7 @@ pub async fn memory_usage(&self) -> Result<String> { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in &self.service { + for (name, (service, ..)) in self.service.iter() { trace!("Interrupting {name}"); service.interrupt(); } -- GitLab From 992c0a1e583e94589bad46ee56bf464b046931d8 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sat, 20 Jul 2024 23:38:20 +0000 Subject: [PATCH 14/47] de-global services for admin Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 2 +- src/admin/handler.rs | 2 +- src/admin/mod.rs | 5 +- src/admin/room/room_moderation_commands.rs | 6 +- src/admin/user/commands.rs | 5 +- src/api/client/account.rs | 7 +- src/api/client/state.rs | 4 +- src/service/admin/console.rs | 20 +- src/service/admin/create.rs | 47 ++-- src/service/admin/grant.rs | 43 ++-- src/service/admin/mod.rs | 284 ++++++++++----------- src/service/globals/migrations.rs | 2 +- src/service/rooms/timeline/mod.rs | 5 +- src/service/users/mod.rs | 2 +- 14 files changed, 214 insertions(+), 220 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index f319f5a55..f0aa23cb6 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -8,7 +8,7 @@ use conduit::{ debug, info, log, log::{capture, Capture}, - warn, Error, Result, + warn, Error, PduEvent, Result, }; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 95c7ed418..409abc18f 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -14,7 +14,7 @@ extern crate conduit_service as service; use conduit::{utils::string::common_prefix, Result}; -pub(crate) use service::admin::{Command, Service}; +pub(crate) use service::admin::Command; use service::admin::{CommandOutput, CommandResult, HandlerResult}; use crate::{ diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 14856811f..e020ed436 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -20,10 +20,7 @@ pub(crate) use conduit::{mod_ctor, mod_dtor, Result}; pub(crate) use service::{services, user_is_local}; -pub(crate) use crate::{ - handler::Service, - utils::{escape_html, get_room_info}, -}; +pub(crate) use crate::utils::{escape_html, get_room_info}; mod_ctor! {} mod_dtor! {} diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index c9be44737..46354c0f7 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -2,7 +2,7 @@ use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use tracing::{debug, error, info, warn}; -use super::{super::Service, RoomModerationCommand}; +use super::RoomModerationCommand; use crate::{get_room_info, services, user_is_local, Result}; pub(super) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { @@ -31,7 +31,7 @@ async fn ban_room( let admin_room_alias = &services().globals.admin_alias; - if let Some(admin_room_id) = Service::get_admin_room()? { + if let Some(admin_room_id) = services().admin.get_admin_room()? { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -198,7 +198,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = Service::get_admin_room()? { + if let Some(admin_room_id) = services().admin.get_admin_room()? { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index f6387ee82..fa1c52881 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -363,7 +363,10 @@ pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result .unwrap_or_else(|| user_id.to_string()); assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); - service::admin::make_user_admin(&user_id, displayname).await?; + services() + .admin + .make_user_admin(&user_id, displayname) + .await?; Ok(RoomMessageEventContent::notice_markdown(format!( "{user_id} has been granted admin privileges.", diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 19ac89a08..b3495b429 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -347,9 +347,12 @@ pub(crate) async fn register_route( // If this is the first real user, grant them admin privileges except for guest // users Note: the server user, @conduit:servername, is generated first if !is_guest { - if let Some(admin_room) = service::admin::Service::get_admin_room()? { + if let Some(admin_room) = services.admin.get_admin_room()? { if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { - service::admin::make_user_admin(&user_id, displayname).await?; + services + .admin + .make_user_admin(&user_id, displayname) + .await?; warn!("Granting {user_id} admin privileges as the first user"); } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 51217d001..56ffd2ac5 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -210,7 +210,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made public StateEventType::RoomJoinRules => { - if let Some(admin_room_id) = service::admin::Service::get_admin_room()? { + if let Some(admin_room_id) = services.admin.get_admin_room()? { if admin_room_id == room_id { if let Ok(join_rule) = serde_json::from_str::<RoomJoinRulesEventContent>(json.json().get()) { if join_rule.join_rule == JoinRule::Public { @@ -225,7 +225,7 @@ async fn allowed_to_send_state_event( }, // admin room is a sensitive room, it should not ever be made world readable StateEventType::RoomHistoryVisibility => { - if let Some(admin_room_id) = service::admin::Service::get_admin_room()? { + if let Some(admin_room_id) = services.admin.get_admin_room()? { if admin_room_id == room_id { if let Ok(visibility_content) = serde_json::from_str::<RoomHistoryVisibilityEventContent>(json.json().get()) diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index c590b928f..9c3357917 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -4,7 +4,7 @@ sync::{Arc, Mutex}, }; -use conduit::{debug, defer, error, log}; +use conduit::{debug, defer, error, log, Server}; use futures_util::future::{AbortHandle, Abortable}; use ruma::events::room::message::RoomMessageEventContent; use rustyline_async::{Readline, ReadlineError, ReadlineEvent}; @@ -14,6 +14,7 @@ use crate::services; pub struct Console { + server: Arc<Server>, worker_join: Mutex<Option<JoinHandle<()>>>, input_abort: Mutex<Option<AbortHandle>>, command_abort: Mutex<Option<AbortHandle>>, @@ -25,9 +26,9 @@ pub struct Console { const HISTORY_LIMIT: usize = 48; impl Console { - #[must_use] - pub fn new() -> Arc<Self> { + pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> { Arc::new(Self { + server: args.server.clone(), worker_join: None.into(), input_abort: None.into(), command_abort: None.into(), @@ -37,7 +38,7 @@ pub fn new() -> Arc<Self> { } pub(super) async fn handle_signal(self: &Arc<Self>, sig: &'static str) { - if !services().server.running() { + if !self.server.running() { self.interrupt(); } else if sig == "SIGINT" { self.interrupt_command(); @@ -49,7 +50,7 @@ pub async fn start(self: &Arc<Self>) { let mut worker_join = self.worker_join.lock().expect("locked"); if worker_join.is_none() { let self_ = Arc::clone(self); - _ = worker_join.insert(services().server.runtime().spawn(self_.worker())); + _ = worker_join.insert(self.server.runtime().spawn(self_.worker())); } } @@ -89,16 +90,13 @@ pub fn interrupt_command(self: &Arc<Self>) { #[tracing::instrument(skip_all, name = "console")] async fn worker(self: Arc<Self>) { debug!("session starting"); - while services().server.running() { + while self.server.running() { match self.readline().await { Ok(event) => match event { ReadlineEvent::Line(string) => self.clone().handle(string).await, ReadlineEvent::Interrupted => continue, ReadlineEvent::Eof => break, - ReadlineEvent::Quit => services() - .server - .shutdown() - .unwrap_or_else(error::default_log), + ReadlineEvent::Quit => self.server.shutdown().unwrap_or_else(error::default_log), }, Err(error) => match error { ReadlineError::Closed => break, @@ -115,7 +113,7 @@ async fn worker(self: Arc<Self>) { } async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> { - let _suppression = log::Suppress::new(&services().server); + let _suppression = log::Suppress::new(&self.server); let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?; readline.set_tab_completer(Self::tab_complete); diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index fbb6a078d..18cbe0399 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use conduit::{Error, Result}; +use conduit::{pdu::PduBuilder, warn, Error, Result}; use ruma::{ api::client::error::ErrorKind, events::{ @@ -21,26 +21,25 @@ RoomId, RoomVersionId, }; use serde_json::value::to_raw_value; -use tracing::warn; -use crate::{pdu::PduBuilder, services}; +use crate::Services; /// Create the admin room. /// /// Users in this room are considered admins by conduit, and the room can be /// used to issue admin commands by talking to the server user inside it. -pub async fn create_admin_room() -> Result<()> { - let room_id = RoomId::new(services().globals.server_name()); +pub async fn create_admin_room(services: &Services) -> Result<()> { + let room_id = RoomId::new(services.globals.server_name()); - let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; + let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().rooms.state.mutex.lock(&room_id).await; + let state_lock = services.rooms.state.mutex.lock(&room_id).await; // Create a user for the server - let server_user = &services().globals.server_user; - services().users.create(server_user, None)?; + let server_user = &services.globals.server_user; + services.users.create(server_user, None)?; - let room_version = services().globals.default_room_version(); + let room_version = services.globals.default_room_version(); let mut content = { use RoomVersionId::*; @@ -62,7 +61,7 @@ pub async fn create_admin_room() -> Result<()> { content.room_version = room_version; // 1. The room create event - services() + services .rooms .timeline .build_and_append_pdu( @@ -80,7 +79,7 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 2. Make conduit bot join - services() + services .rooms .timeline .build_and_append_pdu( @@ -111,7 +110,7 @@ pub async fn create_admin_room() -> Result<()> { let mut users = BTreeMap::new(); users.insert(server_user.clone(), 100.into()); - services() + services .rooms .timeline .build_and_append_pdu( @@ -133,7 +132,7 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 4.1 Join Rules - services() + services .rooms .timeline .build_and_append_pdu( @@ -152,7 +151,7 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 4.2 History Visibility - services() + services .rooms .timeline .build_and_append_pdu( @@ -171,7 +170,7 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 4.3 Guest Access - services() + services .rooms .timeline .build_and_append_pdu( @@ -190,8 +189,8 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 5. Events implied by name and topic - let room_name = format!("{} Admin Room", services().globals.server_name()); - services() + let room_name = format!("{} Admin Room", services.globals.server_name()); + services .rooms .timeline .build_and_append_pdu( @@ -209,14 +208,14 @@ pub async fn create_admin_room() -> Result<()> { ) .await?; - services() + services .rooms .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", services().globals.server_name()), + topic: format!("Manage {}", services.globals.server_name()), }) .expect("event is valid, we just created it"), unsigned: None, @@ -230,9 +229,9 @@ pub async fn create_admin_room() -> Result<()> { .await?; // 6. Room alias - let alias = &services().globals.admin_alias; + let alias = &services.globals.admin_alias; - services() + services .rooms .timeline .build_and_append_pdu( @@ -253,13 +252,13 @@ pub async fn create_admin_room() -> Result<()> { ) .await?; - services() + services .rooms .alias .set_alias(alias, &room_id, server_user)?; // 7. (ad-hoc) Disable room previews for everyone by default - services() + services .rooms .timeline .build_and_append_pdu( diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 9a4ef242e..213225ace 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -14,23 +14,24 @@ }; use serde_json::value::to_raw_value; -use super::Service; -use crate::{pdu::PduBuilder, services}; +use crate::pdu::PduBuilder; -/// Invite the user to the conduit admin room. -/// -/// In conduit, this is equivalent to granting admin privileges. -pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<()> { - if let Some(room_id) = Service::get_admin_room()? { - let state_lock = services().rooms.state.mutex.lock(&room_id).await; +impl super::Service { + /// Invite the user to the conduit admin room. + /// + /// In conduit, this is equivalent to granting admin privileges. + pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { + let Some(room_id) = self.get_admin_room()? else { + return Ok(()); + }; + + let state_lock = self.state.mutex.lock(&room_id).await; // Use the server user to grant the new admin's power level - let server_user = &services().globals.server_user; + let server_user = &self.globals.server_user; // Invite and join the real user - services() - .rooms - .timeline + self.timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -54,9 +55,7 @@ pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<() &state_lock, ) .await?; - services() - .rooms - .timeline + self.timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -86,9 +85,7 @@ pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<() users.insert(server_user.clone(), 100.into()); users.insert(user_id.to_owned(), 100.into()); - services() - .rooms - .timeline + self.timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, @@ -108,12 +105,12 @@ pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<() .await?; // Send welcome message - services().rooms.timeline.build_and_append_pdu( + self.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", services().globals.server_name()), - format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", services().globals.server_name()), + format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.globals.server_name()), + format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.globals.server_name()), )) .expect("event is valid, we just created it"), unsigned: None, @@ -124,7 +121,7 @@ pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<() &room_id, &state_lock, ).await?; - } - Ok(()) + Ok(()) + } } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index f63ebf097..fcb342129 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -9,9 +9,8 @@ }; use async_trait::async_trait; -use conduit::{debug, error, error::default_log, Error, Result}; +use conduit::{debug, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; pub use create::create_admin_room; -pub use grant::make_user_admin; use loole::{Receiver, Sender}; use ruma::{ events::{ @@ -23,9 +22,15 @@ use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; -use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, user_is_local, PduEvent}; +use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local}; pub struct Service { + server: Arc<Server>, + globals: Arc<globals::Service>, + alias: Arc<rooms::alias::Service>, + timeline: Arc<rooms::timeline::Service>, + state: Arc<rooms::state::Service>, + state_cache: Arc<rooms::state_cache::Service>, sender: Sender<Command>, receiver: Mutex<Receiver<Command>>, pub handle: RwLock<Option<Handler>>, @@ -50,21 +55,27 @@ pub struct Command { #[async_trait] impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); Ok(Arc::new(Self { + server: args.server.clone(), + globals: args.require_service::<globals::Service>("globals"), + alias: args.require_service::<rooms::alias::Service>("rooms::alias"), + timeline: args.require_service::<rooms::timeline::Service>("rooms::timeline"), + state: args.require_service::<rooms::state::Service>("rooms::state"), + state_cache: args.require_service::<rooms::state_cache::Service>("rooms::state_cache"), sender, receiver: Mutex::new(receiver), handle: RwLock::new(None), complete: StdRwLock::new(None), #[cfg(feature = "console")] - console: console::Console::new(), + console: console::Console::new(&args), })) } async fn worker(self: Arc<Self>) -> Result<()> { let receiver = self.receiver.lock().await; - let mut signals = services().server.signal.subscribe(); + let mut signals = self.server.signal.subscribe(); loop { tokio::select! { command = receiver.recv_async() => match command { @@ -104,9 +115,10 @@ pub async fn send_text(&self, body: &str) { } pub async fn send_message(&self, message_content: RoomMessageEventContent) { - if let Ok(Some(room_id)) = Self::get_admin_room() { - let user_id = &services().globals.server_user; - respond_to_room(message_content, &room_id, user_id).await; + if let Ok(Some(room_id)) = self.get_admin_room() { + let user_id = &self.globals.server_user; + self.respond_to_room(message_content, &room_id, user_id) + .await; } } @@ -147,7 +159,7 @@ async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { async fn handle_command(&self, command: Command) { match self.process_command(command).await { - Ok(Some(output)) => handle_response(output).await, + Ok(Some(output)) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), Err(e) => error!("Command processing error: {e}"), } @@ -163,8 +175,8 @@ async fn process_command(&self, command: Command) -> CommandResult { /// Checks whether a given user is an admin of this server pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { - if let Ok(Some(admin_room)) = Self::get_admin_room() { - services().rooms.state_cache.is_joined(user_id, &admin_room) + if let Ok(Some(admin_room)) = self.get_admin_room() { + self.state_cache.is_joined(user_id, &admin_room) } else { Ok(false) } @@ -174,16 +186,11 @@ pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { /// /// Errors are propagated from the database, and will have None if there is /// no admin room - pub fn get_admin_room() -> Result<Option<OwnedRoomId>> { - if let Some(room_id) = services() - .rooms - .alias - .resolve_local_alias(&services().globals.admin_alias)? - { - if services() - .rooms + pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> { + if let Some(room_id) = self.alias.resolve_local_alias(&self.globals.admin_alias)? { + if self .state_cache - .is_joined(&services().globals.server_user, &room_id)? + .is_joined(&self.globals.server_user, &room_id)? { return Ok(Some(room_id)); } @@ -191,142 +198,133 @@ pub fn get_admin_room() -> Result<Option<OwnedRoomId>> { Ok(None) } -} -async fn handle_response(content: RoomMessageEventContent) { - let Some(Relation::Reply { - in_reply_to, - }) = content.relates_to.as_ref() - else { - return; - }; - - let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(&in_reply_to.event_id) else { - return; - }; - - let response_sender = if is_admin_room(&pdu.room_id) { - &services().globals.server_user - } else { - &pdu.sender - }; - - respond_to_room(content, &pdu.room_id, response_sender).await; -} + async fn handle_response(&self, content: RoomMessageEventContent) { + let Some(Relation::Reply { + in_reply_to, + }) = content.relates_to.as_ref() + else { + return; + }; -async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { - assert!( - services() - .admin - .user_is_admin(user_id) - .await - .expect("checked user is admin"), - "sender is not admin" - ); - - let state_lock = services().rooms.state.mutex.lock(room_id).await; - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }; - - if let Err(e) = services() - .rooms - .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) - .await - { - handle_response_error(e, room_id, user_id, &state_lock) + let Ok(Some(pdu)) = self.timeline.get_pdu(&in_reply_to.event_id) else { + return; + }; + + let response_sender = if self.is_admin_room(&pdu.room_id) { + &self.globals.server_user + } else { + &pdu.sender + }; + + self.respond_to_room(content, &pdu.room_id, response_sender) + .await; + } + + async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &RoomId, user_id: &UserId) { + assert!( + self.user_is_admin(user_id) + .await + .expect("checked user is admin"), + "sender is not admin" + ); + + let state_lock = self.state.mutex.lock(room_id).await; + let response_pdu = PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&content).expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }; + + if let Err(e) = self + .timeline + .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .await - .unwrap_or_else(default_log); + { + self.handle_response_error(e, room_id, user_id, &state_lock) + .await + .unwrap_or_else(default_log); + } } -} -async fn handle_response_error( - e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, -) -> Result<()> { - error!("Failed to build and append admin room response PDU: \"{e}\""); - let error_room_message = RoomMessageEventContent::text_plain(format!( - "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ - successfully, but we could not return the output." - )); - - let response_pdu = PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }; - - services() - .rooms - .timeline - .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) - .await?; - - Ok(()) -} + async fn handle_response_error( + &self, e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, + ) -> Result<()> { + error!("Failed to build and append admin room response PDU: \"{e}\""); + let error_room_message = RoomMessageEventContent::text_plain(format!( + "Failed to build and append admin room PDU: \"{e}\"\n\nThe original admin command may have finished \ + successfully, but we could not return the output." + )); + + let response_pdu = PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&error_room_message).expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }; + + self.timeline + .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) + .await?; -pub async fn is_admin_command(pdu: &PduEvent, body: &str) -> bool { - // Server-side command-escape with public echo - let is_escape = body.starts_with('\\'); - let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); + Ok(()) + } - // Admin command with public echo (in admin room) - let server_user = &services().globals.server_user; - let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str()); + pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { + // Server-side command-escape with public echo + let is_escape = body.starts_with('\\'); + let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); - // Expected backward branch - if !is_public_escape && !is_public_prefix { - return false; - } + // Admin command with public echo (in admin room) + let server_user = &self.globals.server_user; + let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str()); - // only allow public escaped commands by local admins - if is_public_escape && !user_is_local(&pdu.sender) { - return false; - } + // Expected backward branch + if !is_public_escape && !is_public_prefix { + return false; + } - // Check if server-side command-escape is disabled by configuration - if is_public_escape && !services().globals.config.admin_escape_commands { - return false; - } + // only allow public escaped commands by local admins + if is_public_escape && !user_is_local(&pdu.sender) { + return false; + } - // Prevent unescaped !admin from being used outside of the admin room - if is_public_prefix && !is_admin_room(&pdu.room_id) { - return false; - } + // Check if server-side command-escape is disabled by configuration + if is_public_escape && !self.globals.config.admin_escape_commands { + return false; + } - // Only senders who are admin can proceed - if !services() - .admin - .user_is_admin(&pdu.sender) - .await - .unwrap_or(false) - { - return false; - } + // Prevent unescaped !admin from being used outside of the admin room + if is_public_prefix && !self.is_admin_room(&pdu.room_id) { + return false; + } - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as conduit - let emergency_password_set = services().globals.emergency_password().is_some(); - let from_server = pdu.sender == *server_user && !emergency_password_set; - if from_server && is_admin_room(&pdu.room_id) { - return false; - } + // Only senders who are admin can proceed + if !self.user_is_admin(&pdu.sender).await.unwrap_or(false) { + return false; + } - // Authentic admin command - true -} + // This will evaluate to false if the emergency password is set up so that + // the administrator can execute commands as conduit + let emergency_password_set = self.globals.emergency_password().is_some(); + let from_server = pdu.sender == *server_user && !emergency_password_set; + if from_server && self.is_admin_room(&pdu.room_id) { + return false; + } -#[must_use] -pub fn is_admin_room(room_id: &RoomId) -> bool { - if let Ok(Some(admin_room_id)) = Service::get_admin_room() { - admin_room_id == room_id - } else { - false + // Authentic admin command + true + } + + #[must_use] + pub fn is_admin_room(&self, room_id: &RoomId) -> bool { + if let Ok(Some(admin_room_id)) = self.get_admin_room() { + admin_room_id == room_id + } else { + false + } } } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 3948d1f50..d8c5f29b8 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -59,7 +59,7 @@ async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> { db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; // Create the admin room and server user on first run - crate::admin::create_admin_room().await?; + crate::admin::create_admin_room(services).await?; warn!( "Created new {} database with version {DATABASE_VERSION}", diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 0bc5ade16..4c5e407a4 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -38,7 +38,6 @@ use tokio::sync::RwLock; use crate::{ - admin, appservice::NamespaceRegex, pdu::{EventHash, PduBuilder}, rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent}, @@ -485,7 +484,7 @@ pub async fn append_pdu( .search .index_pdu(shortroomid, &pdu_id, &body)?; - if admin::is_admin_command(pdu, &body).await { + if services().admin.is_admin_command(pdu, &body).await { services() .admin .command(body, Some((*pdu.event_id).into())) @@ -784,7 +783,7 @@ pub async fn build_and_append_pdu( state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Arc<EventId>> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = admin::Service::get_admin_room()? { + if let Some(admin_room) = services().admin.get_admin_room()? { if admin_room == room_id { match pdu.event_type() { TimelineEventType::RoomEncryption => { diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 0196e1aa9..e0a4dd1c4 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -247,7 +247,7 @@ pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { self.db.is_deac /// Check if a user is an admin pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { - if let Some(admin_room_id) = crate::admin::Service::get_admin_room()? { + if let Some(admin_room_id) = services().admin.get_admin_room()? { services() .rooms .state_cache -- GitLab From 010e4ee35a372018e56e9572a927cf06394c4291 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Thu, 18 Jul 2024 06:37:47 +0000 Subject: [PATCH 15/47] de-global services for services Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 9 +- src/admin/mod.rs | 1 + src/api/client/alias.rs | 12 +- src/api/client/membership.rs | 17 +- src/api/client/sync.rs | 6 +- src/api/mod.rs | 2 + src/api/server/backfill.rs | 3 +- src/api/server/event.rs | 3 +- src/api/server/event_auth.rs | 3 +- src/api/server/get_missing_events.rs | 3 +- src/api/server/invite.rs | 6 +- src/api/server/send.rs | 3 +- src/api/server/send_join.rs | 8 +- src/api/server/state.rs | 9 +- src/router/mod.rs | 2 + src/service/account_data/data.rs | 17 +- src/service/account_data/mod.rs | 2 +- src/service/admin/console.rs | 14 +- src/service/admin/grant.rs | 19 +- src/service/admin/mod.rs | 63 +- src/service/appservice/mod.rs | 14 +- src/service/client/mod.rs | 2 +- src/service/emergency/mod.rs | 83 +++ src/service/globals/data.rs | 41 +- src/service/globals/emerg_access.rs | 54 -- src/service/globals/migrations.rs | 274 ++++--- src/service/globals/mod.rs | 15 +- src/service/key_backups/data.rs | 23 +- src/service/key_backups/mod.rs | 2 +- src/service/manager.rs | 20 +- src/service/media/data.rs | 4 +- src/service/media/mod.rs | 22 +- src/service/mod.rs | 4 +- src/service/presence/data.rs | 24 +- src/service/presence/mod.rs | 38 +- src/service/pusher/mod.rs | 46 +- src/service/resolver/actual.rs | 107 ++- src/service/resolver/cache.rs | 3 +- src/service/resolver/mod.rs | 15 +- src/service/rooms/alias/data.rs | 17 +- src/service/rooms/alias/mod.rs | 84 ++- src/service/rooms/alias/remote.rs | 108 +-- src/service/rooms/auth_chain/data.rs | 9 +- src/service/rooms/auth_chain/mod.rs | 31 +- src/service/rooms/directory/mod.rs | 4 +- src/service/rooms/event_handler/mod.rs | 229 +++--- .../rooms/event_handler/parse_incoming_pdu.rs | 41 +- .../rooms/event_handler/signing_keys.rs | 57 +- src/service/rooms/lazy_loading/mod.rs | 4 +- src/service/rooms/metadata/data.rs | 17 +- src/service/rooms/metadata/mod.rs | 5 +- src/service/rooms/mod.rs | 4 +- src/service/rooms/pdu_metadata/data.rs | 21 +- src/service/rooms/pdu_metadata/mod.rs | 30 +- src/service/rooms/read_receipt/data.rs | 19 +- src/service/rooms/read_receipt/mod.rs | 14 +- src/service/rooms/search/data.rs | 19 +- src/service/rooms/search/mod.rs | 2 +- src/service/rooms/short/data.rs | 25 +- src/service/rooms/short/mod.rs | 5 +- src/service/rooms/spaces/mod.rs | 367 ++++----- src/service/rooms/state/mod.rs | 110 ++- src/service/rooms/state_accessor/data.rs | 60 +- src/service/rooms/state_accessor/mod.rs | 32 +- src/service/rooms/state_cache/data.rs | 61 +- src/service/rooms/state_cache/mod.rs | 55 +- src/service/rooms/state_compressor/mod.rs | 25 +- src/service/rooms/threads/data.rs | 27 +- src/service/rooms/threads/mod.rs | 39 +- src/service/rooms/timeline/data.rs | 99 +-- src/service/rooms/timeline/mod.rs | 303 ++++---- src/service/rooms/typing/mod.rs | 44 +- src/service/rooms/user/data.rs | 27 +- src/service/rooms/user/mod.rs | 5 +- src/service/sending/appservice.rs | 22 +- src/service/sending/data.rs | 17 +- src/service/sending/mod.rs | 65 +- src/service/sending/send.rs | 160 ++-- src/service/sending/sender.rs | 705 +++++++++--------- src/service/service.rs | 66 +- src/service/services.rs | 118 +-- src/service/uiaa/mod.rs | 23 +- src/service/updates/mod.rs | 17 +- src/service/users/data.rs | 83 ++- src/service/users/mod.rs | 23 +- 85 files changed, 2442 insertions(+), 1849 deletions(-) create mode 100644 src/service/emergency/mod.rs delete mode 100644 src/service/globals/emerg_access.rs diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index f0aa23cb6..cbe524732 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -15,7 +15,7 @@ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent}; +use service::services; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -189,7 +189,10 @@ pub(super) async fn get_remote_pdu( debug!("Attempting to parse PDU: {:?}", &response.pdu); let parsed_pdu = { - let parsed_result = parse_incoming_pdu(&response.pdu); + let parsed_result = services() + .rooms + .event_handler + .parse_incoming_pdu(&response.pdu); let (event_id, value, room_id) = match parsed_result { Ok(t) => t, Err(e) => { @@ -510,7 +513,7 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match parse_incoming_pdu(&pdu) { + events.push(match services().rooms.event_handler.parse_incoming_pdu(&pdu) { Ok(t) => t, Err(e) => { warn!("Could not parse PDU, ignoring: {e}"); diff --git a/src/admin/mod.rs b/src/admin/mod.rs index e020ed436..c57659c17 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "168"] #![allow(clippy::wildcard_imports)] pub(crate) mod appservice; diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 88d1a4e6b..11617a0e8 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -22,7 +22,11 @@ pub(crate) async fn create_alias_route( ) -> Result<create_alias::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; + services + .rooms + .alias + .appservice_checks(&body.room_alias, &body.appservice_info) + .await?; // this isn't apart of alias_checks or delete alias route because we should // allow removing forbidden room aliases @@ -61,7 +65,11 @@ pub(crate) async fn delete_alias_route( ) -> Result<delete_alias::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; + services + .rooms + .alias + .appservice_checks(&body.room_alias, &body.appservice_info) + .await?; if services .rooms diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 9fde99a4a..e3630c726 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -43,7 +43,6 @@ service::{ pdu::{gen_event_id_canonical_json, PduBuilder}, rooms::state::RoomMutexGuard, - sending::convert_to_outgoing_federation_event, server_is_ours, user_is_local, Services, }, Ruma, @@ -791,7 +790,9 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: convert_to_outgoing_federation_event(join_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -1203,7 +1204,9 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: convert_to_outgoing_federation_event(join_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -1431,7 +1434,9 @@ pub(crate) async fn invite_helper( room_id: room_id.to_owned(), event_id: (*pdu.event_id).to_owned(), room_version: room_version_id.clone(), - event: convert_to_outgoing_federation_event(pdu_json.clone()), + event: services + .sending + .convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, via: services.rooms.state_cache.servers_route_via(room_id).ok(), }, @@ -1763,7 +1768,9 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room federation::membership::create_leave_event::v2::Request { room_id: room_id.to_owned(), event_id, - pdu: convert_to_outgoing_federation_event(leave_event.clone()), + pdu: services + .sending + .convert_to_outgoing_federation_event(leave_event.clone()), }, ) .await?; diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 5739052c3..6eeb8fffe 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -475,8 +475,6 @@ async fn handle_left_room( async fn process_presence_updates( services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, ) -> Result<()> { - use crate::service::presence::Presence; - // Take presence updates for (user_id, _, presence_bytes) in services.presence.presence_since(since) { if !services @@ -487,7 +485,9 @@ async fn process_presence_updates( continue; } - let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; + let presence_event = services + .presence + .from_json_bytes_to_event(&presence_bytes, &user_id)?; match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); diff --git a/src/api/mod.rs b/src/api/mod.rs index 793829344..0d80e5814 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "160"] + pub mod client; pub mod router; pub mod server; diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 8dd38cad0..1b665c19d 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -4,7 +4,6 @@ api::{client::error::ErrorKind, federation::backfill::get_backfill}, uint, user_id, MilliSecondsSinceUnixEpoch, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -67,7 +66,7 @@ pub(crate) async fn get_backfill_route( }) .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id)) .filter_map(|r| r.ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(); Ok(get_backfill::v1::Response { diff --git a/src/api/server/event.rs b/src/api/server/event.rs index e8e08c817..e11a01a20 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -4,7 +4,6 @@ api::{client::error::ErrorKind, federation::event::get_event}, MilliSecondsSinceUnixEpoch, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -50,6 +49,6 @@ pub(crate) async fn get_event_route( Ok(get_event::v1::Response { origin: services.globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: convert_to_outgoing_federation_event(event), + pdu: services.sending.convert_to_outgoing_federation_event(event), }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 8d26b73a4..4b0f6bc00 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -6,7 +6,6 @@ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -60,7 +59,7 @@ pub(crate) async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 378cd4fe3..e2c3c93cf 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -4,7 +4,6 @@ api::{client::error::ErrorKind, federation::event::get_missing_events}, OwnedEventId, RoomId, }; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -82,7 +81,7 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(convert_to_outgoing_federation_event(pdu)); + events.push(services.sending.convert_to_outgoing_federation_event(pdu)); } i = i.saturating_add(1); } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 982a2a01e..17e219205 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -7,7 +7,7 @@ serde::JsonObject, CanonicalJsonValue, EventId, OwnedUserId, }; -use service::{sending::convert_to_outgoing_federation_event, server_is_ours}; +use service::server_is_ours; use crate::Ruma; @@ -174,6 +174,8 @@ pub(crate) async fn create_invite_route( } Ok(create_invite::v2::Response { - event: convert_to_outgoing_federation_event(signed_event), + event: services + .sending + .convert_to_outgoing_federation_event(signed_event), }) } diff --git a/src/api/server/send.rs b/src/api/server/send.rs index f29344802..2f698d337 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -21,7 +21,6 @@ use tokio::sync::RwLock; use crate::{ - service::rooms::event_handler::parse_incoming_pdu, services::Services, utils::{self}, Error, Result, Ruma, @@ -89,7 +88,7 @@ async fn handle_pdus( ) -> Result<ResolvedMap> { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { - parsed_pdus.push(match parse_incoming_pdu(pdu) { + parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) { Ok(t) => t, Err(e) => { debug_warn!("Could not parse PDU: {e}"); diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index b72bfa035..7f79a1d99 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -13,9 +13,7 @@ CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::{ - pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services, -}; +use service::{pdu::gen_event_id_canonical_json, user_is_local, Services}; use tokio::sync::RwLock; use tracing::warn; @@ -186,12 +184,12 @@ async fn create_join_event( Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), state: state_ids .iter() .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect(), // Event field is required if the room version supports restricted join rules. event: Some( diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 24a11ccab..d215236af 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -3,7 +3,6 @@ use axum::extract::State; use conduit::{Error, Result}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; -use service::sending::convert_to_outgoing_federation_event; use crate::Ruma; @@ -44,7 +43,11 @@ pub(crate) async fn get_room_state_route( .state_full_ids(shortstatehash) .await? .into_values() - .map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())) + .map(|id| { + services + .sending + .convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()) + }) .collect(); let auth_chain_ids = services @@ -61,7 +64,7 @@ pub(crate) async fn get_room_state_route( .timeline .get_pdu_json(&id) .ok()? - .map(convert_to_outgoing_federation_event) + .map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) }) .collect(), pdus, diff --git a/src/router/mod.rs b/src/router/mod.rs index e9bae3c5f..03c70f6d7 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "160"] + mod layers; mod request; mod router; diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 7b3a3deed..439603be1 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use conduit::{utils, warn, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ api::client::error::ErrorKind, events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, @@ -9,18 +9,27 @@ RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { roomuserdataid_accountdata: Arc<Map>, roomusertype_roomuserdataid: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(), roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } @@ -40,7 +49,7 @@ pub(super) fn update( prefix.push(0xFF); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); roomuserdataid.push(0xFF); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 69d2f799a..c569889e6 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -17,7 +17,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index 9c3357917..c9a288d94 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -11,10 +11,11 @@ use termimad::MadSkin; use tokio::task::JoinHandle; -use crate::services; +use crate::{admin, Dep}; pub struct Console { server: Arc<Server>, + admin: Dep<admin::Service>, worker_join: Mutex<Option<JoinHandle<()>>>, input_abort: Mutex<Option<AbortHandle>>, command_abort: Mutex<Option<AbortHandle>>, @@ -29,6 +30,7 @@ impl Console { pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> { Arc::new(Self { server: args.server.clone(), + admin: args.depend::<admin::Service>("admin"), worker_join: None.into(), input_abort: None.into(), command_abort: None.into(), @@ -116,7 +118,8 @@ async fn readline(self: &Arc<Self>) -> Result<ReadlineEvent, ReadlineError> { let _suppression = log::Suppress::new(&self.server); let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?; - readline.set_tab_completer(Self::tab_complete); + let self_ = Arc::clone(self); + readline.set_tab_completer(move |line| self_.tab_complete(line)); self.set_history(&mut readline); let future = readline.readline(); @@ -154,7 +157,7 @@ async fn handle(self: Arc<Self>, line: String) { } async fn process(self: Arc<Self>, line: String) { - match services().admin.command_in_place(line, None).await { + match self.admin.command_in_place(line, None).await { Ok(Some(content)) => self.output(content).await, Err(e) => error!("processing command: {e}"), _ => (), @@ -184,9 +187,8 @@ fn add_history(&self, line: String) { history.truncate(HISTORY_LIMIT); } - fn tab_complete(line: &str) -> String { - services() - .admin + fn tab_complete(&self, line: &str) -> String { + self.admin .complete_command(line) .unwrap_or_else(|| line.to_owned()) } diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index 213225ace..c35f8c421 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -25,13 +25,14 @@ pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Re return Ok(()); }; - let state_lock = self.state.mutex.lock(&room_id).await; + let state_lock = self.services.state.mutex.lock(&room_id).await; // Use the server user to grant the new admin's power level - let server_user = &self.globals.server_user; + let server_user = &self.services.globals.server_user; // Invite and join the real user - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -55,7 +56,8 @@ pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Re &state_lock, ) .await?; - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMember, @@ -85,7 +87,8 @@ pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Re users.insert(server_user.clone(), 100.into()); users.insert(user_id.to_owned(), 100.into()); - self.timeline + self.services + .timeline .build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomPowerLevels, @@ -105,12 +108,12 @@ pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Re .await?; // Send welcome message - self.timeline.build_and_append_pdu( + self.services.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.globals.server_name()), - format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.globals.server_name()), + format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.services.globals.server_name()), + format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.services.globals.server_name()), )) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index fcb342129..8b9473a20 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -22,15 +22,10 @@ use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; -use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local}; +use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local, Dep}; pub struct Service { - server: Arc<Server>, - globals: Arc<globals::Service>, - alias: Arc<rooms::alias::Service>, - timeline: Arc<rooms::timeline::Service>, - state: Arc<rooms::state::Service>, - state_cache: Arc<rooms::state_cache::Service>, + services: Services, sender: Sender<Command>, receiver: Mutex<Receiver<Command>>, pub handle: RwLock<Option<Handler>>, @@ -39,6 +34,15 @@ pub struct Service { pub console: Arc<console::Console>, } +struct Services { + server: Arc<Server>, + globals: Dep<globals::Service>, + alias: Dep<rooms::alias::Service>, + timeline: Dep<rooms::timeline::Service>, + state: Dep<rooms::state::Service>, + state_cache: Dep<rooms::state_cache::Service>, +} + #[derive(Debug)] pub struct Command { pub command: String, @@ -58,12 +62,14 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); Ok(Arc::new(Self { - server: args.server.clone(), - globals: args.require_service::<globals::Service>("globals"), - alias: args.require_service::<rooms::alias::Service>("rooms::alias"), - timeline: args.require_service::<rooms::timeline::Service>("rooms::timeline"), - state: args.require_service::<rooms::state::Service>("rooms::state"), - state_cache: args.require_service::<rooms::state_cache::Service>("rooms::state_cache"), + services: Services { + server: args.server.clone(), + globals: args.depend::<globals::Service>("globals"), + alias: args.depend::<rooms::alias::Service>("rooms::alias"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + state: args.depend::<rooms::state::Service>("rooms::state"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + }, sender, receiver: Mutex::new(receiver), handle: RwLock::new(None), @@ -75,7 +81,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { async fn worker(self: Arc<Self>) -> Result<()> { let receiver = self.receiver.lock().await; - let mut signals = self.server.signal.subscribe(); + let mut signals = self.services.server.signal.subscribe(); loop { tokio::select! { command = receiver.recv_async() => match command { @@ -116,7 +122,7 @@ pub async fn send_text(&self, body: &str) { pub async fn send_message(&self, message_content: RoomMessageEventContent) { if let Ok(Some(room_id)) = self.get_admin_room() { - let user_id = &self.globals.server_user; + let user_id = &self.services.globals.server_user; self.respond_to_room(message_content, &room_id, user_id) .await; } @@ -176,7 +182,7 @@ async fn process_command(&self, command: Command) -> CommandResult { /// Checks whether a given user is an admin of this server pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { if let Ok(Some(admin_room)) = self.get_admin_room() { - self.state_cache.is_joined(user_id, &admin_room) + self.services.state_cache.is_joined(user_id, &admin_room) } else { Ok(false) } @@ -187,10 +193,15 @@ pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> { /// Errors are propagated from the database, and will have None if there is /// no admin room pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> { - if let Some(room_id) = self.alias.resolve_local_alias(&self.globals.admin_alias)? { + if let Some(room_id) = self + .services + .alias + .resolve_local_alias(&self.services.globals.admin_alias)? + { if self + .services .state_cache - .is_joined(&self.globals.server_user, &room_id)? + .is_joined(&self.services.globals.server_user, &room_id)? { return Ok(Some(room_id)); } @@ -207,12 +218,12 @@ async fn handle_response(&self, content: RoomMessageEventContent) { return; }; - let Ok(Some(pdu)) = self.timeline.get_pdu(&in_reply_to.event_id) else { + let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else { return; }; let response_sender = if self.is_admin_room(&pdu.room_id) { - &self.globals.server_user + &self.services.globals.server_user } else { &pdu.sender }; @@ -229,7 +240,7 @@ async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &Room "sender is not admin" ); - let state_lock = self.state.mutex.lock(room_id).await; + let state_lock = self.services.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -239,6 +250,7 @@ async fn respond_to_room(&self, content: RoomMessageEventContent, room_id: &Room }; if let Err(e) = self + .services .timeline .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .await @@ -266,7 +278,8 @@ async fn handle_response_error( redacts: None, }; - self.timeline + self.services + .timeline .build_and_append_pdu(response_pdu, user_id, room_id, state_lock) .await?; @@ -279,7 +292,7 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin"); // Admin command with public echo (in admin room) - let server_user = &self.globals.server_user; + let server_user = &self.services.globals.server_user; let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str()); // Expected backward branch @@ -293,7 +306,7 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { } // Check if server-side command-escape is disabled by configuration - if is_public_escape && !self.globals.config.admin_escape_commands { + if is_public_escape && !self.services.globals.config.admin_escape_commands { return false; } @@ -309,7 +322,7 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { // This will evaluate to false if the emergency password is set up so that // the administrator can execute commands as conduit - let emergency_password_set = self.globals.emergency_password().is_some(); + let emergency_password_set = self.services.globals.emergency_password().is_some(); let from_server = pdu.sender == *server_user && !emergency_password_set; if from_server && self.is_admin_room(&pdu.room_id) { return false; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 24c9b8b07..c0752d565 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -12,7 +12,7 @@ }; use tokio::sync::RwLock; -use crate::services; +use crate::{sending, Dep}; /// Compiled regular expressions for a namespace #[derive(Clone, Debug)] @@ -118,9 +118,14 @@ fn try_from(value: Registration) -> Result<Self, regex::Error> { pub struct Service { pub db: Data, + services: Services, registration_info: RwLock<BTreeMap<String, RegistrationInfo>>, } +struct Services { + sending: Dep<sending::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let mut registration_info = BTreeMap::new(); @@ -138,6 +143,9 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { db, + services: Services { + sending: args.depend::<sending::Service>("sending"), + }, registration_info: RwLock::new(registration_info), })) } @@ -178,7 +186,9 @@ pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { // deletes all active requests for the appservice if there are any so we stop // sending to the URL - services().sending.cleanup_events(service_name.to_owned())?; + self.services + .sending + .cleanup_events(service_name.to_owned())?; Ok(()) } diff --git a/src/service/client/mod.rs b/src/service/client/mod.rs index 03b0a1425..386bd33ca 100644 --- a/src/service/client/mod.rs +++ b/src/service/client/mod.rs @@ -18,7 +18,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; - let resolver = args.require_service::<resolver::Service>("resolver"); + let resolver = args.require::<resolver::Service>("resolver"); Ok(Arc::new(Self { default: base(config) diff --git a/src/service/emergency/mod.rs b/src/service/emergency/mod.rs new file mode 100644 index 000000000..1bb0843d4 --- /dev/null +++ b/src/service/emergency/mod.rs @@ -0,0 +1,83 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use conduit::{error, warn, Result}; +use ruma::{ + events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType}, + push::Ruleset, +}; + +use crate::{account_data, globals, users, Dep}; + +pub struct Service { + services: Services, +} + +struct Services { + account_data: Dep<account_data::Service>, + globals: Dep<globals::Service>, + users: Dep<users::Service>, +} + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + Ok(Arc::new(Self { + services: Services { + account_data: args.depend::<account_data::Service>("account_data"), + globals: args.depend::<globals::Service>("globals"), + users: args.depend::<users::Service>("users"), + }, + })) + } + + async fn worker(self: Arc<Self>) -> Result<()> { + self.set_emergency_access() + .inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?; + + Ok(()) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + /// Sets the emergency password and push rules for the @conduit account in + /// case emergency password is set + fn set_emergency_access(&self) -> Result<bool> { + let conduit_user = &self.services.globals.server_user; + + self.services + .users + .set_password(conduit_user, self.services.globals.emergency_password().as_deref())?; + + let (ruleset, pwd_set) = match self.services.globals.emergency_password() { + Some(_) => (Ruleset::server_default(conduit_user), true), + None => (Ruleset::new(), false), + }; + + self.services.account_data.update( + None, + conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + )?; + + if pwd_set { + warn!( + "The server account emergency password is set! Please unset it as soon as you finish admin account \ + recovery! You will be logged out of the server service account when you finish." + ); + } else { + // logs out any users still in the server service account and removes sessions + self.services.users.deactivate_account(conduit_user)?; + } + + Ok(pwd_set) + } +} diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 281c2a94a..5d6240cd4 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -3,7 +3,7 @@ sync::{Arc, RwLock}, }; -use conduit::{trace, utils, Error, Result}; +use conduit::{trace, utils, Error, Result, Server}; use database::{Database, Map}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -12,7 +12,7 @@ DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; -use crate::services; +use crate::{rooms, Dep}; pub struct Data { global: Arc<Map>, @@ -28,14 +28,23 @@ pub struct Data { server_signingkeys: Arc<Map>, readreceiptid_readreceipt: Arc<Map>, userid_lastonetimekeyupdate: Arc<Map>, - pub(super) db: Arc<Database>, counter: RwLock<u64>, + pub(super) db: Arc<Database>, + services: Services, +} + +struct Services { + server: Arc<Server>, + short: Dep<rooms::short::Service>, + state_cache: Dep<rooms::state_cache::Service>, + typing: Dep<rooms::typing::Service>, } const COUNTER: &[u8] = b"c"; impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { global: db["global"].clone(), todeviceid_events: db["todeviceid_events"].clone(), @@ -50,8 +59,14 @@ pub(super) fn new(db: &Arc<Database>) -> Self { server_signingkeys: db["server_signingkeys"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - db: db.clone(), counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")), + db: args.db.clone(), + services: Services { + server: args.server.clone(), + short: args.depend::<rooms::short::Service>("rooms::short"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + typing: args.depend::<rooms::typing::Service>("rooms::typing"), + }, } } @@ -118,14 +133,14 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in services() - .rooms + for room_id in self + .services .state_cache .rooms_joined(user_id) .filter_map(Result::ok) { - let short_roomid = services() - .rooms + let short_roomid = self + .services .short .get_shortroomid(&room_id) .ok() @@ -143,7 +158,7 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> // EDUs futures.push(Box::pin(async move { - let _result = services().rooms.typing.wait_for_update(&room_id).await; + let _result = self.services.typing.wait_for_update(&room_id).await; })); futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); @@ -176,12 +191,12 @@ pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); futures.push(Box::pin(async move { - while services().server.running() { - let _result = services().server.signal.subscribe().recv().await; + while self.services.server.running() { + let _result = self.services.server.signal.subscribe().recv().await; } })); - if !services().server.running() { + if !self.services.server.running() { return Ok(()); } diff --git a/src/service/globals/emerg_access.rs b/src/service/globals/emerg_access.rs deleted file mode 100644 index 50c5f8c3d..000000000 --- a/src/service/globals/emerg_access.rs +++ /dev/null @@ -1,54 +0,0 @@ -use conduit::Result; -use ruma::{ - events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType}, - push::Ruleset, -}; -use tracing::{error, warn}; - -use crate::services; - -/// Set emergency access for the conduit user -pub(crate) fn init_emergency_access() { - if let Err(e) = set_emergency_access() { - error!("Could not set the configured emergency password for the conduit user: {e}"); - } -} - -/// Sets the emergency password and push rules for the @conduit account in case -/// emergency password is set -fn set_emergency_access() -> Result<bool> { - let conduit_user = &services().globals.server_user; - - services() - .users - .set_password(conduit_user, services().globals.emergency_password().as_deref())?; - - let (ruleset, pwd_set) = match services().globals.emergency_password() { - Some(_) => (Ruleset::server_default(conduit_user), true), - None => (Ruleset::new(), false), - }; - - services().account_data.update( - None, - conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { - global: ruleset, - }, - }) - .expect("to json value always works"), - )?; - - if pwd_set { - warn!( - "The server account emergency password is set! Please unset it as soon as you finish admin account \ - recovery! You will be logged out of the server service account when you finish." - ); - } else { - // logs out any users still in the server service account and removes sessions - services().users.deactivate_account(conduit_user)?; - } - - Ok(pwd_set) -} diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index d8c5f29b8..2fe22b0e0 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -10,7 +10,6 @@ }; use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Config, Error, Result}; -use database::Database; use itertools::Itertools; use ruma::{ events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType}, @@ -18,7 +17,7 @@ EventId, OwnedRoomId, RoomId, UserId, }; -use crate::services; +use crate::Services; /// The current schema version. /// - If database is opened at greater version we reject with error. The @@ -28,13 +27,13 @@ /// equal or lesser version. These are expected to be backward-compatible. const DATABASE_VERSION: u64 = 13; -pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<()> { +pub(crate) async fn migrations(services: &Services) -> Result<()> { // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if services().users.count()? > 0 { - let conduit_user = &services().globals.server_user; + if services.users.count()? > 0 { + let conduit_user = &services.globals.server_user; - if !services().users.exists(conduit_user)? { + if !services.users.exists(conduit_user)? { error!("The {} server user does not exist, and the database is not new.", conduit_user); return Err(Error::bad_database( "Cannot reuse an existing database after changing the server name, please delete the old one first.", @@ -42,15 +41,18 @@ pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<() } } - if services().users.count()? > 0 { - migrate(db, config).await + if services.users.count()? > 0 { + migrate(services).await } else { - fresh(db, config).await + fresh(services).await } } -async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> { - services() +async fn fresh(services: &Services) -> Result<()> { + let db = &services.db; + let config = &services.server.config; + + services .globals .db .bump_database_version(DATABASE_VERSION)?; @@ -70,97 +72,100 @@ async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> { } /// Apply any migrations -async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> { - if services().globals.db.database_version()? < 1 { - db_lt_1(db, config).await?; +async fn migrate(services: &Services) -> Result<()> { + let db = &services.db; + let config = &services.server.config; + + if services.globals.db.database_version()? < 1 { + db_lt_1(services).await?; } - if services().globals.db.database_version()? < 2 { - db_lt_2(db, config).await?; + if services.globals.db.database_version()? < 2 { + db_lt_2(services).await?; } - if services().globals.db.database_version()? < 3 { - db_lt_3(db, config).await?; + if services.globals.db.database_version()? < 3 { + db_lt_3(services).await?; } - if services().globals.db.database_version()? < 4 { - db_lt_4(db, config).await?; + if services.globals.db.database_version()? < 4 { + db_lt_4(services).await?; } - if services().globals.db.database_version()? < 5 { - db_lt_5(db, config).await?; + if services.globals.db.database_version()? < 5 { + db_lt_5(services).await?; } - if services().globals.db.database_version()? < 6 { - db_lt_6(db, config).await?; + if services.globals.db.database_version()? < 6 { + db_lt_6(services).await?; } - if services().globals.db.database_version()? < 7 { - db_lt_7(db, config).await?; + if services.globals.db.database_version()? < 7 { + db_lt_7(services).await?; } - if services().globals.db.database_version()? < 8 { - db_lt_8(db, config).await?; + if services.globals.db.database_version()? < 8 { + db_lt_8(services).await?; } - if services().globals.db.database_version()? < 9 { - db_lt_9(db, config).await?; + if services.globals.db.database_version()? < 9 { + db_lt_9(services).await?; } - if services().globals.db.database_version()? < 10 { - db_lt_10(db, config).await?; + if services.globals.db.database_version()? < 10 { + db_lt_10(services).await?; } - if services().globals.db.database_version()? < 11 { - db_lt_11(db, config).await?; + if services.globals.db.database_version()? < 11 { + db_lt_11(services).await?; } - if services().globals.db.database_version()? < 12 { - db_lt_12(db, config).await?; + if services.globals.db.database_version()? < 12 { + db_lt_12(services).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services().globals.db.database_version()? < 13 { - db_lt_13(db, config).await?; + if services.globals.db.database_version()? < 13 { + db_lt_13(services).await?; } if db["global"].get(b"feat_sha256_media")?.is_none() { - migrate_sha256_media(db, config).await?; + migrate_sha256_media(services).await?; } else if config.media_startup_check { - checkup_sha256_media(db, config).await?; + checkup_sha256_media(services).await?; } if db["global"] .get(b"fix_bad_double_separator_in_state_cache")? .is_none() { - fix_bad_double_separator_in_state_cache(db, config).await?; + fix_bad_double_separator_in_state_cache(services).await?; } if db["global"] .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? .is_none() { - retroactively_fix_bad_data_from_roomuserid_joined(db, config).await?; + retroactively_fix_bad_data_from_roomuserid_joined(services).await?; } assert_eq!( - services().globals.db.database_version().unwrap(), + services.globals.db.database_version().unwrap(), DATABASE_VERSION, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services().globals.db.database_version().unwrap(), + services.globals.db.database_version().unwrap(), DATABASE_VERSION, ); { - let patterns = services().globals.forbidden_usernames(); + let patterns = services.globals.forbidden_usernames(); if !patterns.is_empty() { - for user_id in services() + for user_id in services .users .iter() .filter_map(Result::ok) - .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)) + .filter(|user| !services.users.is_deactivated(user).unwrap_or(true)) .filter(|user| user.server_name() == config.server_name) { let matches = patterns.matches(user_id.localpart()); @@ -179,11 +184,11 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> { } { - let patterns = services().globals.forbidden_alias_names(); + let patterns = services.globals.forbidden_alias_names(); if !patterns.is_empty() { - for address in services().rooms.metadata.iter_ids() { + for address in services.rooms.metadata.iter_ids() { let room_id = address?; - let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id); + let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id); for room_alias_result in room_aliases { let room_alias = room_alias_result?; let matches = patterns.matches(room_alias.alias()); @@ -211,7 +216,9 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> { Ok(()) } -async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_1(services: &Services) -> Result<()> { + let db = &services.db; + let roomserverids = &db["roomserverids"]; let serverroomids = &db["serverroomids"]; for (roomserverid, _) in roomserverids.iter() { @@ -228,12 +235,14 @@ async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> { serverroomids.insert(&serverroomid, &[])?; } - services().globals.db.bump_database_version(1)?; + services.globals.db.bump_database_version(1)?; info!("Migration: 0 -> 1 finished"); Ok(()) } -async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_2(services: &Services) -> Result<()> { + let db = &services.db; + // We accidentally inserted hashed versions of "" into the db instead of just "" let userid_password = &db["roomserverids"]; for (userid, password) in userid_password.iter() { @@ -245,12 +254,14 @@ async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> { } } - services().globals.db.bump_database_version(2)?; + services.globals.db.bump_database_version(2)?; info!("Migration: 1 -> 2 finished"); Ok(()) } -async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_3(services: &Services) -> Result<()> { + let db = &services.db; + // Move media to filesystem let mediaid_file = &db["mediaid_file"]; for (key, content) in mediaid_file.iter() { @@ -259,41 +270,45 @@ async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> { } #[allow(deprecated)] - let path = services().media.get_media_file(&key); + let path = services.media.get_media_file(&key); let mut file = fs::File::create(path)?; file.write_all(&content)?; mediaid_file.insert(&key, &[])?; } - services().globals.db.bump_database_version(3)?; + services.globals.db.bump_database_version(3)?; info!("Migration: 2 -> 3 finished"); Ok(()) } -async fn db_lt_4(_db: &Arc<Database>, config: &Config) -> Result<()> { - // Add federated users to services() as deactivated - for our_user in services().users.iter() { +async fn db_lt_4(services: &Services) -> Result<()> { + let config = &services.server.config; + + // Add federated users to services as deactivated + for our_user in services.users.iter() { let our_user = our_user?; - if services().users.is_deactivated(&our_user)? { + if services.users.is_deactivated(&our_user)? { continue; } - for room in services().rooms.state_cache.rooms_joined(&our_user) { - for user in services().rooms.state_cache.room_members(&room?) { + for room in services.rooms.state_cache.rooms_joined(&our_user) { + for user in services.rooms.state_cache.room_members(&room?) { let user = user?; if user.server_name() != config.server_name { info!(?user, "Migration: creating user"); - services().users.create(&user, None)?; + services.users.create(&user, None)?; } } } } - services().globals.db.bump_database_version(4)?; + services.globals.db.bump_database_version(4)?; info!("Migration: 3 -> 4 finished"); Ok(()) } -async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_5(services: &Services) -> Result<()> { + let db = &services.db; + // Upgrade user data store let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"]; let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"]; @@ -312,26 +327,30 @@ async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> { roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; } - services().globals.db.bump_database_version(5)?; + services.globals.db.bump_database_version(5)?; info!("Migration: 4 -> 5 finished"); Ok(()) } -async fn db_lt_6(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_6(services: &Services) -> Result<()> { + let db = &services.db; + // Set room member count let roomid_shortstatehash = &db["roomid_shortstatehash"]; for (roomid, _) in roomid_shortstatehash.iter() { let string = utils::string_from_bytes(&roomid).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services().rooms.state_cache.update_joined_count(room_id)?; + services.rooms.state_cache.update_joined_count(room_id)?; } - services().globals.db.bump_database_version(6)?; + services.globals.db.bump_database_version(6)?; info!("Migration: 5 -> 6 finished"); Ok(()) } -async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_7(services: &Services) -> Result<()> { + let db = &services.db; + // Upgrade state store let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); let mut current_sstatehash: Option<u64> = None; @@ -347,7 +366,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { let states_parents = last_roomsstatehash.map_or_else( || Ok(Vec::new()), |&last_roomsstatehash| { - services() + services .rooms .state_compressor .load_shortstatehash_info(last_roomsstatehash) @@ -371,7 +390,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { (current_state, HashSet::new()) }; - services().rooms.state_compressor.save_state_from_diff( + services.rooms.state_compressor.save_state_from_diff( current_sstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), @@ -380,7 +399,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { )?; /* - let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; + let mut tmp = services.rooms.load_shortstatehash_info(¤t_sstatehash)?; let state = tmp.pop().unwrap(); println!( "{}\t{}{:?}: {:?} + {:?} - {:?}", @@ -425,12 +444,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap(); let string = utils::string_from_bytes(&event_id).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services() - .rooms - .timeline - .get_pdu(event_id) - .unwrap() - .unwrap(); + let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap(); if Some(&pdu.room_id) != current_room.as_ref() { current_room = Some(pdu.room_id.clone()); @@ -451,12 +465,14 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(7)?; + services.globals.db.bump_database_version(7)?; info!("Migration: 6 -> 7 finished"); Ok(()) } -async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_8(services: &Services) -> Result<()> { + let db = &services.db; + let roomid_shortstatehash = &db["roomid_shortstatehash"]; let roomid_shortroomid = &db["roomid_shortroomid"]; let pduid_pdu = &db["pduid_pdu"]; @@ -464,7 +480,7 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> { // Generate short room ids for all rooms for (room_id, _) in roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); + let shortroomid = services.globals.next_count()?.to_be_bytes(); roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } @@ -517,12 +533,14 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> { eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - services().globals.db.bump_database_version(8)?; + services.globals.db.bump_database_version(8)?; info!("Migration: 7 -> 8 finished"); Ok(()) } -async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_9(services: &Services) -> Result<()> { + let db = &services.db; + let tokenids = &db["tokenids"]; let roomid_shortroomid = &db["roomid_shortroomid"]; @@ -574,12 +592,14 @@ async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> { tokenids.remove(&key)?; } - services().globals.db.bump_database_version(9)?; + services.globals.db.bump_database_version(9)?; info!("Migration: 8 -> 9 finished"); Ok(()) } -async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn db_lt_10(services: &Services) -> Result<()> { + let db = &services.db; + let statekey_shortstatekey = &db["statekey_shortstatekey"]; let shortstatekey_statekey = &db["shortstatekey_statekey"]; @@ -589,28 +609,30 @@ async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> { } // Force E2EE device list updates so we can send them over federation - for user_id in services().users.iter().filter_map(Result::ok) { - services().users.mark_device_key_update(&user_id)?; + for user_id in services.users.iter().filter_map(Result::ok) { + services.users.mark_device_key_update(&user_id)?; } - services().globals.db.bump_database_version(10)?; + services.globals.db.bump_database_version(10)?; info!("Migration: 9 -> 10 finished"); Ok(()) } #[allow(unreachable_code)] -async fn db_lt_11(_db: &Arc<Database>, _config: &Config) -> Result<()> { - todo!("Dropping a column to clear data is not implemented yet."); +async fn db_lt_11(services: &Services) -> Result<()> { + error!("Dropping a column to clear data is not implemented yet."); //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; //userdevicesessionid_uiaarequest.clear()?; - services().globals.db.bump_database_version(11)?; + services.globals.db.bump_database_version(11)?; info!("Migration: 10 -> 11 finished"); Ok(()) } -async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> { - for username in services().users.list_local_users()? { +async fn db_lt_12(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in services.users.list_local_users()? { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { Ok(u) => u, Err(e) => { @@ -619,7 +641,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> { }, }; - let raw_rules_list = services() + let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .unwrap() @@ -664,7 +686,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> { } } - services().account_data.update( + services.account_data.update( None, &user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -672,13 +694,15 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(12)?; + services.globals.db.bump_database_version(12)?; info!("Migration: 11 -> 12 finished"); Ok(()) } -async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { - for username in services().users.list_local_users()? { +async fn db_lt_13(services: &Services) -> Result<()> { + let config = &services.server.config; + + for username in services.users.list_local_users()? { let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { Ok(u) => u, Err(e) => { @@ -687,7 +711,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { }, }; - let raw_rules_list = services() + let raw_rules_list = services .account_data .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) .unwrap() @@ -701,7 +725,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { .global .update_with_server_default(user_default_rules); - services().account_data.update( + services.account_data.update( None, &user, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -709,7 +733,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { )?; } - services().globals.db.bump_database_version(13)?; + services.globals.db.bump_database_version(13)?; info!("Migration: 12 -> 13 finished"); Ok(()) } @@ -717,15 +741,17 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> { /// Migrates a media directory from legacy base64 file names to sha2 file names. /// All errors are fatal. Upon success the database is keyed to not perform this /// again. -async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn migrate_sha256_media(services: &Services) -> Result<()> { + let db = &services.db; + warn!("Migrating legacy base64 file names to sha256 file names"); let mediaid_file = &db["mediaid_file"]; // Move old media files to new names let mut changes = Vec::<(PathBuf, PathBuf)>::new(); for (key, _) in mediaid_file.iter() { - let old = services().media.get_media_file_b64(&key); - let new = services().media.get_media_file_sha256(&key); + let old = services.media.get_media_file_b64(&key); + let new = services.media.get_media_file_sha256(&key); debug!(?key, ?old, ?new, num = changes.len(), "change"); changes.push((old, new)); } @@ -739,8 +765,8 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<() // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services().globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { - services().globals.db.bump_database_version(13)?; + if services.globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { + services.globals.db.bump_database_version(13)?; } db["global"].insert(b"feat_sha256_media", &[])?; @@ -752,14 +778,16 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<() /// - Going back and forth to non-sha256 legacy binaries (e.g. upstream). /// - Deletion of artifacts in the media directory which will then fall out of /// sync with the database. -async fn checkup_sha256_media(db: &Arc<Database>, config: &Config) -> Result<()> { +async fn checkup_sha256_media(services: &Services) -> Result<()> { use crate::media::encode_key; debug!("Checking integrity of media directory"); + let db = &services.db; + let media = &services.media; + let config = &services.server.config; let mediaid_file = &db["mediaid_file"]; let mediaid_user = &db["mediaid_user"]; let dbs = (mediaid_file, mediaid_user); - let media = &services().media; let timer = Instant::now(); let dir = media.get_media_dir(); @@ -791,6 +819,7 @@ async fn handle_media_check( new_path: &OsStr, old_path: &OsStr, ) -> Result<()> { use crate::media::encode_key; + let (mediaid_file, mediaid_user) = dbs; let old_exists = files.contains(old_path); @@ -827,8 +856,10 @@ async fn handle_media_check( Ok(()) } -async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> { warn!("Fixing bad double separator in state_cache roomuserid_joined"); + + let db = &services.db; let roomuserid_joined = &db["roomuserid_joined"]; let _cork = db.cork_and_sync(); @@ -864,11 +895,13 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &C Ok(()) } -async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _config: &Config) -> Result<()> { +async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> { warn!("Retroactively fixing bad data from broken roomuserid_joined"); + + let db = &services.db; let _cork = db.cork_and_sync(); - let room_ids = services() + let room_ids = services .rooms .metadata .iter_ids() @@ -878,7 +911,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ for room_id in room_ids.clone() { debug_info!("Fixing room {room_id}"); - let users_in_room = services() + let users_in_room = services .rooms .state_cache .room_members(&room_id) @@ -888,7 +921,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ let joined_members = users_in_room .iter() .filter(|user_id| { - services() + services .rooms .state_accessor .get_member(&room_id, user_id) @@ -900,7 +933,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ let non_joined_members = users_in_room .iter() .filter(|user_id| { - services() + services .rooms .state_accessor .get_member(&room_id, user_id) @@ -913,7 +946,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ for user_id in joined_members { debug_info!("User is joined, marking as joined"); - services() + services .rooms .state_cache .mark_as_joined(user_id, &room_id)?; @@ -921,10 +954,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ for user_id in non_joined_members { debug_info!("User is left or banned, marking as left"); - services() - .rooms - .state_cache - .mark_as_left(user_id, &room_id)?; + services.rooms.state_cache.mark_as_left(user_id, &room_id)?; } } @@ -933,7 +963,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _ "Updating joined count for room {room_id} to fix servers in room after correcting membership states" ); - services().rooms.state_cache.update_joined_count(&room_id)?; + services.rooms.state_cache.update_joined_count(&room_id)?; } db.db.cleanup()?; diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 811aff3ad..eab156eee 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,5 +1,4 @@ mod data; -mod emerg_access; pub(super) mod migrations; use std::{ @@ -9,7 +8,6 @@ time::Instant, }; -use async_trait::async_trait; use conduit::{error, trace, Config, Result}; use data::Data; use ipaddress::IPAddress; @@ -43,11 +41,10 @@ pub struct Service { type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries -#[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { + let db = Data::new(&args); let config = &args.server.config; - let db = Data::new(args.db); let keypair = db.load_keypair(); let keypair = match keypair { @@ -104,19 +101,13 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { .supported_room_versions() .contains(&s.config.default_room_version) { - error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); - s.config.default_room_version = crate::config::default_default_room_version(); + error!(config=?s.config.default_room_version, fallback=?conduit::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); + s.config.default_room_version = conduit::config::default_default_room_version(); }; Ok(Arc::new(s)) } - async fn worker(self: Arc<Self>) -> Result<()> { - emerg_access::init_emergency_access(); - - Ok(()) - } - fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { let bad_event_ratelimiter = self .bad_event_ratelimiter diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index f17948959..30ac593b1 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, sync::Arc}; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -11,25 +11,34 @@ OwnedRoomId, RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { backupid_algorithm: Arc<Map>, backupid_etag: Arc<Map>, backupkeyid_backup: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { backupid_algorithm: db["backupid_algorithm"].clone(), backupid_etag: db["backupid_etag"].clone(), backupkeyid_backup: db["backupkeyid_backup"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> { - let version = services().globals.next_count()?.to_string(); + let version = self.services.globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); @@ -40,7 +49,7 @@ pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<Backu &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; Ok(version) } @@ -75,7 +84,7 @@ pub(super) fn update_backup( self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; Ok(version.to_owned()) } @@ -152,7 +161,7 @@ pub(super) fn add_key( } self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + .insert(&key, &self.services.globals.next_count()?.to_be_bytes())?; key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index d83d44972..65d3c065e 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -17,7 +17,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/manager.rs b/src/service/manager.rs index 447cd6fe7..087fd3fac 100644 --- a/src/service/manager.rs +++ b/src/service/manager.rs @@ -8,13 +8,13 @@ time::sleep, }; -use crate::{service::Service, Services}; +use crate::{service, service::Service, Services}; pub(crate) struct Manager { manager: Mutex<Option<JoinHandle<Result<()>>>>, workers: Mutex<Workers>, server: Arc<Server>, - services: &'static Services, + service: Arc<service::Map>, } type Workers = JoinSet<WorkerResult>; @@ -29,7 +29,7 @@ pub(super) fn new(services: &Services) -> Arc<Self> { manager: Mutex::new(None), workers: Mutex::new(JoinSet::new()), server: services.server.clone(), - services: crate::services(), + service: services.service.clone(), }) } @@ -53,9 +53,19 @@ pub(super) async fn start(self: Arc<Self>) -> Result<()> { .spawn(async move { self_.worker().await }), ); + // we can't hold the lock during the iteration with start_worker so the values + // are snapshotted here + let services: Vec<Arc<dyn Service>> = self + .service + .read() + .expect("locked for reading") + .values() + .map(|v| v.0.clone()) + .collect(); + debug!("Starting service workers..."); - for (service, ..) in self.services.service.values() { - self.start_worker(&mut workers, service).await?; + for service in services { + self.start_worker(&mut workers, &service).await?; } Ok(()) diff --git a/src/service/media/data.rs b/src/service/media/data.rs index e5856bbf5..617ec5267 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use conduit::{debug, debug_info, Error, Result}; +use conduit::{debug, debug_info, utils::string_from_bytes, Error, Result}; use database::{Database, Map}; use ruma::api::client::error::ErrorKind; -use crate::{media::UrlPreviewData, utils::string_from_bytes}; +use crate::media::UrlPreviewData; pub(crate) struct Data { mediaid_file: Arc<Map>, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 4a7d38a48..d5d518dca 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -15,7 +15,7 @@ io::{AsyncReadExt, AsyncWriteExt, BufReader}, }; -use crate::services; +use crate::{globals, Dep}; #[derive(Debug)] pub struct FileMeta { @@ -41,16 +41,24 @@ pub struct UrlPreviewData { } pub struct Service { - server: Arc<Server>, + services: Services, pub(crate) db: Data, pub url_preview_mutex: MutexMap<String, ()>, } +struct Services { + server: Arc<Server>, + globals: Dep<globals::Service>, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - server: args.server.clone(), + services: Services { + server: args.server.clone(), + globals: args.depend::<globals::Service>("globals"), + }, db: Data::new(args.db), url_preview_mutex: MutexMap::new(), })) @@ -164,7 +172,7 @@ pub async fn delete_all_remote_media_at_after_time(&self, time: String, force: b debug!("Parsed MXC key to URL: {mxc_s}"); let mxc = OwnedMxcUri::from(mxc_s); - if mxc.server_name() == Ok(services().globals.server_name()) { + if mxc.server_name() == Ok(self.services.globals.server_name()) { debug!("Ignoring local media MXC: {mxc}"); // ignore our own MXC URLs as this would be local media. continue; @@ -246,7 +254,7 @@ async fn remove_media_file(&self, key: &[u8]) -> Result<()> { let legacy_rm = fs::remove_file(&legacy); let (file_rm, legacy_rm) = tokio::join!(file_rm, legacy_rm); if let Err(e) = legacy_rm { - if self.server.config.media_compat_file_link { + if self.services.server.config.media_compat_file_link { debug_error!(?key, ?legacy, "Failed to remove legacy media symlink: {e}"); } } @@ -259,7 +267,7 @@ async fn create_media_file(&self, key: &[u8]) -> Result<fs::File> { debug!(?key, ?path, "Creating media file"); let file = fs::File::create(&path).await?; - if self.server.config.media_compat_file_link { + if self.services.server.config.media_compat_file_link { let legacy = self.get_media_file_b64(key); if let Err(e) = fs::symlink(&path, &legacy).await { debug_error!( @@ -304,7 +312,7 @@ pub fn get_media_file_b64(&self, key: &[u8]) -> PathBuf { #[must_use] pub fn get_media_dir(&self) -> PathBuf { let mut r = PathBuf::new(); - r.push(self.server.config.database_path.clone()); + r.push(self.services.server.config.database_path.clone()); r.push("media"); r } diff --git a/src/service/mod.rs b/src/service/mod.rs index 21d1f5946..6e749c99d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "160"] #![allow(refining_impl_trait)] mod manager; @@ -8,6 +9,7 @@ pub mod admin; pub mod appservice; pub mod client; +pub mod emergency; pub mod globals; pub mod key_backups; pub mod media; @@ -26,8 +28,8 @@ use std::sync::{Arc, RwLock}; -pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server}; pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; +use conduit::{Result, Server}; use database::Database; pub(crate) use service::{Args, Dep, Service}; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index dac8becfc..53f9d8c73 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,21 +1,32 @@ use std::sync::Arc; use conduit::{debug_warn, utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; -use crate::{presence::Presence, services}; +use crate::{globals, presence::Presence, users, Dep}; pub struct Data { presenceid_presence: Arc<Map>, userid_presenceid: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, + users: Dep<users::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { presenceid_presence: db["presenceid_presence"].clone(), userid_presenceid: db["userid_presenceid"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + users: args.depend::<users::Service>("users"), + }, } } @@ -28,7 +39,10 @@ pub fn get_presence(&self, user_id: &UserId) -> Result<Option<(u64, PresenceEven self.presenceid_presence .get(&key)? .map(|presence_bytes| -> Result<(u64, PresenceEvent)> { - Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?)) + Ok(( + count, + Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?, + )) }) .transpose() } else { @@ -80,7 +94,7 @@ pub(super) fn set_presence( last_active_ts, status_msg, ); - let count = services().globals.next_count()?; + let count = self.services.globals.next_count()?; let key = presenceid_key(count, user_id); self.presenceid_presence diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 254304bae..705ac4ffd 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -3,8 +3,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use conduit::{checked, debug, error, utils, Error, Result}; -use data::Data; +use conduit::{checked, debug, error, utils, Error, Result, Server}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -14,7 +13,8 @@ use serde::{Deserialize, Serialize}; use tokio::{sync::Mutex, time::sleep}; -use crate::{services, user_is_local}; +use self::data::Data; +use crate::{user_is_local, users, Dep}; /// Represents data required to be kept in order to implement the presence /// specification. @@ -37,11 +37,6 @@ pub fn new(state: PresenceState, currently_active: bool, last_active_ts: u64, st } } - pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> { - let presence = Self::from_json_bytes(bytes)?; - presence.to_presence_event(user_id) - } - pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> { serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) } @@ -51,7 +46,7 @@ pub fn to_json_bytes(&self) -> Result<Vec<u8>> { } /// Creates a PresenceEvent from available data. - pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> { + pub fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> Result<PresenceEvent> { let now = utils::millis_since_unix_epoch(); let last_active_ago = if self.currently_active { None @@ -66,14 +61,15 @@ pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> { status_msg: self.status_msg.clone(), currently_active: Some(self.currently_active), last_active_ago, - displayname: services().users.displayname(user_id)?, - avatar_url: services().users.avatar_url(user_id)?, + displayname: users.displayname(user_id)?, + avatar_url: users.avatar_url(user_id)?, }, }) } } pub struct Service { + services: Services, pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>, @@ -82,6 +78,11 @@ pub struct Service { offline_timeout: u64, } +struct Services { + server: Arc<Server>, + users: Dep<users::Service>, +} + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { @@ -90,7 +91,11 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let offline_timeout_s = config.presence_offline_timeout_s; let (timer_sender, timer_receiver) = loole::unbounded(); Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + server: args.server.clone(), + users: args.depend::<users::Service>("users"), + }, + db: Data::new(&args), timer_sender, timer_receiver: Mutex::new(timer_receiver), timeout_remote_users: config.presence_timeout_remote_users, @@ -182,8 +187,8 @@ pub fn set_presence( if self.timeout_remote_users || user_is_local(user_id) { let timeout = match presence_state { - PresenceState::Online => services().globals.config.presence_idle_timeout_s, - _ => services().globals.config.presence_offline_timeout_s, + PresenceState::Online => self.services.server.config.presence_idle_timeout_s, + _ => self.services.server.config.presence_offline_timeout_s, }; self.timer_sender @@ -210,6 +215,11 @@ pub fn presence_since(&self, since: u64) -> Box<dyn Iterator<Item = (OwnedUserId self.db.presence_since(since) } + pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> { + let presence = Presence::from_json_bytes(bytes)?; + presence.to_presence_event(user_id, &self.services.users) + } + fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { let mut presence_state = PresenceState::Offline; let mut last_active_ago = None; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 38ea5b9af..873f0f49a 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -3,8 +3,7 @@ use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_info, info, trace, warn, Error, Result}; -use data::Data; +use conduit::{debug_info, info, trace, utils::string_from_bytes, warn, Error, PduEvent, Result}; use ipaddress::IPAddress; use ruma::{ api::{ @@ -23,15 +22,32 @@ uint, RoomId, UInt, UserId, }; -use crate::{services, PduEvent}; +use self::data::Data; +use crate::{client, globals, rooms, users, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + globals: Dep<globals::Service>, + client: Dep<client::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + state_cache: Dep<rooms::state_cache::Service>, + users: Dep<users::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::<globals::Service>("globals"), + client: args.depend::<client::Service>("client"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + users: args.depend::<users::Service>("users"), + }, db: Data::new(args.db), })) } @@ -62,7 +78,7 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin { const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; - let dest = dest.replace(services().globals.notification_push_path(), ""); + let dest = dest.replace(self.services.globals.notification_push_path(), ""); trace!("Push gateway destination: {dest}"); let http_request = request @@ -78,13 +94,13 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin if let Some(url_host) = reqwest_request.url().host_str() { trace!("Checking request URL for IP"); if let Ok(ip) = IPAddress::parse(url_host) { - if !services().globals.valid_cidr_range(&ip) { + if !self.services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } } - let response = services().client.pusher.execute(reqwest_request).await; + let response = self.services.client.pusher.execute(reqwest_request).await; match response { Ok(mut response) => { @@ -93,7 +109,7 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin trace!("Checking response destination's IP"); if let Some(remote_addr) = response.remote_addr() { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - if !services().globals.valid_cidr_range(&ip) { + if !self.services.globals.valid_cidr_range(&ip) { return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); } } @@ -114,7 +130,7 @@ pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::Incomin if !status.is_success() { info!("Push gateway {dest} returned unsuccessful HTTP response ({status})"); - debug_info!("Push gateway response body: {:?}", crate::utils::string_from_bytes(&body)); + debug_info!("Push gateway response body: {:?}", string_from_bytes(&body)); return Err(Error::BadServerResponse("Push gateway returned unsuccessful response")); } @@ -143,8 +159,8 @@ pub async fn send_push_notice( let mut notify = None; let mut tweaks = Vec::new(); - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -195,15 +211,15 @@ pub fn get_actions<'a>( let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), member_count: UInt::try_from( - services() - .rooms + self.services .state_cache .room_joined_count(room_id)? .unwrap_or(1), ) .unwrap_or_else(|_| uint!(0)), user_id: user.to_owned(), - user_display_name: services() + user_display_name: self + .services .users .displayname(user)? .unwrap_or_else(|| user.localpart().to_owned()), @@ -263,9 +279,9 @@ async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, e notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - notifi.sender_display_name = services().users.displayname(&event.sender)?; + notifi.sender_display_name = self.services.users.displayname(&event.sender)?; - notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; + notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?; self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) .await?; diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index b2a00023f..4d5c132f5 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -9,12 +9,9 @@ use ipaddress::IPAddress; use ruma::ServerName; -use crate::{ - resolver::{ - cache::{CachedDest, CachedOverride}, - fed::{add_port_to_hostname, get_ip_with_port, FedDest}, - }, - services, +use crate::resolver::{ + cache::{CachedDest, CachedOverride}, + fed::{add_port_to_hostname, get_ip_with_port, FedDest}, }; #[derive(Clone, Debug)] @@ -40,7 +37,7 @@ pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result<A result } else { cached = false; - validate_dest(server_name)?; + self.validate_dest(server_name)?; self.resolve_actual_dest(server_name, true).await? }; @@ -188,7 +185,8 @@ async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { self.query_and_cache_override(dest, dest, 8448).await?; } - let response = services() + let response = self + .services .client .well_known .get(&format!("https://{dest}/.well-known/matrix/server")) @@ -245,19 +243,14 @@ async fn conditional_query_and_cache_override( #[tracing::instrument(skip_all, name = "ip")] async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { - match services() - .resolver - .raw() - .lookup_ip(hostname.to_owned()) - .await - { - Err(e) => handle_resolve_error(&e), + match self.raw().lookup_ip(hostname.to_owned()).await { + Err(e) => Self::handle_resolve_error(&e), Ok(override_ip) => { if hostname != overname { debug_info!("{overname:?} overriden by {hostname:?}"); } - services().resolver.set_cached_override( + self.set_cached_override( overname.to_owned(), CachedOverride { ips: override_ip.iter().collect(), @@ -295,62 +288,62 @@ async fn lookup_srv( for hostname in hostnames { match lookup_srv(self.raw(), &hostname).await { Ok(result) => return Ok(handle_successful_srv(&result)), - Err(e) => handle_resolve_error(&e)?, + Err(e) => Self::handle_resolve_error(&e)?, } } Ok(None) } -} -#[allow(clippy::single_match_else)] -fn handle_resolve_error(e: &ResolveError) -> Result<()> { - use hickory_resolver::error::ResolveErrorKind; + #[allow(clippy::single_match_else)] + fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; - match *e.kind() { - ResolveErrorKind::NoRecordsFound { - .. - } => { - // Raise to debug_warn if we can find out the result wasn't from cache - debug!("{e}"); - Ok(()) - }, - _ => Err!(error!("DNS {e}")), + match *e.kind() { + ResolveErrorKind::NoRecordsFound { + .. + } => { + // Raise to debug_warn if we can find out the result wasn't from cache + debug!("{e}"); + Ok(()) + }, + _ => Err!(error!("DNS {e}")), + } } -} -fn validate_dest(dest: &ServerName) -> Result<()> { - if dest == services().globals.server_name() { - return Err!("Won't send federation request to ourselves"); - } + fn validate_dest(&self, dest: &ServerName) -> Result<()> { + if dest == self.services.server.config.server_name { + return Err!("Won't send federation request to ourselves"); + } + + if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { + self.validate_dest_ip_literal(dest)?; + } - if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { - validate_dest_ip_literal(dest)?; + Ok(()) } - Ok(()) -} + fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result<()> { + trace!("Destination is an IP literal, checking against IP range denylist.",); + debug_assert!( + dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), + "Destination is not an IP literal." + ); + let ip = IPAddress::parse(dest.host()).map_err(|e| { + debug_error!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; -fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { - trace!("Destination is an IP literal, checking against IP range denylist.",); - debug_assert!( - dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), - "Destination is not an IP literal." - ); - let ip = IPAddress::parse(dest.host()).map_err(|e| { - debug_error!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; + self.validate_ip(&ip)?; - validate_ip(&ip)?; + Ok(()) + } - Ok(()) -} + pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> { + if !self.services.globals.valid_cidr_range(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } -pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { - if !services().globals.valid_cidr_range(ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + Ok(()) } - - Ok(()) } diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 0fba24006..465b59855 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -5,11 +5,10 @@ time::SystemTime, }; -use conduit::trace; +use conduit::{trace, utils::rand}; use ruma::{OwnedServerName, ServerName}; use super::fed::FedDest; -use crate::utils::rand; pub struct Cache { pub destinations: RwLock<WellKnownMap>, // actual_destination, host diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 48ff8813d..457ea9ccc 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -6,14 +6,22 @@ use std::{fmt::Write, sync::Arc}; -use conduit::Result; +use conduit::{Result, Server}; use hickory_resolver::TokioAsyncResolver; use self::{cache::Cache, dns::Resolver}; +use crate::{client, globals, Dep}; pub struct Service { pub cache: Arc<Cache>, pub resolver: Arc<Resolver>, + services: Services, +} + +struct Services { + server: Arc<Server>, + client: Dep<client::Service>, + globals: Dep<globals::Service>, } impl crate::Service for Service { @@ -23,6 +31,11 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { cache: cache.clone(), resolver: Resolver::build(args.server, cache)?, + services: Services { + server: args.server.clone(), + client: args.depend::<client::Service>("client"), + globals: args.depend::<globals::Service>("globals"), + }, })) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 302c21aed..efd2b5b76 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,23 +1,32 @@ use std::sync::Arc; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { alias_userid: Arc<Map>, alias_roomid: Arc<Map>, aliasid_alias: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { alias_userid: db["alias_userid"].clone(), alias_roomid: db["alias_roomid"].clone(), aliasid_alias: db["aliasid_alias"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } @@ -31,7 +40,7 @@ pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: & let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xFF); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; Ok(()) diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 792f5c988..344ab6d2b 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -4,9 +4,8 @@ use std::sync::Arc; use conduit::{err, Error, Result}; -use data::Data; use ruma::{ - api::{appservice, client::error::ErrorKind}, + api::client::error::ErrorKind, events::{ room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, StateEventType, @@ -14,16 +13,33 @@ OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId, }; -use crate::{appservice::RegistrationInfo, server_is_ours, services}; +use self::data::Data; +use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, server_is_ours, Dep}; pub struct Service { db: Data, + services: Services, +} + +struct Services { + admin: Dep<admin::Service>, + appservice: Dep<appservice::Service>, + globals: Dep<globals::Service>, + sending: Dep<sending::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, } impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), + services: Services { + admin: args.depend::<admin::Service>("admin"), + appservice: args.depend::<appservice::Service>("appservice"), + globals: args.depend::<globals::Service>("globals"), + sending: args.depend::<sending::Service>("sending"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + }, })) } @@ -33,7 +49,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { - if alias == services().globals.admin_alias && user_id != services().globals.server_user { + if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user { Err(Error::BadRequest( ErrorKind::forbidden(), "Only the server user can set this alias", @@ -72,10 +88,10 @@ pub async fn resolve_alias( if !server_is_ours(room_alias.server_name()) && (!servers .as_ref() - .is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned())) + .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned())) || servers.as_ref().is_none()) { - return remote::resolve(room_alias, servers).await; + return self.remote_resolve(room_alias, servers).await; } let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? { @@ -111,7 +127,7 @@ async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found.")); }; - let server_user = &services().globals.server_user; + let server_user = &self.services.globals.server_user; // The creator of an alias can remove it if self @@ -119,7 +135,7 @@ async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> .who_created_alias(alias)? .is_some_and(|user| user == user_id) // Server admins can remove any local alias - || services().admin.user_is_admin(user_id).await? + || self.services.admin.user_is_admin(user_id).await? // Always allow the server service account to remove the alias, since there may not be an admin room || server_user == user_id { @@ -127,8 +143,7 @@ async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> // Checking whether the user is able to change canonical aliases of the // room } else if let Some(event) = - services() - .rooms + self.services .state_accessor .room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")? { @@ -140,8 +155,7 @@ async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> // If there is no power levels event, only the room creator can change // canonical aliases } else if let Some(event) = - services() - .rooms + self.services .state_accessor .room_state_get(&room_id, &StateEventType::RoomCreate, "")? { @@ -152,14 +166,16 @@ async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> } async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> { - for appservice in services().appservice.read().await.values() { + use ruma::api::appservice::query::query_room_alias; + + for appservice in self.services.appservice.read().await.values() { if appservice.aliases.is_match(room_alias.as_str()) && matches!( - services() + self.services .sending .send_appservice_request( appservice.registration.clone(), - appservice::query::query_room_alias::v1::Request { + query_room_alias::v1::Request { room_alias: room_alias.to_owned(), }, ) @@ -167,10 +183,7 @@ async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Opt Ok(Some(_opt_result)) ) { return Ok(Some( - services() - .rooms - .alias - .resolve_local_alias(room_alias)? + self.resolve_local_alias(room_alias)? .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, )); } @@ -178,20 +191,27 @@ async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Opt Ok(None) } -} -pub async fn appservice_checks(room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>) -> Result<()> { - if !server_is_ours(room_alias.server_name()) { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); - } + pub async fn appservice_checks( + &self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>, + ) -> Result<()> { + if !server_is_ours(room_alias.server_name()) { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); + } - if let Some(ref info) = appservice_info { - if !info.aliases.is_match(room_alias.as_str()) { - return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); + if let Some(ref info) = appservice_info { + if !info.aliases.is_match(room_alias.as_str()) { + return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); + } + } else if self + .services + .appservice + .is_exclusive_alias(room_alias) + .await + { + return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice.")); } - } else if services().appservice.is_exclusive_alias(room_alias).await { - return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice.")); - } - Ok(()) + Ok(()) + } } diff --git a/src/service/rooms/alias/remote.rs b/src/service/rooms/alias/remote.rs index 7fcd27f55..5d835240b 100644 --- a/src/service/rooms/alias/remote.rs +++ b/src/service/rooms/alias/remote.rs @@ -1,71 +1,75 @@ -use conduit::{debug, debug_info, debug_warn, Error, Result}; +use conduit::{debug, debug_warn, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation}, OwnedRoomId, OwnedServerName, RoomAliasId, }; -use crate::services; +impl super::Service { + pub(super) async fn remote_resolve( + &self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, + ) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { + debug!(?room_alias, ?servers, "resolve"); -pub(super) async fn resolve( - room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, -) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { - debug!(?room_alias, ?servers, "resolve"); + let mut response = self + .services + .sending + .send_federation_request( + room_alias.server_name(), + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await; - let mut response = services() - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; + debug!("room alias server_name get_alias_helper response: {response:?}"); - debug!("room alias server_name get_alias_helper response: {response:?}"); - - if let Err(ref e) = response { - debug_info!( - "Server {} of the original room alias failed to assist in resolving room alias: {e}", - room_alias.server_name() - ); - } + if let Err(ref e) = response { + debug_warn!( + "Server {} of the original room alias failed to assist in resolving room alias: {e}", + room_alias.server_name(), + ); + } - if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { - if let Some(servers) = servers { - for server in servers { - response = services() - .sending - .send_federation_request( - server, - federation::query::get_room_information::v1::Request { - room_alias: room_alias.to_owned(), - }, - ) - .await; - debug!("Got response from server {server} for room aliases: {response:?}"); + if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() { + if let Some(servers) = servers { + for server in servers { + response = self + .services + .sending + .send_federation_request( + server, + federation::query::get_room_information::v1::Request { + room_alias: room_alias.to_owned(), + }, + ) + .await; + debug!("Got response from server {server} for room aliases: {response:?}"); - if let Ok(ref response) = response { - if !response.servers.is_empty() { - break; + if let Ok(ref response) = response { + if !response.servers.is_empty() { + break; + } + debug_warn!( + "Server {server} responded with room aliases, but was empty? Response: {response:?}" + ); } - debug_warn!("Server {server} responded with room aliases, but was empty? Response: {response:?}"); } } } - } - if let Ok(response) = response { - let room_id = response.room_id; + if let Ok(response) = response { + let room_id = response.room_id; - let mut pre_servers = response.servers; - // since the room alis server responded, insert it into the list - pre_servers.push(room_alias.server_name().into()); + let mut pre_servers = response.servers; + // since the room alis server responded, insert it into the list + pre_servers.push(room_alias.server_name().into()); - return Ok((room_id, Some(pre_servers))); - } + return Ok((room_id, Some(pre_servers))); + } - Err(Error::BadRequest( - ErrorKind::NotFound, - "No servers could assist in resolving the room alias", - )) + Err(Error::BadRequest( + ErrorKind::NotFound, + "No servers could assist in resolving the room alias", + )) + } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 4e4682343..6e7c78359 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,8 +3,8 @@ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result, Server}; -use database::{Database, Map}; +use conduit::{utils, utils::math::usize_from_f64, Result}; +use database::Map; use lru_cache::LruCache; pub(super) struct Data { @@ -13,8 +13,9 @@ pub(super) struct Data { } impl Data { - pub(super) fn new(server: &Arc<Server>, db: &Arc<Database>) -> Self { - let config = &server.config; + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; + let config = &args.server.config; let cache_size = f64::from(config.auth_chain_cache_capacity); let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size"); Self { diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 4e8c7bb21..9a1e7e67a 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,19 +6,29 @@ }; use conduit::{debug, error, trace, validated, warn, Err, Result}; -use data::Data; use ruma::{EventId, RoomId}; -use crate::services; +use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + short: Dep<rooms::short::Service>, + timeline: Dep<rooms::timeline::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.server, args.db), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -27,7 +37,7 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { pub async fn event_ids_iter<'a>( - &self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, + &'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>, ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); for starting_event in &starting_events_ { @@ -38,7 +48,7 @@ pub async fn event_ids_iter<'a>( .get_auth_chain(room_id, &starting_events) .await? .into_iter() - .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + .filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok())) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -48,8 +58,8 @@ pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId let started = std::time::Instant::now(); let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, &short) in services() - .rooms + for (i, &short) in self + .services .short .multi_get_or_create_shorteventid(starting_events)? .iter() @@ -140,7 +150,7 @@ fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<H while let Some(event_id) = todo.pop() { trace!(?event_id, "processing auth event"); - match services().rooms.timeline.get_pdu(&event_id) { + match self.services.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( @@ -150,10 +160,7 @@ fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<H ))); } for auth_event in &pdu.auth_events { - let sauthevent = services() - .rooms - .short - .get_or_create_shorteventid(auth_event)?; + let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?; if found.insert(sauthevent) { trace!(?event_id, ?auth_event, "adding auth event to processing queue"); diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 23ec6b6b5..706e6c2e5 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -2,10 +2,10 @@ use std::sync::Arc; -use data::Data; +use conduit::Result; use ruma::{OwnedRoomId, RoomId}; -use crate::Result; +use self::data::Data; pub struct Service { db: Data, diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 6cb23b9fa..fd8e21855 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -10,12 +10,11 @@ }; use conduit::{ - debug, debug_error, debug_info, err, error, info, trace, + debug, debug_error, debug_info, err, error, info, pdu, trace, utils::{math::continue_exponential_backoff_secs, MutexMap}, - warn, Error, Result, + warn, Error, PduEvent, Result, }; use futures_util::Future; -pub use parse_incoming_pdu::parse_incoming_pdu; use ruma::{ api::{ client::error::ErrorKind, @@ -36,13 +35,28 @@ use tokio::sync::RwLock; use super::state_compressor::CompressedStateEvent; -use crate::{pdu, services, PduEvent}; +use crate::{globals, rooms, sending, Dep}; pub struct Service { + services: Services, pub federation_handletime: StdRwLock<HandleTimeMap>, pub mutex_federation: RoomMutexMap, } +struct Services { + globals: Dep<globals::Service>, + sending: Dep<sending::Service>, + auth_chain: Dep<rooms::auth_chain::Service>, + metadata: Dep<rooms::metadata::Service>, + outlier: Dep<rooms::outlier::Service>, + pdu_metadata: Dep<rooms::pdu_metadata::Service>, + short: Dep<rooms::short::Service>, + state: Dep<rooms::state::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + state_compressor: Dep<rooms::state_compressor::Service>, + timeline: Dep<rooms::timeline::Service>, +} + type RoomMutexMap = MutexMap<OwnedRoomId, ()>; type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>; @@ -55,8 +69,21 @@ pub struct Service { AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>; impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::<globals::Service>("globals"), + sending: args.depend::<sending::Service>("sending"), + auth_chain: args.depend::<rooms::auth_chain::Service>("rooms::auth_chain"), + metadata: args.depend::<rooms::metadata::Service>("rooms::metadata"), + outlier: args.depend::<rooms::outlier::Service>("rooms::outlier"), + pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"), + short: args.depend::<rooms::short::Service>("rooms::short"), + state: args.depend::<rooms::state::Service>("rooms::state"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, federation_handletime: HandleTimeMap::new().into(), mutex_federation: RoomMutexMap::new(), })) @@ -114,17 +141,17 @@ pub async fn handle_incoming_pdu<'a>( pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<Option<Vec<u8>>> { // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { + if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? { return Ok(Some(pdu_id.to_vec())); } // 1.1 Check the server is in the room - if !services().rooms.metadata.exists(room_id)? { + if !self.services.metadata.exists(room_id)? { return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); } // 1.2 Check if the room is disabled - if services().rooms.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "Federation of this room is currently disabled on this server.", @@ -147,8 +174,8 @@ pub async fn handle_incoming_pdu<'a>( self.acl_check(sender.server_name(), room_id)?; // Fetch create event - let create_event = services() - .rooms + let create_event = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; @@ -156,8 +183,8 @@ pub async fn handle_incoming_pdu<'a>( // Procure the room version let room_version_id = Self::get_room_version_id(&create_event)?; - let first_pdu_in_room = services() - .rooms + let first_pdu_in_room = self + .services .timeline .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; @@ -208,7 +235,8 @@ pub async fn handle_incoming_pdu<'a>( Ok(()) => continue, Err(e) => { warn!("Prev event {} failed: {}", prev_id, e); - match services() + match self + .services .globals .bad_event_ratelimiter .write() @@ -258,7 +286,7 @@ pub async fn handle_prev_pdu<'a>( create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId, ) -> Result<()> { // Check for disabled again because it might have changed - if services().rooms.metadata.is_disabled(room_id)? { + if self.services.metadata.is_disabled(room_id)? { debug!( "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ event ID {event_id}" @@ -269,7 +297,8 @@ pub async fn handle_prev_pdu<'a>( )); } - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -349,7 +378,7 @@ fn handle_outlier_pdu<'a>( }; // Skip the PDU if it is redacted and we already have it as an outlier event - if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { + if self.services.timeline.get_pdu_json(event_id)?.is_some() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Event was redacted and we already knew about it", @@ -401,7 +430,7 @@ fn handle_outlier_pdu<'a>( // Build map of auth events let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); for id in &incoming_pdu.auth_events { - let Some(auth_event) = services().rooms.timeline.get_pdu(id)? else { + let Some(auth_event) = self.services.timeline.get_pdu(id)? else { warn!("Could not find auth event {}", id); continue; }; @@ -454,8 +483,7 @@ fn handle_outlier_pdu<'a>( trace!("Validation successful."); // 7. Persist the event as an outlier. - services() - .rooms + self.services .outlier .add_pdu_outlier(&incoming_pdu.event_id, &val)?; @@ -470,12 +498,12 @@ pub async fn upgrade_outlier_to_timeline_pdu( origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<Option<Vec<u8>>> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) { return Ok(Some(pduid.to_vec())); } - if services() - .rooms + if self + .services .pdu_metadata .is_event_soft_failed(&incoming_pdu.event_id)? { @@ -521,14 +549,13 @@ pub async fn upgrade_outlier_to_timeline_pdu( &incoming_pdu, None::<PduEvent>, // TODO: third party invite |k, s| { - services() - .rooms + self.services .short .get_shortstatekey(&k.to_string().into(), s) .ok() .flatten() .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) + .and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten()) }, ) .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?; @@ -541,7 +568,7 @@ pub async fn upgrade_outlier_to_timeline_pdu( } debug!("Gathering auth events"); - let auth_events = services().rooms.state.get_auth_events( + let auth_events = self.services.state.get_auth_events( room_id, &incoming_pdu.kind, &incoming_pdu.sender, @@ -562,7 +589,7 @@ pub async fn upgrade_outlier_to_timeline_pdu( && match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &incoming_pdu.redacts { - !services().rooms.state_accessor.user_can_redact( + !self.services.state_accessor.user_can_redact( redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, @@ -577,7 +604,7 @@ pub async fn upgrade_outlier_to_timeline_pdu( .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - !services().rooms.state_accessor.user_can_redact( + !self.services.state_accessor.user_can_redact( redact_id, &incoming_pdu.sender, &incoming_pdu.room_id, @@ -594,12 +621,12 @@ pub async fn upgrade_outlier_to_timeline_pdu( // We start looking at current room state now, so lets lock the room trace!("Locking the room"); - let state_lock = services().rooms.state.mutex.lock(room_id).await; + let state_lock = self.services.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) trace!("Calculating extremities"); - let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; + let mut extremities = self.services.state.get_forward_extremities(room_id)?; trace!("Calculated {} extremities", extremities.len()); // Remove any forward extremities that are referenced by this incoming event's @@ -609,22 +636,13 @@ pub async fn upgrade_outlier_to_timeline_pdu( } // Only keep those extremities were not referenced yet - extremities.retain(|id| { - !matches!( - services() - .rooms - .pdu_metadata - .is_event_referenced(room_id, id), - Ok(true) - ) - }); + extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); debug!("Retained {} extremities. Compressing state", extremities.len()); let state_ids_compressed = Arc::new( state_at_incoming_event .iter() .map(|(shortstatekey, id)| { - services() - .rooms + self.services .state_compressor .compress_state_event(*shortstatekey, id) }) @@ -637,8 +655,8 @@ pub async fn upgrade_outlier_to_timeline_pdu( // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; @@ -651,13 +669,12 @@ pub async fn upgrade_outlier_to_timeline_pdu( // Set the new room state to the resolved state debug!("Forcing new room state"); - let (sstatehash, new, removed) = services() - .rooms + let (sstatehash, new, removed) = self + .services .state_compressor .save_state(room_id, new_room_state)?; - services() - .rooms + self.services .state .force_state(room_id, sstatehash, new, removed, &state_lock) .await?; @@ -667,8 +684,7 @@ pub async fn upgrade_outlier_to_timeline_pdu( // if not soft fail it if soft_fail { debug!("Soft failing event"); - services() - .rooms + self.services .timeline .append_incoming_pdu( &incoming_pdu, @@ -682,8 +698,7 @@ pub async fn upgrade_outlier_to_timeline_pdu( // Soft fail, we keep the event as an outlier but don't add it to the timeline warn!("Event was soft failed: {:?}", incoming_pdu); - services() - .rooms + self.services .pdu_metadata .mark_event_soft_failed(&incoming_pdu.event_id)?; @@ -696,8 +711,8 @@ pub async fn upgrade_outlier_to_timeline_pdu( // Now that the event has passed all auth it is added into the timeline. // We use the `state_at_event` instead of `state_after` so we accurately // represent the state for this event. - let pdu_id = services() - .rooms + let pdu_id = self + .services .timeline .append_incoming_pdu( &incoming_pdu, @@ -723,14 +738,14 @@ pub async fn resolve_state( &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, ) -> Result<Arc<HashSet<CompressedStateEvent>>> { debug!("Loading current room state ids"); - let current_sstatehash = services() - .rooms + let current_sstatehash = self + .services .state .get_room_shortstatehash(room_id)? .expect("every room has state"); - let current_state_ids = services() - .rooms + let current_state_ids = self + .services .state_accessor .state_full_ids(current_sstatehash) .await?; @@ -740,8 +755,7 @@ pub async fn resolve_state( let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { auth_chain_sets.push( - services() - .rooms + self.services .auth_chain .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) .await? @@ -755,8 +769,7 @@ pub async fn resolve_state( .map(|map| { map.into_iter() .filter_map(|(k, id)| { - services() - .rooms + self.services .short .get_statekey_from_short(k) .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) @@ -766,11 +779,11 @@ pub async fn resolve_state( }) .collect(); - let lock = services().globals.stateres_mutex.lock(); + let lock = self.services.globals.stateres_mutex.lock(); debug!("Resolving state"); let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); + let res = self.services.timeline.get_pdu(id); if let Err(e) = &res { error!("Failed to fetch event: {}", e); } @@ -793,12 +806,11 @@ pub async fn resolve_state( let new_room_state = state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - services() - .rooms + self.services .state_compressor .compress_state_event(shortstatekey, &event_id) }) @@ -814,15 +826,14 @@ pub async fn state_at_incoming_degree_one( &self, incoming_pdu: &Arc<PduEvent>, ) -> Result<Option<HashMap<u64, Arc<EventId>>>> { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = services() - .rooms + let prev_event_sstatehash = self + .services .state_accessor .pdu_shortstatehash(prev_event)?; let state = if let Some(shortstatehash) = prev_event_sstatehash { Some( - services() - .rooms + self.services .state_accessor .state_full_ids(shortstatehash) .await, @@ -833,8 +844,8 @@ pub async fn state_at_incoming_degree_one( if let Some(Ok(mut state)) = state { debug!("Using cached state"); - let prev_pdu = services() - .rooms + let prev_pdu = self + .services .timeline .get_pdu(prev_event) .ok() @@ -842,8 +853,8 @@ pub async fn state_at_incoming_degree_one( .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; @@ -866,13 +877,13 @@ pub async fn state_at_incoming_resolved( let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let Ok(Some(prev_event)) = services().rooms.timeline.get_pdu(prev_eventid) else { + let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else { okay = false; break; }; - let Ok(Some(sstatehash)) = services() - .rooms + let Ok(Some(sstatehash)) = self + .services .state_accessor .pdu_shortstatehash(prev_eventid) else { @@ -891,15 +902,15 @@ pub async fn state_at_incoming_resolved( let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = services() - .rooms + let mut leaf_state: HashMap<_, _> = self + .services .state_accessor .state_full_ids(sstatehash) .await?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); @@ -910,7 +921,7 @@ pub async fn state_at_incoming_resolved( let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); @@ -921,8 +932,7 @@ pub async fn state_at_incoming_resolved( } auth_chain_sets.push( - services() - .rooms + self.services .auth_chain .event_ids_iter(room_id, starting_events) .await? @@ -932,9 +942,9 @@ pub async fn state_at_incoming_resolved( fork_states.push(state); } - let lock = services().globals.stateres_mutex.lock(); + let lock = self.services.globals.stateres_mutex.lock(); let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); + let res = self.services.timeline.get_pdu(id); if let Err(e) = &res { error!("Failed to fetch event: {}", e); } @@ -947,8 +957,8 @@ pub async fn state_at_incoming_resolved( new_state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; Ok((shortstatekey, event_id)) @@ -974,7 +984,8 @@ async fn fetch_state( pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId, ) -> Result<Option<HashMap<u64, Arc<EventId>>>> { debug!("Fetching state ids"); - match services() + match self + .services .sending .send_federation_request( origin, @@ -1004,8 +1015,8 @@ async fn fetch_state( .clone() .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; @@ -1022,8 +1033,8 @@ async fn fetch_state( } // The original create event must still be in the state - let create_shortstatekey = services() - .rooms + let create_shortstatekey = self + .services .short .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); @@ -1056,7 +1067,8 @@ pub fn fetch_and_handle_outliers<'a>( ) -> AsyncRecursiveCanonicalJsonVec<'a> { Box::pin(async move { let back_off = |id| async { - match services() + match self + .services .globals .bad_event_ratelimiter .write() @@ -1075,7 +1087,7 @@ pub fn fetch_and_handle_outliers<'a>( // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) { trace!("Found {} in db", id); events_with_auth_events.push((id, Some(local_pdu), vec![])); continue; @@ -1089,7 +1101,8 @@ pub fn fetch_and_handle_outliers<'a>( let mut events_all = HashSet::with_capacity(todo_auth_events.len()); let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -1114,13 +1127,14 @@ pub fn fetch_and_handle_outliers<'a>( tokio::task::yield_now().await; } - if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) { trace!("Found {} in db", next_id); continue; } debug!("Fetching {} over federation.", next_id); - match services() + match self + .services .sending .send_federation_request( origin, @@ -1195,7 +1209,8 @@ pub fn fetch_and_handle_outliers<'a>( pdus.push((local_pdu, None)); } for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = services() + if let Some((time, tries)) = self + .services .globals .bad_event_ratelimiter .read() @@ -1244,8 +1259,8 @@ async fn fetch_prev( let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; - let first_pdu_in_room = services() - .rooms + let first_pdu_in_room = self + .services .timeline .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; @@ -1267,19 +1282,18 @@ async fn fetch_prev( { Self::check_room_id(room_id, &pdu)?; - if amount > services().globals.max_fetch_prev_events() { + if amount > self.services.globals.max_fetch_prev_events() { // Max limit reached debug!( "Max prev event limit reached! Limit: {}", - services().globals.max_fetch_prev_events() + self.services.globals.max_fetch_prev_events() ); graph.insert(prev_event_id.clone(), HashSet::new()); continue; } if let Some(json) = json_opt.or_else(|| { - services() - .rooms + self.services .outlier .get_outlier_pdu_json(&prev_event_id) .ok() @@ -1335,8 +1349,7 @@ async fn fetch_prev( #[tracing::instrument(skip_all)] pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { let acl_event = if let Some(acl) = - services() - .rooms + self.services .state_accessor .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 8fcd85496..a19862a54 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,29 +1,28 @@ -use conduit::{Err, Error, Result}; +use conduit::{pdu::gen_event_id_canonical_json, warn, Err, Error, Result}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; -use tracing::warn; -use crate::{pdu::gen_event_id_canonical_json, services}; +impl super::Service { + pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {pdu:?}: {e:?}"); + Error::BadServerResponse("Invalid PDU in server response") + })?; -pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {pdu:?}: {e:?}"); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; + let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else { + return Err!("Server is not in room {room_id}"); + }; - let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { - return Err!("Server is not in room {room_id}"); - }; + let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { + // Event could not be converted to canonical json + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); + }; - let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { - // Event could not be converted to canonical json - return Err!(Request(InvalidParam("Could not convert event to canonical json."))); - }; - - Ok((event_id, value, room_id)) + Ok((event_id, value, room_id)) + } } diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index 2fa5b0df0..1ebcbefb3 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -3,7 +3,7 @@ time::{Duration, SystemTime}, }; -use conduit::{debug, error, info, trace, warn}; +use conduit::{debug, error, info, trace, warn, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ @@ -21,8 +21,6 @@ use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard}; -use crate::{services, Error, Result}; - impl super::Service { pub async fn fetch_required_signing_keys<'a, E>( &'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, @@ -147,7 +145,8 @@ async fn get_server_keys_from_cache( debug!("Loading signing keys for {}", origin); - let result: BTreeMap<_, _> = services() + let result: BTreeMap<_, _> = self + .services .globals .signing_keys_for(origin)? .into_iter() @@ -171,9 +170,10 @@ async fn batch_request_signing_keys( &self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<()> { - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking batch signing keys from trusted server {}", server); - match services() + match self + .services .sending .send_federation_request( server, @@ -199,7 +199,8 @@ async fn batch_request_signing_keys( // TODO: Check signature from trusted server? servers.remove(&k.server_name); - let result = services() + let result = self + .services .globals .db .add_signing_key(&k.server_name, k.clone())? @@ -234,7 +235,7 @@ async fn request_signing_keys( .into_keys() .map(|server| async move { ( - services() + self.services .sending .send_federation_request(&server, get_server_keys::v2::Request::new()) .await, @@ -248,7 +249,8 @@ async fn request_signing_keys( if let (Ok(get_keys_response), origin) = result { debug!("Result is from {origin}"); if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = services() + let result: BTreeMap<_, _> = self + .services .globals .db .add_signing_key(&origin, key)? @@ -297,7 +299,7 @@ pub async fn fetch_join_signing_keys( return Ok(()); } - if services().globals.query_trusted_key_servers_first() { + if self.services.globals.query_trusted_key_servers_first() { info!( "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ homeserver signing keys." @@ -349,7 +351,8 @@ pub async fn fetch_signing_keys_for_server( ) -> Result<BTreeMap<String, Base64>> { let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - let mut result: BTreeMap<_, _> = services() + let mut result: BTreeMap<_, _> = self + .services .globals .signing_keys_for(origin)? .into_iter() @@ -362,15 +365,16 @@ pub async fn fetch_signing_keys_for_server( } // i didnt split this out into their own functions because it's relatively small - if services().globals.query_trusted_key_servers_first() { + if self.services.globals.query_trusted_key_servers_first() { info!( "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ keys" ); - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() + if let Some(server_keys) = self + .services .sending .send_federation_request( server, @@ -394,7 +398,10 @@ pub async fn fetch_signing_keys_for_server( }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.db.add_signing_key(origin, k.clone())?; + self.services + .globals + .db + .add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -414,14 +421,15 @@ pub async fn fetch_signing_keys_for_server( } debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() + if let Some(server_key) = self + .services .sending .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services() + self.services .globals .db .add_signing_key(origin, server_key.clone())?; @@ -447,14 +455,15 @@ pub async fn fetch_signing_keys_for_server( info!("query_trusted_key_servers_first is set to false, querying {origin} first"); debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() + if let Some(server_key) = self + .services .sending .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services() + self.services .globals .db .add_signing_key(origin, server_key.clone())?; @@ -477,9 +486,10 @@ pub async fn fetch_signing_keys_for_server( } } - for server in services().globals.trusted_servers() { + for server in self.services.globals.trusted_servers() { debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() + if let Some(server_keys) = self + .services .sending .send_federation_request( server, @@ -503,7 +513,10 @@ pub async fn fetch_signing_keys_for_server( }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.db.add_signing_key(origin, k.clone())?; + self.services + .globals + .db + .add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 96f623f2e..64764198c 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -6,10 +6,10 @@ sync::{Arc, Mutex}, }; -use data::Data; +use conduit::{PduCount, Result}; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{PduCount, Result}; +use self::data::Data; pub struct Service { db: Data, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 763dd0e80..efe681b1b 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,30 +1,39 @@ use std::sync::Arc; use conduit::{error, utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{OwnedRoomId, RoomId}; -use crate::services; +use crate::{rooms, Dep}; pub(super) struct Data { disabledroomids: Arc<Map>, bannedroomids: Arc<Map>, roomid_shortroomid: Arc<Map>, pduid_pdu: Arc<Map>, + services: Services, +} + +struct Services { + short: Dep<rooms::short::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { disabledroomids: db["disabledroomids"].clone(), bannedroomids: db["bannedroomids"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(), pduid_pdu: db["pduid_pdu"].clone(), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + }, } } pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> { - let prefix = match services().rooms.short.get_shortroomid(room_id)? { + let prefix = match self.services.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), None => return Ok(false), }; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index ec34a82c2..7415c53b7 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{OwnedRoomId, RoomId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index ef50b0946..44a83582d 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -33,13 +33,13 @@ pub struct Service { pub read_receipt: Arc<read_receipt::Service>, pub search: Arc<search::Service>, pub short: Arc<short::Service>, + pub spaces: Arc<spaces::Service>, pub state: Arc<state::Service>, pub state_accessor: Arc<state_accessor::Service>, pub state_cache: Arc<state_cache::Service>, pub state_compressor: Arc<state_compressor::Service>, - pub timeline: Arc<timeline::Service>, pub threads: Arc<threads::Service>, + pub timeline: Arc<timeline::Service>, pub typing: Arc<typing::Service>, - pub spaces: Arc<spaces::Service>, pub user: Arc<user::Service>, } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 24c756fd3..d1649da81 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,26 +1,35 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{utils, Error, PduCount, PduEvent, Result}; +use database::Map; use ruma::{EventId, RoomId, UserId}; -use crate::{services, PduCount, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { tofrom_relation: Arc<Map>, referencedevents: Arc<Map>, softfailedeventids: Arc<Map>, + services: Services, +} + +struct Services { + timeline: Dep<rooms::timeline::Service>, } type PdusIterItem = Result<(PduCount, PduEvent)>; type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>; impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { tofrom_relation: db["tofrom_relation"].clone(), referencedevents: db["referencedevents"].clone(), softfailedeventids: db["softfailedeventids"].clone(), + services: Services { + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, } } @@ -57,8 +66,8 @@ pub(super) fn relations_until<'a>( let mut pduid = shortroomid.to_be_bytes().to_vec(); pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = services() - .rooms + let mut pdu = self + .services .timeline .get_pdu_from_id(&pduid)? .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 05067aa8e..7546dcb29 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -2,8 +2,7 @@ use std::sync::Arc; -use conduit::Result; -use data::Data; +use conduit::{PduCount, PduEvent, Result}; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -11,12 +10,20 @@ }; use serde::Deserialize; -use crate::{services, PduCount, PduEvent}; +use self::data::Data; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + short: Dep<rooms::short::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + timeline: Dep<rooms::timeline::Service>, +} + #[derive(Clone, Debug, Deserialize)] struct ExtractRelType { rel_type: RelationType, @@ -30,7 +37,12 @@ struct ExtractRelatesToEventId { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -101,8 +113,7 @@ pub fn paginate_relations_with_filter( }) .take(limit) .filter(|(_, pdu)| { - services() - .rooms + self.services .state_accessor .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) @@ -147,8 +158,7 @@ pub fn paginate_relations_with_filter( }) .take(limit) .filter(|(_, pdu)| { - services() - .rooms + self.services .state_accessor .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) @@ -180,10 +190,10 @@ pub fn paginate_relations_with_filter( pub fn relations_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, ) -> Result<Vec<(PduCount, PduEvent)>> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let room_id = self.services.short.get_or_create_shortroomid(room_id)?; #[allow(unknown_lints)] #[allow(clippy::manual_unwrap_or_default)] - let target = match services().rooms.timeline.get_pdu_count(target)? { + let target = match self.services.timeline.get_pdu_count(target)? { Some(PduCount::Normal(c)) => c, // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 06eaf6555..0f400ff3a 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,14 +1,14 @@ use std::{mem::size_of, sync::Arc}; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{ events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, }; -use crate::services; +use crate::{globals, Dep}; type AnySyncEphemeralRoomEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>; @@ -16,15 +16,24 @@ pub(super) struct Data { roomuserid_privateread: Arc<Map>, roomuserid_lastprivatereadupdate: Arc<Map>, + services: Services, readreceiptid_readreceipt: Arc<Map>, } +struct Services { + globals: Dep<globals::Service>, +} + impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { roomuserid_privateread: db["roomuserid_privateread"].clone(), roomuserid_lastprivatereadupdate: db["roomuserid_lastprivatereadupdate"].clone(), readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } @@ -51,7 +60,7 @@ pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, even } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); room_latest_id.push(0xFF); room_latest_id.extend_from_slice(user_id.as_bytes()); @@ -108,7 +117,7 @@ pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: .insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) + .insert(&key, &self.services.globals.next_count()?.to_be_bytes()) } pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index 9375276ee..d202d8935 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -6,16 +6,24 @@ use data::Data; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; -use crate::services; +use crate::{sending, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + sending: Dep<sending::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + sending: args.depend::<sending::Service>("sending"), + }, + db: Data::new(&args), })) } @@ -26,7 +34,7 @@ impl Service { /// Replaces the previous read receipt. pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { self.db.readreceipt_update(user_id, room_id, event)?; - services().sending.flush_room(room_id)?; + self.services.sending.flush_room(room_id)?; Ok(()) } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 79b23cba3..a0086095b 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,21 +1,30 @@ use std::sync::Arc; use conduit::{utils, Result}; -use database::{Database, Map}; +use database::Map; use ruma::RoomId; -use crate::services; +use crate::{rooms, Dep}; type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>; pub(super) struct Data { tokenids: Arc<Map>, + services: Services, +} + +struct Services { + short: Dep<rooms::short::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { tokenids: db["tokenids"].clone(), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + }, } } @@ -51,8 +60,8 @@ pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: } pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = services() - .rooms + let prefix = self + .services .short .get_shortroomid(room_id)? .expect("room exists") diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 082dd432f..8caa0ce35 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -13,7 +13,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 883c3c1de..963bd9277 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{utils, warn, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::services; +use crate::{globals, Dep}; pub(super) struct Data { eventid_shorteventid: Arc<Map>, @@ -13,10 +13,16 @@ pub(super) struct Data { shortstatekey_statekey: Arc<Map>, roomid_shortroomid: Arc<Map>, statehash_shortstatehash: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { eventid_shorteventid: db["eventid_shorteventid"].clone(), shorteventid_eventid: db["shorteventid_eventid"].clone(), @@ -24,6 +30,9 @@ pub(super) fn new(db: &Arc<Database>) -> Self { shortstatekey_statekey: db["shortstatekey_statekey"].clone(), roomid_shortroomid: db["roomid_shortroomid"].clone(), statehash_shortstatehash: db["statehash_shortstatehash"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } @@ -31,7 +40,7 @@ pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u6 let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? } else { - let shorteventid = services().globals.next_count()?; + let shorteventid = self.services.globals.next_count()?; self.eventid_shorteventid .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.shorteventid_eventid @@ -59,7 +68,7 @@ pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, ), None => { - let short = services().globals.next_count()?; + let short = self.services.globals.next_count()?; self.eventid_shorteventid .insert(keys[i], &short.to_be_bytes())?; self.shorteventid_eventid @@ -98,7 +107,7 @@ pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, st let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? { utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))? } else { - let shortstatekey = services().globals.next_count()?; + let shortstatekey = self.services.globals.next_count()?; self.statekey_shortstatekey .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; self.shortstatekey_statekey @@ -158,7 +167,7 @@ pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<( true, ) } else { - let shortstatehash = services().globals.next_count()?; + let shortstatehash = self.services.globals.next_count()?; self.statehash_shortstatehash .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) @@ -176,7 +185,7 @@ pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? } else { - let short = services().globals.next_count()?; + let short = self.services.globals.next_count()?; self.roomid_shortroomid .insert(room_id.as_bytes(), &short.to_be_bytes())?; short diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 0979fb4fb..bfe0e9a0e 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{events::StateEventType, EventId, RoomId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 19a3ebbbd..24d612d87 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -28,7 +28,7 @@ }; use tokio::sync::Mutex; -use crate::services; +use crate::{rooms, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -119,42 +119,18 @@ enum Identifier<'a> { } pub struct Service { + services: Services, pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>, } -// Here because cannot implement `From` across ruma-federation-api and -// ruma-client-api types -impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk { - fn from(value: CachedSpaceHierarchySummary) -> Self { - let SpaceHierarchyParentSummary { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state, - .. - } = value.summary; - - Self { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state, - } - } +struct Services { + state_accessor: Dep<rooms::state_accessor::Service>, + state_cache: Dep<rooms::state_cache::Service>, + state: Dep<rooms::state::Service>, + short: Dep<rooms::short::Service>, + event_handler: Dep<rooms::event_handler::Service>, + timeline: Dep<rooms::timeline::Service>, + sending: Dep<sending::Service>, } impl crate::Service for Service { @@ -163,6 +139,15 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity); let cache_size = cache_size * config.cache_capacity_modifier; Ok(Arc::new(Self { + services: Services { + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + state: args.depend::<rooms::state::Service>("rooms::state"), + short: args.depend::<rooms::short::Service>("rooms::short"), + event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + sending: args.depend::<sending::Service>("sending"), + }, roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)), })) } @@ -226,7 +211,7 @@ async fn get_summary_and_children_local( .as_ref() { return Ok(if let Some(cached) = cached { - if is_accessable_child( + if self.is_accessible_child( current_room, &cached.summary.join_rule, &identifier, @@ -242,8 +227,8 @@ async fn get_summary_and_children_local( } Ok( - if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? { - let summary = Self::get_room_summary(current_room, children_pdus, &identifier); + if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { + let summary = self.get_room_summary(current_room, children_pdus, &identifier); if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -269,7 +254,8 @@ async fn get_summary_and_children_federation( ) -> Result<Option<SummaryAccessibility>> { for server in via { debug_info!("Asking {server} for /hierarchy"); - let Ok(response) = services() + let Ok(response) = self + .services .sending .send_federation_request( server, @@ -325,7 +311,10 @@ async fn get_summary_and_children_federation( avatar_url, join_rule, room_type, - children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), + children_state: self + .get_stripped_space_child_events(&room_id) + .await? + .unwrap(), allowed_room_ids, } }, @@ -333,7 +322,7 @@ async fn get_summary_and_children_federation( ); } } - if is_accessable_child( + if self.is_accessible_child( current_room, &response.room.join_rule, &Identifier::UserId(user_id), @@ -370,12 +359,13 @@ async fn get_summary_and_children_client( } fn get_room_summary( - current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>, identifier: &Identifier<'_>, + &self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>, + identifier: &Identifier<'_>, ) -> Result<SpaceHierarchyParentSummary, Error> { let room_id: &RoomId = current_room; - let join_rule = services() - .rooms + let join_rule = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { @@ -386,12 +376,12 @@ fn get_room_summary( .transpose()? .unwrap_or(JoinRule::Invite); - let allowed_room_ids = services() - .rooms + let allowed_room_ids = self + .services .state_accessor .allowed_room_ids(join_rule.clone()); - if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { + if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); @@ -400,18 +390,18 @@ fn get_room_summary( let join_rule = join_rule.into(); Ok(SpaceHierarchyParentSummary { - canonical_alias: services() - .rooms + canonical_alias: self + .services .state_accessor .get_canonical_alias(room_id) .unwrap_or(None), - name: services() - .rooms + name: self + .services .state_accessor .get_name(room_id) .unwrap_or(None), - num_joined_members: services() - .rooms + num_joined_members: self + .services .state_cache .room_joined_count(room_id) .unwrap_or_default() @@ -422,22 +412,22 @@ fn get_room_summary( .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), - topic: services() - .rooms + topic: self + .services .state_accessor .get_room_topic(room_id) .unwrap_or(None), - world_readable: services().rooms.state_accessor.is_world_readable(room_id)?, - guest_can_join: services().rooms.state_accessor.guest_can_join(room_id)?, - avatar_url: services() - .rooms + world_readable: self.services.state_accessor.is_world_readable(room_id)?, + guest_can_join: self.services.state_accessor.guest_can_join(room_id)?, + avatar_url: self + .services .state_accessor .get_avatar(room_id)? .into_option() .unwrap_or_default() .url, join_rule, - room_type: services().rooms.state_accessor.get_room_type(room_id)?, + room_type: self.services.state_accessor.get_room_type(room_id)?, children_state, allowed_room_ids, }) @@ -487,7 +477,7 @@ pub async fn get_client_hierarchy( .into_iter() .rev() .skip_while(|(room, _)| { - if let Ok(short) = services().rooms.short.get_shortroomid(room) + if let Ok(short) = self.services.short.get_shortroomid(room) { short.as_ref() != short_room_ids.get(parents.len()) } else { @@ -541,7 +531,7 @@ pub async fn get_client_hierarchy( let mut short_room_ids = vec![]; for room in parents { - short_room_ids.push(services().rooms.short.get_or_create_shortroomid(&room)?); + short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?); } Some( @@ -559,128 +549,152 @@ pub async fn get_client_hierarchy( rooms: results, }) } -} -fn next_room_to_traverse( - stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>, -) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> { - while stack.last().map_or(false, Vec::is_empty) { - stack.pop(); - parents.pop_back(); - } + /// Simply returns the stripped m.space.child events of a room + async fn get_stripped_space_child_events( + &self, room_id: &RoomId, + ) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { + let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else { + return Ok(None); + }; - stack.last_mut().and_then(Vec::pop) -} + let state = self + .services + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } -/// Simply returns the stripped m.space.child events of a room -async fn get_stripped_space_child_events( - room_id: &RoomId, -) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { - let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { - return Ok(None); - }; - - let state = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let mut children_pdus = Vec::new(); - for (key, id) in state { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; - if event_type != StateEventType::SpaceChild { - continue; - } + let pdu = self + .services + .timeline + .get_pdu(&id)? + .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get()) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) + { + continue; + } - if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; + if OwnedRoomId::try_from(state_key).is_ok() { + children_pdus.push(pdu.to_stripped_spacechild_state_event()); + } } - if OwnedRoomId::try_from(state_key).is_ok() { - children_pdus.push(pdu.to_stripped_spacechild_state_event()); - } + Ok(Some(children_pdus)) } - Ok(Some(children_pdus)) -} - -/// With the given identifier, checks if a room is accessable -fn is_accessable_child( - current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, - allowed_room_ids: &Vec<OwnedRoomId>, -) -> bool { - // Note: unwrap_or_default for bool means false - match identifier { - Identifier::ServerName(server_name) => { - let room_id: &RoomId = current_room; - - // Checks if ACLs allow for the server to participate - if services() - .rooms - .event_handler - .acl_check(server_name, room_id) - .is_err() - { - return false; - } - }, - Identifier::UserId(user_id) => { - if services() - .rooms - .state_cache - .is_joined(user_id, current_room) - .unwrap_or_default() - || services() - .rooms + /// With the given identifier, checks if a room is accessable + fn is_accessible_child( + &self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, + allowed_room_ids: &Vec<OwnedRoomId>, + ) -> bool { + // Note: unwrap_or_default for bool means false + match identifier { + Identifier::ServerName(server_name) => { + let room_id: &RoomId = current_room; + + // Checks if ACLs allow for the server to participate + if self + .services + .event_handler + .acl_check(server_name, room_id) + .is_err() + { + return false; + } + }, + Identifier::UserId(user_id) => { + if self + .services .state_cache - .is_invited(user_id, current_room) + .is_joined(user_id, current_room) .unwrap_or_default() - { - return true; - } - }, - } // Takes care of join rules - match join_rule { - SpaceRoomJoinRule::Restricted => { - for room in allowed_room_ids { - match identifier { - Identifier::UserId(user) => { - if services() - .rooms - .state_cache - .is_joined(user, room) - .unwrap_or_default() - { - return true; - } - }, - Identifier::ServerName(server) => { - if services() - .rooms - .state_cache - .server_in_room(server, room) - .unwrap_or_default() - { - return true; - } - }, + || self + .services + .state_cache + .is_invited(user_id, current_room) + .unwrap_or_default() + { + return true; } - } - false - }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, - // Invite only, Private, or Custom join rule - _ => false, + }, + } // Takes care of join rules + match join_rule { + SpaceRoomJoinRule::Restricted => { + for room in allowed_room_ids { + match identifier { + Identifier::UserId(user) => { + if self + .services + .state_cache + .is_joined(user, room) + .unwrap_or_default() + { + return true; + } + }, + Identifier::ServerName(server) => { + if self + .services + .state_cache + .server_in_room(server, room) + .unwrap_or_default() + { + return true; + } + }, + } + } + false + }, + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, + // Invite only, Private, or Custom join rule + _ => false, + } + } +} + +// Here because cannot implement `From` across ruma-federation-api and +// ruma-client-api types +impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk { + fn from(value: CachedSpaceHierarchySummary) -> Self { + let SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + .. + } = value.summary; + + Self { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state, + } } } @@ -736,3 +750,14 @@ fn get_parent_children_via( }) .collect() } + +fn next_room_to_traverse( + stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>, +) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> { + while stack.last().map_or(false, Vec::is_empty) { + stack.pop(); + parents.pop_back(); + } + + stack.last_mut().and_then(Vec::pop) +} diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index a3a317a58..cb219bc03 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -8,7 +8,7 @@ use conduit::{ utils::{calculate_hash, MutexMap, MutexMapGuard}, - warn, Error, Result, + warn, Error, PduEvent, Result, }; use data::Data; use ruma::{ @@ -23,19 +23,39 @@ }; use super::state_compressor::CompressedStateEvent; -use crate::{services, PduEvent}; +use crate::{globals, rooms, Dep}; pub struct Service { + services: Services, db: Data, pub mutex: RoomMutexMap, } +struct Services { + globals: Dep<globals::Service>, + short: Dep<rooms::short::Service>, + spaces: Dep<rooms::spaces::Service>, + state_cache: Dep<rooms::state_cache::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + state_compressor: Dep<rooms::state_compressor::Service>, + timeline: Dep<rooms::timeline::Service>, +} + type RoomMutexMap = MutexMap<OwnedRoomId, ()>; pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + services: Services { + globals: args.depend::<globals::Service>("globals"), + short: args.depend::<rooms::short::Service>("rooms::short"), + spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, db: Data::new(args.db), mutex: RoomMutexMap::new(), })) @@ -62,14 +82,13 @@ pub async fn force_state( state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else { + let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else { continue; }; @@ -94,7 +113,7 @@ pub async fn force_state( continue; }; - services().rooms.state_cache.update_membership( + self.services.state_cache.update_membership( room_id, &user_id, membership_event, @@ -105,8 +124,7 @@ pub async fn force_state( )?; }, TimelineEventType::SpaceChild => { - services() - .rooms + self.services .spaces .roomid_spacehierarchy_cache .lock() @@ -117,7 +135,7 @@ pub async fn force_state( } } - services().rooms.state_cache.update_joined_count(room_id)?; + self.services.state_cache.update_joined_count(room_id)?; self.db .set_room_state(room_id, shortstatehash, state_lock)?; @@ -133,10 +151,7 @@ pub async fn force_state( pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, ) -> Result<u64> { - let shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(event_id)?; + let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; @@ -147,20 +162,15 @@ pub fn set_event_state( .collect::<Vec<_>>(), ); - let (shortstatehash, already_existed) = services() - .rooms + let (shortstatehash, already_existed) = self + .services .short .get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, + |p| self.services.state_compressor.load_shortstatehash_info(p), )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -179,7 +189,7 @@ pub fn set_event_state( } else { (state_ids_compressed, Arc::new(HashSet::new())) }; - services().rooms.state_compressor.save_state_from_diff( + self.services.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -199,8 +209,8 @@ pub fn set_event_state( /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { - let shorteventid = services() - .rooms + let shorteventid = self + .services .short .get_or_create_shorteventid(&new_pdu.event_id)?; @@ -214,21 +224,16 @@ pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), #[inline] - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, + |p| self.services.state_compressor.load_shortstatehash_info(p), )?; - let shortstatekey = services() - .rooms + let shortstatekey = self + .services .short .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; - let new = services() - .rooms + let new = self + .services .state_compressor .compress_state_event(shortstatekey, &new_pdu.event_id)?; @@ -246,7 +251,7 @@ pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { } // TODO: statehash with deterministic inputs - let shortstatehash = services().globals.next_count()?; + let shortstatehash = self.services.globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); @@ -256,7 +261,7 @@ pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { statediffremoved.insert(*replaces); } - services().rooms.state_compressor.save_state_from_diff( + self.services.state_compressor.save_state_from_diff( shortstatehash, Arc::new(statediffnew), Arc::new(statediffremoved), @@ -275,22 +280,20 @@ pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw< let mut state = Vec::new(); // Add recommended events if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = services().rooms.state_accessor.room_state_get( + if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomCanonicalAlias, "", @@ -298,22 +301,20 @@ pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw< state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services() - .rooms + self.services .state_accessor .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } - if let Some(e) = services().rooms.state_accessor.room_state_get( + if let Some(e) = self.services.state_accessor.room_state_get( &invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str(), @@ -339,8 +340,8 @@ pub fn set_room_state( /// Returns the room's version. #[tracing::instrument(skip(self), level = "debug")] pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { - let create_event = services() - .rooms + let create_event = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")?; @@ -393,8 +394,7 @@ pub fn get_auth_events( let mut sauthevents = auth_events .into_iter() .filter_map(|(event_type, state_key)| { - services() - .rooms + self.services .short .get_shortstatekey(&event_type.to_string().into(), &state_key) .ok() @@ -403,8 +403,8 @@ pub fn get_auth_events( }) .collect::<HashMap<_, _>>(); - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -414,16 +414,14 @@ pub fn get_auth_events( Ok(full_state .iter() .filter_map(|compressed| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(compressed) .ok() }) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) .filter_map(|(k, event_id)| { - services() - .rooms + self.services .timeline .get_pdu(&event_id) .ok() diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 7e9daeda2..4c85148db 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,28 +1,43 @@ use std::{collections::HashMap, sync::Arc}; -use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use conduit::{utils, Error, PduEvent, Result}; +use database::Map; use ruma::{events::StateEventType, EventId, RoomId}; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { eventid_shorteventid: Arc<Map>, shorteventid_shortstatehash: Arc<Map>, + services: Services, +} + +struct Services { + short: Dep<rooms::short::Service>, + state: Dep<rooms::state::Service>, + state_compressor: Dep<rooms::state_compressor::Service>, + timeline: Dep<rooms::timeline::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { eventid_shorteventid: db["eventid_shorteventid"].clone(), shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + state: args.depend::<rooms::state::Service>("rooms::state"), + state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, } } #[allow(unused_qualifications)] // async traits pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> { - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -31,8 +46,8 @@ pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { - let parsed = services() - .rooms + let parsed = self + .services .state_compressor .parse_compressed_state_event(compressed)?; result.insert(parsed.0, parsed.1); @@ -49,8 +64,8 @@ pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap pub(super) async fn state_full( &self, shortstatehash: u64, ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -60,11 +75,11 @@ pub(super) async fn state_full( let mut result = HashMap::new(); let mut i: u8 = 0; for compressed in full_state.iter() { - let (_, eventid) = services() - .rooms + let (_, eventid) = self + .services .state_compressor .parse_compressed_state_event(compressed)?; - if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { + if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? { result.insert( ( pdu.kind.to_string().into(), @@ -92,15 +107,15 @@ pub(super) async fn state_full( pub(super) fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<EventId>>> { - let Some(shortstatekey) = services() - .rooms + let Some(shortstatekey) = self + .services .short .get_shortstatekey(event_type, state_key)? else { return Ok(None); }; - let full_state = services() - .rooms + let full_state = self + .services .state_compressor .load_shortstatehash_info(shortstatehash)? .pop() @@ -110,8 +125,7 @@ pub(super) fn state_get_id( .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) .and_then(|compressed| { - services() - .rooms + self.services .state_compressor .parse_compressed_state_event(compressed) .ok() @@ -125,7 +139,7 @@ pub(super) fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<PduEvent>>> { self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) + .map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id)) } /// Returns the state hash for this pdu. @@ -149,7 +163,7 @@ pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64 pub(super) async fn room_state_full( &self, room_id: &RoomId, ) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_full(current_shortstatehash).await } else { Ok(HashMap::new()) @@ -161,7 +175,7 @@ pub(super) async fn room_state_full( pub(super) fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<EventId>>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_get_id(current_shortstatehash, event_type, state_key) } else { Ok(None) @@ -173,7 +187,7 @@ pub(super) fn room_state_get_id( pub(super) fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result<Option<Arc<PduEvent>>> { - if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? { self.state_get(current_shortstatehash, event_type, state_key) } else { Ok(None) diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index bd3eb0a15..2526f1bdd 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,7 +6,7 @@ sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result}; +use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result}; use data::Data; use lru_cache::LruCache; use ruma::{ @@ -33,14 +33,20 @@ }; use serde_json::value::to_raw_value; -use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent}; +use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { + services: Services, db: Data, pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>, pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>, } +struct Services { + state_cache: Dep<rooms::state_cache::Service>, + timeline: Dep<rooms::timeline::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; @@ -50,7 +56,11 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, + db: Data::new(&args), server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)), user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)), })) @@ -164,8 +174,8 @@ pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_ }) .unwrap_or(HistoryVisibility::Shared); - let mut current_server_members = services() - .rooms + let mut current_server_members = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -212,7 +222,7 @@ pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: & return Ok(*visibility); } - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; let history_visibility = self .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? @@ -258,7 +268,7 @@ pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: & /// the room's history_visibility at that event's state. #[tracing::instrument(skip(self, user_id, room_id))] pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = self.services.state_cache.is_joined(user_id, room_id)?; let history_visibility = self .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? @@ -342,8 +352,8 @@ pub fn user_can_invite( redacts: None, }; - Ok(services() - .rooms + Ok(self + .services .timeline .create_hash_and_sign_event(new_event, sender, room_id, state_lock) .is_ok()) @@ -413,7 +423,7 @@ pub fn user_can_redact( // Falling back on m.room.create to judge power level if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? { Ok(pdu.sender == sender - || if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { + || if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { pdu.sender == sender } else { false @@ -430,7 +440,7 @@ pub fn user_can_redact( .map(|event: RoomPowerLevels| { event.user_can_redact_event_of_other(sender) || event.user_can_redact_own_event(sender) - && if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) { + && if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) { if federation { pdu.sender.server_name() == sender.server_name() } else { diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 2b9fbe941..cbda73cf5 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -4,7 +4,7 @@ }; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use itertools::Itertools; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, @@ -12,44 +12,55 @@ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{appservice::RegistrationInfo, services, user_is_local}; +use crate::{appservice::RegistrationInfo, globals, user_is_local, users, Dep}; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>; pub(super) struct Data { - userroomid_joined: Arc<Map>, - roomuserid_joined: Arc<Map>, - userroomid_invitestate: Arc<Map>, - roomuserid_invitecount: Arc<Map>, - userroomid_leftstate: Arc<Map>, - roomuserid_leftcount: Arc<Map>, + pub(super) appservice_in_room_cache: AppServiceInRoomCache, + roomid_invitedcount: Arc<Map>, roomid_inviteviaservers: Arc<Map>, - roomuseroncejoinedids: Arc<Map>, roomid_joinedcount: Arc<Map>, - roomid_invitedcount: Arc<Map>, roomserverids: Arc<Map>, + roomuserid_invitecount: Arc<Map>, + roomuserid_joined: Arc<Map>, + roomuserid_leftcount: Arc<Map>, + roomuseroncejoinedids: Arc<Map>, serverroomids: Arc<Map>, - pub(super) appservice_in_room_cache: AppServiceInRoomCache, + userroomid_invitestate: Arc<Map>, + userroomid_joined: Arc<Map>, + userroomid_leftstate: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, + users: Dep<users::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { - userroomid_joined: db["userroomid_joined"].clone(), - roomuserid_joined: db["roomuserid_joined"].clone(), - userroomid_invitestate: db["userroomid_invitestate"].clone(), - roomuserid_invitecount: db["roomuserid_invitecount"].clone(), - userroomid_leftstate: db["userroomid_leftstate"].clone(), - roomuserid_leftcount: db["roomuserid_leftcount"].clone(), + appservice_in_room_cache: RwLock::new(HashMap::new()), + roomid_invitedcount: db["roomid_invitedcount"].clone(), roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(), - roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), roomid_joinedcount: db["roomid_joinedcount"].clone(), - roomid_invitedcount: db["roomid_invitedcount"].clone(), roomserverids: db["roomserverids"].clone(), + roomuserid_invitecount: db["roomuserid_invitecount"].clone(), + roomuserid_joined: db["roomuserid_joined"].clone(), + roomuserid_leftcount: db["roomuserid_leftcount"].clone(), + roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(), serverroomids: db["serverroomids"].clone(), - appservice_in_room_cache: RwLock::new(HashMap::new()), + userroomid_invitestate: db["userroomid_invitestate"].clone(), + userroomid_joined: db["userroomid_joined"].clone(), + userroomid_leftstate: db["userroomid_leftstate"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + users: args.depend::<users::Service>("users"), + }, } } @@ -100,7 +111,7 @@ pub(super) fn mark_as_invited( &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), )?; self.roomuserid_invitecount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; @@ -144,7 +155,7 @@ pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result< &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), )?; // TODO self.roomuserid_leftcount - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -228,7 +239,7 @@ pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &Registrat } else { let bridge_user_id = UserId::parse_with_server_name( appservice.registration.sender_localpart.as_str(), - services().globals.server_name(), + self.services.globals.server_name(), ) .ok(); @@ -356,7 +367,7 @@ pub(super) fn active_local_users_in_room<'a>( ) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> { Box::new( self.local_users_in_room(room_id) - .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), + .filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)), ) } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 48215817e..ac2f688e7 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -21,16 +21,28 @@ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{appservice::RegistrationInfo, services, user_is_local}; +use crate::{account_data, appservice::RegistrationInfo, rooms, user_is_local, users, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + account_data: Dep<account_data::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + users: Dep<users::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + account_data: args.depend::<account_data::Service>("account_data"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + users: args.depend::<users::Service>("users"), + }, + db: Data::new(&args), })) } @@ -54,18 +66,18 @@ pub fn update_membership( // update #[allow(clippy::collapsible_if)] if !user_is_local(user_id) { - if !services().users.exists(user_id)? { - services().users.create(user_id, None)?; + if !self.services.users.exists(user_id)? { + self.services.users.create(user_id, None)?; } /* // Try to update our local copy of the user if ours does not match - if ((services().users.displayname(user_id)? != membership_event.displayname) - || (services().users.avatar_url(user_id)? != membership_event.avatar_url) - || (services().users.blurhash(user_id)? != membership_event.blurhash)) + if ((self.services.users.displayname(user_id)? != membership_event.displayname) + || (self.services.users.avatar_url(user_id)? != membership_event.avatar_url) + || (self.services.users.blurhash(user_id)? != membership_event.blurhash)) && (membership != MembershipState::Leave) { - let response = services() + let response = self.services .sending .send_federation_request( user_id.server_name(), @@ -76,9 +88,9 @@ pub fn update_membership( ) .await; - services().users.set_displayname(user_id, response.displayname.clone()).await?; - services().users.set_avatar_url(user_id, response.avatar_url).await?; - services().users.set_blurhash(user_id, response.blurhash).await?; + self.services.users.set_displayname(user_id, response.displayname.clone()).await?; + self.services.users.set_avatar_url(user_id, response.avatar_url).await?; + self.services.users.set_blurhash(user_id, response.blurhash).await?; }; */ } @@ -91,8 +103,8 @@ pub fn update_membership( self.db.mark_as_once_joined(user_id, room_id)?; // Check if the room has a predecessor - if let Some(predecessor) = services() - .rooms + if let Some(predecessor) = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomCreate, "")? .and_then(|create| serde_json::from_str(create.content.get()).ok()) @@ -124,21 +136,23 @@ pub fn update_membership( // .ok(); // Copy old tags to new room - if let Some(tag_event) = services() + if let Some(tag_event) = self + .services .account_data .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? .map(|event| { serde_json::from_str(event.get()) .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { - services() + self.services .account_data .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) .ok(); }; // Copy direct chat flag - if let Some(direct_event) = services() + if let Some(direct_event) = self + .services .account_data .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? .map(|event| { @@ -156,7 +170,7 @@ pub fn update_membership( } if room_ids_updated { - services().account_data.update( + self.services.account_data.update( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), @@ -171,7 +185,8 @@ pub fn update_membership( }, MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = services() + let is_ignored = self + .services .account_data .get( None, // Ignored users are in global account data @@ -393,8 +408,8 @@ pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Resul /// See <https://spec.matrix.org/v1.10/appendices/#routing> #[tracing::instrument(skip(self))] pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> { - let most_powerful_user_server = services() - .rooms + let most_powerful_user_server = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .map(|pdu| { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 422c562b8..2550774e1 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -13,7 +13,7 @@ use ruma::{EventId, RoomId}; use self::data::StateDiff; -use crate::services; +use crate::{rooms, Dep}; type StateInfoLruCache = Mutex< LruCache< @@ -48,16 +48,25 @@ pub struct Service { db: Data, - + services: Services, pub stateinfo_cache: StateInfoLruCache, } +struct Services { + short: Dep<rooms::short::Service>, + state: Dep<rooms::state::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let config = &args.server.config; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { db: Data::new(args.db), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + state: args.depend::<rooms::state::Service>("rooms::state"), + }, stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } @@ -124,8 +133,8 @@ pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoRes pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> { let mut v = shortstatekey.to_be_bytes().to_vec(); v.extend_from_slice( - &services() - .rooms + &self + .services .short .get_or_create_shorteventid(event_id)? .to_be_bytes(), @@ -138,7 +147,7 @@ pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Re pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> { Ok(( utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"), - services().rooms.short.get_eventid_from_short( + self.services.short.get_eventid_from_short( utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"), )?, )) @@ -282,7 +291,7 @@ pub fn save_state_from_diff( pub fn save_state( &self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, ) -> HashSetCompressStateEvent { - let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?; let state_hash = utils::calculate_hash( &new_state_ids_compressed @@ -291,8 +300,8 @@ pub fn save_state( .collect::<Vec<_>>(), ); - let (new_shortstatehash, already_existed) = services() - .rooms + let (new_shortstatehash, already_existed) = self + .services .short .get_or_create_shortstatehash(&state_hash)?; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index c4a1a2945..fb279a007 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,29 +1,40 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{checked, utils, Error, Result}; -use database::{Database, Map}; +use conduit::{checked, utils, Error, PduEvent, Result}; +use database::Map; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>; pub(super) struct Data { threadid_userids: Arc<Map>, + services: Services, +} + +struct Services { + short: Dep<rooms::short::Service>, + timeline: Dep<rooms::timeline::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { threadid_userids: db["threadid_userids"].clone(), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, } } pub(super) fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, ) -> PduEventIterResult<'a> { - let prefix = services() - .rooms + let prefix = self + .services .short .get_shortroomid(room_id)? .expect("room exists") @@ -40,8 +51,8 @@ pub(super) fn threads_until<'a>( .map(move |(pduid, _users)| { let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..]) .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = services() - .rooms + let mut pdu = self + .services .timeline .get_pdu_from_id(&pduid)? .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index dd2686b06..ae51cd0f9 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, Result}; +use conduit::{Error, PduEvent, Result}; use data::Data; use ruma::{ api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, @@ -11,16 +11,24 @@ }; use serde_json::json; -use crate::{services, PduEvent}; +use crate::{rooms, Dep}; pub struct Service { + services: Services, db: Data, } +struct Services { + timeline: Dep<rooms::timeline::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + }, + db: Data::new(&args), })) } @@ -35,22 +43,22 @@ pub fn threads_until<'a>( } pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { - let root_id = &services() - .rooms + let root_id = self + .services .timeline .get_pdu_id(root_event_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; - let root_pdu = services() - .rooms + let root_pdu = self + .services .timeline - .get_pdu_from_id(root_id)? + .get_pdu_from_id(&root_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; - let mut root_pdu_json = services() - .rooms + let mut root_pdu_json = self + .services .timeline - .get_pdu_json_from_id(root_id)? + .get_pdu_json_from_id(&root_id)? .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; if let CanonicalJsonValue::Object(unsigned) = root_pdu_json @@ -93,20 +101,19 @@ pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<( ); } - services() - .rooms + self.services .timeline - .replace_pdu(root_id, &root_pdu_json, &root_pdu)?; + .replace_pdu(&root_id, &root_pdu_json, &root_pdu)?; } let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(root_id)? { + if let Some(userids) = self.db.get_participants(&root_id)? { users.extend_from_slice(&userids); } else { users.push(root_pdu.sender); } users.push(pdu.sender.clone()); - self.db.update_participants(root_id, &users) + self.db.update_participants(&root_id, &users) } } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index ec975b99d..5917e96b7 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,19 +4,25 @@ sync::{Arc, Mutex}, }; -use conduit::{checked, error, utils, Error, Result}; +use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result}; use database::{Database, Map}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::{services, PduCount, PduEvent}; +use crate::{rooms, Dep}; pub(super) struct Data { + eventid_outlierpdu: Arc<Map>, eventid_pduid: Arc<Map>, pduid_pdu: Arc<Map>, - eventid_outlierpdu: Arc<Map>, - userroomid_notificationcount: Arc<Map>, userroomid_highlightcount: Arc<Map>, + userroomid_notificationcount: Arc<Map>, pub(super) lasttimelinecount_cache: LastTimelineCountCache, + pub(super) db: Arc<Database>, + services: Services, +} + +struct Services { + short: Dep<rooms::short::Service>, } type PdusIterItem = Result<(PduCount, PduEvent)>; @@ -24,14 +30,19 @@ pub(super) struct Data { type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>; impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { + eventid_outlierpdu: db["eventid_outlierpdu"].clone(), eventid_pduid: db["eventid_pduid"].clone(), pduid_pdu: db["pduid_pdu"].clone(), - eventid_outlierpdu: db["eventid_outlierpdu"].clone(), - userroomid_notificationcount: db["userroomid_notificationcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(), + userroomid_notificationcount: db["userroomid_notificationcount"].clone(), lasttimelinecount_cache: Mutex::new(HashMap::new()), + db: args.db.clone(), + services: Services { + short: args.depend::<rooms::short::Service>("rooms::short"), + }, } } @@ -210,7 +221,7 @@ pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, /// happened before the event with id `until` in reverse-chronological /// order. pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> { - let (prefix, current) = count_to_id(room_id, until, 1, true)?; + let (prefix, current) = self.count_to_id(room_id, until, 1, true)?; let user_id = user_id.to_owned(); @@ -232,7 +243,7 @@ pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCo } pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> { - let (prefix, current) = count_to_id(room_id, from, 1, false)?; + let (prefix, current) = self.count_to_id(room_id, from, 1, false)?; let user_id = user_id.to_owned(); @@ -277,6 +288,41 @@ pub(super) fn increment_notification_counts( .increment_batch(highlights_batch.iter().map(Vec::as_slice))?; Ok(()) } + + pub(super) fn count_to_id( + &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, + ) -> Result<(Vec<u8>, Vec<u8>)> { + let prefix = self + .services + .short + .get_shortroomid(room_id)? + .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .to_be_bytes() + .to_vec(); + let mut pdu_id = prefix.clone(); + // +1 so we don't send the base event + let count_raw = match count { + PduCount::Normal(x) => { + if subtract { + x.saturating_sub(offset) + } else { + x.saturating_add(offset) + } + }, + PduCount::Backfilled(x) => { + pdu_id.extend_from_slice(&0_u64.to_be_bytes()); + let num = u64::MAX.saturating_sub(x); + if subtract { + num.saturating_sub(offset) + } else { + num.saturating_add(offset) + } + }, + }; + pdu_id.extend_from_slice(&count_raw.to_be_bytes()); + + Ok((prefix, pdu_id)) + } } /// Returns the `count` of this pdu's id. @@ -294,38 +340,3 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> { Ok(PduCount::Normal(last_u64)) } } - -pub(super) fn count_to_id( - room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, -) -> Result<(Vec<u8>, Vec<u8>)> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? - .to_be_bytes() - .to_vec(); - let mut pdu_id = prefix.clone(); - // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } - }, - }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - - Ok((prefix, pdu_id)) -} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 4c5e407a4..50d294753 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -7,11 +7,12 @@ }; use conduit::{ - debug, error, info, utils, + debug, error, info, + pdu::{EventHash, PduBuilder, PduCount, PduEvent}, + utils, utils::{MutexMap, MutexMapGuard}, - validated, warn, Error, Result, + validated, warn, Error, Result, Server, }; -use data::Data; use itertools::Itertools; use ruma::{ api::{client::error::ErrorKind, federation}, @@ -37,11 +38,10 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; +use self::data::Data; use crate::{ - appservice::NamespaceRegex, - pdu::{EventHash, PduBuilder}, - rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent}, - server_is_ours, services, PduCount, PduEvent, + account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, + rooms::state_compressor::CompressedStateEvent, sending, server_is_ours, Dep, }; // Update Relationships @@ -67,17 +67,61 @@ struct ExtractBody { } pub struct Service { + services: Services, db: Data, pub mutex_insert: RoomMutexMap, } +struct Services { + server: Arc<Server>, + account_data: Dep<account_data::Service>, + appservice: Dep<appservice::Service>, + admin: Dep<admin::Service>, + alias: Dep<rooms::alias::Service>, + globals: Dep<globals::Service>, + short: Dep<rooms::short::Service>, + state: Dep<rooms::state::Service>, + state_cache: Dep<rooms::state_cache::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, + pdu_metadata: Dep<rooms::pdu_metadata::Service>, + read_receipt: Dep<rooms::read_receipt::Service>, + sending: Dep<sending::Service>, + user: Dep<rooms::user::Service>, + pusher: Dep<pusher::Service>, + threads: Dep<rooms::threads::Service>, + search: Dep<rooms::search::Service>, + spaces: Dep<rooms::spaces::Service>, + event_handler: Dep<rooms::event_handler::Service>, +} + type RoomMutexMap = MutexMap<OwnedRoomId, ()>; pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + services: Services { + server: args.server.clone(), + account_data: args.depend::<account_data::Service>("account_data"), + appservice: args.depend::<appservice::Service>("appservice"), + admin: args.depend::<admin::Service>("admin"), + alias: args.depend::<rooms::alias::Service>("rooms::alias"), + globals: args.depend::<globals::Service>("globals"), + short: args.depend::<rooms::short::Service>("rooms::short"), + state: args.depend::<rooms::state::Service>("rooms::state"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"), + read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"), + sending: args.depend::<sending::Service>("sending"), + user: args.depend::<rooms::user::Service>("rooms::user"), + pusher: args.depend::<pusher::Service>("pusher"), + threads: args.depend::<rooms::threads::Service>("rooms::threads"), + search: args.depend::<rooms::search::Service>("rooms::search"), + spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"), + event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"), + }, + db: Data::new(&args), mutex_insert: RoomMutexMap::new(), })) } @@ -217,10 +261,10 @@ pub async fn append_pdu( state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Vec<u8>> { // Coalesce database writes for the remainder of this scope. - let _cork = services().db.cork_and_flush(); + let _cork = self.db.db.cork_and_flush(); - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(&pdu.room_id)? .expect("room exists"); @@ -233,14 +277,14 @@ pub async fn append_pdu( .entry("unsigned".to_owned()) .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = services() - .rooms + if let Some(shortstatehash) = self + .services .state_accessor .pdu_shortstatehash(&pdu.event_id) .unwrap() { - if let Some(prev_state) = services() - .rooms + if let Some(prev_state) = self + .services .state_accessor .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) .unwrap() @@ -270,30 +314,26 @@ pub async fn append_pdu( } // We must keep track of all events that have been referenced. - services() - .rooms + self.services .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms + self.services .state .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; - let count1 = services().globals.next_count()?; + let count1 = self.services.globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if // appending fails - services() - .rooms + self.services .read_receipt .private_read_set(&pdu.room_id, &pdu.sender, count1)?; - services() - .rooms + self.services .user .reset_notification_counts(&pdu.sender, &pdu.room_id)?; - let count2 = services().globals.next_count()?; + let count2 = self.services.globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); @@ -303,8 +343,8 @@ pub async fn append_pdu( drop(insert_lock); // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -319,8 +359,8 @@ pub async fn append_pdu( let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = services() - .rooms + let mut push_target = self + .services .state_cache .active_local_users_in_room(&pdu.room_id) .collect_vec(); @@ -341,7 +381,8 @@ pub async fn append_pdu( continue; } - let rules_for_user = services() + let rules_for_user = self + .services .account_data .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? .map(|event| { @@ -357,7 +398,7 @@ pub async fn append_pdu( let mut notify = false; for action in - services() + self.services .pusher .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? { @@ -378,8 +419,10 @@ pub async fn append_pdu( highlights.push(user.clone()); } - for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_pdu_push(&pdu_id, user, push_key?)?; + for push_key in self.services.pusher.get_pushkeys(user) { + self.services + .sending + .send_pdu_push(&pdu_id, user, push_key?)?; } } @@ -390,11 +433,11 @@ pub async fn append_pdu( TimelineEventType::RoomRedaction => { use RoomVersionId::*; - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; match room_version_id { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if services().rooms.state_accessor.user_can_redact( + if self.services.state_accessor.user_can_redact( redact_id, &pdu.sender, &pdu.room_id, @@ -412,7 +455,7 @@ pub async fn append_pdu( })?; if let Some(redact_id) = &content.redacts { - if services().rooms.state_accessor.user_can_redact( + if self.services.state_accessor.user_can_redact( redact_id, &pdu.sender, &pdu.room_id, @@ -433,8 +476,7 @@ pub async fn append_pdu( }, TimelineEventType::SpaceChild => { if let Some(_state_key) = &pdu.state_key { - services() - .rooms + self.services .spaces .roomid_spacehierarchy_cache .lock() @@ -455,7 +497,7 @@ pub async fn append_pdu( let invite_state = match content.membership { MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; + let state = self.services.state.calculate_invite_state(pdu)?; Some(state) }, _ => None, @@ -463,7 +505,7 @@ pub async fn append_pdu( // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - services().rooms.state_cache.update_membership( + self.services.state_cache.update_membership( &pdu.room_id, &target_user_id, content, @@ -479,13 +521,12 @@ pub async fn append_pdu( .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services() - .rooms + self.services .search .index_pdu(shortroomid, &pdu_id, &body)?; - if services().admin.is_admin_command(pdu, &body).await { - services() + if self.services.admin.is_admin_command(pdu, &body).await { + self.services .admin .command(body, Some((*pdu.event_id).into())) .await; @@ -497,8 +538,7 @@ pub async fn append_pdu( if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) { if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? { - services() - .rooms + self.services .pdu_metadata .add_relation(PduCount::Normal(count2), related_pducount)?; } @@ -512,29 +552,25 @@ pub async fn append_pdu( // We need to do it again here, because replies don't have // event_id as a top level field if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? { - services() - .rooms + self.services .pdu_metadata .add_relation(PduCount::Normal(count2), related_pducount)?; } }, Relation::Thread(thread) => { - services() - .rooms - .threads - .add_to_thread(&thread.event_id, pdu)?; + self.services.threads.add_to_thread(&thread.event_id, pdu)?; }, _ => {}, // TODO: Aggregate other types } } - for appservice in services().appservice.read().await.values() { - if services() - .rooms + for appservice in self.services.appservice.read().await.values() { + if self + .services .state_cache .appservice_in_room(&pdu.room_id, appservice)? { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; @@ -550,7 +586,7 @@ pub async fn append_pdu( { let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; @@ -567,8 +603,7 @@ pub async fn append_pdu( .map_or(false, |state_key| users.is_match(state_key)) }; let matching_aliases = |aliases: &NamespaceRegex| { - services() - .rooms + self.services .alias .local_aliases_for_room(&pdu.room_id) .filter_map(Result::ok) @@ -579,7 +614,7 @@ pub async fn append_pdu( || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { - services() + self.services .sending .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; } @@ -603,8 +638,8 @@ pub fn create_hash_and_sign_event( redacts, } = pdu_builder; - let prev_events: Vec<_> = services() - .rooms + let prev_events: Vec<_> = self + .services .state .get_forward_extremities(room_id)? .into_iter() @@ -612,28 +647,23 @@ pub fn create_hash_and_sign_event( .collect(); // If there was no create event yet, assume we are creating a room - let room_version_id = services() - .rooms - .state - .get_room_version(room_id) - .or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - let content = serde_json::from_str::<RoomCreateEventContent>(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + let content = serde_json::from_str::<RoomCreateEventContent>(content.get()) + .expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); let auth_events = - services() - .rooms + self.services .state .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; @@ -649,8 +679,7 @@ pub fn create_hash_and_sign_event( if let Some(state_key) = &state_key { if let Some(prev_pdu) = - services() - .rooms + self.services .state_accessor .room_state_get(room_id, &event_type.to_string().into(), state_key)? { @@ -730,12 +759,12 @@ pub fn create_hash_and_sign_event( // Add origin because synapse likes that (and it's required in the spec) pdu_json.insert( "origin".to_owned(), - to_canonical_value(services().globals.server_name()).expect("server name is a valid CanonicalJsonValue"), + to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"), ); match ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), + self.services.globals.server_name().as_str(), + self.services.globals.keypair(), &mut pdu_json, &room_version_id, ) { @@ -763,8 +792,8 @@ pub fn create_hash_and_sign_event( ); // Generate short event id - let _shorteventid = services() - .rooms + let _shorteventid = self + .services .short .get_or_create_shorteventid(&pdu.event_id)?; @@ -783,7 +812,7 @@ pub async fn build_and_append_pdu( state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<Arc<EventId>> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - if let Some(admin_room) = services().admin.get_admin_room()? { + if let Some(admin_room) = self.services.admin.get_admin_room()? { if admin_room == room_id { match pdu.event_type() { TimelineEventType::RoomEncryption => { @@ -798,7 +827,7 @@ pub async fn build_and_append_pdu( .state_key() .filter(|v| v.starts_with('@')) .unwrap_or(sender.as_str()); - let server_user = &services().globals.server_user.to_string(); + let server_user = &self.services.globals.server_user.to_string(); let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get()) .map_err(|_| Error::bad_database("Invalid content in pdu"))?; @@ -812,8 +841,8 @@ pub async fn build_and_append_pdu( )); } - let count = services() - .rooms + let count = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -837,8 +866,8 @@ pub async fn build_and_append_pdu( )); } - let count = services() - .rooms + let count = self + .services .state_cache .room_members(room_id) .filter_map(Result::ok) @@ -861,15 +890,14 @@ pub async fn build_and_append_pdu( // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { use RoomVersionId::*; - match services().rooms.state.get_room_version(&pdu.room_id)? { + match self.services.state.get_room_version(&pdu.room_id)? { V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { - if !services().rooms.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { + if !self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + { return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); } }; @@ -879,12 +907,11 @@ pub async fn build_and_append_pdu( .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; if let Some(redact_id) = &content.redacts { - if !services().rooms.state_accessor.user_can_redact( - redact_id, - &pdu.sender, - &pdu.room_id, - false, - )? { + if !self + .services + .state_accessor + .user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)? + { return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event.")); } } @@ -895,7 +922,7 @@ pub async fn build_and_append_pdu( // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - let statehashid = services().rooms.state.append_to_state(&pdu)?; + let statehashid = self.services.state.append_to_state(&pdu)?; let pdu_id = self .append_pdu( @@ -910,13 +937,12 @@ pub async fn build_and_append_pdu( // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist - services() - .rooms + self.services .state .set_room_state(room_id, statehashid, state_lock)?; - let mut servers: HashSet<OwnedServerName> = services() - .rooms + let mut servers: HashSet<OwnedServerName> = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -936,9 +962,9 @@ pub async fn build_and_append_pdu( // Remove our server from the server list since it will be added to it by // room_servers() and/or the if statement above - servers.remove(services().globals.server_name()); + servers.remove(self.services.globals.server_name()); - services() + self.services .sending .send_pdu_servers(servers.into_iter(), &pdu_id)?; @@ -960,18 +986,15 @@ pub async fn append_incoming_pdu( // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - services() - .rooms + self.services .state .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; if soft_fail { - services() - .rooms + self.services .pdu_metadata .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms + self.services .state .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; return Ok(None); @@ -1022,14 +1045,13 @@ pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64 if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) { if let Some(body) = content.body { - services() - .rooms + self.services .search .deindex_pdu(shortroomid, &pdu_id, &body)?; } } - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + let room_version_id = self.services.state.get_room_version(&pdu.room_id)?; pdu.redact(room_version_id, reason)?; @@ -1058,8 +1080,8 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re return Ok(()); } - let power_levels: RoomPowerLevelsEventContent = services() - .rooms + let power_levels: RoomPowerLevelsEventContent = self + .services .state_accessor .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -1077,8 +1099,8 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re } }); - let room_alias_servers = services() - .rooms + let room_alias_servers = self + .services .alias .local_aliases_for_room(room_id) .filter_map(|alias| { @@ -1090,14 +1112,13 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re let servers = room_mods .chain(room_alias_servers) - .chain(services().globals.config.trusted_servers.clone()) + .chain(self.services.server.config.trusted_servers.clone()) .filter(|server_name| { if server_is_ours(server_name) { return false; } - services() - .rooms + self.services .state_cache .server_in_room(server_name, room_id) .unwrap_or(false) @@ -1105,7 +1126,8 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re for backfill_server in servers { info!("Asking {backfill_server} for backfill"); - let response = services() + let response = self + .services .sending .send_federation_request( &backfill_server, @@ -1141,11 +1163,11 @@ pub async fn backfill_pdu( &self, origin: &ServerName, pdu: Box<RawJsonValue>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<()> { - let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?; + let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?; // Lock so we cannot backfill the same pdu twice at the same time - let mutex_lock = services() - .rooms + let mutex_lock = self + .services .event_handler .mutex_federation .lock(&room_id) @@ -1158,14 +1180,12 @@ pub async fn backfill_pdu( return Ok(()); } - services() - .rooms + self.services .event_handler .fetch_required_signing_keys([&value], pub_key_map) .await?; - services() - .rooms + self.services .event_handler .handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map) .await?; @@ -1173,8 +1193,8 @@ pub async fn backfill_pdu( let value = self.get_pdu_json(&event_id)?.expect("We just created it"); let pdu = self.get_pdu(&event_id)?.expect("We just created it"); - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(&room_id)? .expect("room exists"); @@ -1182,7 +1202,7 @@ pub async fn backfill_pdu( let insert_lock = self.mutex_insert.lock(&room_id).await; let max = u64::MAX; - let count = services().globals.next_count()?; + let count = self.services.globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); @@ -1197,8 +1217,7 @@ pub async fn backfill_pdu( .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services() - .rooms + self.services .search .index_pdu(shortroomid, &pdu_id, &body)?; } diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 715e31624..d863f2175 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug_info, trace, utils, Result}; +use conduit::{debug_info, trace, utils, Result, Server}; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -8,19 +8,31 @@ }; use tokio::sync::{broadcast, RwLock}; -use crate::{services, user_is_local}; +use crate::{globals, sending, user_is_local, Dep}; pub struct Service { - pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>, // u64 is unix timestamp of timeout - pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>, /* timestamp of the last change to - * typing - * users */ + server: Arc<Server>, + services: Services, + /// u64 is unix timestamp of timeout + pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>, + /// timestamp of the last change to typing users + pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>, pub typing_update_sender: broadcast::Sender<OwnedRoomId>, } +struct Services { + globals: Dep<globals::Service>, + sending: Dep<sending::Service>, +} + impl crate::Service for Service { - fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> { + fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + server: args.server.clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + sending: args.depend::<sending::Service>("sending"), + }, typing: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()), typing_update_sender: broadcast::channel(100).0, @@ -45,14 +57,14 @@ pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if user_is_local(user_id) { - Self::federation_send(room_id, user_id, true)?; + self.federation_send(room_id, user_id, true)?; } Ok(()) @@ -71,14 +83,14 @@ pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result< self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } // update federation if user_is_local(user_id) { - Self::federation_send(room_id, user_id, false)?; + self.federation_send(room_id, user_id, false)?; } Ok(()) @@ -126,7 +138,7 @@ async fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { self.last_typing_update .write() .await - .insert(room_id.to_owned(), services().globals.next_count()?); + .insert(room_id.to_owned(), self.services.globals.next_count()?); if self.typing_update_sender.send(room_id.to_owned()).is_err() { trace!("receiver found what it was looking for and is no longer interested"); } @@ -134,7 +146,7 @@ async fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { // update federation for user in removable { if user_is_local(&user) { - Self::federation_send(room_id, &user, false)?; + self.federation_send(room_id, &user, false)?; } } } @@ -171,15 +183,15 @@ pub async fn typings_all( }) } - fn federation_send(room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { + fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",); - if !services().globals.config.allow_outgoing_typing { + if !self.server.config.allow_outgoing_typing { return Ok(()); } let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing)); - services() + self.services .sending .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?; diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 618caae00..c71316153 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{utils, Error, Result}; -use database::{Database, Map}; +use database::Map; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use crate::services; +use crate::{globals, rooms, Dep}; pub(super) struct Data { userroomid_notificationcount: Arc<Map>, @@ -12,16 +12,27 @@ pub(super) struct Data { roomuserid_lastnotificationread: Arc<Map>, roomsynctoken_shortstatehash: Arc<Map>, userroomid_joined: Arc<Map>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, + short: Dep<rooms::short::Service>, } impl Data { - pub(super) fn new(db: &Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { userroomid_notificationcount: db["userroomid_notificationcount"].clone(), userroomid_highlightcount: db["userroomid_highlightcount"].clone(), roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(), userroomid_joined: db["userroomid_joined"].clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + short: args.depend::<rooms::short::Service>("rooms::short"), + }, } } @@ -39,7 +50,7 @@ pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomI .insert(&userroom_id, &0_u64.to_be_bytes())?; self.roomuserid_lastnotificationread - .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + .insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?; Ok(()) } @@ -87,8 +98,8 @@ pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) pub(super) fn associate_token_shortstatehash( &self, room_id: &RoomId, token: u64, shortstatehash: u64, ) -> Result<()> { - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(room_id)? .expect("room exists"); @@ -101,8 +112,8 @@ pub(super) fn associate_token_shortstatehash( } pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { - let shortroomid = services() - .rooms + let shortroomid = self + .services .short .get_shortroomid(room_id)? .expect("room exists"); diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 12124a576..93d38470f 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -3,9 +3,10 @@ use std::sync::Arc; use conduit::Result; -use data::Data; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use self::data::Data; + pub struct Service { db: Data, } @@ -13,7 +14,7 @@ pub struct Service { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db), + db: Data::new(&args), })) } diff --git a/src/service/sending/appservice.rs b/src/service/sending/appservice.rs index 9e060e811..5ed40ad93 100644 --- a/src/service/sending/appservice.rs +++ b/src/service/sending/appservice.rs @@ -1,16 +1,17 @@ use std::{fmt::Debug, mem}; use bytes::BytesMut; +use conduit::{debug_error, trace, utils, warn, Error, Result}; +use reqwest::Client; use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; -use tracing::{trace, warn}; - -use crate::{debug_error, services, utils, Error, Result}; /// Sends a request to an appservice /// /// Only returns Ok(None) if there is no url specified in the appservice /// registration file -pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Result<Option<T::IncomingResponse>> +pub(crate) async fn send_request<T>( + client: &Client, registration: Registration, request: T, +) -> Result<Option<T::IncomingResponse>> where T: OutgoingRequest + Debug + Send, { @@ -48,15 +49,10 @@ pub(crate) async fn send_request<T>(registration: Registration, request: T) -> R let reqwest_request = reqwest::Request::try_from(http_request)?; - let mut response = services() - .client - .appservice - .execute(reqwest_request) - .await - .map_err(|e| { - warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id); - e - })?; + let mut response = client.execute(reqwest_request).await.map_err(|e| { + warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id); + e + })?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 9cb1c2670..6c8e2544d 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -5,7 +5,7 @@ use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; -use crate::services; +use crate::{globals, Dep}; type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>; @@ -15,15 +15,24 @@ pub struct Data { servernameevent_data: Arc<Map>, servername_educount: Arc<Map>, pub(super) db: Arc<Database>, + services: Services, +} + +struct Services { + globals: Dep<globals::Service>, } impl Data { - pub(super) fn new(db: Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { servercurrentevent_data: db["servercurrentevent_data"].clone(), servernameevent_data: db["servernameevent_data"].clone(), servername_educount: db["servername_educount"].clone(), - db, + db: args.db.clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + }, } } @@ -78,7 +87,7 @@ pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) - if let SendingEvent::Pdu(value) = &event { key.extend_from_slice(value); } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 26f43fd32..6f091b04f 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -6,26 +6,39 @@ use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; -use conduit::{err, Result, Server}; +use conduit::{err, warn, Result, Server}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -pub use sender::convert_to_outgoing_federation_event; use tokio::sync::Mutex; -use tracing::warn; -use crate::{server_is_ours, services}; +use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_is_ours, users, Dep}; pub struct Service { - pub db: data::Data, server: Arc<Server>, - - /// The state for a given state hash. + services: Services, + pub db: data::Data, sender: loole::Sender<Msg>, receiver: Mutex<loole::Receiver<Msg>>, } +struct Services { + client: Dep<client::Service>, + globals: Dep<globals::Service>, + resolver: Dep<resolver::Service>, + state: Dep<rooms::state::Service>, + state_cache: Dep<rooms::state_cache::Service>, + user: Dep<rooms::user::Service>, + users: Dep<users::Service>, + presence: Dep<presence::Service>, + read_receipt: Dep<rooms::read_receipt::Service>, + timeline: Dep<rooms::timeline::Service>, + account_data: Dep<account_data::Service>, + appservice: Dep<crate::appservice::Service>, + pusher: Dep<pusher::Service>, +} + #[derive(Clone, Debug, PartialEq, Eq)] struct Msg { dest: Destination, @@ -53,8 +66,23 @@ impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { let (sender, receiver) = loole::unbounded(); Ok(Arc::new(Self { - db: data::Data::new(args.db.clone()), server: args.server.clone(), + services: Services { + client: args.depend::<client::Service>("client"), + globals: args.depend::<globals::Service>("globals"), + resolver: args.depend::<resolver::Service>("resolver"), + state: args.depend::<rooms::state::Service>("rooms::state"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + user: args.depend::<rooms::user::Service>("rooms::user"), + users: args.depend::<users::Service>("users"), + presence: args.depend::<presence::Service>("presence"), + read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"), + timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), + account_data: args.depend::<account_data::Service>("account_data"), + appservice: args.depend::<crate::appservice::Service>("appservice"), + pusher: args.depend::<pusher::Service>("pusher"), + }, + db: data::Data::new(&args), sender, receiver: Mutex::new(receiver), })) @@ -103,8 +131,8 @@ pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Res #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -152,8 +180,8 @@ pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Resul #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -189,8 +217,8 @@ pub fn send_edu_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I, #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { - let servers = services() - .rooms + let servers = self + .services .state_cache .room_servers(room_id) .filter_map(Result::ok) @@ -213,13 +241,13 @@ pub fn flush_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I) -> Ok(()) } - #[tracing::instrument(skip(self, request), name = "request")] + #[tracing::instrument(skip_all, name = "request")] pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse> where T: OutgoingRequest + Debug + Send, { - let client = &services().client.federation; - send::send(client, dest, request).await + let client = &self.services.client.federation; + self.send(client, dest, request).await } /// Sends a request to an appservice @@ -232,7 +260,8 @@ pub async fn send_appservice_request<T>( where T: OutgoingRequest + Debug + Send, { - appservice::send_request(registration, request).await + let client = &self.services.client.appservice; + appservice::send_request(client, registration, request).await } /// Cleanup event data diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 7901de48b..b3a84d622 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,6 +1,8 @@ use std::{fmt::Debug, mem}; -use conduit::Err; +use conduit::{ + debug, debug_error, debug_warn, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, Result, +}; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; use reqwest::{Client, Method, Request, Response, Url}; @@ -13,75 +15,91 @@ server_util::authorization::XMatrix, ServerName, }; -use tracing::{debug, trace}; use crate::{ - debug_error, debug_warn, resolver, + globals, resolver, resolver::{actual::ActualDest, cache::CachedDest}, - services, Error, Result, }; -#[tracing::instrument(skip_all, name = "send")] -pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> -where - T: OutgoingRequest + Debug + Send, -{ - if !services().globals.allow_federation() { - return Err!(Config("allow_federation", "Federation is disabled.")); - } +impl super::Service { + #[tracing::instrument(skip(self, client, req), name = "send")] + pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> + where + T: OutgoingRequest + Debug + Send, + { + if !self.server.config.allow_federation { + return Err!(Config("allow_federation", "Federation is disabled.")); + } - let actual = services().resolver.get_actual_dest(dest).await?; - let request = prepare::<T>(dest, &actual, req).await?; - execute::<T>(client, dest, &actual, request).await -} + let actual = self.services.resolver.get_actual_dest(dest).await?; + let request = self.prepare::<T>(dest, &actual, req).await?; + self.execute::<T>(dest, &actual, request, client).await + } -async fn execute<T>( - client: &Client, dest: &ServerName, actual: &ActualDest, request: Request, -) -> Result<T::IncomingResponse> -where - T: OutgoingRequest + Debug + Send, -{ - let method = request.method().clone(); - let url = request.url().clone(); - debug!( - method = ?method, - url = ?url, - "Sending request", - ); - match client.execute(request).await { - Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await, - Err(e) => handle_error::<T>(dest, actual, &method, &url, e), + async fn execute<T>( + &self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client, + ) -> Result<T::IncomingResponse> + where + T: OutgoingRequest + Debug + Send, + { + let url = request.url().clone(); + let method = request.method().clone(); + + debug!(?method, ?url, "Sending request"); + match client.execute(request).await { + Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await, + Err(error) => handle_error::<T>(dest, actual, &method, &url, error), + } } -} -async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request> -where - T: OutgoingRequest + Debug + Send, -{ - const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; + async fn prepare<T>(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request> + where + T: OutgoingRequest + Debug + Send, + { + const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; + const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY); - trace!("Preparing request"); + trace!("Preparing request"); + let mut http_request = req + .try_into_http_request::<Vec<u8>>(&actual.string, SATIR, &VERSIONS) + .map_err(|_| Error::BadServerResponse("Invalid destination"))?; - let mut http_request = req - .try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS) - .map_err(|_e| Error::BadServerResponse("Invalid destination"))?; + sign_request::<T>(&self.services.globals, dest, &mut http_request); - sign_request::<T>(dest, &mut http_request); + let request = Request::try_from(http_request)?; + self.validate_url(request.url())?; - let request = Request::try_from(http_request)?; - validate_url(request.url())?; + Ok(request) + } - Ok(request) + fn validate_url(&self, url: &Url) -> Result<()> { + if let Some(url_host) = url.host_str() { + if let Ok(ip) = IPAddress::parse(url_host) { + trace!("Checking request URL IP {ip:?}"); + self.services.resolver.validate_ip(&ip)?; + } + } + + Ok(()) + } } async fn handle_response<T>( - dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, + resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, + mut response: Response, ) -> Result<T::IncomingResponse> where T: OutgoingRequest + Debug + Send, { - trace!("Received response from {} for {} with {}", actual.string, url, response.url()); let status = response.status(); + trace!( + ?status, ?method, + request_url = ?url, + response_url = ?response.url(), + "Received response from {}", + actual.string, + ); + let mut http_response_builder = http::Response::builder() .status(status) .version(response.version()); @@ -92,11 +110,13 @@ async fn handle_response<T>( .expect("http::response::Builder is usable"), ); - trace!("Waiting for response body"); - let body = response.bytes().await.unwrap_or_else(|e| { - debug_error!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout + // TODO: handle timeout + trace!("Waiting for response body..."); + let body = response + .bytes() + .await + .inspect_err(inspect_debug_log) + .unwrap_or_else(|_| Vec::new().into()); let http_response = http_response_builder .body(body) @@ -109,7 +129,7 @@ async fn handle_response<T>( let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services().resolver.set_cached_destination( + resolver.set_cached_destination( dest.to_owned(), CachedDest { dest: actual.dest.clone(), @@ -120,7 +140,7 @@ async fn handle_response<T>( } match response { - Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")), + Err(_) => Err(Error::BadServerResponse("Server returned bad 200 response.")), Ok(response) => Ok(response), } } @@ -150,7 +170,7 @@ fn handle_error<T>( Err(e.into()) } -fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) +fn sign_request<T>(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) where T: OutgoingRequest + Debug + Send, { @@ -172,16 +192,12 @@ fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) .to_string() .into(), ); - req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); + req_map.insert("origin".to_owned(), globals.server_name().as_str().into()); req_map.insert("destination".to_owned(), dest.as_str().into()); let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap"); - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut req_json, - ) - .expect("our request json is what ruma expects"); + ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json) + .expect("our request json is what ruma expects"); let req_json: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap(); @@ -207,24 +223,8 @@ fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) http_request.headers_mut().insert( AUTHORIZATION, - HeaderValue::from(&XMatrix::new( - services().globals.config.server_name.clone(), - dest.to_owned(), - key, - sig, - )), + HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)), ); } } } - -fn validate_url(url: &Url) -> Result<()> { - if let Some(url_host) = url.host_str() { - if let Ok(ip) = IPAddress::parse(url_host) { - trace!("Checking request URL IP {ip:?}"); - resolver::actual::validate_ip(&ip)?; - } - } - - Ok(()) -} diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 774c3d694..0668ce242 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -6,7 +6,11 @@ }; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; +use conduit::{ + debug, debug_warn, error, trace, + utils::{calculate_hash, math::continue_exponential_backoff_secs}, + warn, Error, Result, +}; use federation::transactions::send_transaction_message; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -24,8 +28,8 @@ use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::time::sleep_until; -use super::{appservice, send, Destination, Msg, SendingEvent, Service}; -use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; +use super::{appservice, Destination, Msg, SendingEvent, Service}; +use crate::user_is_local; #[derive(Debug)] enum TransactionStatus { @@ -69,8 +73,8 @@ pub(super) async fn sender(&self) -> Result<()> { Ok(()) } - fn handle_response( - &self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, + fn handle_response<'a>( + &'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { match response { Ok(dest) => self.handle_response_ok(&dest, futures, statuses), @@ -91,8 +95,8 @@ fn handle_response_err( }); } - fn handle_response_ok( - &self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, + fn handle_response_ok<'a>( + &'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { let _cork = self.db.db.cork(); self.db @@ -113,24 +117,24 @@ fn handle_response_ok( .mark_as_active(&new_events) .expect("marked as active"); let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect(); - futures.push(Box::pin(send_events(dest.clone(), new_events_vec))); + futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec))); } else { statuses.remove(dest); } } - fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let iv = vec![(msg.event, msg.queue_id)]; if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { if !events.is_empty() { - futures.push(Box::pin(send_events(msg.dest, events))); + futures.push(Box::pin(self.send_events(msg.dest, events))); } else { statuses.remove(&msg.dest); } } } - async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let now = Instant::now(); let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS); let deadline = now.checked_add(timeout).unwrap_or(now); @@ -148,7 +152,7 @@ async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mu debug_warn!("Leaving with {} unfinished requests...", futures.len()); } - fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { @@ -166,12 +170,12 @@ fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTrans for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(Box::pin(send_events(dest.clone(), events))); + futures.push(Box::pin(self.send_events(dest.clone(), events))); } } } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_events( &self, dest: &Destination, @@ -218,7 +222,7 @@ fn select_events( Ok(Some(events)) } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> { let (mut allow, mut retry) = (true, false); statuses @@ -244,7 +248,7 @@ fn select_events_current(&self, dest: Destination, statuses: &mut CurTransaction Ok((allow, retry)) } - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip_all, level = "debug")] fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { // u64: count of last edu let since = self.db.get_latest_educount(server_name)?; @@ -252,11 +256,11 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - for room_id in services().rooms.state_cache.server_rooms(server_name) { + for room_id in self.services.state_cache.server_rooms(server_name) { let room_id = room_id?; // Look for device list updates in this room device_list_changes.extend( - services() + self.services .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok) @@ -264,7 +268,7 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { ); if self.server.config.allow_outgoing_read_receipts - && !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? + && !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? { break; } @@ -287,381 +291,390 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { } if self.server.config.allow_outgoing_presence { - select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; + self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?; } Ok((events, max_edu_count)) } -} -/// Look for presence -fn select_edus_presence( - server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, -) -> Result<bool> { - // Look for presence updates for this server - let mut presence_updates = Vec::new(); - for (user_id, count, presence_bytes) in services().presence.presence_since(since) { - *max_edu_count = cmp::max(count, *max_edu_count); - - if !user_is_local(&user_id) { - continue; - } + /// Look for presence + fn select_edus_presence( + &self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, + ) -> Result<bool> { + // Look for presence updates for this server + let mut presence_updates = Vec::new(); + for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { + *max_edu_count = cmp::max(count, *max_edu_count); + + if !user_is_local(&user_id) { + continue; + } - if !services() - .rooms - .state_cache - .server_sees_user(server_name, &user_id)? - { - continue; - } + if !self + .services + .state_cache + .server_sees_user(server_name, &user_id)? + { + continue; + } - let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?; - presence_updates.push(PresenceUpdate { - user_id, - presence: presence_event.content.presence, - currently_active: presence_event.content.currently_active.unwrap_or(false), - last_active_ago: presence_event - .content - .last_active_ago - .unwrap_or_else(|| uint!(0)), - status_msg: presence_event.content.status_msg, - }); + let presence_event = self + .services + .presence + .from_json_bytes_to_event(&presence_bytes, &user_id)?; + presence_updates.push(PresenceUpdate { + user_id, + presence: presence_event.content.presence, + currently_active: presence_event.content.currently_active.unwrap_or(false), + last_active_ago: presence_event + .content + .last_active_ago + .unwrap_or_else(|| uint!(0)), + status_msg: presence_event.content.status_msg, + }); - if presence_updates.len() >= SELECT_EDU_LIMIT { - break; + if presence_updates.len() >= SELECT_EDU_LIMIT { + break; + } + } + + if !presence_updates.is_empty() { + let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); + events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); } - } - if !presence_updates.is_empty() { - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + Ok(true) } - Ok(true) -} + /// Look for read receipts in this room + fn select_edus_receipts( + &self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, + ) -> Result<bool> { + for r in self + .services + .read_receipt + .readreceipts_since(room_id, since) + { + let (user_id, count, read_receipt) = r?; + *max_edu_count = cmp::max(count, *max_edu_count); -/// Look for read receipts in this room -fn select_edus_receipts( - room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>, -) -> Result<bool> { - for r in services() - .rooms - .read_receipt - .readreceipts_since(room_id, since) - { - let (user_id, count, read_receipt) = r?; - *max_edu_count = cmp::max(count, *max_edu_count); - - if !user_is_local(&user_id) { - continue; - } + if !user_is_local(&user_id) { + continue; + } - let event = serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { - let mut read = BTreeMap::new(); - - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); + let event = serde_json::from_str(read_receipt.json().get()) + .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { + let mut read = BTreeMap::new(); + + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); - let receipt_map = ReceiptMap { - read, - }; + let receipt_map = ReceiptMap { + read, + }; - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.to_owned(), receipt_map); + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.to_owned(), receipt_map); - Edu::Receipt(ReceiptContent { - receipts, - }) - } else { - Error::bad_database("Invalid event type in read_receipts"); - continue; - }; + Edu::Receipt(ReceiptContent { + receipts, + }) + } else { + Error::bad_database("Invalid event type in read_receipts"); + continue; + }; - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); - if events.len() >= SELECT_EDU_LIMIT { - return Ok(false); + if events.len() >= SELECT_EDU_LIMIT { + return Ok(false); + } } - } - - Ok(true) -} -async fn send_events(dest: Destination, events: Vec<SendingEvent>) -> SendingResult { - //debug_assert!(!events.is_empty(), "sending empty transaction"); - match dest { - Destination::Normal(ref server) => send_events_dest_normal(&dest, server, events).await, - Destination::Appservice(ref id) => send_events_dest_appservice(&dest, id, events).await, - Destination::Push(ref userid, ref pushkey) => send_events_dest_push(&dest, userid, pushkey, events).await, + Ok(true) } -} -#[tracing::instrument(skip(dest, events))] -async fn send_events_dest_appservice(dest: &Destination, id: &str, events: Vec<SendingEvent>) -> SendingResult { - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => { - pdu_jsons.push( - services() - .rooms - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), - ) - })? - .to_room_event(), - ); - }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Appservices don't need EDUs (?) and flush only; - // no new content + async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult { + //debug_assert!(!events.is_empty(), "sending empty transaction"); + match dest { + Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await, + Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await, + Destination::Push(ref userid, ref pushkey) => { + self.send_events_dest_push(&dest, userid, pushkey, events) + .await }, } } - //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); - match appservice::send_request( - services() - .appservice - .get_registration(id) - .await - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, - ruma::api::appservice::event::push_events::v1::Request { - events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::<Vec<_>>(), - ))) - .into(), - }, - ) - .await - { - Ok(_) => Ok(dest.clone()), - Err(e) => Err((dest.clone(), e)), - } -} + #[tracing::instrument(skip(self, dest, events), name = "appservice")] + async fn send_events_dest_appservice( + &self, dest: &Destination, id: &str, events: Vec<SendingEvent>, + ) -> SendingResult { + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEvent::Pdu(pdu_id) => { + pdu_jsons.push( + self.services + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Appservice] Event in servernameevent_data not found in db."), + ) + })? + .to_room_event(), + ); + }, + SendingEvent::Edu(_) | SendingEvent::Flush => { + // Appservices don't need EDUs (?) and flush only; + // no new content + }, + } + } -#[tracing::instrument(skip(dest, events))] -async fn send_events_dest_push( - dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>, -) -> SendingResult { - let mut pdus = Vec::new(); - - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => { - pdus.push( - services() - .rooms - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - ( - dest.clone(), - Error::bad_database("[Push] Event in servernameevent_data not found in db."), - ) - })?, - ); - }, - SendingEvent::Edu(_) | SendingEvent::Flush => { - // Push gateways don't need EDUs (?) and flush only; - // no new content + //debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction"); + let client = &self.services.client.appservice; + match appservice::send_request( + client, + self.services + .appservice + .get_registration(id) + .await + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Appservice] Could not load registration from db."), + ) + })?, + ruma::api::appservice::event::push_events::v1::Request { + events: pdu_jsons, + txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::<Vec<_>>(), + ))) + .into(), }, + ) + .await + { + Ok(_) => Ok(dest.clone()), + Err(e) => Err((dest.clone(), e)), } } - for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) { - if unsigned.get("redacted_because").is_some() { - continue; - } + #[tracing::instrument(skip(self, dest, events), name = "push")] + async fn send_events_dest_push( + &self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>, + ) -> SendingResult { + let mut pdus = Vec::new(); + + for event in &events { + match event { + SendingEvent::Pdu(pdu_id) => { + pdus.push( + self.services + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + ( + dest.clone(), + Error::bad_database("[Push] Event in servernameevent_data not found in db."), + ) + })?, + ); + }, + SendingEvent::Edu(_) | SendingEvent::Flush => { + // Push gateways don't need EDUs (?) and flush only; + // no new content + }, } } - let Some(pusher) = services() - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (dest.clone(), e))? - else { - continue; - }; + for pdu in pdus { + // Redacted events are not notification targets (we don't send push for them) + if let Some(unsigned) = &pdu.unsigned { + if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) { + if unsigned.get("redacted_because").is_some() { + continue; + } + } + } - let rules_for_user = services() - .account_data - .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok()) - .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); - - let unread: UInt = services() - .rooms - .user - .notification_count(userid, &pdu.room_id) - .map_err(|e| (dest.clone(), e))? - .try_into() - .expect("notification count can't go that high"); - - let _response = services() - .pusher - .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) - .await - .map(|_response| dest.clone()) - .map_err(|e| (dest.clone(), e)); - } + let Some(pusher) = self + .services + .pusher + .get_pusher(userid, pushkey) + .map_err(|e| (dest.clone(), e))? + else { + continue; + }; - Ok(dest.clone()) -} + let rules_for_user = self + .services + .account_data + .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap_or_default() + .and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok()) + .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + + let unread: UInt = self + .services + .user + .notification_count(userid, &pdu.room_id) + .map_err(|e| (dest.clone(), e))? + .try_into() + .expect("notification count can't go that high"); + + let _response = self + .services + .pusher + .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .await + .map(|_response| dest.clone()) + .map_err(|e| (dest.clone(), e)); + } -#[tracing::instrument(skip(dest, events), name = "")] -async fn send_events_dest_normal( - dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>, -) -> SendingResult { - let mut pdu_jsons = Vec::with_capacity( - events - .iter() - .filter(|event| matches!(event, SendingEvent::Pdu(_))) - .count(), - ); - let mut edu_jsons = Vec::with_capacity( - events - .iter() - .filter(|event| matches!(event, SendingEvent::Edu(_))) - .count(), - ); - - for event in &events { - match event { - SendingEvent::Pdu(pdu_id) => pdu_jsons.push(convert_to_outgoing_federation_event( + Ok(dest.clone()) + } + + #[tracing::instrument(skip(self, dest, events), name = "", level = "debug")] + async fn send_events_dest_normal( + &self, dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>, + ) -> SendingResult { + let mut pdu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Pdu(_))) + .count(), + ); + let mut edu_jsons = Vec::with_capacity( + events + .iter() + .filter(|event| matches!(event, SendingEvent::Edu(_))) + .count(), + ); + + for event in &events { + match event { // TODO: check room version and remove event_id if needed - services() - .rooms - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - )), - SendingEvent::Edu(edu) => { - if let Ok(raw) = serde_json::from_slice(edu) { - edu_jsons.push(raw); - } - }, - SendingEvent::Flush => { - // flush only; no new content - }, + SendingEvent::Pdu(pdu_id) => pdu_jsons.push( + self.convert_to_outgoing_federation_event( + self.services + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + error!(?dest, ?server, ?pdu_id, "event not found"); + ( + dest.clone(), + Error::bad_database("[Normal] Event in servernameevent_data not found in db."), + ) + })?, + ), + ), + SendingEvent::Edu(edu) => { + if let Ok(raw) = serde_json::from_slice(edu) { + edu_jsons.push(raw); + } + }, + SendingEvent::Flush => {}, // flush only; no new content + } } - } - //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty - // transaction"); - send::send( - &services().client.sender, - server, - send_transaction_message::v1::Request { - origin: services().server.config.server_name.clone(), + //debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty + // transaction"); + let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Flush => &[], + }) + .collect::<Vec<_>>(), + )); + + let request = send_transaction_message::v1::Request { + origin: self.server.config.server_name.clone(), pdus: pdu_jsons, edus: edu_jsons, origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, - SendingEvent::Flush => &[], - }) - .collect::<Vec<_>>(), - ))) - .into(), - }, - ) - .await - .map(|response| { - for pdu in response.pdus { - if pdu.1.is_err() { - warn!("error for {} from remote: {:?}", pdu.0, pdu.1); - } - } - dest.clone() - }) - .map_err(|e| (dest.clone(), e)) -} + transaction_id: transaction_id.into(), + }; -/// This does not return a full `Pdu` it is only to satisfy ruma's types. -#[tracing::instrument] -pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { - if let Some(unsigned) = pdu_json - .get_mut("unsigned") - .and_then(|val| val.as_object_mut()) - { - unsigned.remove("transaction_id"); + let client = &self.services.client.sender; + self.send(client, server, request) + .await + .inspect(|response| { + response + .pdus + .iter() + .filter(|(_, res)| res.is_err()) + .for_each(|(pdu_id, res)| warn!("error for {pdu_id} from remote: {res:?}")); + }) + .map(|_| dest.clone()) + .map_err(|e| (dest.clone(), e)) } - // room v3 and above removed the "event_id" field from remote PDU format - if let Some(room_id) = pdu_json - .get("room_id") - .and_then(|val| RoomId::parse(val.as_str()?).ok()) - { - match services().rooms.state.get_room_version(&room_id) { - Ok(room_version_id) => match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - _ => _ = pdu_json.remove("event_id"), - }, - Err(_) => _ = pdu_json.remove("event_id"), + /// This does not return a full `Pdu` it is only to satisfy ruma's types. + pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { + if let Some(unsigned) = pdu_json + .get_mut("unsigned") + .and_then(|val| val.as_object_mut()) + { + unsigned.remove("transaction_id"); } - } else { - pdu_json.remove("event_id"); - } - // TODO: another option would be to convert it to a canonical string to validate - // size and return a Result<Raw<...>> - // serde_json::from_str::<Raw<_>>( - // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is - // valid serde_json::Value"), ) - // .expect("Raw::from_value always works") + // room v3 and above removed the "event_id" field from remote PDU format + if let Some(room_id) = pdu_json + .get("room_id") + .and_then(|val| RoomId::parse(val.as_str()?).ok()) + { + match self.services.state.get_room_version(&room_id) { + Ok(room_version_id) => match room_version_id { + RoomVersionId::V1 | RoomVersionId::V2 => {}, + _ => _ = pdu_json.remove("event_id"), + }, + Err(_) => _ = pdu_json.remove("event_id"), + } + } else { + pdu_json.remove("event_id"); + } - to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") + // TODO: another option would be to convert it to a canonical string to validate + // size and return a Result<Raw<...>> + // serde_json::from_str::<Raw<_>>( + // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is + // valid serde_json::Value"), ) + // .expect("Raw::from_value always works") + + to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") + } } diff --git a/src/service/service.rs b/src/service/service.rs index ce4f15b2a..bf3b891bd 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -3,7 +3,7 @@ collections::BTreeMap, fmt::Write, ops::Deref, - sync::{Arc, OnceLock}, + sync::{Arc, OnceLock, RwLock}, }; use async_trait::async_trait; @@ -50,20 +50,20 @@ pub(crate) struct Args<'a> { } /// Dep is a reference to a service used within another service. -/// Circular-dependencies between services require this indirection to allow the -/// referenced service construction after the referencing service. +/// Circular-dependencies between services require this indirection. pub(crate) struct Dep<T> { dep: OnceLock<Arc<T>>, service: Arc<Map>, name: &'static str, } -pub(crate) type Map = BTreeMap<String, MapVal>; +pub(crate) type Map = RwLock<BTreeMap<String, MapVal>>; pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>); -impl<T: Any + Send + Sync> Deref for Dep<T> { +impl<T: Send + Sync + 'static> Deref for Dep<T> { type Target = Arc<T>; + /// Dereference a dependency. The dependency must be ready or panics. fn deref(&self) -> &Self::Target { self.dep .get_or_init(|| require::<T>(&self.service, self.name)) @@ -71,39 +71,61 @@ fn deref(&self) -> &Self::Target { } impl Args<'_> { - pub(crate) fn depend_service<T: Any + Send + Sync>(&self, name: &'static str) -> Dep<T> { + /// Create a lazy-reference to a service when constructing another Service. + pub(crate) fn depend<T: Send + Sync + 'static>(&self, name: &'static str) -> Dep<T> { Dep::<T> { dep: OnceLock::new(), service: self.service.clone(), name, } } + + /// Create a reference immediately to a service when constructing another + /// Service. The other service must be constructed. + pub(crate) fn require<T: Send + Sync + 'static>(&self, name: &str) -> Arc<T> { require::<T>(self.service, name) } } -pub(crate) fn require<T: Any + Send + Sync>(map: &Map, name: &str) -> Arc<T> { +/// Reference a Service by name. Panics if the Service does not exist or was +/// incorrectly cast. +pub(crate) fn require<T: Send + Sync + 'static>(map: &Map, name: &str) -> Arc<T> { try_get::<T>(map, name) .inspect_err(inspect_log) .expect("Failure to reference service required by another service.") } -pub(crate) fn try_get<T: Any + Send + Sync>(map: &Map, name: &str) -> Result<Arc<T>> { - map.get(name).map_or_else( - || Err!("Service {name:?} does not exist or has not been built yet."), - |(_, s)| { - s.clone() - .downcast::<T>() - .map_err(|_| err!("Service {name:?} must be correctly downcast.")) - }, - ) +/// Reference a Service by name. Returns Err if the Service does not exist or +/// was incorrectly cast. +pub(crate) fn try_get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Result<Arc<T>> { + map.read() + .expect("locked for reading") + .get(name) + .map_or_else( + || Err!("Service {name:?} does not exist or has not been built yet."), + |(_, s)| { + s.clone() + .downcast::<T>() + .map_err(|_| err!("Service {name:?} must be correctly downcast.")) + }, + ) } -pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> { - map.get(name).map(|(_, s)| { - s.clone() - .downcast::<T>() - .expect("Service must be correctly downcast.") - }) +/// Reference a Service by name. Returns None if the Service does not exist, but +/// panics if incorrectly cast. +/// +/// # Panics +/// Incorrect type is not a silent failure (None) as the type never has a reason +/// to be incorrect. +pub(crate) fn get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Option<Arc<T>> { + map.read() + .expect("locked for reading") + .get(name) + .map(|(_, s)| { + s.clone() + .downcast::<T>() + .expect("Service must be correctly downcast.") + }) } +/// Utility for service implementations; see Service::name() in the trait. #[inline] pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index 68205323f..59909f8cb 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,11 +1,16 @@ -use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc}; +use std::{ + any::Any, + collections::BTreeMap, + fmt::Write, + sync::{Arc, RwLock}, +}; use conduit::{debug, debug_info, info, trace, Result, Server}; use database::Database; use tokio::sync::Mutex; use crate::{ - account_data, admin, appservice, client, globals, key_backups, + account_data, admin, appservice, client, emergency, globals, key_backups, manager::Manager, media, presence, pusher, resolver, rooms, sending, service, service::{Args, Map, Service}, @@ -13,22 +18,23 @@ }; pub struct Services { - pub resolver: Arc<resolver::Service>, - pub client: Arc<client::Service>, - pub globals: Arc<globals::Service>, - pub rooms: rooms::Service, - pub appservice: Arc<appservice::Service>, - pub pusher: Arc<pusher::Service>, - pub transaction_ids: Arc<transaction_ids::Service>, - pub uiaa: Arc<uiaa::Service>, - pub users: Arc<users::Service>, pub account_data: Arc<account_data::Service>, - pub presence: Arc<presence::Service>, pub admin: Arc<admin::Service>, + pub appservice: Arc<appservice::Service>, + pub client: Arc<client::Service>, + pub emergency: Arc<emergency::Service>, + pub globals: Arc<globals::Service>, pub key_backups: Arc<key_backups::Service>, pub media: Arc<media::Service>, + pub presence: Arc<presence::Service>, + pub pusher: Arc<pusher::Service>, + pub resolver: Arc<resolver::Service>, + pub rooms: rooms::Service, pub sending: Arc<sending::Service>, + pub transaction_ids: Arc<transaction_ids::Service>, + pub uiaa: Arc<uiaa::Service>, pub updates: Arc<updates::Service>, + pub users: Arc<users::Service>, manager: Mutex<Option<Arc<Manager>>>, pub(crate) service: Arc<Map>, @@ -36,37 +42,34 @@ pub struct Services { pub db: Arc<Database>, } -macro_rules! build_service { - ($map:ident, $server:ident, $db:ident, $tyname:ty) => {{ - let built = <$tyname>::build(Args { - server: &$server, - db: &$db, - service: &$map, - })?; - - Arc::get_mut(&mut $map) - .expect("must have mutable reference to services collection") - .insert(built.name().to_owned(), (built.clone(), built.clone())); - - trace!("built service #{}: {:?}", $map.len(), built.name()); - built - }}; -} - impl Services { #[allow(clippy::cognitive_complexity)] pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> { - let mut service: Arc<Map> = Arc::new(BTreeMap::new()); + let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new())); macro_rules! build { - ($srv:ty) => { - build_service!(service, server, db, $srv) - }; + ($tyname:ty) => {{ + let built = <$tyname>::build(Args { + db: &db, + server: &server, + service: &service, + })?; + add_service(&service, built.clone(), built.clone()); + built + }}; } Ok(Self { - globals: build!(globals::Service), + account_data: build!(account_data::Service), + admin: build!(admin::Service), + appservice: build!(appservice::Service), resolver: build!(resolver::Service), client: build!(client::Service), + emergency: build!(emergency::Service), + globals: build!(globals::Service), + key_backups: build!(key_backups::Service), + media: build!(media::Service), + presence: build!(presence::Service), + pusher: build!(pusher::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), auth_chain: build!(rooms::auth_chain::Service), @@ -79,28 +82,22 @@ macro_rules! build { read_receipt: build!(rooms::read_receipt::Service), search: build!(rooms::search::Service), short: build!(rooms::short::Service), + spaces: build!(rooms::spaces::Service), state: build!(rooms::state::Service), state_accessor: build!(rooms::state_accessor::Service), state_cache: build!(rooms::state_cache::Service), state_compressor: build!(rooms::state_compressor::Service), - timeline: build!(rooms::timeline::Service), threads: build!(rooms::threads::Service), + timeline: build!(rooms::timeline::Service), typing: build!(rooms::typing::Service), - spaces: build!(rooms::spaces::Service), user: build!(rooms::user::Service), }, - appservice: build!(appservice::Service), - pusher: build!(pusher::Service), + sending: build!(sending::Service), transaction_ids: build!(transaction_ids::Service), uiaa: build!(uiaa::Service), - users: build!(users::Service), - account_data: build!(account_data::Service), - presence: build!(presence::Service), - admin: build!(admin::Service), - key_backups: build!(key_backups::Service), - media: build!(media::Service), - sending: build!(sending::Service), updates: build!(updates::Service), + users: build!(users::Service), + manager: Mutex::new(None), service, server, @@ -111,7 +108,7 @@ macro_rules! build { pub(super) async fn start(&self) -> Result<()> { debug_info!("Starting services..."); - globals::migrations::migrations(&self.db, &self.server.config).await?; + globals::migrations::migrations(self).await?; self.manager .lock() .await @@ -144,7 +141,7 @@ pub async fn poll(&self) -> Result<()> { } pub async fn clear_cache(&self) { - for (service, ..) in self.service.values() { + for (service, ..) in self.service.read().expect("locked for reading").values() { service.clear_cache(); } @@ -159,7 +156,7 @@ pub async fn clear_cache(&self) { pub async fn memory_usage(&self) -> Result<String> { let mut out = String::new(); - for (service, ..) in self.service.values() { + for (service, ..) in self.service.read().expect("locked for reading").values() { service.memory_usage(&mut out)?; } @@ -179,23 +176,26 @@ pub async fn memory_usage(&self) -> Result<String> { fn interrupt(&self) { debug!("Interrupting services..."); - for (name, (service, ..)) in self.service.iter() { + for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() { trace!("Interrupting {name}"); service.interrupt(); } } - pub fn try_get<T>(&self, name: &str) -> Result<Arc<T>> - where - T: Any + Send + Sync, - { + pub fn try_get<T: Send + Sync + 'static>(&self, name: &str) -> Result<Arc<T>> { service::try_get::<T>(&self.service, name) } - pub fn get<T>(&self, name: &str) -> Option<Arc<T>> - where - T: Any + Send + Sync, - { - service::get::<T>(&self.service, name) - } + pub fn get<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> { service::get::<T>(&self.service, name) } +} + +fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) { + let name = s.name(); + let len = map.read().expect("locked for reading").len(); + + trace!("built service #{len}: {name:?}"); + + map.write() + .expect("locked for writing") + .insert(name.to_owned(), (s, a)); } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 4b953ffb5..6041bbd34 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use conduit::{utils, utils::hash, Error, Result}; +use conduit::{error, utils, utils::hash, Error, Result, Server}; use data::Data; use ruma::{ api::client::{ @@ -11,19 +11,30 @@ }, CanonicalJsonValue, DeviceId, UserId, }; -use tracing::error; -use crate::services; +use crate::{globals, users, Dep}; pub const SESSION_ID_LENGTH: usize = 32; pub struct Service { + server: Arc<Server>, + services: Services, pub db: Data, } +struct Services { + globals: Dep<globals::Service>, + users: Dep<users::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + server: args.server.clone(), + services: Services { + globals: args.depend::<globals::Service>("globals"), + users: args.depend::<users::Service>("users"), + }, db: Data::new(args.db), })) } @@ -87,11 +98,11 @@ pub fn try_auth( return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); }; - let user_id = UserId::parse_with_server_name(username.clone(), services().globals.server_name()) + let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name()) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; // Check if password is correct - if let Some(hash) = services().users.password_hash(&user_id)? { + if let Some(hash) = self.services.users.password_hash(&user_id)? { let hash_matches = hash::verify_password(password, &hash).is_ok(); if !hash_matches { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { @@ -106,7 +117,7 @@ pub fn try_auth( uiaainfo.completed.push(AuthType::Password); }, AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { + if Some(t.token.trim()) == self.server.config.registration_token.as_deref() { uiaainfo.completed.push(AuthType::RegistrationToken); } else { uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs index db69d9b0f..f89471cdf 100644 --- a/src/service/updates/mod.rs +++ b/src/service/updates/mod.rs @@ -7,14 +7,20 @@ use serde::Deserialize; use tokio::{sync::Notify, time::interval}; -use crate::services; +use crate::{admin, client, Dep}; pub struct Service { + services: Services, db: Arc<Map>, interrupt: Notify, interval: Duration, } +struct Services { + admin: Dep<admin::Service>, + client: Dep<client::Service>, +} + #[derive(Deserialize)] struct CheckForUpdatesResponse { updates: Vec<CheckForUpdatesResponseEntry>, @@ -35,6 +41,10 @@ struct CheckForUpdatesResponseEntry { impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { + services: Services { + admin: args.depend::<admin::Service>("admin"), + client: args.depend::<client::Service>("client"), + }, db: args.db["global"].clone(), interrupt: Notify::new(), interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), @@ -63,7 +73,8 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } impl Service { #[tracing::instrument(skip_all)] async fn handle_updates(&self) -> Result<()> { - let response = services() + let response = self + .services .client .default .get(CHECK_FOR_UPDATES_URL) @@ -78,7 +89,7 @@ async fn handle_updates(&self) -> Result<()> { last_update_id = last_update_id.max(update.id); if update.id > self.last_check_for_updates_id()? { info!("{:#}", update.message); - services() + self.services .admin .send_message(RoomMessageEventContent::text_markdown(format!( "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 5546adb15..2dcde7ce6 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, mem::size_of, sync::Arc}; -use conduit::{debug_info, err, utils, warn, Err, Error, Result}; -use database::{Database, Map}; +use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server}; +use database::Map; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, @@ -11,52 +11,65 @@ OwnedMxcUri, OwnedUserId, UInt, UserId, }; -use crate::{services, users::clean_signatures}; +use crate::{globals, rooms, users::clean_signatures, Dep}; pub struct Data { - userid_password: Arc<Map>, + keychangeid_userid: Arc<Map>, + keyid_key: Arc<Map>, + onetimekeyid_onetimekeys: Arc<Map>, + openidtoken_expiresatuserid: Arc<Map>, + todeviceid_events: Arc<Map>, token_userdeviceid: Arc<Map>, - userid_displayname: Arc<Map>, + userdeviceid_metadata: Arc<Map>, + userdeviceid_token: Arc<Map>, + userfilterid_filter: Arc<Map>, userid_avatarurl: Arc<Map>, userid_blurhash: Arc<Map>, userid_devicelistversion: Arc<Map>, - userdeviceid_token: Arc<Map>, - userdeviceid_metadata: Arc<Map>, - onetimekeyid_onetimekeys: Arc<Map>, + userid_displayname: Arc<Map>, userid_lastonetimekeyupdate: Arc<Map>, - keyid_key: Arc<Map>, userid_masterkeyid: Arc<Map>, + userid_password: Arc<Map>, userid_selfsigningkeyid: Arc<Map>, userid_usersigningkeyid: Arc<Map>, - openidtoken_expiresatuserid: Arc<Map>, - keychangeid_userid: Arc<Map>, - todeviceid_events: Arc<Map>, - userfilterid_filter: Arc<Map>, - _db: Arc<Database>, + services: Services, +} + +struct Services { + server: Arc<Server>, + globals: Dep<globals::Service>, + state_cache: Dep<rooms::state_cache::Service>, + state_accessor: Dep<rooms::state_accessor::Service>, } impl Data { - pub(super) fn new(db: Arc<Database>) -> Self { + pub(super) fn new(args: &crate::Args<'_>) -> Self { + let db = &args.db; Self { - userid_password: db["userid_password"].clone(), + keychangeid_userid: db["keychangeid_userid"].clone(), + keyid_key: db["keyid_key"].clone(), + onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), + openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), + todeviceid_events: db["todeviceid_events"].clone(), token_userdeviceid: db["token_userdeviceid"].clone(), - userid_displayname: db["userid_displayname"].clone(), + userdeviceid_metadata: db["userdeviceid_metadata"].clone(), + userdeviceid_token: db["userdeviceid_token"].clone(), + userfilterid_filter: db["userfilterid_filter"].clone(), userid_avatarurl: db["userid_avatarurl"].clone(), userid_blurhash: db["userid_blurhash"].clone(), userid_devicelistversion: db["userid_devicelistversion"].clone(), - userdeviceid_token: db["userdeviceid_token"].clone(), - userdeviceid_metadata: db["userdeviceid_metadata"].clone(), - onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(), + userid_displayname: db["userid_displayname"].clone(), userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(), - keyid_key: db["keyid_key"].clone(), userid_masterkeyid: db["userid_masterkeyid"].clone(), + userid_password: db["userid_password"].clone(), userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(), userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(), - openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(), - keychangeid_userid: db["keychangeid_userid"].clone(), - todeviceid_events: db["todeviceid_events"].clone(), - userfilterid_filter: db["userfilterid_filter"].clone(), - _db: db, + services: Services { + server: args.server.clone(), + globals: args.depend::<globals::Service>("globals"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), + }, } } @@ -377,7 +390,7 @@ pub(super) fn add_one_time_key( )?; self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; Ok(()) } @@ -403,7 +416,7 @@ pub(super) fn take_one_time_key( prefix.push(b':'); self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?; self.onetimekeyid_onetimekeys .scan_prefix(prefix) @@ -631,16 +644,16 @@ pub(super) fn keys_changed<'a>( } pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services() - .rooms + let count = self.services.globals.next_count()?.to_be_bytes(); + for room_id in self + .services .state_cache .rooms_joined(user_id) .filter_map(Result::ok) { // Don't send key updates to unencrypted rooms - if services() - .rooms + if self + .services .state_accessor .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .is_none() @@ -750,7 +763,7 @@ pub(super) fn add_to_device_event( key.push(0xFF); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xFF); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -916,7 +929,7 @@ pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Opt pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> { use std::num::Saturating as Sat; - let expires_in = services().globals.config.openid_token_ttl; + let expires_in = self.services.server.config.openid_token_ttl; let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); let mut value = expires_at.0.to_be_bytes().to_vec(); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index e0a4dd1c4..4c80d0d37 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -7,7 +7,6 @@ }; use conduit::{Error, Result}; -use data::Data; use ruma::{ api::client::{ device::Device, @@ -24,7 +23,8 @@ UInt, UserId, }; -use crate::services; +use self::data::Data; +use crate::{admin, rooms, Dep}; pub struct SlidingSyncCache { lists: BTreeMap<String, SyncRequestList>, @@ -36,14 +36,24 @@ pub struct SlidingSyncCache { type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>; pub struct Service { + services: Services, pub db: Data, pub connections: DbConnections, } +struct Services { + admin: Dep<admin::Service>, + state_cache: Dep<rooms::state_cache::Service>, +} + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { - db: Data::new(args.db.clone()), + services: Services { + admin: args.depend::<admin::Service>("admin"), + state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + }, + db: Data::new(&args), connections: StdMutex::new(BTreeMap::new()), })) } @@ -247,11 +257,8 @@ pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { self.db.is_deac /// Check if a user is an admin pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { - if let Some(admin_room_id) = services().admin.get_admin_room()? { - services() - .rooms - .state_cache - .is_joined(user_id, &admin_room_id) + if let Some(admin_room_id) = self.services.admin.get_admin_room()? { + self.services.state_cache.is_joined(user_id, &admin_room_id) } else { Ok(false) } -- GitLab From 59efabbbc2feee08589a2118683d29dddb942186 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 22 Jul 2024 07:43:51 +0000 Subject: [PATCH 16/47] de-global server_is_ours / user_is_local Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/federation/commands.rs | 2 +- src/admin/mod.rs | 2 +- src/admin/room/room_commands.rs | 2 +- src/admin/room/room_directory_commands.rs | 2 +- src/admin/room/room_moderation_commands.rs | 10 +++---- src/admin/user/commands.rs | 32 ++++++++++++--------- src/admin/utils.rs | 30 +++++++++----------- src/api/client/account.rs | 13 +++------ src/api/client/alias.rs | 10 +++---- src/api/client/directory.rs | 8 ++---- src/api/client/keys.rs | 8 ++---- src/api/client/media.rs | 33 +++++++++++----------- src/api/client/membership.rs | 23 ++++++++------- src/api/client/profile.rs | 14 ++++----- src/api/client/state.rs | 10 +++---- src/api/client/to_device.rs | 5 ++-- src/api/mod.rs | 4 +-- src/api/server/invite.rs | 3 +- src/api/server/query.rs | 5 ++-- src/api/server/send_join.rs | 7 ++--- src/api/server/send_leave.rs | 7 +++-- src/api/server/user.rs | 18 ++++++++---- src/core/server.rs | 3 ++ src/service/admin/mod.rs | 4 +-- src/service/globals/mod.rs | 16 +++++------ src/service/mod.rs | 5 +--- src/service/presence/mod.rs | 6 ++-- src/service/rooms/alias/mod.rs | 13 +++++++-- src/service/rooms/state_cache/data.rs | 4 +-- src/service/rooms/state_cache/mod.rs | 6 ++-- src/service/rooms/timeline/mod.rs | 12 ++++---- src/service/rooms/typing/mod.rs | 14 +++++---- src/service/sending/mod.rs | 8 +++--- src/service/sending/sender.rs | 7 ++--- 34 files changed, 178 insertions(+), 168 deletions(-) diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 331231aed..d6ecd3f7c 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -90,7 +90,7 @@ pub(super) async fn remote_user_in_rooms(_body: Vec<&str>, user_id: Box<UserId>) .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) - .map(|room_id| get_room_info(&room_id)) + .map(|room_id| get_room_info(services(), &room_id)) .collect(); if rooms.is_empty() { diff --git a/src/admin/mod.rs b/src/admin/mod.rs index c57659c17..b183f3f64 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -19,7 +19,7 @@ extern crate conduit_service as service; pub(crate) use conduit::{mod_ctor, mod_dtor, Result}; -pub(crate) use service::{services, user_is_local}; +pub(crate) use service::services; pub(crate) use crate::utils::{escape_html, get_room_info}; diff --git a/src/admin/room/room_commands.rs b/src/admin/room/room_commands.rs index d47edce20..1a387c7e1 100644 --- a/src/admin/room/room_commands.rs +++ b/src/admin/room/room_commands.rs @@ -39,7 +39,7 @@ pub(super) async fn list( true }) - .map(|room_id| get_room_info(&room_id)) + .map(|room_id| get_room_info(services(), &room_id)) }) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); diff --git a/src/admin/room/room_directory_commands.rs b/src/admin/room/room_directory_commands.rs index 5e81bd89c..c9b4eb9e0 100644 --- a/src/admin/room/room_directory_commands.rs +++ b/src/admin/room/room_directory_commands.rs @@ -29,7 +29,7 @@ pub(super) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> .directory .public_rooms() .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(&id)) + .map(|id: OwnedRoomId| get_room_info(services(), &id)) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index 46354c0f7..8ad8295b0 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -1,9 +1,9 @@ use api::client::leave_room; +use conduit::{debug, error, info, warn, Result}; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; -use tracing::{debug, error, info, warn}; use super::RoomModerationCommand; -use crate::{get_room_info, services, user_is_local, Result}; +use crate::{get_room_info, services}; pub(super) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { match command { @@ -110,11 +110,11 @@ async fn ban_room( .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - user_is_local(local_user) + services().globals.user_is_local(local_user) // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would // fail auth check) - && (user_is_local(local_user) + && (services().globals.user_is_local(local_user) // since this is a force operation, assume user is an admin // if somehow this fails && services() @@ -484,7 +484,7 @@ async fn list_banned_rooms(_body: Vec<&str>) -> Result<RoomMessageEventContent> let mut rooms = room_ids .into_iter() - .map(|room_id| get_room_info(&room_id)) + .map(|room_id| get_room_info(services(), &room_id)) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index fa1c52881..69019d79e 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -36,7 +36,7 @@ pub(super) async fn create( _body: Vec<&str>, username: String, password: Option<String>, ) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(&username)?; + let user_id = parse_local_user_id(services(), &username)?; if services().users.exists(&user_id)? { return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists"))); @@ -134,7 +134,7 @@ pub(super) async fn deactivate( _body: Vec<&str>, no_leave_rooms: bool, user_id: String, ) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(&user_id)?; + let user_id = parse_local_user_id(services(), &user_id)?; // don't deactivate the server service account if user_id == services().globals.server_user { @@ -170,7 +170,7 @@ pub(super) async fn deactivate( } pub(super) async fn reset_password(_body: Vec<&str>, username: String) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(&username)?; + let user_id = parse_local_user_id(services(), &username)?; if user_id == services().globals.server_user { return Ok(RoomMessageEventContent::text_plain( @@ -211,7 +211,7 @@ pub(super) async fn deactivate_all( let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(username) { + match parse_active_local_user_id(services(), username) { Ok(user_id) => { if services().users.is_admin(&user_id)? && !force { services() @@ -292,14 +292,14 @@ pub(super) async fn deactivate_all( pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(&user_id)?; + let user_id = parse_local_user_id(services(), &user_id)?; let mut rooms: Vec<(OwnedRoomId, u64, String)> = services() .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) - .map(|room_id| get_room_info(&room_id)) + .map(|room_id| get_room_info(services(), &room_id)) .collect(); if rooms.is_empty() { @@ -344,10 +344,13 @@ pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Resu pub(super) async fn force_join_room( _body: Vec<&str>, user_id: String, room_id: OwnedRoomOrAliasId, ) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(&user_id)?; + let user_id = parse_local_user_id(services(), &user_id)?; let room_id = services().rooms.alias.resolve(&room_id).await?; - assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + assert!( + services().globals.user_is_local(&user_id), + "Parsed user_id must be a local user" + ); join_room_by_id_helper(services(), &user_id, &room_id, None, &[], None).await?; Ok(RoomMessageEventContent::notice_markdown(format!( @@ -356,13 +359,16 @@ pub(super) async fn force_join_room( } pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(&user_id)?; + let user_id = parse_local_user_id(services(), &user_id)?; let displayname = services() .users .displayname(&user_id)? .unwrap_or_else(|| user_id.to_string()); - assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + assert!( + services().globals.user_is_local(&user_id), + "Parsed user_id must be a local user" + ); services() .admin .make_user_admin(&user_id, displayname) @@ -376,7 +382,7 @@ pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result pub(super) async fn put_room_tag( _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(&user_id)?; + let user_id = parse_active_local_user_id(services(), &user_id)?; let event = services() .account_data @@ -411,7 +417,7 @@ pub(super) async fn put_room_tag( pub(super) async fn delete_room_tag( _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(&user_id)?; + let user_id = parse_active_local_user_id(services(), &user_id)?; let event = services() .account_data @@ -443,7 +449,7 @@ pub(super) async fn delete_room_tag( pub(super) async fn get_room_tags( _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(&user_id)?; + let user_id = parse_active_local_user_id(services(), &user_id)?; let event = services() .account_data diff --git a/src/admin/utils.rs b/src/admin/utils.rs index fda42e6e2..8d3d15ae4 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -1,8 +1,6 @@ -use conduit_core::{err, Err}; +use conduit_core::{err, Err, Result}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; -use service::user_is_local; - -use crate::{services, Result}; +use service::Services; pub(crate) fn escape_html(s: &str) -> String { s.replace('&', "&") @@ -10,17 +8,17 @@ pub(crate) fn escape_html(s: &str) -> String { .replace('>', ">") } -pub(crate) fn get_room_info(id: &RoomId) -> (OwnedRoomId, u64, String) { +pub(crate) fn get_room_info(services: &Services, id: &RoomId) -> (OwnedRoomId, u64, String) { ( id.into(), - services() + services .rooms .state_cache .room_joined_count(id) .ok() .flatten() .unwrap_or(0), - services() + services .rooms .state_accessor .get_name(id) @@ -31,16 +29,16 @@ pub(crate) fn get_room_info(id: &RoomId) -> (OwnedRoomId, u64, String) { } /// Parses user ID -pub(crate) fn parse_user_id(user_id: &str) -> Result<OwnedUserId> { - UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) +pub(crate) fn parse_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { + UserId::parse_with_server_name(user_id.to_lowercase(), services.globals.server_name()) .map_err(|e| err!("The supplied username is not a valid username: {e}")) } /// Parses user ID as our local user -pub(crate) fn parse_local_user_id(user_id: &str) -> Result<OwnedUserId> { - let user_id = parse_user_id(user_id)?; +pub(crate) fn parse_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { + let user_id = parse_user_id(services, user_id)?; - if !user_is_local(&user_id) { + if !services.globals.user_is_local(&user_id) { return Err!("User {user_id:?} does not belong to our server."); } @@ -48,14 +46,14 @@ pub(crate) fn parse_local_user_id(user_id: &str) -> Result<OwnedUserId> { } /// Parses user ID that is an active (not guest or deactivated) local user -pub(crate) fn parse_active_local_user_id(user_id: &str) -> Result<OwnedUserId> { - let user_id = parse_local_user_id(user_id)?; +pub(crate) fn parse_active_local_user_id(services: &Services, user_id: &str) -> Result<OwnedUserId> { + let user_id = parse_local_user_id(services, user_id)?; - if !services().users.exists(&user_id)? { + if !services.users.exists(&user_id)? { return Err!("User {user_id:?} does not exist on this server."); } - if services().users.is_deactivated(&user_id)? { + if services.users.is_deactivated(&user_id)? { return Err!("User {user_id:?} is deactivated."); } diff --git a/src/api/client/account.rs b/src/api/client/account.rs index b3495b429..7c2bb0b6a 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -2,7 +2,7 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::debug_info; +use conduit::{debug_info, error, info, utils, warn, Error, Result}; use register::RegistrationKind; use ruma::{ api::client::{ @@ -18,14 +18,9 @@ events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, push, OwnedRoomId, UserId, }; -use tracing::{error, info, warn}; use super::{join_room_by_id_helper, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{ - service::user_is_local, - utils::{self}, - Error, Result, Ruma, -}; +use crate::Ruma; const RANDOM_USER_ID_LENGTH: usize = 10; @@ -48,7 +43,7 @@ pub(crate) async fn get_register_available_route( // Validate user id let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) + .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough @@ -136,7 +131,7 @@ pub(crate) async fn register_route( let proposed_user_id = UserId::parse_with_server_name(username.to_lowercase(), services.globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) + .filter(|user_id| !user_id.is_historical() && services.globals.user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; if services.users.exists(&proposed_user_id)? { diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index 11617a0e8..dbc75e641 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::{debug, Error, Result}; use rand::seq::SliceRandom; use ruma::{ api::client::{ @@ -7,12 +8,9 @@ }, OwnedServerName, RoomAliasId, RoomId, }; -use tracing::debug; +use service::Services; -use crate::{ - service::{server_is_ours, Services}, - Error, Result, Ruma, -}; +use crate::Ruma; /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// @@ -142,7 +140,7 @@ fn room_available_servers( // prefer the room alias server first if let Some(server_index) = servers .iter() - .position(|server_name| server_is_ours(server_name)) + .position(|server_name| services.globals.server_is_ours(server_name)) { servers.swap_remove(server_index); servers.insert(0, services.globals.server_name().to_owned()); diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index deebd250e..cb30b60a5 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -20,11 +20,9 @@ }, uint, RoomId, ServerName, UInt, UserId, }; +use service::Services; -use crate::{ - service::{server_is_ours, Services}, - Ruma, -}; +use crate::Ruma; /// # `POST /_matrix/client/v3/publicRooms` /// @@ -187,7 +185,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( services: &Services, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork, ) -> Result<get_public_rooms_filtered::v3::Response> { - if let Some(other_server) = server.filter(|server_name| !server_is_ours(server_name)) { + if let Some(other_server) = server.filter(|server_name| !services.globals.server_is_ours(server_name)) { let response = services .sending .send_federation_request( diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 728ea7a93..8489dde35 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -4,7 +4,7 @@ }; use axum::extract::State; -use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result}; +use conduit::{debug, utils, utils::math::continue_exponential_backoff_secs, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -19,8 +19,6 @@ DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; -use service::user_is_local; -use tracing::debug; use super::SESSION_ID_LENGTH; use crate::{service::Services, Ruma}; @@ -266,7 +264,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( for (user_id, device_ids) in device_keys_input { let user_id: &UserId = user_id; - if !user_is_local(user_id) { + if !services.globals.user_is_local(user_id) { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -459,7 +457,7 @@ pub(crate) async fn claim_keys_helper( let mut get_over_federation = BTreeMap::new(); for (user_id, map) in one_time_keys_input { - if !user_is_local(user_id) { + if !services.globals.user_is_local(user_id) { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) diff --git a/src/api/client/media.rs b/src/api/client/media.rs index e219e4575..78463fc6b 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -4,7 +4,15 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::{debug, error, utils::math::ruma_from_usize, warn}; +use conduit::{ + debug, debug_warn, error, + utils::{ + self, + content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename}, + math::ruma_from_usize, + }, + warn, Error, Result, +}; use image::io::Reader as ImgReader; use ipaddress::IPAddress; use reqwest::Url; @@ -15,20 +23,13 @@ get_media_preview, }, }; +use service::{ + media::{FileMeta, UrlPreviewData}, + Services, +}; use webpage::HTML; -use crate::{ - debug_warn, - service::{ - media::{FileMeta, UrlPreviewData}, - server_is_ours, Services, - }, - utils::{ - self, - content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename}, - }, - Error, Result, Ruma, RumaResponse, -}; +use crate::{Ruma, RumaResponse}; /// generated MXC ID (`media-id`) length const MXC_LENGTH: usize = 32; @@ -218,7 +219,7 @@ pub(crate) async fn get_content_route( cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.to_owned()), cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), }) - } else if !server_is_ours(&body.server_name) && body.allow_remote { + } else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote { let response = get_remote_content( services, &mxc, @@ -308,7 +309,7 @@ pub(crate) async fn get_content_as_filename_route( cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.to_owned()), cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), }) - } else if !server_is_ours(&body.server_name) && body.allow_remote { + } else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote { match get_remote_content( services, &mxc, @@ -408,7 +409,7 @@ pub(crate) async fn get_content_thumbnail_route( cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), content_disposition, }) - } else if !server_is_ours(&body.server_name) && body.allow_remote { + } else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote { if services .globals .prevent_media_downloads_from() diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index e3630c726..d3b2d8f66 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -8,8 +8,11 @@ use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{ - debug, debug_warn, error, info, trace, utils, utils::math::continue_exponential_backoff_secs, warn, Error, - PduEvent, Result, + debug, debug_warn, error, info, + pdu::{gen_event_id_canonical_json, PduBuilder}, + trace, utils, + utils::math::continue_exponential_backoff_secs, + warn, Error, PduEvent, Result, }; use ruma::{ api::{ @@ -36,15 +39,11 @@ OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use service::{rooms::state::RoomMutexGuard, Services}; use tokio::sync::RwLock; use crate::{ client::{update_avatar_url, update_displayname}, - service::{ - pdu::{gen_event_id_canonical_json, PduBuilder}, - rooms::state::RoomMutexGuard, - server_is_ours, user_is_local, Services, - }, Ruma, }; @@ -675,7 +674,7 @@ pub async fn join_room_by_id_helper( .state_cache .server_in_room(services.globals.server_name(), room_id)? || servers.is_empty() - || (servers.len() == 1 && server_is_ours(&servers[0])) + || (servers.len() == 1 && services.globals.server_is_ours(&servers[0])) { join_room_by_id_helper_local(services, sender_user, room_id, reason, servers, third_party_signed, state_lock) .await @@ -1049,7 +1048,7 @@ async fn join_room_by_id_helper_local( .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|user| user_is_local(user)) + .filter(|user| services.globals.user_is_local(user)) .collect::<Vec<OwnedUserId>>(); let mut join_authorized_via_users_server: Option<OwnedUserId> = None; @@ -1110,7 +1109,7 @@ async fn join_room_by_id_helper_local( if !restriction_rooms.is_empty() && servers .iter() - .any(|server_name| !server_is_ours(server_name)) + .any(|server_name| !services.globals.server_is_ours(server_name)) { warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); let (make_join_response, remote_server) = make_join_request(services, sender_user, room_id, servers).await?; @@ -1259,7 +1258,7 @@ async fn make_join_request( let mut incompatible_room_version_count: u8 = 0; for remote_server in servers { - if server_is_ours(remote_server) { + if services.globals.server_is_ours(remote_server) { continue; } info!("Asking {remote_server} for make_join ({make_join_counter})"); @@ -1389,7 +1388,7 @@ pub(crate) async fn invite_helper( )); } - if !user_is_local(user_id) { + if !services.globals.user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { let state_lock = services.rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 3b2c32ecf..9e9bcf8e0 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::{pdu::PduBuilder, warn, Error, Result}; use ruma::{ api::{ client::{ @@ -12,12 +13,9 @@ OwnedMxcUri, OwnedRoomId, OwnedUserId, }; use serde_json::value::to_raw_value; -use tracing::warn; +use service::Services; -use crate::{ - service::{pdu::PduBuilder, user_is_local, Services}, - Error, Result, Ruma, -}; +use crate::Ruma; /// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// @@ -56,7 +54,7 @@ pub(crate) async fn set_displayname_route( pub(crate) async fn get_displayname_route( State(services): State<crate::State>, body: Ruma<get_display_name::v3::Request>, ) -> Result<get_display_name::v3::Response> { - if !user_is_local(&body.user_id) { + if !services.globals.user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services .sending @@ -147,7 +145,7 @@ pub(crate) async fn set_avatar_url_route( pub(crate) async fn get_avatar_url_route( State(services): State<crate::State>, body: Ruma<get_avatar_url::v3::Request>, ) -> Result<get_avatar_url::v3::Response> { - if !user_is_local(&body.user_id) { + if !services.globals.user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services .sending @@ -205,7 +203,7 @@ pub(crate) async fn get_avatar_url_route( pub(crate) async fn get_profile_route( State(services): State<crate::State>, body: Ruma<get_profile::v3::Request>, ) -> Result<get_profile::v3::Response> { - if !user_is_local(&body.user_id) { + if !services.globals.user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services .sending diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 56ffd2ac5..7af4f5f97 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::extract::State; -use conduit::{debug_info, error}; +use conduit::{debug_info, error, pdu::PduBuilder, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, @@ -18,11 +18,9 @@ serde::Raw, EventId, RoomId, UserId, }; +use service::Services; -use crate::{ - service::{pdu::PduBuilder, server_is_ours, Services}, - Error, Result, Ruma, RumaResponse, -}; +use crate::{Ruma, RumaResponse}; /// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}` /// @@ -250,7 +248,7 @@ async fn allowed_to_send_state_event( } for alias in aliases { - if !server_is_ours(alias.server_name()) + if !services.globals.server_is_ours(alias.server_name()) || services .rooms .alias diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 8476ff41d..1f557ad7b 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{Error, Result}; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -9,7 +10,7 @@ to_device::DeviceIdOrAllDevices, }; -use crate::{user_is_local, Error, Result, Ruma}; +use crate::Ruma; /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// @@ -31,7 +32,7 @@ pub(crate) async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if !user_is_local(target_user_id) { + if !services.globals.user_is_local(target_user_id) { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); diff --git a/src/api/mod.rs b/src/api/mod.rs index 0d80e5814..c7411b6c1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -7,8 +7,8 @@ extern crate conduit_core as conduit; extern crate conduit_service as service; -pub(crate) use conduit::{debug_info, debug_warn, pdu::PduEvent, utils, Error, Result}; -pub(crate) use service::{services, user_is_local}; +pub(crate) use conduit::{debug_info, pdu::PduEvent, utils, Error, Result}; +pub(crate) use service::services; pub use crate::router::State; pub(crate) use crate::router::{Ruma, RumaResponse}; diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 17e219205..688e026c5 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -7,7 +7,6 @@ serde::JsonObject, CanonicalJsonValue, EventId, OwnedUserId, }; -use service::server_is_ours; use crate::Ruma; @@ -88,7 +87,7 @@ pub(crate) async fn create_invite_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user ID."))?; - if !server_is_ours(invited_user.server_name()) { + if !services.globals.server_is_ours(invited_user.server_name()) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this homeserver.", diff --git a/src/api/server/query.rs b/src/api/server/query.rs index dddf23e71..5712f46a0 100644 --- a/src/api/server/query.rs +++ b/src/api/server/query.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::{Error, Result}; use get_profile_information::v1::ProfileField; use rand::seq::SliceRandom; use ruma::{ @@ -9,7 +10,7 @@ OwnedServerName, }; -use crate::{service::server_is_ours, Error, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/query/directory` /// @@ -64,7 +65,7 @@ pub(crate) async fn get_profile_information_route( )); } - if !server_is_ours(body.user_id.server_name()) { + if !services.globals.server_is_ours(body.user_id.server_name()) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this server.", diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 7f79a1d99..4cd29795f 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::{Error, Result}; +use conduit::{pdu::gen_event_id_canonical_json, warn, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ @@ -13,9 +13,8 @@ CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::{pdu::gen_event_id_canonical_json, user_is_local, Services}; +use service::Services; use tokio::sync::RwLock; -use tracing::warn; use crate::Ruma; @@ -126,7 +125,7 @@ async fn create_join_event( if content .join_authorized_via_users_server - .is_some_and(|user| user_is_local(&user)) + .is_some_and(|user| services.globals.user_is_local(&user)) && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default() { ruma::signatures::hash_and_sign_event( diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index f7484a330..b1b8fec81 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -3,6 +3,7 @@ use std::collections::BTreeMap; use axum::extract::State; +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -15,8 +16,8 @@ use tokio::sync::RwLock; use crate::{ - service::{pdu::gen_event_id_canonical_json, server_is_ours, Services}, - Error, Result, Ruma, + service::{pdu::gen_event_id_canonical_json, Services}, + Ruma, }; /// # `PUT /_matrix/federation/v1/send_leave/{roomId}/{eventId}` @@ -174,7 +175,7 @@ async fn create_leave_event( .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| !server_is_ours(server)); + .filter(|server| !services.globals.server_is_ours(server)); services.sending.send_pdu_servers(servers, &pdu_id)?; diff --git a/src/api/server/user.rs b/src/api/server/user.rs index 949e5f380..bd0372e66 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -1,4 +1,5 @@ use axum::extract::State; +use conduit::{Error, Result}; use ruma::api::{ client::error::ErrorKind, federation::{ @@ -9,8 +10,7 @@ use crate::{ client::{claim_keys_helper, get_keys_helper}, - service::user_is_local, - Error, Result, Ruma, + Ruma, }; /// # `GET /_matrix/federation/v1/user/devices/{userId}` @@ -19,7 +19,7 @@ pub(crate) async fn get_devices_route( State(services): State<crate::State>, body: Ruma<get_devices::v1::Request>, ) -> Result<get_devices::v1::Response> { - if !user_is_local(&body.user_id) { + if !services.globals.user_is_local(&body.user_id) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Tried to access user from other server.", @@ -72,7 +72,11 @@ pub(crate) async fn get_devices_route( pub(crate) async fn get_keys_route( State(services): State<crate::State>, body: Ruma<get_keys::v1::Request>, ) -> Result<get_keys::v1::Response> { - if body.device_keys.iter().any(|(u, _)| !user_is_local(u)) { + if body + .device_keys + .iter() + .any(|(u, _)| !services.globals.user_is_local(u)) + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this server.", @@ -101,7 +105,11 @@ pub(crate) async fn get_keys_route( pub(crate) async fn claim_keys_route( State(services): State<crate::State>, body: Ruma<claim_keys::v1::Request>, ) -> Result<claim_keys::v1::Response> { - if body.one_time_keys.iter().any(|(u, _)| !user_is_local(u)) { + if body + .one_time_keys + .iter() + .any(|(u, _)| !services.globals.user_is_local(u)) + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Tried to access user from other server.", diff --git a/src/core/server.rs b/src/core/server.rs index bf0ab99d1..89f1dea58 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -109,4 +109,7 @@ pub fn runtime(&self) -> &runtime::Handle { #[inline] pub fn running(&self) -> bool { !self.stopping.load(Ordering::Acquire) } + + #[inline] + pub fn is_ours(&self, name: &str) -> bool { name == self.config.server_name } } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 8b9473a20..1d51bf381 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -22,7 +22,7 @@ use serde_json::value::to_raw_value; use tokio::sync::{Mutex, RwLock}; -use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local, Dep}; +use crate::{globals, rooms, rooms::state::RoomMutexGuard, Dep}; pub struct Service { services: Services, @@ -301,7 +301,7 @@ pub async fn is_admin_command(&self, pdu: &PduEvent, body: &str) -> bool { } // only allow public escaped commands by local admins - if is_public_escape && !user_is_local(&pdu.sender) { + if is_public_escape && !self.services.globals.user_is_local(&pdu.sender) { return false; } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index eab156eee..c2c7d9697 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -21,7 +21,7 @@ use tokio::sync::Mutex; use url::Url; -use crate::{service, services}; +use crate::service; pub struct Service { pub db: Data, @@ -302,13 +302,11 @@ pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool { true } -} -#[inline] -#[must_use] -pub fn server_is_ours(server_name: &ServerName) -> bool { server_name == services().globals.config.server_name } + /// checks if `user_id` is local to us via server_name comparison + #[inline] + pub fn user_is_local(&self, user_id: &UserId) -> bool { self.server_is_ours(user_id.server_name()) } -/// checks if `user_id` is local to us via server_name comparison -#[inline] -#[must_use] -pub fn user_is_local(user_id: &UserId) -> bool { server_is_ours(user_id.server_name()) } + #[inline] + pub fn server_is_ours(&self, server_name: &ServerName) -> bool { server_name == self.config.server_name } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 6e749c99d..ce106809c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -33,10 +33,7 @@ use database::Database; pub(crate) use service::{Args, Dep, Service}; -pub use crate::{ - globals::{server_is_ours, user_is_local}, - services::Services, -}; +pub use crate::services::Services; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 705ac4ffd..b86733045 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -14,7 +14,7 @@ use tokio::{sync::Mutex, time::sleep}; use self::data::Data; -use crate::{user_is_local, users, Dep}; +use crate::{globals, users, Dep}; /// Represents data required to be kept in order to implement the presence /// specification. @@ -80,6 +80,7 @@ pub struct Service { struct Services { server: Arc<Server>, + globals: Dep<globals::Service>, users: Dep<users::Service>, } @@ -93,6 +94,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { services: Services { server: args.server.clone(), + globals: args.depend::<globals::Service>("globals"), users: args.depend::<users::Service>("users"), }, db: Data::new(&args), @@ -185,7 +187,7 @@ pub fn set_presence( self.db .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; - if self.timeout_remote_users || user_is_local(user_id) { + if self.timeout_remote_users || self.services.globals.user_is_local(user_id) { let timeout = match presence_state { PresenceState::Online => self.services.server.config.presence_idle_timeout_s, _ => self.services.server.config.presence_offline_timeout_s, diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 344ab6d2b..f2e01ab54 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -14,7 +14,7 @@ }; use self::data::Data; -use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, server_is_ours, Dep}; +use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep}; pub struct Service { db: Data, @@ -85,7 +85,10 @@ pub async fn resolve(&self, room: &RoomOrAliasId) -> Result<OwnedRoomId> { pub async fn resolve_alias( &self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>, ) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> { - if !server_is_ours(room_alias.server_name()) + if !self + .services + .globals + .server_is_ours(room_alias.server_name()) && (!servers .as_ref() .is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned())) @@ -195,7 +198,11 @@ async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Opt pub async fn appservice_checks( &self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>, ) -> Result<()> { - if !server_is_ours(room_alias.server_name()) { + if !self + .services + .globals + .server_is_ours(room_alias.server_name()) + { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index cbda73cf5..19c73ea1c 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -12,7 +12,7 @@ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{appservice::RegistrationInfo, globals, user_is_local, users, Dep}; +use crate::{appservice::RegistrationInfo, globals, users, Dep}; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>; @@ -355,7 +355,7 @@ pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Ite Box::new( self.room_members(room_id) .filter_map(Result::ok) - .filter(|user| user_is_local(user)), + .filter(|user| self.services.globals.user_is_local(user)), ) } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index ac2f688e7..71899ceb9 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -21,7 +21,7 @@ OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{account_data, appservice::RegistrationInfo, rooms, user_is_local, users, Dep}; +use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep}; pub struct Service { services: Services, @@ -30,6 +30,7 @@ pub struct Service { struct Services { account_data: Dep<account_data::Service>, + globals: Dep<globals::Service>, state_accessor: Dep<rooms::state_accessor::Service>, users: Dep<users::Service>, } @@ -39,6 +40,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { Ok(Arc::new(Self { services: Services { account_data: args.depend::<account_data::Service>("account_data"), + globals: args.depend::<globals::Service>("globals"), state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"), users: args.depend::<users::Service>("users"), }, @@ -65,7 +67,7 @@ pub fn update_membership( // TODO: use futures to update remote profiles without blocking the membership // update #[allow(clippy::collapsible_if)] - if !user_is_local(user_id) { + if !self.services.globals.user_is_local(user_id) { if !self.services.users.exists(user_id)? { self.services.users.create(user_id, None)?; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 50d294753..82b1cd0cf 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -41,7 +41,7 @@ use self::data::Data; use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::state_compressor::CompressedStateEvent, sending, server_is_ours, Dep, + rooms::state_compressor::CompressedStateEvent, sending, Dep, }; // Update Relationships @@ -846,7 +846,7 @@ pub async fn build_and_append_pdu( .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|m| server_is_ours(m.server_name()) && m != target) + .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) .count(); if count < 2 { warn!("Last admin cannot leave from admins room"); @@ -871,7 +871,7 @@ pub async fn build_and_append_pdu( .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|m| server_is_ours(m.server_name()) && m != target) + .filter(|m| self.services.globals.server_is_ours(m.server_name()) && m != target) .count(); if count < 2 { warn!("Last admin cannot be banned in admins room"); @@ -1092,7 +1092,7 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re .unwrap_or_default(); let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { - if level > &power_levels.users_default && !server_is_ours(user_id.server_name()) { + if level > &power_levels.users_default && !self.services.globals.user_is_local(user_id) { Some(user_id.server_name().to_owned()) } else { None @@ -1106,7 +1106,7 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re .filter_map(|alias| { alias .ok() - .filter(|alias| !server_is_ours(alias.server_name())) + .filter(|alias| !self.services.globals.server_is_ours(alias.server_name())) .map(|alias| alias.server_name().to_owned()) }); @@ -1114,7 +1114,7 @@ pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Re .chain(room_alias_servers) .chain(self.services.server.config.trusted_servers.clone()) .filter(|server_name| { - if server_is_ours(server_name) { + if self.services.globals.server_is_ours(server_name) { return false; } diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index d863f2175..3cf1cdd59 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -8,7 +8,7 @@ }; use tokio::sync::{broadcast, RwLock}; -use crate::{globals, sending, user_is_local, Dep}; +use crate::{globals, sending, Dep}; pub struct Service { server: Arc<Server>, @@ -63,7 +63,7 @@ pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) } // update federation - if user_is_local(user_id) { + if self.services.globals.user_is_local(user_id) { self.federation_send(room_id, user_id, true)?; } @@ -89,7 +89,7 @@ pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result< } // update federation - if user_is_local(user_id) { + if self.services.globals.user_is_local(user_id) { self.federation_send(room_id, user_id, false)?; } @@ -145,7 +145,7 @@ async fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { // update federation for user in removable { - if user_is_local(&user) { + if self.services.globals.user_is_local(&user) { self.federation_send(room_id, &user, false)?; } } @@ -184,7 +184,11 @@ pub async fn typings_all( } fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { - debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",); + debug_assert!( + self.services.globals.user_is_local(user_id), + "tried to broadcast typing status of remote user", + ); + if !self.server.config.allow_outgoing_typing { return Ok(()); } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 6f091b04f..fc32d04f9 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -13,7 +13,7 @@ }; use tokio::sync::Mutex; -use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_is_ours, users, Dep}; +use crate::{account_data, client, globals, presence, pusher, resolver, rooms, users, Dep}; pub struct Service { server: Arc<Server>, @@ -136,7 +136,7 @@ pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server_name| !server_is_ours(server_name)); + .filter(|server_name| !self.services.globals.server_is_ours(server_name)); self.send_pdu_servers(servers, pdu_id) } @@ -185,7 +185,7 @@ pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server_name| !server_is_ours(server_name)); + .filter(|server_name| !self.services.globals.server_is_ours(server_name)); self.send_edu_servers(servers, serialized) } @@ -222,7 +222,7 @@ pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server_name| !server_is_ours(server_name)); + .filter(|server_name| !self.services.globals.server_is_ours(server_name)); self.flush_servers(servers) } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 0668ce242..9fbc30eaf 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -29,7 +29,6 @@ use tokio::time::sleep_until; use super::{appservice, Destination, Msg, SendingEvent, Service}; -use crate::user_is_local; #[derive(Debug)] enum TransactionStatus { @@ -264,7 +263,7 @@ fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok) - .filter(|user_id| user_is_local(user_id)), + .filter(|user_id| self.services.globals.user_is_local(user_id)), ); if self.server.config.allow_outgoing_read_receipts @@ -306,7 +305,7 @@ fn select_edus_presence( for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) { *max_edu_count = cmp::max(count, *max_edu_count); - if !user_is_local(&user_id) { + if !self.services.globals.user_is_local(&user_id) { continue; } @@ -358,7 +357,7 @@ fn select_edus_receipts( let (user_id, count, read_receipt) = r?; *max_edu_count = cmp::max(count, *max_edu_count); - if !user_is_local(&user_id) { + if !self.services.globals.user_is_local(&user_id) { continue; } -- GitLab From 2fb43dd38dfa1a86632921ebb65ca10f34093e33 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 22 Jul 2024 21:16:46 +0000 Subject: [PATCH 17/47] infra to synthesize program options with config options Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/config/mod.rs | 2 +- src/main/clap.rs | 12 +++++++++++- src/main/main.rs | 2 +- src/main/server.rs | 5 +++-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index b40ebb65f..a689701aa 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -405,7 +405,7 @@ struct ListeningAddr { impl Config { /// Initialize config - pub fn new(path: Option<PathBuf>) -> Result<Self> { + pub fn new(path: &Option<PathBuf>) -> Result<Self> { let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { Figment::new() .merge(Toml::file(config_file_env).nested()) diff --git a/src/main/clap.rs b/src/main/clap.rs index 6ce164f53..e82fea16e 100644 --- a/src/main/clap.rs +++ b/src/main/clap.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use clap::Parser; +use conduit::{Config, Result}; /// Commandline arguments #[derive(Parser, Debug)] @@ -15,4 +16,13 @@ pub(crate) struct Args { /// Parse commandline arguments into structured data #[must_use] -pub(crate) fn parse() -> Args { Args::parse() } +pub(super) fn parse() -> Args { Args::parse() } + +/// Synthesize any command line options with configuration file options. +pub(crate) fn update(config: &mut Config, args: &Args) -> Result<()> { + // Indicate the admin console should be spawned automatically if the + // configuration file hasn't already. + config.admin_console_automatic |= args.console.unwrap_or(false); + + Ok(()) +} diff --git a/src/main/main.rs b/src/main/main.rs index 959e86100..b13e117db 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -33,7 +33,7 @@ fn main() -> Result<(), Error> { .build() .expect("built runtime"); - let server: Arc<Server> = Server::build(args, Some(runtime.handle()))?; + let server: Arc<Server> = Server::build(&args, Some(runtime.handle()))?; runtime.spawn(signal::signal(server.clone())); runtime.block_on(async_main(&server))?; diff --git a/src/main/server.rs b/src/main/server.rs index 73c06f0ca..b1cd6936e 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -21,8 +21,9 @@ pub(crate) struct Server { } impl Server { - pub(crate) fn build(args: Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Self>, Error> { - let config = Config::new(args.config)?; + pub(crate) fn build(args: &Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Self>, Error> { + let mut config = Config::new(&args.config)?; + crate::clap::update(&mut config, args)?; #[cfg(feature = "sentry_telemetry")] let sentry_guard = crate::sentry::init(&config); -- GitLab From 263e3380883ec2ae801f7be8fa0be7388b158377 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 22 Jul 2024 21:21:44 +0000 Subject: [PATCH 18/47] add --console program option to automatically spawn Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/config/mod.rs | 6 ++++++ src/main/clap.rs | 6 +++++- src/service/admin/mod.rs | 22 +++++++++++++++++++--- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index a689701aa..6cef402af 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -330,6 +330,8 @@ pub struct Config { pub block_non_admin_invites: bool, #[serde(default = "true_fn")] pub admin_escape_commands: bool, + #[serde(default)] + pub admin_console_automatic: bool, #[serde(default)] pub sentry: bool, @@ -579,6 +581,10 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { &self.block_non_admin_invites.to_string(), ); line("Enable admin escape commands", &self.admin_escape_commands.to_string()); + line( + "Activate admin console after startup", + &self.admin_console_automatic.to_string(), + ); line("Allow outgoing federated typing", &self.allow_outgoing_typing.to_string()); line("Allow incoming federated typing", &self.allow_incoming_typing.to_string()); line( diff --git a/src/main/clap.rs b/src/main/clap.rs index e82fea16e..16d88978e 100644 --- a/src/main/clap.rs +++ b/src/main/clap.rs @@ -12,6 +12,10 @@ pub(crate) struct Args { #[arg(short, long)] /// Optional argument to the path of a conduwuit config TOML file pub(crate) config: Option<PathBuf>, + + /// Activate admin command console automatically after startup. + #[arg(long, num_args(0))] + pub(crate) console: bool, } /// Parse commandline arguments into structured data @@ -22,7 +26,7 @@ pub(super) fn parse() -> Args { Args::parse() } pub(crate) fn update(config: &mut Config, args: &Args) -> Result<()> { // Indicate the admin console should be spawned automatically if the // configuration file hasn't already. - config.admin_console_automatic |= args.console.unwrap_or(false); + config.admin_console_automatic |= args.console; Ok(()) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 1d51bf381..b3879deb7 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -82,6 +82,8 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { async fn worker(self: Arc<Self>) -> Result<()> { let receiver = self.receiver.lock().await; let mut signals = self.services.server.signal.subscribe(); + self.console_auto_start().await; + loop { tokio::select! { command = receiver.recv_async() => match command { @@ -95,9 +97,7 @@ async fn worker(self: Arc<Self>) -> Result<()> { } } - //TODO: not unwind safe - #[cfg(feature = "console")] - self.console.close().await; + self.console_auto_stop().await; //TODO: not unwind safe Ok(()) } @@ -340,4 +340,20 @@ pub fn is_admin_room(&self, room_id: &RoomId) -> bool { false } } + + /// Possibly spawn the terminal console at startup if configured. + async fn console_auto_start(&self) { + #[cfg(feature = "console")] + if self.services.server.config.admin_console_automatic { + // Allow more of the startup sequence to execute before spawning + tokio::task::yield_now().await; + self.console.start().await; + } + } + + /// Shutdown the console when the admin worker terminates. + async fn console_auto_stop(&self) { + #[cfg(feature = "console")] + self.console.close().await; + } } -- GitLab From ccfa939bd3995645868ddfc5fb9f39f3fa6eaabb Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 22 Jul 2024 22:24:17 +0000 Subject: [PATCH 19/47] split admin command enum from handler Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/admin.rs | 66 ++++++++++++++++ src/admin/handler.rs | 91 ++++------------------- src/admin/mod.rs | 10 ++- src/admin/room/room_commands.rs | 2 +- src/admin/room/room_directory_commands.rs | 2 +- src/admin/tests.rs | 2 +- 6 files changed, 89 insertions(+), 84 deletions(-) create mode 100644 src/admin/admin.rs diff --git a/src/admin/admin.rs b/src/admin/admin.rs new file mode 100644 index 000000000..e55f6d610 --- /dev/null +++ b/src/admin/admin.rs @@ -0,0 +1,66 @@ +use clap::Parser; +use conduit::Result; +use ruma::events::room::message::RoomMessageEventContent; + +use crate::{ + appservice, appservice::AppserviceCommand, check, check::CheckCommand, debug, debug::DebugCommand, federation, + federation::FederationCommand, media, media::MediaCommand, query, query::QueryCommand, room, room::RoomCommand, + server, server::ServerCommand, user, user::UserCommand, +}; + +#[derive(Parser)] +#[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] +pub(crate) enum AdminCommand { + #[command(subcommand)] + /// - Commands for managing appservices + Appservices(AppserviceCommand), + + #[command(subcommand)] + /// - Commands for managing local users + Users(UserCommand), + + #[command(subcommand)] + /// - Commands for managing rooms + Rooms(RoomCommand), + + #[command(subcommand)] + /// - Commands for managing federation + Federation(FederationCommand), + + #[command(subcommand)] + /// - Commands for managing the server + Server(ServerCommand), + + #[command(subcommand)] + /// - Commands for managing media + Media(MediaCommand), + + #[command(subcommand)] + /// - Commands for checking integrity + Check(CheckCommand), + + #[command(subcommand)] + /// - Commands for debugging things + Debug(DebugCommand), + + #[command(subcommand)] + /// - Low-level queries for database getters and iterators + Query(QueryCommand), +} + +#[tracing::instrument(skip_all, name = "command")] +pub(crate) async fn process(command: AdminCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { + let reply_message_content = match command { + AdminCommand::Appservices(command) => appservice::process(command, body).await?, + AdminCommand::Media(command) => media::process(command, body).await?, + AdminCommand::Users(command) => user::process(command, body).await?, + AdminCommand::Rooms(command) => room::process(command, body).await?, + AdminCommand::Federation(command) => federation::process(command, body).await?, + AdminCommand::Server(command) => server::process(command, body).await?, + AdminCommand::Debug(command) => debug::process(command, body).await?, + AdminCommand::Query(command) => query::process(command, body).await?, + AdminCommand::Check(command) => check::process(command, body).await?, + }; + + Ok(reply_message_content) +} diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 409abc18f..46712077a 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -11,64 +11,16 @@ OwnedEventId, }; -extern crate conduit_service as service; - use conduit::{utils::string::common_prefix, Result}; -pub(crate) use service::admin::Command; -use service::admin::{CommandOutput, CommandResult, HandlerResult}; +use service::admin::{Command, CommandOutput, CommandResult, HandlerResult}; -use crate::{ - appservice, appservice::AppserviceCommand, check, check::CheckCommand, debug, debug::DebugCommand, federation, - federation::FederationCommand, media, media::MediaCommand, query, query::QueryCommand, room, room::RoomCommand, - server, server::ServerCommand, services, user, user::UserCommand, -}; -pub(crate) const PAGE_SIZE: usize = 100; - -#[derive(Parser)] -#[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] -pub(crate) enum AdminCommand { - #[command(subcommand)] - /// - Commands for managing appservices - Appservices(AppserviceCommand), - - #[command(subcommand)] - /// - Commands for managing local users - Users(UserCommand), - - #[command(subcommand)] - /// - Commands for managing rooms - Rooms(RoomCommand), - - #[command(subcommand)] - /// - Commands for managing federation - Federation(FederationCommand), - - #[command(subcommand)] - /// - Commands for managing the server - Server(ServerCommand), - - #[command(subcommand)] - /// - Commands for managing media - Media(MediaCommand), - - #[command(subcommand)] - /// - Commands for checking integrity - Check(CheckCommand), - - #[command(subcommand)] - /// - Commands for debugging things - Debug(DebugCommand), - - #[command(subcommand)] - /// - Low-level queries for database getters and iterators - Query(QueryCommand), -} +use crate::{admin, admin::AdminCommand, services}; #[must_use] -pub(crate) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } +pub(super) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } #[must_use] -pub(crate) fn complete(line: &str) -> String { complete_admin_command(AdminCommand::command(), line) } +pub(super) fn complete(line: &str) -> String { complete_command(AdminCommand::command(), line) } #[tracing::instrument(skip_all, name = "admin")] async fn handle_command(command: Command) -> CommandResult { @@ -80,7 +32,7 @@ async fn handle_command(command: Command) -> CommandResult { } async fn process_command(command: &Command) -> CommandOutput { - process_admin_message(&command.command) + process(&command.command) .await .and_then(|content| reply(content, command.reply_id.clone())) } @@ -104,11 +56,11 @@ fn reply(mut content: RoomMessageEventContent, reply_id: Option<OwnedEventId>) - } // Parse and process a message from the admin room -async fn process_admin_message(msg: &str) -> CommandOutput { +async fn process(msg: &str) -> CommandOutput { let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); let command = lines.next().expect("each string has at least one line"); let body = lines.collect::<Vec<_>>(); - let parsed = match parse_admin_command(command) { + let parsed = match parse_command(command) { Ok(parsed) => parsed, Err(error) => { let server_name = services().globals.server_name(); @@ -118,7 +70,7 @@ async fn process_admin_message(msg: &str) -> CommandOutput { }; let timer = Instant::now(); - let result = process_admin_command(parsed, body).await; + let result = admin::process(parsed, body).await; let elapsed = timer.elapsed(); conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); match result { @@ -129,31 +81,14 @@ async fn process_admin_message(msg: &str) -> CommandOutput { } } -#[tracing::instrument(skip_all, name = "command")] -async fn process_admin_command(command: AdminCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - let reply_message_content = match command { - AdminCommand::Appservices(command) => appservice::process(command, body).await?, - AdminCommand::Media(command) => media::process(command, body).await?, - AdminCommand::Users(command) => user::process(command, body).await?, - AdminCommand::Rooms(command) => room::process(command, body).await?, - AdminCommand::Federation(command) => federation::process(command, body).await?, - AdminCommand::Server(command) => server::process(command, body).await?, - AdminCommand::Debug(command) => debug::process(command, body).await?, - AdminCommand::Query(command) => query::process(command, body).await?, - AdminCommand::Check(command) => check::process(command, body).await?, - }; - - Ok(reply_message_content) -} - // Parse chat messages from the admin room into an AdminCommand object -fn parse_admin_command(command_line: &str) -> Result<AdminCommand, String> { - let argv = parse_command_line(command_line); +fn parse_command(command_line: &str) -> Result<AdminCommand, String> { + let argv = parse_line(command_line); AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) } -fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { - let argv = parse_command_line(line); +fn complete_command(mut cmd: clap::Command, line: &str) -> String { + let argv = parse_line(line); let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); 'token: for token in argv.into_iter().skip(1) { @@ -196,7 +131,7 @@ fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { } // Parse chat messages from the admin room into an AdminCommand object -fn parse_command_line(command_line: &str) -> Vec<String> { +fn parse_line(command_line: &str) -> Vec<String> { let mut argv = command_line .split_whitespace() .map(str::to_owned) diff --git a/src/admin/mod.rs b/src/admin/mod.rs index b183f3f64..ff2aefd5c 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,18 +1,20 @@ #![recursion_limit = "168"] #![allow(clippy::wildcard_imports)] +pub(crate) mod admin; +pub(crate) mod handler; +mod tests; +pub(crate) mod utils; + pub(crate) mod appservice; pub(crate) mod check; pub(crate) mod debug; pub(crate) mod federation; -pub(crate) mod handler; pub(crate) mod media; pub(crate) mod query; pub(crate) mod room; pub(crate) mod server; -mod tests; pub(crate) mod user; -pub(crate) mod utils; extern crate conduit_api as api; extern crate conduit_core as conduit; @@ -23,6 +25,8 @@ pub(crate) use crate::utils::{escape_html, get_room_info}; +pub(crate) const PAGE_SIZE: usize = 100; + mod_ctor! {} mod_dtor! {} diff --git a/src/admin/room/room_commands.rs b/src/admin/room/room_commands.rs index 1a387c7e1..cf0f3ddbd 100644 --- a/src/admin/room/room_commands.rs +++ b/src/admin/room/room_commands.rs @@ -2,7 +2,7 @@ use ruma::events::room::message::RoomMessageEventContent; -use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result}; +use crate::{escape_html, get_room_info, services, Result, PAGE_SIZE}; pub(super) async fn list( _body: Vec<&str>, page: Option<usize>, exclude_disabled: bool, exclude_banned: bool, diff --git a/src/admin/room/room_directory_commands.rs b/src/admin/room/room_directory_commands.rs index c9b4eb9e0..912e970c6 100644 --- a/src/admin/room/room_directory_commands.rs +++ b/src/admin/room/room_directory_commands.rs @@ -3,7 +3,7 @@ use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; use super::RoomDirectoryCommand; -use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result}; +use crate::{escape_html, get_room_info, services, Result, PAGE_SIZE}; pub(super) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> { match command { diff --git a/src/admin/tests.rs b/src/admin/tests.rs index 69ccd896c..296d48887 100644 --- a/src/admin/tests.rs +++ b/src/admin/tests.rs @@ -12,7 +12,7 @@ fn get_help_inner(input: &str) { use clap::Parser; - use crate::handler::AdminCommand; + use crate::admin::AdminCommand; let Err(error) = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) else { panic!("no error!"); -- GitLab From 5ed95ea3572a9bd23942fdaaf900e18a40159aec Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 22 Jul 2024 22:45:25 +0000 Subject: [PATCH 20/47] contextualize handler in object Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/handler.rs | 229 +++++++++++++++++++++++-------------------- 1 file changed, 123 insertions(+), 106 deletions(-) diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 46712077a..ff04d3782 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -12,15 +12,27 @@ }; use conduit::{utils::string::common_prefix, Result}; -use service::admin::{Command, CommandOutput, CommandResult, HandlerResult}; +use service::{ + admin::{Command, CommandOutput, CommandResult, HandlerResult}, + Services, +}; + +use crate::{admin, admin::AdminCommand}; -use crate::{admin, admin::AdminCommand, services}; +struct Handler<'a> { + services: &'a Services, +} #[must_use] -pub(super) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } +pub(super) fn complete(line: &str) -> String { + Handler { + services: service::services(), + } + .complete_command(AdminCommand::command(), line) +} #[must_use] -pub(super) fn complete(line: &str) -> String { complete_command(AdminCommand::command(), line) } +pub(super) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } #[tracing::instrument(skip_all, name = "admin")] async fn handle_command(command: Command) -> CommandResult { @@ -32,9 +44,12 @@ async fn handle_command(command: Command) -> CommandResult { } async fn process_command(command: &Command) -> CommandOutput { - process(&command.command) - .await - .and_then(|content| reply(content, command.reply_id.clone())) + Handler { + services: service::services(), + } + .process(&command.command) + .await + .and_then(|content| reply(content, command.reply_id.clone())) } fn handle_panic(error: &Error, command: Command) -> CommandResult { @@ -55,122 +70,124 @@ fn reply(mut content: RoomMessageEventContent, reply_id: Option<OwnedEventId>) - Some(content) } -// Parse and process a message from the admin room -async fn process(msg: &str) -> CommandOutput { - let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); - let command = lines.next().expect("each string has at least one line"); - let body = lines.collect::<Vec<_>>(); - let parsed = match parse_command(command) { - Ok(parsed) => parsed, - Err(error) => { - let server_name = services().globals.server_name(); - let message = error.replace("server.name", server_name.as_str()); - return Some(RoomMessageEventContent::notice_markdown(message)); - }, - }; - - let timer = Instant::now(); - let result = admin::process(parsed, body).await; - let elapsed = timer.elapsed(); - conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); - match result { - Ok(reply) => Some(reply), - Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( - "Encountered an error while handling the command:\n```\n{error:#?}\n```" - ))), +impl Handler<'_> { + // Parse and process a message from the admin room + async fn process(&self, msg: &str) -> CommandOutput { + let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); + let command = lines.next().expect("each string has at least one line"); + let body = lines.collect::<Vec<_>>(); + let parsed = match self.parse_command(command) { + Ok(parsed) => parsed, + Err(error) => { + let server_name = self.services.globals.server_name(); + let message = error.replace("server.name", server_name.as_str()); + return Some(RoomMessageEventContent::notice_markdown(message)); + }, + }; + + let timer = Instant::now(); + let result = Box::pin(admin::process(parsed, body)).await; + let elapsed = timer.elapsed(); + conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); + match result { + Ok(reply) => Some(reply), + Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( + "Encountered an error while handling the command:\n```\n{error:#?}\n```" + ))), + } } -} -// Parse chat messages from the admin room into an AdminCommand object -fn parse_command(command_line: &str) -> Result<AdminCommand, String> { - let argv = parse_line(command_line); - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) -} - -fn complete_command(mut cmd: clap::Command, line: &str) -> String { - let argv = parse_line(line); - let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); + // Parse chat messages from the admin room into an AdminCommand object + fn parse_command(&self, command_line: &str) -> Result<AdminCommand, String> { + let argv = self.parse_line(command_line); + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) + } - 'token: for token in argv.into_iter().skip(1) { - let cmd_ = cmd.clone(); - let mut choice = Vec::new(); + fn complete_command(&self, mut cmd: clap::Command, line: &str) -> String { + let argv = self.parse_line(line); + let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); + + 'token: for token in argv.into_iter().skip(1) { + let cmd_ = cmd.clone(); + let mut choice = Vec::new(); + + for sub in cmd_.get_subcommands() { + let name = sub.get_name(); + if *name == token { + // token already complete; recurse to subcommand + ret.push(token); + cmd.clone_from(sub); + continue 'token; + } else if name.starts_with(&token) { + // partial match; add to choices + choice.push(name); + } + } - for sub in cmd_.get_subcommands() { - let name = sub.get_name(); - if *name == token { - // token already complete; recurse to subcommand + if choice.len() == 1 { + // One choice. Add extra space because it's complete + let choice = *choice.first().expect("only choice"); + ret.push(choice.to_owned()); + ret.push(String::new()); + } else if choice.is_empty() { + // Nothing found, return original string ret.push(token); - cmd.clone_from(sub); - continue 'token; - } else if name.starts_with(&token) { - // partial match; add to choices - choice.push(name); + } else { + // Find the common prefix + ret.push(common_prefix(&choice).into()); } - } - if choice.len() == 1 { - // One choice. Add extra space because it's complete - let choice = *choice.first().expect("only choice"); - ret.push(choice.to_owned()); - ret.push(String::new()); - } else if choice.is_empty() { - // Nothing found, return original string - ret.push(token); - } else { - // Find the common prefix - ret.push(common_prefix(&choice).into()); + // Return from completion + return ret.join(" "); } - // Return from completion - return ret.join(" "); + // Return from no completion. Needs a space though. + ret.push(String::new()); + ret.join(" ") } - // Return from no completion. Needs a space though. - ret.push(String::new()); - ret.join(" ") -} + // Parse chat messages from the admin room into an AdminCommand object + fn parse_line(&self, command_line: &str) -> Vec<String> { + let mut argv = command_line + .split_whitespace() + .map(str::to_owned) + .collect::<Vec<String>>(); -// Parse chat messages from the admin room into an AdminCommand object -fn parse_line(command_line: &str) -> Vec<String> { - let mut argv = command_line - .split_whitespace() - .map(str::to_owned) - .collect::<Vec<String>>(); + // Remove any escapes that came with a server-side escape command + if !argv.is_empty() && argv[0].ends_with("admin") { + argv[0] = argv[0].trim_start_matches('\\').into(); + } - // Remove any escapes that came with a server-side escape command - if !argv.is_empty() && argv[0].ends_with("admin") { - argv[0] = argv[0].trim_start_matches('\\').into(); - } + // First indice has to be "admin" but for console convenience we add it here + let server_user = self.services.globals.server_user.as_str(); + if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with(server_user) { + argv.insert(0, "admin".to_owned()); + } - // First indice has to be "admin" but for console convenience we add it here - let server_user = services().globals.server_user.as_str(); - if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with(server_user) { - argv.insert(0, "admin".to_owned()); - } + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help".to_owned()); + } - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help".to_owned()); - } + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 1 && argv[1].contains('_') { + argv[1] = argv[1].replace('_', "-"); + } - // Backwards compatibility with `register_appservice`-style commands - if argv.len() > 1 && argv[1].contains('_') { - argv[1] = argv[1].replace('_', "-"); - } + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 2 && argv[2].contains('_') { + argv[2] = argv[2].replace('_', "-"); + } - // Backwards compatibility with `register_appservice`-style commands - if argv.len() > 2 && argv[2].contains('_') { - argv[2] = argv[2].replace('_', "-"); - } + // if the user is using the `query` command (argv[1]), replace the database + // function/table calls with underscores to match the codebase + if argv.len() > 3 && argv[1].eq("query") { + argv[3] = argv[3].replace('_', "-"); + } - // if the user is using the `query` command (argv[1]), replace the database - // function/table calls with underscores to match the codebase - if argv.len() > 3 && argv[1].eq("query") { - argv[3] = argv[3].replace('_', "-"); + trace!(?command_line, ?argv, "parse"); + argv } - - trace!(?command_line, ?argv, "parse"); - argv } -- GitLab From 5a17fbccf56e395f9eea476442967275cfd544cf Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 5 Jul 2024 08:40:02 +0000 Subject: [PATCH 21/47] add type_name debug tool Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/debug.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/core/debug.rs b/src/core/debug.rs index 14d0be87a..5e52f399a 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -80,3 +80,8 @@ pub fn trap() { #[must_use] pub fn panic_str(p: &Box<dyn Any + Send>) -> &'static str { p.downcast_ref::<&str>().copied().unwrap_or_default() } + +#[cfg(debug_assertions)] +#[inline(always)] +#[must_use] +pub fn type_name<T>(_: &T) -> &'static str { std::any::type_name::<T>() } -- GitLab From 91b49a7786482e1bb7ffa3c7e8afe688e8f94afd Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 23 Jul 2024 06:57:14 +0000 Subject: [PATCH 22/47] add basic exchange util Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/utils/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index bbd528290..d7b6a72a1 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -26,8 +26,16 @@ pub use sys::available_parallelism; pub use time::now_millis as millis_since_unix_epoch; +#[inline] pub fn clamp<T: Ord>(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } +#[inline] +pub fn exchange<T: Clone>(state: &mut T, source: T) -> T { + let ret = state.clone(); + *state = source; + ret +} + #[must_use] pub fn generate_keypair() -> Vec<u8> { let mut value = rand::string(8).as_bytes().to_vec(); -- GitLab From 5c0bf2912205ab46989a5165080df8bce77686a2 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Tue, 23 Jul 2024 07:03:33 +0000 Subject: [PATCH 23/47] add util for camel to snake case conversion Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/utils/string.rs | 37 ++++++++++++++++++++++++++++++++++++- src/core/utils/tests.rs | 16 ++++++++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 106d0cb77..fc3891dfb 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -1,4 +1,4 @@ -use crate::Result; +use crate::{utils::exchange, Result}; pub const EMPTY: &str = ""; @@ -26,6 +26,41 @@ macro_rules! is_format { }; } +#[inline] +#[must_use] +pub fn camel_to_snake_string(s: &str) -> String { + let est_len = s + .chars() + .fold(s.len(), |est, c| est.saturating_add(usize::from(c.is_ascii_uppercase()))); + + let mut ret = String::with_capacity(est_len); + camel_to_snake_case(&mut ret, s.as_bytes()).expect("string-to-string stream error"); + ret +} + +#[inline] +pub fn camel_to_snake_case<I, O>(output: &mut O, input: I) -> Result<()> +where + I: std::io::Read, + O: std::fmt::Write, +{ + let mut state = false; + input + .bytes() + .take_while(Result::is_ok) + .map(Result::unwrap) + .map(char::from) + .try_for_each(|ch| { + let m = ch.is_ascii_uppercase(); + let s = exchange(&mut state, !m); + if m && s { + output.write_char('_')?; + } + output.write_char(ch.to_ascii_lowercase())?; + Result::<()>::Ok(()) + }) +} + /// Find the common prefix from a collection of strings and return a slice /// ``` /// use conduit_core::utils::string::common_prefix; diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index 439689470..e91accdf4 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -134,3 +134,19 @@ async fn mutex_map_contend() { tokio::try_join!(join_b, join_a).expect("joined"); assert!(map.is_empty(), "Must be empty"); } + +#[test] +fn camel_to_snake_case_0() { + use utils::string::camel_to_snake_string; + + let res = camel_to_snake_string("CamelToSnakeCase"); + assert_eq!(res, "camel_to_snake_case"); +} + +#[test] +fn camel_to_snake_case_1() { + use utils::string::camel_to_snake_string; + + let res = camel_to_snake_string("CAmelTOSnakeCase"); + assert_eq!(res, "camel_tosnake_case"); +} -- GitLab From 4458efa2b2fdb850f68e9afa3ffd72d21155bed7 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 17 Jun 2024 07:21:51 +0000 Subject: [PATCH 24/47] rename signing_keys_for to verify_keys_for Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/query/globals.rs | 2 +- src/service/globals/data.rs | 2 +- src/service/globals/mod.rs | 4 ++-- src/service/rooms/event_handler/signing_keys.rs | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 9bdd38fca..389860711 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -46,7 +46,7 @@ pub(super) async fn globals(subcommand: Globals) -> Result<RoomMessageEventConte origin, } => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.signing_keys_for(&origin); + let results = services().globals.db.verify_keys_for(&origin); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 5d6240cd4..d73ddec66 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -289,7 +289,7 @@ pub fn add_signing_key( /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { let signingkeys = self .server_signingkeys .get(origin.as_bytes())? diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index c2c7d9697..d46eb79e4 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -272,8 +272,8 @@ pub fn supported_room_versions(&self) -> Vec<RoomVersionId> { /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { - let mut keys = self.db.signing_keys_for(origin)?; + pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { + let mut keys = self.db.verify_keys_for(origin)?; if origin == self.server_name() { keys.insert( format!("ed25519:{}", self.keypair().version()) diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index 1ebcbefb3..d7cb3a9da 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -148,7 +148,7 @@ async fn get_server_keys_from_cache( let result: BTreeMap<_, _> = self .services .globals - .signing_keys_for(origin)? + .verify_keys_for(origin)? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); @@ -354,7 +354,7 @@ pub async fn fetch_signing_keys_for_server( let mut result: BTreeMap<_, _> = self .services .globals - .signing_keys_for(origin)? + .verify_keys_for(origin)? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); -- GitLab From c64adbec0e267c8507d6f2dd3c566cdd614a99c0 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 17 Jun 2024 07:49:52 +0000 Subject: [PATCH 25/47] split signing_keys_for from verify_keys_for Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/globals/data.rs | 13 ++++++++++--- src/service/globals/mod.rs | 9 ++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index d73ddec66..5b5d9f09d 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -291,9 +291,7 @@ pub fn add_signing_key( /// for the server. pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> { let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .signing_keys_for(origin)? .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { let mut tree = keys.verify_keys; tree.extend( @@ -307,6 +305,15 @@ pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServe Ok(signingkeys) } + pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()); + + Ok(signingkeys) + } + pub fn database_version(&self) -> Result<u64> { self.global.get(b"version")?.map_or(Ok(0), |version| { utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index d46eb79e4..2c588dce0 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -13,7 +13,10 @@ use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ - api::{client::discovery::discover_support::ContactRole, federation::discovery::VerifyKey}, + api::{ + client::discovery::discover_support::ContactRole, + federation::discovery::{ServerSigningKeys, VerifyKey}, + }, serde::Base64, DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, RoomVersionId, ServerName, UserId, @@ -288,6 +291,10 @@ pub fn verify_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServe Ok(keys) } + pub fn signing_keys_for(&self, origin: &ServerName) -> Result<Option<ServerSigningKeys>> { + self.db.signing_keys_for(origin) + } + pub fn well_known_client(&self) -> &Option<Url> { &self.config.well_known.client } pub fn well_known_server(&self) -> &Option<OwnedServerName> { &self.config.well_known.server } -- GitLab From f841c2356d402d2a7b59f4683ea02a3f14134873 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Mon, 17 Jun 2024 21:46:23 +0000 Subject: [PATCH 26/47] preliminary get-signing-keys command Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 45 +++++++++++++++++++++++++++++++++++++ src/admin/debug/mod.rs | 12 ++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index cbe524732..3a7d35449 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -1,5 +1,6 @@ use std::{ collections::{BTreeMap, HashMap}, + fmt::Write, sync::{Arc, Mutex}, time::Instant, }; @@ -605,6 +606,50 @@ pub(super) async fn force_set_room_state_from_server( )) } +pub(super) async fn get_signing_keys( + _body: Vec<&str>, server_name: Option<Box<ServerName>>, _cached: bool, +) -> Result<RoomMessageEventContent> { + let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); + let signing_keys = services().globals.signing_keys_for(&server_name)?; + + Ok(RoomMessageEventContent::notice_markdown(format!( + "```rs\n{signing_keys:#?}\n```" + ))) +} + +#[allow(dead_code)] +pub(super) async fn get_verify_keys( + _body: Vec<&str>, server_name: Option<Box<ServerName>>, cached: bool, +) -> Result<RoomMessageEventContent> { + let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); + let mut out = String::new(); + + if cached { + writeln!(out, "| Key ID | VerifyKey |")?; + writeln!(out, "| --- | --- |")?; + for (key_id, verify_key) in services().globals.verify_keys_for(&server_name)? { + writeln!(out, "| {key_id} | {verify_key:?} |")?; + } + + return Ok(RoomMessageEventContent::notice_markdown(out)); + } + + let signature_ids: Vec<String> = Vec::new(); + let keys = services() + .rooms + .event_handler + .fetch_signing_keys_for_server(&server_name, signature_ids) + .await?; + + writeln!(out, "| Key ID | Public Key |")?; + writeln!(out, "| --- | --- |")?; + for (key_id, key) in keys { + writeln!(out, "| {key_id} | {key} |")?; + } + + Ok(RoomMessageEventContent::notice_markdown(out)) +} + pub(super) async fn resolve_true_destination( _body: Vec<&str>, server_name: Box<ServerName>, no_cache: bool, ) -> Result<RoomMessageEventContent> { diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 7d6cafa77..41527cf50 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -76,6 +76,14 @@ pub(super) enum DebugCommand { room_id: OwnedRoomOrAliasId, }, + /// - Get and display signing keys from local cache or remote server. + GetSigningKeys { + server_name: Option<Box<ServerName>>, + + #[arg(short, long)] + cached: bool, + }, + /// - Sends a federation request to the remote server's /// `/_matrix/federation/v1/version` endpoint and measures the latency it /// took for the server to respond @@ -177,6 +185,10 @@ pub(super) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro DebugCommand::Echo { message, } => echo(body, message).await?, + DebugCommand::GetSigningKeys { + server_name, + cached, + } => get_signing_keys(body, server_name, cached).await?, DebugCommand::GetAuthChain { event_id, } => get_auth_chain(body, event_id).await?, -- GitLab From 2468e0c3ded9a26cc4014e2d4e35e6d968f8f4cc Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 00:13:03 +0000 Subject: [PATCH 27/47] unconditionally derive Debug on subcommand enums Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/admin.rs | 2 +- src/admin/appservice/mod.rs | 3 +-- src/admin/check/mod.rs | 3 +-- src/admin/debug/mod.rs | 3 +-- src/admin/debug/tester.rs | 3 +-- src/admin/federation/mod.rs | 3 +-- src/admin/media/mod.rs | 3 +-- src/admin/query/mod.rs | 30 ++++++++++-------------------- src/admin/room/mod.rs | 15 +++++---------- src/admin/server/mod.rs | 3 +-- src/admin/user/mod.rs | 3 +-- 11 files changed, 24 insertions(+), 47 deletions(-) diff --git a/src/admin/admin.rs b/src/admin/admin.rs index e55f6d610..f5fe5dc22 100644 --- a/src/admin/admin.rs +++ b/src/admin/admin.rs @@ -8,7 +8,7 @@ server, server::ServerCommand, user, user::UserCommand, }; -#[derive(Parser)] +#[derive(Debug, Parser)] #[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] pub(crate) enum AdminCommand { #[command(subcommand)] diff --git a/src/admin/appservice/mod.rs b/src/admin/appservice/mod.rs index 87ab1b6d3..81e04087c 100644 --- a/src/admin/appservice/mod.rs +++ b/src/admin/appservice/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum AppserviceCommand { /// - Register an appservice using its registration YAML /// diff --git a/src/admin/check/mod.rs b/src/admin/check/mod.rs index fdbe3d0d7..f1cfa2b94 100644 --- a/src/admin/check/mod.rs +++ b/src/admin/check/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum CheckCommand { AllUsers, } diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 41527cf50..0df477487 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -8,8 +8,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum DebugCommand { /// - Echo input of admin command Echo { diff --git a/src/admin/debug/tester.rs b/src/admin/debug/tester.rs index 6982ef52b..2765a344d 100644 --- a/src/admin/debug/tester.rs +++ b/src/admin/debug/tester.rs @@ -2,8 +2,7 @@ use crate::Result; -#[derive(clap::Subcommand)] -#[cfg_attr(test, derive(Debug))] +#[derive(Debug, clap::Subcommand)] pub(crate) enum TesterCommand { Tester, Timer, diff --git a/src/admin/federation/mod.rs b/src/admin/federation/mod.rs index 46e323b23..d02b42956 100644 --- a/src/admin/federation/mod.rs +++ b/src/admin/federation/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum FederationCommand { /// - List all rooms we are currently handling an incoming pdu from IncomingFederation, diff --git a/src/admin/media/mod.rs b/src/admin/media/mod.rs index 60f7a9bb6..d30c55d0b 100644 --- a/src/admin/media/mod.rs +++ b/src/admin/media/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum MediaCommand { /// - Deletes a single media file from our database and on the filesystem /// via a single MXC URL diff --git a/src/admin/query/mod.rs b/src/admin/query/mod.rs index ea7036d04..c86f4f538 100644 --- a/src/admin/query/mod.rs +++ b/src/admin/query/mod.rs @@ -21,8 +21,7 @@ room_alias::room_alias, sending::sending, users::users, }; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// Query tables from database pub(super) enum QueryCommand { /// - account_data.rs iterators and getters @@ -62,8 +61,7 @@ pub(super) enum QueryCommand { Resolver(Resolver), } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/account_data.rs pub(super) enum AccountData { /// - Returns all changes to the account data that happened after `since`. @@ -87,8 +85,7 @@ pub(super) enum AccountData { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/appservice.rs pub(super) enum Appservice { /// - Gets the appservice registration info/details from the ID as a string @@ -101,8 +98,7 @@ pub(super) enum Appservice { All, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/presence.rs pub(super) enum Presence { /// - Returns the latest presence event for the given user. @@ -119,8 +115,7 @@ pub(super) enum Presence { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/rooms/alias.rs pub(super) enum RoomAlias { ResolveLocalAlias { @@ -138,8 +133,7 @@ pub(super) enum RoomAlias { AllLocalAliases, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomStateCache { ServerInRoom { server: Box<ServerName>, @@ -210,8 +204,7 @@ pub(super) enum RoomStateCache { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/globals.rs pub(super) enum Globals { DatabaseVersion, @@ -229,8 +222,7 @@ pub(super) enum Globals { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/sending.rs pub(super) enum Sending { /// - Queries database for all `servercurrentevent_data` @@ -285,15 +277,13 @@ pub(super) enum Sending { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/users.rs pub(super) enum Users { Iter, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] /// Resolver service and caches pub(super) enum Resolver { /// Query the destinations cache diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index b4fa15bdb..8f125f0a3 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -10,8 +10,7 @@ use self::room_commands::list; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomCommand { /// - List all rooms the server knows about List { @@ -43,8 +42,7 @@ pub(super) enum RoomCommand { Directory(RoomDirectoryCommand), } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomInfoCommand { /// - List joined members in a room ListJoinedMembers { @@ -60,8 +58,7 @@ pub(super) enum RoomInfoCommand { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomAliasCommand { /// - Make an alias point to a room. Set { @@ -96,8 +93,7 @@ pub(super) enum RoomAliasCommand { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomDirectoryCommand { /// - Publish a room to the room directory Publish { @@ -117,8 +113,7 @@ pub(super) enum RoomDirectoryCommand { }, } -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum RoomModerationCommand { /// - Bans a room from local users joining and evicts all our local users /// from the room. Also blocks any invites (local and remote) for the diff --git a/src/admin/server/mod.rs b/src/admin/server/mod.rs index 41b101802..604dacc7f 100644 --- a/src/admin/server/mod.rs +++ b/src/admin/server/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum ServerCommand { /// - Time elapsed since startup Uptime, diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index cdb5fa5ea..1b92d668c 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -6,8 +6,7 @@ use self::commands::*; -#[cfg_attr(test, derive(Debug))] -#[derive(Subcommand)] +#[derive(Debug, Subcommand)] pub(super) enum UserCommand { /// - Create a new user Create { -- GitLab From 85f734ec74d98cda9d6557dd56e4f4b99a976ed0 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 00:14:03 +0000 Subject: [PATCH 28/47] =?UTF-8?q?proc=20macro=20=E2=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 11 +++++++ Cargo.toml | 15 +++++++++ src/admin/Cargo.toml | 1 + src/admin/debug/commands.rs | 13 ++++---- src/admin/debug/mod.rs | 64 +++---------------------------------- src/admin/handler.rs | 6 ++-- src/admin/mod.rs | 1 + src/macros/Cargo.toml | 24 ++++++++++++++ src/macros/admin.rs | 50 +++++++++++++++++++++++++++++ src/macros/mod.rs | 8 +++++ 10 files changed, 123 insertions(+), 70 deletions(-) create mode 100644 src/macros/Cargo.toml create mode 100644 src/macros/admin.rs create mode 100644 src/macros/mod.rs diff --git a/Cargo.lock b/Cargo.lock index e9730053e..0649f3be1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,6 +612,7 @@ dependencies = [ "clap", "conduit_api", "conduit_core", + "conduit_macros", "conduit_service", "const-str", "futures-util", @@ -712,6 +713,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "conduit_macros" +version = "0.4.5" +dependencies = [ + "conduit_core", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "conduit_router" version = "0.4.6" diff --git a/Cargo.toml b/Cargo.toml index f5c2ad241..5fcb03ef8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -413,6 +413,16 @@ default-features = false [workspace.dependencies.checked_ops] version = "0.1" +[workspace.dependencies.syn] +version = "1.0" +features = ["full", "extra-traits"] + +[workspace.dependencies.quote] +version = "1.0" + +[workspace.dependencies.proc-macro2] +version = "1.0.86" + # # Patches @@ -480,6 +490,11 @@ package = "conduit_core" path = "src/core" default-features = false +[workspace.dependencies.conduit-macros] +package = "conduit_macros" +path = "src/macros" +default-features = false + ############################################################################### # # Release profiles diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index 1e13fb7a9..d756b3cbd 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,6 +29,7 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true +conduit-macros.workspace = true conduit-service.workspace = true const-str.workspace = true futures-util.workspace = true diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 3a7d35449..4efefce7c 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -477,7 +477,7 @@ pub(super) async fn latest_pdu_in_room(_body: Vec<&str>, room_id: Box<RoomId>) - #[tracing::instrument(skip(_body))] pub(super) async fn force_set_room_state_from_server( - _body: Vec<&str>, server_name: Box<ServerName>, room_id: Box<RoomId>, + _body: Vec<&str>, room_id: Box<RoomId>, server_name: Box<ServerName>, ) -> Result<RoomMessageEventContent> { if !services() .rooms @@ -691,18 +691,19 @@ pub(super) async fn resolve_true_destination( Ok(RoomMessageEventContent::text_markdown(msg)) } -#[must_use] -pub(super) fn memory_stats() -> RoomMessageEventContent { +pub(super) async fn memory_stats(_body: Vec<&str>) -> Result<RoomMessageEventContent> { let html_body = conduit::alloc::memory_stats(); if html_body.is_none() { - return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc."); + return Ok(RoomMessageEventContent::text_plain( + "malloc stats are not supported on your compiled malloc.", + )); } - RoomMessageEventContent::text_html( + Ok(RoomMessageEventContent::text_html( "This command's output can only be viewed by clients that render HTML.".to_owned(), html_body.expect("string result"), - ) + )) } #[cfg(tokio_unstable)] diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 0df477487..82b37c535 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -3,11 +3,12 @@ use clap::Subcommand; use conduit::Result; +use conduit_macros::admin_command_dispatch; use ruma::{events::room::message::RoomMessageEventContent, EventId, OwnedRoomOrAliasId, RoomId, ServerName}; -use tester::TesterCommand; -use self::commands::*; +use self::{commands::*, tester::TesterCommand}; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum DebugCommand { /// - Echo input of admin command @@ -176,63 +177,6 @@ pub(super) enum DebugCommand { /// - Developer test stubs #[command(subcommand)] + #[allow(non_snake_case)] Tester(TesterCommand), } - -pub(super) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - DebugCommand::Echo { - message, - } => echo(body, message).await?, - DebugCommand::GetSigningKeys { - server_name, - cached, - } => get_signing_keys(body, server_name, cached).await?, - DebugCommand::GetAuthChain { - event_id, - } => get_auth_chain(body, event_id).await?, - DebugCommand::ParsePdu => parse_pdu(body).await?, - DebugCommand::GetPdu { - event_id, - } => get_pdu(body, event_id).await?, - DebugCommand::GetRemotePdu { - event_id, - server, - } => get_remote_pdu(body, event_id, server).await?, - DebugCommand::GetRoomState { - room_id, - } => get_room_state(body, room_id).await?, - DebugCommand::Ping { - server, - } => ping(body, server).await?, - DebugCommand::ForceDeviceListUpdates => force_device_list_updates(body).await?, - DebugCommand::ChangeLogLevel { - filter, - reset, - } => change_log_level(body, filter, reset).await?, - DebugCommand::SignJson => sign_json(body).await?, - DebugCommand::VerifyJson => verify_json(body).await?, - DebugCommand::FirstPduInRoom { - room_id, - } => first_pdu_in_room(body, room_id).await?, - DebugCommand::LatestPduInRoom { - room_id, - } => latest_pdu_in_room(body, room_id).await?, - DebugCommand::GetRemotePduList { - server, - force, - } => get_remote_pdu_list(body, server, force).await?, - DebugCommand::ForceSetRoomStateFromServer { - room_id, - server_name, - } => force_set_room_state_from_server(body, server_name, room_id).await?, - DebugCommand::ResolveTrueDestination { - server_name, - no_cache, - } => resolve_true_destination(body, server_name, no_cache).await?, - DebugCommand::MemoryStats => memory_stats(), - DebugCommand::RuntimeMetrics => runtime_metrics(body).await?, - DebugCommand::RuntimeInterval => runtime_interval(body).await?, - DebugCommand::Tester(command) => tester::process(command, body).await?, - }) -} diff --git a/src/admin/handler.rs b/src/admin/handler.rs index ff04d3782..6acb19bfd 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -1,7 +1,7 @@ use std::{panic::AssertUnwindSafe, time::Instant}; use clap::{CommandFactory, Parser}; -use conduit::{error, trace, Error}; +use conduit::{error, trace, utils::string::common_prefix, Error, Result}; use futures_util::future::FutureExt; use ruma::{ events::{ @@ -10,8 +10,6 @@ }, OwnedEventId, }; - -use conduit::{utils::string::common_prefix, Result}; use service::{ admin::{Command, CommandOutput, CommandResult, HandlerResult}, Services, @@ -36,7 +34,7 @@ pub(super) fn handle(command: Command) -> HandlerResult { Box::pin(handle_comman #[tracing::instrument(skip_all, name = "admin")] async fn handle_command(command: Command) -> CommandResult { - AssertUnwindSafe(process_command(&command)) + AssertUnwindSafe(Box::pin(process_command(&command))) .catch_unwind() .await .map_err(Error::from_panic) diff --git a/src/admin/mod.rs b/src/admin/mod.rs index ff2aefd5c..7d752ff86 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,5 +1,6 @@ #![recursion_limit = "168"] #![allow(clippy::wildcard_imports)] +#![allow(clippy::enum_glob_use)] pub(crate) mod admin; pub(crate) mod handler; diff --git a/src/macros/Cargo.toml b/src/macros/Cargo.toml new file mode 100644 index 000000000..b9a35aab7 --- /dev/null +++ b/src/macros/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "conduit_macros" +categories.workspace = true +description.workspace = true +edition.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[lib] +name = "conduit_macros" +path = "mod.rs" +proc-macro = true + +[dependencies] +syn.workspace = true +quote.workspace = true +proc-macro2.workspace = true +conduit-core.workspace = true + +[lints] +workspace = true diff --git a/src/macros/admin.rs b/src/macros/admin.rs new file mode 100644 index 000000000..e95e402ae --- /dev/null +++ b/src/macros/admin.rs @@ -0,0 +1,50 @@ +use conduit_core::utils::string::camel_to_snake_string; +use proc_macro::{Span, TokenStream}; +use proc_macro2::TokenStream as TokenStream2; +use quote::quote; +use syn::{parse_macro_input, AttributeArgs, Fields, Ident, ItemEnum, Variant}; + +pub(super) fn command_dispatch(args: TokenStream, input_: TokenStream) -> TokenStream { + let input = input_.clone(); + let item = parse_macro_input!(input as ItemEnum); + let _args = parse_macro_input!(args as AttributeArgs); + let arm = item.variants.iter().map(dispatch_arm); + let name = item.ident; + let q = quote! { + pub(super) async fn process(command: #name, body: Vec<&str>) -> Result<RoomMessageEventContent> { + use #name::*; + #[allow(non_snake_case)] + Ok(match command { + #( #arm )* + }) + } + }; + + [input_, q.into()].into_iter().collect::<TokenStream>() +} + +fn dispatch_arm(v: &Variant) -> TokenStream2 { + let name = &v.ident; + let target = camel_to_snake_string(&format!("{name}")); + let handler = Ident::new(&target, Span::call_site().into()); + match &v.fields { + Fields::Named(fields) => { + let field = fields.named.iter().filter_map(|f| f.ident.as_ref()); + let arg = field.clone(); + quote! { + #name { #( #field ),* } => Box::pin(#handler(body, #( #arg ),*)).await?, + } + }, + Fields::Unnamed(fields) => { + let field = &fields.unnamed.first().expect("one field"); + quote! { + #name ( #field ) => Box::pin(#handler::process(#field, body)).await?, + } + }, + Fields::Unit => { + quote! { + #name => Box::pin(#handler(body)).await?, + } + }, + } +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs new file mode 100644 index 000000000..0aba7560e --- /dev/null +++ b/src/macros/mod.rs @@ -0,0 +1,8 @@ +mod admin; + +use proc_macro::TokenStream; + +#[proc_macro_attribute] +pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStream { + admin::command_dispatch(args, input) +} -- GitLab From 111cbea6fc70b25fde58a33dc1e8594de79040e0 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 00:15:03 +0000 Subject: [PATCH 29/47] add debug time command Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 9 +++++++-- src/admin/debug/mod.rs | 3 +++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 4efefce7c..ff68641c3 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -2,14 +2,14 @@ collections::{BTreeMap, HashMap}, fmt::Write, sync::{Arc, Mutex}, - time::Instant, + time::{Instant, SystemTime}, }; use api::client::validate_and_add_event_id; use conduit::{ debug, info, log, log::{capture, Capture}, - warn, Error, PduEvent, Result, + utils, warn, Error, PduEvent, Result, }; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, @@ -739,3 +739,8 @@ pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEven "Runtime metrics require building with `tokio_unstable`.", )) } + +pub(super) async fn time(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + let now = SystemTime::now(); + Ok(RoomMessageEventContent::text_markdown(utils::time::format(now, "%+"))) +} diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 82b37c535..77bc36b8e 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -175,6 +175,9 @@ pub(super) enum DebugCommand { /// invocation. RuntimeInterval, + /// - Print the current time + Time, + /// - Developer test stubs #[command(subcommand)] #[allow(non_snake_case)] -- GitLab From ee864bcd9e0133f4c002d68ebc373b54b183eb16 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 01:26:23 +0000 Subject: [PATCH 30/47] normalize admin debug command handlers Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 57 +++++++++++++++++++------------------ src/macros/admin.rs | 4 +-- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index ff68641c3..32fba4095 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -20,13 +20,13 @@ use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; -pub(super) async fn echo(_body: Vec<&str>, message: Vec<String>) -> Result<RoomMessageEventContent> { +pub(super) async fn echo(_body: &[&str], message: Vec<String>) -> Result<RoomMessageEventContent> { let message = message.join(" "); Ok(RoomMessageEventContent::notice_plain(message)) } -pub(super) async fn get_auth_chain(_body: Vec<&str>, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { +pub(super) async fn get_auth_chain(_body: &[&str], event_id: Box<EventId>) -> Result<RoomMessageEventContent> { let event_id = Arc::<EventId>::from(event_id); if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { let room_id_str = event @@ -52,7 +52,7 @@ pub(super) async fn get_auth_chain(_body: Vec<&str>, event_id: Box<EventId>) -> } } -pub(super) async fn parse_pdu(body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn parse_pdu(body: &[&str]) -> Result<RoomMessageEventContent> { if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -80,7 +80,7 @@ pub(super) async fn parse_pdu(body: Vec<&str>) -> Result<RoomMessageEventContent } } -pub(super) async fn get_pdu(_body: Vec<&str>, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { +pub(super) async fn get_pdu(_body: &[&str], event_id: Box<EventId>) -> Result<RoomMessageEventContent> { let mut outlier = false; let mut pdu_json = services() .rooms @@ -108,7 +108,7 @@ pub(super) async fn get_pdu(_body: Vec<&str>, event_id: Box<EventId>) -> Result< } pub(super) async fn get_remote_pdu_list( - body: Vec<&str>, server: Box<ServerName>, force: bool, + body: &[&str], server: Box<ServerName>, force: bool, ) -> Result<RoomMessageEventContent> { if !services().globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( @@ -130,14 +130,15 @@ pub(super) async fn get_remote_pdu_list( } let list = body - .clone() - .drain(1..body.len().checked_sub(1).unwrap()) + .iter() + .collect::<Vec<_>>() + .drain(1..body.len().saturating_sub(1)) .filter_map(|pdu| EventId::parse(pdu).ok()) .collect::<Vec<_>>(); for pdu in list { if force { - if let Err(e) = get_remote_pdu(Vec::new(), Box::from(pdu), server.clone()).await { + if let Err(e) = get_remote_pdu(&[], Box::from(pdu), server.clone()).await { services() .admin .send_message(RoomMessageEventContent::text_plain(format!( @@ -147,7 +148,7 @@ pub(super) async fn get_remote_pdu_list( warn!(%e, "Failed to get remote PDU, ignoring error"); } } else { - get_remote_pdu(Vec::new(), Box::from(pdu), server.clone()).await?; + get_remote_pdu(&[], Box::from(pdu), server.clone()).await?; } } @@ -155,7 +156,7 @@ pub(super) async fn get_remote_pdu_list( } pub(super) async fn get_remote_pdu( - _body: Vec<&str>, event_id: Box<EventId>, server: Box<ServerName>, + _body: &[&str], event_id: Box<EventId>, server: Box<ServerName>, ) -> Result<RoomMessageEventContent> { if !services().globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( @@ -240,7 +241,7 @@ pub(super) async fn get_remote_pdu( } } -pub(super) async fn get_room_state(_body: Vec<&str>, room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> { +pub(super) async fn get_room_state(_body: &[&str], room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> { let room_id = services().rooms.alias.resolve(&room).await?; let room_state = services() .rooms @@ -267,7 +268,7 @@ pub(super) async fn get_room_state(_body: Vec<&str>, room: OwnedRoomOrAliasId) - Ok(RoomMessageEventContent::notice_markdown(format!("```json\n{json}\n```"))) } -pub(super) async fn ping(_body: Vec<&str>, server: Box<ServerName>) -> Result<RoomMessageEventContent> { +pub(super) async fn ping(_body: &[&str], server: Box<ServerName>) -> Result<RoomMessageEventContent> { if server == services().globals.server_name() { return Ok(RoomMessageEventContent::text_plain( "Not allowed to send federation requests to ourselves.", @@ -305,7 +306,7 @@ pub(super) async fn ping(_body: Vec<&str>, server: Box<ServerName>) -> Result<Ro } } -pub(super) async fn force_device_list_updates(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn force_device_list_updates(_body: &[&str]) -> Result<RoomMessageEventContent> { // Force E2EE device list updates for all users for user_id in services().users.iter().filter_map(Result::ok) { services().users.mark_device_key_update(&user_id)?; @@ -316,7 +317,7 @@ pub(super) async fn force_device_list_updates(_body: Vec<&str>) -> Result<RoomMe } pub(super) async fn change_log_level( - _body: Vec<&str>, filter: Option<String>, reset: bool, + _body: &[&str], filter: Option<String>, reset: bool, ) -> Result<RoomMessageEventContent> { let handles = &["console"]; @@ -380,7 +381,7 @@ pub(super) async fn change_log_level( Ok(RoomMessageEventContent::text_plain("No log level was specified.")) } -pub(super) async fn sign_json(body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn sign_json(body: &[&str]) -> Result<RoomMessageEventContent> { if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -403,7 +404,7 @@ pub(super) async fn sign_json(body: Vec<&str>) -> Result<RoomMessageEventContent } } -pub(super) async fn verify_json(body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn verify_json(body: &[&str]) -> Result<RoomMessageEventContent> { if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", @@ -434,7 +435,7 @@ pub(super) async fn verify_json(body: Vec<&str>) -> Result<RoomMessageEventConte } #[tracing::instrument(skip(_body))] -pub(super) async fn first_pdu_in_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { +pub(super) async fn first_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { if !services() .rooms .state_cache @@ -455,7 +456,7 @@ pub(super) async fn first_pdu_in_room(_body: Vec<&str>, room_id: Box<RoomId>) -> } #[tracing::instrument(skip(_body))] -pub(super) async fn latest_pdu_in_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { +pub(super) async fn latest_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { if !services() .rooms .state_cache @@ -477,7 +478,7 @@ pub(super) async fn latest_pdu_in_room(_body: Vec<&str>, room_id: Box<RoomId>) - #[tracing::instrument(skip(_body))] pub(super) async fn force_set_room_state_from_server( - _body: Vec<&str>, room_id: Box<RoomId>, server_name: Box<ServerName>, + _body: &[&str], room_id: Box<RoomId>, server_name: Box<ServerName>, ) -> Result<RoomMessageEventContent> { if !services() .rooms @@ -607,7 +608,7 @@ pub(super) async fn force_set_room_state_from_server( } pub(super) async fn get_signing_keys( - _body: Vec<&str>, server_name: Option<Box<ServerName>>, _cached: bool, + _body: &[&str], server_name: Option<Box<ServerName>>, _cached: bool, ) -> Result<RoomMessageEventContent> { let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); let signing_keys = services().globals.signing_keys_for(&server_name)?; @@ -619,7 +620,7 @@ pub(super) async fn get_signing_keys( #[allow(dead_code)] pub(super) async fn get_verify_keys( - _body: Vec<&str>, server_name: Option<Box<ServerName>>, cached: bool, + _body: &[&str], server_name: Option<Box<ServerName>>, cached: bool, ) -> Result<RoomMessageEventContent> { let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); let mut out = String::new(); @@ -651,7 +652,7 @@ pub(super) async fn get_verify_keys( } pub(super) async fn resolve_true_destination( - _body: Vec<&str>, server_name: Box<ServerName>, no_cache: bool, + _body: &[&str], server_name: Box<ServerName>, no_cache: bool, ) -> Result<RoomMessageEventContent> { if !services().globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( @@ -691,7 +692,7 @@ pub(super) async fn resolve_true_destination( Ok(RoomMessageEventContent::text_markdown(msg)) } -pub(super) async fn memory_stats(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn memory_stats(_body: &[&str]) -> Result<RoomMessageEventContent> { let html_body = conduit::alloc::memory_stats(); if html_body.is_none() { @@ -707,7 +708,7 @@ pub(super) async fn memory_stats(_body: Vec<&str>) -> Result<RoomMessageEventCon } #[cfg(tokio_unstable)] -pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn runtime_metrics(_body: &[&str]) -> Result<RoomMessageEventContent> { let out = services().server.metrics.runtime_metrics().map_or_else( || "Runtime metrics are not available.".to_owned(), |metrics| format!("```rs\n{metrics:#?}\n```"), @@ -717,14 +718,14 @@ pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result<RoomMessageEvent } #[cfg(not(tokio_unstable))] -pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn runtime_metrics(_body: &[&str]) -> Result<RoomMessageEventContent> { Ok(RoomMessageEventContent::text_markdown( "Runtime metrics require building with `tokio_unstable`.", )) } #[cfg(tokio_unstable)] -pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn runtime_interval(_body: &[&str]) -> Result<RoomMessageEventContent> { let out = services().server.metrics.runtime_interval().map_or_else( || "Runtime metrics are not available.".to_owned(), |metrics| format!("```rs\n{metrics:#?}\n```"), @@ -734,13 +735,13 @@ pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEven } #[cfg(not(tokio_unstable))] -pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn runtime_interval(_body: &[&str]) -> Result<RoomMessageEventContent> { Ok(RoomMessageEventContent::text_markdown( "Runtime metrics require building with `tokio_unstable`.", )) } -pub(super) async fn time(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn time(_body: &[&str]) -> Result<RoomMessageEventContent> { let now = SystemTime::now(); Ok(RoomMessageEventContent::text_markdown(utils::time::format(now, "%+"))) } diff --git a/src/macros/admin.rs b/src/macros/admin.rs index e95e402ae..a1a5b22ae 100644 --- a/src/macros/admin.rs +++ b/src/macros/admin.rs @@ -32,7 +32,7 @@ fn dispatch_arm(v: &Variant) -> TokenStream2 { let field = fields.named.iter().filter_map(|f| f.ident.as_ref()); let arg = field.clone(); quote! { - #name { #( #field ),* } => Box::pin(#handler(body, #( #arg ),*)).await?, + #name { #( #field ),* } => Box::pin(#handler(&body, #( #arg ),*)).await?, } }, Fields::Unnamed(fields) => { @@ -43,7 +43,7 @@ fn dispatch_arm(v: &Variant) -> TokenStream2 { }, Fields::Unit => { quote! { - #name => Box::pin(#handler(body)).await?, + #name => Box::pin(#handler(&body)).await?, } }, } -- GitLab From d7d874f88d4d00204883bf3f773d76d75d0d54f1 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 01:06:58 +0000 Subject: [PATCH 31/47] start core info module; move version to info Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/info/mod.rs | 4 ++++ src/core/{ => info}/version.rs | 13 +++++++------ src/core/mod.rs | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 src/core/info/mod.rs rename src/core/{ => info}/version.rs (75%) diff --git a/src/core/info/mod.rs b/src/core/info/mod.rs new file mode 100644 index 000000000..42ec971e9 --- /dev/null +++ b/src/core/info/mod.rs @@ -0,0 +1,4 @@ +//! Information about the project. This module contains version, build, system, +//! etc information which can be queried by admins or used by developers. + +pub mod version; diff --git a/src/core/version.rs b/src/core/info/version.rs similarity index 75% rename from src/core/version.rs rename to src/core/info/version.rs index 2876cea87..fb71d4e17 100644 --- a/src/core/version.rs +++ b/src/core/info/version.rs @@ -1,9 +1,10 @@ -/// one true function for returning the conduwuit version with the necessary -/// CONDUWUIT_VERSION_EXTRA env variables used if specified -/// -/// Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string -/// to include it in parenthesis after the SemVer version. A common value are -/// git commit hashes. +//! one true function for returning the conduwuit version with the necessary +//! CONDUWUIT_VERSION_EXTRA env variables used if specified +//! +//! Set the environment variable `CONDUWUIT_VERSION_EXTRA` to any UTF-8 string +//! to include it in parenthesis after the SemVer version. A common value are +//! git commit hashes. + use std::sync::OnceLock; static BRANDING: &str = "Conduwuit"; diff --git a/src/core/mod.rs b/src/core/mod.rs index 9716b46e8..749d7b080 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,19 +2,19 @@ pub mod config; pub mod debug; pub mod error; +pub mod info; pub mod log; pub mod metrics; pub mod mods; pub mod pdu; pub mod server; pub mod utils; -pub mod version; pub use config::Config; pub use error::Error; +pub use info::{version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; -pub use version::version; pub type Result<T, E = Error> = std::result::Result<T, E>; -- GitLab From 7d487d53d81d8095c4fd9128f848e0bde47427b5 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 03:55:01 +0000 Subject: [PATCH 32/47] de-cycle conduit_macros from conduit_core. Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 1 - src/macros/Cargo.toml | 1 - src/macros/admin.rs | 3 ++- src/macros/mod.rs | 1 + src/macros/utils.rs | 25 +++++++++++++++++++++++++ 5 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 src/macros/utils.rs diff --git a/Cargo.lock b/Cargo.lock index 0649f3be1..71d79adb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -717,7 +717,6 @@ dependencies = [ name = "conduit_macros" version = "0.4.5" dependencies = [ - "conduit_core", "proc-macro2", "quote", "syn 1.0.109", diff --git a/src/macros/Cargo.toml b/src/macros/Cargo.toml index b9a35aab7..ca98f169e 100644 --- a/src/macros/Cargo.toml +++ b/src/macros/Cargo.toml @@ -18,7 +18,6 @@ proc-macro = true syn.workspace = true quote.workspace = true proc-macro2.workspace = true -conduit-core.workspace = true [lints] workspace = true diff --git a/src/macros/admin.rs b/src/macros/admin.rs index a1a5b22ae..e1d294b91 100644 --- a/src/macros/admin.rs +++ b/src/macros/admin.rs @@ -1,9 +1,10 @@ -use conduit_core::utils::string::camel_to_snake_string; use proc_macro::{Span, TokenStream}; use proc_macro2::TokenStream as TokenStream2; use quote::quote; use syn::{parse_macro_input, AttributeArgs, Fields, Ident, ItemEnum, Variant}; +use crate::utils::camel_to_snake_string; + pub(super) fn command_dispatch(args: TokenStream, input_: TokenStream) -> TokenStream { let input = input_.clone(); let item = parse_macro_input!(input as ItemEnum); diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 0aba7560e..6f286d66d 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,4 +1,5 @@ mod admin; +mod utils; use proc_macro::TokenStream; diff --git a/src/macros/utils.rs b/src/macros/utils.rs new file mode 100644 index 000000000..f512c56c2 --- /dev/null +++ b/src/macros/utils.rs @@ -0,0 +1,25 @@ +#[must_use] +pub(crate) fn camel_to_snake_string(s: &str) -> String { + let mut output = String::with_capacity( + s.chars() + .fold(s.len(), |a, ch| a.saturating_add(usize::from(ch.is_ascii_uppercase()))), + ); + + let mut state = false; + s.chars().for_each(|ch| { + let m = ch.is_ascii_uppercase(); + let s = exchange(&mut state, !m); + if m && s { + output.push('_'); + } + output.push(ch.to_ascii_lowercase()); + }); + + output +} + +pub(crate) fn exchange<T: Clone>(state: &mut T, source: T) -> T { + let ret = state.clone(); + *state = source; + ret +} -- GitLab From f01423164479345493c06c86121678113e27d58e Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 09:04:16 +0000 Subject: [PATCH 33/47] add conf item to disable rocksdb compaction Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/config/mod.rs | 3 +++ src/database/opts.rs | 1 + 2 files changed, 4 insertions(+) diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 6cef402af..ae2735454 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -236,6 +236,8 @@ pub struct Config { pub rocksdb_compaction_prio_idle: bool, #[serde(default = "true_fn")] pub rocksdb_compaction_ioprio_idle: bool, + #[serde(default = "true_fn")] + pub rocksdb_compaction: bool, pub emergency_password: Option<String>, @@ -712,6 +714,7 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { "RocksDB Compaction Idle IOPriority", &self.rocksdb_compaction_ioprio_idle.to_string(), ); + line("RocksDB Compaction enabled", &self.rocksdb_compaction.to_string()); line("Media integrity checks on startup", &self.media_startup_check.to_string()); line("Media compatibility filesystem links", &self.media_compat_file_link.to_string()); line("Prevent Media Downloads From", { diff --git a/src/database/opts.rs b/src/database/opts.rs index d22364548..3c3ad4242 100644 --- a/src/database/opts.rs +++ b/src/database/opts.rs @@ -68,6 +68,7 @@ pub(crate) fn db_options(config: &Config, env: &mut Env, row_cache: &Cache, col_ set_compression_defaults(&mut opts, config); // Misc + opts.set_disable_auto_compactions(!config.rocksdb_compaction); opts.create_if_missing(true); // Default: https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords -- GitLab From 936d2915e2257be20213e3501e96be3df10c5f49 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 09:10:01 +0000 Subject: [PATCH 34/47] add cargo manifest reflection Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 12 ++++++++ Cargo.toml | 4 +++ src/core/Cargo.toml | 2 ++ src/core/error/mod.rs | 2 ++ src/core/info/cargo.rs | 66 ++++++++++++++++++++++++++++++++++++++++++ src/core/info/mod.rs | 1 + src/macros/cargo.rs | 51 ++++++++++++++++++++++++++++++++ src/macros/mod.rs | 4 +++ 8 files changed, 142 insertions(+) create mode 100644 src/core/info/cargo.rs create mode 100644 src/macros/cargo.rs diff --git a/Cargo.lock b/Cargo.lock index 71d79adb8..d9af38d24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -472,6 +472,16 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "cargo_toml" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad639525b1c67b6a298f378417b060fbc04618bea559482a8484381cce27d965" +dependencies = [ + "serde", + "toml", +] + [[package]] name = "cc" version = "1.1.6" @@ -666,8 +676,10 @@ dependencies = [ "argon2", "axum 0.7.5", "bytes", + "cargo_toml", "checked_ops", "chrono", + "conduit_macros", "const-str", "either", "figment", diff --git a/Cargo.toml b/Cargo.toml index 5fcb03ef8..acaed7017 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,10 @@ name = "conduit" [workspace.dependencies.const-str] version = "0.5.7" +[workspace.dependencies.cargo_toml] +version = "0.20" +features = ["features"] + [workspace.dependencies.sanitize-filename] version = "0.5.0" diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 453d7b13e..620aad02d 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -53,8 +53,10 @@ sha256_media = [] argon2.workspace = true axum.workspace = true bytes.workspace = true +cargo_toml.workspace = true checked_ops.workspace = true chrono.workspace = true +conduit-macros.workspace = true const-str.workspace = true either.workspace = true figment.workspace = true diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 9439261e5..8664d740a 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -55,6 +55,8 @@ pub enum Error { Http(#[from] http::Error), #[error("{0}")] HttpHeader(#[from] http::header::InvalidHeaderValue), + #[error("{0}")] + CargoToml(#[from] cargo_toml::Error), // ruma #[error("{0}")] diff --git a/src/core/info/cargo.rs b/src/core/info/cargo.rs new file mode 100644 index 000000000..0d2db1add --- /dev/null +++ b/src/core/info/cargo.rs @@ -0,0 +1,66 @@ +//! Information about the build related to Cargo. This is a frontend interface +//! informed by proc-macros that capture raw information at build time which is +//! further processed at runtime either during static initialization or as +//! necessary. + +use std::sync::OnceLock; + +use cargo_toml::Manifest; +use conduit_macros::cargo_manifest; + +use crate::Result; + +// Raw captures of the cargo manifest for each crate. This is provided by a +// proc-macro at build time since the source directory and the cargo toml's may +// not be present during execution. + +#[cargo_manifest] +const WORKSPACE_MANIFEST: &'static str = (); +#[cargo_manifest("macros")] +const MACROS_MANIFEST: &'static str = (); +#[cargo_manifest("core")] +const CORE_MANIFEST: &'static str = (); +#[cargo_manifest("database")] +const DATABASE_MANIFEST: &'static str = (); +#[cargo_manifest("service")] +const SERVICE_MANIFEST: &'static str = (); +#[cargo_manifest("admin")] +const ADMIN_MANIFEST: &'static str = (); +#[cargo_manifest("router")] +const ROUTER_MANIFEST: &'static str = (); +#[cargo_manifest("main")] +const MAIN_MANIFEST: &'static str = (); + +/// Processed list of features access all project crates. This is generated from +/// the data in the MANIFEST strings and contains all possible project features. +/// For *enabled* features see the info::rustc module instead. +static FEATURES: OnceLock<Vec<String>> = OnceLock::new(); + +/// List of all possible features for the project. For *enabled* features in +/// this build see the companion function in info::rustc. +pub fn features() -> &'static Vec<String> { + FEATURES.get_or_init(|| init_features().unwrap_or_else(|e| panic!("Failed initialize features: {e}"))) +} + +fn init_features() -> Result<Vec<String>> { + let mut features = Vec::new(); + append_features(&mut features, WORKSPACE_MANIFEST)?; + append_features(&mut features, MACROS_MANIFEST)?; + append_features(&mut features, CORE_MANIFEST)?; + append_features(&mut features, DATABASE_MANIFEST)?; + append_features(&mut features, SERVICE_MANIFEST)?; + append_features(&mut features, ADMIN_MANIFEST)?; + append_features(&mut features, ROUTER_MANIFEST)?; + append_features(&mut features, MAIN_MANIFEST)?; + features.sort(); + features.dedup(); + + Ok(features) +} + +fn append_features(features: &mut Vec<String>, manifest: &str) -> Result<()> { + let manifest = Manifest::from_str(manifest)?; + features.extend(manifest.features.keys().cloned()); + + Ok(()) +} diff --git a/src/core/info/mod.rs b/src/core/info/mod.rs index 42ec971e9..7749bbdc8 100644 --- a/src/core/info/mod.rs +++ b/src/core/info/mod.rs @@ -1,4 +1,5 @@ //! Information about the project. This module contains version, build, system, //! etc information which can be queried by admins or used by developers. +pub mod cargo; pub mod version; diff --git a/src/macros/cargo.rs b/src/macros/cargo.rs new file mode 100644 index 000000000..17132a6ca --- /dev/null +++ b/src/macros/cargo.rs @@ -0,0 +1,51 @@ +use std::{fs::read_to_string, path::PathBuf}; + +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, AttributeArgs, ItemConst, Lit, NestedMeta}; + +pub(super) fn manifest(args: TokenStream, item: TokenStream) -> TokenStream { + let item = parse_macro_input!(item as ItemConst); + let args = parse_macro_input!(args as AttributeArgs); + let member = args.into_iter().find_map(|arg| { + let NestedMeta::Lit(arg) = arg else { + return None; + }; + let Lit::Str(arg) = arg else { + return None; + }; + Some(arg.value()) + }); + + let path = manifest_path(member.as_deref()); + let manifest = read_to_string(&path).unwrap_or_default(); + + let name = item.ident; + let val = manifest.as_str(); + let ret = quote! { + const #name: &'static str = #val; + }; + + ret.into() +} + +#[allow(clippy::option_env_unwrap)] +fn manifest_path(member: Option<&str>) -> PathBuf { + let mut path: PathBuf = option_env!("CARGO_MANIFEST_DIR") + .expect("missing CARGO_MANIFEST_DIR in environment") + .into(); + + // conduwuit/src/macros/ -> conduwuit/src/ + path.pop(); + + if let Some(member) = member { + // conduwuit/$member/Cargo.toml + path.push(member); + } else { + // conduwuit/src/ -> conduwuit/ + path.pop(); + } + + path.push("Cargo.toml"); + path +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 6f286d66d..718583a4d 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,4 +1,5 @@ mod admin; +mod cargo; mod utils; use proc_macro::TokenStream; @@ -7,3 +8,6 @@ pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStream { admin::command_dispatch(args, input) } + +#[proc_macro_attribute] +pub fn cargo_manifest(args: TokenStream, input: TokenStream) -> TokenStream { cargo::manifest(args, input) } -- GitLab From 2100618d47ae71ba0d272c65a0b5ddd5b3d7baf2 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 23:01:00 +0000 Subject: [PATCH 35/47] add rustc build flags reflection Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 11 +++++++++ Cargo.toml | 3 +++ src/admin/mod.rs | 7 +++--- src/core/Cargo.toml | 1 + src/core/info/mod.rs | 3 +++ src/core/info/rustc.rs | 53 ++++++++++++++++++++++++++++++++++++++++++ src/core/mod.rs | 7 +++++- src/core/utils/mod.rs | 1 + src/database/mod.rs | 1 + src/macros/mod.rs | 4 ++++ src/macros/rustc.rs | 27 +++++++++++++++++++++ src/main/main.rs | 4 +++- src/router/mod.rs | 1 + src/service/mod.rs | 1 + 14 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 src/core/info/rustc.rs create mode 100644 src/macros/rustc.rs diff --git a/Cargo.lock b/Cargo.lock index d9af38d24..52ea07c9e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -681,6 +681,7 @@ dependencies = [ "chrono", "conduit_macros", "const-str", + "ctor", "either", "figment", "hardened_malloc-rs", @@ -1018,6 +1019,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctor" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" +dependencies = [ + "quote", + "syn 2.0.71", +] + [[package]] name = "curve25519-dalek" version = "4.1.3" diff --git a/Cargo.toml b/Cargo.toml index acaed7017..b65ba7ad3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,9 @@ name = "conduit" [workspace.dependencies.const-str] version = "0.5.7" +[workspace.dependencies.ctor] +version = "0.2.8" + [workspace.dependencies.cargo_toml] version = "0.20" features = ["features"] diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 7d752ff86..cd1110ee3 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -21,15 +21,16 @@ extern crate conduit_core as conduit; extern crate conduit_service as service; -pub(crate) use conduit::{mod_ctor, mod_dtor, Result}; +pub(crate) use conduit::Result; pub(crate) use service::services; pub(crate) use crate::utils::{escape_html, get_room_info}; pub(crate) const PAGE_SIZE: usize = 100; -mod_ctor! {} -mod_dtor! {} +conduit::mod_ctor! {} +conduit::mod_dtor! {} +conduit::rustc_flags_capture! {} /// Install the admin command handler pub async fn init() { diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 620aad02d..ea772a839 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -58,6 +58,7 @@ checked_ops.workspace = true chrono.workspace = true conduit-macros.workspace = true const-str.workspace = true +ctor.workspace = true either.workspace = true figment.workspace = true http-body-util.workspace = true diff --git a/src/core/info/mod.rs b/src/core/info/mod.rs index 7749bbdc8..e11a60219 100644 --- a/src/core/info/mod.rs +++ b/src/core/info/mod.rs @@ -2,4 +2,7 @@ //! etc information which can be queried by admins or used by developers. pub mod cargo; +pub mod rustc; pub mod version; + +pub use conduit_macros::rustc_flags_capture; diff --git a/src/core/info/rustc.rs b/src/core/info/rustc.rs new file mode 100644 index 000000000..048c0cd59 --- /dev/null +++ b/src/core/info/rustc.rs @@ -0,0 +1,53 @@ +//! Information about the build related to rustc. This is a frontend interface +//! informed by proc-macros at build time. Since the project is split into +//! several crates, lower-level information is supplied from each crate during +//! static initialization. + +use std::{ + collections::BTreeMap, + sync::{Mutex, OnceLock}, +}; + +use crate::utils::exchange; + +/// Raw capture of rustc flags used to build each crate in the project. Informed +/// by rustc_flags_capture macro (one in each crate's mod.rs). This is +/// done during static initialization which is why it's mutex-protected and pub. +/// Should not be written to by anything other than our macro. +pub static FLAGS: Mutex<BTreeMap<&str, &[&str]>> = Mutex::new(BTreeMap::new()); + +/// Processed list of enabled features across all project crates. This is +/// generated from the data in FLAGS. +static FEATURES: OnceLock<Vec<&'static str>> = OnceLock::new(); + +/// List of features enabled for the project. +pub fn features() -> &'static Vec<&'static str> { FEATURES.get_or_init(init_features) } + +fn init_features() -> Vec<&'static str> { + let mut features = Vec::new(); + FLAGS + .lock() + .expect("locked") + .iter() + .for_each(|(_, flags)| append_features(&mut features, flags)); + + features.sort_unstable(); + features.dedup(); + features +} + +fn append_features(features: &mut Vec<&'static str>, flags: &[&'static str]) { + let mut next_is_cfg = false; + for flag in flags { + let is_cfg = *flag == "--cfg"; + let is_feature = flag.starts_with("feature="); + if exchange(&mut next_is_cfg, is_cfg) && is_feature { + if let Some(feature) = flag + .split_once('=') + .map(|(_, feature)| feature.trim_matches('"')) + { + features.push(feature); + } + } + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 749d7b080..b302fdcc6 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -12,12 +12,17 @@ pub use config::Config; pub use error::Error; -pub use info::{version, version::version}; +pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; +pub use utils::{ctor, dtor}; + +pub use crate as conduit_core; pub type Result<T, E = Error> = std::result::Result<T, E>; +rustc_flags_capture! {} + #[cfg(not(conduit_mods))] pub mod mods { #[macro_export] diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index d7b6a72a1..767b65a93 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -15,6 +15,7 @@ use std::cmp::{self, Ordering}; +pub use ::ctor::{ctor, dtor}; pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; pub use debug::slice_truncated as debug_slice_truncated; pub use hash::calculate_hash; diff --git a/src/database/mod.rs b/src/database/mod.rs index 283224f66..6446624ca 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -23,3 +23,4 @@ conduit::mod_ctor! {} conduit::mod_dtor! {} +conduit::rustc_flags_capture! {} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 718583a4d..94ea781e3 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,5 +1,6 @@ mod admin; mod cargo; +mod rustc; mod utils; use proc_macro::TokenStream; @@ -11,3 +12,6 @@ pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStr #[proc_macro_attribute] pub fn cargo_manifest(args: TokenStream, input: TokenStream) -> TokenStream { cargo::manifest(args, input) } + +#[proc_macro] +pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } diff --git a/src/macros/rustc.rs b/src/macros/rustc.rs new file mode 100644 index 000000000..a3bab26f3 --- /dev/null +++ b/src/macros/rustc.rs @@ -0,0 +1,27 @@ +use proc_macro::TokenStream; +use quote::quote; + +pub(super) fn flags_capture(args: TokenStream) -> TokenStream { + let cargo_crate_name = std::env::var("CARGO_CRATE_NAME"); + let crate_name = match cargo_crate_name.as_ref() { + Err(_) => return args, + Ok(crate_name) => crate_name.trim_start_matches("conduit_"), + }; + + let flag = std::env::args().collect::<Vec<_>>(); + let ret = quote! { + #[conduit_core::ctor] + fn _set_rustc_flags() { + let flags = &[#( #flag ),*]; + conduit_core::info::rustc::FLAGS.lock().expect("locked").insert(#crate_name, flags); + } + + // static strings have to be yanked on module unload + #[conduit_core::dtor] + fn _unset_rustc_flags() { + conduit_core::info::rustc::FLAGS.lock().expect("locked").remove(#crate_name); + } + }; + + ret.into() +} diff --git a/src/main/main.rs b/src/main/main.rs index b13e117db..b8cb24ff1 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -14,7 +14,7 @@ time::Duration, }; -use conduit::{debug_info, error, utils::available_parallelism, Error, Result}; +use conduit::{debug_info, error, rustc_flags_capture, utils::available_parallelism, Error, Result}; use server::Server; use tokio::runtime; @@ -22,6 +22,8 @@ const WORKER_MIN: usize = 2; const WORKER_KEEPALIVE: u64 = 36; +rustc_flags_capture! {} + fn main() -> Result<(), Error> { let args = clap::parse(); let runtime = runtime::Builder::new_multi_thread() diff --git a/src/router/mod.rs b/src/router/mod.rs index 03c70f6d7..13fe39087 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -14,6 +14,7 @@ conduit::mod_ctor! {} conduit::mod_dtor! {} +conduit::rustc_flags_capture! {} #[no_mangle] pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { diff --git a/src/service/mod.rs b/src/service/mod.rs index ce106809c..b6ec58b5c 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -37,6 +37,7 @@ conduit::mod_ctor! {} conduit::mod_dtor! {} +conduit::rustc_flags_capture! {} static SERVICES: RwLock<Option<&Services>> = RwLock::new(None); -- GitLab From 8bb69eb81ddf1fb90a47f02ee078da943438352c Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Thu, 25 Jul 2024 01:41:31 +0000 Subject: [PATCH 36/47] add simple ast dimension diagnostic Signed-off-by: Jason Volk <jason@zemos.net> --- src/core/debug.rs | 3 +++ src/macros/debug.rs | 28 ++++++++++++++++++++++++++++ src/macros/mod.rs | 4 ++++ 3 files changed, 35 insertions(+) create mode 100644 src/macros/debug.rs diff --git a/src/core/debug.rs b/src/core/debug.rs index 5e52f399a..63f25d472 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,5 +1,8 @@ use std::{any::Any, panic}; +/// Export debug proc_macros +pub use conduit_macros::recursion_depth; + /// Export all of the ancillary tools from here as well. pub use crate::utils::debug::*; diff --git a/src/macros/debug.rs b/src/macros/debug.rs new file mode 100644 index 000000000..5251fa179 --- /dev/null +++ b/src/macros/debug.rs @@ -0,0 +1,28 @@ +use std::cmp; + +use proc_macro::TokenStream; +use syn::{parse_macro_input, AttributeArgs, Item}; + +pub(super) fn recursion_depth(args: TokenStream, item_: TokenStream) -> TokenStream { + let item = item_.clone(); + let item = parse_macro_input!(item as Item); + let _args = parse_macro_input!(args as AttributeArgs); + + let mut best: usize = 0; + let mut count: usize = 0; + // think you'd find a fancy recursive ast visitor? think again + let tree = format!("{item:#?}"); + for line in tree.lines() { + let trim = line.trim_start_matches(' '); + let diff = line.len().saturating_sub(trim.len()); + let level = diff / 4; + best = cmp::max(level, best); + count = count.saturating_add(1); + } + + println!("--- Recursion Diagnostic ---"); + println!("DEPTH: {best}"); + println!("LENGTH: {count}"); + + item_ +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 94ea781e3..b01e5275a 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,5 +1,6 @@ mod admin; mod cargo; +mod debug; mod rustc; mod utils; @@ -13,5 +14,8 @@ pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStr #[proc_macro_attribute] pub fn cargo_manifest(args: TokenStream, input: TokenStream) -> TokenStream { cargo::manifest(args, input) } +#[proc_macro_attribute] +pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { debug::recursion_depth(args, input) } + #[proc_macro] pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } -- GitLab From 4e975887cf2e2e7f81d28789fca2fc4a2fb32657 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 24 Jul 2024 23:53:48 +0000 Subject: [PATCH 37/47] add command to list features Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/server/commands.rs | 45 +++++++++++++++++++++++++++++++++++- src/admin/server/mod.rs | 17 ++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index e45037365..c30f8f871 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -1,4 +1,6 @@ -use conduit::{utils::time, warn, Err, Result}; +use std::fmt::Write; + +use conduit::{info, utils::time, warn, Err, Result}; use ruma::events::room::message::RoomMessageEventContent; use crate::services; @@ -19,6 +21,47 @@ pub(super) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventCont Ok(RoomMessageEventContent::text_plain(format!("{}", services().globals.config))) } +pub(super) async fn list_features( + _body: Vec<&str>, available: bool, enabled: bool, comma: bool, +) -> Result<RoomMessageEventContent> { + let delim = if comma { + "," + } else { + " " + }; + if enabled && !available { + let features = info::rustc::features().join(delim); + let out = format!("```\n{features}\n```"); + return Ok(RoomMessageEventContent::text_markdown(out)); + } + + if available && !enabled { + let features = info::cargo::features().join(delim); + let out = format!("```\n{features}\n```"); + return Ok(RoomMessageEventContent::text_markdown(out)); + } + + let mut features = String::new(); + let enabled = info::rustc::features(); + let available = info::cargo::features(); + for feature in available { + let active = enabled.contains(&feature.as_str()); + let emoji = if active { + "✅" + } else { + "âŒ" + }; + let remark = if active { + "[enabled]" + } else { + "" + }; + writeln!(features, "{emoji} {feature} {remark}")?; + } + + Ok(RoomMessageEventContent::text_markdown(features)) +} + pub(super) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventContent> { let services_usage = services().memory_usage().await?; let database_usage = services().db.db.memory_usage()?; diff --git a/src/admin/server/mod.rs b/src/admin/server/mod.rs index 604dacc7f..958cc54ba 100644 --- a/src/admin/server/mod.rs +++ b/src/admin/server/mod.rs @@ -14,6 +14,18 @@ pub(super) enum ServerCommand { /// - Show configuration values ShowConfig, + /// - List the features built into the server + ListFeatures { + #[arg(short, long)] + available: bool, + + #[arg(short, long)] + enabled: bool, + + #[arg(short, long)] + comma: bool, + }, + /// - Print database memory usage statistics MemoryUsage, @@ -54,6 +66,11 @@ pub(super) async fn process(command: ServerCommand, body: Vec<&str>) -> Result<R Ok(match command { ServerCommand::Uptime => uptime(body).await?, ServerCommand::ShowConfig => show_config(body).await?, + ServerCommand::ListFeatures { + available, + enabled, + comma, + } => list_features(body, available, enabled, comma).await?, ServerCommand::MemoryUsage => memory_usage(body).await?, ServerCommand::ClearCaches => clear_caches(body).await?, ServerCommand::ListBackups => list_backups(body).await?, -- GitLab From c423a8365692c6a02780d3f4224c4a0fdcd16e0e Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Thu, 25 Jul 2024 02:59:54 +0000 Subject: [PATCH 38/47] add cli override for any configuration item Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 1 + Cargo.toml | 5 +++++ src/core/Cargo.toml | 1 + src/core/config/mod.rs | 17 +++++++++------- src/core/error/mod.rs | 6 ++++++ src/core/mod.rs | 1 + src/main/clap.rs | 46 +++++++++++++++++++++++++++++++++++++----- src/main/server.rs | 5 +++-- 8 files changed, 68 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 52ea07c9e..c73ef931b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -708,6 +708,7 @@ dependencies = [ "tikv-jemallocator", "tokio", "tokio-metrics", + "toml", "tracing", "tracing-core", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index b65ba7ad3..66ba73855 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,11 @@ version = "0.2.8" version = "0.20" features = ["features"] +[workspace.dependencies.toml] +version = "0.8.14" +default-features = false +features = ["parse"] + [workspace.dependencies.sanitize-filename] version = "0.5.0" diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index ea772a839..f915c43ea 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -86,6 +86,7 @@ tikv-jemalloc-sys.optional = true tikv-jemalloc-sys.workspace = true tokio.workspace = true tokio-metrics.workspace = true +toml.workspace = true tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index ae2735454..a9443fd2d 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -9,10 +9,8 @@ Either, Either::{Left, Right}, }; -use figment::{ - providers::{Env, Format, Toml}, - Figment, -}; +use figment::providers::{Env, Format, Toml}; +pub use figment::{value::Value as FigmentValue, Figment}; use itertools::Itertools; use regex::RegexSet; use ruma::{ @@ -408,8 +406,8 @@ struct ListeningAddr { ]; impl Config { - /// Initialize config - pub fn new(path: &Option<PathBuf>) -> Result<Self> { + /// Pre-initialize config + pub fn load(path: &Option<PathBuf>) -> Result<Figment> { let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { Figment::new() .merge(Toml::file(config_file_env).nested()) @@ -431,13 +429,18 @@ pub fn new(path: &Option<PathBuf>) -> Result<Self> { .merge(Env::prefixed("CONDUWUIT_").global().split("__")) }; + Ok(raw_config) + } + + /// Finalize config + pub fn new(raw_config: &Figment) -> Result<Self> { let config = match raw_config.extract::<Self>() { Err(e) => return Err!("There was a problem with your configuration file: {e}"), Ok(config) => config, }; // don't start if we're listening on both UNIX sockets and TCP at same time - check::is_dual_listening(&raw_config)?; + check::is_dual_listening(raw_config)?; Ok(config) } diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 8664d740a..5406d5ff3 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -57,6 +57,12 @@ pub enum Error { HttpHeader(#[from] http::header::InvalidHeaderValue), #[error("{0}")] CargoToml(#[from] cargo_toml::Error), + #[error("{0}")] + FigmentError(#[from] figment::error::Error), + #[error("{0}")] + TomlSerError(#[from] toml::ser::Error), + #[error("{0}")] + TomlDeError(#[from] toml::de::Error), // ruma #[error("{0}")] diff --git a/src/core/mod.rs b/src/core/mod.rs index b302fdcc6..5ed5ea159 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -10,6 +10,7 @@ pub mod server; pub mod utils; +pub use ::toml; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; diff --git a/src/main/clap.rs b/src/main/clap.rs index 16d88978e..92bd73c16 100644 --- a/src/main/clap.rs +++ b/src/main/clap.rs @@ -3,16 +3,24 @@ use std::path::PathBuf; use clap::Parser; -use conduit::{Config, Result}; +use conduit::{ + config::{Figment, FigmentValue}, + err, toml, Err, Result, +}; /// Commandline arguments #[derive(Parser, Debug)] #[clap(version = conduit::version(), about, long_about = None)] pub(crate) struct Args { #[arg(short, long)] - /// Optional argument to the path of a conduwuit config TOML file + /// Path to the config TOML file (optional) pub(crate) config: Option<PathBuf>, + /// Override a configuration variable using TOML 'key=value' syntax + #[arg(long, short('O'))] + pub(crate) option: Vec<String>, + + #[cfg(feature = "console")] /// Activate admin command console automatically after startup. #[arg(long, num_args(0))] pub(crate) console: bool, @@ -23,10 +31,38 @@ pub(crate) struct Args { pub(super) fn parse() -> Args { Args::parse() } /// Synthesize any command line options with configuration file options. -pub(crate) fn update(config: &mut Config, args: &Args) -> Result<()> { +pub(crate) fn update(mut config: Figment, args: &Args) -> Result<Figment> { + #[cfg(feature = "console")] // Indicate the admin console should be spawned automatically if the // configuration file hasn't already. - config.admin_console_automatic |= args.console; + if args.console { + config = config.join(("admin_console_automatic", true)); + } + + // All other individual overrides can go last in case we have options which + // set multiple conf items at once and the user still needs granular overrides. + for option in &args.option { + let (key, val) = option + .split_once('=') + .ok_or_else(|| err!("Missing '=' in -O/--option: {option:?}"))?; + + if key.is_empty() { + return Err!("Missing key= in -O/--option: {option:?}"); + } + + if val.is_empty() { + return Err!("Missing =val in -O/--option: {option:?}"); + } + + // The value has to pass for what would appear as a line in the TOML file. + let val = toml::from_str::<FigmentValue>(option)?; + let FigmentValue::Dict(_, val) = val else { + panic!("Unexpected Figment Value: {val:#?}"); + }; + + // Figment::merge() overrides existing + config = config.merge((key, val[key].clone())); + } - Ok(()) + Ok(config) } diff --git a/src/main/server.rs b/src/main/server.rs index b1cd6936e..71cdadce4 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -22,8 +22,9 @@ pub(crate) struct Server { impl Server { pub(crate) fn build(args: &Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Self>, Error> { - let mut config = Config::new(&args.config)?; - crate::clap::update(&mut config, args)?; + let raw_config = Config::load(&args.config)?; + let raw_config = crate::clap::update(raw_config, args)?; + let config = Config::new(&raw_config)?; #[cfg(feature = "sentry_telemetry")] let sentry_guard = crate::sentry::init(&config); -- GitLab From 271959ee271229e21594c7ddb2d2f52165ad8955 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Thu, 25 Jul 2024 21:29:37 +0000 Subject: [PATCH 39/47] add debug list-dependencies admin command Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/debug/commands.rs | 24 ++++++++++++++++++++++++ src/admin/debug/mod.rs | 6 ++++++ src/core/info/cargo.rs | 24 +++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 32fba4095..307495d70 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -745,3 +745,27 @@ pub(super) async fn time(_body: &[&str]) -> Result<RoomMessageEventContent> { let now = SystemTime::now(); Ok(RoomMessageEventContent::text_markdown(utils::time::format(now, "%+"))) } + +pub(super) async fn list_dependencies(_body: &[&str], names: bool) -> Result<RoomMessageEventContent> { + if names { + let out = info::cargo::dependencies_names().join(" "); + return Ok(RoomMessageEventContent::notice_markdown(out)); + } + + let deps = info::cargo::dependencies(); + let mut out = String::new(); + writeln!(out, "| name | version | features |")?; + writeln!(out, "| ---- | ------- | -------- |")?; + for (name, dep) in deps { + let version = dep.try_req().unwrap_or("*"); + let feats = dep.req_features(); + let feats = if !feats.is_empty() { + feats.join(" ") + } else { + String::new() + }; + writeln!(out, "{name} | {version} | {feats}")?; + } + + Ok(RoomMessageEventContent::notice_markdown(out)) +} diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index 77bc36b8e..babd6047d 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -178,6 +178,12 @@ pub(super) enum DebugCommand { /// - Print the current time Time, + /// - List dependencies + ListDependencies { + #[arg(short, long)] + names: bool, + }, + /// - Developer test stubs #[command(subcommand)] #[allow(non_snake_case)] diff --git a/src/core/info/cargo.rs b/src/core/info/cargo.rs index 0d2db1add..544bbb8fd 100644 --- a/src/core/info/cargo.rs +++ b/src/core/info/cargo.rs @@ -5,7 +5,7 @@ use std::sync::OnceLock; -use cargo_toml::Manifest; +use cargo_toml::{DepsSet, Manifest}; use conduit_macros::cargo_manifest; use crate::Result; @@ -36,6 +36,18 @@ /// For *enabled* features see the info::rustc module instead. static FEATURES: OnceLock<Vec<String>> = OnceLock::new(); +/// Processed list of dependencies. This is generated from the datas captured in +/// the MANIFEST. +static DEPENDENCIES: OnceLock<DepsSet> = OnceLock::new(); + +#[must_use] +pub fn dependencies_names() -> Vec<&'static str> { dependencies().keys().map(String::as_str).collect() } + +pub fn dependencies() -> &'static DepsSet { + DEPENDENCIES + .get_or_init(|| init_dependencies().unwrap_or_else(|e| panic!("Failed to initialize dependencies: {e}"))) +} + /// List of all possible features for the project. For *enabled* features in /// this build see the companion function in info::rustc. pub fn features() -> &'static Vec<String> { @@ -64,3 +76,13 @@ fn append_features(features: &mut Vec<String>, manifest: &str) -> Result<()> { Ok(()) } + +fn init_dependencies() -> Result<DepsSet> { + let manifest = Manifest::from_str(WORKSPACE_MANIFEST)?; + Ok(manifest + .workspace + .as_ref() + .expect("manifest has workspace section") + .dependencies + .clone()) +} -- GitLab From 68f42baf736258b90924cce1aee2511765d1abc2 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Thu, 25 Jul 2024 22:13:22 +0000 Subject: [PATCH 40/47] rename admin Command to CommandInput Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/handler.rs | 10 +++++----- src/service/admin/mod.rs | 18 +++++++++--------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 6acb19bfd..26a5ea419 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -11,7 +11,7 @@ OwnedEventId, }; use service::{ - admin::{Command, CommandOutput, CommandResult, HandlerResult}, + admin::{CommandInput, CommandOutput, CommandResult, HandlerResult}, Services, }; @@ -30,10 +30,10 @@ pub(super) fn complete(line: &str) -> String { } #[must_use] -pub(super) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } +pub(super) fn handle(command: CommandInput) -> HandlerResult { Box::pin(handle_command(command)) } #[tracing::instrument(skip_all, name = "admin")] -async fn handle_command(command: Command) -> CommandResult { +async fn handle_command(command: CommandInput) -> CommandResult { AssertUnwindSafe(Box::pin(process_command(&command))) .catch_unwind() .await @@ -41,7 +41,7 @@ async fn handle_command(command: Command) -> CommandResult { .or_else(|error| handle_panic(&error, command)) } -async fn process_command(command: &Command) -> CommandOutput { +async fn process_command(command: &CommandInput) -> CommandOutput { Handler { services: service::services(), } @@ -50,7 +50,7 @@ async fn process_command(command: &Command) -> CommandOutput { .and_then(|content| reply(content, command.reply_id.clone())) } -fn handle_panic(error: &Error, command: Command) -> CommandResult { +fn handle_panic(error: &Error, command: CommandInput) -> CommandResult { let link = "Please submit a [bug report](https://github.com/girlbossceo/conduwuit/issues/new). 🥺"; let msg = format!("Panic occurred while processing command:\n```\n{error:#?}\n```\n{link}"); let content = RoomMessageEventContent::notice_markdown(msg); diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index b3879deb7..6241c6684 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -26,8 +26,8 @@ pub struct Service { services: Services, - sender: Sender<Command>, - receiver: Mutex<Receiver<Command>>, + sender: Sender<CommandInput>, + receiver: Mutex<Receiver<CommandInput>>, pub handle: RwLock<Option<Handler>>, pub complete: StdRwLock<Option<Completer>>, #[cfg(feature = "console")] @@ -44,13 +44,13 @@ struct Services { } #[derive(Debug)] -pub struct Command { +pub struct CommandInput { pub command: String, pub reply_id: Option<OwnedEventId>, } pub type Completer = fn(&str) -> String; -pub type Handler = fn(Command) -> HandlerResult; +pub type Handler = fn(CommandInput) -> HandlerResult; pub type HandlerResult = Pin<Box<dyn Future<Output = CommandResult> + Send>>; pub type CommandResult = Result<CommandOutput, Error>; pub type CommandOutput = Option<RoomMessageEventContent>; @@ -129,7 +129,7 @@ pub async fn send_message(&self, message_content: RoomMessageEventContent) { } pub async fn command(&self, command: String, reply_id: Option<OwnedEventId>) { - self.send(Command { + self.send(CommandInput { command, reply_id, }) @@ -139,7 +139,7 @@ pub async fn command(&self, command: String, reply_id: Option<OwnedEventId>) { pub async fn command_in_place( &self, command: String, reply_id: Option<OwnedEventId>, ) -> Result<Option<RoomMessageEventContent>> { - self.process_command(Command { + self.process_command(CommandInput { command, reply_id, }) @@ -153,7 +153,7 @@ pub fn complete_command(&self, command: &str) -> Option<String> { .map(|complete| complete(command)) } - async fn send(&self, message: Command) { + async fn send(&self, message: CommandInput) { debug_assert!(!self.sender.is_closed(), "channel closed"); self.sender.send_async(message).await.expect("message sent"); } @@ -163,7 +163,7 @@ async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { self.console.handle_signal(sig).await; } - async fn handle_command(&self, command: Command) { + async fn handle_command(&self, command: CommandInput) { match self.process_command(command).await { Ok(Some(output)) => self.handle_response(output).await, Ok(None) => debug!("Command successful with no response"), @@ -171,7 +171,7 @@ async fn handle_command(&self, command: Command) { } } - async fn process_command(&self, command: Command) -> CommandResult { + async fn process_command(&self, command: CommandInput) -> CommandResult { if let Some(handle) = self.handle.read().await.as_ref() { handle(command).await } else { -- GitLab From 96f6a75bc82a9dabd25f84b917b5358244bd55a4 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 26 Jul 2024 06:13:30 +0000 Subject: [PATCH 41/47] add refutable pattern function macro Signed-off-by: Jason Volk <jason@zemos.net> --- src/macros/mod.rs | 4 +++ src/macros/refutable.rs | 61 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/macros/refutable.rs diff --git a/src/macros/mod.rs b/src/macros/mod.rs index b01e5275a..a0e61324a 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,6 +1,7 @@ mod admin; mod cargo; mod debug; +mod refutable; mod rustc; mod utils; @@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d #[proc_macro] pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } + +#[proc_macro_attribute] +pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) } diff --git a/src/macros/refutable.rs b/src/macros/refutable.rs new file mode 100644 index 000000000..6a6884e0e --- /dev/null +++ b/src/macros/refutable.rs @@ -0,0 +1,61 @@ +use proc_macro::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt}; + +pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { + let _args = parse_macro_input!(args as AttributeArgs); + let mut item = parse_macro_input!(input as ItemFn); + + let inputs = item.sig.inputs.clone(); + let stmt = &mut item.block.stmts; + let sig = &mut item.sig; + for (i, input) in inputs.iter().enumerate() { + let Typed(PatType { + pat, + .. + }) = input + else { + continue; + }; + + let Pat::Struct(ref pat) = **pat else { + continue; + }; + + let variant = &pat.path; + let fields = &pat.fields; + + // new versions of syn can replace this kronecker kludge with get_mut() + for (j, input) in sig.inputs.iter_mut().enumerate() { + if i != j { + continue; + } + + let Typed(PatType { + ref mut pat, + .. + }) = input + else { + continue; + }; + + let name = format!("_args_{i}"); + *pat = Box::new(Pat::Ident(PatIdent { + ident: Ident::new(&name, Span::call_site().into()), + attrs: Vec::new(), + by_ref: None, + mutability: None, + subpat: None, + })); + + let field = fields.iter(); + let refute = quote! { + let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); }; + }; + + stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error")); + } + } + + item.into_token_stream().into() +} -- GitLab From 3b5607ecdc772f9f30d572b436f66c222862b8c2 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 26 Jul 2024 06:41:26 +0000 Subject: [PATCH 42/47] add macro for out-of-line definitions Signed-off-by: Jason Volk <jason@zemos.net> --- src/macros/implement.rs | 23 +++++++++++++++++++++++ src/macros/mod.rs | 4 ++++ 2 files changed, 27 insertions(+) create mode 100644 src/macros/implement.rs diff --git a/src/macros/implement.rs b/src/macros/implement.rs new file mode 100644 index 000000000..1a06e588c --- /dev/null +++ b/src/macros/implement.rs @@ -0,0 +1,23 @@ +use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::{parse_macro_input, AttributeArgs, ItemFn, Meta, NestedMeta}; + +pub(super) fn implement(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args as AttributeArgs); + let item = parse_macro_input!(input as ItemFn); + + let NestedMeta::Meta(Meta::Path(receiver)) = args + .first() + .expect("missing path to trait or item to implement") + else { + panic!("invalid path to item for implement"); + }; + + let out = quote! { + impl #receiver { + #item + } + }; + + out.into_token_stream().into() +} diff --git a/src/macros/mod.rs b/src/macros/mod.rs index a0e61324a..8d7e2e5b7 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -1,6 +1,7 @@ mod admin; mod cargo; mod debug; +mod implement; mod refutable; mod rustc; mod utils; @@ -23,3 +24,6 @@ pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capt #[proc_macro_attribute] pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) } + +#[proc_macro_attribute] +pub fn implement(args: TokenStream, input: TokenStream) -> TokenStream { implement::implement(args, input) } -- GitLab From ca82b59c6fad717ea8b0d3143cb2759464e48ddf Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 26 Jul 2024 20:40:07 +0000 Subject: [PATCH 43/47] upgrade to syn 2.x Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 3 +- Cargo.toml | 2 +- src/core/info/cargo.rs | 14 +++++----- src/macros/Cargo.toml | 1 + src/macros/admin.rs | 35 +++++++++++++---------- src/macros/cargo.rs | 46 ++++++++++++++---------------- src/macros/debug.rs | 11 ++++---- src/macros/implement.rs | 21 ++++++++------ src/macros/mod.rs | 37 ++++++++++++++++++++---- src/macros/refutable.rs | 62 ++++++++++++++++++----------------------- src/macros/utils.rs | 15 ++++++++++ 11 files changed, 143 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c73ef931b..ea215eec2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -731,9 +731,10 @@ dependencies = [ name = "conduit_macros" version = "0.4.5" dependencies = [ + "itertools 0.13.0", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.71", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 66ba73855..17e7e7126 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -426,7 +426,7 @@ default-features = false version = "0.1" [workspace.dependencies.syn] -version = "1.0" +version = "2.0" features = ["full", "extra-traits"] [workspace.dependencies.quote] diff --git a/src/core/info/cargo.rs b/src/core/info/cargo.rs index 544bbb8fd..012a08e06 100644 --- a/src/core/info/cargo.rs +++ b/src/core/info/cargo.rs @@ -16,19 +16,19 @@ #[cargo_manifest] const WORKSPACE_MANIFEST: &'static str = (); -#[cargo_manifest("macros")] +#[cargo_manifest(crate = "macros")] const MACROS_MANIFEST: &'static str = (); -#[cargo_manifest("core")] +#[cargo_manifest(crate = "core")] const CORE_MANIFEST: &'static str = (); -#[cargo_manifest("database")] +#[cargo_manifest(crate = "database")] const DATABASE_MANIFEST: &'static str = (); -#[cargo_manifest("service")] +#[cargo_manifest(crate = "service")] const SERVICE_MANIFEST: &'static str = (); -#[cargo_manifest("admin")] +#[cargo_manifest(crate = "admin")] const ADMIN_MANIFEST: &'static str = (); -#[cargo_manifest("router")] +#[cargo_manifest(crate = "router")] const ROUTER_MANIFEST: &'static str = (); -#[cargo_manifest("main")] +#[cargo_manifest(crate = "main")] const MAIN_MANIFEST: &'static str = (); /// Processed list of features access all project crates. This is generated from diff --git a/src/macros/Cargo.toml b/src/macros/Cargo.toml index ca98f169e..9e8665784 100644 --- a/src/macros/Cargo.toml +++ b/src/macros/Cargo.toml @@ -18,6 +18,7 @@ proc-macro = true syn.workspace = true quote.workspace = true proc-macro2.workspace = true +itertools.workspace = true [lints] workspace = true diff --git a/src/macros/admin.rs b/src/macros/admin.rs index e1d294b91..4189d64f9 100644 --- a/src/macros/admin.rs +++ b/src/macros/admin.rs @@ -1,17 +1,15 @@ +use itertools::Itertools; use proc_macro::{Span, TokenStream}; use proc_macro2::TokenStream as TokenStream2; -use quote::quote; -use syn::{parse_macro_input, AttributeArgs, Fields, Ident, ItemEnum, Variant}; +use quote::{quote, ToTokens}; +use syn::{Error, Fields, Ident, ItemEnum, Meta, Variant}; -use crate::utils::camel_to_snake_string; +use crate::{utils::camel_to_snake_string, Result}; -pub(super) fn command_dispatch(args: TokenStream, input_: TokenStream) -> TokenStream { - let input = input_.clone(); - let item = parse_macro_input!(input as ItemEnum); - let _args = parse_macro_input!(args as AttributeArgs); - let arm = item.variants.iter().map(dispatch_arm); - let name = item.ident; - let q = quote! { +pub(super) fn command_dispatch(item: ItemEnum, _args: &[Meta]) -> Result<TokenStream> { + let name = &item.ident; + let arm: Vec<TokenStream2> = item.variants.iter().map(dispatch_arm).try_collect()?; + let switch = quote! { pub(super) async fn process(command: #name, body: Vec<&str>) -> Result<RoomMessageEventContent> { use #name::*; #[allow(non_snake_case)] @@ -21,14 +19,17 @@ pub(super) async fn process(command: #name, body: Vec<&str>) -> Result<RoomMessa } }; - [input_, q.into()].into_iter().collect::<TokenStream>() + Ok([item.into_token_stream(), switch] + .into_iter() + .collect::<TokenStream2>() + .into()) } -fn dispatch_arm(v: &Variant) -> TokenStream2 { +fn dispatch_arm(v: &Variant) -> Result<TokenStream2> { let name = &v.ident; let target = camel_to_snake_string(&format!("{name}")); let handler = Ident::new(&target, Span::call_site().into()); - match &v.fields { + let res = match &v.fields { Fields::Named(fields) => { let field = fields.named.iter().filter_map(|f| f.ident.as_ref()); let arg = field.clone(); @@ -37,7 +38,9 @@ fn dispatch_arm(v: &Variant) -> TokenStream2 { } }, Fields::Unnamed(fields) => { - let field = &fields.unnamed.first().expect("one field"); + let Some(ref field) = fields.unnamed.first() else { + return Err(Error::new(Span::call_site().into(), "One unnamed field required")); + }; quote! { #name ( #field ) => Box::pin(#handler::process(#field, body)).await?, } @@ -47,5 +50,7 @@ fn dispatch_arm(v: &Variant) -> TokenStream2 { #name => Box::pin(#handler(&body)).await?, } }, - } + }; + + Ok(res) } diff --git a/src/macros/cargo.rs b/src/macros/cargo.rs index 17132a6ca..cd36658ef 100644 --- a/src/macros/cargo.rs +++ b/src/macros/cargo.rs @@ -1,39 +1,34 @@ use std::{fs::read_to_string, path::PathBuf}; -use proc_macro::TokenStream; +use proc_macro::{Span, TokenStream}; use quote::quote; -use syn::{parse_macro_input, AttributeArgs, ItemConst, Lit, NestedMeta}; - -pub(super) fn manifest(args: TokenStream, item: TokenStream) -> TokenStream { - let item = parse_macro_input!(item as ItemConst); - let args = parse_macro_input!(args as AttributeArgs); - let member = args.into_iter().find_map(|arg| { - let NestedMeta::Lit(arg) = arg else { - return None; - }; - let Lit::Str(arg) = arg else { - return None; - }; - Some(arg.value()) - }); - - let path = manifest_path(member.as_deref()); - let manifest = read_to_string(&path).unwrap_or_default(); +use syn::{Error, ItemConst, Meta}; - let name = item.ident; +use crate::{utils, Result}; + +pub(super) fn manifest(item: ItemConst, args: &[Meta]) -> Result<TokenStream> { + let member = utils::get_named_string(args, "crate"); + let path = manifest_path(member.as_deref())?; + let manifest = read_to_string(&path).unwrap_or_default(); let val = manifest.as_str(); + let name = item.ident; let ret = quote! { const #name: &'static str = #val; }; - ret.into() + Ok(ret.into()) } #[allow(clippy::option_env_unwrap)] -fn manifest_path(member: Option<&str>) -> PathBuf { - let mut path: PathBuf = option_env!("CARGO_MANIFEST_DIR") - .expect("missing CARGO_MANIFEST_DIR in environment") - .into(); +fn manifest_path(member: Option<&str>) -> Result<PathBuf> { + let Some(path) = option_env!("CARGO_MANIFEST_DIR") else { + return Err(Error::new( + Span::call_site().into(), + "missing CARGO_MANIFEST_DIR in environment", + )); + }; + + let mut path: PathBuf = path.into(); // conduwuit/src/macros/ -> conduwuit/src/ path.pop(); @@ -47,5 +42,6 @@ fn manifest_path(member: Option<&str>) -> PathBuf { } path.push("Cargo.toml"); - path + + Ok(path) } diff --git a/src/macros/debug.rs b/src/macros/debug.rs index 5251fa179..e83fd44ec 100644 --- a/src/macros/debug.rs +++ b/src/macros/debug.rs @@ -1,13 +1,12 @@ use std::cmp; use proc_macro::TokenStream; -use syn::{parse_macro_input, AttributeArgs, Item}; +use quote::ToTokens; +use syn::{Item, Meta}; -pub(super) fn recursion_depth(args: TokenStream, item_: TokenStream) -> TokenStream { - let item = item_.clone(); - let item = parse_macro_input!(item as Item); - let _args = parse_macro_input!(args as AttributeArgs); +use crate::Result; +pub(super) fn recursion_depth(item: Item, _args: &[Meta]) -> Result<TokenStream> { let mut best: usize = 0; let mut count: usize = 0; // think you'd find a fancy recursive ast visitor? think again @@ -24,5 +23,5 @@ pub(super) fn recursion_depth(args: TokenStream, item_: TokenStream) -> TokenStr println!("DEPTH: {best}"); println!("LENGTH: {count}"); - item_ + Ok(item.into_token_stream().into()) } diff --git a/src/macros/implement.rs b/src/macros/implement.rs index 1a06e588c..b5c8d787a 100644 --- a/src/macros/implement.rs +++ b/src/macros/implement.rs @@ -1,23 +1,26 @@ use proc_macro::TokenStream; -use quote::{quote, ToTokens}; -use syn::{parse_macro_input, AttributeArgs, ItemFn, Meta, NestedMeta}; +use quote::quote; +use syn::{ItemFn, Meta, MetaList}; -pub(super) fn implement(args: TokenStream, input: TokenStream) -> TokenStream { - let args = parse_macro_input!(args as AttributeArgs); - let item = parse_macro_input!(input as ItemFn); +use crate::Result; - let NestedMeta::Meta(Meta::Path(receiver)) = args +pub(super) fn implement(item: ItemFn, args: &[Meta]) -> Result<TokenStream> { + let Meta::List(MetaList { + path, + .. + }) = &args .first() .expect("missing path to trait or item to implement") else { panic!("invalid path to item for implement"); }; + let input = item; let out = quote! { - impl #receiver { - #item + impl #path { + #input } }; - out.into_token_stream().into() + Ok(out.into()) } diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 8d7e2e5b7..1a5494bb0 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -7,23 +7,50 @@ mod utils; use proc_macro::TokenStream; +use syn::{ + parse::{Parse, Parser}, + parse_macro_input, Error, Item, ItemConst, ItemEnum, ItemFn, Meta, +}; + +pub(crate) type Result<T> = std::result::Result<T, Error>; #[proc_macro_attribute] pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStream { - admin::command_dispatch(args, input) + attribute_macro::<ItemEnum, _>(args, input, admin::command_dispatch) } #[proc_macro_attribute] -pub fn cargo_manifest(args: TokenStream, input: TokenStream) -> TokenStream { cargo::manifest(args, input) } +pub fn cargo_manifest(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<ItemConst, _>(args, input, cargo::manifest) +} #[proc_macro_attribute] -pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { debug::recursion_depth(args, input) } +pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<Item, _>(args, input, debug::recursion_depth) +} #[proc_macro] pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) } #[proc_macro_attribute] -pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) } +pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<ItemFn, _>(args, input, refutable::refutable) +} #[proc_macro_attribute] -pub fn implement(args: TokenStream, input: TokenStream) -> TokenStream { implement::implement(args, input) } +pub fn implement(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<ItemFn, _>(args, input, implement::implement) +} + +fn attribute_macro<I, F>(args: TokenStream, input: TokenStream, func: F) -> TokenStream +where + F: Fn(I, &[Meta]) -> Result<TokenStream>, + I: Parse, +{ + let item = parse_macro_input!(input as I); + syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated + .parse(args) + .map(|args| args.iter().cloned().collect::<Vec<_>>()) + .and_then(|ref args| func(item, args)) + .unwrap_or_else(|e| e.to_compile_error().into()) +} diff --git a/src/macros/refutable.rs b/src/macros/refutable.rs index 6a6884e0e..facb4729d 100644 --- a/src/macros/refutable.rs +++ b/src/macros/refutable.rs @@ -1,11 +1,10 @@ use proc_macro::{Span, TokenStream}; use quote::{quote, ToTokens}; -use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt}; +use syn::{FnArg::Typed, Ident, ItemFn, Meta, Pat, PatIdent, PatType, Stmt}; -pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { - let _args = parse_macro_input!(args as AttributeArgs); - let mut item = parse_macro_input!(input as ItemFn); +use crate::Result; +pub(super) fn refutable(mut item: ItemFn, _args: &[Meta]) -> Result<TokenStream> { let inputs = item.sig.inputs.clone(); let stmt = &mut item.block.stmts; let sig = &mut item.sig; @@ -25,37 +24,30 @@ pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { let variant = &pat.path; let fields = &pat.fields; - // new versions of syn can replace this kronecker kludge with get_mut() - for (j, input) in sig.inputs.iter_mut().enumerate() { - if i != j { - continue; - } - - let Typed(PatType { - ref mut pat, - .. - }) = input - else { - continue; - }; - - let name = format!("_args_{i}"); - *pat = Box::new(Pat::Ident(PatIdent { - ident: Ident::new(&name, Span::call_site().into()), - attrs: Vec::new(), - by_ref: None, - mutability: None, - subpat: None, - })); - - let field = fields.iter(); - let refute = quote! { - let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); }; - }; - - stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error")); - } + let Some(Typed(PatType { + ref mut pat, + .. + })) = sig.inputs.get_mut(i) + else { + continue; + }; + + let name = format!("_args_{i}"); + *pat = Box::new(Pat::Ident(PatIdent { + ident: Ident::new(&name, Span::call_site().into()), + attrs: Vec::new(), + by_ref: None, + mutability: None, + subpat: None, + })); + + let field = fields.iter(); + let refute = quote! { + let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); }; + }; + + stmt.insert(0, syn::parse2::<Stmt>(refute)?); } - item.into_token_stream().into() + Ok(item.into_token_stream().into()) } diff --git a/src/macros/utils.rs b/src/macros/utils.rs index f512c56c2..7ae55d399 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -1,3 +1,18 @@ +use syn::{Expr, Lit, Meta}; + +pub(crate) fn get_named_string(args: &[Meta], name: &str) -> Option<String> { + args.iter().find_map(|arg| { + let value = arg.require_name_value().ok()?; + let Expr::Lit(ref lit) = value.value else { + return None; + }; + let Lit::Str(ref str) = lit.lit else { + return None; + }; + value.path.is_ident(name).then_some(str.value()) + }) +} + #[must_use] pub(crate) fn camel_to_snake_string(s: &str) -> String { let mut output = String::with_capacity( -- GitLab From 7a3cc3941e825a3d063ab365a51a9b8b1a4e2cbd Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Fri, 26 Jul 2024 21:31:31 +0000 Subject: [PATCH 44/47] parse generics for implement macro Signed-off-by: Jason Volk <jason@zemos.net> --- src/macros/implement.rs | 34 ++++++++++++++++++++-------------- src/macros/utils.rs | 10 +++++++++- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/macros/implement.rs b/src/macros/implement.rs index b5c8d787a..b798bae5b 100644 --- a/src/macros/implement.rs +++ b/src/macros/implement.rs @@ -1,26 +1,32 @@ -use proc_macro::TokenStream; +use proc_macro::{Span, TokenStream}; use quote::quote; -use syn::{ItemFn, Meta, MetaList}; +use syn::{Error, ItemFn, Meta, Path}; +use utils::get_named_generics; -use crate::Result; +use crate::{utils, Result}; pub(super) fn implement(item: ItemFn, args: &[Meta]) -> Result<TokenStream> { - let Meta::List(MetaList { - path, - .. - }) = &args - .first() - .expect("missing path to trait or item to implement") - else { - panic!("invalid path to item for implement"); - }; - + let generics = get_named_generics(args, "generics")?; + let receiver = get_receiver(args)?; + let params = get_named_generics(args, "params")?; let input = item; let out = quote! { - impl #path { + impl #generics #receiver #params { #input } }; Ok(out.into()) } + +fn get_receiver(args: &[Meta]) -> Result<Path> { + let receiver = &args + .first() + .ok_or_else(|| Error::new(Span::call_site().into(), "Missing required argument to receiver"))?; + + let Meta::Path(receiver) = receiver else { + return Err(Error::new(Span::call_site().into(), "First argument is not path to receiver")); + }; + + Ok(receiver.clone()) +} diff --git a/src/macros/utils.rs b/src/macros/utils.rs index 7ae55d399..58074e3a0 100644 --- a/src/macros/utils.rs +++ b/src/macros/utils.rs @@ -1,4 +1,12 @@ -use syn::{Expr, Lit, Meta}; +use syn::{parse_str, Expr, Generics, Lit, Meta}; + +use crate::Result; + +pub(crate) fn get_named_generics(args: &[Meta], name: &str) -> Result<Generics> { + const DEFAULT: &str = "<>"; + + parse_str::<Generics>(&get_named_string(args, name).unwrap_or_else(|| DEFAULT.to_owned())) +} pub(crate) fn get_named_string(args: &[Meta], name: &str) -> Option<String> { args.iter().find_map(|arg| { -- GitLab From 7e50db419362c6838d5eaeefccee24c7686b57e3 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sat, 27 Jul 2024 00:11:41 +0000 Subject: [PATCH 45/47] de-global services from admin Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/admin.rs | 34 +-- src/admin/appservice/commands.rs | 29 +- src/admin/appservice/mod.rs | 23 +- src/admin/check/commands.rs | 6 +- src/admin/check/mod.rs | 6 +- src/admin/command.rs | 6 + src/admin/debug/commands.rs | 258 +++++++++------- src/admin/debug/mod.rs | 6 +- src/admin/debug/tester.rs | 18 +- src/admin/federation/commands.rs | 39 ++- src/admin/federation/mod.rs | 23 +- src/admin/handler.rs | 23 +- src/admin/media/commands.rs | 37 +-- src/admin/media/mod.rs | 19 +- src/admin/mod.rs | 7 +- src/admin/query/account_data.rs | 46 ++- src/admin/query/appservice.rs | 30 +- src/admin/query/globals.rs | 49 ++- src/admin/query/mod.rs | 281 +----------------- src/admin/query/presence.rs | 36 ++- src/admin/query/resolver.rs | 44 ++- src/admin/query/room_alias.rs | 41 ++- src/admin/query/room_state_cache.rs | 170 +++++++---- src/admin/query/sending.rs | 91 +++++- src/admin/query/users.rs | 19 +- .../room/{room_alias_commands.rs => alias.rs} | 67 ++++- .../room/{room_commands.rs => commands.rs} | 19 +- ...oom_directory_commands.rs => directory.rs} | 38 ++- .../room/{room_info_commands.rs => info.rs} | 51 ++-- src/admin/room/mod.rs | 166 +---------- ...m_moderation_commands.rs => moderation.rs} | 197 +++++++----- src/admin/server/commands.rs | 77 +++-- src/admin/server/mod.rs | 35 +-- src/admin/user/commands.rs | 180 ++++++----- src/admin/user/mod.rs | 56 +--- src/macros/admin.rs | 22 +- src/macros/mod.rs | 5 + 37 files changed, 1129 insertions(+), 1125 deletions(-) create mode 100644 src/admin/command.rs rename src/admin/room/{room_alias_commands.rs => alias.rs} (73%) rename src/admin/room/{room_commands.rs => commands.rs} (82%) rename src/admin/room/{room_directory_commands.rs => directory.rs} (68%) rename src/admin/room/{room_info_commands.rs => info.rs} (54%) rename src/admin/room/{room_moderation_commands.rs => moderation.rs} (72%) diff --git a/src/admin/admin.rs b/src/admin/admin.rs index f5fe5dc22..fa4972056 100644 --- a/src/admin/admin.rs +++ b/src/admin/admin.rs @@ -3,14 +3,14 @@ use ruma::events::room::message::RoomMessageEventContent; use crate::{ - appservice, appservice::AppserviceCommand, check, check::CheckCommand, debug, debug::DebugCommand, federation, - federation::FederationCommand, media, media::MediaCommand, query, query::QueryCommand, room, room::RoomCommand, - server, server::ServerCommand, user, user::UserCommand, + appservice, appservice::AppserviceCommand, check, check::CheckCommand, command::Command, debug, + debug::DebugCommand, federation, federation::FederationCommand, media, media::MediaCommand, query, + query::QueryCommand, room, room::RoomCommand, server, server::ServerCommand, user, user::UserCommand, }; #[derive(Debug, Parser)] #[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] -pub(crate) enum AdminCommand { +pub(super) enum AdminCommand { #[command(subcommand)] /// - Commands for managing appservices Appservices(AppserviceCommand), @@ -49,18 +49,18 @@ pub(crate) enum AdminCommand { } #[tracing::instrument(skip_all, name = "command")] -pub(crate) async fn process(command: AdminCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - let reply_message_content = match command { - AdminCommand::Appservices(command) => appservice::process(command, body).await?, - AdminCommand::Media(command) => media::process(command, body).await?, - AdminCommand::Users(command) => user::process(command, body).await?, - AdminCommand::Rooms(command) => room::process(command, body).await?, - AdminCommand::Federation(command) => federation::process(command, body).await?, - AdminCommand::Server(command) => server::process(command, body).await?, - AdminCommand::Debug(command) => debug::process(command, body).await?, - AdminCommand::Query(command) => query::process(command, body).await?, - AdminCommand::Check(command) => check::process(command, body).await?, - }; +pub(super) async fn process(command: AdminCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + use AdminCommand::*; - Ok(reply_message_content) + Ok(match command { + Appservices(command) => appservice::process(command, context).await?, + Media(command) => media::process(command, context).await?, + Users(command) => user::process(command, context).await?, + Rooms(command) => room::process(command, context).await?, + Federation(command) => federation::process(command, context).await?, + Server(command) => server::process(command, context).await?, + Debug(command) => debug::process(command, context).await?, + Query(command) => query::process(command, context).await?, + Check(command) => check::process(command, context).await?, + }) } diff --git a/src/admin/appservice/commands.rs b/src/admin/appservice/commands.rs index 5cce83510..7d6378f31 100644 --- a/src/admin/appservice/commands.rs +++ b/src/admin/appservice/commands.rs @@ -1,18 +1,20 @@ use ruma::{api::appservice::Registration, events::room::message::RoomMessageEventContent}; -use crate::{services, Result}; +use crate::{admin_command, Result}; -pub(super) async fn register(body: Vec<&str>) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn register(&self) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let appservice_config = body[1..body.len().checked_sub(1).unwrap()].join("\n"); + let appservice_config = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); let parsed_config = serde_yaml::from_str::<Registration>(&appservice_config); match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml).await { + Ok(yaml) => match self.services.appservice.register_appservice(yaml).await { Ok(id) => Ok(RoomMessageEventContent::text_plain(format!( "Appservice registered with ID: {id}." ))), @@ -26,8 +28,10 @@ pub(super) async fn register(body: Vec<&str>) -> Result<RoomMessageEventContent> } } -pub(super) async fn unregister(_body: Vec<&str>, appservice_identifier: String) -> Result<RoomMessageEventContent> { - match services() +#[admin_command] +pub(super) async fn unregister(&self, appservice_identifier: String) -> Result<RoomMessageEventContent> { + match self + .services .appservice .unregister_appservice(&appservice_identifier) .await @@ -39,8 +43,10 @@ pub(super) async fn unregister(_body: Vec<&str>, appservice_identifier: String) } } -pub(super) async fn show(_body: Vec<&str>, appservice_identifier: String) -> Result<RoomMessageEventContent> { - match services() +#[admin_command] +pub(super) async fn show_appservice_config(&self, appservice_identifier: String) -> Result<RoomMessageEventContent> { + match self + .services .appservice .get_registration(&appservice_identifier) .await @@ -54,8 +60,9 @@ pub(super) async fn show(_body: Vec<&str>, appservice_identifier: String) -> Res } } -pub(super) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let appservices = services().appservice.iter_ids().await; +#[admin_command] +pub(super) async fn list_registered(&self) -> Result<RoomMessageEventContent> { + let appservices = self.services.appservice.iter_ids().await; let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", ")); Ok(RoomMessageEventContent::text_plain(output)) } diff --git a/src/admin/appservice/mod.rs b/src/admin/appservice/mod.rs index 81e04087c..ca5f46bba 100644 --- a/src/admin/appservice/mod.rs +++ b/src/admin/appservice/mod.rs @@ -2,11 +2,11 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; -use self::commands::*; +use crate::admin_command_dispatch; #[derive(Debug, Subcommand)] +#[admin_command_dispatch] pub(super) enum AppserviceCommand { /// - Register an appservice using its registration YAML /// @@ -28,24 +28,13 @@ pub(super) enum AppserviceCommand { /// - Show an appservice's config using its ID /// /// You can find the ID using the `list-appservices` command. - Show { + #[clap(alias("show"))] + ShowAppserviceConfig { /// The appservice to show appservice_identifier: String, }, /// - List all the currently registered appservices - List, -} - -pub(super) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - AppserviceCommand::Register => register(body).await?, - AppserviceCommand::Unregister { - appservice_identifier, - } => unregister(body, appservice_identifier).await?, - AppserviceCommand::Show { - appservice_identifier, - } => show(body, appservice_identifier).await?, - AppserviceCommand::List => list(body).await?, - }) + #[clap(alias("list"))] + ListRegistered, } diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index 1fbea8f6e..a757d5044 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -1,12 +1,14 @@ use conduit::Result; +use conduit_macros::implement; use ruma::events::room::message::RoomMessageEventContent; -use crate::services; +use crate::{services, Command}; /// Uses the iterator in `src/database/key_value/users.rs` to iterator over /// every user in our database (remote and local). Reports total count, any /// errors if there were any, etc -pub(super) async fn check_all_users(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[implement(Command, params = "<'_>")] +pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> { let timer = tokio::time::Instant::now(); let results = services().users.db.iter(); let query_time = timer.elapsed(); diff --git a/src/admin/check/mod.rs b/src/admin/check/mod.rs index f1cfa2b94..e543e5b54 100644 --- a/src/admin/check/mod.rs +++ b/src/admin/check/mod.rs @@ -4,15 +4,15 @@ use conduit::Result; use ruma::events::room::message::RoomMessageEventContent; -use self::commands::*; +use crate::Command; #[derive(Debug, Subcommand)] pub(super) enum CheckCommand { AllUsers, } -pub(super) async fn process(command: CheckCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { +pub(super) async fn process(command: CheckCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { Ok(match command { - CheckCommand::AllUsers => check_all_users(body).await?, + CheckCommand::AllUsers => context.check_all_users().await?, }) } diff --git a/src/admin/command.rs b/src/admin/command.rs new file mode 100644 index 000000000..fbfdd2ba8 --- /dev/null +++ b/src/admin/command.rs @@ -0,0 +1,6 @@ +use service::Services; + +pub(crate) struct Command<'a> { + pub(crate) services: &'a Services, + pub(crate) body: &'a [&'a str], +} diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 307495d70..cf4ab31d7 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -16,19 +16,22 @@ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::services; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; -pub(super) async fn echo(_body: &[&str], message: Vec<String>) -> Result<RoomMessageEventContent> { +use crate::admin_command; + +#[admin_command] +pub(super) async fn echo(&self, message: Vec<String>) -> Result<RoomMessageEventContent> { let message = message.join(" "); Ok(RoomMessageEventContent::notice_plain(message)) } -pub(super) async fn get_auth_chain(_body: &[&str], event_id: Box<EventId>) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn get_auth_chain(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { let event_id = Arc::<EventId>::from(event_id); - if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { + if let Some(event) = self.services.rooms.timeline.get_pdu_json(&event_id)? { let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) @@ -36,13 +39,16 @@ pub(super) async fn get_auth_chain(_body: &[&str], event_id: Box<EventId>) -> Re let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let start = Instant::now(); - let count = services() + let count = self + .services .rooms .auth_chain .event_ids_iter(room_id, vec![event_id]) .await? .count(); + let elapsed = start.elapsed(); Ok(RoomMessageEventContent::text_plain(format!( "Loaded auth chain with length {count} in {elapsed:?}" @@ -52,14 +58,16 @@ pub(super) async fn get_auth_chain(_body: &[&str], event_id: Box<EventId>) -> Re } } -pub(super) async fn parse_pdu(body: &[&str]) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn parse_pdu(&self) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let string = body[1..body.len().saturating_sub(1)].join("\n"); + let string = self.body[1..self.body.len().saturating_sub(1)].join("\n"); match serde_json::from_str(&string) { Ok(value) => match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { Ok(hash) => { @@ -80,15 +88,17 @@ pub(super) async fn parse_pdu(body: &[&str]) -> Result<RoomMessageEventContent> } } -pub(super) async fn get_pdu(_body: &[&str], event_id: Box<EventId>) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn get_pdu(&self, event_id: Box<EventId>) -> Result<RoomMessageEventContent> { let mut outlier = false; - let mut pdu_json = services() + let mut pdu_json = self + .services .rooms .timeline .get_non_outlier_pdu_json(&event_id)?; if pdu_json.is_none() { outlier = true; - pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; + pdu_json = self.services.rooms.timeline.get_pdu_json(&event_id)?; } match pdu_json { Some(json) => { @@ -107,39 +117,42 @@ pub(super) async fn get_pdu(_body: &[&str], event_id: Box<EventId>) -> Result<Ro } } +#[admin_command] pub(super) async fn get_remote_pdu_list( - body: &[&str], server: Box<ServerName>, force: bool, + &self, server: Box<ServerName>, force: bool, ) -> Result<RoomMessageEventContent> { - if !services().globals.config.allow_federation { + if !self.services.globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( "Federation is disabled on this homeserver.", )); } - if server == services().globals.server_name() { + if server == self.services.globals.server_name() { return Ok(RoomMessageEventContent::text_plain( "Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs from \ the database.", )); } - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let list = body + let list = self + .body .iter() .collect::<Vec<_>>() - .drain(1..body.len().saturating_sub(1)) + .drain(1..self.body.len().saturating_sub(1)) .filter_map(|pdu| EventId::parse(pdu).ok()) .collect::<Vec<_>>(); for pdu in list { if force { - if let Err(e) = get_remote_pdu(&[], Box::from(pdu), server.clone()).await { - services() + if let Err(e) = self.get_remote_pdu(Box::from(pdu), server.clone()).await { + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "Failed to get remote PDU, ignoring error: {e}" @@ -148,29 +161,31 @@ pub(super) async fn get_remote_pdu_list( warn!(%e, "Failed to get remote PDU, ignoring error"); } } else { - get_remote_pdu(&[], Box::from(pdu), server.clone()).await?; + self.get_remote_pdu(Box::from(pdu), server.clone()).await?; } } Ok(RoomMessageEventContent::text_plain("Fetched list of remote PDUs.")) } +#[admin_command] pub(super) async fn get_remote_pdu( - _body: &[&str], event_id: Box<EventId>, server: Box<ServerName>, + &self, event_id: Box<EventId>, server: Box<ServerName>, ) -> Result<RoomMessageEventContent> { - if !services().globals.config.allow_federation { + if !self.services.globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( "Federation is disabled on this homeserver.", )); } - if server == services().globals.server_name() { + if server == self.services.globals.server_name() { return Ok(RoomMessageEventContent::text_plain( "Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs.", )); } - match services() + match self + .services .sending .send_federation_request( &server, @@ -191,7 +206,8 @@ pub(super) async fn get_remote_pdu( debug!("Attempting to parse PDU: {:?}", &response.pdu); let parsed_pdu = { - let parsed_result = services() + let parsed_result = self + .services .rooms .event_handler .parse_incoming_pdu(&response.pdu); @@ -212,7 +228,7 @@ pub(super) async fn get_remote_pdu( let pub_key_map = RwLock::new(BTreeMap::new()); debug!("Attempting to fetch homeserver signing keys for {server}"); - services() + self.services .rooms .event_handler .fetch_required_signing_keys(parsed_pdu.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) @@ -222,7 +238,7 @@ pub(super) async fn get_remote_pdu( }); info!("Attempting to handle event ID {event_id} as backfilled PDU"); - services() + self.services .rooms .timeline .backfill_pdu(&server, response.pdu, &pub_key_map) @@ -241,9 +257,11 @@ pub(super) async fn get_remote_pdu( } } -pub(super) async fn get_room_state(_body: &[&str], room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> { - let room_id = services().rooms.alias.resolve(&room).await?; - let room_state = services() +#[admin_command] +pub(super) async fn get_room_state(&self, room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> { + let room_id = self.services.rooms.alias.resolve(&room).await?; + let room_state = self + .services .rooms .state_accessor .room_state_full(&room_id) @@ -268,8 +286,9 @@ pub(super) async fn get_room_state(_body: &[&str], room: OwnedRoomOrAliasId) -> Ok(RoomMessageEventContent::notice_markdown(format!("```json\n{json}\n```"))) } -pub(super) async fn ping(_body: &[&str], server: Box<ServerName>) -> Result<RoomMessageEventContent> { - if server == services().globals.server_name() { +#[admin_command] +pub(super) async fn ping(&self, server: Box<ServerName>) -> Result<RoomMessageEventContent> { + if server == self.services.globals.server_name() { return Ok(RoomMessageEventContent::text_plain( "Not allowed to send federation requests to ourselves.", )); @@ -277,7 +296,8 @@ pub(super) async fn ping(_body: &[&str], server: Box<ServerName>) -> Result<Room let timer = tokio::time::Instant::now(); - match services() + match self + .services .sending .send_federation_request(&server, ruma::api::federation::discovery::get_server_version::v1::Request {}) .await @@ -306,23 +326,23 @@ pub(super) async fn ping(_body: &[&str], server: Box<ServerName>) -> Result<Room } } -pub(super) async fn force_device_list_updates(_body: &[&str]) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn force_device_list_updates(&self) -> Result<RoomMessageEventContent> { // Force E2EE device list updates for all users - for user_id in services().users.iter().filter_map(Result::ok) { - services().users.mark_device_key_update(&user_id)?; + for user_id in self.services.users.iter().filter_map(Result::ok) { + self.services.users.mark_device_key_update(&user_id)?; } Ok(RoomMessageEventContent::text_plain( "Marked all devices for all users as having new keys to update", )) } -pub(super) async fn change_log_level( - _body: &[&str], filter: Option<String>, reset: bool, -) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn change_log_level(&self, filter: Option<String>, reset: bool) -> Result<RoomMessageEventContent> { let handles = &["console"]; if reset { - let old_filter_layer = match EnvFilter::try_new(&services().globals.config.log) { + let old_filter_layer = match EnvFilter::try_new(&self.services.globals.config.log) { Ok(s) => s, Err(e) => { return Ok(RoomMessageEventContent::text_plain(format!( @@ -331,7 +351,8 @@ pub(super) async fn change_log_level( }, }; - match services() + match self + .services .server .log .reload @@ -340,7 +361,7 @@ pub(super) async fn change_log_level( Ok(()) => { return Ok(RoomMessageEventContent::text_plain(format!( "Successfully changed log level back to config value {}", - services().globals.config.log + self.services.globals.config.log ))); }, Err(e) => { @@ -361,7 +382,8 @@ pub(super) async fn change_log_level( }, }; - match services() + match self + .services .server .log .reload @@ -381,19 +403,21 @@ pub(super) async fn change_log_level( Ok(RoomMessageEventContent::text_plain("No log level was specified.")) } -pub(super) async fn sign_json(body: &[&str]) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn sign_json(&self) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let string = body[1..body.len().checked_sub(1).unwrap()].join("\n"); + let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); match serde_json::from_str(&string) { Ok(mut value) => { ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), + self.services.globals.server_name().as_str(), + self.services.globals.keypair(), &mut value, ) .expect("our request json is what ruma expects"); @@ -404,19 +428,21 @@ pub(super) async fn sign_json(body: &[&str]) -> Result<RoomMessageEventContent> } } -pub(super) async fn verify_json(body: &[&str]) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn verify_json(&self) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let string = body[1..body.len().checked_sub(1).unwrap()].join("\n"); + let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n"); match serde_json::from_str(&string) { Ok(value) => { let pub_key_map = RwLock::new(BTreeMap::new()); - services() + self.services .rooms .event_handler .fetch_required_signing_keys([&value], &pub_key_map) @@ -434,19 +460,22 @@ pub(super) async fn verify_json(body: &[&str]) -> Result<RoomMessageEventContent } } -#[tracing::instrument(skip(_body))] -pub(super) async fn first_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - if !services() +#[admin_command] +#[tracing::instrument(skip(self))] +pub(super) async fn first_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + if !self + .services .rooms .state_cache - .server_in_room(&services().globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id)? { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", )); } - let first_pdu = services() + let first_pdu = self + .services .rooms .timeline .first_pdu_in_room(&room_id)? @@ -455,19 +484,22 @@ pub(super) async fn first_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> R Ok(RoomMessageEventContent::text_plain(format!("{first_pdu:?}"))) } -#[tracing::instrument(skip(_body))] -pub(super) async fn latest_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - if !services() +#[admin_command] +#[tracing::instrument(skip(self))] +pub(super) async fn latest_pdu_in_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + if !self + .services .rooms .state_cache - .server_in_room(&services().globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id)? { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", )); } - let latest_pdu = services() + let latest_pdu = self + .services .rooms .timeline .latest_pdu_in_room(&room_id)? @@ -476,32 +508,36 @@ pub(super) async fn latest_pdu_in_room(_body: &[&str], room_id: Box<RoomId>) -> Ok(RoomMessageEventContent::text_plain(format!("{latest_pdu:?}"))) } -#[tracing::instrument(skip(_body))] +#[admin_command] +#[tracing::instrument(skip(self))] pub(super) async fn force_set_room_state_from_server( - _body: &[&str], room_id: Box<RoomId>, server_name: Box<ServerName>, + &self, room_id: Box<RoomId>, server_name: Box<ServerName>, ) -> Result<RoomMessageEventContent> { - if !services() + if !self + .services .rooms .state_cache - .server_in_room(&services().globals.config.server_name, &room_id)? + .server_in_room(&self.services.globals.config.server_name, &room_id)? { return Ok(RoomMessageEventContent::text_plain( "We are not participating in the room / we don't know about the room ID.", )); } - let first_pdu = services() + let first_pdu = self + .services .rooms .timeline .latest_pdu_in_room(&room_id)? .ok_or_else(|| Error::bad_database("Failed to find the latest PDU in database"))?; - let room_version = services().rooms.state.get_room_version(&room_id)?; + let room_version = self.services.rooms.state.get_room_version(&room_id)?; let mut state: HashMap<u64, Arc<EventId>> = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); - let remote_state_response = services() + let remote_state_response = self + .services .sending .send_federation_request( &server_name, @@ -515,7 +551,7 @@ pub(super) async fn force_set_room_state_from_server( let mut events = Vec::with_capacity(remote_state_response.pdus.len()); for pdu in remote_state_response.pdus.clone() { - events.push(match services().rooms.event_handler.parse_incoming_pdu(&pdu) { + events.push(match self.services.rooms.event_handler.parse_incoming_pdu(&pdu) { Ok(t) => t, Err(e) => { warn!("Could not parse PDU, ignoring: {e}"); @@ -525,7 +561,7 @@ pub(super) async fn force_set_room_state_from_server( } info!("Fetching required signing keys for all the state events we got"); - services() + self.services .rooms .event_handler .fetch_required_signing_keys(events.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) @@ -535,7 +571,7 @@ pub(super) async fn force_set_room_state_from_server( for result in remote_state_response .pdus .iter() - .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; @@ -546,12 +582,13 @@ pub(super) async fn force_set_room_state_from_server( Error::BadServerResponse("Invalid PDU in send_join response.") })?; - services() + self.services .rooms .outlier .add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() + let shortstatekey = self + .services .rooms .short .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; @@ -563,32 +600,34 @@ pub(super) async fn force_set_room_state_from_server( for result in remote_state_response .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map)) + .map(|pdu| validate_and_add_event_id(self.services, pdu, &room_version, &pub_key_map)) { let Ok((event_id, value)) = result.await else { continue; }; - services() + self.services .rooms .outlier .add_pdu_outlier(&event_id, &value)?; } - let new_room_state = services() + let new_room_state = self + .services .rooms .event_handler .resolve_state(room_id.clone().as_ref(), &room_version, state) .await?; info!("Forcing new room state"); - let (short_state_hash, new, removed) = services() + let (short_state_hash, new, removed) = self + .services .rooms .state_compressor .save_state(room_id.clone().as_ref(), new_room_state)?; - let state_lock = services().rooms.state.mutex.lock(&room_id).await; - services() + let state_lock = self.services.rooms.state.mutex.lock(&room_id).await; + self.services .rooms .state .force_state(room_id.clone().as_ref(), short_state_hash, new, removed, &state_lock) @@ -598,7 +637,10 @@ pub(super) async fn force_set_room_state_from_server( "Updating joined counts for room just in case (e.g. we may have found a difference in the room's \ m.room.member state" ); - services().rooms.state_cache.update_joined_count(&room_id)?; + self.services + .rooms + .state_cache + .update_joined_count(&room_id)?; drop(state_lock); @@ -607,28 +649,30 @@ pub(super) async fn force_set_room_state_from_server( )) } +#[admin_command] pub(super) async fn get_signing_keys( - _body: &[&str], server_name: Option<Box<ServerName>>, _cached: bool, + &self, server_name: Option<Box<ServerName>>, _cached: bool, ) -> Result<RoomMessageEventContent> { - let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); - let signing_keys = services().globals.signing_keys_for(&server_name)?; + let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); + let signing_keys = self.services.globals.signing_keys_for(&server_name)?; Ok(RoomMessageEventContent::notice_markdown(format!( "```rs\n{signing_keys:#?}\n```" ))) } +#[admin_command] #[allow(dead_code)] pub(super) async fn get_verify_keys( - _body: &[&str], server_name: Option<Box<ServerName>>, cached: bool, + &self, server_name: Option<Box<ServerName>>, cached: bool, ) -> Result<RoomMessageEventContent> { - let server_name = server_name.unwrap_or_else(|| services().server.config.server_name.clone().into()); + let server_name = server_name.unwrap_or_else(|| self.services.server.config.server_name.clone().into()); let mut out = String::new(); if cached { writeln!(out, "| Key ID | VerifyKey |")?; writeln!(out, "| --- | --- |")?; - for (key_id, verify_key) in services().globals.verify_keys_for(&server_name)? { + for (key_id, verify_key) in self.services.globals.verify_keys_for(&server_name)? { writeln!(out, "| {key_id} | {verify_key:?} |")?; } @@ -636,7 +680,8 @@ pub(super) async fn get_verify_keys( } let signature_ids: Vec<String> = Vec::new(); - let keys = services() + let keys = self + .services .rooms .event_handler .fetch_signing_keys_for_server(&server_name, signature_ids) @@ -651,16 +696,17 @@ pub(super) async fn get_verify_keys( Ok(RoomMessageEventContent::notice_markdown(out)) } +#[admin_command] pub(super) async fn resolve_true_destination( - _body: &[&str], server_name: Box<ServerName>, no_cache: bool, + &self, server_name: Box<ServerName>, no_cache: bool, ) -> Result<RoomMessageEventContent> { - if !services().globals.config.allow_federation { + if !self.services.globals.config.allow_federation { return Ok(RoomMessageEventContent::text_plain( "Federation is disabled on this homeserver.", )); } - if server_name == services().globals.config.server_name { + if server_name == self.services.globals.config.server_name { return Ok(RoomMessageEventContent::text_plain( "Not allowed to send federation requests to ourselves. Please use `get-pdu` for fetching local PDUs.", )); @@ -672,12 +718,13 @@ pub(super) async fn resolve_true_destination( && matches!(data.span_name(), "actual" | "well-known" | "srv") }; - let state = &services().server.log.capture; + let state = &self.services.server.log.capture; let logs = Arc::new(Mutex::new(String::new())); let capture = Capture::new(state, Some(filter), capture::fmt_markdown(logs.clone())); let capture_scope = capture.start(); - let actual = services() + let actual = self + .services .resolver .resolve_actual_dest(&server_name, !no_cache) .await?; @@ -692,7 +739,8 @@ pub(super) async fn resolve_true_destination( Ok(RoomMessageEventContent::text_markdown(msg)) } -pub(super) async fn memory_stats(_body: &[&str]) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn memory_stats(&self) -> Result<RoomMessageEventContent> { let html_body = conduit::alloc::memory_stats(); if html_body.is_none() { @@ -708,8 +756,9 @@ pub(super) async fn memory_stats(_body: &[&str]) -> Result<RoomMessageEventConte } #[cfg(tokio_unstable)] -pub(super) async fn runtime_metrics(_body: &[&str]) -> Result<RoomMessageEventContent> { - let out = services().server.metrics.runtime_metrics().map_or_else( +#[admin_command] +pub(super) async fn runtime_metrics(&self) -> Result<RoomMessageEventContent> { + let out = self.services.server.metrics.runtime_metrics().map_or_else( || "Runtime metrics are not available.".to_owned(), |metrics| format!("```rs\n{metrics:#?}\n```"), ); @@ -718,15 +767,17 @@ pub(super) async fn runtime_metrics(_body: &[&str]) -> Result<RoomMessageEventCo } #[cfg(not(tokio_unstable))] -pub(super) async fn runtime_metrics(_body: &[&str]) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn runtime_metrics(&self) -> Result<RoomMessageEventContent> { Ok(RoomMessageEventContent::text_markdown( "Runtime metrics require building with `tokio_unstable`.", )) } #[cfg(tokio_unstable)] -pub(super) async fn runtime_interval(_body: &[&str]) -> Result<RoomMessageEventContent> { - let out = services().server.metrics.runtime_interval().map_or_else( +#[admin_command] +pub(super) async fn runtime_interval(&self) -> Result<RoomMessageEventContent> { + let out = self.services.server.metrics.runtime_interval().map_or_else( || "Runtime metrics are not available.".to_owned(), |metrics| format!("```rs\n{metrics:#?}\n```"), ); @@ -735,18 +786,21 @@ pub(super) async fn runtime_interval(_body: &[&str]) -> Result<RoomMessageEventC } #[cfg(not(tokio_unstable))] -pub(super) async fn runtime_interval(_body: &[&str]) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn runtime_interval(&self) -> Result<RoomMessageEventContent> { Ok(RoomMessageEventContent::text_markdown( "Runtime metrics require building with `tokio_unstable`.", )) } -pub(super) async fn time(_body: &[&str]) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn time(&self) -> Result<RoomMessageEventContent> { let now = SystemTime::now(); Ok(RoomMessageEventContent::text_markdown(utils::time::format(now, "%+"))) } -pub(super) async fn list_dependencies(_body: &[&str], names: bool) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn list_dependencies(&self, names: bool) -> Result<RoomMessageEventContent> { if names { let out = info::cargo::dependencies_names().join(" "); return Ok(RoomMessageEventContent::notice_markdown(out)); diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index babd6047d..fbe6fd264 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -3,10 +3,10 @@ use clap::Subcommand; use conduit::Result; -use conduit_macros::admin_command_dispatch; -use ruma::{events::room::message::RoomMessageEventContent, EventId, OwnedRoomOrAliasId, RoomId, ServerName}; +use ruma::{EventId, OwnedRoomOrAliasId, RoomId, ServerName}; -use self::{commands::*, tester::TesterCommand}; +use self::tester::TesterCommand; +use crate::admin_command_dispatch; #[admin_command_dispatch] #[derive(Debug, Subcommand)] diff --git a/src/admin/debug/tester.rs b/src/admin/debug/tester.rs index 2765a344d..af4ea2dca 100644 --- a/src/admin/debug/tester.rs +++ b/src/admin/debug/tester.rs @@ -1,33 +1,29 @@ use ruma::events::room::message::RoomMessageEventContent; -use crate::Result; +use crate::{admin_command, admin_command_dispatch, Result}; +#[admin_command_dispatch] #[derive(Debug, clap::Subcommand)] pub(crate) enum TesterCommand { Tester, Timer, } -pub(super) async fn process(command: TesterCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - match command { - TesterCommand::Tester => tester(body).await, - TesterCommand::Timer => timer(body).await, - } -} - #[inline(never)] #[rustfmt::skip] #[allow(unused_variables)] -async fn tester(body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[admin_command] +async fn tester(&self) -> Result<RoomMessageEventContent> { Ok(RoomMessageEventContent::notice_plain("completed")) } #[inline(never)] #[rustfmt::skip] -async fn timer(body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[admin_command] +async fn timer(&self) -> Result<RoomMessageEventContent> { let started = std::time::Instant::now(); - timed(&body); + timed(self.body); let elapsed = started.elapsed(); Ok(RoomMessageEventContent::notice_plain(format!("completed in {elapsed:#?}"))) diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index d6ecd3f7c..8917a46b9 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -1,21 +1,26 @@ use std::fmt::Write; +use conduit::Result; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId, ServerName, UserId}; -use crate::{escape_html, get_room_info, services, Result}; +use crate::{admin_command, escape_html, get_room_info}; -pub(super) async fn disable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - services().rooms.metadata.disable_room(&room_id, true)?; +#[admin_command] +pub(super) async fn disable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + self.services.rooms.metadata.disable_room(&room_id, true)?; Ok(RoomMessageEventContent::text_plain("Room disabled.")) } -pub(super) async fn enable_room(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - services().rooms.metadata.disable_room(&room_id, false)?; +#[admin_command] +pub(super) async fn enable_room(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + self.services.rooms.metadata.disable_room(&room_id, false)?; Ok(RoomMessageEventContent::text_plain("Room enabled.")) } -pub(super) async fn incoming_federation(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let map = services() +#[admin_command] +pub(super) async fn incoming_federation(&self) -> Result<RoomMessageEventContent> { + let map = self + .services .rooms .event_handler .federation_handletime @@ -31,10 +36,10 @@ pub(super) async fn incoming_federation(_body: Vec<&str>) -> Result<RoomMessageE Ok(RoomMessageEventContent::text_plain(&msg)) } -pub(super) async fn fetch_support_well_known( - _body: Vec<&str>, server_name: Box<ServerName>, -) -> Result<RoomMessageEventContent> { - let response = services() +#[admin_command] +pub(super) async fn fetch_support_well_known(&self, server_name: Box<ServerName>) -> Result<RoomMessageEventContent> { + let response = self + .services .client .default .get(format!("https://{server_name}/.well-known/matrix/support")) @@ -72,25 +77,27 @@ pub(super) async fn fetch_support_well_known( ))) } -pub(super) async fn remote_user_in_rooms(_body: Vec<&str>, user_id: Box<UserId>) -> Result<RoomMessageEventContent> { - if user_id.server_name() == services().globals.config.server_name { +#[admin_command] +pub(super) async fn remote_user_in_rooms(&self, user_id: Box<UserId>) -> Result<RoomMessageEventContent> { + if user_id.server_name() == self.services.globals.config.server_name { return Ok(RoomMessageEventContent::text_plain( "User belongs to our server, please use `list-joined-rooms` user admin command instead.", )); } - if !services().users.exists(&user_id)? { + if !self.services.users.exists(&user_id)? { return Ok(RoomMessageEventContent::text_plain( "Remote user does not exist in our database.", )); } - let mut rooms: Vec<(OwnedRoomId, u64, String)> = services() + let mut rooms: Vec<(OwnedRoomId, u64, String)> = self + .services .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) - .map(|room_id| get_room_info(services(), &room_id)) + .map(|room_id| get_room_info(self.services, &room_id)) .collect(); if rooms.is_empty() { diff --git a/src/admin/federation/mod.rs b/src/admin/federation/mod.rs index d02b42956..8f5d3fae5 100644 --- a/src/admin/federation/mod.rs +++ b/src/admin/federation/mod.rs @@ -2,10 +2,11 @@ use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; +use ruma::{RoomId, ServerName, UserId}; -use self::commands::*; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum FederationCommand { /// - List all rooms we are currently handling an incoming pdu from @@ -39,21 +40,3 @@ pub(super) enum FederationCommand { user_id: Box<UserId>, }, } - -pub(super) async fn process(command: FederationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - FederationCommand::DisableRoom { - room_id, - } => disable_room(body, room_id).await?, - FederationCommand::EnableRoom { - room_id, - } => enable_room(body, room_id).await?, - FederationCommand::IncomingFederation => incoming_federation(body).await?, - FederationCommand::FetchSupportWellKnown { - server_name, - } => fetch_support_well_known(body, server_name).await?, - FederationCommand::RemoteUserInRooms { - user_id, - } => remote_user_in_rooms(body, user_id).await?, - }) -} diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 26a5ea419..32360c855 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -15,10 +15,10 @@ Services, }; -use crate::{admin, admin::AdminCommand}; +use crate::{admin, admin::AdminCommand, Command}; -struct Handler<'a> { - services: &'a Services, +struct Handler { + services: &'static Services, } #[must_use] @@ -68,13 +68,12 @@ fn reply(mut content: RoomMessageEventContent, reply_id: Option<OwnedEventId>) - Some(content) } -impl Handler<'_> { +impl Handler { // Parse and process a message from the admin room async fn process(&self, msg: &str) -> CommandOutput { let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); let command = lines.next().expect("each string has at least one line"); - let body = lines.collect::<Vec<_>>(); - let parsed = match self.parse_command(command) { + let (parsed, body) = match self.parse_command(command) { Ok(parsed) => parsed, Err(error) => { let server_name = self.services.globals.server_name(); @@ -84,7 +83,12 @@ async fn process(&self, msg: &str) -> CommandOutput { }; let timer = Instant::now(); - let result = Box::pin(admin::process(parsed, body)).await; + let body: Vec<&str> = body.iter().map(String::as_str).collect(); + let context = Command { + services: self.services, + body: &body, + }; + let result = Box::pin(admin::process(parsed, &context)).await; let elapsed = timer.elapsed(); conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); match result { @@ -96,9 +100,10 @@ async fn process(&self, msg: &str) -> CommandOutput { } // Parse chat messages from the admin room into an AdminCommand object - fn parse_command(&self, command_line: &str) -> Result<AdminCommand, String> { + fn parse_command(&self, command_line: &str) -> Result<(AdminCommand, Vec<String>), String> { let argv = self.parse_line(command_line); - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) + let com = AdminCommand::try_parse_from(&argv).map_err(|error| error.to_string())?; + Ok((com, argv)) } fn complete_command(&self, mut cmd: clap::Command, line: &str) -> String { diff --git a/src/admin/media/commands.rs b/src/admin/media/commands.rs index d29d5f47d..7906d951b 100644 --- a/src/admin/media/commands.rs +++ b/src/admin/media/commands.rs @@ -1,11 +1,11 @@ -use conduit::Result; +use conduit::{debug, info, Result}; use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri}; -use tracing::{debug, info}; -use crate::services; +use crate::admin_command; +#[admin_command] pub(super) async fn delete( - _body: Vec<&str>, mxc: Option<Box<MxcUri>>, event_id: Option<Box<EventId>>, + &self, mxc: Option<Box<MxcUri>>, event_id: Option<Box<EventId>>, ) -> Result<RoomMessageEventContent> { if event_id.is_some() && mxc.is_some() { return Ok(RoomMessageEventContent::text_plain( @@ -15,7 +15,7 @@ pub(super) async fn delete( if let Some(mxc) = mxc { debug!("Got MXC URL: {mxc}"); - services().media.delete(mxc.as_ref()).await?; + self.services.media.delete(mxc.as_ref()).await?; return Ok(RoomMessageEventContent::text_plain( "Deleted the MXC from our database and on our filesystem.", @@ -27,7 +27,7 @@ pub(super) async fn delete( let mut mxc_deletion_count: usize = 0; // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = services().rooms.timeline.get_pdu_json(&event_id)? { + if let Some(event_json) = self.services.rooms.timeline.get_pdu_json(&event_id)? { if let Some(content_key) = event_json.get("content") { debug!("Event ID has \"content\"."); let content_obj = content_key.as_object(); @@ -123,7 +123,7 @@ pub(super) async fn delete( } for mxc_url in mxc_urls { - services().media.delete(&mxc_url).await?; + self.services.media.delete(&mxc_url).await?; mxc_deletion_count = mxc_deletion_count.saturating_add(1); } @@ -138,23 +138,26 @@ pub(super) async fn delete( )) } -pub(super) async fn delete_list(body: Vec<&str>) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn delete_list(&self) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let mxc_list = body - .clone() - .drain(1..body.len().checked_sub(1).unwrap()) + let mxc_list = self + .body + .to_vec() + .drain(1..self.body.len().checked_sub(1).unwrap()) .collect::<Vec<_>>(); let mut mxc_deletion_count: usize = 0; for mxc in mxc_list { debug!("Deleting MXC {mxc} in bulk"); - services().media.delete(mxc).await?; + self.services.media.delete(mxc).await?; mxc_deletion_count = mxc_deletion_count .checked_add(1) .expect("mxc_deletion_count should not get this high"); @@ -165,10 +168,10 @@ pub(super) async fn delete_list(body: Vec<&str>) -> Result<RoomMessageEventConte ))) } -pub(super) async fn delete_past_remote_media( - _body: Vec<&str>, duration: String, force: bool, -) -> Result<RoomMessageEventContent> { - let deleted_count = services() +#[admin_command] +pub(super) async fn delete_past_remote_media(&self, duration: String, force: bool) -> Result<RoomMessageEventContent> { + let deleted_count = self + .services .media .delete_all_remote_media_at_after_time(duration, force) .await?; diff --git a/src/admin/media/mod.rs b/src/admin/media/mod.rs index d30c55d0b..31cbf810e 100644 --- a/src/admin/media/mod.rs +++ b/src/admin/media/mod.rs @@ -2,10 +2,11 @@ use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, EventId, MxcUri}; +use ruma::{EventId, MxcUri}; -use self::commands::*; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum MediaCommand { /// - Deletes a single media file from our database and on the filesystem @@ -36,17 +37,3 @@ pub(super) enum MediaCommand { force: bool, }, } - -pub(super) async fn process(command: MediaCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - MediaCommand::Delete { - mxc, - event_id, - } => delete(body, mxc, event_id).await?, - MediaCommand::DeleteList => delete_list(body).await?, - MediaCommand::DeletePastRemoteMedia { - duration, - force, - } => delete_past_remote_media(body, duration, force).await?, - }) -} diff --git a/src/admin/mod.rs b/src/admin/mod.rs index cd1110ee3..5d4c8f5e4 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -3,6 +3,7 @@ #![allow(clippy::enum_glob_use)] pub(crate) mod admin; +pub(crate) mod command; pub(crate) mod handler; mod tests; pub(crate) mod utils; @@ -22,9 +23,13 @@ extern crate conduit_service as service; pub(crate) use conduit::Result; +pub(crate) use conduit_macros::{admin_command, admin_command_dispatch}; pub(crate) use service::services; -pub(crate) use crate::utils::{escape_html, get_room_info}; +pub(crate) use crate::{ + command::Command, + utils::{escape_html, get_room_info}, +}; pub(crate) const PAGE_SIZE: usize = 100; diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index ce9caedb4..e18c298a3 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -1,18 +1,48 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{ + events::{room::message::RoomMessageEventContent, RoomAccountDataEventType}, + RoomId, UserId, +}; -use super::AccountData; -use crate::{services, Result}; +use crate::Command; +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/account_data.rs -pub(super) async fn account_data(subcommand: AccountData) -> Result<RoomMessageEventContent> { +pub(crate) enum AccountDataCommand { + /// - Returns all changes to the account data that happened after `since`. + ChangesSince { + /// Full user ID + user_id: Box<UserId>, + /// UNIX timestamp since (u64) + since: u64, + /// Optional room ID of the account data + room_id: Option<Box<RoomId>>, + }, + + /// - Searches the account data for a specific kind. + Get { + /// Full user ID + user_id: Box<UserId>, + /// Account data event type + kind: RoomAccountDataEventType, + /// Optional room ID of the account data + room_id: Option<Box<RoomId>>, + }, +} + +/// All the getters and iterators from src/database/key_value/account_data.rs +pub(super) async fn process(subcommand: AccountDataCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - AccountData::ChangesSince { + AccountDataCommand::ChangesSince { user_id, since, room_id, } => { let timer = tokio::time::Instant::now(); - let results = services() + let results = services .account_data .changes_since(room_id.as_deref(), &user_id, since)?; let query_time = timer.elapsed(); @@ -21,13 +51,13 @@ pub(super) async fn account_data(subcommand: AccountData) -> Result<RoomMessageE "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - AccountData::Get { + AccountDataCommand::Get { user_id, kind, room_id, } => { let timer = tokio::time::Instant::now(); - let results = services() + let results = services .account_data .get(room_id.as_deref(), &user_id, kind)?; let query_time = timer.elapsed(); diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 77ac96419..683c228f7 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -1,16 +1,32 @@ +use clap::Subcommand; +use conduit::Result; use ruma::events::room::message::RoomMessageEventContent; -use super::Appservice; -use crate::{services, Result}; +use crate::Command; +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/appservice.rs -pub(super) async fn appservice(subcommand: Appservice) -> Result<RoomMessageEventContent> { +pub(crate) enum AppserviceCommand { + /// - Gets the appservice registration info/details from the ID as a string + GetRegistration { + /// Appservice registration ID + appservice_id: Box<str>, + }, + + /// - Gets all appservice registrations with their ID and registration info + All, +} + +/// All the getters and iterators from src/database/key_value/appservice.rs +pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - Appservice::GetRegistration { + AppserviceCommand::GetRegistration { appservice_id, } => { let timer = tokio::time::Instant::now(); - let results = services() + let results = services .appservice .db .get_registration(appservice_id.as_ref()); @@ -20,9 +36,9 @@ pub(super) async fn appservice(subcommand: Appservice) -> Result<RoomMessageEven "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Appservice::All => { + AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services().appservice.all(); + let results = services.appservice.all(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 389860711..5f271c2c4 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -1,52 +1,73 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, ServerName}; -use super::Globals; -use crate::{services, Result}; +use crate::Command; +#[derive(Debug, Subcommand)] /// All the getters and iterators from src/database/key_value/globals.rs -pub(super) async fn globals(subcommand: Globals) -> Result<RoomMessageEventContent> { +pub(crate) enum GlobalsCommand { + DatabaseVersion, + + CurrentCount, + + LastCheckForUpdatesId, + + LoadKeypair, + + /// - This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. + SigningKeysFor { + origin: Box<ServerName>, + }, +} + +/// All the getters and iterators from src/database/key_value/globals.rs +pub(super) async fn process(subcommand: GlobalsCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - Globals::DatabaseVersion => { + GlobalsCommand::DatabaseVersion => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.database_version(); + let results = services.globals.db.database_version(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Globals::CurrentCount => { + GlobalsCommand::CurrentCount => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.current_count(); + let results = services.globals.db.current_count(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Globals::LastCheckForUpdatesId => { + GlobalsCommand::LastCheckForUpdatesId => { let timer = tokio::time::Instant::now(); - let results = services().updates.last_check_for_updates_id(); + let results = services.updates.last_check_for_updates_id(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Globals::LoadKeypair => { + GlobalsCommand::LoadKeypair => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.load_keypair(); + let results = services.globals.db.load_keypair(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Globals::SigningKeysFor { + GlobalsCommand::SigningKeysFor { origin, } => { let timer = tokio::time::Instant::now(); - let results = services().globals.db.verify_keys_for(&origin); + let results = services.globals.db.verify_keys_for(&origin); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/mod.rs b/src/admin/query/mod.rs index c86f4f538..1aa28c48b 100644 --- a/src/admin/query/mod.rs +++ b/src/admin/query/mod.rs @@ -10,304 +10,51 @@ use clap::Subcommand; use conduit::Result; -use room_state_cache::room_state_cache; -use ruma::{ - events::{room::message::RoomMessageEventContent, RoomAccountDataEventType}, - OwnedServerName, RoomAliasId, RoomId, ServerName, UserId, -}; use self::{ - account_data::account_data, appservice::appservice, globals::globals, presence::presence, resolver::resolver, - room_alias::room_alias, sending::sending, users::users, + account_data::AccountDataCommand, appservice::AppserviceCommand, globals::GlobalsCommand, + presence::PresenceCommand, resolver::ResolverCommand, room_alias::RoomAliasCommand, + room_state_cache::RoomStateCacheCommand, sending::SendingCommand, users::UsersCommand, }; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] /// Query tables from database pub(super) enum QueryCommand { /// - account_data.rs iterators and getters #[command(subcommand)] - AccountData(AccountData), + AccountData(AccountDataCommand), /// - appservice.rs iterators and getters #[command(subcommand)] - Appservice(Appservice), + Appservice(AppserviceCommand), /// - presence.rs iterators and getters #[command(subcommand)] - Presence(Presence), + Presence(PresenceCommand), /// - rooms/alias.rs iterators and getters #[command(subcommand)] - RoomAlias(RoomAlias), + RoomAlias(RoomAliasCommand), /// - rooms/state_cache iterators and getters #[command(subcommand)] - RoomStateCache(RoomStateCache), + RoomStateCache(RoomStateCacheCommand), /// - globals.rs iterators and getters #[command(subcommand)] - Globals(Globals), + Globals(GlobalsCommand), /// - sending.rs iterators and getters #[command(subcommand)] - Sending(Sending), + Sending(SendingCommand), /// - users.rs iterators and getters #[command(subcommand)] - Users(Users), + Users(UsersCommand), /// - resolver service #[command(subcommand)] - Resolver(Resolver), -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/account_data.rs -pub(super) enum AccountData { - /// - Returns all changes to the account data that happened after `since`. - ChangesSince { - /// Full user ID - user_id: Box<UserId>, - /// UNIX timestamp since (u64) - since: u64, - /// Optional room ID of the account data - room_id: Option<Box<RoomId>>, - }, - - /// - Searches the account data for a specific kind. - Get { - /// Full user ID - user_id: Box<UserId>, - /// Account data event type - kind: RoomAccountDataEventType, - /// Optional room ID of the account data - room_id: Option<Box<RoomId>>, - }, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/appservice.rs -pub(super) enum Appservice { - /// - Gets the appservice registration info/details from the ID as a string - GetRegistration { - /// Appservice registration ID - appservice_id: Box<str>, - }, - - /// - Gets all appservice registrations with their ID and registration info - All, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/presence.rs -pub(super) enum Presence { - /// - Returns the latest presence event for the given user. - GetPresence { - /// Full user ID - user_id: Box<UserId>, - }, - - /// - Iterator of the most recent presence updates that happened after the - /// event with id `since`. - PresenceSince { - /// UNIX timestamp since (u64) - since: u64, - }, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/rooms/alias.rs -pub(super) enum RoomAlias { - ResolveLocalAlias { - /// Full room alias - alias: Box<RoomAliasId>, - }, - - /// - Iterator of all our local room aliases for the room ID - LocalAliasesForRoom { - /// Full room ID - room_id: Box<RoomId>, - }, - - /// - Iterator of all our local aliases in our database with their room IDs - AllLocalAliases, -} - -#[derive(Debug, Subcommand)] -pub(super) enum RoomStateCache { - ServerInRoom { - server: Box<ServerName>, - room_id: Box<RoomId>, - }, - - RoomServers { - room_id: Box<RoomId>, - }, - - ServerRooms { - server: Box<ServerName>, - }, - - RoomMembers { - room_id: Box<RoomId>, - }, - - LocalUsersInRoom { - room_id: Box<RoomId>, - }, - - ActiveLocalUsersInRoom { - room_id: Box<RoomId>, - }, - - RoomJoinedCount { - room_id: Box<RoomId>, - }, - - RoomInvitedCount { - room_id: Box<RoomId>, - }, - - RoomUserOnceJoined { - room_id: Box<RoomId>, - }, - - RoomMembersInvited { - room_id: Box<RoomId>, - }, - - GetInviteCount { - room_id: Box<RoomId>, - user_id: Box<UserId>, - }, - - GetLeftCount { - room_id: Box<RoomId>, - user_id: Box<UserId>, - }, - - RoomsJoined { - user_id: Box<UserId>, - }, - - RoomsLeft { - user_id: Box<UserId>, - }, - - RoomsInvited { - user_id: Box<UserId>, - }, - - InviteState { - user_id: Box<UserId>, - room_id: Box<RoomId>, - }, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/globals.rs -pub(super) enum Globals { - DatabaseVersion, - - CurrentCount, - - LastCheckForUpdatesId, - - LoadKeypair, - - /// - This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - SigningKeysFor { - origin: Box<ServerName>, - }, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/sending.rs -pub(super) enum Sending { - /// - Queries database for all `servercurrentevent_data` - ActiveRequests, - - /// - Queries database for `servercurrentevent_data` but for a specific - /// destination - /// - /// This command takes only *one* format of these arguments: - /// - /// appservice_id - /// server_name - /// user_id AND push_key - /// - /// See src/service/sending/mod.rs for the definition of the `Destination` - /// enum - ActiveRequestsFor { - #[arg(short, long)] - appservice_id: Option<String>, - #[arg(short, long)] - server_name: Option<Box<ServerName>>, - #[arg(short, long)] - user_id: Option<Box<UserId>>, - #[arg(short, long)] - push_key: Option<String>, - }, - - /// - Queries database for `servernameevent_data` which are the queued up - /// requests that will eventually be sent - /// - /// This command takes only *one* format of these arguments: - /// - /// appservice_id - /// server_name - /// user_id AND push_key - /// - /// See src/service/sending/mod.rs for the definition of the `Destination` - /// enum - QueuedRequests { - #[arg(short, long)] - appservice_id: Option<String>, - #[arg(short, long)] - server_name: Option<Box<ServerName>>, - #[arg(short, long)] - user_id: Option<Box<UserId>>, - #[arg(short, long)] - push_key: Option<String>, - }, - - GetLatestEduCount { - server_name: Box<ServerName>, - }, -} - -#[derive(Debug, Subcommand)] -/// All the getters and iterators from src/database/key_value/users.rs -pub(super) enum Users { - Iter, -} - -#[derive(Debug, Subcommand)] -/// Resolver service and caches -pub(super) enum Resolver { - /// Query the destinations cache - DestinationsCache { - server_name: Option<OwnedServerName>, - }, - - /// Query the overrides cache - OverridesCache { - name: Option<String>, - }, -} - -/// Processes admin query commands -pub(super) async fn process(command: QueryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - QueryCommand::AccountData(command) => account_data(command).await?, - QueryCommand::Appservice(command) => appservice(command).await?, - QueryCommand::Presence(command) => presence(command).await?, - QueryCommand::RoomAlias(command) => room_alias(command).await?, - QueryCommand::RoomStateCache(command) => room_state_cache(command).await?, - QueryCommand::Globals(command) => globals(command).await?, - QueryCommand::Sending(command) => sending(command).await?, - QueryCommand::Users(command) => users(command).await?, - QueryCommand::Resolver(command) => resolver(command).await?, - }) + Resolver(ResolverCommand), } diff --git a/src/admin/query/presence.rs b/src/admin/query/presence.rs index c47b7a51d..145ecd9b1 100644 --- a/src/admin/query/presence.rs +++ b/src/admin/query/presence.rs @@ -1,27 +1,47 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, UserId}; -use super::Presence; -use crate::{services, Result}; +use crate::Command; + +#[derive(Debug, Subcommand)] +/// All the getters and iterators from src/database/key_value/presence.rs +pub(crate) enum PresenceCommand { + /// - Returns the latest presence event for the given user. + GetPresence { + /// Full user ID + user_id: Box<UserId>, + }, + + /// - Iterator of the most recent presence updates that happened after the + /// event with id `since`. + PresenceSince { + /// UNIX timestamp since (u64) + since: u64, + }, +} /// All the getters and iterators in key_value/presence.rs -pub(super) async fn presence(subcommand: Presence) -> Result<RoomMessageEventContent> { +pub(super) async fn process(subcommand: PresenceCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - Presence::GetPresence { + PresenceCommand::GetPresence { user_id, } => { let timer = tokio::time::Instant::now(); - let results = services().presence.db.get_presence(&user_id)?; + let results = services.presence.db.get_presence(&user_id)?; let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - Presence::PresenceSince { + PresenceCommand::PresenceSince { since, } => { let timer = tokio::time::Instant::now(); - let results = services().presence.db.presence_since(since); + let results = services.presence.db.presence_since(since); let presence_since: Vec<(_, _, _)> = results.collect(); let query_time = timer.elapsed(); diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 37d179609..e8340dad8 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -1,24 +1,28 @@ use std::fmt::Write; +use clap::Subcommand; use conduit::{utils::time, Result}; use ruma::{events::room::message::RoomMessageEventContent, OwnedServerName}; -use super::Resolver; -use crate::services; +use crate::{admin_command, admin_command_dispatch}; -/// All the getters and iterators in key_value/users.rs -pub(super) async fn resolver(subcommand: Resolver) -> Result<RoomMessageEventContent> { - match subcommand { - Resolver::DestinationsCache { - server_name, - } => destinations_cache(server_name).await, - Resolver::OverridesCache { - name, - } => overrides_cache(name).await, - } +#[admin_command_dispatch] +#[derive(Debug, Subcommand)] +/// Resolver service and caches +pub(crate) enum ResolverCommand { + /// Query the destinations cache + DestinationsCache { + server_name: Option<OwnedServerName>, + }, + + /// Query the overrides cache + OverridesCache { + name: Option<String>, + }, } -async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> { +#[admin_command] +async fn destinations_cache(&self, server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> { use service::resolver::cache::CachedDest; let mut out = String::new(); @@ -36,7 +40,8 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line"); }; - let map = services() + let map = self + .services .resolver .cache .destinations @@ -52,7 +57,8 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room Ok(RoomMessageEventContent::notice_markdown(out)) } -async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEventContent> { +#[admin_command] +async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessageEventContent> { use service::resolver::cache::CachedOverride; let mut out = String::new(); @@ -70,7 +76,13 @@ async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEvent writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line"); }; - let map = services().resolver.cache.overrides.read().expect("locked"); + let map = self + .services + .resolver + .cache + .overrides + .read() + .expect("locked"); if let Some(server_name) = server_name.as_ref() { map.get_key_value(server_name).map(row); diff --git a/src/admin/query/room_alias.rs b/src/admin/query/room_alias.rs index d2c16801d..1809e26a0 100644 --- a/src/admin/query/room_alias.rs +++ b/src/admin/query/room_alias.rs @@ -1,27 +1,48 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; -use super::RoomAlias; -use crate::{services, Result}; +use crate::Command; + +#[derive(Debug, Subcommand)] +/// All the getters and iterators from src/database/key_value/rooms/alias.rs +pub(crate) enum RoomAliasCommand { + ResolveLocalAlias { + /// Full room alias + alias: Box<RoomAliasId>, + }, + + /// - Iterator of all our local room aliases for the room ID + LocalAliasesForRoom { + /// Full room ID + room_id: Box<RoomId>, + }, + + /// - Iterator of all our local aliases in our database with their room IDs + AllLocalAliases, +} /// All the getters and iterators in src/database/key_value/rooms/alias.rs -pub(super) async fn room_alias(subcommand: RoomAlias) -> Result<RoomMessageEventContent> { +pub(super) async fn process(subcommand: RoomAliasCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - RoomAlias::ResolveLocalAlias { + RoomAliasCommand::ResolveLocalAlias { alias, } => { let timer = tokio::time::Instant::now(); - let results = services().rooms.alias.resolve_local_alias(&alias); + let results = services.rooms.alias.resolve_local_alias(&alias); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomAlias::LocalAliasesForRoom { + RoomAliasCommand::LocalAliasesForRoom { room_id, } => { let timer = tokio::time::Instant::now(); - let results = services().rooms.alias.local_aliases_for_room(&room_id); + let results = services.rooms.alias.local_aliases_for_room(&room_id); let aliases: Vec<_> = results.collect(); let query_time = timer.elapsed(); @@ -29,9 +50,9 @@ pub(super) async fn room_alias(subcommand: RoomAlias) -> Result<RoomMessageEvent "Query completed in {query_time:?}:\n\n```rs\n{aliases:#?}\n```" ))) }, - RoomAlias::AllLocalAliases => { + RoomAliasCommand::AllLocalAliases => { let timer = tokio::time::Instant::now(); - let results = services().rooms.alias.all_local_aliases(); + let results = services.rooms.alias.all_local_aliases(); let aliases: Vec<_> = results.collect(); let query_time = timer.elapsed(); diff --git a/src/admin/query/room_state_cache.rs b/src/admin/query/room_state_cache.rs index aed2b4a26..4215cf8d6 100644 --- a/src/admin/query/room_state_cache.rs +++ b/src/admin/query/room_state_cache.rs @@ -1,71 +1,136 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, RoomId, ServerName, UserId}; -use super::RoomStateCache; -use crate::{services, Result}; +use crate::Command; + +#[derive(Debug, Subcommand)] +pub(crate) enum RoomStateCacheCommand { + ServerInRoom { + server: Box<ServerName>, + room_id: Box<RoomId>, + }, + + RoomServers { + room_id: Box<RoomId>, + }, + + ServerRooms { + server: Box<ServerName>, + }, + + RoomMembers { + room_id: Box<RoomId>, + }, + + LocalUsersInRoom { + room_id: Box<RoomId>, + }, + + ActiveLocalUsersInRoom { + room_id: Box<RoomId>, + }, + + RoomJoinedCount { + room_id: Box<RoomId>, + }, + + RoomInvitedCount { + room_id: Box<RoomId>, + }, + + RoomUserOnceJoined { + room_id: Box<RoomId>, + }, + + RoomMembersInvited { + room_id: Box<RoomId>, + }, + + GetInviteCount { + room_id: Box<RoomId>, + user_id: Box<UserId>, + }, + + GetLeftCount { + room_id: Box<RoomId>, + user_id: Box<UserId>, + }, + + RoomsJoined { + user_id: Box<UserId>, + }, + + RoomsLeft { + user_id: Box<UserId>, + }, + + RoomsInvited { + user_id: Box<UserId>, + }, + + InviteState { + user_id: Box<UserId>, + room_id: Box<RoomId>, + }, +} + +pub(super) async fn process( + subcommand: RoomStateCacheCommand, context: &Command<'_>, +) -> Result<RoomMessageEventContent> { + let services = context.services; -pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomMessageEventContent> { match subcommand { - RoomStateCache::ServerInRoom { + RoomStateCacheCommand::ServerInRoom { server, room_id, } => { let timer = tokio::time::Instant::now(); - let result = services() - .rooms - .state_cache - .server_in_room(&server, &room_id); + let result = services.rooms.state_cache.server_in_room(&server, &room_id); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{result:#?}\n```" ))) }, - RoomStateCache::RoomServers { + RoomStateCacheCommand::RoomServers { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() - .rooms - .state_cache - .room_servers(&room_id) - .collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.room_servers(&room_id).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::ServerRooms { + RoomStateCacheCommand::ServerRooms { server, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services().rooms.state_cache.server_rooms(&server).collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.server_rooms(&server).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomMembers { + RoomStateCacheCommand::RoomMembers { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() - .rooms - .state_cache - .room_members(&room_id) - .collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.room_members(&room_id).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::LocalUsersInRoom { + RoomStateCacheCommand::LocalUsersInRoom { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Vec<_> = services() + let results: Vec<_> = services .rooms .state_cache .local_users_in_room(&room_id) @@ -76,11 +141,11 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::ActiveLocalUsersInRoom { + RoomStateCacheCommand::ActiveLocalUsersInRoom { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Vec<_> = services() + let results: Vec<_> = services .rooms .state_cache .active_local_users_in_room(&room_id) @@ -91,33 +156,33 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomJoinedCount { + RoomStateCacheCommand::RoomJoinedCount { room_id, } => { let timer = tokio::time::Instant::now(); - let results = services().rooms.state_cache.room_joined_count(&room_id); + let results = services.rooms.state_cache.room_joined_count(&room_id); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomInvitedCount { + RoomStateCacheCommand::RoomInvitedCount { room_id, } => { let timer = tokio::time::Instant::now(); - let results = services().rooms.state_cache.room_invited_count(&room_id); + let results = services.rooms.state_cache.room_invited_count(&room_id); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomUserOnceJoined { + RoomStateCacheCommand::RoomUserOnceJoined { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() + let results: Result<Vec<_>> = services .rooms .state_cache .room_useroncejoined(&room_id) @@ -128,11 +193,11 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomMembersInvited { + RoomStateCacheCommand::RoomMembersInvited { room_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() + let results: Result<Vec<_>> = services .rooms .state_cache .room_members_invited(&room_id) @@ -143,12 +208,12 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::GetInviteCount { + RoomStateCacheCommand::GetInviteCount { room_id, user_id, } => { let timer = tokio::time::Instant::now(); - let results = services() + let results = services .rooms .state_cache .get_invite_count(&room_id, &user_id); @@ -158,12 +223,12 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::GetLeftCount { + RoomStateCacheCommand::GetLeftCount { room_id, user_id, } => { let timer = tokio::time::Instant::now(); - let results = services() + let results = services .rooms .state_cache .get_left_count(&room_id, &user_id); @@ -173,56 +238,45 @@ pub(super) async fn room_state_cache(subcommand: RoomStateCache) -> Result<RoomM "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomsJoined { + RoomStateCacheCommand::RoomsJoined { user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() - .rooms - .state_cache - .rooms_joined(&user_id) - .collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.rooms_joined(&user_id).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomsInvited { + RoomStateCacheCommand::RoomsInvited { user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services() - .rooms - .state_cache - .rooms_invited(&user_id) - .collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.rooms_invited(&user_id).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::RoomsLeft { + RoomStateCacheCommand::RoomsLeft { user_id, } => { let timer = tokio::time::Instant::now(); - let results: Result<Vec<_>> = services().rooms.state_cache.rooms_left(&user_id).collect(); + let results: Result<Vec<_>> = services.rooms.state_cache.rooms_left(&user_id).collect(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```" ))) }, - RoomStateCache::InviteState { + RoomStateCacheCommand::InviteState { user_id, room_id, } => { let timer = tokio::time::Instant::now(); - let results = services() - .rooms - .state_cache - .invite_state(&user_id, &room_id); + let results = services.rooms.state_cache.invite_state(&user_id, &room_id); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 4e82695d3..6d54bddfd 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -1,14 +1,73 @@ -use ruma::events::room::message::RoomMessageEventContent; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, ServerName, UserId}; +use service::sending::Destination; -use super::Sending; -use crate::{service::sending::Destination, services, Result}; +use crate::Command; + +#[derive(Debug, Subcommand)] +/// All the getters and iterators from src/database/key_value/sending.rs +pub(crate) enum SendingCommand { + /// - Queries database for all `servercurrentevent_data` + ActiveRequests, + + /// - Queries database for `servercurrentevent_data` but for a specific + /// destination + /// + /// This command takes only *one* format of these arguments: + /// + /// appservice_id + /// server_name + /// user_id AND push_key + /// + /// See src/service/sending/mod.rs for the definition of the `Destination` + /// enum + ActiveRequestsFor { + #[arg(short, long)] + appservice_id: Option<String>, + #[arg(short, long)] + server_name: Option<Box<ServerName>>, + #[arg(short, long)] + user_id: Option<Box<UserId>>, + #[arg(short, long)] + push_key: Option<String>, + }, + + /// - Queries database for `servernameevent_data` which are the queued up + /// requests that will eventually be sent + /// + /// This command takes only *one* format of these arguments: + /// + /// appservice_id + /// server_name + /// user_id AND push_key + /// + /// See src/service/sending/mod.rs for the definition of the `Destination` + /// enum + QueuedRequests { + #[arg(short, long)] + appservice_id: Option<String>, + #[arg(short, long)] + server_name: Option<Box<ServerName>>, + #[arg(short, long)] + user_id: Option<Box<UserId>>, + #[arg(short, long)] + push_key: Option<String>, + }, + + GetLatestEduCount { + server_name: Box<ServerName>, + }, +} /// All the getters and iterators in key_value/sending.rs -pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventContent> { +pub(super) async fn process(subcommand: SendingCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - Sending::ActiveRequests => { + SendingCommand::ActiveRequests => { let timer = tokio::time::Instant::now(); - let results = services().sending.db.active_requests(); + let results = services.sending.db.active_requests(); let active_requests: Result<Vec<(_, _, _)>> = results.collect(); let query_time = timer.elapsed(); @@ -16,7 +75,7 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte "Query completed in {query_time:?}:\n\n```rs\n{active_requests:#?}\n```" ))) }, - Sending::QueuedRequests { + SendingCommand::QueuedRequests { appservice_id, server_name, user_id, @@ -38,12 +97,12 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte )); } - services() + services .sending .db .queued_requests(&Destination::Appservice(appservice_id)) }, - (None, Some(server_name), None, None) => services() + (None, Some(server_name), None, None) => services .sending .db .queued_requests(&Destination::Normal(server_name.into())), @@ -55,7 +114,7 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte )); } - services() + services .sending .db .queued_requests(&Destination::Push(user_id.into(), push_key)) @@ -81,7 +140,7 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte "Query completed in {query_time:?}:\n\n```rs\n{queued_requests:#?}\n```" ))) }, - Sending::ActiveRequestsFor { + SendingCommand::ActiveRequestsFor { appservice_id, server_name, user_id, @@ -104,12 +163,12 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte )); } - services() + services .sending .db .active_requests_for(&Destination::Appservice(appservice_id)) }, - (None, Some(server_name), None, None) => services() + (None, Some(server_name), None, None) => services .sending .db .active_requests_for(&Destination::Normal(server_name.into())), @@ -121,7 +180,7 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte )); } - services() + services .sending .db .active_requests_for(&Destination::Push(user_id.into(), push_key)) @@ -147,11 +206,11 @@ pub(super) async fn sending(subcommand: Sending) -> Result<RoomMessageEventConte "Query completed in {query_time:?}:\n\n```rs\n{active_requests:#?}\n```" ))) }, - Sending::GetLatestEduCount { + SendingCommand::GetLatestEduCount { server_name, } => { let timer = tokio::time::Instant::now(); - let results = services().sending.db.get_latest_educount(&server_name); + let results = services.sending.db.get_latest_educount(&server_name); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/users.rs b/src/admin/query/users.rs index 2e73bff37..fee12fbfc 100644 --- a/src/admin/query/users.rs +++ b/src/admin/query/users.rs @@ -1,14 +1,23 @@ +use clap::Subcommand; +use conduit::Result; use ruma::events::room::message::RoomMessageEventContent; -use super::Users; -use crate::{services, Result}; +use crate::Command; + +#[derive(Debug, Subcommand)] +/// All the getters and iterators from src/database/key_value/users.rs +pub(crate) enum UsersCommand { + Iter, +} /// All the getters and iterators in key_value/users.rs -pub(super) async fn users(subcommand: Users) -> Result<RoomMessageEventContent> { +pub(super) async fn process(subcommand: UsersCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + match subcommand { - Users::Iter => { + UsersCommand::Iter => { let timer = tokio::time::Instant::now(); - let results = services().users.db.iter(); + let results = services.users.db.iter(); let users = results.collect::<Vec<_>>(); let query_time = timer.elapsed(); diff --git a/src/admin/room/room_alias_commands.rs b/src/admin/room/alias.rs similarity index 73% rename from src/admin/room/room_alias_commands.rs rename to src/admin/room/alias.rs index 4f43ac6e8..415b8a083 100644 --- a/src/admin/room/room_alias_commands.rs +++ b/src/admin/room/alias.rs @@ -1,12 +1,49 @@ use std::fmt::Write; -use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId}; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, RoomAliasId, RoomId}; -use super::RoomAliasCommand; -use crate::{escape_html, services, Result}; +use crate::{escape_html, Command}; -pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> { - let server_user = &services().globals.server_user; +#[derive(Debug, Subcommand)] +pub(crate) enum RoomAliasCommand { + /// - Make an alias point to a room. + Set { + #[arg(short, long)] + /// Set the alias even if a room is already using it + force: bool, + + /// The room id to set the alias on + room_id: Box<RoomId>, + + /// The alias localpart to use (`alias`, not `#alias:servername.tld`) + room_alias_localpart: String, + }, + + /// - Remove a local alias + Remove { + /// The alias localpart to remove (`alias`, not `#alias:servername.tld`) + room_alias_localpart: String, + }, + + /// - Show which room is using an alias + Which { + /// The alias localpart to look up (`alias`, not + /// `#alias:servername.tld`) + room_alias_localpart: String, + }, + + /// - List aliases currently being used + List { + /// If set, only list the aliases for this room + room_id: Option<Box<RoomId>>, + }, +} + +pub(super) async fn process(command: RoomAliasCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; + let server_user = &services.globals.server_user; match command { RoomAliasCommand::Set { @@ -19,7 +56,7 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu | RoomAliasCommand::Which { ref room_alias_localpart, } => { - let room_alias_str = format!("#{}:{}", room_alias_localpart, services().globals.server_name()); + let room_alias_str = format!("#{}:{}", room_alias_localpart, services.globals.server_name()); let room_alias = match RoomAliasId::parse_box(room_alias_str) { Ok(alias) => alias, Err(err) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse alias: {err}"))), @@ -29,8 +66,8 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu force, room_id, .. - } => match (force, services().rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services() + } => match (force, services.rooms.alias.resolve_local_alias(&room_alias)) { + (true, Ok(Some(id))) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -43,7 +80,7 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu (false, Ok(Some(id))) => Ok(RoomMessageEventContent::text_plain(format!( "Refusing to overwrite in use alias for {id}, use -f or --force to overwrite" ))), - (_, Ok(None)) => match services() + (_, Ok(None)) => match services .rooms .alias .set_alias(&room_alias, &room_id, server_user) @@ -55,8 +92,8 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu }, RoomAliasCommand::Remove { .. - } => match services().rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => match services() + } => match services.rooms.alias.resolve_local_alias(&room_alias) { + Ok(Some(id)) => match services .rooms .alias .remove_alias(&room_alias, server_user) @@ -70,7 +107,7 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu }, RoomAliasCommand::Which { .. - } => match services().rooms.alias.resolve_local_alias(&room_alias) { + } => match services.rooms.alias.resolve_local_alias(&room_alias) { Ok(Some(id)) => Ok(RoomMessageEventContent::text_plain(format!("Alias resolves to {id}"))), Ok(None) => Ok(RoomMessageEventContent::text_plain("Alias isn't in use.")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {err}"))), @@ -84,7 +121,7 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu room_id, } => { if let Some(room_id) = room_id { - let aliases = services() + let aliases = services .rooms .alias .local_aliases_for_room(&room_id) @@ -109,14 +146,14 @@ pub(super) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {err}"))), } } else { - let aliases = services() + let aliases = services .rooms .alias .all_local_aliases() .collect::<Result<Vec<_>, _>>(); match aliases { Ok(aliases) => { - let server_name = services().globals.server_name(); + let server_name = services.globals.server_name(); let plain_list = aliases .iter() .fold(String::new(), |mut output, (alias, id)| { diff --git a/src/admin/room/room_commands.rs b/src/admin/room/commands.rs similarity index 82% rename from src/admin/room/room_commands.rs rename to src/admin/room/commands.rs index cf0f3ddbd..2e14a8498 100644 --- a/src/admin/room/room_commands.rs +++ b/src/admin/room/commands.rs @@ -1,15 +1,18 @@ use std::fmt::Write; +use conduit::Result; use ruma::events::room::message::RoomMessageEventContent; -use crate::{escape_html, get_room_info, services, Result, PAGE_SIZE}; +use crate::{admin_command, escape_html, get_room_info, PAGE_SIZE}; -pub(super) async fn list( - _body: Vec<&str>, page: Option<usize>, exclude_disabled: bool, exclude_banned: bool, +#[admin_command] +pub(super) async fn list_rooms( + &self, page: Option<usize>, exclude_disabled: bool, exclude_banned: bool, ) -> Result<RoomMessageEventContent> { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); - let mut rooms = services() + let mut rooms = self + .services .rooms .metadata .iter_ids() @@ -18,7 +21,8 @@ pub(super) async fn list( .ok() .filter(|room_id| { if exclude_disabled - && services() + && self + .services .rooms .metadata .is_disabled(room_id) @@ -28,7 +32,8 @@ pub(super) async fn list( } if exclude_banned - && services() + && self + .services .rooms .metadata .is_banned(room_id) @@ -39,7 +44,7 @@ pub(super) async fn list( true }) - .map(|room_id| get_room_info(services(), &room_id)) + .map(|room_id| get_room_info(self.services, &room_id)) }) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); diff --git a/src/admin/room/room_directory_commands.rs b/src/admin/room/directory.rs similarity index 68% rename from src/admin/room/room_directory_commands.rs rename to src/admin/room/directory.rs index 912e970c6..7bba2eb7b 100644 --- a/src/admin/room/room_directory_commands.rs +++ b/src/admin/room/directory.rs @@ -1,21 +1,43 @@ use std::fmt::Write; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; +use clap::Subcommand; +use conduit::Result; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomId}; -use super::RoomDirectoryCommand; -use crate::{escape_html, get_room_info, services, Result, PAGE_SIZE}; +use crate::{escape_html, get_room_info, Command, PAGE_SIZE}; -pub(super) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[derive(Debug, Subcommand)] +pub(crate) enum RoomDirectoryCommand { + /// - Publish a room to the room directory + Publish { + /// The room id of the room to publish + room_id: Box<RoomId>, + }, + + /// - Unpublish a room to the room directory + Unpublish { + /// The room id of the room to unpublish + room_id: Box<RoomId>, + }, + + /// - List rooms that are published + List { + page: Option<usize>, + }, +} + +pub(super) async fn process(command: RoomDirectoryCommand, context: &Command<'_>) -> Result<RoomMessageEventContent> { + let services = context.services; match command { RoomDirectoryCommand::Publish { room_id, - } => match services().rooms.directory.set_public(&room_id) { + } => match services.rooms.directory.set_public(&room_id) { Ok(()) => Ok(RoomMessageEventContent::text_plain("Room published")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), }, RoomDirectoryCommand::Unpublish { room_id, - } => match services().rooms.directory.set_not_public(&room_id) { + } => match services.rooms.directory.set_not_public(&room_id) { Ok(()) => Ok(RoomMessageEventContent::text_plain("Room unpublished")), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to update room: {err}"))), }, @@ -24,12 +46,12 @@ pub(super) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> } => { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); - let mut rooms = services() + let mut rooms = services .rooms .directory .public_rooms() .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(services(), &id)) + .map(|id: OwnedRoomId| get_room_info(services, &id)) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/room/room_info_commands.rs b/src/admin/room/info.rs similarity index 54% rename from src/admin/room/room_info_commands.rs rename to src/admin/room/info.rs index 9182616fd..8ba0a7963 100644 --- a/src/admin/room/room_info_commands.rs +++ b/src/admin/room/info.rs @@ -1,22 +1,30 @@ +use clap::Subcommand; +use conduit::Result; use ruma::{events::room::message::RoomMessageEventContent, RoomId}; -use service::services; -use super::RoomInfoCommand; -use crate::Result; +use crate::{admin_command, admin_command_dispatch}; -pub(super) async fn process(command: RoomInfoCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - match command { - RoomInfoCommand::ListJoinedMembers { - room_id, - } => list_joined_members(body, room_id).await, - RoomInfoCommand::ViewRoomTopic { - room_id, - } => view_room_topic(body, room_id).await, - } +#[admin_command_dispatch] +#[derive(Debug, Subcommand)] +pub(crate) enum RoomInfoCommand { + /// - List joined members in a room + ListJoinedMembers { + room_id: Box<RoomId>, + }, + + /// - Displays room topic + /// + /// Room topics can be huge, so this is in its + /// own separate command + ViewRoomTopic { + room_id: Box<RoomId>, + }, } -async fn list_joined_members(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - let room_name = services() +#[admin_command] +async fn list_joined_members(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + let room_name = self + .services .rooms .state_accessor .get_name(&room_id) @@ -24,7 +32,8 @@ async fn list_joined_members(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<R .flatten() .unwrap_or_else(|| room_id.to_string()); - let members = services() + let members = self + .services .rooms .state_cache .room_members(&room_id) @@ -35,7 +44,7 @@ async fn list_joined_members(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<R .map(|user_id| { ( user_id.clone(), - services() + self.services .users .displayname(&user_id) .unwrap_or(None) @@ -58,8 +67,14 @@ async fn list_joined_members(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<R Ok(RoomMessageEventContent::notice_markdown(output_plain)) } -async fn view_room_topic(_body: Vec<&str>, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { - let Some(room_topic) = services().rooms.state_accessor.get_room_topic(&room_id)? else { +#[admin_command] +async fn view_room_topic(&self, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + let Some(room_topic) = self + .services + .rooms + .state_accessor + .get_room_topic(&room_id)? + else { return Ok(RoomMessageEventContent::text_plain("Room does not have a room topic set.")); }; diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 8f125f0a3..da7acb80c 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -1,19 +1,23 @@ -mod room_alias_commands; -mod room_commands; -mod room_directory_commands; -mod room_info_commands; -mod room_moderation_commands; +mod alias; +mod commands; +mod directory; +mod info; +mod moderation; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomId, RoomOrAliasId}; -use self::room_commands::list; +use self::{ + alias::RoomAliasCommand, directory::RoomDirectoryCommand, info::RoomInfoCommand, moderation::RoomModerationCommand, +}; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum RoomCommand { /// - List all rooms the server knows about - List { + #[clap(alias = "list")] + ListRooms { page: Option<usize>, /// Excludes rooms that we have federation disabled with @@ -41,149 +45,3 @@ pub(super) enum RoomCommand { /// - Manage the room directory Directory(RoomDirectoryCommand), } - -#[derive(Debug, Subcommand)] -pub(super) enum RoomInfoCommand { - /// - List joined members in a room - ListJoinedMembers { - room_id: Box<RoomId>, - }, - - /// - Displays room topic - /// - /// Room topics can be huge, so this is in its - /// own separate command - ViewRoomTopic { - room_id: Box<RoomId>, - }, -} - -#[derive(Debug, Subcommand)] -pub(super) enum RoomAliasCommand { - /// - Make an alias point to a room. - Set { - #[arg(short, long)] - /// Set the alias even if a room is already using it - force: bool, - - /// The room id to set the alias on - room_id: Box<RoomId>, - - /// The alias localpart to use (`alias`, not `#alias:servername.tld`) - room_alias_localpart: String, - }, - - /// - Remove a local alias - Remove { - /// The alias localpart to remove (`alias`, not `#alias:servername.tld`) - room_alias_localpart: String, - }, - - /// - Show which room is using an alias - Which { - /// The alias localpart to look up (`alias`, not - /// `#alias:servername.tld`) - room_alias_localpart: String, - }, - - /// - List aliases currently being used - List { - /// If set, only list the aliases for this room - room_id: Option<Box<RoomId>>, - }, -} - -#[derive(Debug, Subcommand)] -pub(super) enum RoomDirectoryCommand { - /// - Publish a room to the room directory - Publish { - /// The room id of the room to publish - room_id: Box<RoomId>, - }, - - /// - Unpublish a room to the room directory - Unpublish { - /// The room id of the room to unpublish - room_id: Box<RoomId>, - }, - - /// - List rooms that are published - List { - page: Option<usize>, - }, -} - -#[derive(Debug, Subcommand)] -pub(super) enum RoomModerationCommand { - /// - Bans a room from local users joining and evicts all our local users - /// from the room. Also blocks any invites (local and remote) for the - /// banned room. - /// - /// Server admins (users in the conduwuit admin room) will not be evicted - /// and server admins can still join the room. To evict admins too, use - /// --force (also ignores errors) To disable incoming federation of the - /// room, use --disable-federation - BanRoom { - #[arg(short, long)] - /// Evicts admins out of the room and ignores any potential errors when - /// making our local users leave the room - force: bool, - - #[arg(long)] - /// Disables incoming federation of the room after banning and evicting - /// users - disable_federation: bool, - - /// The room in the format of `!roomid:example.com` or a room alias in - /// the format of `#roomalias:example.com` - room: Box<RoomOrAliasId>, - }, - - /// - Bans a list of rooms (room IDs and room aliases) from a newline - /// delimited codeblock similar to `user deactivate-all` - BanListOfRooms { - #[arg(short, long)] - /// Evicts admins out of the room and ignores any potential errors when - /// making our local users leave the room - force: bool, - - #[arg(long)] - /// Disables incoming federation of the room after banning and evicting - /// users - disable_federation: bool, - }, - - /// - Unbans a room to allow local users to join again - /// - /// To re-enable incoming federation of the room, use --enable-federation - UnbanRoom { - #[arg(long)] - /// Enables incoming federation of the room after unbanning - enable_federation: bool, - - /// The room in the format of `!roomid:example.com` or a room alias in - /// the format of `#roomalias:example.com` - room: Box<RoomOrAliasId>, - }, - - /// - List of all rooms we have banned - ListBannedRooms, -} - -pub(super) async fn process(command: RoomCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - RoomCommand::Info(command) => room_info_commands::process(command, body).await?, - - RoomCommand::Alias(command) => room_alias_commands::process(command, body).await?, - - RoomCommand::Directory(command) => room_directory_commands::process(command, body).await?, - - RoomCommand::Moderation(command) => room_moderation_commands::process(command, body).await?, - - RoomCommand::List { - page, - exclude_disabled, - exclude_banned, - } => list(body, page, exclude_disabled, exclude_banned).await?, - }) -} diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/moderation.rs similarity index 72% rename from src/admin/room/room_moderation_commands.rs rename to src/admin/room/moderation.rs index 8ad8295b0..869103798 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/moderation.rs @@ -1,37 +1,77 @@ use api::client::leave_room; +use clap::Subcommand; use conduit::{debug, error, info, warn, Result}; use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; -use super::RoomModerationCommand; -use crate::{get_room_info, services}; - -pub(super) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - match command { - RoomModerationCommand::BanRoom { - force, - room, - disable_federation, - } => ban_room(body, force, room, disable_federation).await, - RoomModerationCommand::BanListOfRooms { - force, - disable_federation, - } => ban_list_of_rooms(body, force, disable_federation).await, - RoomModerationCommand::UnbanRoom { - room, - enable_federation, - } => unban_room(body, room, enable_federation).await, - RoomModerationCommand::ListBannedRooms => list_banned_rooms(body).await, - } +use crate::{admin_command, admin_command_dispatch, get_room_info}; + +#[admin_command_dispatch] +#[derive(Debug, Subcommand)] +pub(crate) enum RoomModerationCommand { + /// - Bans a room from local users joining and evicts all our local users + /// from the room. Also blocks any invites (local and remote) for the + /// banned room. + /// + /// Server admins (users in the conduwuit admin room) will not be evicted + /// and server admins can still join the room. To evict admins too, use + /// --force (also ignores errors) To disable incoming federation of the + /// room, use --disable-federation + BanRoom { + #[arg(short, long)] + /// Evicts admins out of the room and ignores any potential errors when + /// making our local users leave the room + force: bool, + + #[arg(long)] + /// Disables incoming federation of the room after banning and evicting + /// users + disable_federation: bool, + + /// The room in the format of `!roomid:example.com` or a room alias in + /// the format of `#roomalias:example.com` + room: Box<RoomOrAliasId>, + }, + + /// - Bans a list of rooms (room IDs and room aliases) from a newline + /// delimited codeblock similar to `user deactivate-all` + BanListOfRooms { + #[arg(short, long)] + /// Evicts admins out of the room and ignores any potential errors when + /// making our local users leave the room + force: bool, + + #[arg(long)] + /// Disables incoming federation of the room after banning and evicting + /// users + disable_federation: bool, + }, + + /// - Unbans a room to allow local users to join again + /// + /// To re-enable incoming federation of the room, use --enable-federation + UnbanRoom { + #[arg(long)] + /// Enables incoming federation of the room after unbanning + enable_federation: bool, + + /// The room in the format of `!roomid:example.com` or a room alias in + /// the format of `#roomalias:example.com` + room: Box<RoomOrAliasId>, + }, + + /// - List of all rooms we have banned + ListBannedRooms, } +#[admin_command] async fn ban_room( - _body: Vec<&str>, force: bool, room: Box<RoomOrAliasId>, disable_federation: bool, + &self, force: bool, disable_federation: bool, room: Box<RoomOrAliasId>, ) -> Result<RoomMessageEventContent> { debug!("Got room alias or ID: {}", room); - let admin_room_alias = &services().globals.admin_alias; + let admin_room_alias = &self.services.globals.admin_alias; - if let Some(admin_room_id) = services().admin.get_admin_room()? { + if let Some(admin_room_id) = self.services.admin.get_admin_room()? { if room.to_string().eq(&admin_room_id) || room.to_string().eq(admin_room_alias) { return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); } @@ -50,7 +90,7 @@ async fn ban_room( debug!("Room specified is a room ID, banning room ID"); - services().rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true)?; room_id } else if room.is_room_alias_id() { @@ -69,12 +109,13 @@ async fn ban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); - match services() + match self + .services .rooms .alias .resolve_alias(&room_alias, None) @@ -92,7 +133,7 @@ async fn ban_room( } }; - services().rooms.metadata.ban_room(&room_id, true)?; + self.services.rooms.metadata.ban_room(&room_id, true)?; room_id } else { @@ -104,20 +145,21 @@ async fn ban_room( debug!("Making all users leave the room {}", &room); if force { - for local_user in services() + for local_user in self + .services .rooms .state_cache .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - services().globals.user_is_local(local_user) + self.services.globals.user_is_local(local_user) // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would // fail auth check) - && (services().globals.user_is_local(local_user) + && (self.services.globals.user_is_local(local_user) // since this is a force operation, assume user is an admin // if somehow this fails - && services() + && self.services .users .is_admin(local_user) .unwrap_or(true)) @@ -128,30 +170,31 @@ async fn ban_room( &local_user, &room_id ); - if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in services() + for local_user in self + .services .rooms .state_cache .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + local_user.server_name() == self.services.globals.server_name() // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) && (local_user.server_name() - == services().globals.server_name() - && !services() + == self.services.globals.server_name() + && !self.services .users .is_admin(local_user) .unwrap_or(false)) }) }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e @@ -166,7 +209,7 @@ async fn ban_room( } if disable_federation { - services().rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true)?; return Ok(RoomMessageEventContent::text_plain( "Room banned, removed all our local users, and disabled incoming federation with room.", )); @@ -178,19 +221,22 @@ async fn ban_room( )) } -async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: bool) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +async fn ban_list_of_rooms(&self, force: bool, disable_federation: bool) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let rooms_s = body - .clone() - .drain(1..body.len().saturating_sub(1)) + let rooms_s = self + .body + .to_vec() + .drain(1..self.body.len().saturating_sub(1)) .collect::<Vec<_>>(); - let admin_room_alias = &services().globals.admin_alias; + let admin_room_alias = &self.services.globals.admin_alias; let mut room_ban_count: usize = 0; let mut room_ids: Vec<OwnedRoomId> = Vec::new(); @@ -198,7 +244,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo for &room in &rooms_s { match <&RoomOrAliasId>::try_from(room) { Ok(room_alias_or_id) => { - if let Some(admin_room_id) = services().admin.get_admin_room()? { + if let Some(admin_room_id) = self.services.admin.get_admin_room()? { if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(admin_room_alias) { info!("User specified admin room in bulk ban list, ignoring"); continue; @@ -231,7 +277,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo match RoomAliasId::parse(room_alias_or_id) { Ok(room_alias) => { let room_id = - if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { room_id } else { debug!( @@ -239,7 +285,8 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo ID over federation" ); - match services() + match self + .services .rooms .alias .resolve_alias(&room_alias, None) @@ -303,28 +350,35 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo } for room_id in room_ids { - if services().rooms.metadata.ban_room(&room_id, true).is_ok() { + if self + .services + .rooms + .metadata + .ban_room(&room_id, true) + .is_ok() + { debug!("Banned {room_id} successfully"); room_ban_count = room_ban_count.saturating_add(1); } debug!("Making all users leave the room {}", &room_id); if force { - for local_user in services() + for local_user in self + .services .rooms .state_cache .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + local_user.server_name() == self.services.globals.server_name() // additional wrapped check here is to avoid adding remote // users who are in the admin room to the list of local // users (would fail auth check) && (local_user.server_name() - == services().globals.server_name() + == self.services.globals.server_name() // since this is a force operation, assume user is an // admin if somehow this fails - && services() + && self.services .users .is_admin(local_user) .unwrap_or(true)) @@ -334,31 +388,32 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, room_id ); - if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } } } else { - for local_user in services() + for local_user in self + .services .rooms .state_cache .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + local_user.server_name() == self.services.globals.server_name() // additional wrapped check here is to avoid adding remote // users who are in the admin room to the list of local // users (would fail auth check) && (local_user.server_name() - == services().globals.server_name() - && !services() + == self.services.globals.server_name() + && !self.services .users .is_admin(local_user) .unwrap_or(false)) }) }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(services(), &local_user, &room_id, None).await { + if let Err(e) = leave_room(self.services, &local_user, &room_id, None).await { error!( "Error attempting to make local user {} leave room {} during bulk room banning: {}", &local_user, &room_id, e @@ -374,7 +429,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo } if disable_federation { - services().rooms.metadata.disable_room(&room_id, true)?; + self.services.rooms.metadata.disable_room(&room_id, true)?; } } @@ -390,9 +445,8 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo } } -async fn unban_room( - _body: Vec<&str>, room: Box<RoomOrAliasId>, enable_federation: bool, -) -> Result<RoomMessageEventContent> { +#[admin_command] +async fn unban_room(&self, enable_federation: bool, room: Box<RoomOrAliasId>) -> Result<RoomMessageEventContent> { let room_id = if room.is_room_id() { let room_id = match RoomId::parse(&room) { Ok(room_id) => room_id, @@ -406,7 +460,7 @@ async fn unban_room( debug!("Room specified is a room ID, unbanning room ID"); - services().rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false)?; room_id } else if room.is_room_alias_id() { @@ -425,12 +479,13 @@ async fn unban_room( get_alias_helper to fetch room ID remotely" ); - let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + let room_id = if let Some(room_id) = self.services.rooms.alias.resolve_local_alias(&room_alias)? { room_id } else { debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); - match services() + match self + .services .rooms .alias .resolve_alias(&room_alias, None) @@ -448,7 +503,7 @@ async fn unban_room( } }; - services().rooms.metadata.ban_room(&room_id, false)?; + self.services.rooms.metadata.ban_room(&room_id, false)?; room_id } else { @@ -459,7 +514,7 @@ async fn unban_room( }; if enable_federation { - services().rooms.metadata.disable_room(&room_id, false)?; + self.services.rooms.metadata.disable_room(&room_id, false)?; return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); } @@ -469,8 +524,10 @@ async fn unban_room( )) } -async fn list_banned_rooms(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let rooms = services() +#[admin_command] +async fn list_banned_rooms(&self) -> Result<RoomMessageEventContent> { + let rooms = self + .services .rooms .metadata .list_banned_rooms() @@ -484,7 +541,7 @@ async fn list_banned_rooms(_body: Vec<&str>) -> Result<RoomMessageEventContent> let mut rooms = room_ids .into_iter() - .map(|room_id| get_room_info(services(), &room_id)) + .map(|room_id| get_room_info(self.services, &room_id)) .collect::<Vec<_>>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index c30f8f871..34b0f0a01 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -1,12 +1,14 @@ -use std::fmt::Write; +use std::{fmt::Write, sync::Arc}; use conduit::{info, utils::time, warn, Err, Result}; use ruma::events::room::message::RoomMessageEventContent; -use crate::services; +use crate::admin_command; -pub(super) async fn uptime(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let elapsed = services() +#[admin_command] +pub(super) async fn uptime(&self) -> Result<RoomMessageEventContent> { + let elapsed = self + .services .server .started .elapsed() @@ -16,13 +18,15 @@ pub(super) async fn uptime(_body: Vec<&str>) -> Result<RoomMessageEventContent> Ok(RoomMessageEventContent::notice_plain(format!("{result}."))) } -pub(super) async fn show_config(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn show_config(&self) -> Result<RoomMessageEventContent> { // Construct and send the response - Ok(RoomMessageEventContent::text_plain(format!("{}", services().globals.config))) + Ok(RoomMessageEventContent::text_plain(format!("{}", self.services.globals.config))) } +#[admin_command] pub(super) async fn list_features( - _body: Vec<&str>, available: bool, enabled: bool, comma: bool, + &self, available: bool, enabled: bool, comma: bool, ) -> Result<RoomMessageEventContent> { let delim = if comma { "," @@ -62,9 +66,10 @@ pub(super) async fn list_features( Ok(RoomMessageEventContent::text_markdown(features)) } -pub(super) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let services_usage = services().memory_usage().await?; - let database_usage = services().db.db.memory_usage()?; +#[admin_command] +pub(super) async fn memory_usage(&self) -> Result<RoomMessageEventContent> { + let services_usage = self.services.memory_usage().await?; + let database_usage = self.services.db.db.memory_usage()?; let allocator_usage = conduit::alloc::memory_usage().map_or(String::new(), |s| format!("\nAllocator:\n{s}")); Ok(RoomMessageEventContent::text_plain(format!( @@ -72,14 +77,16 @@ pub(super) async fn memory_usage(_body: Vec<&str>) -> Result<RoomMessageEventCon ))) } -pub(super) async fn clear_caches(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - services().clear_cache().await; +#[admin_command] +pub(super) async fn clear_caches(&self) -> Result<RoomMessageEventContent> { + self.services.clear_cache().await; Ok(RoomMessageEventContent::text_plain("Done.")) } -pub(super) async fn list_backups(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let result = services().globals.db.backup_list()?; +#[admin_command] +pub(super) async fn list_backups(&self) -> Result<RoomMessageEventContent> { + let result = self.services.globals.db.backup_list()?; if result.is_empty() { Ok(RoomMessageEventContent::text_plain("No backups found.")) @@ -88,46 +95,51 @@ pub(super) async fn list_backups(_body: Vec<&str>) -> Result<RoomMessageEventCon } } -pub(super) async fn backup_database(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let mut result = services() +#[admin_command] +pub(super) async fn backup_database(&self) -> Result<RoomMessageEventContent> { + let globals = Arc::clone(&self.services.globals); + let mut result = self + .services .server .runtime() - .spawn_blocking(move || match services().globals.db.backup() { + .spawn_blocking(move || match globals.db.backup() { Ok(()) => String::new(), Err(e) => (*e).to_string(), }) - .await - .unwrap(); + .await?; if result.is_empty() { - result = services().globals.db.backup_list()?; + result = self.services.globals.db.backup_list()?; } - Ok(RoomMessageEventContent::text_plain(&result)) + Ok(RoomMessageEventContent::notice_markdown(result)) } -pub(super) async fn list_database_files(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - let result = services().globals.db.file_list()?; +#[admin_command] +pub(super) async fn list_database_files(&self) -> Result<RoomMessageEventContent> { + let result = self.services.globals.db.file_list()?; Ok(RoomMessageEventContent::notice_markdown(result)) } -pub(super) async fn admin_notice(_body: Vec<&str>, message: Vec<String>) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn admin_notice(&self, message: Vec<String>) -> Result<RoomMessageEventContent> { let message = message.join(" "); - services().admin.send_text(&message).await; + self.services.admin.send_text(&message).await; Ok(RoomMessageEventContent::notice_plain("Notice was sent to #admins")) } -#[cfg(conduit_mods)] -pub(super) async fn reload(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - services().server.reload()?; +#[admin_command] +pub(super) async fn reload_mods(&self) -> Result<RoomMessageEventContent> { + self.services.server.reload()?; Ok(RoomMessageEventContent::notice_plain("Reloading server...")) } +#[admin_command] #[cfg(unix)] -pub(super) async fn restart(_body: Vec<&str>, force: bool) -> Result<RoomMessageEventContent> { +pub(super) async fn restart(&self, force: bool) -> Result<RoomMessageEventContent> { use conduit::utils::sys::current_exe_deleted; if !force && current_exe_deleted() { @@ -137,14 +149,15 @@ pub(super) async fn restart(_body: Vec<&str>, force: bool) -> Result<RoomMessage ); } - services().server.restart()?; + self.services.server.restart()?; Ok(RoomMessageEventContent::notice_plain("Restarting server...")) } -pub(super) async fn shutdown(_body: Vec<&str>) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn shutdown(&self) -> Result<RoomMessageEventContent> { warn!("shutdown command"); - services().server.shutdown()?; + self.services.server.shutdown()?; Ok(RoomMessageEventContent::notice_plain("Shutting down server...")) } diff --git a/src/admin/server/mod.rs b/src/admin/server/mod.rs index 958cc54ba..222c537a0 100644 --- a/src/admin/server/mod.rs +++ b/src/admin/server/mod.rs @@ -2,10 +2,10 @@ use clap::Subcommand; use conduit::Result; -use ruma::events::room::message::RoomMessageEventContent; -use self::commands::*; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum ServerCommand { /// - Time elapsed since startup @@ -47,9 +47,9 @@ pub(super) enum ServerCommand { message: Vec<String>, }, - #[cfg(conduit_mods)] /// - Hot-reload the server - Reload, + #[clap(alias = "reload")] + ReloadMods, #[cfg(unix)] /// - Restart the server @@ -61,30 +61,3 @@ pub(super) enum ServerCommand { /// - Shutdown the server Shutdown, } - -pub(super) async fn process(command: ServerCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - ServerCommand::Uptime => uptime(body).await?, - ServerCommand::ShowConfig => show_config(body).await?, - ServerCommand::ListFeatures { - available, - enabled, - comma, - } => list_features(body, available, enabled, comma).await?, - ServerCommand::MemoryUsage => memory_usage(body).await?, - ServerCommand::ClearCaches => clear_caches(body).await?, - ServerCommand::ListBackups => list_backups(body).await?, - ServerCommand::BackupDatabase => backup_database(body).await?, - ServerCommand::ListDatabaseFiles => list_database_files(body).await?, - ServerCommand::AdminNotice { - message, - } => admin_notice(body, message).await?, - #[cfg(conduit_mods)] - ServerCommand::Reload => reload(body).await?, - #[cfg(unix)] - ServerCommand::Restart { - force, - } => restart(body, force).await?, - ServerCommand::Shutdown => shutdown(body).await?, - }) -} diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 69019d79e..bdd35d59e 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, fmt::Write as _}; use api::client::{join_room_by_id_helper, leave_all_rooms, update_avatar_url, update_displayname}; -use conduit::{utils, Result}; +use conduit::{error, info, utils, warn, Result}; use ruma::{ events::{ room::message::RoomMessageEventContent, @@ -10,17 +10,17 @@ }, OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, }; -use tracing::{error, info, warn}; use crate::{ - escape_html, get_room_info, services, + admin_command, escape_html, get_room_info, utils::{parse_active_local_user_id, parse_local_user_id}, }; const AUTO_GEN_PASSWORD_LENGTH: usize = 25; -pub(super) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> { - match services().users.list_local_users() { +#[admin_command] +pub(super) async fn list_users(&self) -> Result<RoomMessageEventContent> { + match self.services.users.list_local_users() { Ok(users) => { let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); plain_msg += users.join("\n").as_str(); @@ -32,13 +32,12 @@ pub(super) async fn list(_body: Vec<&str>) -> Result<RoomMessageEventContent> { } } -pub(super) async fn create( - _body: Vec<&str>, username: String, password: Option<String>, -) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn create_user(&self, username: String, password: Option<String>) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(services(), &username)?; + let user_id = parse_local_user_id(self.services, &username)?; - if services().users.exists(&user_id)? { + if self.services.users.exists(&user_id)? { return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists"))); } @@ -51,30 +50,33 @@ pub(super) async fn create( let password = password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); // Create user - services().users.create(&user_id, Some(password.as_str()))?; + self.services + .users + .create(&user_id, Some(password.as_str()))?; // Default to pretty displayname let mut displayname = user_id.localpart().to_owned(); // If `new_user_displayname_suffix` is set, registration will push whatever // content is set to the user's display name with a space before it - if !services() + if !self + .services .globals .config .new_user_displayname_suffix .is_empty() { - write!(displayname, " {}", services().globals.config.new_user_displayname_suffix) + write!(displayname, " {}", self.services.globals.config.new_user_displayname_suffix) .expect("should be able to write to string buffer"); } - services() + self.services .users .set_displayname(&user_id, Some(displayname)) .await?; // Initial account data - services().account_data.update( + self.services.account_data.update( None, &user_id, ruma::events::GlobalAccountDataEventType::PushRules @@ -88,12 +90,13 @@ pub(super) async fn create( .expect("to json value always works"), )?; - if !services().globals.config.auto_join_rooms.is_empty() { - for room in &services().globals.config.auto_join_rooms { - if !services() + if !self.services.globals.config.auto_join_rooms.is_empty() { + for room in &self.services.globals.config.auto_join_rooms { + if !self + .services .rooms .state_cache - .server_in_room(services().globals.server_name(), room)? + .server_in_room(self.services.globals.server_name(), room)? { warn!("Skipping room {room} to automatically join as we have never joined before."); continue; @@ -101,11 +104,11 @@ pub(super) async fn create( if let Some(room_id_server_name) = room.server_name() { match join_room_by_id_helper( - services(), + self.services, &user_id, room, Some("Automatically joining this room upon registration".to_owned()), - &[room_id_server_name.to_owned(), services().globals.server_name().to_owned()], + &[room_id_server_name.to_owned(), self.services.globals.server_name().to_owned()], None, ) .await @@ -130,38 +133,38 @@ pub(super) async fn create( ))) } -pub(super) async fn deactivate( - _body: Vec<&str>, no_leave_rooms: bool, user_id: String, -) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn deactivate(&self, no_leave_rooms: bool, user_id: String) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(services(), &user_id)?; + let user_id = parse_local_user_id(self.services, &user_id)?; // don't deactivate the server service account - if user_id == services().globals.server_user { + if user_id == self.services.globals.server_user { return Ok(RoomMessageEventContent::text_plain( "Not allowed to deactivate the server service account.", )); } - services().users.deactivate_account(&user_id)?; + self.services.users.deactivate_account(&user_id)?; if !no_leave_rooms { - services() + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "Making {user_id} leave all rooms after deactivation..." ))) .await; - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = self + .services .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) .collect(); - update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?; - update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?; - leave_all_rooms(services(), &user_id).await; + update_displayname(self.services, user_id.clone(), None, all_joined_rooms.clone()).await?; + update_avatar_url(self.services, user_id.clone(), None, None, all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } Ok(RoomMessageEventContent::text_plain(format!( @@ -169,10 +172,11 @@ pub(super) async fn deactivate( ))) } -pub(super) async fn reset_password(_body: Vec<&str>, username: String) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(services(), &username)?; +#[admin_command] +pub(super) async fn reset_password(&self, username: String) -> Result<RoomMessageEventContent> { + let user_id = parse_local_user_id(self.services, &username)?; - if user_id == services().globals.server_user { + if user_id == self.services.globals.server_user { return Ok(RoomMessageEventContent::text_plain( "Not allowed to set the password for the server account. Please use the emergency password config option.", )); @@ -180,7 +184,8 @@ pub(super) async fn reset_password(_body: Vec<&str>, username: String) -> Result let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); - match services() + match self + .services .users .set_password(&user_id, Some(new_password.as_str())) { @@ -193,28 +198,29 @@ pub(super) async fn reset_password(_body: Vec<&str>, username: String) -> Result } } -pub(super) async fn deactivate_all( - body: Vec<&str>, no_leave_rooms: bool, force: bool, -) -> Result<RoomMessageEventContent> { - if body.len() < 2 || !body[0].trim().starts_with("```") || body.last().unwrap_or(&"").trim() != "```" { +#[admin_command] +pub(super) async fn deactivate_all(&self, no_leave_rooms: bool, force: bool) -> Result<RoomMessageEventContent> { + if self.body.len() < 2 || !self.body[0].trim().starts_with("```") || self.body.last().unwrap_or(&"").trim() != "```" + { return Ok(RoomMessageEventContent::text_plain( "Expected code block in command body. Add --help for details.", )); } - let usernames = body - .clone() - .drain(1..body.len().saturating_sub(1)) + let usernames = self + .body + .to_vec() + .drain(1..self.body.len().saturating_sub(1)) .collect::<Vec<_>>(); let mut user_ids: Vec<OwnedUserId> = Vec::with_capacity(usernames.len()); let mut admins = Vec::new(); for username in usernames { - match parse_active_local_user_id(services(), username) { + match parse_active_local_user_id(self.services, username) { Ok(user_id) => { - if services().users.is_admin(&user_id)? && !force { - services() + if self.services.users.is_admin(&user_id)? && !force { + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is an admin and --force is not set, skipping over" @@ -225,8 +231,8 @@ pub(super) async fn deactivate_all( } // don't deactivate the server service account - if user_id == services().globals.server_user { - services() + if user_id == self.services.globals.server_user { + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is the server service account, skipping over" @@ -238,7 +244,7 @@ pub(super) async fn deactivate_all( user_ids.push(user_id); }, Err(e) => { - services() + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!( "{username} is not a valid username, skipping over: {e}" @@ -252,24 +258,25 @@ pub(super) async fn deactivate_all( let mut deactivation_count: usize = 0; for user_id in user_ids { - match services().users.deactivate_account(&user_id) { + match self.services.users.deactivate_account(&user_id) { Ok(()) => { deactivation_count = deactivation_count.saturating_add(1); if !no_leave_rooms { info!("Forcing user {user_id} to leave all rooms apart of deactivate-all"); - let all_joined_rooms: Vec<OwnedRoomId> = services() + let all_joined_rooms: Vec<OwnedRoomId> = self + .services .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) .collect(); - update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?; - update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?; - leave_all_rooms(services(), &user_id).await; + update_displayname(self.services, user_id.clone(), None, all_joined_rooms.clone()).await?; + update_avatar_url(self.services, user_id.clone(), None, None, all_joined_rooms).await?; + leave_all_rooms(self.services, &user_id).await; } }, Err(e) => { - services() + self.services .admin .send_message(RoomMessageEventContent::text_plain(format!("Failed deactivating user: {e}"))) .await; @@ -290,16 +297,18 @@ pub(super) async fn deactivate_all( } } -pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> { +#[admin_command] +pub(super) async fn list_joined_rooms(&self, user_id: String) -> Result<RoomMessageEventContent> { // Validate user id - let user_id = parse_local_user_id(services(), &user_id)?; + let user_id = parse_local_user_id(self.services, &user_id)?; - let mut rooms: Vec<(OwnedRoomId, u64, String)> = services() + let mut rooms: Vec<(OwnedRoomId, u64, String)> = self + .services .rooms .state_cache .rooms_joined(&user_id) .filter_map(Result::ok) - .map(|room_id| get_room_info(services(), &room_id)) + .map(|room_id| get_room_info(self.services, &room_id)) .collect(); if rooms.is_empty() { @@ -341,35 +350,38 @@ pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Resu Ok(RoomMessageEventContent::text_html(output_plain, output_html)) } +#[admin_command] pub(super) async fn force_join_room( - _body: Vec<&str>, user_id: String, room_id: OwnedRoomOrAliasId, + &self, user_id: String, room_id: OwnedRoomOrAliasId, ) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(services(), &user_id)?; - let room_id = services().rooms.alias.resolve(&room_id).await?; + let user_id = parse_local_user_id(self.services, &user_id)?; + let room_id = self.services.rooms.alias.resolve(&room_id).await?; assert!( - services().globals.user_is_local(&user_id), + self.services.globals.user_is_local(&user_id), "Parsed user_id must be a local user" ); - join_room_by_id_helper(services(), &user_id, &room_id, None, &[], None).await?; + join_room_by_id_helper(self.services, &user_id, &room_id, None, &[], None).await?; Ok(RoomMessageEventContent::notice_markdown(format!( "{user_id} has been joined to {room_id}.", ))) } -pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result<RoomMessageEventContent> { - let user_id = parse_local_user_id(services(), &user_id)?; - let displayname = services() +#[admin_command] +pub(super) async fn make_user_admin(&self, user_id: String) -> Result<RoomMessageEventContent> { + let user_id = parse_local_user_id(self.services, &user_id)?; + let displayname = self + .services .users .displayname(&user_id)? .unwrap_or_else(|| user_id.to_string()); assert!( - services().globals.user_is_local(&user_id), + self.services.globals.user_is_local(&user_id), "Parsed user_id must be a local user" ); - services() + self.services .admin .make_user_admin(&user_id, displayname) .await?; @@ -379,12 +391,14 @@ pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result ))) } +#[admin_command] pub(super) async fn put_room_tag( - _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, tag: String, + &self, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(services(), &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id)?; - let event = services() + let event = self + .services .account_data .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; @@ -402,7 +416,7 @@ pub(super) async fn put_room_tag( .tags .insert(tag.clone().into(), TagInfo::new()); - services().account_data.update( + self.services.account_data.update( Some(&room_id), &user_id, RoomAccountDataEventType::Tag, @@ -414,12 +428,14 @@ pub(super) async fn put_room_tag( ))) } +#[admin_command] pub(super) async fn delete_room_tag( - _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, tag: String, + &self, user_id: String, room_id: Box<RoomId>, tag: String, ) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(services(), &user_id)?; + let user_id = parse_active_local_user_id(self.services, &user_id)?; - let event = services() + let event = self + .services .account_data .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; @@ -434,7 +450,7 @@ pub(super) async fn delete_room_tag( tags_event.content.tags.remove(&tag.clone().into()); - services().account_data.update( + self.services.account_data.update( Some(&room_id), &user_id, RoomAccountDataEventType::Tag, @@ -446,12 +462,12 @@ pub(super) async fn delete_room_tag( ))) } -pub(super) async fn get_room_tags( - _body: Vec<&str>, user_id: String, room_id: Box<RoomId>, -) -> Result<RoomMessageEventContent> { - let user_id = parse_active_local_user_id(services(), &user_id)?; +#[admin_command] +pub(super) async fn get_room_tags(&self, user_id: String, room_id: Box<RoomId>) -> Result<RoomMessageEventContent> { + let user_id = parse_active_local_user_id(self.services, &user_id)?; - let event = services() + let event = self + .services .account_data .get(Some(&room_id), &user_id, RoomAccountDataEventType::Tag)?; diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index 1b92d668c..b0c0bd1ec 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -2,14 +2,16 @@ use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomOrAliasId, RoomId}; +use ruma::{OwnedRoomOrAliasId, RoomId}; -use self::commands::*; +use crate::admin_command_dispatch; +#[admin_command_dispatch] #[derive(Debug, Subcommand)] pub(super) enum UserCommand { /// - Create a new user - Create { + #[clap(alias = "create")] + CreateUser { /// Username of the new user username: String, /// Password of the new user, if unspecified one is generated @@ -56,7 +58,8 @@ pub(super) enum UserCommand { }, /// - List local users in the database - List, + #[clap(alias = "list")] + ListUsers, /// - Lists all the rooms (local and remote) that the specified user is /// joined in @@ -101,48 +104,3 @@ pub(super) enum UserCommand { room_id: Box<RoomId>, }, } - -pub(super) async fn process(command: UserCommand, body: Vec<&str>) -> Result<RoomMessageEventContent> { - Ok(match command { - UserCommand::List => list(body).await?, - UserCommand::Create { - username, - password, - } => create(body, username, password).await?, - UserCommand::Deactivate { - no_leave_rooms, - user_id, - } => deactivate(body, no_leave_rooms, user_id).await?, - UserCommand::ResetPassword { - username, - } => reset_password(body, username).await?, - UserCommand::DeactivateAll { - no_leave_rooms, - force, - } => deactivate_all(body, no_leave_rooms, force).await?, - UserCommand::ListJoinedRooms { - user_id, - } => list_joined_rooms(body, user_id).await?, - UserCommand::ForceJoinRoom { - user_id, - room_id, - } => force_join_room(body, user_id, room_id).await?, - UserCommand::MakeUserAdmin { - user_id, - } => make_user_admin(body, user_id).await?, - UserCommand::PutRoomTag { - user_id, - room_id, - tag, - } => put_room_tag(body, user_id, room_id, tag).await?, - UserCommand::DeleteRoomTag { - user_id, - room_id, - tag, - } => delete_room_tag(body, user_id, room_id, tag).await?, - UserCommand::GetRoomTags { - user_id, - room_id, - } => get_room_tags(body, user_id, room_id).await?, - }) -} diff --git a/src/macros/admin.rs b/src/macros/admin.rs index 4189d64f9..d4ce7ad5f 100644 --- a/src/macros/admin.rs +++ b/src/macros/admin.rs @@ -2,15 +2,27 @@ use proc_macro::{Span, TokenStream}; use proc_macro2::TokenStream as TokenStream2; use quote::{quote, ToTokens}; -use syn::{Error, Fields, Ident, ItemEnum, Meta, Variant}; +use syn::{parse_quote, Attribute, Error, Fields, Ident, ItemEnum, ItemFn, Meta, Variant}; use crate::{utils::camel_to_snake_string, Result}; +pub(super) fn command(mut item: ItemFn, _args: &[Meta]) -> Result<TokenStream> { + let attr: Attribute = parse_quote! { + #[conduit_macros::implement(crate::Command, params = "<'_>")] + }; + + item.attrs.push(attr); + Ok(item.into_token_stream().into()) +} + pub(super) fn command_dispatch(item: ItemEnum, _args: &[Meta]) -> Result<TokenStream> { let name = &item.ident; let arm: Vec<TokenStream2> = item.variants.iter().map(dispatch_arm).try_collect()?; let switch = quote! { - pub(super) async fn process(command: #name, body: Vec<&str>) -> Result<RoomMessageEventContent> { + pub(super) async fn process( + command: #name, + context: &crate::Command<'_> + ) -> Result<ruma::events::room::message::RoomMessageEventContent> { use #name::*; #[allow(non_snake_case)] Ok(match command { @@ -34,7 +46,7 @@ fn dispatch_arm(v: &Variant) -> Result<TokenStream2> { let field = fields.named.iter().filter_map(|f| f.ident.as_ref()); let arg = field.clone(); quote! { - #name { #( #field ),* } => Box::pin(#handler(&body, #( #arg ),*)).await?, + #name { #( #field ),* } => Box::pin(context.#handler(#( #arg ),*)).await?, } }, Fields::Unnamed(fields) => { @@ -42,12 +54,12 @@ fn dispatch_arm(v: &Variant) -> Result<TokenStream2> { return Err(Error::new(Span::call_site().into(), "One unnamed field required")); }; quote! { - #name ( #field ) => Box::pin(#handler::process(#field, body)).await?, + #name ( #field ) => Box::pin(#handler::process(#field, context)).await?, } }, Fields::Unit => { quote! { - #name => Box::pin(#handler(&body)).await?, + #name => Box::pin(context.#handler()).await?, } }, }; diff --git a/src/macros/mod.rs b/src/macros/mod.rs index 1a5494bb0..d32cda71c 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -14,6 +14,11 @@ pub(crate) type Result<T> = std::result::Result<T, Error>; +#[proc_macro_attribute] +pub fn admin_command(args: TokenStream, input: TokenStream) -> TokenStream { + attribute_macro::<ItemFn, _>(args, input, admin::command) +} + #[proc_macro_attribute] pub fn admin_command_dispatch(args: TokenStream, input: TokenStream) -> TokenStream { attribute_macro::<ItemEnum, _>(args, input, admin::command_dispatch) -- GitLab From 2f85a5c1ac5fba02773004db7aa5e855c08364f4 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sat, 27 Jul 2024 07:17:07 +0000 Subject: [PATCH 46/47] de-global services Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/check/commands.rs | 4 +- src/admin/handler.rs | 241 ++++++++++++++++------------------ src/admin/mod.rs | 22 +--- src/api/client/account.rs | 8 +- src/api/client/alias.rs | 2 +- src/api/client/config.rs | 4 +- src/api/client/directory.rs | 6 +- src/api/client/keys.rs | 4 +- src/api/client/media.rs | 8 +- src/api/client/membership.rs | 16 +-- src/api/client/message.rs | 4 +- src/api/client/profile.rs | 4 +- src/api/client/report.rs | 2 +- src/api/client/room.rs | 6 +- src/api/client/state.rs | 2 +- src/api/client/sync.rs | 13 +- src/api/mod.rs | 2 +- src/api/router.rs | 4 +- src/api/router/args.rs | 7 +- src/api/server/make_join.rs | 2 +- src/api/server/publicrooms.rs | 4 +- src/api/server/send.rs | 4 +- src/api/server/send_join.rs | 4 +- src/api/server/send_leave.rs | 4 +- src/api/server/user.rs | 4 +- src/database/database.rs | 6 +- src/main/main.rs | 37 +++++- src/main/server.rs | 6 +- src/router/layers.rs | 20 +-- src/router/mod.rs | 11 +- src/router/router.rs | 16 +-- src/router/run.rs | 47 ++++--- src/router/serve/mod.rs | 18 ++- src/service/admin/mod.rs | 20 ++- src/service/mod.rs | 53 +------- src/service/services.rs | 16 ++- 36 files changed, 311 insertions(+), 320 deletions(-) diff --git a/src/admin/check/commands.rs b/src/admin/check/commands.rs index a757d5044..0a9830464 100644 --- a/src/admin/check/commands.rs +++ b/src/admin/check/commands.rs @@ -2,7 +2,7 @@ use conduit_macros::implement; use ruma::events::room::message::RoomMessageEventContent; -use crate::{services, Command}; +use crate::Command; /// Uses the iterator in `src/database/key_value/users.rs` to iterator over /// every user in our database (remote and local). Reports total count, any @@ -10,7 +10,7 @@ #[implement(Command, params = "<'_>")] pub(super) async fn check_all_users(&self) -> Result<RoomMessageEventContent> { let timer = tokio::time::Instant::now(); - let results = services().users.db.iter(); + let results = self.services.users.db.iter(); let query_time = timer.elapsed(); let users = results.collect::<Vec<_>>(); diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 32360c855..a7e4d79a1 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -1,4 +1,4 @@ -use std::{panic::AssertUnwindSafe, time::Instant}; +use std::{panic::AssertUnwindSafe, sync::Arc, time::Instant}; use clap::{CommandFactory, Parser}; use conduit::{error, trace, utils::string::common_prefix, Error, Result}; @@ -17,37 +17,27 @@ use crate::{admin, admin::AdminCommand, Command}; -struct Handler { - services: &'static Services, -} - #[must_use] -pub(super) fn complete(line: &str) -> String { - Handler { - services: service::services(), - } - .complete_command(AdminCommand::command(), line) -} +pub(super) fn complete(line: &str) -> String { complete_command(AdminCommand::command(), line) } #[must_use] -pub(super) fn handle(command: CommandInput) -> HandlerResult { Box::pin(handle_command(command)) } +pub(super) fn handle(services: Arc<Services>, command: CommandInput) -> HandlerResult { + Box::pin(handle_command(services, command)) +} #[tracing::instrument(skip_all, name = "admin")] -async fn handle_command(command: CommandInput) -> CommandResult { - AssertUnwindSafe(Box::pin(process_command(&command))) +async fn handle_command(services: Arc<Services>, command: CommandInput) -> CommandResult { + AssertUnwindSafe(Box::pin(process_command(services, &command))) .catch_unwind() .await .map_err(Error::from_panic) .or_else(|error| handle_panic(&error, command)) } -async fn process_command(command: &CommandInput) -> CommandOutput { - Handler { - services: service::services(), - } - .process(&command.command) - .await - .and_then(|content| reply(content, command.reply_id.clone())) +async fn process_command(services: Arc<Services>, command: &CommandInput) -> CommandOutput { + process(services, &command.command) + .await + .and_then(|content| reply(content, command.reply_id.clone())) } fn handle_panic(error: &Error, command: CommandInput) -> CommandResult { @@ -68,129 +58,126 @@ fn reply(mut content: RoomMessageEventContent, reply_id: Option<OwnedEventId>) - Some(content) } -impl Handler { - // Parse and process a message from the admin room - async fn process(&self, msg: &str) -> CommandOutput { - let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); - let command = lines.next().expect("each string has at least one line"); - let (parsed, body) = match self.parse_command(command) { - Ok(parsed) => parsed, - Err(error) => { - let server_name = self.services.globals.server_name(); - let message = error.replace("server.name", server_name.as_str()); - return Some(RoomMessageEventContent::notice_markdown(message)); - }, - }; - - let timer = Instant::now(); - let body: Vec<&str> = body.iter().map(String::as_str).collect(); - let context = Command { - services: self.services, - body: &body, - }; - let result = Box::pin(admin::process(parsed, &context)).await; - let elapsed = timer.elapsed(); - conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); - match result { - Ok(reply) => Some(reply), - Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( - "Encountered an error while handling the command:\n```\n{error:#?}\n```" - ))), - } +// Parse and process a message from the admin room +async fn process(services: Arc<Services>, msg: &str) -> CommandOutput { + let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); + let command = lines.next().expect("each string has at least one line"); + let (parsed, body) = match parse_command(command) { + Ok(parsed) => parsed, + Err(error) => { + let server_name = services.globals.server_name(); + let message = error.replace("server.name", server_name.as_str()); + return Some(RoomMessageEventContent::notice_markdown(message)); + }, + }; + + let timer = Instant::now(); + let body: Vec<&str> = body.iter().map(String::as_str).collect(); + let context = Command { + services: &services, + body: &body, + }; + let result = Box::pin(admin::process(parsed, &context)).await; + let elapsed = timer.elapsed(); + conduit::debug!(?command, ok = result.is_ok(), "command processed in {elapsed:?}"); + match result { + Ok(reply) => Some(reply), + Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( + "Encountered an error while handling the command:\n```\n{error:#?}\n```" + ))), } +} - // Parse chat messages from the admin room into an AdminCommand object - fn parse_command(&self, command_line: &str) -> Result<(AdminCommand, Vec<String>), String> { - let argv = self.parse_line(command_line); - let com = AdminCommand::try_parse_from(&argv).map_err(|error| error.to_string())?; - Ok((com, argv)) - } +// Parse chat messages from the admin room into an AdminCommand object +fn parse_command(command_line: &str) -> Result<(AdminCommand, Vec<String>), String> { + let argv = parse_line(command_line); + let com = AdminCommand::try_parse_from(&argv).map_err(|error| error.to_string())?; + Ok((com, argv)) +} - fn complete_command(&self, mut cmd: clap::Command, line: &str) -> String { - let argv = self.parse_line(line); - let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); - - 'token: for token in argv.into_iter().skip(1) { - let cmd_ = cmd.clone(); - let mut choice = Vec::new(); - - for sub in cmd_.get_subcommands() { - let name = sub.get_name(); - if *name == token { - // token already complete; recurse to subcommand - ret.push(token); - cmd.clone_from(sub); - continue 'token; - } else if name.starts_with(&token) { - // partial match; add to choices - choice.push(name); - } - } +fn complete_command(mut cmd: clap::Command, line: &str) -> String { + let argv = parse_line(line); + let mut ret = Vec::<String>::with_capacity(argv.len().saturating_add(1)); + + 'token: for token in argv.into_iter().skip(1) { + let cmd_ = cmd.clone(); + let mut choice = Vec::new(); - if choice.len() == 1 { - // One choice. Add extra space because it's complete - let choice = *choice.first().expect("only choice"); - ret.push(choice.to_owned()); - ret.push(String::new()); - } else if choice.is_empty() { - // Nothing found, return original string + for sub in cmd_.get_subcommands() { + let name = sub.get_name(); + if *name == token { + // token already complete; recurse to subcommand ret.push(token); - } else { - // Find the common prefix - ret.push(common_prefix(&choice).into()); + cmd.clone_from(sub); + continue 'token; + } else if name.starts_with(&token) { + // partial match; add to choices + choice.push(name); } + } - // Return from completion - return ret.join(" "); + if choice.len() == 1 { + // One choice. Add extra space because it's complete + let choice = *choice.first().expect("only choice"); + ret.push(choice.to_owned()); + ret.push(String::new()); + } else if choice.is_empty() { + // Nothing found, return original string + ret.push(token); + } else { + // Find the common prefix + ret.push(common_prefix(&choice).into()); } - // Return from no completion. Needs a space though. - ret.push(String::new()); - ret.join(" ") + // Return from completion + return ret.join(" "); } - // Parse chat messages from the admin room into an AdminCommand object - fn parse_line(&self, command_line: &str) -> Vec<String> { - let mut argv = command_line - .split_whitespace() - .map(str::to_owned) - .collect::<Vec<String>>(); + // Return from no completion. Needs a space though. + ret.push(String::new()); + ret.join(" ") +} - // Remove any escapes that came with a server-side escape command - if !argv.is_empty() && argv[0].ends_with("admin") { - argv[0] = argv[0].trim_start_matches('\\').into(); - } +// Parse chat messages from the admin room into an AdminCommand object +fn parse_line(command_line: &str) -> Vec<String> { + let mut argv = command_line + .split_whitespace() + .map(str::to_owned) + .collect::<Vec<String>>(); - // First indice has to be "admin" but for console convenience we add it here - let server_user = self.services.globals.server_user.as_str(); - if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with(server_user) { - argv.insert(0, "admin".to_owned()); - } + // Remove any escapes that came with a server-side escape command + if !argv.is_empty() && argv[0].ends_with("admin") { + argv[0] = argv[0].trim_start_matches('\\').into(); + } - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help".to_owned()); - } + // First indice has to be "admin" but for console convenience we add it here + if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with('@') { + argv.insert(0, "admin".to_owned()); + } - // Backwards compatibility with `register_appservice`-style commands - if argv.len() > 1 && argv[1].contains('_') { - argv[1] = argv[1].replace('_', "-"); - } + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help".to_owned()); + } - // Backwards compatibility with `register_appservice`-style commands - if argv.len() > 2 && argv[2].contains('_') { - argv[2] = argv[2].replace('_', "-"); - } + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 1 && argv[1].contains('_') { + argv[1] = argv[1].replace('_', "-"); + } - // if the user is using the `query` command (argv[1]), replace the database - // function/table calls with underscores to match the codebase - if argv.len() > 3 && argv[1].eq("query") { - argv[3] = argv[3].replace('_', "-"); - } + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 2 && argv[2].contains('_') { + argv[2] = argv[2].replace('_', "-"); + } - trace!(?command_line, ?argv, "parse"); - argv + // if the user is using the `query` command (argv[1]), replace the database + // function/table calls with underscores to match the codebase + if argv.len() > 3 && argv[1].eq("query") { + argv[3] = argv[3].replace('_', "-"); } + + trace!(?command_line, ?argv, "parse"); + argv } diff --git a/src/admin/mod.rs b/src/admin/mod.rs index 5d4c8f5e4..fb1c02be7 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -1,4 +1,4 @@ -#![recursion_limit = "168"] +#![recursion_limit = "192"] #![allow(clippy::wildcard_imports)] #![allow(clippy::enum_glob_use)] @@ -24,7 +24,6 @@ pub(crate) use conduit::Result; pub(crate) use conduit_macros::{admin_command, admin_command_dispatch}; -pub(crate) use service::services; pub(crate) use crate::{ command::Command, @@ -38,26 +37,19 @@ conduit::rustc_flags_capture! {} /// Install the admin command handler -pub async fn init() { - _ = services() - .admin +pub async fn init(admin_service: &service::admin::Service) { + _ = admin_service .complete .write() .expect("locked for writing") .insert(handler::complete); - _ = services() - .admin - .handle - .write() - .await - .insert(handler::handle); + _ = admin_service.handle.write().await.insert(handler::handle); } /// Uninstall the admin command handler -pub async fn fini() { - _ = services().admin.handle.write().await.take(); - _ = services() - .admin +pub async fn fini(admin_service: &service::admin::Service) { + _ = admin_service.handle.write().await.take(); + _ = admin_service .complete .write() .expect("locked for writing") diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 7c2bb0b6a..f093c459c 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -370,7 +370,7 @@ pub(crate) async fn register_route( if let Some(room_id_server_name) = room.server_name() { if let Err(e) = join_room_by_id_helper( - services, + &services, &user_id, room, Some("Automatically joining this room upon registration".to_owned()), @@ -562,11 +562,11 @@ pub(crate) async fn deactivate_route( .rooms_joined(sender_user) .filter_map(Result::ok) .collect(); - super::update_displayname(services, sender_user.clone(), None, all_joined_rooms.clone()).await?; - super::update_avatar_url(services, sender_user.clone(), None, None, all_joined_rooms).await?; + super::update_displayname(&services, sender_user.clone(), None, all_joined_rooms.clone()).await?; + super::update_avatar_url(&services, sender_user.clone(), None, None, all_joined_rooms).await?; // Make the user leave all rooms before deactivation - super::leave_all_rooms(services, sender_user).await; + super::leave_all_rooms(&services, sender_user).await; info!("User {sender_user} deactivated their account."); services diff --git a/src/api/client/alias.rs b/src/api/client/alias.rs index dbc75e641..18d1c5b0e 100644 --- a/src/api/client/alias.rs +++ b/src/api/client/alias.rs @@ -107,7 +107,7 @@ pub(crate) async fn get_alias_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); }; - let servers = room_available_servers(services, &room_id, &room_alias, &pre_servers); + let servers = room_available_servers(&services, &room_id, &room_alias, &pre_servers); debug!(?room_alias, ?room_id, "available servers: {servers:?}"); Ok(get_alias::v3::Response::new(room_id, servers)) diff --git a/src/api/client/config.rs b/src/api/client/config.rs index 56d33ba7e..61cc97ff5 100644 --- a/src/api/client/config.rs +++ b/src/api/client/config.rs @@ -20,7 +20,7 @@ pub(crate) async fn set_global_account_data_route( State(services): State<crate::State>, body: Ruma<set_global_account_data::v3::Request>, ) -> Result<set_global_account_data::v3::Response> { set_account_data( - services, + &services, None, &body.sender_user, &body.event_type.to_string(), @@ -37,7 +37,7 @@ pub(crate) async fn set_room_account_data_route( State(services): State<crate::State>, body: Ruma<set_room_account_data::v3::Request>, ) -> Result<set_room_account_data::v3::Response> { set_account_data( - services, + &services, Some(&body.room_id), &body.sender_user, &body.event_type.to_string(), diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index cb30b60a5..6054bd9cd 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -48,7 +48,7 @@ pub(crate) async fn get_public_rooms_filtered_route( } let response = get_public_rooms_filtered_helper( - services, + &services, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -88,7 +88,7 @@ pub(crate) async fn get_public_rooms_route( } let response = get_public_rooms_filtered_helper( - services, + &services, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -124,7 +124,7 @@ pub(crate) async fn set_room_visibility_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } - if !user_can_publish_room(services, sender_user, &body.room_id)? { + if !user_can_publish_room(&services, sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::forbidden(), "User is not allowed to publish this room", diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 8489dde35..6fa7a8ad2 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -77,7 +77,7 @@ pub(crate) async fn get_keys_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); get_keys_helper( - services, + &services, Some(sender_user), &body.device_keys, |u| u == sender_user, @@ -92,7 +92,7 @@ pub(crate) async fn get_keys_route( pub(crate) async fn claim_keys_route( State(services): State<crate::State>, body: Ruma<claim_keys::v3::Request>, ) -> Result<claim_keys::v3::Response> { - claim_keys_helper(services, &body.one_time_keys).await + claim_keys_helper(&services, &body.one_time_keys).await } /// # `POST /_matrix/client/r0/keys/device_signing/upload` diff --git a/src/api/client/media.rs b/src/api/client/media.rs index 78463fc6b..f0afa290e 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -76,12 +76,12 @@ pub(crate) async fn get_media_preview_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !url_preview_allowed(services, url) { + if !url_preview_allowed(&services, url) { warn!(%sender_user, "URL is not allowed to be previewed: {url}"); return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed")); } - match get_url_preview(services, url).await { + match get_url_preview(&services, url).await { Ok(preview) => { let res = serde_json::value::to_raw_value(&preview).map_err(|e| { error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}"); @@ -221,7 +221,7 @@ pub(crate) async fn get_content_route( }) } else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote { let response = get_remote_content( - services, + &services, &mxc, &body.server_name, body.media_id.clone(), @@ -311,7 +311,7 @@ pub(crate) async fn get_content_as_filename_route( }) } else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote { match get_remote_content( - services, + &services, &mxc, &body.server_name, body.media_id.clone(), diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index d3b2d8f66..ca7e6b6ff 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -167,7 +167,7 @@ pub(crate) async fn join_room_by_id_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); banned_room_check( - services, + &services, sender_user, Some(&body.room_id), body.room_id.server_name(), @@ -202,7 +202,7 @@ pub(crate) async fn join_room_by_id_route( } join_room_by_id_helper( - services, + &services, sender_user, &body.room_id, body.reason.clone(), @@ -231,7 +231,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { Ok(room_id) => { - banned_room_check(services, sender_user, Some(&room_id), room_id.server_name(), client).await?; + banned_room_check(&services, sender_user, Some(&room_id), room_id.server_name(), client).await?; let mut servers = body.server_name.clone(); servers.extend( @@ -270,7 +270,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( .await?; let (room_id, mut pre_servers) = response; - banned_room_check(services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; + banned_room_check(&services, sender_user, Some(&room_id), Some(room_alias.server_name()), client).await?; let mut servers = body.server_name; if let Some(pre_servers) = &mut pre_servers { @@ -303,7 +303,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( }; let join_room_response = join_room_by_id_helper( - services, + &services, sender_user, &room_id, body.reason.clone(), @@ -327,7 +327,7 @@ pub(crate) async fn leave_room_route( ) -> Result<leave_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - leave_room(services, sender_user, &body.room_id, body.reason.clone()).await?; + leave_room(&services, sender_user, &body.room_id, body.reason.clone()).await?; Ok(leave_room::v3::Response::new()) } @@ -353,13 +353,13 @@ pub(crate) async fn invite_user_route( )); } - banned_room_check(services, sender_user, Some(&body.room_id), body.room_id.server_name(), client).await?; + banned_room_check(&services, sender_user, Some(&body.room_id), body.room_id.server_name(), client).await?; if let invite_user::v3::InvitationRecipient::UserId { user_id, } = &body.recipient { - invite_helper(services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; + invite_helper(&services, sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) diff --git a/src/api/client/message.rs b/src/api/client/message.rs index c0b5cf0c3..9aae4aaf6 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -146,7 +146,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_after(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id) + .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id) }) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` @@ -193,7 +193,7 @@ pub(crate) async fn get_message_events_route( .timeline .pdus_until(sender_user, &body.room_id, from)? .filter_map(Result::ok) // Filter out buggy events - .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id)}) + .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(&services, pdu, sender_user, &body.room_id)}) .take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take(limit) .collect(); diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 9e9bcf8e0..8f6d90568 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -33,7 +33,7 @@ pub(crate) async fn set_displayname_route( .filter_map(Result::ok) .collect(); - update_displayname(services, sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?; + update_displayname(&services, sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?; if services.globals.allow_local_presence() { // Presence update @@ -118,7 +118,7 @@ pub(crate) async fn set_avatar_url_route( .collect(); update_avatar_url( - services, + &services, sender_user.clone(), body.avatar_url.clone(), body.blurhash.clone(), diff --git a/src/api/client/report.rs b/src/api/client/report.rs index a16df4448..dc87fd214 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -40,7 +40,7 @@ pub(crate) async fn report_event_route( }; is_report_valid( - services, + &services, &pdu.event_id, &body.room_id, sender_user, diff --git a/src/api/client/room.rs b/src/api/client/room.rs index c4d828220..c78ba6edb 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -79,7 +79,7 @@ pub(crate) async fn create_room_route( } let room_id: OwnedRoomId = if let Some(custom_room_id) = &body.room_id { - custom_room_id_check(services, custom_room_id)? + custom_room_id_check(&services, custom_room_id)? } else { RoomId::new(&services.globals.config.server_name) }; @@ -96,7 +96,7 @@ pub(crate) async fn create_room_route( let state_lock = services.rooms.state.mutex.lock(&room_id).await; let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { - Some(room_alias_check(services, alias, &body.appservice_info).await?) + Some(room_alias_check(&services, alias, &body.appservice_info).await?) } else { None }; @@ -438,7 +438,7 @@ pub(crate) async fn create_room_route( // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - if let Err(e) = invite_helper(services, sender_user, user_id, &room_id, None, body.is_direct).await { + if let Err(e) = invite_helper(&services, sender_user, user_id, &room_id, None, body.is_direct).await { warn!(%e, "Failed to send invite"); } } diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 7af4f5f97..d0fb83d17 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -37,7 +37,7 @@ pub(crate) async fn send_state_event_for_key_route( Ok(send_state_event::v3::Response { event_id: send_state_event_for_key_helper( - services, + &services, sender_user, &body.room_id, &body.event_type, diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 6eeb8fffe..34baf7c1c 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -137,7 +137,7 @@ pub(crate) async fn sync_events_route( ); if services.globals.allow_local_presence() { - process_presence_updates(services, &mut presence_updates, since, &sender_user).await?; + process_presence_updates(&services, &mut presence_updates, since, &sender_user).await?; } let all_joined_rooms = services @@ -152,7 +152,7 @@ pub(crate) async fn sync_events_route( for room_id in all_joined_rooms { let room_id = room_id?; if let Ok(joined_room) = load_joined_room( - services, + &services, &sender_user, &sender_device, &room_id, @@ -182,7 +182,7 @@ pub(crate) async fn sync_events_route( .collect(); for result in all_left_rooms { handle_left_room( - services, + &services, since, &result?.0, &sender_user, @@ -1214,7 +1214,7 @@ pub(crate) async fn sync_events_v4_route( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(services, &sender_user, &user_id, room_id)? { + if !share_encrypted_room(&services, &sender_user, &user_id, room_id)? { device_list_changes.insert(user_id); } }, @@ -1243,7 +1243,7 @@ pub(crate) async fn sync_events_v4_route( .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target // already - !share_encrypted_room(services, &sender_user, user_id, room_id).unwrap_or(false) + !share_encrypted_room(&services, &sender_user, user_id, room_id).unwrap_or(false) }), ); } @@ -1407,7 +1407,8 @@ pub(crate) async fn sync_events_v4_route( for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { let roomsincecount = PduCount::Normal(*roomsince); - let (timeline_pdus, limited) = load_timeline(services, &sender_user, room_id, roomsincecount, *timeline_limit)?; + let (timeline_pdus, limited) = + load_timeline(&services, &sender_user, room_id, roomsincecount, *timeline_limit)?; if roomsince != &0 && timeline_pdus.is_empty() { continue; diff --git a/src/api/mod.rs b/src/api/mod.rs index c7411b6c1..7fa70873a 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,4 @@ -#![recursion_limit = "160"] +#![recursion_limit = "192"] pub mod client; pub mod router; diff --git a/src/api/router.rs b/src/api/router.rs index 761c173cc..d624de32b 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -4,6 +4,8 @@ mod request; mod response; +use std::sync::Arc; + use axum::{ response::IntoResponse, routing::{any, get, post}, @@ -16,7 +18,7 @@ pub(super) use self::{args::Args as Ruma, response::RumaResponse}; use crate::{client, server}; -pub type State = &'static service::Services; +pub type State = Arc<service::Services>; pub fn build(router: Router<State>, server: &Server) -> Router<State> { let config = &server.config; diff --git a/src/api/router/args.rs b/src/api/router/args.rs index fa5b1e439..a3d09dff5 100644 --- a/src/api/router/args.rs +++ b/src/api/router/args.rs @@ -7,7 +7,7 @@ use service::Services; use super::{auth, auth::Auth, request, request::Request}; -use crate::service::appservice::RegistrationInfo; +use crate::{service::appservice::RegistrationInfo, State}; /// Extractor for Ruma request structs pub(crate) struct Args<T> { @@ -36,14 +36,13 @@ pub(crate) struct Args<T> { } #[async_trait] -impl<T, S> FromRequest<S, Body> for Args<T> +impl<T> FromRequest<State, Body> for Args<T> where T: IncomingRequest, { type Rejection = Error; - async fn from_request(request: hyper::Request<Body>, _: &S) -> Result<Self, Self::Rejection> { - let services = service::services(); // ??? + async fn from_request(request: hyper::Request<Body>, services: &State) -> Result<Self, Self::Rejection> { let mut request = request::from(services, request).await?; let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok(); let auth = auth::auth(services, &mut request, &json_body, &T::METADATA).await?; diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index e1beaa33c..e9b1a6c78 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -82,7 +82,7 @@ pub(crate) async fn create_join_event_template_route( .state_cache .is_left(&body.user_id, &body.room_id) .unwrap_or(true)) - && user_can_perform_restricted_join(services, &body.user_id, &body.room_id, &room_version_id)? + && user_can_perform_restricted_join(&services, &body.user_id, &body.room_id, &room_version_id)? { let auth_user = services .rooms diff --git a/src/api/server/publicrooms.rs b/src/api/server/publicrooms.rs index 1876dde17..af8a58464 100644 --- a/src/api/server/publicrooms.rs +++ b/src/api/server/publicrooms.rs @@ -26,7 +26,7 @@ pub(crate) async fn get_public_rooms_filtered_route( } let response = crate::client::get_public_rooms_filtered_helper( - services, + &services, None, body.limit, body.since.as_deref(), @@ -60,7 +60,7 @@ pub(crate) async fn get_public_rooms_route( } let response = crate::client::get_public_rooms_filtered_helper( - services, + &services, None, body.limit, body.since.as_deref(), diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 2f698d337..394289a6f 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -62,8 +62,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(services, &client, &body, origin, &txn_start_time).await?; - handle_edus(services, &client, &body, origin).await?; + let resolved_map = handle_pdus(&services, &client, &body, origin, &txn_start_time).await?; + handle_edus(&services, &client, &body, origin).await?; debug!( pdus = ?body.pdus.len(), diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 4cd29795f..17f563831 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -241,7 +241,7 @@ pub(crate) async fn create_join_event_v1_route( } } - let room_state = create_join_event(services, origin, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state, @@ -286,7 +286,7 @@ pub(crate) async fn create_join_event_v2_route( auth_chain, state, event, - } = create_join_event(services, origin, &body.room_id, &body.pdu).await?; + } = create_join_event(&services, origin, &body.room_id, &body.pdu).await?; let room_state = create_join_event::v2::RoomState { members_omitted: false, auth_chain, diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index b1b8fec81..ef4c8c454 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -28,7 +28,7 @@ pub(crate) async fn create_leave_event_v1_route( ) -> Result<create_leave_event::v1::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - create_leave_event(services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; Ok(create_leave_event::v1::Response::new()) } @@ -41,7 +41,7 @@ pub(crate) async fn create_leave_event_v2_route( ) -> Result<create_leave_event::v2::Response> { let origin = body.origin.as_ref().expect("server is authenticated"); - create_leave_event(services, origin, &body.room_id, &body.pdu).await?; + create_leave_event(&services, origin, &body.room_id, &body.pdu).await?; Ok(create_leave_event::v2::Response::new()) } diff --git a/src/api/server/user.rs b/src/api/server/user.rs index bd0372e66..e9a400a79 100644 --- a/src/api/server/user.rs +++ b/src/api/server/user.rs @@ -84,7 +84,7 @@ pub(crate) async fn get_keys_route( } let result = get_keys_helper( - services, + &services, None, &body.device_keys, |u| Some(u.server_name()) == body.origin.as_deref(), @@ -116,7 +116,7 @@ pub(crate) async fn claim_keys_route( )); } - let result = claim_keys_helper(services, &body.one_time_keys).await?; + let result = claim_keys_helper(&services, &body.one_time_keys).await?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys, diff --git a/src/database/database.rs b/src/database/database.rs index 44bb655cc..1d7bbc33a 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -11,12 +11,12 @@ pub struct Database { impl Database { /// Load an existing database or create a new one. - pub async fn open(server: &Arc<Server>) -> Result<Self> { + pub async fn open(server: &Arc<Server>) -> Result<Arc<Self>> { let db = Engine::open(server)?; - Ok(Self { + Ok(Arc::new(Self { db: db.clone(), map: maps::open(&db)?, - }) + })) } #[inline] diff --git a/src/main/main.rs b/src/main/main.rs index b8cb24ff1..8703eef2b 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "192"] + pub(crate) mod clap; mod mods; mod restart; @@ -57,17 +59,38 @@ fn main() -> Result<(), Error> { async fn async_main(server: &Arc<Server>) -> Result<(), Error> { extern crate conduit_router as router; - if let Err(error) = router::start(&server.server).await { - error!("Critical error starting server: {error}"); - return Err(error); - } - - if let Err(error) = router::run(&server.server).await { + match router::start(&server.server).await { + Ok(services) => server.services.lock().await.insert(services), + Err(error) => { + error!("Critical error starting server: {error}"); + return Err(error); + }, + }; + + if let Err(error) = router::run( + server + .services + .lock() + .await + .as_ref() + .expect("services initialized"), + ) + .await + { error!("Critical error running server: {error}"); return Err(error); } - if let Err(error) = router::stop(&server.server).await { + if let Err(error) = router::stop( + server + .services + .lock() + .await + .take() + .expect("services initialied"), + ) + .await + { error!("Critical error stopping server: {error}"); return Err(error); } diff --git a/src/main/server.rs b/src/main/server.rs index 71cdadce4..e435b2f44 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use conduit::{config::Config, info, log::Log, utils::sys, Error, Result}; -use tokio::runtime; +use tokio::{runtime, sync::Mutex}; use crate::{clap::Args, tracing::TracingFlameGuard}; @@ -10,6 +10,8 @@ pub(crate) struct Server { /// Server runtime state; public portion pub(crate) server: Arc<conduit::Server>, + pub(crate) services: Mutex<Option<Arc<conduit_service::Services>>>, + _tracing_flame_guard: TracingFlameGuard, #[cfg(feature = "sentry_telemetry")] @@ -54,6 +56,8 @@ pub(crate) fn build(args: &Args, runtime: Option<&runtime::Handle>) -> Result<Ar }, )), + services: None.into(), + _tracing_flame_guard: tracing_flame_guard, #[cfg(feature = "sentry_telemetry")] diff --git a/src/router/layers.rs b/src/router/layers.rs index 67342eb30..2b143666b 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -6,6 +6,7 @@ }; use axum_client_ip::SecureClientIpSource; use conduit::{error, Result, Server}; +use conduit_service::Services; use http::{ header::{self, HeaderName}, HeaderValue, Method, StatusCode, @@ -34,7 +35,8 @@ const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; -pub(crate) fn build(server: &Arc<Server>) -> Result<Router> { +pub(crate) fn build(services: &Arc<Services>) -> Result<Router> { + let server = &services.server; let layers = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] @@ -83,7 +85,7 @@ pub(crate) fn build(server: &Arc<Server>) -> Result<Router> { .layer(body_limit_layer(server)) .layer(CatchPanicLayer::custom(catch_panic)); - Ok(router::build(server).layer(layers)) + Ok(router::build(services).layer(layers)) } #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] @@ -151,12 +153,14 @@ fn body_limit_layer(server: &Server) -> DefaultBodyLimit { DefaultBodyLimit::max #[allow(clippy::needless_pass_by_value)] #[tracing::instrument(skip_all, name = "panic")] fn catch_panic(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> { - conduit_service::services() - .server - .metrics - .requests_panic - .fetch_add(1, std::sync::atomic::Ordering::Release); - + //TODO: XXX + /* + conduit_service::services() + .server + .metrics + .requests_panic + .fetch_add(1, std::sync::atomic::Ordering::Release); + */ let details = if let Some(s) = err.downcast_ref::<String>() { s.clone() } else if let Some(s) = err.downcast_ref::<&str>() { diff --git a/src/router/mod.rs b/src/router/mod.rs index 13fe39087..67ebc0e3f 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -11,22 +11,23 @@ use std::{future::Future, pin::Pin, sync::Arc}; use conduit::{Result, Server}; +use conduit_service::Services; conduit::mod_ctor! {} conduit::mod_dtor! {} conduit::rustc_flags_capture! {} #[no_mangle] -pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { +pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<Arc<Services>>> + Send>> { Box::pin(run::start(server.clone())) } #[no_mangle] -pub extern "Rust" fn stop(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { - Box::pin(run::stop(server.clone())) +pub extern "Rust" fn stop(services: Arc<Services>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { + Box::pin(run::stop(services)) } #[no_mangle] -pub extern "Rust" fn run(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { - Box::pin(run::run(server.clone())) +pub extern "Rust" fn run(services: &Arc<Services>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { + Box::pin(run::run(services.clone())) } diff --git a/src/router/router.rs b/src/router/router.rs index 7c374b47d..3527f1e6d 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -1,20 +1,20 @@ use std::sync::Arc; use axum::{response::IntoResponse, routing::get, Router}; -use conduit::{Error, Server}; +use conduit::Error; +use conduit_api::State; +use conduit_service::Services; use http::{StatusCode, Uri}; use ruma::api::client::error::ErrorKind; -extern crate conduit_api as api; -extern crate conduit_service as service; +pub(crate) fn build(services: &Arc<Services>) -> Router { + let router = Router::<State>::new(); + let state = services.clone(); -pub(crate) fn build(server: &Arc<Server>) -> Router { - let router = Router::<api::State>::new(); - - api::router::build(router, server) + conduit_api::router::build(router, &services.server) .route("/", get(it_works)) .fallback(not_found) - .with_state(service::services()) + .with_state(state) } async fn not_found(_uri: Uri) -> impl IntoResponse { diff --git a/src/router/run.rs b/src/router/run.rs index cb5d2abf3..395aa8c4d 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -1,28 +1,30 @@ -use std::{sync::Arc, time::Duration}; +extern crate conduit_admin as admin; +extern crate conduit_core as conduit; +extern crate conduit_service as service; + +use std::{ + sync::{atomic::Ordering, Arc}, + time::Duration, +}; use axum_server::Handle as ServerHandle; +use conduit::{debug, debug_error, debug_info, error, info, Error, Result, Server}; +use service::Services; use tokio::{ sync::broadcast::{self, Sender}, task::JoinHandle, }; -extern crate conduit_admin as admin; -extern crate conduit_core as conduit; -extern crate conduit_service as service; - -use std::sync::atomic::Ordering; - -use conduit::{debug, debug_info, error, info, Error, Result, Server}; - use crate::serve; /// Main loop base #[tracing::instrument(skip_all)] -pub(crate) async fn run(server: Arc<Server>) -> Result<()> { +pub(crate) async fn run(services: Arc<Services>) -> Result<()> { + let server = &services.server; debug!("Start"); // Install the admin room callback here for now - admin::init().await; + admin::init(&services.admin).await; // Setup shutdown/signal handling let handle = ServerHandle::new(); @@ -33,13 +35,13 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<()> { let mut listener = server .runtime() - .spawn(serve::serve(server.clone(), handle.clone(), tx.subscribe())); + .spawn(serve::serve(services.clone(), handle.clone(), tx.subscribe())); // Focal point debug!("Running"); let res = tokio::select! { res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err), - res = service::services().poll() => handle_services_poll(&server, res, listener).await, + res = services.poll() => handle_services_poll(server, res, listener).await, }; // Join the signal handler before we leave. @@ -47,7 +49,7 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<()> { _ = sigs.await; // Remove the admin room callback - admin::fini().await; + admin::fini(&services.admin).await; debug_info!("Finish"); res @@ -55,26 +57,33 @@ pub(crate) async fn run(server: Arc<Server>) -> Result<()> { /// Async initializations #[tracing::instrument(skip_all)] -pub(crate) async fn start(server: Arc<Server>) -> Result<()> { +pub(crate) async fn start(server: Arc<Server>) -> Result<Arc<Services>> { debug!("Starting..."); - service::start(&server).await?; + let services = Services::build(server).await?.start().await?; #[cfg(feature = "systemd")] sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).expect("failed to notify systemd of ready state"); debug!("Started"); - Ok(()) + Ok(services) } /// Async destructions #[tracing::instrument(skip_all)] -pub(crate) async fn stop(_server: Arc<Server>) -> Result<()> { +pub(crate) async fn stop(services: Arc<Services>) -> Result<()> { debug!("Shutting down..."); // Wait for all completions before dropping or we'll lose them to the module // unload and explode. - service::stop().await; + services.stop().await; + + if let Err(services) = Arc::try_unwrap(services) { + debug_error!( + "{} dangling references to Services after shutdown", + Arc::strong_count(&services) + ); + } debug!("Cleaning up..."); diff --git a/src/router/serve/mod.rs b/src/router/serve/mod.rs index 4e9234443..58bf2de8e 100644 --- a/src/router/serve/mod.rs +++ b/src/router/serve/mod.rs @@ -5,22 +5,26 @@ use std::sync::Arc; use axum_server::Handle as ServerHandle; -use conduit::{Result, Server}; +use conduit::Result; +use conduit_service::Services; use tokio::sync::broadcast; -use crate::layers; +use super::layers; /// Serve clients -pub(super) async fn serve(server: Arc<Server>, handle: ServerHandle, shutdown: broadcast::Receiver<()>) -> Result<()> { +pub(super) async fn serve( + services: Arc<Services>, handle: ServerHandle, shutdown: broadcast::Receiver<()>, +) -> Result<()> { + let server = &services.server; let config = &server.config; let addrs = config.get_bind_addrs(); - let app = layers::build(&server)?; + let app = layers::build(&services)?; if cfg!(unix) && config.unix_socket_path.is_some() { - unix::serve(&server, app, shutdown).await + unix::serve(server, app, shutdown).await } else if config.tls.is_some() { - tls::serve(&server, app, handle, addrs).await + tls::serve(server, app, handle, addrs).await } else { - plain::serve(&server, app, handle, addrs).await + plain::serve(server, app, handle, addrs).await } } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 6241c6684..71f3b73e6 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -9,7 +9,7 @@ }; use async_trait::async_trait; -use conduit::{debug, error, error::default_log, pdu::PduBuilder, Error, PduEvent, Result, Server}; +use conduit::{debug, error, error::default_log, pdu::PduBuilder, Err, Error, PduEvent, Result, Server}; pub use create::create_admin_room; use loole::{Receiver, Sender}; use ruma::{ @@ -41,6 +41,7 @@ struct Services { timeline: Dep<rooms::timeline::Service>, state: Dep<rooms::state::Service>, state_cache: Dep<rooms::state_cache::Service>, + services: StdRwLock<Option<Arc<crate::Services>>>, } #[derive(Debug)] @@ -50,7 +51,7 @@ pub struct CommandInput { } pub type Completer = fn(&str) -> String; -pub type Handler = fn(CommandInput) -> HandlerResult; +pub type Handler = fn(Arc<crate::Services>, CommandInput) -> HandlerResult; pub type HandlerResult = Pin<Box<dyn Future<Output = CommandResult> + Send>>; pub type CommandResult = Result<CommandOutput, Error>; pub type CommandOutput = Option<RoomMessageEventContent>; @@ -69,6 +70,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"), state: args.depend::<rooms::state::Service>("rooms::state"), state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"), + services: None.into(), }, sender, receiver: Mutex::new(receiver), @@ -172,10 +174,14 @@ async fn handle_command(&self, command: CommandInput) { } async fn process_command(&self, command: CommandInput) -> CommandResult { + let Some(services) = self.services.services.read().expect("locked").clone() else { + return Err!("Services self-reference not initialized."); + }; + if let Some(handle) = self.handle.read().await.as_ref() { - handle(command).await + handle(services, command).await } else { - Err(Error::Err("Admin module is not loaded.".into())) + Err!("Admin module is not loaded.") } } @@ -356,4 +362,10 @@ async fn console_auto_stop(&self) { #[cfg(feature = "console")] self.console.close().await; } + + /// Sets the self-reference to crate::Services which will provide context to + /// the admin commands. + pub(super) fn set_services(&self, services: Option<Arc<crate::Services>>) { + *self.services.services.write().expect("locked for writing") = services; + } } diff --git a/src/service/mod.rs b/src/service/mod.rs index b6ec58b5c..084bfac83 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,4 +1,4 @@ -#![recursion_limit = "160"] +#![recursion_limit = "192"] #![allow(refining_impl_trait)] mod manager; @@ -26,11 +26,7 @@ extern crate conduit_core as conduit; extern crate conduit_database as database; -use std::sync::{Arc, RwLock}; - pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; -use conduit::{Result, Server}; -use database::Database; pub(crate) use service::{Args, Dep, Service}; pub use crate::services::Services; @@ -38,50 +34,3 @@ conduit::mod_ctor! {} conduit::mod_dtor! {} conduit::rustc_flags_capture! {} - -static SERVICES: RwLock<Option<&Services>> = RwLock::new(None); - -pub async fn start(server: &Arc<Server>) -> Result<()> { - let d = Arc::new(Database::open(server).await?); - let s = Box::new(Services::build(server.clone(), d)?); - _ = SERVICES.write().expect("write locked").insert(Box::leak(s)); - - services().start().await -} - -pub async fn stop() { - services().stop().await; - - // Deactivate services(). Any further use will panic the caller. - let s = SERVICES - .write() - .expect("write locked") - .take() - .expect("services initialized"); - - let s: *mut Services = std::ptr::from_ref(s).cast_mut(); - //SAFETY: Services was instantiated in init() and leaked into the SERVICES - // global perusing as 'static for the duration of service. Now we reclaim - // it to drop it before unloading the module. If this is not done there wil - // be multiple instances after module reload. - let s = unsafe { Box::from_raw(s) }; - - // Drop it so we encounter any trouble before the infolog message - drop(s); -} - -#[must_use] -pub fn services() -> &'static Services { - SERVICES - .read() - .expect("SERVICES locked for reading") - .expect("SERVICES initialized with Services instance") -} - -#[inline] -pub fn available() -> bool { - SERVICES - .read() - .expect("SERVICES locked for reading") - .is_some() -} diff --git a/src/service/services.rs b/src/service/services.rs index 59909f8cb..b283db6cf 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -44,7 +44,8 @@ pub struct Services { impl Services { #[allow(clippy::cognitive_complexity)] - pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> { + pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> { + let db = Database::open(&server).await?; let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new())); macro_rules! build { ($tyname:ty) => {{ @@ -58,7 +59,7 @@ macro_rules! build { }}; } - Ok(Self { + Ok(Arc::new(Self { account_data: build!(account_data::Service), admin: build!(admin::Service), appservice: build!(appservice::Service), @@ -102,12 +103,13 @@ macro_rules! build { service, server, db, - }) + })) } - pub(super) async fn start(&self) -> Result<()> { + pub async fn start(self: &Arc<Self>) -> Result<Arc<Self>> { debug_info!("Starting services..."); + self.admin.set_services(Some(Arc::clone(self))); globals::migrations::migrations(self).await?; self.manager .lock() @@ -118,10 +120,10 @@ pub(super) async fn start(&self) -> Result<()> { .await?; debug_info!("Services startup complete."); - Ok(()) + Ok(Arc::clone(self)) } - pub(super) async fn stop(&self) { + pub async fn stop(&self) { info!("Shutting down services..."); self.interrupt(); @@ -129,6 +131,8 @@ pub(super) async fn stop(&self) { manager.stop().await; } + self.admin.set_services(None); + debug_info!("Services shutdown complete."); } -- GitLab From 954cfc6bb7f87c381e56f028d49f503134c732de Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sat, 27 Jul 2024 08:28:35 +0000 Subject: [PATCH 47/47] bump cargo Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 86 +++++++++++++++++++++++++++--------------------------- Cargo.toml | 4 +-- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ea215eec2..83c443cc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,9 +43,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "anyhow" @@ -544,9 +544,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.9" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" +checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" dependencies = [ "clap_builder", "clap_derive", @@ -554,9 +554,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.9" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" +checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" dependencies = [ "anstyle", "clap_lex", @@ -564,9 +564,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.8" +version = "4.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" +checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -576,9 +576,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "color_quant" @@ -729,12 +729,12 @@ dependencies = [ [[package]] name = "conduit_macros" -version = "0.4.5" +version = "0.4.6" dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.71", + "syn 2.0.72", ] [[package]] @@ -1028,7 +1028,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.71", + "syn 2.0.72", ] [[package]] @@ -1717,7 +1717,7 @@ dependencies = [ "http 1.1.0", "hyper 1.4.1", "hyper-util", - "rustls 0.23.11", + "rustls 0.23.12", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1896,9 +1896,9 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.31" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] @@ -1967,9 +1967,9 @@ dependencies = [ [[package]] name = "lazy-regex" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d12be4595afdf58bd19e4a9f4e24187da2a66700786ff660a418e9059937a4c" +checksum = "576c8060ecfdf2e56995cf3274b4f2d71fa5e4fa3607c1c0b63c10180ee58741" dependencies = [ "lazy-regex-proc_macros", "once_cell", @@ -1978,9 +1978,9 @@ dependencies = [ [[package]] name = "lazy-regex-proc_macros" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bcd58e6c97a7fcbaffcdc95728b393b8d98933bfadad49ed4097845b57ef0b" +checksum = "9efb9e65d4503df81c615dc33ff07042a9408ac7f26b45abee25566f7fbfd12c" dependencies = [ "proc-macro2", "quote", @@ -2329,9 +2329,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.1" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" +checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" dependencies = [ "memchr", ] @@ -2765,7 +2765,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.11", + "rustls 0.23.12", "thiserror", "tokio", "tracing", @@ -2781,7 +2781,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.11", + "rustls 0.23.12", "slab", "thiserror", "tinyvec", @@ -2920,7 +2920,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.11", + "rustls 0.23.12", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types", @@ -3261,9 +3261,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.11" +version = "0.23.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ "log", "once_cell", @@ -3622,9 +3622,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -4120,7 +4120,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.11", + "rustls 0.23.12", "rustls-pki-types", "tokio", ] @@ -4163,21 +4163,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac2caab0bf757388c6c0ae23b3293fdb463fee59434529014f85e3263b995c28" +checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.16", + "toml_edit 0.22.17", ] [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" dependencies = [ "serde", ] @@ -4195,15 +4195,15 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.16" +version = "0.22.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "278f3d518e152219c994ce877758516bca5e118eaed6996192a774fb9fbf0788" +checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" dependencies = [ "indexmap 2.2.6", "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.15", + "winnow 0.6.16", ] [[package]] @@ -4482,7 +4482,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.11", + "rustls 0.23.12", "rustls-pki-types", "url", "webpki-roots", @@ -4536,9 +4536,9 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "want" @@ -4881,9 +4881,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.15" +version = "0.6.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "557404e450152cd6795bb558bca69e43c585055f4606e3bcae5894fc6dac9ba0" +checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 17e7e7126..48e6ac1af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -426,11 +426,11 @@ default-features = false version = "0.1" [workspace.dependencies.syn] -version = "2.0" +version = "2.0.72" features = ["full", "extra-traits"] [workspace.dependencies.quote] -version = "1.0" +version = "1.0.36" [workspace.dependencies.proc-macro2] version = "1.0.86" -- GitLab