From 49eb418786369a32f9bc7f43b7db408f71fa7ea9 Mon Sep 17 00:00:00 2001
From: Matthias Ahouansou <matthias@ahouansou.cz>
Date: Wed, 3 Apr 2024 14:16:11 -0400
Subject: [PATCH] feat: support /make_join and /send_join for restricted rooms

from https://gitlab.com/famedly/conduit/-/merge_requests/618

Signed-off-by: strawberry <strawberry@puppygock.gay>
---
 src/api/server_server.rs                | 108 ++++++++++++++----------
 src/service/rooms/state_accessor/mod.rs |  13 +++
 2 files changed, 75 insertions(+), 46 deletions(-)

diff --git a/src/api/server_server.rs b/src/api/server_server.rs
index 350529f03..f87ef0a88 100644
--- a/src/api/server_server.rs
+++ b/src/api/server_server.rs
@@ -34,7 +34,7 @@
 	events::{
 		receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
 		room::{
-			join_rules::{JoinRule, RoomJoinRulesEventContent},
+			join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent},
 			member::{MembershipState, RoomMemberEventContent},
 		},
 		StateEventType, TimelineEventType,
@@ -910,7 +910,6 @@ pub async fn create_join_event_template_route(
 	);
 	let state_lock = mutex_state.lock().await;
 
-	// TODO: Conduit does not implement restricted join rules yet, we always reject
 	let join_rules_event =
 		services()
 			.rooms
@@ -921,23 +920,57 @@ pub async fn create_join_event_template_route(
 		.as_ref()
 		.map(|join_rules_event| {
 			serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
-				warn!("Invalid join rules event: {}", e);
+				warn!("Invalid join rules event: {e}");
 				Error::bad_database("Invalid join rules event in db.")
 			})
 		})
 		.transpose()?;
 
-	if let Some(join_rules_event_content) = join_rules_event_content {
-		if matches!(
-			join_rules_event_content.join_rule,
-			JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
-		) {
-			return Err(Error::BadRequest(
-				ErrorKind::UnableToAuthorizeJoin,
-				"Conduit does not support restricted rooms yet.",
-			));
+	let join_authorized_via_users_server = if let Some(join_rules_event_content) = join_rules_event_content {
+		if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rules_event_content.join_rule {
+			if r.allow
+				.iter()
+				.filter_map(|rule| {
+					if let AllowRule::RoomMembership(membership) = rule {
+						Some(membership)
+					} else {
+						None
+					}
+				})
+				.any(|m| {
+					services()
+						.rooms
+						.state_cache
+						.is_joined(&body.user_id, &m.room_id)
+						.unwrap_or(false)
+				}) && services()
+				.rooms
+				.state_cache
+				.is_left(&body.user_id, &body.room_id)
+				.unwrap_or(true)
+			{
+				services()
+					.rooms
+					.state_cache
+					.room_members(&body.room_id)
+					.filter_map(Result::ok)
+					.find(|user| {
+						user.server_name() == services().globals.server_name()
+							&& services()
+								.rooms
+								.state_accessor
+								.user_can_invite(&body.room_id, user, &body.user_id)
+								.unwrap_or(false)
+					})
+			} else {
+				None
+			}
+		} else {
+			None
 		}
-	}
+	} else {
+		None
+	};
 
 	let room_version_id = services().rooms.state.get_room_version(&body.room_id)?;
 	if !body.ver.contains(&room_version_id) {
@@ -957,7 +990,7 @@ pub async fn create_join_event_template_route(
 		membership: MembershipState::Join,
 		third_party_invite: None,
 		reason: None,
-		join_authorized_via_users_server: None,
+		join_authorized_via_users_server,
 	})
 	.expect("member event is valid value");
 
@@ -1017,35 +1050,6 @@ async fn create_join_event(
 		.event_handler
 		.acl_check(sender_servername, room_id)?;
 
-	// TODO: Conduit does not implement restricted join rules yet, we always reject
-	let join_rules_event =
-		services()
-			.rooms
-			.state_accessor
-			.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
-
-	let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
-		.as_ref()
-		.map(|join_rules_event| {
-			serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
-				warn!("Invalid join rules event: {}", e);
-				Error::bad_database("Invalid join rules event in db.")
-			})
-		})
-		.transpose()?;
-
-	if let Some(join_rules_event_content) = join_rules_event_content {
-		if matches!(
-			join_rules_event_content.join_rule,
-			JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
-		) {
-			return Err(Error::BadRequest(
-				ErrorKind::UnableToAuthorizeJoin,
-				"Conduit does not support restricted rooms yet.",
-			));
-		}
-	}
-
 	// We need to return the state prior to joining, let's keep a reference to that
 	// here
 	let shortstatehash = services()
@@ -1060,7 +1064,7 @@ async fn create_join_event(
 	// We do not add the event_id field to the pdu here because of signature and
 	// hashes checks
 	let room_version_id = services().rooms.state.get_room_version(room_id)?;
-	let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
+	let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
 		// Event could not be converted to canonical json
 		return Err(Error::BadRequest(
 			ErrorKind::InvalidParam,
@@ -1068,6 +1072,14 @@ async fn create_join_event(
 		));
 	};
 
+	ruma::signatures::hash_and_sign_event(
+		services().globals.server_name().as_str(),
+		services().globals.keypair(),
+		&mut value,
+		&room_version_id,
+	)
+	.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
+
 	let origin: OwnedServerName = serde_json::from_value(
 		serde_json::to_value(
 			value
@@ -1097,7 +1109,7 @@ async fn create_join_event(
 	let pdu_id: Vec<u8> = services()
 		.rooms
 		.event_handler
-		.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map)
+		.handle_incoming_pdu(&origin, &event_id, room_id, value.clone(), true, &pub_key_map)
 		.await?
 		.ok_or(Error::BadRequest(
 			ErrorKind::InvalidParam,
@@ -1128,7 +1140,11 @@ async fn create_join_event(
 			.filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten())
 			.map(PduEvent::convert_to_outgoing_federation_event)
 			.collect(),
-		event: None, // TODO: handle restricted joins
+		// Event field is required if the room version supports restricted join rules.
+		event: Some(
+			to_raw_value(&CanonicalJsonValue::Object(value.clone()))
+				.expect("To raw json should not fail since only change was adding signature"),
+		),
 	})
 }
 
diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs
index d4c708f54..fdd62eddc 100644
--- a/src/service/rooms/state_accessor/mod.rs
+++ b/src/service/rooms/state_accessor/mod.rs
@@ -13,6 +13,7 @@
 			history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent},
 			member::{MembershipState, RoomMemberEventContent},
 			name::RoomNameEventContent,
+			power_levels::RoomPowerLevelsEventContent,
 		},
 		StateEventType,
 	},
@@ -139,6 +140,18 @@ pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_
 		Ok(visibility)
 	}
 
+	/// Whether a user's power level is sufficient to invite other users
+	pub fn user_can_invite(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
+		self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
+			.map(|pdu_event| {
+				serde_json::from_str(pdu_event.content.get()).map(|content: RoomPowerLevelsEventContent| {
+					content.users.get(user_id).unwrap_or(&content.users_default) >= &content.invite
+				})
+			})
+			.unwrap_or(Ok(false))
+			.map_err(|_| Error::bad_database("Invalid history visibility event in database."))
+	}
+
 	/// Whether a user is allowed to see an event, based on
 	/// the room's history_visibility at that event's state.
 	#[tracing::instrument(skip(self, user_id, room_id, event_id))]
-- 
GitLab