From fa2da9e04849ccf2c37e3769d7fc776599186e2e Mon Sep 17 00:00:00 2001
From: CapsizeGlimmer <>
Date: Thu, 23 Jul 2020 23:03:24 -0400
Subject: [PATCH] Implement max_request_size config option

---
 src/client_server.rs    | 10 +++++-----
 src/database/globals.rs | 16 +++++++++++++---
 src/ruma_wrapper.rs     | 14 ++++++--------
 3 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/src/client_server.rs b/src/client_server.rs
index a3f476082..eede5fd2e 100644
--- a/src/client_server.rs
+++ b/src/client_server.rs
@@ -2977,11 +2977,11 @@ pub fn send_event_to_device_route(
 }
 
 #[get("/_matrix/media/r0/config")]
-pub fn get_media_config_route() -> ConduitResult<get_media_config::Response> {
-    Ok(get_media_config::Response {
-        upload_size: (20_u32 * 1024 * 1024).into(), // 20 MB
-    }
-    .into())
+pub fn get_media_config_route(
+    db: State<'_, Database>,
+) -> ConduitResult<get_media_config::Response> {
+    let upload_size = db.globals.max_request_size().into();
+    Ok(get_media_config::Response { upload_size }.into())
 }
 
 #[post("/_matrix/media/r0/upload", data = "<body>")]
diff --git a/src/database/globals.rs b/src/database/globals.rs
index 3a257a54c..5db28069f 100644
--- a/src/database/globals.rs
+++ b/src/database/globals.rs
@@ -1,7 +1,7 @@
-use std::convert::TryInto;
-
 use crate::{utils, Error, Result};
 use ruma::ServerName;
+use std::convert::TryInto;
+
 pub const COUNTER: &str = "c";
 
 pub struct Globals {
@@ -9,6 +9,7 @@ pub struct Globals {
     keypair: ruma::signatures::Ed25519KeyPair,
     reqwest_client: reqwest::Client,
     server_name: Box<ServerName>,
+    max_request_size: u32,
     registration_disabled: bool,
     encryption_disabled: bool,
 }
@@ -32,7 +33,12 @@ pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result<Self> {
                 .unwrap_or("localhost")
                 .to_string()
                 .try_into()
-                .map_err(|_| Error::BadConfig("Invalid server name found."))?,
+                .map_err(|_| Error::BadConfig("Invalid server_name."))?,
+            max_request_size: config
+                .get_int("max_request_size")
+                .unwrap_or(20 * 1024 * 1024) // Default to 20 MB
+                .try_into()
+                .map_err(|_| Error::BadConfig("Invalid max_request_size."))?,
             registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
             encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false),
         })
@@ -69,6 +75,10 @@ pub fn server_name(&self) -> &ServerName {
         self.server_name.as_ref()
     }
 
+    pub fn max_request_size(&self) -> u32 {
+        self.max_request_size
+    }
+
     pub fn registration_disabled(&self) -> bool {
         self.registration_disabled
     }
diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs
index 15e50ba32..5b380b37f 100644
--- a/src/ruma_wrapper.rs
+++ b/src/ruma_wrapper.rs
@@ -11,8 +11,6 @@
 use std::{convert::TryInto, io::Cursor, ops::Deref};
 use tokio::io::AsyncReadExt;
 
-const MESSAGE_LIMIT: u64 = 20 * 1024 * 1024; // 20 MB
-
 /// This struct converts rocket requests into ruma structs by converting them into http requests
 /// first.
 pub struct Ruma<T> {
@@ -40,13 +38,12 @@ fn from_data(
     ) -> FromDataFuture<'a, Self, Self::Error> {
         Box::pin(async move {
             let data = rocket::try_outcome!(outcome.owned());
+            let db = request
+                .guard::<State<'_, crate::Database>>()
+                .await
+                .expect("database was loaded");
 
             let (user_id, device_id) = if T::METADATA.requires_authentication {
-                let db = request
-                    .guard::<State<'_, crate::Database>>()
-                    .await
-                    .expect("database was loaded");
-
                 // Get token from header or query value
                 let token = match request
                     .headers()
@@ -76,7 +73,8 @@ fn from_data(
                 http_request = http_request.header(header.name.as_str(), &*header.value);
             }
 
-            let mut handle = data.open().take(MESSAGE_LIMIT);
+            let limit = db.globals.max_request_size();
+            let mut handle = data.open().take(limit.into());
             let mut body = Vec::new();
             handle.read_to_end(&mut body).await.unwrap();
 
-- 
GitLab