From 6e50b07bf56315b002d6ef2f98221ebbc87c171b Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 03:57:18 +0000 Subject: [PATCH 01/13] Fix large future Signed-off-by: Jason Volk <jason@zemos.net> --- src/admin/handler.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 1a2ff38e6..b09bdb8df 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -81,7 +81,8 @@ async fn handle_event(event: AdminRoomEvent, admin_room: OwnedRoomId, server_use let (mut message_content, reply) = match event { AdminRoomEvent::SendMessage(content) => (content, None), AdminRoomEvent::ProcessMessage(room_message, reply_id) => { - (process_admin_message(room_message).await, Some(reply_id)) + // This future is ~8 KiB so it's better to start it off the stack. + (Box::pin(process_admin_message(room_message)).await, Some(reply_id)) }, }; -- GitLab From fc1b8326e63fae30bc5d8b02590cfbba4f8a320b Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 04:22:47 +0000 Subject: [PATCH 02/13] split join_room_by_id_helper into local and remote prior stack frame allocated 180 KiB Signed-off-by: Jason Volk <jason@zemos.net> --- clippy.toml | 2 +- src/api/client/membership.rs | 940 ++++++++++++++++++----------------- 2 files changed, 475 insertions(+), 467 deletions(-) diff --git a/clippy.toml b/clippy.toml index 0a04ecd29..1743ee5c5 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,6 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 178030 # reduce me ALARA +stack-size-threshold = 173577 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 3b8920f8a..b50ef2cbd 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -30,7 +30,7 @@ OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::RwLock; +use tokio::sync::{MutexGuard, RwLock}; use tracing::{debug, error, info, trace, warn}; use super::get_alias_helper; @@ -627,7 +627,7 @@ pub(crate) async fn joined_members_route( pub async fn join_room_by_id_helper( sender_user: Option<&UserId>, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, + third_party_signed: Option<&ThirdPartySigned>, ) -> Result<join_room_by_id::v3::Response> { let sender_user = sender_user.expect("user is authenticated"); @@ -655,27 +655,468 @@ pub async fn join_room_by_id_helper( .state_cache .server_in_room(services().globals.server_name(), room_id)? { - info!("Joining {room_id} over federation."); + join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + } else { + join_room_by_id_helper_local(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + } +} - let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; +async fn join_room_by_id_helper_remote( + sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, state_lock: MutexGuard<'_, ()>, +) -> Result<join_room_by_id::v3::Response> { + info!("Joining {room_id} over federation."); + + let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; + + info!("make_join finished"); + + let room_version_id = match make_join_response.room_version { + Some(room_version) + if services() + .globals + .supported_room_versions() + .contains(&room_version) => + { + room_version + }, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) + .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; + + let join_authorized_via_users_server = join_event_stub + .get("content") + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + + // TODO: Is origin needed? + join_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + join_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + join_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason, + join_authorized_via_users_server: join_authorized_via_users_server.clone(), + }) + .expect("event is valid, we just created it"), + ); + + // We keep the "event_id" in the pdu only in v1 or + // v2 rooms + match room_version_id { + RoomVersionId::V1 | RoomVersionId::V2 => {}, + _ => { + join_event_stub.remove("event_id"); + }, + }; + + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut join_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + // Generate event id + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&join_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + + // It has enough fields to be called a proper event now + let mut join_event = join_event_stub; + + info!("Asking {remote_server} for send_join in room {room_id}"); + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; + + info!("send_join finished"); + + if join_authorized_via_users_server.is_some() { + match &room_version_id { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 => { + warn!( + "Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", + room_id, &room_version_id + ); + }, + // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to + // validate and send signatures + RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { + if let Some(signed_raw) = &send_join_response.room_state.event { + info!( + "There is a signed event. This room is probably using restricted joins. Adding signature to \ + our event" + ); + let Ok((signed_event_id, signed_value)) = gen_event_id_canonical_json(signed_raw, &room_version_id) + else { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }; + + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } + + match signed_value["signatures"] + .as_object() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent invalid signatures type", + )) + .and_then(|e| { + e.get(remote_server.as_str()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Server did not send its signature")) + }) { + Ok(signature) => { + join_event + .get_mut("signatures") + .expect("we created a valid pdu") + .as_object_mut() + .expect("we created a valid pdu") + .insert(remote_server.to_string(), signature.clone()); + }, + Err(e) => { + warn!( + "Server {remote_server} sent invalid signature in sendjoin signatures for event \ + {signed_value:?}: {e:?}", + ); + }, + } + } + }, + _ => { + warn!( + "Unexpected or unsupported room version {} for room {}", + &room_version_id, room_id + ); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } + } + + services().rooms.short.get_or_create_shortroomid(room_id)?; + + info!("Parsing join event"); + let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + + let mut state = HashMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); + + info!("Fetching join signing keys"); + services() + .rooms + .event_handler + .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .await?; + + info!("Going through send_join response room_state"); + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let Ok((event_id, value)) = result.await else { + continue; + }; + + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("Invalid PDU in send_join response: {} {:?}", e, value); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + if let Some(state_key) = &pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + state.insert(shortstatekey, pdu.event_id.clone()); + } + } + + info!("Going through send_join response auth_chain"); + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let Ok((event_id, value)) = result.await else { + continue; + }; + + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; + } + + debug!("Running send_join auth check"); + + let auth_check = state_res::event_auth::auth_check( + &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &parsed_join_pdu, + None::<PduEvent>, // TODO: third party invite + |k, s| { + services() + .rooms + .timeline + .get_pdu( + state.get( + &services() + .rooms + .short + .get_or_create_shortstatekey(&k.to_string().into(), s) + .ok()?, + )?, + ) + .ok()? + }, + ) + .map_err(|e| { + warn!("Auth check failed: {e}"); + Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") + })?; + + if !auth_check { + return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); + } + + info!("Saving state from send_join"); + let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( + room_id, + Arc::new( + state + .into_iter() + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) + .collect::<Result<_>>()?, + ), + )?; + + services() + .rooms + .state + .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .await?; + + info!("Updating joined counts for new room"); + services().rooms.state_cache.update_joined_count(room_id)?; + + // We append to state before appending the pdu, so we don't have a moment in + // time with the pdu without it's state. This is okay because append_pdu can't + // fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + + info!("Appending new room join event"); + services() + .rooms + .timeline + .append_pdu( + &parsed_join_pdu, + join_event, + vec![(*parsed_join_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; + + info!("Setting final room state for new room"); + // We set the room state after inserting the pdu, so that we never have a moment + // in time where events in the current room state do not exist + services() + .rooms + .state + .set_room_state(room_id, statehash_after_join, &state_lock)?; + + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) +} + +async fn join_room_by_id_helper_local( + sender_user: &UserId, room_id: &RoomId, reason: Option<String>, servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, state_lock: MutexGuard<'_, ()>, +) -> Result<join_room_by_id::v3::Response> { + info!("We can join locally"); + + let join_rules_event = + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + + let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; + + let restriction_rooms = match join_rules_event_content { + Some(RoomJoinRulesEventContent { + join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), + }) => restricted + .allow + .into_iter() + .filter_map(|a| match a { + AllowRule::RoomMembership(r) => Some(r.room_id), + _ => None, + }) + .collect(), + _ => Vec::new(), + }; + + let local_members = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .filter(|user| user_is_local(user)) + .collect::<Vec<OwnedUserId>>(); + + let mut authorized_user: Option<OwnedUserId> = None; + + if restriction_rooms.iter().any(|restriction_room_id| { + services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .unwrap_or(false) + }) { + for user in local_members { + if services() + .rooms + .state_accessor + .user_can_invite(room_id, &user, sender_user, &state_lock) + .await + .unwrap_or(false) + { + authorized_user = Some(user); + break; + } + } + } - info!("make_join finished"); + let event = RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server: authorized_user, + }; + + // Try normal join first + let error = match services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await + { + Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Err(e) => e, + }; + + if !restriction_rooms.is_empty() + && servers + .iter() + .any(|server_name| !server_is_ours(server_name)) + { + info!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); + let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let room_version_id = match make_join_response.room_version { - Some(room_version) + Some(room_version_id) if services() .globals .supported_room_versions() - .contains(&room_version) => + .contains(&room_version_id) => { - room_version + room_version_id }, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; - let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; - let join_authorized_via_users_server = join_event_stub .get("content") .map(|s| { @@ -684,7 +1125,6 @@ pub async fn join_room_by_id_helper( .as_str() }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), @@ -708,7 +1148,7 @@ pub async fn join_room_by_id_helper( third_party_invite: None, blurhash: services().users.blurhash(sender_user)?, reason, - join_authorized_via_users_server: join_authorized_via_users_server.clone(), + join_authorized_via_users_server, }) .expect("event is valid, we just created it"), ); @@ -744,9 +1184,8 @@ pub async fn join_room_by_id_helper( join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); // It has enough fields to be called a proper event now - let mut join_event = join_event_stub; + let join_event = join_event_stub; - info!("Asking {remote_server} for send_join in room {room_id}"); let send_join_response = services() .sending .send_federation_request( @@ -760,470 +1199,39 @@ pub async fn join_room_by_id_helper( ) .await?; - info!("send_join finished"); - - if join_authorized_via_users_server.is_some() { - match &room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 => { - warn!( - "Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", - room_id, &room_version_id - ); - }, - // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to - // validate and send signatures - RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { - if let Some(signed_raw) = &send_join_response.room_state.event { - info!( - "There is a signed event. This room is probably using restricted joins. Adding signature \ - to our event" - ); - let Ok((signed_event_id, signed_value)) = - gen_event_id_canonical_json(signed_raw, &room_version_id) - else { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - }; - - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } - - match signed_value["signatures"] - .as_object() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent invalid signatures type", - )) - .and_then(|e| { - e.get(remote_server.as_str()).ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server did not send its signature", - )) - }) { - Ok(signature) => { - join_event - .get_mut("signatures") - .expect("we created a valid pdu") - .as_object_mut() - .expect("we created a valid pdu") - .insert(remote_server.to_string(), signature.clone()); - }, - Err(e) => { - warn!( - "Server {remote_server} sent invalid signature in sendjoin signatures for event \ - {signed_value:?}: {e:?}", - ); - }, - } - } - }, - _ => { - warn!( - "Unexpected or unsupported room version {} for room {}", - &room_version_id, room_id - ); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, - } - } - - services().rooms.short.get_or_create_shortroomid(room_id)?; - - info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); - - info!("Fetching join signing keys"); - services() - .rooms - .event_handler - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; - - info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let Ok((event_id, value)) = result.await else { - continue; + if let Some(signed_raw) = send_join_response.room_state.event { + let Ok((signed_event_id, signed_value)) = gen_event_id_canonical_json(&signed_raw, &room_version_id) else { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); }; - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("Invalid PDU in send_join response: {} {:?}", e, value); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; - state.insert(shortstatekey, pdu.event_id.clone()); + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); } - } - - info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let Ok((event_id, value)) = result.await else { - continue; - }; - - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - } - - debug!("Running send_join auth check"); - - let auth_check = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), - &parsed_join_pdu, - None::<PduEvent>, // TODO: third party invite - |k, s| { - services() - .rooms - .timeline - .get_pdu( - state.get( - &services() - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, - ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::forbidden(), "Auth check failed") - })?; - - if !auth_check { - return Err(Error::BadRequest(ErrorKind::forbidden(), "Auth check failed")); - } - - info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) - .collect::<Result<_>>()?, - ), - )?; - - services() - .rooms - .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) - .await?; - - info!("Updating joined counts for new room"); - services().rooms.state_cache.update_joined_count(room_id)?; - // We append to state before appending the pdu, so we don't have a moment in - // time with the pdu without it's state. This is okay because append_pdu can't - // fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; - - info!("Appending new room join event"); - services() - .rooms - .timeline - .append_pdu( - &parsed_join_pdu, - join_event, - vec![(*parsed_join_pdu.event_id).to_owned()], - &state_lock, - ) - .await?; - - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment - // in time where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; - } else { - info!("We can join locally"); - - let join_rules_event = + drop(state_lock); + let pub_key_map = RwLock::new(BTreeMap::new()); services() .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - - let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; - - let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted) | JoinRule::KnockRestricted(restricted), - }) => restricted - .allow - .into_iter() - .filter_map(|a| match a { - AllowRule::RoomMembership(r) => Some(r.room_id), - _ => None, - }) - .collect(), - _ => Vec::new(), - }; - - let local_members = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(Result::ok) - .filter(|user| user_is_local(user)) - .collect::<Vec<OwnedUserId>>(); - - let mut authorized_user: Option<OwnedUserId> = None; - - if restriction_rooms.iter().any(|restriction_room_id| { + .event_handler + .fetch_required_signing_keys([&signed_value], &pub_key_map) + .await?; services() .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .unwrap_or(false) - }) { - for user in local_members { - if services() - .rooms - .state_accessor - .user_can_invite(room_id, &user, sender_user, &state_lock) - .await - .unwrap_or(false) - { - authorized_user = Some(user); - break; - } - } - } - - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: reason.clone(), - join_authorized_via_users_server: authorized_user, - }; - - // Try normal join first - let error = match services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await - { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), - Err(e) => e, - }; - - if !restriction_rooms.is_empty() - && servers - .iter() - .any(|server_name| !server_is_ours(server_name)) - { - info!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted join \ - requirements" - ); - let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; - - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services() - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - }, - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) - .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); - - // We keep the "event_id" in the pdu only in v1 or - // v2 rooms - match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - _ => { - join_event_stub.remove("event_id"); - }, - }; - - // In order to create a compatible ref hash (EventID) the `hashes` field needs - // to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = - <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - - // It has enough fields to be called a proper event now - let join_event = join_event_stub; - - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) + .event_handler + .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true, &pub_key_map) .await?; - - if let Some(signed_raw) = send_join_response.room_state.event { - let Ok((signed_event_id, signed_value)) = gen_event_id_canonical_json(&signed_raw, &room_version_id) - else { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - }; - - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } - - drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services() - .rooms - .event_handler - .fetch_required_signing_keys([&signed_value], &pub_key_map) - .await?; - services() - .rooms - .event_handler - .handle_incoming_pdu(&remote_server, room_id, &signed_event_id, signed_value, true, &pub_key_map) - .await?; - } else { - return Err(error); - } } else { return Err(error); } + } else { + return Err(error); } Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) -- GitLab From db2c9f28b60977addb9b4605bd1229b7e5c56c16 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 04:40:19 +0000 Subject: [PATCH 03/13] split admin room moderation commands prior stack frame allocated 170 KiB Signed-off-by: Jason Volk <jason@zemos.net> --- clippy.toml | 2 +- src/admin/room/room_moderation_commands.rs | 818 +++++++++++---------- 2 files changed, 411 insertions(+), 409 deletions(-) diff --git a/clippy.toml b/clippy.toml index 1743ee5c5..c942b93c7 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,6 +2,6 @@ array-size-threshold = 4096 cognitive-complexity-threshold = 94 # TODO reduce me ALARA excessive-nesting-threshold = 11 # TODO reduce me to 4 or 5 future-size-threshold = 7745 # TODO reduce me ALARA -stack-size-threshold = 173577 # reduce me ALARA +stack-size-threshold = 144000 # reduce me ALARA too-many-lines-threshold = 700 # TODO reduce me to <= 100 type-complexity-threshold = 250 # reduce me to ~200 diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index f5176147b..6b759c9bb 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -18,83 +18,296 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> force, room, disable_federation, - } => { - debug!("Got room alias or ID: {}", room); + } => ban_room(body, force, room, disable_federation).await, + RoomModerationCommand::BanListOfRooms { + force, + disable_federation, + } => ban_list_of_rooms(body, force, disable_federation).await, + RoomModerationCommand::UnbanRoom { + room, + enable_federation, + } => unban_room(body, room, enable_federation).await, + RoomModerationCommand::ListBannedRooms => list_banned_rooms(body).await, + } +} - let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); +async fn ban_room( + _body: Vec<&str>, force: bool, room: Box<RoomOrAliasId>, disable_federation: bool, +) -> Result<RoomMessageEventContent> { + debug!("Got room alias or ID: {}", room); - if let Some(admin_room_id) = Service::get_admin_room().await? { - if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) { - return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); - } - } + let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); - let room_id = if room.is_room_id() { - let room_id = match RoomId::parse(&room) { - Ok(room_id) => room_id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to parse room ID {room}. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" - ))) - }, - }; + if let Some(admin_room_id) = Service::get_admin_room().await? { + if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) { + return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); + } + } - debug!("Room specified is a room ID, banning room ID"); + let room_id = if room.is_room_id() { + let room_id = match RoomId::parse(&room) { + Ok(room_id) => room_id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" + ))) + }, + }; - services().rooms.metadata.ban_room(&room_id, true)?; + debug!("Room specified is a room ID, banning room ID"); - room_id - } else if room.is_room_alias_id() { - let room_alias = match RoomAliasId::parse(&room) { - Ok(room_alias) => room_alias, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to parse room ID {room}. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" - ))) - }, - }; - - debug!( - "Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not \ - using get_alias_helper to fetch room ID remotely" + services().rooms.metadata.ban_room(&room_id, true)?; + + room_id + } else if room.is_room_alias_id() { + let room_alias = match RoomAliasId::parse(&room) { + Ok(room_alias) => room_alias, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!( + "Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not using \ + get_alias_helper to fetch room ID remotely" + ); + + let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + room_id + } else { + debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); + + match get_alias_helper(room_alias, None).await { + Ok(response) => { + debug!("Got federation response fetching room ID for room {room}: {:?}", response); + response.room_id + }, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; + + services().rooms.metadata.ban_room(&room_id, true)?; + + room_id + } else { + return Ok(RoomMessageEventContent::text_plain( + "Room specified is not a room ID or room alias. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)", + )); + }; + + debug!("Making all users leave the room {}", &room); + if force { + for local_user in services() + .rooms + .state_cache + .room_members(&room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + user_is_local(local_user) + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (user_is_local(local_user) + && services() + .users + .is_admin(local_user) + .unwrap_or(true)) // since this is a force + // operation, assume user + // is an admin if somehow + // this fails + }) + }) + .collect::<Vec<OwnedUserId>>() + { + debug!( + "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", + &local_user, &room_id + ); + + if let Err(e) = leave_room(&local_user, &room_id, None).await { + warn!(%e, "Failed to leave room"); + } + } + } else { + for local_user in services() + .rooms + .state_cache + .room_members(&room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + local_user.server_name() == services().globals.server_name() + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && !services() + .users + .is_admin(local_user) + .unwrap_or(false)) + }) + }) + .collect::<Vec<OwnedUserId>>() + { + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); + if let Err(e) = leave_room(&local_user, &room_id, None).await { + error!( + "Error attempting to make local user {} leave room {} during room banning: {}", + &local_user, &room_id, e ); + return Ok(RoomMessageEventContent::text_plain(format!( + "Error attempting to make local user {} leave room {} during room banning (room is still banned \ + but not removing any more users): {}\nIf you would like to ignore errors, use --force", + &local_user, &room_id, e + ))); + } + } + } - let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room ID over \ - federation" - ); + if disable_federation { + services().rooms.metadata.disable_room(&room_id, true)?; + return Ok(RoomMessageEventContent::text_plain( + "Room banned, removed all our local users, and disabled incoming federation with room.", + )); + } + + Ok(RoomMessageEventContent::text_plain( + "Room banned and removed all our local users, use `!admin federation disable-room` to stop receiving new \ + inbound federation events as well if needed.", + )) +} - match get_alias_helper(room_alias, None).await { - Ok(response) => { - debug!("Got federation response fetching room ID for room {room}: {:?}", response); - response.room_id - }, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, +async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: bool) -> Result<RoomMessageEventContent> { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let rooms_s = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); + + let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + + let mut room_ban_count: usize = 0; + let mut room_ids: Vec<OwnedRoomId> = Vec::new(); + + for &room in &rooms_s { + match <&RoomOrAliasId>::try_from(room) { + Ok(room_alias_or_id) => { + if let Some(admin_room_id) = Service::get_admin_room().await? { + if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(&admin_room_alias) { + info!("User specified admin room in bulk ban list, ignoring"); + continue; + } } - }; - services().rooms.metadata.ban_room(&room_id, true)?; + if room_alias_or_id.is_room_id() { + let room_id = match RoomId::parse(room_alias_or_id) { + Ok(room_id) => room_id, + Err(e) => { + if force { + // ignore rooms we failed to parse if we're force banning + warn!( + "Error parsing room \"{room}\" during bulk room banning, ignoring error and \ + logging here: {e}" + ); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "{room} is not a valid room ID or room alias, please fix the list and try again: \ + {e}" + ))); + }, + }; + + room_ids.push(room_id); + } - room_id - } else { - return Ok(RoomMessageEventContent::text_plain( - "Room specified is not a room ID or room alias. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)", - )); - }; + if room_alias_or_id.is_room_alias_id() { + match RoomAliasId::parse(room_alias_or_id) { + Ok(room_alias) => { + let room_id = + if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + room_id + } else { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch \ + room ID over federation" + ); + + match get_alias_helper(room_alias, None).await { + Ok(response) => { + debug!( + "Got federation response fetching room ID for room {room}: {:?}", + response + ); + response.room_id + }, + Err(e) => { + // don't fail if force blocking + if force { + warn!("Failed to resolve room alias {room} to a room ID: {e}"); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }; + + room_ids.push(room_id); + }, + Err(e) => { + if force { + // ignore rooms we failed to parse if we're force deleting + error!( + "Error parsing room \"{room}\" during bulk room banning, ignoring error and \ + logging here: {e}" + ); + continue; + } - debug!("Making all users leave the room {}", &room); + return Ok(RoomMessageEventContent::text_plain(format!( + "{room} is not a valid room ID or room alias, please fix the list and try again: \ + {e}" + ))); + }, + } + } + }, + Err(e) => { + if force { + // ignore rooms we failed to parse if we're force deleting + error!( + "Error parsing room \"{room}\" during bulk room banning, ignoring error and logging here: \ + {e}" + ); + continue; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "{room} is not a valid room ID or room alias, please fix the list and try again: {e}" + ))); + }, + } + } + + for room_id in room_ids { + if services().rooms.metadata.ban_room(&room_id, true).is_ok() { + debug!("Banned {room_id} successfully"); + room_ban_count = room_ban_count.saturating_add(1); + } + + debug!("Making all users leave the room {}", &room_id); if force { for local_user in services() .rooms @@ -102,26 +315,28 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - user_is_local(local_user) - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (user_is_local(local_user) - && services() - .users - .is_admin(local_user) - .unwrap_or(true)) // since this is a force - // operation, assume user - // is an admin if somehow - // this fails + local_user.server_name() == services().globals.server_name() + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && services() + .users + .is_admin(local_user) + .unwrap_or(true)) // since this is a + // force operation, + // assume user is + // an admin if + // somehow this + // fails }) }) .collect::<Vec<OwnedUserId>>() { debug!( "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id + &local_user, room_id ); - if let Err(e) = leave_room(&local_user, &room_id, None).await { warn!(%e, "Failed to leave room"); } @@ -134,14 +349,14 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .filter_map(|user| { user.ok().filter(|local_user| { local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && !services() - .users - .is_admin(local_user) - .unwrap_or(false)) + // additional wrapped check here is to avoid adding remote users + // who are in the admin room to the list of local users (would fail auth check) + && (local_user.server_name() + == services().globals.server_name() + && !services() + .users + .is_admin(local_user) + .unwrap_or(false)) }) }) .collect::<Vec<OwnedUserId>>() @@ -149,13 +364,13 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> debug!("Attempting leave for user {} in room {}", &local_user, &room_id); if let Err(e) = leave_room(&local_user, &room_id, None).await { error!( - "Error attempting to make local user {} leave room {} during room banning: {}", + "Error attempting to make local user {} leave room {} during bulk room banning: {}", &local_user, &room_id, e ); return Ok(RoomMessageEventContent::text_plain(format!( "Error attempting to make local user {} leave room {} during room banning (room is still \ - banned but not removing any more users): {}\nIf you would like to ignore errors, use \ - --force", + banned but not removing any more users and not banning any more rooms): {}\nIf you would \ + like to ignore errors, use --force", &local_user, &room_id, e ))); } @@ -164,341 +379,128 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> if disable_federation { services().rooms.metadata.disable_room(&room_id, true)?; - return Ok(RoomMessageEventContent::text_plain( - "Room banned, removed all our local users, and disabled incoming federation with room.", - )); } + } + + if disable_federation { + return Ok(RoomMessageEventContent::text_plain(format!( + "Finished bulk room ban, banned {room_ban_count} total rooms, evicted all users, and disabled \ + incoming federation with the room." + ))); + } + return Ok(RoomMessageEventContent::text_plain(format!( + "Finished bulk room ban, banned {room_ban_count} total rooms and evicted all users." + ))); + } - Ok(RoomMessageEventContent::text_plain( - "Room banned and removed all our local users, use `!admin federation disable-room` to stop receiving \ - new inbound federation events as well if needed.", - )) - }, - RoomModerationCommand::BanListOfRooms { - force, - disable_federation, - } => { - if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { - let rooms_s = body.clone().drain(1..body.len() - 1).collect::<Vec<_>>(); - - let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - - let mut room_ban_count: usize = 0; - let mut room_ids: Vec<OwnedRoomId> = Vec::new(); - - for &room in &rooms_s { - match <&RoomOrAliasId>::try_from(room) { - Ok(room_alias_or_id) => { - if let Some(admin_room_id) = Service::get_admin_room().await? { - if room.to_owned().eq(&admin_room_id) || room.to_owned().eq(&admin_room_alias) { - info!("User specified admin room in bulk ban list, ignoring"); - continue; - } - } - - if room_alias_or_id.is_room_id() { - let room_id = match RoomId::parse(room_alias_or_id) { - Ok(room_id) => room_id, - Err(e) => { - if force { - // ignore rooms we failed to parse if we're force banning - warn!( - "Error parsing room \"{room}\" during bulk room banning, ignoring \ - error and logging here: {e}" - ); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "{room} is not a valid room ID or room alias, please fix the list and try \ - again: {e}" - ))); - }, - }; + Ok(RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + )) +} - room_ids.push(room_id); - } - - if room_alias_or_id.is_room_alias_id() { - match RoomAliasId::parse(room_alias_or_id) { - Ok(room_alias) => { - let room_id = if let Some(room_id) = - services().rooms.alias.resolve_local_alias(&room_alias)? - { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to \ - fetch room ID over federation" - ); - - match get_alias_helper(room_alias, None).await { - Ok(response) => { - debug!( - "Got federation response fetching room ID for room {room}: \ - {:?}", - response - ); - response.room_id - }, - Err(e) => { - // don't fail if force blocking - if force { - warn!("Failed to resolve room alias {room} to a room ID: {e}"); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; - - room_ids.push(room_id); - }, - Err(e) => { - if force { - // ignore rooms we failed to parse if we're force deleting - error!( - "Error parsing room \"{room}\" during bulk room banning, ignoring \ - error and logging here: {e}" - ); - continue; - } +async fn unban_room( + _body: Vec<&str>, room: Box<RoomOrAliasId>, enable_federation: bool, +) -> Result<RoomMessageEventContent> { + let room_id = if room.is_room_id() { + let room_id = match RoomId::parse(&room) { + Ok(room_id) => room_id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" + ))) + }, + }; - return Ok(RoomMessageEventContent::text_plain(format!( - "{room} is not a valid room ID or room alias, please fix the list and try \ - again: {e}" - ))); - }, - } - } - }, - Err(e) => { - if force { - // ignore rooms we failed to parse if we're force deleting - error!( - "Error parsing room \"{room}\" during bulk room banning, ignoring error and \ - logging here: {e}" - ); - continue; - } - - return Ok(RoomMessageEventContent::text_plain(format!( - "{room} is not a valid room ID or room alias, please fix the list and try again: {e}" - ))); - }, - } - } + debug!("Room specified is a room ID, unbanning room ID"); - for room_id in room_ids { - if services().rooms.metadata.ban_room(&room_id, true).is_ok() { - debug!("Banned {room_id} successfully"); - room_ban_count = room_ban_count.saturating_add(1); - } + services().rooms.metadata.ban_room(&room_id, false)?; - debug!("Making all users leave the room {}", &room_id); - if force { - for local_user in services() - .rooms - .state_cache - .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && services() - .users - .is_admin(local_user) - .unwrap_or(true)) // since this is a - // force operation, - // assume user is - // an admin if - // somehow this - // fails - }) - }) - .collect::<Vec<OwnedUserId>>() - { - debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting \ - admins too)", - &local_user, room_id - ); - if let Err(e) = leave_room(&local_user, &room_id, None).await { - warn!(%e, "Failed to leave room"); - } - } - } else { - for local_user in services() - .rooms - .state_cache - .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() - // additional wrapped check here is to avoid adding remote users - // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() - && !services() - .users - .is_admin(local_user) - .unwrap_or(false)) - }) - }) - .collect::<Vec<OwnedUserId>>() - { - debug!("Attempting leave for user {} in room {}", &local_user, &room_id); - if let Err(e) = leave_room(&local_user, &room_id, None).await { - error!( - "Error attempting to make local user {} leave room {} during bulk room banning: {}", - &local_user, &room_id, e - ); - return Ok(RoomMessageEventContent::text_plain(format!( - "Error attempting to make local user {} leave room {} during room banning (room \ - is still banned but not removing any more users and not banning any more rooms): \ - {}\nIf you would like to ignore errors, use --force", - &local_user, &room_id, e - ))); - } - } - } - - if disable_federation { - services().rooms.metadata.disable_room(&room_id, true)?; - } - } - - if disable_federation { + room_id + } else if room.is_room_alias_id() { + let room_alias = match RoomAliasId::parse(&room) { + Ok(room_alias) => room_alias, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!( + "Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not using \ + get_alias_helper to fetch room ID remotely" + ); + + let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { + room_id + } else { + debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); + + match get_alias_helper(room_alias, None).await { + Ok(response) => { + debug!("Got federation response fetching room ID for room {room}: {:?}", response); + response.room_id + }, + Err(e) => { return Ok(RoomMessageEventContent::text_plain(format!( - "Finished bulk room ban, banned {room_ban_count} total rooms, evicted all users, and disabled \ - incoming federation with the room." + "Failed to resolve room alias {room} to a room ID: {e}" ))); - } - return Ok(RoomMessageEventContent::text_plain(format!( - "Finished bulk room ban, banned {room_ban_count} total rooms and evicted all users." - ))); + }, } + }; - Ok(RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - )) - }, - RoomModerationCommand::UnbanRoom { - room, - enable_federation, - } => { - let room_id = if room.is_room_id() { - let room_id = match RoomId::parse(&room) { - Ok(room_id) => room_id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to parse room ID {room}. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" - ))) - }, - }; - - debug!("Room specified is a room ID, unbanning room ID"); - - services().rooms.metadata.ban_room(&room_id, false)?; + services().rooms.metadata.ban_room(&room_id, false)?; - room_id - } else if room.is_room_alias_id() { - let room_alias = match RoomAliasId::parse(&room) { - Ok(room_alias) => room_alias, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to parse room ID {room}. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}" - ))) - }, - }; - - debug!( - "Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not \ - using get_alias_helper to fetch room ID remotely" - ); + room_id + } else { + return Ok(RoomMessageEventContent::text_plain( + "Room specified is not a room ID or room alias. Please note that this requires a full room ID \ + (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)", + )); + }; - let room_id = if let Some(room_id) = services().rooms.alias.resolve_local_alias(&room_alias)? { - room_id - } else { - debug!( - "We don't have this room alias to a room ID locally, attempting to fetch room ID over \ - federation" - ); - - match get_alias_helper(room_alias, None).await { - Ok(response) => { - debug!("Got federation response fetching room ID for room {room}: {:?}", response); - response.room_id - }, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to resolve room alias {room} to a room ID: {e}" - ))); - }, - } - }; - - services().rooms.metadata.ban_room(&room_id, false)?; + if enable_federation { + services().rooms.metadata.disable_room(&room_id, false)?; + return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); + } - room_id - } else { - return Ok(RoomMessageEventContent::text_plain( - "Room specified is not a room ID or room alias. Please note that this requires a full room ID \ - (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)", - )); - }; - - if enable_federation { - services().rooms.metadata.disable_room(&room_id, false)?; - return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); - } + Ok(RoomMessageEventContent::text_plain( + "Room unbanned, you may need to re-enable federation with the room using enable-room if this is a remote room \ + to make it fully functional.", + )) +} - Ok(RoomMessageEventContent::text_plain( - "Room unbanned, you may need to re-enable federation with the room using enable-room if this is a \ - remote room to make it fully functional.", - )) +async fn list_banned_rooms(_body: Vec<&str>) -> Result<RoomMessageEventContent> { + let rooms = services() + .rooms + .metadata + .list_banned_rooms() + .collect::<Result<Vec<_>, _>>(); + + match rooms { + Ok(room_ids) => { + // TODO: add room name from our state cache if available, default to the room ID + // as the room name if we dont have it TODO: do same if we have a room alias for + // this + let plain_list = room_ids.iter().fold(String::new(), |mut output, room_id| { + writeln!(output, "- `{room_id}`").unwrap(); + output + }); + + let html_list = room_ids.iter().fold(String::new(), |mut output, room_id| { + writeln!(output, "<li><code>{}</code></li>", escape_html(room_id.as_ref())).unwrap(); + output + }); + + let plain = format!("Rooms:\n{plain_list}"); + let html = format!("Rooms:\n<ul>{html_list}</ul>"); + Ok(RoomMessageEventContent::text_html(plain, html)) }, - RoomModerationCommand::ListBannedRooms => { - let rooms = services() - .rooms - .metadata - .list_banned_rooms() - .collect::<Result<Vec<_>, _>>(); - - match rooms { - Ok(room_ids) => { - // TODO: add room name from our state cache if available, default to the room ID - // as the room name if we dont have it TODO: do same if we have a room alias for - // this - let plain_list = room_ids.iter().fold(String::new(), |mut output, room_id| { - writeln!(output, "- `{room_id}`").unwrap(); - output - }); - - let html_list = room_ids.iter().fold(String::new(), |mut output, room_id| { - writeln!(output, "<li><code>{}</code></li>", escape_html(room_id.as_ref())).unwrap(); - output - }); - - let plain = format!("Rooms:\n{plain_list}"); - let html = format!("Rooms:\n<ul>{html_list}</ul>"); - Ok(RoomMessageEventContent::text_html(plain, html)) - }, - Err(e) => { - error!("Failed to list banned rooms: {}", e); - Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))) - }, - } + Err(e) => { + error!("Failed to list banned rooms: {}", e); + Ok(RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {e}"))) }, } } -- GitLab From b65f05ce19f7f581a5770d82b82b2ee25a9e9971 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Wed, 29 May 2024 14:57:53 +0000 Subject: [PATCH 04/13] simplify lifetime parameters Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/rooms/state_cache/data.rs | 4 ++-- src/service/rooms/state_cache/mod.rs | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 1c0571c95..acf390d98 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -72,7 +72,7 @@ fn mark_as_invited( fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>>; /// Returns an iterator over all rooms this user joined. - fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>; + fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_>; /// Returns an iterator over all rooms a user was invited to. fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a>; @@ -494,7 +494,7 @@ fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u6 /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] - fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { + fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> { Box::new( self.userroomid_joined .scan_prefix(user_id.as_bytes().to_vec()) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 1cfd56db6..2b3f5fb7b 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -239,7 +239,7 @@ pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self))] - pub fn room_servers<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + 'a { + pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { self.db.room_servers(room_id) } @@ -251,7 +251,7 @@ pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bo /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self))] - pub fn server_rooms<'a>(&'a self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + 'a { + pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.server_rooms(server) } @@ -282,7 +282,7 @@ pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> { /// Returns an iterator over all joined members of a room. #[tracing::instrument(skip(self))] - pub fn room_members<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.room_members(room_id) } @@ -308,13 +308,13 @@ pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self))] - pub fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.room_useroncejoined(room_id) } /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self))] - pub fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + 'a { + pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ { self.db.room_members_invited(room_id) } @@ -330,15 +330,15 @@ pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Optio /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] - pub fn rooms_joined<'a>(&'a self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + 'a { + pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.rooms_joined(user_id) } /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self))] - pub fn rooms_invited<'a>( - &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a { + pub fn rooms_invited( + &self, user_id: &UserId, + ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + '_ { self.db.rooms_invited(user_id) } @@ -354,9 +354,9 @@ pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Ve /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self))] - pub fn rooms_left<'a>( - &'a self, user_id: &UserId, - ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a { + pub fn rooms_left( + &self, user_id: &UserId, + ) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + '_ { self.db.rooms_left(user_id) } -- GitLab From 67f4285504c20206c85da9a62e6c0fda521a2c3a Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 05:15:27 +0000 Subject: [PATCH 05/13] Fix branches sharing code Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/database/rocksdb/mod.rs | 9 +++------ src/service/admin.rs | 6 ++---- src/service/rooms/threads/mod.rs | 3 +-- 4 files changed, 6 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a36e1fd1a..7c2630427 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -715,7 +715,6 @@ correctness = "warn" nursery = "warn" ## some sadness -branches_sharing_code = { level = "allow", priority = 1 } # TODO derive_partial_eq_without_eq = { level = "allow", priority = 1 } # TODO equatable_if_let = { level = "allow", priority = 1 } # TODO future_not_send = { level = "allow", priority = 1 } # TODO diff --git a/src/database/rocksdb/mod.rs b/src/database/rocksdb/mod.rs index 28ce50720..f0c3882f0 100644 --- a/src/database/rocksdb/mod.rs +++ b/src/database/rocksdb/mod.rs @@ -182,7 +182,7 @@ fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { let options = BackupEngineOptions::new(path.unwrap())?; let mut engine = BackupEngine::open(&options, &self.env)?; - let ret = if self.config.database_backups_to_keep > 0 { + if self.config.database_backups_to_keep > 0 { if let Err(e) = engine.create_new_backup_flush(&self.rocks, true) { return Err(Box::new(e)); } @@ -193,10 +193,7 @@ fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { "Created database backup #{} using {} bytes in {} files", info.backup_id, info.size, info.num_files, ); - Ok(()) - } else { - Ok(()) - }; + } if self.config.database_backups_to_keep >= 0 { let keep = u32::try_from(self.config.database_backups_to_keep)?; @@ -205,7 +202,7 @@ fn backup(&self) -> Result<(), Box<dyn std::error::Error>> { } } - ret + Ok(()) } fn backup_list(&self) -> Result<String> { diff --git a/src/service/admin.rs b/src/service/admin.rs index 19923d0aa..63d6d5c64 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -529,10 +529,8 @@ pub async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Re &room_id, &state_lock, ).await?; - - Ok(()) - } else { - Ok(()) } + + Ok(()) } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index a7d5c4342..6c48e842c 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -91,11 +91,10 @@ pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<( let mut users = Vec::new(); if let Some(userids) = self.db.get_participants(root_id)? { users.extend_from_slice(&userids); - users.push(pdu.sender.clone()); } else { users.push(root_pdu.sender); - users.push(pdu.sender.clone()); } + users.push(pdu.sender.clone()); self.db.update_participants(root_id, &users) } -- GitLab From b525031a254f2b95e0ed20fc4c78de02e138328f Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 05:20:31 +0000 Subject: [PATCH 06/13] Fix derive partial eq without eq Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/service/rooms/spaces/mod.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7c2630427..a6cce3f78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -715,7 +715,6 @@ correctness = "warn" nursery = "warn" ## some sadness -derive_partial_eq_without_eq = { level = "allow", priority = 1 } # TODO equatable_if_let = { level = "allow", priority = 1 } # TODO future_not_send = { level = "allow", priority = 1 } # TODO missing_const_for_fn = { level = "allow", priority = 1 } # TODO diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 72ebaf4a7..fec3fdcce 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -233,7 +233,7 @@ fn new(root: OwnedRoomId, max_depth: usize) -> Self { } // Note: perhaps use some better form of token rather than just room count -#[derive(Debug, PartialEq)] +#[derive(Debug, Eq, PartialEq)] pub struct PagnationToken { pub skip: UInt, pub limit: UInt, -- GitLab From 89d7d4832439c06344322d0aa601a7802c3ede25 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 05:25:51 +0000 Subject: [PATCH 07/13] Fix equatable if let Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/api/client/membership.rs | 2 +- src/api/client/report.rs | 4 ++-- src/service/rooms/spaces/mod.rs | 17 ++++++++++------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a6cce3f78..f52b9c8f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -715,7 +715,6 @@ correctness = "warn" nursery = "warn" ## some sadness -equatable_if_let = { level = "allow", priority = 1 } # TODO future_not_send = { level = "allow", priority = 1 } # TODO missing_const_for_fn = { level = "allow", priority = 1 } # TODO needless_collect = { level = "allow", priority = 1 } # TODO diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index b50ef2cbd..f59da7b7d 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -631,7 +631,7 @@ pub async fn join_room_by_id_helper( ) -> Result<join_room_by_id::v3::Response> { let sender_user = sender_user.expect("user is authenticated"); - if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, room_id) { + if matches!(services().rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { info!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), diff --git a/src/api/client/report.rs b/src/api/client/report.rs index 3e6ba8803..dae6086f4 100644 --- a/src/api/client/report.rs +++ b/src/api/client/report.rs @@ -104,14 +104,14 @@ fn is_report_valid( )); } - if let Some(true) = score.map(|s| s > int!(0) || s < int!(-100)) { + if score.map(|s| s > int!(0) || s < int!(-100)) == Some(true) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Invalid score, must be within 0 to -100", )); }; - if let Some(true) = reason.clone().map(|s| s.len() >= 750) { + if reason.clone().map(|s| s.len() >= 750) == Some(true) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Reason too long, should be 750 characters or fewer", diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index fec3fdcce..18b9258e9 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -108,7 +108,7 @@ fn first_untraversed(&mut self) -> Option<NodeId> { // whole space tree. // // You should only ever encounter a traversed node when going up through parents - while let Some(true) = self.traversed(current) { + while self.traversed(current) == Some(true) { if let Some(next) = self.next_sibling(current) { current = next; } else if let Some(parent) = self.parent(current) { @@ -821,12 +821,15 @@ fn is_accessable_child_recurse( SpaceRoomJoinRule::Restricted => { for room in allowed_room_ids { if let Ok((join_rule, allowed_room_ids)) = get_join_rule(room) { - if let Ok(true) = is_accessable_child_recurse( - room, - &join_rule, - identifier, - &allowed_room_ids, - recurse_num + 1, + if matches!( + is_accessable_child_recurse( + room, + &join_rule, + identifier, + &allowed_room_ids, + recurse_num + 1, + ), + Ok(true) ) { return Ok(true); } -- GitLab From 7688d678709c8da6e6b5466b7f658bb51fcc0db5 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 06:50:44 +0000 Subject: [PATCH 08/13] Fix needless pass by ref mut Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/service/sending/sender.rs | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f52b9c8f6..3633bc1d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -718,7 +718,6 @@ nursery = "warn" future_not_send = { level = "allow", priority = 1 } # TODO missing_const_for_fn = { level = "allow", priority = 1 } # TODO needless_collect = { level = "allow", priority = 1 } # TODO -needless_pass_by_ref_mut = { level = "allow", priority = 1 } # TODO option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 39b7b6003..7efe3c72f 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -60,12 +60,12 @@ async fn handler(&self) -> Result<()> { let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); - self.initial_transactions(&mut futures, &mut statuses); + self.initial_transactions(&futures, &mut statuses); loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &mut futures, &mut statuses), + Ok(request) => self.handle_request(request, &futures, &mut statuses), Err(_) => return Ok(()), }, Some(response) = futures.next() => { @@ -98,7 +98,7 @@ fn handle_response_err( } fn handle_response_ok( - &self, dest: &Destination, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, + &self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { let _cork = services().globals.db.cork(); self.db @@ -125,7 +125,7 @@ fn handle_response_ok( } } - fn handle_request(&self, msg: Msg, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { let iv = vec![(msg.event, msg.queue_id)]; if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) { if !events.is_empty() { @@ -136,7 +136,7 @@ fn handle_request(&self, msg: Msg, futures: &mut SendingFutures<'_>, statuses: & } } - fn initial_transactions(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + fn initial_transactions(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { -- GitLab From a8de5d1e60e5bfdb8a1db35172dcec58df4520d8 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 06:10:21 +0000 Subject: [PATCH 09/13] Fix futures not Send Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.lock | 24 +++++++++---------- Cargo.toml | 1 - src/api/client/keys.rs | 2 +- src/router/mod.rs | 6 ++--- src/service/globals/migrations.rs | 7 +++++- src/service/pusher/mod.rs | 2 +- .../rooms/event_handler/signing_keys.rs | 2 +- src/service/sending/appservice.rs | 2 +- src/service/sending/mod.rs | 4 ++-- src/service/sending/send.rs | 12 +++++----- 10 files changed, 33 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c3702221..ade7465fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2749,7 +2749,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.10.1" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "assign", "js_int", @@ -2769,7 +2769,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.10.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "js_int", "ruma-common", @@ -2781,7 +2781,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.18.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "as_variant", "assign", @@ -2804,7 +2804,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "as_variant", "base64 0.22.1", @@ -2834,7 +2834,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.28.1" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "as_variant", "indexmap 2.2.6", @@ -2856,7 +2856,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "js_int", "ruma-common", @@ -2868,7 +2868,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "js_int", "thiserror", @@ -2877,7 +2877,7 @@ dependencies = [ [[package]] name = "ruma-identity-service-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "js_int", "ruma-common", @@ -2887,7 +2887,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.13.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "once_cell", "proc-macro-crate", @@ -2902,7 +2902,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.9.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "js_int", "ruma-common", @@ -2914,7 +2914,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.15.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "base64 0.22.1", "ed25519-dalek", @@ -2930,7 +2930,7 @@ dependencies = [ [[package]] name = "ruma-state-res" version = "0.11.0" -source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#e9302a955614429ca969eb4c7f561fc87a4f6f15" +source = "git+https://github.com/girlbossceo/ruwuma?branch=conduwuit-changes#f8f6db89d8efa2b4fddbfcaf502759b80cd2d6cb" dependencies = [ "itertools 0.12.1", "js_int", diff --git a/Cargo.toml b/Cargo.toml index 3633bc1d9..4c9ba3f5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -715,7 +715,6 @@ correctness = "warn" nursery = "warn" ## some sadness -future_not_send = { level = "allow", priority = 1 } # TODO missing_const_for_fn = { level = "allow", priority = 1 } # TODO needless_collect = { level = "allow", priority = 1 } # TODO option_if_let_else = { level = "allow", priority = 1 } # TODO diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index b021bea19..0cfb01384 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -248,7 +248,7 @@ pub(crate) async fn get_key_changes_route( }) } -pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( +pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F, include_display_names: bool, ) -> Result<get_keys::v3::Response> { diff --git a/src/router/mod.rs b/src/router/mod.rs index 9bc273566..e9bae3c5f 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -14,16 +14,16 @@ conduit::mod_dtor! {} #[no_mangle] -pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> { +pub extern "Rust" fn start(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { Box::pin(run::start(server.clone())) } #[no_mangle] -pub extern "Rust" fn stop(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> { +pub extern "Rust" fn stop(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { Box::pin(run::stop(server.clone())) } #[no_mangle] -pub extern "Rust" fn run(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>>>> { +pub extern "Rust" fn run(server: &Arc<Server>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> { Box::pin(run::run(server.clone())) } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index aa172b5c4..0d0978a77 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -528,15 +528,20 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result #[cfg(feature = "sha256_media")] { + use std::path::PathBuf; if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") { warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"); // Move old media files to new names + let mut changes = Vec::<(PathBuf, PathBuf)>::new(); for (key, _) in db.mediaid_file.iter() { let old_path = services().globals.get_media_file(&key); debug!("Old file path: {old_path:?}"); let path = services().globals.get_media_file_new(&key); debug!("New file path: {path:?}"); - // move the file to the new location + changes.push((old_path, path)); + } + // move the file to the new location + for (old_path, path) in changes { if old_path.exists() { tokio::fs::rename(&old_path, &path).await?; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 636ba4dd7..5736ef1d9 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -47,7 +47,7 @@ pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<St #[tracing::instrument(skip(self, dest, request))] pub async fn send_request<T>(&self, dest: &str, request: T) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index 91ba0aa4e..d22af9bd9 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -28,7 +28,7 @@ pub async fn fetch_required_signing_keys<'a, E>( &'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, ) -> Result<()> where - E: IntoIterator<Item = &'a BTreeMap<String, CanonicalJsonValue>>, + E: IntoIterator<Item = &'a BTreeMap<String, CanonicalJsonValue>> + Send, { let mut server_key_ids = HashMap::new(); for event in events { diff --git a/src/service/sending/appservice.rs b/src/service/sending/appservice.rs index 4df5b9340..721424e17 100644 --- a/src/service/sending/appservice.rs +++ b/src/service/sending/appservice.rs @@ -12,7 +12,7 @@ /// registration file pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Result<Option<T::IncomingResponse>> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0]; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 4ad40f677..ab813489b 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -219,7 +219,7 @@ pub fn flush_servers<I: Iterator<Item = OwnedServerName>>(&self, servers: I) -> #[tracing::instrument(skip(self, request), name = "request")] pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { let client = &services().globals.client.federation; send::send(client, dest, request).await @@ -233,7 +233,7 @@ pub async fn send_appservice_request<T>( &self, registration: Registration, request: T, ) -> Result<Option<T::IncomingResponse>> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { appservice::send_request(registration, request).await } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index e8432d125..f600f196b 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -18,7 +18,7 @@ #[tracing::instrument(skip_all, name = "send")] pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); @@ -33,7 +33,7 @@ async fn execute<T>( client: &Client, dest: &ServerName, actual: &ActualDest, request: Request, ) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { let method = request.method().clone(); let url = request.url().clone(); @@ -50,7 +50,7 @@ async fn execute<T>( async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5]; @@ -72,7 +72,7 @@ async fn handle_response<T>( dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response, ) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { trace!("Received response from {} for {} with {}", actual.string, url, response.url()); let status = response.status(); @@ -121,7 +121,7 @@ fn handle_error<T>( _dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut e: reqwest::Error, ) -> Result<T::IncomingResponse> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { if e.is_timeout() || e.is_connect() { e = e.without_url(); @@ -144,7 +144,7 @@ fn handle_error<T>( fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>) where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug + Send, { let mut req_map = serde_json::Map::new(); if !http_request.body().is_empty() { -- GitLab From c3c91e9d8043690f7babcfc00173440f0360cd27 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 10:13:49 +0000 Subject: [PATCH 10/13] Fix suboptimal flops Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/database/sqlite/mod.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4c9ba3f5f..dbf3cf652 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -721,7 +721,6 @@ option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO significant_drop_tightening = { level = "allow", priority = 1 } # TODO -suboptimal_flops = { level = "allow", priority = 1 } # TODO use_self = { level = "allow", priority = 1 } # TODO useless_let_if_seq = { level = "allow", priority = 1 } # TODO diff --git a/src/database/sqlite/mod.rs b/src/database/sqlite/mod.rs index 61d78a9a9..745c6aa78 100644 --- a/src/database/sqlite/mod.rs +++ b/src/database/sqlite/mod.rs @@ -109,7 +109,7 @@ fn open(config: &Config) -> Result<Self> { clippy::cast_sign_loss )] let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0) - / ((conduit::utils::available_parallelism() as f64 * 2.0) + 1.0)) as u32; + / (conduit::utils::available_parallelism() as f64).mul_add(2.0, 1.0)) as u32; let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); -- GitLab From eae41fc411188a7edc649506b736ef46ffed1d74 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 10:23:06 +0000 Subject: [PATCH 11/13] Fix use-self Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/api/ruma_wrapper/mod.rs | 2 +- src/core/config/mod.rs | 2 +- src/core/config/proxy.rs | 26 +++++++++++++------------- src/core/log.rs | 6 +++--- src/core/pducount.rs | 12 ++++++------ src/database/cork.rs | 2 +- src/database/kvdatabase.rs | 2 +- src/database/rocksdb/mod.rs | 4 ++-- src/database/sqlite/mod.rs | 4 ++-- src/main/server.rs | 4 ++-- src/service/appservice/mod.rs | 6 +++--- src/service/globals/client.rs | 4 ++-- src/service/globals/resolver.rs | 2 +- src/service/pdu.rs | 2 +- src/service/rooms/spaces/mod.rs | 6 +++--- src/service/sending/mod.rs | 6 +++--- src/service/sending/resolve.rs | 4 ++-- 18 files changed, 47 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dbf3cf652..74c8a0c73 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -721,7 +721,6 @@ option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO significant_drop_tightening = { level = "allow", priority = 1 } # TODO -use_self = { level = "allow", priority = 1 } # TODO useless_let_if_seq = { level = "allow", priority = 1 } # TODO ################### diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index fb633b912..09c07cef9 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -55,7 +55,7 @@ async fn from_request(request: hyper::Request<Body>, _: &S) -> Result<Self, Self let mut request = request::from(request).await?; let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok(); let auth = auth::auth(&mut request, &json_body, &T::METADATA).await?; - Ok(Ruma { + Ok(Self { body: make_body::<T>(&mut request, &mut json_body, &auth)?, origin: auth.origin, sender_user: auth.sender_user, diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index ae9902e11..dadb000b8 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -409,7 +409,7 @@ pub fn new(path: Option<PathBuf>) -> Result<Self, Error> { .merge(Env::prefixed("CONDUWUIT_").global().split("__")) }; - let config = match raw_config.extract::<Config>() { + let config = match raw_config.extract::<Self>() { Err(e) => return Err(Error::BadConfig(format!("{e}"))), Ok(config) => config, }; diff --git a/src/core/config/proxy.rs b/src/core/config/proxy.rs index f41a92f60..a2d64fa61 100644 --- a/src/core/config/proxy.rs +++ b/src/core/config/proxy.rs @@ -42,11 +42,11 @@ pub enum ProxyConfig { impl ProxyConfig { pub fn to_proxy(&self) -> Result<Option<Proxy>> { Ok(match self.clone() { - ProxyConfig::None => None, - ProxyConfig::Global { + Self::None => None, + Self::Global { url, } => Some(Proxy::all(url)?), - ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { + Self::ByDomain(proxies) => Some(Proxy::custom(move |url| { proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching // proxy })), @@ -108,18 +108,18 @@ enum WildCardedDomain { impl WildCardedDomain { fn matches(&self, domain: &str) -> bool { match self { - WildCardedDomain::WildCard => true, - WildCardedDomain::WildCarded(d) => domain.ends_with(d), - WildCardedDomain::Exact(d) => domain == d, + Self::WildCard => true, + Self::WildCarded(d) => domain.ends_with(d), + Self::Exact(d) => domain == d, } } fn more_specific_than(&self, other: &Self) -> bool { match (self, other) { - (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, - (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b), + (Self::WildCard, Self::WildCard) => false, + (_, Self::WildCard) => true, + (Self::Exact(a), Self::WildCarded(_)) => other.matches(a), + (Self::WildCarded(a), Self::WildCarded(b)) => a != b && a.ends_with(b), _ => false, } } @@ -130,11 +130,11 @@ impl std::str::FromStr for WildCardedDomain { fn from_str(s: &str) -> Result<Self, Self::Err> { // maybe do some domain validation? Ok(if s.starts_with("*.") { - WildCardedDomain::WildCarded(s[1..].to_owned()) + Self::WildCarded(s[1..].to_owned()) } else if s == "*" { - WildCardedDomain::WildCarded(String::new()) + Self::WildCarded(String::new()) } else { - WildCardedDomain::Exact(s.to_owned()) + Self::Exact(s.to_owned()) }) } } diff --git a/src/core/log.rs b/src/core/log.rs index d31ca194e..eae25eb08 100644 --- a/src/core/log.rs +++ b/src/core/log.rs @@ -21,7 +21,7 @@ pub trait ReloadHandle<L> { } impl<L, S> ReloadHandle<L> for reload::Handle<L, S> { - fn reload(&self, new_value: L) -> Result<(), reload::Error> { reload::Handle::reload(self, new_value) } + fn reload(&self, new_value: L) -> Result<(), reload::Error> { Self::reload(self, new_value) } } struct LogLevelReloadHandlesInner { @@ -37,8 +37,8 @@ pub struct LogLevelReloadHandles { impl LogLevelReloadHandles { #[must_use] - pub fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> LogLevelReloadHandles { - LogLevelReloadHandles { + pub fn new(handles: Vec<Box<dyn ReloadHandle<EnvFilter> + Send + Sync>>) -> Self { + Self { inner: Arc::new(LogLevelReloadHandlesInner { handles, }), diff --git a/src/core/pducount.rs b/src/core/pducount.rs index 8adb4ca5d..094988b69 100644 --- a/src/core/pducount.rs +++ b/src/core/pducount.rs @@ -29,8 +29,8 @@ pub fn try_from_string(token: &str) -> Result<Self> { #[must_use] pub fn stringify(&self) -> String { match self { - PduCount::Backfilled(x) => format!("-{x}"), - PduCount::Normal(x) => x.to_string(), + Self::Backfilled(x) => format!("-{x}"), + Self::Normal(x) => x.to_string(), } } } @@ -42,10 +42,10 @@ fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) impl Ord for PduCount { fn cmp(&self, other: &Self) -> Ordering { match (self, other) { - (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), - (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), - (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, - (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, + (Self::Normal(s), Self::Normal(o)) => s.cmp(o), + (Self::Backfilled(s), Self::Backfilled(o)) => o.cmp(s), + (Self::Normal(_), Self::Backfilled(_)) => Ordering::Greater, + (Self::Backfilled(_), Self::Normal(_)) => Ordering::Less, } } } diff --git a/src/database/cork.rs b/src/database/cork.rs index 752260a6b..dcabd7072 100644 --- a/src/database/cork.rs +++ b/src/database/cork.rs @@ -11,7 +11,7 @@ pub struct Cork { impl Cork { pub fn new(db: &Arc<dyn KeyValueDatabaseEngine>, flush: bool, sync: bool) -> Self { db.cork().unwrap(); - Cork { + Self { db: db.clone(), flush, sync, diff --git a/src/database/kvdatabase.rs b/src/database/kvdatabase.rs index 906712dc3..ce3bf30f9 100644 --- a/src/database/kvdatabase.rs +++ b/src/database/kvdatabase.rs @@ -157,7 +157,7 @@ pub struct KeyValueDatabase { impl KeyValueDatabase { /// Load an existing database or create a new one. #[allow(clippy::too_many_lines)] - pub async fn load_or_create(server: &Arc<Server>) -> Result<KeyValueDatabase> { + pub async fn load_or_create(server: &Arc<Server>) -> Result<Self> { let config = &server.config; check_db_setup(config)?; let builder = build(config)?; diff --git a/src/database/rocksdb/mod.rs b/src/database/rocksdb/mod.rs index f0c3882f0..c9493cab1 100644 --- a/src/database/rocksdb/mod.rs +++ b/src/database/rocksdb/mod.rs @@ -84,7 +84,7 @@ fn open(config: &Config) -> Result<Self> { db.latest_sequence_number(), load_time.elapsed() ); - Ok(Arc::new(Engine { + Ok(Self::new(Engine { config: config.clone(), row_cache, col_cache, @@ -110,7 +110,7 @@ fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> { Ok(Arc::new(RocksDbEngineTree { name, - db: Arc::clone(self), + db: Self::clone(self), watchers: Watchers::default(), })) } diff --git a/src/database/sqlite/mod.rs b/src/database/sqlite/mod.rs index 745c6aa78..562d2a10b 100644 --- a/src/database/sqlite/mod.rs +++ b/src/database/sqlite/mod.rs @@ -113,7 +113,7 @@ fn open(config: &Config) -> Result<Self> { let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); - let arc = Arc::new(Engine { + let arc = Self::new(Engine { writer, read_conn_tls: ThreadLocal::new(), read_iterator_conn_tls: ThreadLocal::new(), @@ -131,7 +131,7 @@ fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> { )?; Ok(Arc::new(SqliteTable { - engine: Arc::clone(self), + engine: Self::clone(self), name: name.to_owned(), watchers: Watchers::default(), })) diff --git a/src/main/server.rs b/src/main/server.rs index cd5b1cdb3..b1bccfe94 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -29,7 +29,7 @@ pub(crate) struct Server { } impl Server { - pub(crate) fn build(args: Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Server>, Error> { + pub(crate) fn build(args: Args, runtime: Option<&runtime::Handle>) -> Result<Arc<Self>, Error> { let config = Config::new(args.config)?; #[cfg(feature = "sentry_telemetry")] @@ -49,7 +49,7 @@ pub(crate) fn build(args: Args, runtime: Option<&runtime::Handle>) -> Result<Arc conduit::version::conduwuit(), ); - Ok(Arc::new(Server { + Ok(Arc::new(Self { server: Arc::new(conduit::Server::new(config, runtime.cloned(), tracing_reload_handle)), _tracing_flame_guard: tracing_flame_guard, diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 49faa6f61..1f29c4154 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -75,7 +75,7 @@ fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> { } } - Ok(NamespaceRegex { + Ok(Self { exclusive: if exclusive.is_empty() { None } else { @@ -102,8 +102,8 @@ pub struct RegistrationInfo { impl TryFrom<Registration> for RegistrationInfo { type Error = regex::Error; - fn try_from(value: Registration) -> Result<RegistrationInfo, regex::Error> { - Ok(RegistrationInfo { + fn try_from(value: Registration) -> Result<Self, regex::Error> { + Ok(Self { users: value.namespaces.users.clone().try_into()?, aliases: value.namespaces.aliases.clone().try_into()?, rooms: value.namespaces.rooms.clone().try_into()?, diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs index 33f6d85f7..5e1b129da 100644 --- a/src/service/globals/client.rs +++ b/src/service/globals/client.rs @@ -15,8 +15,8 @@ pub struct Client { } impl Client { - pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Client { - Client { + pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Self { + Self { default: Self::base(config) .unwrap() .dns_resolver(resolver.clone()) diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index e90201ede..39e4cfb0e 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -83,7 +83,7 @@ pub fn new(config: &Config) -> Self { let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); let overrides = Arc::new(StdRwLock::new(TlsNameMap::new())); - Resolver { + Self { destinations: Arc::new(RwLock::new(WellKnownMap::new())), overrides: overrides.clone(), resolver: resolver.clone(), diff --git a/src/service/pdu.rs b/src/service/pdu.rs index f608200b2..6274d6350 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -56,7 +56,7 @@ pub struct PduEvent { impl PduEvent { #[tracing::instrument(skip(self))] - pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> { + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { self.unsigned = None; let mut content = serde_json::from_str(self.content.get()) diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 18b9258e9..6fcaccb9c 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -211,7 +211,7 @@ fn push(&mut self, parent: NodeId, mut children: Vec<(OwnedRoomId, Vec<OwnedServ fn new(root: OwnedRoomId, max_depth: usize) -> Self { let zero_depth = max_depth == 0; - Arena { + Self { nodes: vec![Node { parent: None, next_sibling: None, @@ -248,7 +248,7 @@ fn from_str(value: &str) -> Result<Self> { let mut values = value.split('_'); let mut pag_tok = || { - Some(PagnationToken { + Some(Self { skip: UInt::from_str(values.next()?).ok()?, limit: UInt::from_str(values.next()?).ok()?, max_depth: UInt::from_str(values.next()?).ok()?, @@ -316,7 +316,7 @@ fn from(value: CachedSpaceHierarchySummary) -> Self { .. } = value.summary; - SpaceHierarchyRoomsChunk { + Self { canonical_alias, name, num_joined_members, diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index ab813489b..225d37c7e 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -259,19 +259,19 @@ impl Destination { #[tracing::instrument(skip(self))] pub fn get_prefix(&self) -> Vec<u8> { let mut prefix = match self { - Destination::Appservice(server) => { + Self::Appservice(server) => { let mut p = b"+".to_vec(); p.extend_from_slice(server.as_bytes()); p }, - Destination::Push(user, pushkey) => { + Self::Push(user, pushkey) => { let mut p = b"$".to_vec(); p.extend_from_slice(user.as_bytes()); p.push(0xFF); p.extend_from_slice(pushkey.as_bytes()); p }, - Destination::Normal(server) => { + Self::Normal(server) => { let mut p = Vec::new(); p.extend_from_slice(server.as_bytes()); p diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index b79f18c88..1d043d8bf 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -421,8 +421,8 @@ fn port(&self) -> Option<u16> { impl fmt::Display for FedDest { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - FedDest::Named(host, port) => write!(f, "{host}{port}"), - FedDest::Literal(addr) => write!(f, "{addr}"), + Self::Named(host, port) => write!(f, "{host}{port}"), + Self::Literal(addr) => write!(f, "{addr}"), } } } -- GitLab From f52acd9cdf2c217ae45677840ba1e821bbee6cb4 Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 10:29:53 +0000 Subject: [PATCH 12/13] Fix idiomatic let if Signed-off-by: Jason Volk <jason@zemos.net> --- Cargo.toml | 1 - src/api/client/account.rs | 9 ++++----- src/api/client/sync.rs | 9 ++++----- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 74c8a0c73..ce074a791 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -721,7 +721,6 @@ option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO significant_drop_tightening = { level = "allow", priority = 1 } # TODO -useless_let_if_seq = { level = "allow", priority = 1 } # TODO ################### pedantic = "warn" diff --git a/src/api/client/account.rs b/src/api/client/account.rs index a5ff5d829..72830f9d4 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -173,8 +173,7 @@ pub(crate) async fn register_route(body: Ruma<register::v3::Request>) -> Result< // UIAA let mut uiaainfo; - let skip_auth; - if services().globals.config.registration_token.is_some() { + let skip_auth = if services().globals.config.registration_token.is_some() { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -185,7 +184,7 @@ pub(crate) async fn register_route(body: Ruma<register::v3::Request>) -> Result< session: None, auth_error: None, }; - skip_auth = body.appservice_info.is_some(); + body.appservice_info.is_some() } else { // No registration token necessary, but clients must still go through the flow uiaainfo = UiaaInfo { @@ -197,8 +196,8 @@ pub(crate) async fn register_route(body: Ruma<register::v3::Request>) -> Result< session: None, auth_error: None, }; - skip_auth = body.appservice_info.is_some() || is_guest; - } + body.appservice_info.is_some() || is_guest + }; if !skip_auth { if let Some(auth) = &body.auth { diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 8a83cbe76..b441eb6c5 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -1042,8 +1042,7 @@ fn load_timeline( sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; - let limited; - if services() + let limited = if services() .rooms .timeline .last_timeline_count(sender_user, room_id)? @@ -1073,11 +1072,11 @@ fn load_timeline( // They /sync response doesn't always return all messages, so we say the output // is limited unless there are events in non_timeline_pdus - limited = non_timeline_pdus.next().is_some(); + non_timeline_pdus.next().is_some() } else { timeline_pdus = Vec::new(); - limited = false; - } + false + }; Ok((timeline_pdus, limited)) } -- GitLab From f0557e3303724f862f38a8a2bad882caadce614e Mon Sep 17 00:00:00 2001 From: Jason Volk <jason@zemos.net> Date: Sun, 9 Jun 2024 20:50:40 +0000 Subject: [PATCH 13/13] split migrations function Signed-off-by: Jason Volk <jason@zemos.net> --- src/service/globals/migrations.rs | 1219 +++++++++++++++-------------- 1 file changed, 641 insertions(+), 578 deletions(-) diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 0d0978a77..d8936de7b 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -16,7 +16,7 @@ }; use tracing::{debug, error, info, warn}; -use crate::{globals::data::Data, services, utils, Config, Error, Result}; +use crate::{services, utils, Config, Error, Result}; pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result<()> { // Matrix resource ownership is based on the server name; changing it @@ -44,513 +44,62 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result if services().users.count()? > 0 { // MIGRATIONS if services().globals.database_version()? < 1 { - for (roomserverid, _) in db.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xFF); - let room_id = parts.next().expect("split always returns one element"); - let Some(servername) = parts.next() else { - error!("Migration: Invalid roomserverid in db."); - continue; - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xFF); - serverroomid.extend_from_slice(room_id); - - db.serverroomids.insert(&serverroomid, &[])?; - } - - services().globals.bump_database_version(1)?; - - warn!("Migration: 0 -> 1 finished"); + db_lt_1(db, config).await?; } if services().globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.userid_password.iter() { - let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); - let password = std::str::from_utf8(&password).expect("password is valid utf-8"); - let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); - if empty_hashed_password { - db.userid_password.insert(&userid, b"")?; - } - } - - services().globals.bump_database_version(2)?; - - warn!("Migration: 1 -> 2 finished"); + db_lt_2(db, config).await?; } if services().globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services().globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.mediaid_file.insert(&key, &[])?; - } - - services().globals.bump_database_version(3)?; - - warn!("Migration: 2 -> 3 finished"); + db_lt_3(db, config).await?; } if services().globals.database_version()? < 4 { - // Add federated users to services() as deactivated - for our_user in services().users.iter() { - let our_user = our_user?; - if services().users.is_deactivated(&our_user)? { - continue; - } - for room in services().rooms.state_cache.rooms_joined(&our_user) { - for user in services().rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != config.server_name { - info!(?user, "Migration: creating user"); - services().users.create(&user, None)?; - } - } - } - } - - services().globals.bump_database_version(4)?; - - warn!("Migration: 3 -> 4 finished"); + db_lt_4(db, config).await?; } if services().globals.database_version()? < 5 { - // Upgrade user data store - for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xFF); - key.extend_from_slice(user_id); - key.push(0xFF); - key.extend_from_slice(event_type); - - db.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - } - - services().globals.bump_database_version(5)?; - - warn!("Migration: 4 -> 5 finished"); + db_lt_5(db, config).await?; } if services().globals.database_version()? < 6 { - // Set room member count - for (roomid, _) in db.roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services().rooms.state_cache.update_joined_count(room_id)?; - } - - services().globals.bump_database_version(6)?; - - warn!("Migration: 5 -> 6 finished"); + db_lt_6(db, config).await?; } if services().globals.database_version()? < 7 { - // Upgrade state store - let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); - let mut current_sstatehash: Option<u64> = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - - let handle_state = |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::<HashSet<_>>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::<HashSet<_>>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services().rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) - .collect::<Vec<_>>() - ); - */ - - Ok::<_, Error>(()) - }; - - for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct"); - let sstatekey = k[size_of::<u64>()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services() - .rooms - .timeline - .get_pdu(event_id) - .unwrap() - .unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); - } - } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services().globals.bump_database_version(7)?; - - warn!("Migration: 6 -> 7 finished"); + db_lt_7(db, config).await?; } if services().globals.database_version()? < 8 { - // Generate short room ids for all rooms - for (room_id, _) in db.roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); - db.roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id; - new_key.extend_from_slice(count); - - Some((new_key, v)) - }); - - db.pduid_pdu.insert_batch(&mut batch)?; - - let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id; - new_value.extend_from_slice(count); - - Some((k, new_value)) - }); - - db.eventid_pduid.insert_batch(&mut batch2)?; - - services().globals.bump_database_version(8)?; - - warn!("Migration: 7 -> 8 finished"); + db_lt_8(db, config).await?; } if services().globals.database_version()? < 9 { - // Update tokenids db layout - let mut iter = db - .tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xFF); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id; - new_key.extend_from_slice(word); - new_key.push(0xFF); - new_key.extend_from_slice(pdu_id_count); - Some((new_key, Vec::new())) - }) - .peekable(); - - while iter.peek().is_some() { - db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; - debug!("Inserted smaller batch"); - } - - info!("Deleting starts"); - - let batch2: Vec<_> = db - .tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - db.tokenids.remove(&key)?; - } - - services().globals.bump_database_version(9)?; - - warn!("Migration: 8 -> 9 finished"); + db_lt_9(db, config).await?; } if services().globals.database_version()? < 10 { - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { - db.shortstatekey_statekey - .insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services().users.iter().filter_map(Result::ok) { - services().users.mark_device_key_update(&user_id)?; - } - - services().globals.bump_database_version(10)?; - - warn!("Migration: 9 -> 10 finished"); + db_lt_10(db, config).await?; } if services().globals.database_version()? < 11 { - db.db - .open_tree("userdevicesessionid_uiaarequest")? - .clear()?; - services().globals.bump_database_version(11)?; - - warn!("Migration: 10 -> 11 finished"); + db_lt_11(db, config).await?; } if services().globals.database_version()? < 12 { - for username in services().users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - }, - }; - - let raw_rules_list = services() - .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); - let rules_list = &mut account_data.content.global; - - //content rule - { - let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; - - let rule = rules_list.content.get(content_rule_transformation[0]); - if rule.is_some() { - let mut rule = rule.unwrap().clone(); - content_rule_transformation[1].clone_into(&mut rule.rule_id); - rules_list - .content - .shift_remove(content_rule_transformation[0]); - rules_list.content.insert(rule); - } - } - - //underride rules - { - let underride_rule_transformation = [ - [".m.rules.call", ".m.rule.call"], - [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], - [".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"], - [".m.rules.message", ".m.rule.message"], - [".m.rules.encrypted", ".m.rule.encrypted"], - ]; - - for transformation in underride_rule_transformation { - let rule = rules_list.underride.get(transformation[0]); - if let Some(rule) = rule { - let mut rule = rule.clone(); - transformation[1].clone_into(&mut rule.rule_id); - rules_list.underride.shift_remove(transformation[0]); - rules_list.underride.insert(rule); - } - } - } - - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services().globals.bump_database_version(12)?; - - warn!("Migration: 11 -> 12 finished"); + db_lt_12(db, config).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. if services().globals.database_version()? < 13 { - for username in services().users.list_local_users()? { - let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - }, - }; - - let raw_rules_list = services() - .account_data - .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); - - let user_default_rules = Ruleset::server_default(&user); - account_data - .content - .global - .update_with_server_default(user_default_rules); - - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services().globals.bump_database_version(13)?; - - warn!("Migration: 12 -> 13 finished"); + db_lt_13(db, config).await?; } #[cfg(feature = "sha256_media")] - { - use std::path::PathBuf; - if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") { - warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"); - // Move old media files to new names - let mut changes = Vec::<(PathBuf, PathBuf)>::new(); - for (key, _) in db.mediaid_file.iter() { - let old_path = services().globals.get_media_file(&key); - debug!("Old file path: {old_path:?}"); - let path = services().globals.get_media_file_new(&key); - debug!("New file path: {path:?}"); - changes.push((old_path, path)); - } - // move the file to the new location - for (old_path, path) in changes { - if old_path.exists() { - tokio::fs::rename(&old_path, &path).await?; - } - } - - services().globals.bump_database_version(14)?; - - warn!("Migration: 13 -> 14 finished"); - } + if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") { + feat_sha256_media(db, config).await?; } if db @@ -558,37 +107,7 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result .get(b"fix_bad_double_separator_in_state_cache")? .is_none() { - warn!("Fixing bad double separator in state_cache roomuserid_joined"); - let mut iter_count: usize = 0; - - let _cork = db.cork(); - - for (mut key, value) in db.roomuserid_joined.iter() { - iter_count = iter_count.saturating_add(1); - debug_info!(%iter_count); - let first_sep_index = key.iter().position(|&i| i == 0xFF).unwrap(); - - if key - .iter() - .get(first_sep_index..=first_sep_index + 1) - .copied() - .collect_vec() == vec![0xFF, 0xFF] - { - debug_warn!("Found bad key: {key:?}"); - db.roomuserid_joined.remove(&key)?; - - key.remove(first_sep_index); - debug_warn!("Fixed key: {key:?}"); - db.roomuserid_joined.insert(&key, &value)?; - } - } - - db.cleanup()?; - - warn!("Finished fixing"); - - db.global - .insert(b"fix_bad_double_separator_in_state_cache", &[])?; + fix_bad_double_separator_in_state_cache(db, config).await?; } if db @@ -596,86 +115,7 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result .get(b"retroactively_fix_bad_data_from_roomuserid_joined")? .is_none() { - warn!("Retroactively fixing bad data from broken roomuserid_joined"); - - let room_ids = services() - .rooms - .metadata - .iter_ids() - .filter_map(Result::ok) - .collect_vec(); - - let _cork = db.cork(); - - for room_id in room_ids.clone() { - debug_info!("Fixing room {room_id}"); - - let users_in_room = services() - .rooms - .state_cache - .room_members(&room_id) - .filter_map(Result::ok) - .collect_vec(); - - let joined_members = users_in_room - .iter() - .filter(|user_id| { - services() - .rooms - .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| membership.membership == MembershipState::Join) - }) - .collect_vec(); - - let non_joined_members = users_in_room - .iter() - .filter(|user_id| { - services() - .rooms - .state_accessor - .get_member(&room_id, user_id) - .unwrap_or(None) - .map_or(false, |membership| { - membership.membership == MembershipState::Leave - || membership.membership == MembershipState::Ban - }) - }) - .collect_vec(); - - for user_id in joined_members { - debug_info!("User is joined, marking as joined"); - services() - .rooms - .state_cache - .mark_as_joined(user_id, &room_id)?; - } - - for user_id in non_joined_members { - debug_info!("User is left or banned, marking as left"); - services() - .rooms - .state_cache - .mark_as_left(user_id, &room_id)?; - } - } - - for room_id in room_ids { - debug_info!( - "Updating joined count for room {room_id} to fix servers in room after correcting membership \ - states" - ); - - services().rooms.state_cache.update_joined_count(&room_id)?; - } - - db.cleanup()?; - - warn!("Finished fixing"); - - db.global - .insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + retroactively_fix_bad_data_from_roomuserid_joined(db, config).await?; } assert_eq!( @@ -761,3 +201,626 @@ pub(crate) async fn migrations(db: &KeyValueDatabase, config: &Config) -> Result Ok(()) } + +async fn db_lt_1(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + for (roomserverid, _) in db.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xFF); + let room_id = parts.next().expect("split always returns one element"); + let Some(servername) = parts.next() else { + error!("Migration: Invalid roomserverid in db."); + continue; + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xFF); + serverroomid.extend_from_slice(room_id); + + db.serverroomids.insert(&serverroomid, &[])?; + } + + services().globals.bump_database_version(1)?; + info!("Migration: 0 -> 1 finished"); + Ok(()) +} + +async fn db_lt_2(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.userid_password.iter() { + let empty_pass = utils::hash::password("").expect("our own password to be properly hashed"); + let password = std::str::from_utf8(&password).expect("password is valid utf-8"); + let empty_hashed_password = utils::hash::verify_password(password, &empty_pass).is_ok(); + if empty_hashed_password { + db.userid_password.insert(&userid, b"")?; + } + } + + services().globals.bump_database_version(2)?; + info!("Migration: 1 -> 2 finished"); + Ok(()) +} + +async fn db_lt_3(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Move media to filesystem + for (key, content) in db.mediaid_file.iter() { + if content.is_empty() { + continue; + } + + #[allow(deprecated)] + let path = services().globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.mediaid_file.insert(&key, &[])?; + } + + services().globals.bump_database_version(3)?; + info!("Migration: 2 -> 3 finished"); + Ok(()) +} + +async fn db_lt_4(_db: &KeyValueDatabase, config: &Config) -> Result<()> { + // Add federated users to services() as deactivated + for our_user in services().users.iter() { + let our_user = our_user?; + if services().users.is_deactivated(&our_user)? { + continue; + } + for room in services().rooms.state_cache.rooms_joined(&our_user) { + for user in services().rooms.state_cache.room_members(&room?) { + let user = user?; + if user.server_name() != config.server_name { + info!(?user, "Migration: creating user"); + services().users.create(&user, None)?; + } + } + } + } + + services().globals.bump_database_version(4)?; + info!("Migration: 3 -> 4 finished"); + Ok(()) +} + +async fn db_lt_5(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Upgrade user data store + for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { + let mut parts = roomuserdataid.split(|&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let user_id = parts.next().unwrap(); + let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); + + let mut key = room_id.to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id); + key.push(0xFF); + key.extend_from_slice(event_type); + + db.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; + } + + services().globals.bump_database_version(5)?; + info!("Migration: 4 -> 5 finished"); + Ok(()) +} + +async fn db_lt_6(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Set room member count + for (roomid, _) in db.roomid_shortstatehash.iter() { + let string = utils::string_from_bytes(&roomid).unwrap(); + let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); + services().rooms.state_cache.update_joined_count(room_id)?; + } + + services().globals.bump_database_version(6)?; + info!("Migration: 5 -> 6 finished"); + Ok(()) +} + +async fn db_lt_7(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Upgrade state store + let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new(); + let mut current_sstatehash: Option<u64> = None; + let mut current_room = None; + let mut current_state = HashSet::new(); + + let handle_state = |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(last_roomsstatehash) + }, + )?; + + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .copied() + .collect::<HashSet<_>>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .copied() + .collect::<HashSet<_>>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + services().rooms.state_compressor.save_state_from_diff( + current_sstatehash, + Arc::new(statediffnew), + Arc::new(statediffremoved), + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) + .collect::<Vec<_>>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap()) + .collect::<Vec<_>>() + ); + */ + + Ok::<_, Error>(()) + }; + + for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { + let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct"); + let sstatekey = k[size_of::<u64>()..].to_vec(); + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); + } + current_state = HashSet::new(); + current_sstatehash = Some(sstatehash); + + let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); + let string = utils::string_from_bytes(&event_id).unwrap(); + let event_id = <&EventId>::try_from(string.as_str()).unwrap(); + let pdu = services() + .rooms + .timeline + .get_pdu(event_id) + .unwrap() + .unwrap(); + + if Some(&pdu.room_id) != current_room.as_ref() { + current_room = Some(pdu.room_id.clone()); + } + } + + let mut val = sstatekey; + val.extend_from_slice(&seventid); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + } + + services().globals.bump_database_version(7)?; + info!("Migration: 6 -> 7 finished"); + Ok(()) +} + +async fn db_lt_8(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Generate short room ids for all rooms + for (room_id, _) in db.roomid_shortstatehash.iter() { + let shortroomid = services().globals.next_count()?.to_be_bytes(); + db.roomid_shortroomid.insert(&room_id, &shortroomid)?; + info!("Migration: 8"); + } + // Update pduids db layout + let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(2, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_key = short_room_id; + new_key.extend_from_slice(count); + + Some((new_key, v)) + }); + + db.pduid_pdu.insert_batch(&mut batch)?; + + let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; + } + let mut parts = value.splitn(2, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.eventid_pduid.insert_batch(&mut batch2)?; + + services().globals.bump_database_version(8)?; + info!("Migration: 7 -> 8 finished"); + Ok(()) +} + +async fn db_lt_9(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Update tokenids db layout + let mut iter = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(4, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); + + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xFF); + new_key.extend_from_slice(pdu_id_count); + Some((new_key, Vec::new())) + }) + .peekable(); + + while iter.peek().is_some() { + db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + debug!("Inserted smaller batch"); + } + + info!("Deleting starts"); + + let batch2: Vec<_> = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if key.starts_with(b"!") { + Some(key) + } else { + None + } + }) + .collect(); + + for key in batch2 { + db.tokenids.remove(&key)?; + } + + services().globals.bump_database_version(9)?; + info!("Migration: 8 -> 9 finished"); + Ok(()) +} + +async fn db_lt_10(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + // Add other direction for shortstatekeys + for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { + db.shortstatekey_statekey + .insert(&shortstatekey, &statekey)?; + } + + // Force E2EE device list updates so we can send them over federation + for user_id in services().users.iter().filter_map(Result::ok) { + services().users.mark_device_key_update(&user_id)?; + } + + services().globals.bump_database_version(10)?; + info!("Migration: 9 -> 10 finished"); + Ok(()) +} + +async fn db_lt_11(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + db.db + .open_tree("userdevicesessionid_uiaarequest")? + .clear()?; + + services().globals.bump_database_version(11)?; + info!("Migration: 10 -> 11 finished"); + Ok(()) +} + +async fn db_lt_12(_db: &KeyValueDatabase, config: &Config) -> Result<()> { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let raw_rules_list = services() + .account_data + .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); + let rules_list = &mut account_data.content.global; + + //content rule + { + let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; + + let rule = rules_list.content.get(content_rule_transformation[0]); + if rule.is_some() { + let mut rule = rule.unwrap().clone(); + content_rule_transformation[1].clone_into(&mut rule.rule_id); + rules_list + .content + .shift_remove(content_rule_transformation[0]); + rules_list.content.insert(rule); + } + } + + //underride rules + { + let underride_rule_transformation = [ + [".m.rules.call", ".m.rule.call"], + [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], + [".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"], + [".m.rules.message", ".m.rule.message"], + [".m.rules.encrypted", ".m.rule.encrypted"], + ]; + + for transformation in underride_rule_transformation { + let rule = rules_list.underride.get(transformation[0]); + if let Some(rule) = rule { + let mut rule = rule.clone(); + transformation[1].clone_into(&mut rule.rule_id); + rules_list.underride.shift_remove(transformation[0]); + rules_list.underride.insert(rule); + } + } + } + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(12)?; + info!("Migration: 11 -> 12 finished"); + Ok(()) +} + +async fn db_lt_13(_db: &KeyValueDatabase, config: &Config) -> Result<()> { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let raw_rules_list = services() + .account_data + .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); + + let user_default_rules = Ruleset::server_default(&user); + account_data + .content + .global + .update_with_server_default(user_default_rules); + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(13)?; + info!("Migration: 12 -> 13 finished"); + Ok(()) +} + +#[cfg(feature = "sha256_media")] +async fn feat_sha256_media(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + use std::path::PathBuf; + warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"); + // Move old media files to new names + let mut changes = Vec::<(PathBuf, PathBuf)>::new(); + for (key, _) in db.mediaid_file.iter() { + let old_path = services().globals.get_media_file(&key); + debug!("Old file path: {old_path:?}"); + let path = services().globals.get_media_file_new(&key); + debug!("New file path: {path:?}"); + changes.push((old_path, path)); + } + // move the file to the new location + for (old_path, path) in changes { + if old_path.exists() { + tokio::fs::rename(&old_path, &path).await?; + } + } + + services().globals.bump_database_version(14)?; + info!("Migration: 13 -> 14 finished"); + Ok(()) +} + +async fn fix_bad_double_separator_in_state_cache(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + warn!("Fixing bad double separator in state_cache roomuserid_joined"); + let mut iter_count: usize = 0; + + let _cork = db.db.cork(); + + for (mut key, value) in db.roomuserid_joined.iter() { + iter_count = iter_count.saturating_add(1); + debug_info!(%iter_count); + let first_sep_index = key.iter().position(|&i| i == 0xFF).unwrap(); + + if key + .iter() + .get(first_sep_index..=first_sep_index + 1) + .copied() + .collect_vec() + == vec![0xFF, 0xFF] + { + debug_warn!("Found bad key: {key:?}"); + db.roomuserid_joined.remove(&key)?; + + key.remove(first_sep_index); + debug_warn!("Fixed key: {key:?}"); + db.roomuserid_joined.insert(&key, &value)?; + } + } + + db.db.cleanup()?; + db.global + .insert(b"fix_bad_double_separator_in_state_cache", &[])?; + + info!("Finished fixing"); + Ok(()) +} + +async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &KeyValueDatabase, _config: &Config) -> Result<()> { + warn!("Retroactively fixing bad data from broken roomuserid_joined"); + + let room_ids = services() + .rooms + .metadata + .iter_ids() + .filter_map(Result::ok) + .collect_vec(); + + let _cork = db.db.cork(); + + for room_id in room_ids.clone() { + debug_info!("Fixing room {room_id}"); + + let users_in_room = services() + .rooms + .state_cache + .room_members(&room_id) + .filter_map(Result::ok) + .collect_vec(); + + let joined_members = users_in_room + .iter() + .filter(|user_id| { + services() + .rooms + .state_accessor + .get_member(&room_id, user_id) + .unwrap_or(None) + .map_or(false, |membership| membership.membership == MembershipState::Join) + }) + .collect_vec(); + + let non_joined_members = users_in_room + .iter() + .filter(|user_id| { + services() + .rooms + .state_accessor + .get_member(&room_id, user_id) + .unwrap_or(None) + .map_or(false, |membership| { + membership.membership == MembershipState::Leave || membership.membership == MembershipState::Ban + }) + }) + .collect_vec(); + + for user_id in joined_members { + debug_info!("User is joined, marking as joined"); + services() + .rooms + .state_cache + .mark_as_joined(user_id, &room_id)?; + } + + for user_id in non_joined_members { + debug_info!("User is left or banned, marking as left"); + services() + .rooms + .state_cache + .mark_as_left(user_id, &room_id)?; + } + } + + for room_id in room_ids { + debug_info!( + "Updating joined count for room {room_id} to fix servers in room after correcting membership states" + ); + + services().rooms.state_cache.update_joined_count(&room_id)?; + } + + db.db.cleanup()?; + db.global + .insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; + + info!("Finished fixing"); + Ok(()) +} -- GitLab