diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 0e930fc3f889e5c05f0c99c35f33508cde3dd2de..be4a8ed0a83aea9d247af00e85a0ccacef8e2f44 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,5 +1,4 @@ use rand::seq::SliceRandom; -use regex::Regex; use ruma::{ api::{ appservice, @@ -116,19 +115,12 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get match services().rooms.alias.resolve_local_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in services().appservice.all()? { - let aliases = registration - .namespaces - .aliases - .iter() - .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) - .collect::<Vec<_>>(); - - if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str())) + for appservice in services().appservice.registration_info.read().await.values() { + if appservice.aliases.is_match(room_alias.as_str()) && if let Some(opt_result) = services() .sending .send_appservice_request( - registration, + appservice.registration.clone(), appservice::query::query_room_alias::v1::Request { room_alias: room_alias.clone(), }, @@ -144,7 +136,7 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get .rooms .alias .resolve_local_alias(&room_alias)? - .ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?, + .ok_or_else(|| Error::bad_config("Room does not exist."))?, ); break; } diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index c2cb3833c9e6c1127821e1eff5d53703377c2628..6c19dbe81a91a60b3418baf92bcf2622e0833db8 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,14 +1,16 @@ use std::{collections::HashSet, sync::Arc}; -use regex::Regex; use ruma::{ - api::appservice::Registration, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; +use crate::{ + database::KeyValueDatabase, + service::{self, appservice::RegistrationInfo}, + services, utils, Error, Result, +}; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; @@ -160,19 +162,20 @@ fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId } #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result<bool> { - let maybe = - self.appservice_in_room_cache.read().unwrap().get(room_id).and_then(|map| map.get(&appservice.0)).copied(); + fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { + let maybe = self + .appservice_in_room_cache + .read() + .unwrap() + .get(room_id) + .and_then(|map| map.get(&appservice.registration.id)) + .copied(); if let Some(b) = maybe { Ok(b) } else { - let namespaces = &appservice.1.namespaces; - let users = - namespaces.users.iter().filter_map(|users| Regex::new(users.regex.as_str()).ok()).collect::<Vec<_>>(); - let bridge_user_id = UserId::parse_with_server_name( - appservice.1.sender_localpart.as_str(), + appservice.registration.sender_localpart.as_str(), services().globals.server_name(), ) .ok(); @@ -180,14 +183,14 @@ fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registratio let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) || self .room_members(room_id) - .any(|userid| userid.map_or(false, |userid| users.iter().any(|r| r.is_match(userid.as_str())))); + .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); self.appservice_in_room_cache .write() .unwrap() .entry(room_id.to_owned()) .or_default() - .insert(appservice.0.clone(), in_room); + .insert(appservice.registration.id.clone(), in_room); Ok(in_room) } diff --git a/src/database/mod.rs b/src/database/mod.rs index 7222d1b0f96bf5a046ec6df3e807d431b5c73f62..2cb22c9073d60acd3910216440847e8a0f24d9a8 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -967,6 +967,14 @@ pub async fn load_or_create(config: Config) -> Result<()> { ); } + // Inserting registraions into cache + for appservice in services().appservice.all()? { + services().appservice.registration_info.write().await.insert( + appservice.0, + appservice.1.try_into().expect("Should be validated on registration"), + ); + } + services().admin.start_handler(); // Set emergency access for the conduit user diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index f7e27c38fcc1efea3e8e3c018b19fab3019c6360..06b00ce43045aed042189512c929cbebd6c29567 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -567,7 +567,7 @@ async fn process_admin_command(&self, command: AdminCommand, body: Vec<&str>) -> let appservice_config = body[1..body.len() - 1].join("\n"); let parsed_config = serde_yaml::from_str::<Registration>(&appservice_config); match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml) { + Ok(yaml) => match services().appservice.register_appservice(yaml).await { Ok(id) => { RoomMessageEventContent::text_plain(format!("Appservice registered with ID: {id}.")) }, @@ -587,7 +587,7 @@ async fn process_admin_command(&self, command: AdminCommand, body: Vec<&str>) -> }, AppserviceCommand::Unregister { appservice_identifier, - } => match services().appservice.unregister_appservice(&appservice_identifier) { + } => match services().appservice.unregister_appservice(&appservice_identifier).await { Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), Err(e) => RoomMessageEventContent::text_plain(format!("Failed to unregister appservice: {e}")), }, diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 5700731de5a54158c5776a10f2b990c2ae9400cf..5884e4a83f4b583587b07bb34d8a2e46b7d4b017 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,24 +1,118 @@ mod data; +use std::collections::HashMap; + pub(crate) use data::Data; -use ruma::api::appservice::Registration; +use regex::RegexSet; +use ruma::api::appservice::{Namespace, Registration}; +use tokio::sync::RwLock; + +use crate::{services, Result}; + +/// Compiled regular expressions for a namespace +pub struct NamespaceRegex { + pub exclusive: Option<RegexSet>, + pub non_exclusive: Option<RegexSet>, +} + +impl NamespaceRegex { + /// Checks if this namespace has rights to a namespace + pub fn is_match(&self, heystack: &str) -> bool { + if self.is_exclusive_match(heystack) { + return true; + } + + if let Some(non_exclusive) = &self.non_exclusive { + if non_exclusive.is_match(heystack) { + return true; + } + } + false + } + + /// Checks if this namespace has exlusive rights to a namespace + pub fn is_exclusive_match(&self, heystack: &str) -> bool { + if let Some(exclusive) = &self.exclusive { + if exclusive.is_match(heystack) { + return true; + } + } + false + } +} + +impl TryFrom<Vec<Namespace>> for NamespaceRegex { + type Error = regex::Error; + + fn try_from(value: Vec<Namespace>) -> Result<Self, regex::Error> { + let mut exclusive = vec![]; + let mut non_exclusive = vec![]; -use crate::Result; + for namespace in value { + if namespace.exclusive { + exclusive.push(namespace.regex); + } else { + non_exclusive.push(namespace.regex); + } + } + + Ok(NamespaceRegex { + exclusive: if exclusive.is_empty() { + None + } else { + Some(RegexSet::new(exclusive)?) + }, + non_exclusive: if non_exclusive.is_empty() { + None + } else { + Some(RegexSet::new(non_exclusive)?) + }, + }) + } +} + +/// Compiled regular expressions for an appservice +pub struct RegistrationInfo { + pub registration: Registration, + pub users: NamespaceRegex, + pub aliases: NamespaceRegex, + pub rooms: NamespaceRegex, +} + +impl TryFrom<Registration> for RegistrationInfo { + type Error = regex::Error; + + fn try_from(value: Registration) -> Result<RegistrationInfo, regex::Error> { + Ok(RegistrationInfo { + users: value.namespaces.users.clone().try_into()?, + aliases: value.namespaces.aliases.clone().try_into()?, + rooms: value.namespaces.rooms.clone().try_into()?, + registration: value, + }) + } +} pub struct Service { pub db: &'static dyn Data, + pub registration_info: RwLock<HashMap<String, RegistrationInfo>>, } impl Service { /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: Registration) -> Result<String> { self.db.register_appservice(yaml) } + pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { + services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?); + + self.db.register_appservice(yaml) + } /// Remove an appservice registration /// /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { + services().appservice.registration_info.write().await.remove(service_name); + self.db.unregister_appservice(service_name) } diff --git a/src/service/mod.rs b/src/service/mod.rs index b4e7bb4e15d01f60ae7a866e5bc39bcdd0321b89..25413647200aeae8accbc54caaf84a4309680e12 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -57,6 +57,7 @@ pub fn build< Ok(Self { appservice: appservice::Service { db, + registration_info: RwLock::new(HashMap::new()), }, pusher: pusher::Service { db, diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 44ae57f8228e13c054f18d9a42fc6e5b7730a406..6e97396f2390a3df9dc767d16a7318bcca0e947d 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,13 +1,12 @@ use std::{collections::HashSet, sync::Arc}; use ruma::{ - api::appservice::Registration, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use crate::Result; +use crate::{service::appservice::RegistrationInfo, Result}; type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>; @@ -25,7 +24,7 @@ fn mark_as_invited( fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>; - fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result<bool>; + fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool>; /// Makes a user forget a room. fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 7f305712358c6003b4afee89691dbc1d79148933..0ce82e1f5ca21c694e9132d0cfb03664ca021ab8 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -2,7 +2,7 @@ pub use data::Data; use ruma::{ - api::{appservice::Registration, federation}, + api::federation, events::{ direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, @@ -17,7 +17,7 @@ }; use tracing::warn; -use crate::{services, Error, Result}; +use crate::{service::appservice::RegistrationInfo, services, Error, Result}; mod data; @@ -201,7 +201,7 @@ pub fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUs } #[tracing::instrument(skip(self, room_id, appservice))] - pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result<bool> { + pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> { self.db.appservice_in_room(room_id, appservice) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 95063371911e0812779f1a9544a9e763328bccb9..9ffe972e1630b6cf591cf10fab8140665110baea 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -7,7 +7,6 @@ }; pub use data::Data; -use regex::Regex; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -36,7 +35,10 @@ use super::state_compressor::CompressedStateEvent; use crate::{ api::server_server, - service::pdu::{EventHash, PduBuilder}, + service::{ + appservice::NamespaceRegex, + pdu::{EventHash, PduBuilder}, + }, services, utils, Error, PduEvent, Result, }; @@ -506,9 +508,9 @@ struct ExtractRelatesToEventId { } } - for appservice in services().appservice.all()? { - if services().rooms.state_cache.appservice_in_room(&pdu.room_id, &appservice)? { - services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + for appservice in services().appservice.registration_info.read().await.values() { + if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? { + services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; } @@ -518,30 +520,20 @@ struct ExtractRelatesToEventId { if let Some(state_key_uid) = &pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) { - let appservice_uid = appservice.1.sender_localpart.as_str(); + let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; } } } - let namespaces = appservice.1.namespaces; - - // TODO: create some helper function to change from Strings to Regexes - let users = - namespaces.users.iter().filter_map(|user| Regex::new(user.regex.as_str()).ok()).collect::<Vec<_>>(); - let aliases = - namespaces.aliases.iter().filter_map(|alias| Regex::new(alias.regex.as_str()).ok()).collect::<Vec<_>>(); - let rooms = - namespaces.rooms.iter().filter_map(|room| Regex::new(room.regex.as_str()).ok()).collect::<Vec<_>>(); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) + let matching_users = |users: &NamespaceRegex| { + appservice.users.is_match(pdu.sender.as_str()) || pdu.kind == TimelineEventType::RoomMember && pdu.state_key.as_ref().map_or(false, |state_key| users.is_match(state_key)) }; - let matching_aliases = |aliases: &Regex| { + let matching_aliases = |aliases: &NamespaceRegex| { services() .rooms .alias @@ -550,11 +542,11 @@ struct ExtractRelatesToEventId { .any(|room_alias| aliases.is_match(room_alias.as_str())) }; - if aliases.iter().any(matching_aliases) - || rooms.iter().any(|namespace| namespace.is_match(pdu.room_id.as_str())) - || users.iter().any(matching_users) + if matching_aliases(&appservice.aliases) + || appservice.rooms.is_match(pdu.room_id.as_str()) + || matching_users(&appservice.users) { - services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; } } diff --git a/src/utils/error.rs b/src/utils/error.rs index 60209860d1d5e5f086644c2feed174f51074dfee..4463ffbb7dc27f50af11214e7c11459f32f9cd3b 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -39,6 +39,11 @@ pub enum Error { #[from] source: reqwest::Error, }, + #[error("Could build regular expression: {source}")] + RegexError { + #[from] + source: regex::Error, + }, #[error("{0}")] FederationError(OwnedServerName, RumaError), #[error("Could not do this io: {source}")]