From ab15ec6c32f2e5463369fe7a29f5ea2e6d9c4f2d Mon Sep 17 00:00:00 2001
From: Tglman <tglman@tglman.com>
Date: Fri, 18 Jun 2021 00:38:32 +0100
Subject: [PATCH] feat: Integration with persy using background ops

---
 Cargo.toml                        |   5 +
 src/database.rs                   |   6 +
 src/database/abstraction.rs       |   5 +-
 src/database/abstraction/persy.rs | 245 ++++++++++++++++++++++++++++++
 src/error.rs                      |  15 ++
 5 files changed, 275 insertions(+), 1 deletion(-)
 create mode 100644 src/database/abstraction/persy.rs

diff --git a/Cargo.toml b/Cargo.toml
index c87d949c0..2dbd3fd3d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -28,6 +28,10 @@ tokio = "1.11.0"
 # Used for storing data permanently
 sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true }
 #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] }
+persy = { git = "https://gitlab.com/tglman/persy.git", branch="master" , optional = true, features=["background_ops"] }
+# Used by the persy write cache for background flush
+timer = "0.2"
+chrono = "0.4"
 
 # Used for the http request / response body type for Ruma endpoints used with reqwest
 bytes = "1.1.0"
@@ -87,6 +91,7 @@ sha-1 = "0.9.8"
 [features]
 default = ["conduit_bin", "backend_sqlite", "backend_rocksdb"]
 backend_sled = ["sled"]
+backend_persy = ["persy"]
 backend_sqlite = ["sqlite"]
 backend_heed = ["heed", "crossbeam"]
 backend_rocksdb = ["rocksdb"]
diff --git a/src/database.rs b/src/database.rs
index d688ff9fc..c2cd9f291 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -255,6 +255,12 @@ pub async fn load_or_create(config: &Config) -> Result<Arc<TokioRwLock<Self>>> {
                 #[cfg(feature = "rocksdb")]
                 Arc::new(Arc::<abstraction::rocksdb::Engine>::open(config)?)
             }
+            "persy" => {
+                #[cfg(not(feature = "persy"))]
+                return Err(Error::BadConfig("Database backend not found."));
+                #[cfg(feature = "persy")]
+                Arc::new(Arc::<abstraction::persy::Engine>::open(config)?)
+            }
             _ => {
                 return Err(Error::BadConfig("Database backend not found."));
             }
diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs
index 17bd971fc..9a3771f34 100644
--- a/src/database/abstraction.rs
+++ b/src/database/abstraction.rs
@@ -15,7 +15,10 @@
 #[cfg(feature = "rocksdb")]
 pub mod rocksdb;
 
-#[cfg(any(feature = "sqlite", feature = "rocksdb", feature = "heed"))]
+#[cfg(feature = "persy")]
+pub mod persy;
+
+#[cfg(any(feature = "sqlite", feature = "rocksdb", feature = "heed", feature="persy"))]
 pub mod watchers;
 
 pub trait DatabaseEngine: Send + Sync {
diff --git a/src/database/abstraction/persy.rs b/src/database/abstraction/persy.rs
new file mode 100644
index 000000000..5d633ab43
--- /dev/null
+++ b/src/database/abstraction/persy.rs
@@ -0,0 +1,245 @@
+use crate::{
+    database::{
+        abstraction::{DatabaseEngine, Tree},
+        Config,
+    },
+    Result,
+};
+use persy::{ByteVec, OpenOptions, Persy, Transaction, TransactionConfig, ValueMode};
+
+use std::{
+    collections::HashMap,
+    future::Future,
+    pin::Pin,
+    sync::{Arc, RwLock},
+};
+
+use tokio::sync::oneshot::Sender;
+use tracing::warn;
+
+pub struct PersyEngine {
+    persy: Persy,
+}
+
+impl DatabaseEngine for PersyEngine {
+    fn open(config: &Config) -> Result<Arc<Self>> {
+        let mut cfg = persy::Config::new();
+        cfg.change_cache_size((config.db_cache_capacity_mb * 1024.0 * 1024.0) as u64);
+
+        let persy = OpenOptions::new()
+            .create(true)
+            .config(cfg)
+            .open(&format!("{}/db.persy", config.database_path))?;
+        Ok(Arc::new(PersyEngine { persy }))
+    }
+
+    fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
+        // Create if it doesn't exist
+        if !self.persy.exists_index(name)? {
+            let mut tx = self.persy.begin()?;
+            tx.create_index::<ByteVec, ByteVec>(name, ValueMode::Replace)?;
+            tx.prepare()?.commit()?;
+        }
+
+        Ok(Arc::new(PersyTree {
+            persy: self.persy.clone(),
+            name: name.to_owned(),
+            watchers: RwLock::new(HashMap::new()),
+        }))
+    }
+
+    fn flush(self: &Arc<Self>) -> Result<()> {
+        Ok(())
+    }
+}
+
+pub struct PersyTree {
+    persy: Persy,
+    name: String,
+    watchers: RwLock<HashMap<Vec<u8>, Vec<Sender<()>>>>,
+}
+
+impl PersyTree {
+    fn begin(&self) -> Result<Transaction> {
+        Ok(self
+            .persy
+            .begin_with(TransactionConfig::new().set_background_sync(true))?)
+    }
+}
+
+impl Tree for PersyTree {
+    #[tracing::instrument(skip(self, key))]
+    fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
+        let result = self
+            .persy
+            .get::<ByteVec, ByteVec>(&self.name, &ByteVec::from(key))?
+            .next()
+            .map(|v| (*v).to_owned());
+        Ok(result)
+    }
+
+    #[tracing::instrument(skip(self, key, value))]
+    fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
+        self.insert_batch(&mut Some((key.to_owned(), value.to_owned())).into_iter())?;
+        let watchers = self.watchers.read().unwrap();
+        let mut triggered = Vec::new();
+
+        for length in 0..=key.len() {
+            if watchers.contains_key(&key[..length]) {
+                triggered.push(&key[..length]);
+            }
+        }
+
+        drop(watchers);
+
+        if !triggered.is_empty() {
+            let mut watchers = self.watchers.write().unwrap();
+            for prefix in triggered {
+                if let Some(txs) = watchers.remove(prefix) {
+                    for tx in txs {
+                        let _ = tx.send(());
+                    }
+                }
+            }
+        }
+        Ok(())
+    }
+
+    #[tracing::instrument(skip(self, iter))]
+    fn insert_batch<'a>(&self, iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>) -> Result<()> {
+        let mut tx = self.begin()?;
+        for (key, value) in iter {
+            tx.put::<ByteVec, ByteVec>(
+                &self.name,
+                ByteVec::from(key.clone()),
+                ByteVec::from(value),
+            )?;
+        }
+        tx.prepare()?.commit()?;
+        Ok(())
+    }
+
+    #[tracing::instrument(skip(self, iter))]
+    fn increment_batch<'a>(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()> {
+        let mut tx = self.begin()?;
+        for key in iter {
+            let old = tx
+                .get::<ByteVec, ByteVec>(&self.name, &ByteVec::from(key.clone()))?
+                .next()
+                .map(|v| (*v).to_owned());
+            let new = crate::utils::increment(old.as_deref()).unwrap();
+            tx.put::<ByteVec, ByteVec>(&self.name, ByteVec::from(key), ByteVec::from(new))?;
+        }
+        tx.prepare()?.commit()?;
+        Ok(())
+    }
+
+    #[tracing::instrument(skip(self, key))]
+    fn remove(&self, key: &[u8]) -> Result<()> {
+        let mut tx = self.begin()?;
+        tx.remove::<ByteVec, ByteVec>(&self.name, ByteVec::from(key), None)?;
+        tx.prepare()?.commit()?;
+        Ok(())
+    }
+
+    #[tracing::instrument(skip(self))]
+    fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
+        let iter = self.persy.range::<ByteVec, ByteVec, _>(&self.name, ..);
+        match iter {
+            Ok(iter) => Box::new(iter.filter_map(|(k, v)| {
+                v.into_iter()
+                    .map(|val| ((*k).to_owned().into(), (*val).to_owned().into()))
+                    .next()
+            })),
+            Err(e) => {
+                warn!("error iterating {:?}", e);
+                Box::new(std::iter::empty())
+            }
+        }
+    }
+
+    #[tracing::instrument(skip(self, from, backwards))]
+    fn iter_from<'a>(
+        &'a self,
+        from: &[u8],
+        backwards: bool,
+    ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
+        let range = if backwards {
+            self.persy
+                .range::<ByteVec, ByteVec, _>(&self.name, ..=ByteVec::from(from))
+        } else {
+            self.persy
+                .range::<ByteVec, ByteVec, _>(&self.name, ByteVec::from(from)..)
+        };
+        match range {
+            Ok(iter) => {
+                let map = iter.filter_map(|(k, v)| {
+                    v.into_iter()
+                        .map(|val| ((*k).to_owned().into(), (*val).to_owned().into()))
+                        .next()
+                });
+                if backwards {
+                    Box::new(map.rev())
+                } else {
+                    Box::new(map)
+                }
+            }
+            Err(e) => {
+                warn!("error iterating with prefix {:?}", e);
+                Box::new(std::iter::empty())
+            }
+        }
+    }
+
+    #[tracing::instrument(skip(self, key))]
+    fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
+        self.increment_batch(&mut Some(key.to_owned()).into_iter())?;
+        Ok(self.get(key)?.unwrap())
+    }
+
+    #[tracing::instrument(skip(self, prefix))]
+    fn scan_prefix<'a>(
+        &'a self,
+        prefix: Vec<u8>,
+    ) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
+        let range_prefix = ByteVec::from(prefix.clone());
+        let range = self
+            .persy
+            .range::<ByteVec, ByteVec, _>(&self.name, range_prefix..);
+
+        match range {
+            Ok(iter) => {
+                let owned_prefix = prefix.clone();
+                Box::new(
+                    iter.take_while(move |(k, _)| (*k).starts_with(&owned_prefix))
+                        .filter_map(|(k, v)| {
+                            v.into_iter()
+                                .map(|val| ((*k).to_owned().into(), (*val).to_owned().into()))
+                                .next()
+                        }),
+                )
+            }
+            Err(e) => {
+                warn!("error scanning prefix {:?}", e);
+                Box::new(std::iter::empty())
+            }
+        }
+    }
+
+    #[tracing::instrument(skip(self, prefix))]
+    fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
+        let (tx, rx) = tokio::sync::oneshot::channel();
+
+        self.watchers
+            .write()
+            .unwrap()
+            .entry(prefix.to_vec())
+            .or_default()
+            .push(tx);
+
+        Box::pin(async move {
+            // Tx is never destroyed
+            rx.await.unwrap();
+        })
+    }
+}
diff --git a/src/error.rs b/src/error.rs
index 4d427da49..5ffe48c97 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -8,6 +8,9 @@
 use thiserror::Error;
 use tracing::warn;
 
+#[cfg(feature = "persy")]
+use persy::PersyError;
+
 #[cfg(feature = "conduit_bin")]
 use {
     crate::RumaResponse,
@@ -36,6 +39,9 @@ pub enum Error {
         #[from]
         source: rusqlite::Error,
     },
+    #[cfg(feature = "persy")]
+    #[error("There was a problem with the connection to the persy database.")]
+    PersyError { source: PersyError },
     #[cfg(feature = "heed")]
     #[error("There was a problem with the connection to the heed database: {error}")]
     HeedError { error: String },
@@ -142,3 +148,12 @@ fn respond_to(self, r: &'r Request<'_>) -> response::Result<'o> {
         self.to_response().respond_to(r)
     }
 }
+
+#[cfg(feature = "persy")]
+impl<T: Into<PersyError>> From<persy::PE<T>> for Error {
+    fn from(err: persy::PE<T>) -> Self {
+        Error::PersyError {
+            source: err.error().into(),
+        }
+    }
+}
-- 
GitLab