From dba0575e7584f4b5440b61df401fbd50462b8df3 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Wed, 10 Apr 2024 13:55:09 -0700
Subject: [PATCH] some optimizations to get_auth_chain()

Signed-off-by: Jason Volk <jason@zemos.net>
---
 src/api/server_server.rs               |   8 +-
 src/service/admin/debug.rs             |   2 +-
 src/service/rooms/auth_chain/mod.rs    | 103 +++++++++++++++----------
 src/service/rooms/event_handler/mod.rs |   4 +-
 4 files changed, 71 insertions(+), 46 deletions(-)

diff --git a/src/api/server_server.rs b/src/api/server_server.rs
index f01ff273e..85f3e0ab9 100644
--- a/src/api/server_server.rs
+++ b/src/api/server_server.rs
@@ -734,7 +734,7 @@ pub async fn get_event_authorization_route(
 	let auth_chain_ids = services()
 		.rooms
 		.auth_chain
-		.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)])
+		.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
 		.await?;
 
 	Ok(get_event_authorization::v1::Response {
@@ -794,7 +794,7 @@ pub async fn get_room_state_route(body: Ruma<get_room_state::v1::Request>) -> Re
 	let auth_chain_ids = services()
 		.rooms
 		.auth_chain
-		.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)])
+		.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
 		.await?;
 
 	Ok(get_room_state::v1::Response {
@@ -854,7 +854,7 @@ pub async fn get_room_state_ids_route(
 	let auth_chain_ids = services()
 		.rooms
 		.auth_chain
-		.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)])
+		.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
 		.await?;
 
 	Ok(get_room_state_ids::v1::Response {
@@ -1142,7 +1142,7 @@ async fn create_join_event(
 	let auth_chain_ids = services()
 		.rooms
 		.auth_chain
-		.get_auth_chain(room_id, state_ids.values().cloned().collect())
+		.event_ids_iter(room_id, state_ids.values().cloned().collect())
 		.await?;
 
 	services().sending.send_pdu_room(room_id, &pdu_id)?;
diff --git a/src/service/admin/debug.rs b/src/service/admin/debug.rs
index e2775c939..c27f5900c 100644
--- a/src/service/admin/debug.rs
+++ b/src/service/admin/debug.rs
@@ -104,7 +104,7 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro
 				let count = services()
 					.rooms
 					.auth_chain
-					.get_auth_chain(room_id, vec![event_id])
+					.event_ids_iter(room_id, vec![event_id])
 					.await?
 					.count();
 				let elapsed = start.elapsed();
diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs
index 03c49faf8..f3bf50f82 100644
--- a/src/service/rooms/auth_chain/mod.rs
+++ b/src/service/rooms/auth_chain/mod.rs
@@ -15,28 +15,47 @@ pub struct Service {
 }
 
 impl Service {
-	pub async fn get_auth_chain<'a>(
-		&self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
+	pub async fn event_ids_iter<'a>(
+		&self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
 	) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
-		const NUM_BUCKETS: usize = 50;
-
-		let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
-
-		let mut i = 0;
-		for id in starting_events {
-			let short = services().rooms.short.get_or_create_shorteventid(&id)?;
-			let bucket_id = (short % NUM_BUCKETS as u64) as usize;
-			buckets[bucket_id].insert((short, id.clone()));
-			i += 1;
-			if i % 100 == 0 {
-				tokio::task::yield_now().await;
-			}
+		let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
+		for starting_event in &starting_events_ {
+			starting_events.push(starting_event);
+		}
+
+		Ok(self
+			.get_auth_chain(room_id, &starting_events)
+			.await?
+			.into_iter()
+			.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
+	}
+
+	pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<u64>> {
+		const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
+		const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
+
+		let started = std::time::Instant::now();
+		let mut buckets = [BUCKET; NUM_BUCKETS];
+		for (i, short) in services()
+			.rooms
+			.short
+			.multi_get_or_create_shorteventid(starting_events)?
+			.iter()
+			.enumerate()
+		{
+			let bucket = short % NUM_BUCKETS as u64;
+			buckets[bucket as usize].insert((*short, starting_events[i]));
 		}
 
-		let mut full_auth_chain = HashSet::new();
+		debug!(
+			starting_events = ?starting_events.len(),
+			elapsed = ?started.elapsed(),
+			"start",
+		);
 
 		let mut hits = 0;
 		let mut misses = 0;
+		let mut full_auth_chain = Vec::new();
 		for chunk in buckets {
 			if chunk.is_empty() {
 				continue;
@@ -48,68 +67,68 @@ pub async fn get_auth_chain<'a>(
 				.auth_chain
 				.get_cached_eventid_authchain(&chunk_key)?
 			{
-				hits += 1;
 				full_auth_chain.extend(cached.iter().copied());
+				hits += 1;
 				continue;
 			}
-			misses += 1;
 
-			let mut chunk_cache = HashSet::new();
 			let mut hits2 = 0;
 			let mut misses2 = 0;
-			let mut i = 0;
+			let mut chunk_cache = Vec::new();
 			for (sevent_id, event_id) in chunk {
 				if let Some(cached) = services()
 					.rooms
 					.auth_chain
 					.get_cached_eventid_authchain(&[sevent_id])?
 				{
-					hits2 += 1;
 					chunk_cache.extend(cached.iter().copied());
+					hits2 += 1;
 				} else {
-					misses2 += 1;
-					let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
+					let auth_chain = self.get_auth_chain_inner(room_id, event_id)?;
 					services()
 						.rooms
 						.auth_chain
 						.cache_auth_chain(vec![sevent_id], &auth_chain)?;
+					chunk_cache.extend(auth_chain.iter());
+					misses2 += 1;
 					debug!(
 						event_id = ?event_id,
 						chain_length = ?auth_chain.len(),
+						chunk_cache_length = ?chunk_cache.len(),
+						elapsed = ?started.elapsed(),
 						"Cache missed event"
 					);
-					chunk_cache.extend(auth_chain.iter());
-
-					i += 1;
-					if i % 100 == 0 {
-						tokio::task::yield_now().await;
-					}
 				};
 			}
+
+			chunk_cache.sort_unstable();
+			chunk_cache.dedup();
+			services()
+				.rooms
+				.auth_chain
+				.cache_auth_chain_vec(chunk_key, &chunk_cache)?;
+			full_auth_chain.extend(chunk_cache.iter());
+			misses += 1;
 			debug!(
 				chunk_cache_length = ?chunk_cache.len(),
 				hits = ?hits2,
 				misses = ?misses2,
+				elapsed = ?started.elapsed(),
 				"Chunk missed",
 			);
-			let chunk_cache = Arc::new(chunk_cache);
-			services()
-				.rooms
-				.auth_chain
-				.cache_auth_chain(chunk_key, &chunk_cache)?;
-			full_auth_chain.extend(chunk_cache.iter());
 		}
 
+		full_auth_chain.sort();
+		full_auth_chain.dedup();
 		debug!(
 			chain_length = ?full_auth_chain.len(),
 			hits = ?hits,
 			misses = ?misses,
-			"Auth chain stats",
+			elapsed = ?started.elapsed(),
+			"done",
 		);
 
-		Ok(full_auth_chain
-			.into_iter()
-			.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
+		Ok(full_auth_chain)
 	}
 
 	#[tracing::instrument(skip(self, event_id))]
@@ -155,4 +174,10 @@ pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) -> Resu
 		self.db
 			.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
 	}
+
+	#[tracing::instrument(skip(self))]
+	pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) -> Result<()> {
+		self.db
+			.cache_auth_chain(key, auth_chain.iter().copied().collect::<Arc<[u64]>>())
+	}
 }
diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs
index 0f89f5016..43fcf8f79 100644
--- a/src/service/rooms/event_handler/mod.rs
+++ b/src/service/rooms/event_handler/mod.rs
@@ -578,7 +578,7 @@ pub async fn upgrade_outlier_to_timeline_pdu(
 						services()
 							.rooms
 							.auth_chain
-							.get_auth_chain(room_id, starting_events)
+							.event_ids_iter(room_id, starting_events)
 							.await?
 							.collect(),
 					);
@@ -909,7 +909,7 @@ async fn resolve_state(
 				services()
 					.rooms
 					.auth_chain
-					.get_auth_chain(room_id, state.iter().map(|(_, id)| id.clone()).collect())
+					.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect())
 					.await?
 					.collect(),
 			);
-- 
GitLab