From 137e3008ea04d36f9562eeadc61b276032fd2ddf Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Wed, 6 Nov 2024 21:02:23 +0000
Subject: [PATCH] merge rooms threads data and service

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/api/client/relations.rs            | 12 ++--
 src/api/client/threads.rs              | 10 ++-
 src/service/rooms/pdu_metadata/data.rs | 15 ++---
 src/service/rooms/threads/data.rs      | 90 --------------------------
 src/service/rooms/threads/mod.rs       | 88 +++++++++++++++++++------
 5 files changed, 91 insertions(+), 124 deletions(-)
 delete mode 100644 src/service/rooms/threads/data.rs

diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs
index ef7035e2f..b5d1485bd 100644
--- a/src/api/client/relations.rs
+++ b/src/api/client/relations.rs
@@ -97,7 +97,7 @@ async fn paginate_relations_with_filter(
 	filter_event_type: Option<TimelineEventType>, filter_rel_type: Option<RelationType>, from: Option<&str>,
 	to: Option<&str>, limit: Option<UInt>, recurse: bool, dir: Direction,
 ) -> Result<get_relating_events::v1::Response> {
-	let from: PduCount = from
+	let start: PduCount = from
 		.map(str::parse)
 		.transpose()?
 		.unwrap_or_else(|| match dir {
@@ -124,7 +124,7 @@ async fn paginate_relations_with_filter(
 	let events: Vec<PdusIterItem> = services
 		.rooms
 		.pdu_metadata
-		.get_relations(sender_user, room_id, target, from, limit, depth, dir)
+		.get_relations(sender_user, room_id, target, start, limit, depth, dir)
 		.await
 		.into_iter()
 		.filter(|(_, pdu)| {
@@ -146,16 +146,20 @@ async fn paginate_relations_with_filter(
 		.await;
 
 	let next_batch = match dir {
-		Direction::Backward => events.first(),
 		Direction::Forward => events.last(),
+		Direction::Backward => events.first(),
 	}
 	.map(at!(0))
+	.map(|count| match dir {
+		Direction::Forward => count.saturating_add(1),
+		Direction::Backward => count.saturating_sub(1),
+	})
 	.as_ref()
 	.map(ToString::to_string);
 
 	Ok(get_relating_events::v1::Response {
 		next_batch,
-		prev_batch: Some(from.to_string()),
+		prev_batch: from.map(Into::into),
 		recursion_depth: recurse.then_some(depth.into()),
 		chunk: events
 			.into_iter()
diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs
index 02cf79926..8d4e399bb 100644
--- a/src/api/client/threads.rs
+++ b/src/api/client/threads.rs
@@ -1,5 +1,5 @@
 use axum::extract::State;
-use conduit::{PduCount, PduEvent};
+use conduit::{at, PduCount, PduEvent};
 use futures::StreamExt;
 use ruma::{api::client::threads::get_threads, uint};
 
@@ -44,12 +44,16 @@ pub(crate) async fn get_threads_route(
 	Ok(get_threads::v1::Response {
 		next_batch: threads
 			.last()
-			.map(|(count, _)| count)
+			.filter(|_| threads.len() >= limit)
+			.map(at!(0))
+			.map(|count| count.saturating_sub(1))
+			.as_ref()
 			.map(ToString::to_string),
 
 		chunk: threads
 			.into_iter()
-			.map(|(_, pdu)| pdu.to_room_event())
+			.map(at!(1))
+			.map(|pdu| pdu.to_room_event())
 			.collect(),
 	})
 }
diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs
index 3fc065915..f3e1ced8b 100644
--- a/src/service/rooms/pdu_metadata/data.rs
+++ b/src/service/rooms/pdu_metadata/data.rs
@@ -1,5 +1,6 @@
 use std::{mem::size_of, sync::Arc};
 
+use arrayvec::ArrayVec;
 use conduit::{
 	result::LogErr,
 	utils::{stream::TryIgnore, u64_from_u8, ReadyExt},
@@ -54,15 +55,13 @@ pub(super) fn add_relation(&self, from: u64, to: u64) {
 	pub(super) fn get_relations<'a>(
 		&'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction,
 	) -> impl Stream<Item = PdusIterItem> + Send + '_ {
-		let current: RawPduId = PduId {
-			shortroomid,
-			shorteventid: from,
-		}
-		.into();
-
+		let mut current = ArrayVec::<u8, 16>::new();
+		current.extend(target.to_be_bytes());
+		current.extend(from.into_unsigned().to_be_bytes());
+		let current = current.as_slice();
 		match dir {
-			Direction::Forward => self.tofrom_relation.raw_keys_from(&current).boxed(),
-			Direction::Backward => self.tofrom_relation.rev_raw_keys_from(&current).boxed(),
+			Direction::Forward => self.tofrom_relation.raw_keys_from(current).boxed(),
+			Direction::Backward => self.tofrom_relation.rev_raw_keys_from(current).boxed(),
 		}
 		.ignore_err()
 		.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs
deleted file mode 100644
index c26dabb40..000000000
--- a/src/service/rooms/threads/data.rs
+++ /dev/null
@@ -1,90 +0,0 @@
-use std::sync::Arc;
-
-use conduit::{
-	result::LogErr,
-	utils::{stream::TryIgnore, ReadyExt},
-	PduCount, PduEvent, Result,
-};
-use database::{Deserialized, Map};
-use futures::{Stream, StreamExt};
-use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
-
-use crate::{
-	rooms,
-	rooms::{
-		short::ShortRoomId,
-		timeline::{PduId, RawPduId},
-	},
-	Dep,
-};
-
-pub(super) struct Data {
-	threadid_userids: Arc<Map>,
-	services: Services,
-}
-
-struct Services {
-	short: Dep<rooms::short::Service>,
-	timeline: Dep<rooms::timeline::Service>,
-}
-
-impl Data {
-	pub(super) fn new(args: &crate::Args<'_>) -> Self {
-		let db = &args.db;
-		Self {
-			threadid_userids: db["threadid_userids"].clone(),
-			services: Services {
-				short: args.depend::<rooms::short::Service>("rooms::short"),
-				timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
-			},
-		}
-	}
-
-	#[inline]
-	pub(super) async fn threads_until<'a>(
-		&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, _include: &'a IncludeThreads,
-	) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
-		let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
-
-		let current: RawPduId = PduId {
-			shortroomid,
-			shorteventid: until.saturating_sub(1),
-		}
-		.into();
-
-		let stream = self
-			.threadid_userids
-			.rev_raw_keys_from(&current)
-			.ignore_err()
-			.map(RawPduId::from)
-			.ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes())
-			.filter_map(move |pdu_id| async move {
-				let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
-				let pdu_id: PduId = pdu_id.into();
-
-				if pdu.sender != user_id {
-					pdu.remove_transaction_id().log_err().ok();
-				}
-
-				Some((pdu_id.shorteventid, pdu))
-			});
-
-		Ok(stream)
-	}
-
-	pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
-		let users = participants
-			.iter()
-			.map(|user| user.as_bytes())
-			.collect::<Vec<_>>()
-			.join(&[0xFF][..]);
-
-		self.threadid_userids.insert(root_id, &users);
-
-		Ok(())
-	}
-
-	pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
-		self.threadid_userids.get(root_id).await.deserialized()
-	}
-}
diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs
index 025030307..fcc629e1c 100644
--- a/src/service/rooms/threads/mod.rs
+++ b/src/service/rooms/threads/mod.rs
@@ -1,34 +1,44 @@
-mod data;
-
 use std::{collections::BTreeMap, sync::Arc};
 
-use conduit::{err, PduCount, PduEvent, Result};
-use data::Data;
-use futures::Stream;
+use conduit::{
+	err,
+	utils::{stream::TryIgnore, ReadyExt},
+	PduCount, PduEvent, PduId, RawPduId, Result,
+};
+use database::{Deserialized, Map};
+use futures::{Stream, StreamExt};
 use ruma::{
 	api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
-	EventId, RoomId, UserId,
+	EventId, OwnedUserId, RoomId, UserId,
 };
 use serde_json::json;
 
-use crate::{rooms, Dep};
+use crate::{rooms, rooms::short::ShortRoomId, Dep};
 
 pub struct Service {
-	services: Services,
 	db: Data,
+	services: Services,
 }
 
 struct Services {
+	short: Dep<rooms::short::Service>,
 	timeline: Dep<rooms::timeline::Service>,
 }
 
+pub(super) struct Data {
+	threadid_userids: Arc<Map>,
+}
+
 impl crate::Service for Service {
 	fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
 		Ok(Arc::new(Self {
+			db: Data {
+				threadid_userids: args.db["threadid_userids"].clone(),
+			},
 			services: Services {
+				short: args.depend::<rooms::short::Service>("rooms::short"),
 				timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
 			},
-			db: Data::new(&args),
 		}))
 	}
 
@@ -36,14 +46,6 @@ fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
 }
 
 impl Service {
-	pub async fn threads_until<'a>(
-		&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, include: &'a IncludeThreads,
-	) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
-		self.db
-			.threads_until(user_id, room_id, until, include)
-			.await
-	}
-
 	pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
 		let root_id = self
 			.services
@@ -113,13 +115,61 @@ pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Re
 		}
 
 		let mut users = Vec::new();
-		if let Ok(userids) = self.db.get_participants(&root_id).await {
+		if let Ok(userids) = self.get_participants(&root_id).await {
 			users.extend_from_slice(&userids);
 		} else {
 			users.push(root_pdu.sender);
 		}
 		users.push(pdu.sender.clone());
 
-		self.db.update_participants(&root_id, &users)
+		self.update_participants(&root_id, &users)
+	}
+
+	pub async fn threads_until<'a>(
+		&'a self, user_id: &'a UserId, room_id: &'a RoomId, shorteventid: PduCount, _inc: &'a IncludeThreads,
+	) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
+		let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
+
+		let current: RawPduId = PduId {
+			shortroomid,
+			shorteventid,
+		}
+		.into();
+
+		let stream = self
+			.db
+			.threadid_userids
+			.rev_raw_keys_from(&current)
+			.ignore_err()
+			.map(RawPduId::from)
+			.ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes())
+			.filter_map(move |pdu_id| async move {
+				let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
+				let pdu_id: PduId = pdu_id.into();
+
+				if pdu.sender != user_id {
+					pdu.remove_transaction_id().ok();
+				}
+
+				Some((pdu_id.shorteventid, pdu))
+			});
+
+		Ok(stream)
+	}
+
+	pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
+		let users = participants
+			.iter()
+			.map(|user| user.as_bytes())
+			.collect::<Vec<_>>()
+			.join(&[0xFF][..]);
+
+		self.db.threadid_userids.insert(root_id, &users);
+
+		Ok(())
+	}
+
+	pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
+		self.db.threadid_userids.get(root_id).await.deserialized()
 	}
 }
-- 
GitLab