aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-client/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-client/src/database.rs')
-rw-r--r--crates/atuin-client/src/database.rs192
1 files changed, 192 insertions, 0 deletions
diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs
index 5f292bec..7c63368d 100644
--- a/crates/atuin-client/src/database.rs
+++ b/crates/atuin-client/src/database.rs
@@ -138,9 +138,13 @@ pub trait Database: Send + Sync + 'static {
async fn all_with_count(&self) -> Result<Vec<(History, i32)>>;
+ fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged;
+
async fn stats(&self, h: &History) -> Result<HistoryStats>;
async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>;
+
+ fn clone_boxed(&self) -> Box<dyn Database + 'static>;
}
// Intended for use on a developer machine and not a sync server.
@@ -650,6 +654,10 @@ impl Database for Sqlite {
Ok(res)
}
+ fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged {
+ Paged::new(Box::new(self.clone()), page_size, include_deleted, unique)
+ }
+
// deleted_at doesn't mean the actual time that the user deleted it,
// but the time that the system marks it as deleted
async fn delete(&self, mut h: History) -> Result<()> {
@@ -814,6 +822,70 @@ impl Database for Sqlite {
Ok(res)
}
+
+ fn clone_boxed(&self) -> Box<dyn Database + 'static> {
+ Box::new(self.clone())
+ }
+}
+
+pub struct Paged {
+ database: Box<dyn Database + 'static>,
+ page_size: usize,
+ last_id: Option<String>,
+ include_deleted: bool,
+ unique: bool,
+}
+
+impl Paged {
+ pub fn new(
+ database: Box<dyn Database + 'static>,
+ page_size: usize,
+ include_deleted: bool,
+ unique: bool,
+ ) -> Self {
+ Self {
+ database,
+ page_size,
+ last_id: None,
+ include_deleted,
+ unique,
+ }
+ }
+
+ pub async fn next(&mut self) -> Result<Option<Vec<History>>> {
+ let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted());
+
+ query.field("*").order_desc("id");
+
+ if !self.include_deleted {
+ query.and_where_is_null("deleted_at");
+ }
+
+ if self.unique {
+ // We want to deduplicate on command, but the user can search via cwd, hostname, and session.
+ // Without those fields, filter modes won't work right. With those fields, we get duplicates.
+ // This must be handled upstream.
+ query
+ .group_by("command, cwd, hostname, session")
+ .having("max(timestamp)");
+ }
+
+ query.limit(self.page_size);
+
+ if let Some(last_id) = &self.last_id {
+ query.and_where_lt("id", quote(last_id));
+ }
+
+ let query = query.sql().expect("bug in list query. please report");
+ let res = self.database.query_history(&query).await?;
+
+ if res.is_empty() {
+ Ok(None)
+ } else {
+ self.last_id = Some(res.last().unwrap().id.0.clone());
+ Ok(Some(res))
+ }
+ }
}
trait SqlBuilderExt {
@@ -1166,6 +1238,126 @@ mod test {
}
#[tokio::test(flavor = "multi_thread")]
+ async fn test_paged_basic() {
+ let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
+ .await
+ .unwrap();
+
+ // Add 5 history items
+ for i in 0..5 {
+ new_history_item(&mut db, &format!("command{}", i))
+ .await
+ .unwrap();
+ }
+
+ // Create a paged iterator with page_size of 2
+ let mut paged = db.all_paged(2, false, false);
+
+ // First page should have 2 items
+ let page1 = paged.next().await.unwrap();
+ assert!(page1.is_some());
+ assert_eq!(page1.unwrap().len(), 2);
+
+ // Second page should have 2 items
+ let page2 = paged.next().await.unwrap();
+ assert!(page2.is_some());
+ assert_eq!(page2.unwrap().len(), 2);
+
+ // Third page should have 1 item
+ let page3 = paged.next().await.unwrap();
+ assert!(page3.is_some());
+ assert_eq!(page3.unwrap().len(), 1);
+
+ // Fourth page should be None (exhausted)
+ let page4 = paged.next().await.unwrap();
+ assert!(page4.is_none());
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn test_paged_empty() {
+ let db = Sqlite::new("sqlite::memory:", test_local_timeout())
+ .await
+ .unwrap();
+
+ // Create a paged iterator on empty database
+ let mut paged = db.all_paged(10, false, false);
+
+ // Should return None immediately
+ let page = paged.next().await.unwrap();
+ assert!(page.is_none());
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn test_paged_unique() {
+ let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
+ .await
+ .unwrap();
+
+ // Add duplicate commands
+ new_history_item(&mut db, "duplicate").await.unwrap();
+ new_history_item(&mut db, "duplicate").await.unwrap();
+ new_history_item(&mut db, "unique1").await.unwrap();
+ new_history_item(&mut db, "unique2").await.unwrap();
+
+ // Without unique flag - should get all 4
+ let mut paged = db.all_paged(10, false, false);
+ let page = paged.next().await.unwrap().unwrap();
+ assert_eq!(page.len(), 4);
+
+ // With unique flag - should get 3 (duplicates collapsed)
+ let mut paged_unique = db.all_paged(10, false, true);
+ let page_unique = paged_unique.next().await.unwrap().unwrap();
+ assert_eq!(page_unique.len(), 3);
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
+ async fn test_paged_include_deleted() {
+ let mut db = Sqlite::new("sqlite::memory:", test_local_timeout())
+ .await
+ .unwrap();
+
+ // Add items
+ new_history_item(&mut db, "keep1").await.unwrap();
+ new_history_item(&mut db, "keep2").await.unwrap();
+ new_history_item(&mut db, "delete_me").await.unwrap();
+
+ // Delete one item
+ let all = db
+ .list(
+ &[],
+ &Context {
+ hostname: "".to_string(),
+ session: "".to_string(),
+ cwd: "".to_string(),
+ host_id: "".to_string(),
+ git_root: None,
+ },
+ None,
+ false,
+ false,
+ )
+ .await
+ .unwrap();
+
+ let to_delete = all
+ .iter()
+ .find(|h| h.command == "delete_me")
+ .unwrap()
+ .clone();
+ db.delete(to_delete).await.unwrap();
+
+ // Without include_deleted - should get 2
+ let mut paged = db.all_paged(10, false, false);
+ let page = paged.next().await.unwrap().unwrap();
+ assert_eq!(page.len(), 2);
+
+ // With include_deleted - should get 3
+ let mut paged_deleted = db.all_paged(10, true, false);
+ let page_deleted = paged_deleted.next().await.unwrap().unwrap();
+ assert_eq!(page_deleted.len(), 3);
+ }
+
+ #[tokio::test(flavor = "multi_thread")]
async fn test_search_bench_dupes() {
let context = Context {
hostname: "test:host".to_string(),