From 54e6a41404f136b967a98664c739ed3d767a9ce0 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Thu, 15 Aug 2024 20:08:53 +0000
Subject: [PATCH] move remote media fetchers into services

minor error simplification

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/api/client/media.rs     | 250 ++++++++++--------------------------
 src/service/globals/mod.rs  |   2 -
 src/service/media/mod.rs    |   5 +-
 src/service/media/remote.rs | 108 ++++++++++++++++
 4 files changed, 182 insertions(+), 183 deletions(-)
 create mode 100644 src/service/media/remote.rs

diff --git a/src/api/client/media.rs b/src/api/client/media.rs
index 9a7d9eaae..0326568f3 100644
--- a/src/api/client/media.rs
+++ b/src/api/client/media.rs
@@ -1,21 +1,16 @@
 #![allow(deprecated)]
 
-use std::time::Duration;
-
 use axum::extract::State;
 use axum_client_ip::InsecureClientIp;
 use conduit::{
-	debug_info, debug_warn, err, info,
+	err,
 	utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
-	warn, Err, Error, Result,
+	Err, Result,
 };
 use ruma::api::client::media::{
 	create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, get_media_preview,
 };
-use service::{
-	media::{FileMeta, MXC_LENGTH},
-	Services,
-};
+use service::media::{FileMeta, MXC_LENGTH};
 
 use crate::{Ruma, RumaResponse};
 
@@ -62,24 +57,24 @@ pub(crate) async fn get_media_preview_route(
 
 	let url = &body.url;
 	if !services.media.url_preview_allowed(url) {
-		debug_info!(%sender_user, %url, "URL is not allowed to be previewed");
-		return Err!(Request(Forbidden("URL is not allowed to be previewed")));
+		return Err!(Request(Forbidden(
+			debug_warn!(%sender_user, %url, "URL is not allowed to be previewed")
+		)));
 	}
 
-	match services.media.get_url_preview(url).await {
-		Ok(preview) => {
-			let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
-				warn!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}");
-				err!(Request(Unknown("Failed to generate a URL preview")))
-			})?;
-
-			Ok(get_media_preview::v3::Response::from_raw_value(res))
-		},
-		Err(e) => {
-			info!(%sender_user, "Failed to generate a URL preview: {e}");
-			Err!(Request(Unknown("Failed to generate a URL preview")))
-		},
-	}
+	let preview = services.media.get_url_preview(url).await.map_err(|e| {
+		err!(Request(Unknown(
+			debug_error!(%sender_user, %url, "Failed to fetch a URL preview: {e}")
+		)))
+	})?;
+
+	let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
+		err!(Request(Unknown(
+			debug_error!(%sender_user, %url, "Failed to parse a URL preview: {e}")
+		)))
+	})?;
+
+	Ok(get_media_preview::v3::Response::from_raw_value(res))
 }
 
 /// # `GET /_matrix/media/v1/preview_url`
@@ -176,25 +171,25 @@ pub(crate) async fn get_content_route(
 	{
 		let content_disposition = make_content_disposition(content_disposition.as_ref(), content_type.as_deref(), None);
 
-		let file = content.expect("content");
 		Ok(get_content::v3::Response {
-			file,
+			file: content.expect("entire file contents"),
 			content_type: content_type.map(Into::into),
 			content_disposition: Some(content_disposition),
 			cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
 			cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
 		})
 	} else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote {
-		let response = get_remote_content(
-			&services,
-			&mxc,
-			&body.server_name,
-			body.media_id.clone(),
-			body.allow_redirect,
-			body.timeout_ms,
-		)
-		.await
-		.map_err(|e| err!(Request(NotFound(debug_warn!("Fetching media `{mxc}` failed: {e:?}")))))?;
+		let response = services
+			.media
+			.fetch_remote_content(
+				&mxc,
+				&body.server_name,
+				body.media_id.clone(),
+				body.allow_redirect,
+				body.timeout_ms,
+			)
+			.await
+			.map_err(|e| err!(Request(NotFound(debug_warn!(%mxc, "Fetching media failed: {e:?}")))))?;
 
 		let content_disposition =
 			make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
@@ -257,42 +252,36 @@ pub(crate) async fn get_content_as_filename_route(
 		let content_disposition =
 			make_content_disposition(content_disposition.as_ref(), content_type.as_deref(), Some(&body.filename));
 
-		let file = content.expect("content");
 		Ok(get_content_as_filename::v3::Response {
-			file,
+			file: content.expect("entire file contents"),
 			content_type: content_type.map(Into::into),
 			content_disposition: Some(content_disposition),
 			cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
 			cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
 		})
 	} else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote {
-		match get_remote_content(
-			&services,
-			&mxc,
-			&body.server_name,
-			body.media_id.clone(),
-			body.allow_redirect,
-			body.timeout_ms,
-		)
-		.await
-		{
-			Ok(remote_content_response) => {
-				let content_disposition = make_content_disposition(
-					remote_content_response.content_disposition.as_ref(),
-					remote_content_response.content_type.as_deref(),
-					None,
-				);
-
-				Ok(get_content_as_filename::v3::Response {
-					content_disposition: Some(content_disposition),
-					content_type: remote_content_response.content_type,
-					file: remote_content_response.file,
-					cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
-					cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
-				})
-			},
-			Err(e) => Err!(Request(NotFound(debug_warn!("Fetching media `{mxc}` failed: {e:?}")))),
-		}
+		let response = services
+			.media
+			.fetch_remote_content(
+				&mxc,
+				&body.server_name,
+				body.media_id.clone(),
+				body.allow_redirect,
+				body.timeout_ms,
+			)
+			.await
+			.map_err(|e| err!(Request(NotFound(debug_warn!(%mxc, "Fetching media failed: {e:?}")))))?;
+
+		let content_disposition =
+			make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
+
+		Ok(get_content_as_filename::v3::Response {
+			content_disposition: Some(content_disposition),
+			content_type: response.content_type,
+			file: response.file,
+			cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
+			cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
+		})
 	} else {
 		Err!(Request(NotFound("Media not found.")))
 	}
@@ -353,75 +342,31 @@ pub(crate) async fn get_content_thumbnail_route(
 		.await?
 	{
 		let content_disposition = make_content_disposition(content_disposition.as_ref(), content_type.as_deref(), None);
-		let file = content.expect("content");
 
 		Ok(get_content_thumbnail::v3::Response {
-			file,
+			file: content.expect("entire file contents"),
 			content_type: content_type.map(Into::into),
 			cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
 			cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
 			content_disposition: Some(content_disposition),
 		})
 	} else if !services.globals.server_is_ours(&body.server_name) && body.allow_remote {
-		if services
-			.globals
-			.prevent_media_downloads_from()
-			.contains(&body.server_name)
-		{
-			// we'll lie to the client and say the blocked server's media was not found and
-			// log. the client has no way of telling anyways so this is a security bonus.
-			debug_warn!("Received request for media `{}` on blocklisted server", mxc);
-			return Err!(Request(NotFound("Media not found.")));
-		}
-
-		match services
-			.sending
-			.send_federation_request(
-				&body.server_name,
-				get_content_thumbnail::v3::Request {
-					allow_remote: body.allow_remote,
-					height: body.height,
-					width: body.width,
-					method: body.method.clone(),
-					server_name: body.server_name.clone(),
-					media_id: body.media_id.clone(),
-					timeout_ms: body.timeout_ms,
-					allow_redirect: body.allow_redirect,
-					animated: body.animated,
-				},
-			)
+		let response = services
+			.media
+			.fetch_remote_thumbnail(&mxc, &body)
 			.await
-		{
-			Ok(get_thumbnail_response) => {
-				services
-					.media
-					.upload_thumbnail(
-						None,
-						&mxc,
-						None,
-						get_thumbnail_response.content_type.as_deref(),
-						body.width.try_into().expect("all UInts are valid u32s"),
-						body.height.try_into().expect("all UInts are valid u32s"),
-						&get_thumbnail_response.file,
-					)
-					.await?;
-
-				let content_disposition = make_content_disposition(
-					get_thumbnail_response.content_disposition.as_ref(),
-					get_thumbnail_response.content_type.as_deref(),
-					None,
-				);
-
-				Ok(get_content_thumbnail::v3::Response {
-					file: get_thumbnail_response.file,
-					content_type: get_thumbnail_response.content_type,
-					cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
-					cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
-					content_disposition: Some(content_disposition),
-				})
-			},
-			Err(e) => Err!(Request(NotFound(debug_warn!("Fetching media `{mxc}` failed: {e:?}")))),
-		}
+			.map_err(|e| err!(Request(NotFound(debug_warn!(%mxc, "Fetching media failed: {e:?}")))))?;
+
+		let content_disposition =
+			make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
+
+		Ok(get_content_thumbnail::v3::Response {
+			file: response.file,
+			content_type: response.content_type,
+			cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
+			cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
+			content_disposition: Some(content_disposition),
+		})
 	} else {
 		Err!(Request(NotFound("Media not found.")))
 	}
@@ -448,58 +393,3 @@ pub(crate) async fn get_content_thumbnail_v1_route(
 		.await
 		.map(RumaResponse)
 }
-
-async fn get_remote_content(
-	services: &Services, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool,
-	timeout_ms: Duration,
-) -> Result<get_content::v3::Response, Error> {
-	if services
-		.globals
-		.prevent_media_downloads_from()
-		.contains(&server_name.to_owned())
-	{
-		// we'll lie to the client and say the blocked server's media was not found and
-		// log. the client has no way of telling anyways so this is a security bonus.
-		debug_warn!("Received request for media `{mxc}` on blocklisted server");
-		return Err!(Request(NotFound("Media not found.")));
-	}
-
-	let content_response = services
-		.sending
-		.send_federation_request(
-			server_name,
-			get_content::v3::Request {
-				allow_remote: true,
-				server_name: server_name.to_owned(),
-				media_id,
-				timeout_ms,
-				allow_redirect,
-			},
-		)
-		.await?;
-
-	let content_disposition = make_content_disposition(
-		content_response.content_disposition.as_ref(),
-		content_response.content_type.as_deref(),
-		None,
-	);
-
-	services
-		.media
-		.create(
-			None,
-			mxc,
-			Some(&content_disposition),
-			content_response.content_type.as_deref(),
-			&content_response.file,
-		)
-		.await?;
-
-	Ok(get_content::v3::Response {
-		file: content_response.file,
-		content_type: content_response.content_type,
-		content_disposition: Some(content_disposition),
-		cross_origin_resource_policy: Some(CORP_CROSS_ORIGIN.into()),
-		cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()),
-	})
-}
diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs
index 2c588dce0..05fe1a77c 100644
--- a/src/service/globals/mod.rs
+++ b/src/service/globals/mod.rs
@@ -248,8 +248,6 @@ pub fn allow_incoming_read_receipts(&self) -> bool { self.config.allow_incoming_
 
 	pub fn allow_outgoing_read_receipts(&self) -> bool { self.config.allow_outgoing_read_receipts }
 
-	pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from }
-
 	pub fn forbidden_remote_room_directory_server_names(&self) -> &[OwnedServerName] {
 		&self.config.forbidden_remote_room_directory_server_names
 	}
diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs
index 62baf5481..b63dc1413 100644
--- a/src/service/media/mod.rs
+++ b/src/service/media/mod.rs
@@ -1,5 +1,6 @@
 mod data;
 mod preview;
+mod remote;
 mod tests;
 mod thumbnail;
 
@@ -15,7 +16,7 @@
 	io::{AsyncReadExt, AsyncWriteExt, BufReader},
 };
 
-use crate::{client, globals, Dep};
+use crate::{client, globals, sending, Dep};
 
 #[derive(Debug)]
 pub struct FileMeta {
@@ -34,6 +35,7 @@ struct Services {
 	server: Arc<Server>,
 	client: Dep<client::Service>,
 	globals: Dep<globals::Service>,
+	sending: Dep<sending::Service>,
 }
 
 /// generated MXC ID (`media-id`) length
@@ -49,6 +51,7 @@ fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
 				server: args.server.clone(),
 				client: args.depend::<client::Service>("client"),
 				globals: args.depend::<globals::Service>("globals"),
+				sending: args.depend::<sending::Service>("sending"),
 			},
 		}))
 	}
diff --git a/src/service/media/remote.rs b/src/service/media/remote.rs
new file mode 100644
index 000000000..c07ded768
--- /dev/null
+++ b/src/service/media/remote.rs
@@ -0,0 +1,108 @@
+use std::time::Duration;
+
+use conduit::{debug_warn, err, implement, utils::content_disposition::make_content_disposition, Err, Error, Result};
+use ruma::{
+	api::client::media::{get_content, get_content_thumbnail},
+	ServerName,
+};
+
+#[implement(super::Service)]
+#[allow(deprecated)]
+pub async fn fetch_remote_thumbnail(
+	&self, mxc: &str, body: &get_content_thumbnail::v3::Request,
+) -> Result<get_content_thumbnail::v3::Response> {
+	let server_name = &body.server_name;
+	self.check_fetch_authorized(mxc, server_name)?;
+
+	let reponse = self
+		.services
+		.sending
+		.send_federation_request(
+			server_name,
+			get_content_thumbnail::v3::Request {
+				allow_remote: body.allow_remote,
+				height: body.height,
+				width: body.width,
+				method: body.method.clone(),
+				server_name: body.server_name.clone(),
+				media_id: body.media_id.clone(),
+				timeout_ms: body.timeout_ms,
+				allow_redirect: body.allow_redirect,
+				animated: body.animated,
+			},
+		)
+		.await?;
+
+	self.upload_thumbnail(
+		None,
+		mxc,
+		None,
+		reponse.content_type.as_deref(),
+		body.width
+			.try_into()
+			.map_err(|e| err!(Request(InvalidParam("Width is invalid: {e:?}"))))?,
+		body.height
+			.try_into()
+			.map_err(|e| err!(Request(InvalidParam("Height is invalid: {e:?}"))))?,
+		&reponse.file,
+	)
+	.await?;
+
+	Ok(reponse)
+}
+
+#[implement(super::Service)]
+#[allow(deprecated)]
+pub async fn fetch_remote_content(
+	&self, mxc: &str, server_name: &ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
+) -> Result<get_content::v3::Response, Error> {
+	self.check_fetch_authorized(mxc, server_name)?;
+
+	let response = self
+		.services
+		.sending
+		.send_federation_request(
+			server_name,
+			get_content::v3::Request {
+				allow_remote: true,
+				server_name: server_name.to_owned(),
+				media_id,
+				timeout_ms,
+				allow_redirect,
+			},
+		)
+		.await?;
+
+	let content_disposition =
+		make_content_disposition(response.content_disposition.as_ref(), response.content_type.as_deref(), None);
+
+	self.create(
+		None,
+		mxc,
+		Some(&content_disposition),
+		response.content_type.as_deref(),
+		&response.file,
+	)
+	.await?;
+
+	Ok(response)
+}
+
+#[implement(super::Service)]
+fn check_fetch_authorized(&self, mxc: &str, server_name: &ServerName) -> Result<()> {
+	if self
+		.services
+		.server
+		.config
+		.prevent_media_downloads_from
+		.iter()
+		.any(|entry| entry == server_name)
+	{
+		// we'll lie to the client and say the blocked server's media was not found and
+		// log. the client has no way of telling anyways so this is a security bonus.
+		debug_warn!(%mxc, "Received request for media on blocklisted server");
+		return Err!(Request(NotFound("Media not found.")));
+	}
+
+	Ok(())
+}
-- 
GitLab