From eb73d8c669abbdb9ad104de9da08bb69895a9ab6 Mon Sep 17 00:00:00 2001
From: Benjamin Lee <benjamin@computer.surgery>
Date: Wed, 12 Jun 2024 00:33:12 -0400
Subject: [PATCH] fix: de-index pdus when redacted

bit of code dedupe as well

Co-authored-by: strawberry <strawberry@puppygock.gay>
Signed-off-by: strawberry <strawberry@puppygock.gay>
---
 src/service/rooms/search/data.rs  | 18 +++++++++++++++++
 src/service/rooms/search/mod.rs   |  5 +++++
 src/service/rooms/timeline/mod.rs | 33 +++++++++++++++++++------------
 3 files changed, 43 insertions(+), 13 deletions(-)

diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs
index 0e6d251a3..7fa49989a 100644
--- a/src/service/rooms/search/data.rs
+++ b/src/service/rooms/search/data.rs
@@ -7,6 +7,8 @@
 pub trait Data: Send + Sync {
 	fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
 
+	fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
+
 	fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>;
 }
 
@@ -34,6 +36,22 @@ fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Resu
 		self.tokenids.insert_batch(&mut batch)
 	}
 
+	fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
+		let batch = tokenize(message_body).map(|word| {
+			let mut key = shortroomid.to_be_bytes().to_vec();
+			key.extend_from_slice(word.as_bytes());
+			key.push(0xFF);
+			key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
+			key
+		});
+
+		for token in batch {
+			self.tokenids.remove(&token)?;
+		}
+
+		Ok(())
+	}
+
 	fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
 		let prefix = services()
 			.rooms
diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs
index 80ac45ae2..9e7f14c73 100644
--- a/src/service/rooms/search/mod.rs
+++ b/src/service/rooms/search/mod.rs
@@ -17,6 +17,11 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) ->
 		self.db.index_pdu(shortroomid, pdu_id, message_body)
 	}
 
+	#[tracing::instrument(skip(self))]
+	pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
+		self.db.deindex_pdu(shortroomid, pdu_id, message_body)
+	}
+
 	#[tracing::instrument(skip(self))]
 	pub fn search_pdus<'a>(
 		&'a self, room_id: &RoomId, search_string: &str,
diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs
index df23da8fb..c7d045e92 100644
--- a/src/service/rooms/timeline/mod.rs
+++ b/src/service/rooms/timeline/mod.rs
@@ -68,6 +68,11 @@ struct ExtractRelatesToEventId {
 	relates_to: ExtractEventId,
 }
 
+#[derive(Deserialize)]
+struct ExtractBody {
+	body: Option<String>,
+}
+
 pub struct Service {
 	pub db: Arc<dyn Data>,
 
@@ -397,7 +402,7 @@ pub async fn append_pdu(
 					| RoomVersionId::V9
 					| RoomVersionId::V10 => {
 						if let Some(redact_id) = &pdu.redacts {
-							self.redact_pdu(redact_id, pdu)?;
+							self.redact_pdu(redact_id, pdu, shortroomid)?;
 						}
 					},
 					RoomVersionId::V11 => {
@@ -407,7 +412,7 @@ pub async fn append_pdu(
 								Error::bad_database("Invalid content in redaction pdu.")
 							})?;
 						if let Some(redact_id) = &content.redacts {
-							self.redact_pdu(redact_id, pdu)?;
+							self.redact_pdu(redact_id, pdu, shortroomid)?;
 						}
 					},
 					_ => {
@@ -463,11 +468,6 @@ pub async fn append_pdu(
 				}
 			},
 			TimelineEventType::RoomMessage => {
-				#[derive(Deserialize)]
-				struct ExtractBody {
-					body: Option<String>,
-				}
-
 				let content = serde_json::from_str::<ExtractBody>(pdu.content.get())
 					.map_err(|_| Error::bad_database("Invalid content in pdu."))?;
 
@@ -984,14 +984,26 @@ pub fn pdus_after<'a>(
 
 	/// Replace a PDU with the redacted form.
 	#[tracing::instrument(skip(self, reason))]
-	pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> {
+	pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> {
 		// TODO: Don't reserialize, keep original json
 		if let Some(pdu_id) = self.get_pdu_id(event_id)? {
 			let mut pdu = self
 				.get_pdu_from_id(&pdu_id)?
 				.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?;
+
+			if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) {
+				if let Some(body) = content.body {
+					services()
+						.rooms
+						.search
+						.deindex_pdu(shortroomid, &pdu_id, &body)?;
+				}
+			}
+
 			let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?;
+
 			pdu.redact(room_version_id, reason)?;
+
 			self.replace_pdu(
 				&pdu_id,
 				&utils::to_canonical_object(&pdu).map_err(|e| {
@@ -1188,11 +1200,6 @@ pub async fn backfill_pdu(
 		drop(insert_lock);
 
 		if pdu.kind == TimelineEventType::RoomMessage {
-			#[derive(Deserialize)]
-			struct ExtractBody {
-				body: Option<String>,
-			}
-
 			let content = serde_json::from_str::<ExtractBody>(pdu.content.get())
 				.map_err(|_| Error::bad_database("Invalid content in pdu."))?;
 
-- 
GitLab