From e305889b7250a97a7e83c96f98d4e65a570be35b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Wed, 24 Mar 2021 08:48:28 +0100
Subject: [PATCH] feat: room_account_data endpoints

---
 src/client_server/config.rs | 54 ++++++++++++++++++++++++++++++++++++-
 src/main.rs                 |  2 ++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/src/client_server/config.rs b/src/client_server/config.rs
index a53b7cd4c..6abcba235 100644
--- a/src/client_server/config.rs
+++ b/src/client_server/config.rs
@@ -3,7 +3,7 @@
 use ruma::{
     api::client::{
         error::ErrorKind,
-        r0::config::{get_global_account_data, set_global_account_data},
+        r0::config::{get_room_account_data, get_global_account_data, set_room_account_data, set_global_account_data},
     },
     events::{custom::CustomEventContent, BasicEvent},
     serde::Raw,
@@ -43,6 +43,37 @@ pub async fn set_global_account_data_route(
     Ok(set_global_account_data::Response.into())
 }
 
+#[cfg_attr(
+    feature = "conduit_bin",
+    put("/_matrix/client/r0/user/<_>/rooms/<_>/account_data/<_>", data = "<body>")
+)]
+#[tracing::instrument(skip(db, body))]
+pub async fn set_room_account_data_route(
+    db: State<'_, Database>,
+    body: Ruma<set_room_account_data::Request<'_>>,
+) -> ConduitResult<set_room_account_data::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    let data = serde_json::from_str(body.data.get())
+        .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
+
+    let event_type = body.event_type.to_string();
+
+    db.account_data.update(
+        Some(&body.room_id),
+        sender_user,
+        event_type.clone().into(),
+        &BasicEvent {
+            content: CustomEventContent { event_type, data },
+        },
+        &db.globals,
+    )?;
+
+    db.flush().await?;
+
+    Ok(set_room_account_data::Response.into())
+}
+
 #[cfg_attr(
     feature = "conduit_bin",
     get("/_matrix/client/r0/user/<_>/account_data/<_>", data = "<body>")
@@ -63,3 +94,24 @@ pub async fn get_global_account_data_route(
 
     Ok(get_global_account_data::Response { account_data: data }.into())
 }
+
+#[cfg_attr(
+    feature = "conduit_bin",
+    get("/_matrix/client/r0/user/<_>/rooms/<_>/account_data/<_>", data = "<body>")
+)]
+#[tracing::instrument(skip(db, body))]
+pub async fn get_room_account_data_route(
+    db: State<'_, Database>,
+    body: Ruma<get_room_account_data::Request<'_>>,
+) -> ConduitResult<get_room_account_data::Response> {
+    let sender_user = body.sender_user.as_ref().expect("user is authenticated");
+
+    let data = db
+        .account_data
+        .get::<Raw<ruma::events::AnyBasicEvent>>(Some(&body.room_id), sender_user, body.event_type.clone().into())?
+        .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
+
+    db.flush().await?;
+
+    Ok(get_room_account_data::Response { account_data: data }.into())
+}
diff --git a/src/main.rs b/src/main.rs
index 327aefab6..696ce5c87 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -77,7 +77,9 @@ fn setup_rocket() -> (rocket::Rocket, Config) {
                 client_server::get_filter_route,
                 client_server::create_filter_route,
                 client_server::set_global_account_data_route,
+                client_server::set_room_account_data_route,
                 client_server::get_global_account_data_route,
+                client_server::get_room_account_data_route,
                 client_server::set_displayname_route,
                 client_server::get_displayname_route,
                 client_server::set_avatar_url_route,
-- 
GitLab