Skip to content
Snippets Groups Projects
Commit eb73d8c6 authored by Benjamin Lee's avatar Benjamin Lee Committed by 🥺
Browse files

fix: de-index pdus when redacted


bit of code dedupe as well

Co-authored-by: default avatarstrawberry <strawberry@puppygock.gay>
Signed-off-by: default avatarstrawberry <strawberry@puppygock.gay>
parent 20a54aac
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
pub trait Data: Send + Sync { pub trait Data: Send + Sync {
fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>;
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>;
} }
...@@ -34,6 +36,22 @@ fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Resu ...@@ -34,6 +36,22 @@ fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Resu
self.tokenids.insert_batch(&mut batch) self.tokenids.insert_batch(&mut batch)
} }
fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
key
});
for token in batch {
self.tokenids.remove(&token)?;
}
Ok(())
}
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services() let prefix = services()
.rooms .rooms
......
...@@ -17,6 +17,11 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> ...@@ -17,6 +17,11 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) ->
self.db.index_pdu(shortroomid, pdu_id, message_body) self.db.index_pdu(shortroomid, pdu_id, message_body)
} }
#[tracing::instrument(skip(self))]
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.deindex_pdu(shortroomid, pdu_id, message_body)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn search_pdus<'a>( pub fn search_pdus<'a>(
&'a self, room_id: &RoomId, search_string: &str, &'a self, room_id: &RoomId, search_string: &str,
......
...@@ -68,6 +68,11 @@ struct ExtractRelatesToEventId { ...@@ -68,6 +68,11 @@ struct ExtractRelatesToEventId {
relates_to: ExtractEventId, relates_to: ExtractEventId,
} }
#[derive(Deserialize)]
struct ExtractBody {
body: Option<String>,
}
pub struct Service { pub struct Service {
pub db: Arc<dyn Data>, pub db: Arc<dyn Data>,
...@@ -397,7 +402,7 @@ pub async fn append_pdu( ...@@ -397,7 +402,7 @@ pub async fn append_pdu(
| RoomVersionId::V9 | RoomVersionId::V9
| RoomVersionId::V10 => { | RoomVersionId::V10 => {
if let Some(redact_id) = &pdu.redacts { if let Some(redact_id) = &pdu.redacts {
self.redact_pdu(redact_id, pdu)?; self.redact_pdu(redact_id, pdu, shortroomid)?;
} }
}, },
RoomVersionId::V11 => { RoomVersionId::V11 => {
...@@ -407,7 +412,7 @@ pub async fn append_pdu( ...@@ -407,7 +412,7 @@ pub async fn append_pdu(
Error::bad_database("Invalid content in redaction pdu.") Error::bad_database("Invalid content in redaction pdu.")
})?; })?;
if let Some(redact_id) = &content.redacts { if let Some(redact_id) = &content.redacts {
self.redact_pdu(redact_id, pdu)?; self.redact_pdu(redact_id, pdu, shortroomid)?;
} }
}, },
_ => { _ => {
...@@ -463,11 +468,6 @@ pub async fn append_pdu( ...@@ -463,11 +468,6 @@ pub async fn append_pdu(
} }
}, },
TimelineEventType::RoomMessage => { TimelineEventType::RoomMessage => {
#[derive(Deserialize)]
struct ExtractBody {
body: Option<String>,
}
let content = serde_json::from_str::<ExtractBody>(pdu.content.get()) let content = serde_json::from_str::<ExtractBody>(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
...@@ -984,14 +984,26 @@ pub fn pdus_after<'a>( ...@@ -984,14 +984,26 @@ pub fn pdus_after<'a>(
/// Replace a PDU with the redacted form. /// Replace a PDU with the redacted form.
#[tracing::instrument(skip(self, reason))] #[tracing::instrument(skip(self, reason))]
pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> {
// TODO: Don't reserialize, keep original json // TODO: Don't reserialize, keep original json
if let Some(pdu_id) = self.get_pdu_id(event_id)? { if let Some(pdu_id) = self.get_pdu_id(event_id)? {
let mut pdu = self let mut pdu = self
.get_pdu_from_id(&pdu_id)? .get_pdu_from_id(&pdu_id)?
.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?;
if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) {
if let Some(body) = content.body {
services()
.rooms
.search
.deindex_pdu(shortroomid, &pdu_id, &body)?;
}
}
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?;
pdu.redact(room_version_id, reason)?; pdu.redact(room_version_id, reason)?;
self.replace_pdu( self.replace_pdu(
&pdu_id, &pdu_id,
&utils::to_canonical_object(&pdu).map_err(|e| { &utils::to_canonical_object(&pdu).map_err(|e| {
...@@ -1188,11 +1200,6 @@ pub async fn backfill_pdu( ...@@ -1188,11 +1200,6 @@ pub async fn backfill_pdu(
drop(insert_lock); drop(insert_lock);
if pdu.kind == TimelineEventType::RoomMessage { if pdu.kind == TimelineEventType::RoomMessage {
#[derive(Deserialize)]
struct ExtractBody {
body: Option<String>,
}
let content = serde_json::from_str::<ExtractBody>(pdu.content.get()) let content = serde_json::from_str::<ExtractBody>(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment