diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index f59da7b7de5f21970621930628cb6e09a3728476..2d04be2f2262f638420ed044d5a8f45276de4f41 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -152,21 +152,23 @@ pub(crate) async fn join_room_by_id_route( let mut servers = services() .rooms .state_cache - .servers_invite_via(&body.room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::<Vec<_>>(), - ); + .servers_invite_via(&body.room_id) + .filter_map(Result::ok) + .collect::<Vec<_>>(); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); if let Some(server) = body.room_id.server_name() { servers.push(server.into()); @@ -206,21 +208,22 @@ pub(crate) async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .servers_invite_via(&room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(), - ), + .servers_invite_via(&room_id) + .filter_map(Result::ok), + ); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), ); if let Some(server) = room_id.server_name() { @@ -240,21 +243,22 @@ pub(crate) async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .servers_invite_via(&response.room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &response.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(), - ), + .servers_invite_via(&response.room_id) + .filter_map(Result::ok), + ); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &response.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), ); (servers, response.room_id) @@ -1680,21 +1684,23 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { .invite_state(user_id, room_id)? .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; - let servers: HashSet<OwnedServerName> = services() + let mut servers: HashSet<OwnedServerName> = services() .rooms .state_cache - .servers_invite_via(room_id)? - .map_or( - invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::<HashSet<OwnedServerName>>(), - HashSet::from_iter, - ); + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect(); + + servers.extend( + invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::<HashSet<OwnedServerName>>(), + ); debug!("servers in remote_leave_room: {servers:?}"); diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index acf390d989939d20bf0de05e3337e050442fa170..d380c1ab0418ac85147fec74b95cad39effe0def 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -6,7 +6,6 @@ serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use tracing::error; use crate::{ appservice::RegistrationInfo, @@ -94,7 +93,7 @@ fn mark_as_invited( /// Gets the servers to either accept or decline invites via for a given /// room. - fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>; + fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>; /// Add the given servers the list to accept or decline invites via for a /// given room. @@ -159,7 +158,10 @@ fn mark_as_invited( self.roomuserid_leftcount.remove(&roomuser_id)?; if let Some(servers) = invite_via { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + let mut prev_servers = self + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect_vec(); #[allow(clippy::redundant_clone)] // this is a necessary clone? prev_servers.append(servers.clone().as_mut()); let servers = prev_servers.iter().rev().unique().rev().collect_vec(); @@ -639,30 +641,40 @@ fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { } #[tracing::instrument(skip(self))] - fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> { + fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> { let key = room_id.as_bytes().to_vec(); - self.roomid_inviteviaservers - .get(&key)? - .map(|servers| { - let state = serde_json::from_slice(&servers).map_err(|e| { - error!("Invalid state in userroomid_leftstate: {e}"); - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok(state) - }) - .transpose() + Box::new( + self.roomid_inviteviaservers + .scan_prefix(key) + .map(|(_, servers)| { + ServerName::parse( + utils::string_from_bytes( + servers + .rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) + }), + ) } #[tracing::instrument(skip(self))] fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); - prev_servers.append(servers.to_owned().as_mut()); - - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers + let mut prev_servers = self + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect_vec(); + prev_servers.extend(servers.to_owned()); + prev_servers.sort_unstable(); + prev_servers.dedup(); + + let servers = prev_servers .iter() .map(|server| server.as_bytes()) .collect_vec() diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 2b3f5fb7bbd93da8e06938f0366a996cef513fa9..c9d1bbd75e061db66343042f87aa55b42c97689f 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -377,7 +377,7 @@ pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) } #[tracing::instrument(skip(self))] - pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> { + pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ { self.db.servers_invite_via(room_id) } diff --git a/tests/test_results/complement/test_results.jsonl b/tests/test_results/complement/test_results.jsonl index 502a45305bb83c145f4922fedab1b8b715a2c091..51cd57d47fce9a9adc490c5298f26aca79b77f24 100644 --- a/tests/test_results/complement/test_results.jsonl +++ b/tests/test_results/complement/test_results.jsonl @@ -55,7 +55,7 @@ {"Action":"fail","Test":"TestFederationKeyUploadQuery/Can_claim_remote_one_time_key_using_POST"} {"Action":"fail","Test":"TestFederationKeyUploadQuery/Can_query_remote_device_keys_using_POST"} {"Action":"pass","Test":"TestFederationRedactSendsWithoutEvent"} -{"Action":"fail","Test":"TestFederationRejectInvite"} +{"Action":"pass","Test":"TestFederationRejectInvite"} {"Action":"fail","Test":"TestGetMissingEventsGapFilling"} {"Action":"fail","Test":"TestInboundCanReturnMissingEvents"} {"Action":"fail","Test":"TestInboundCanReturnMissingEvents/Inbound_federation_can_return_missing_events_for_invited_visibility"} @@ -199,7 +199,7 @@ {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/good_connectivity"} {"Action":"pass","Test":"TestToDeviceMessagesOverFederation/interrupted_connectivity"} {"Action":"fail","Test":"TestToDeviceMessagesOverFederation/stopped_server"} -{"Action":"fail","Test":"TestUnbanViaInvite"} +{"Action":"pass","Test":"TestUnbanViaInvite"} {"Action":"fail","Test":"TestUnknownEndpoints"} {"Action":"pass","Test":"TestUnknownEndpoints/Client-server_endpoints"} {"Action":"fail","Test":"TestUnknownEndpoints/Key_endpoints"}