diff options
Diffstat (limited to 'atuin-client/src/database.rs')
| -rw-r--r-- | atuin-client/src/database.rs | 216 |
1 files changed, 157 insertions, 59 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs index 9efde2cd..d1b892e2 100644 --- a/atuin-client/src/database.rs +++ b/atuin-client/src/database.rs @@ -1,3 +1,4 @@ +use std::env; use std::path::Path; use std::str::FromStr; @@ -16,7 +17,29 @@ use sqlx::{ use super::history::History; use super::ordering; -use super::settings::SearchMode; +use super::settings::{FilterMode, SearchMode}; + +pub struct Context { + session: String, + cwd: String, + hostname: String, +} + +pub fn current_context() -> Context { + let session = + env::var("ATUIN_SESSION").expect("failed to find ATUIN_SESSION - check your shell setup"); + let hostname = format!("{}:{}", whoami::hostname(), whoami::username()); + let cwd = match env::current_dir() { + Ok(dir) => dir.display().to_string(), + Err(_) => String::from(""), + }; + + Context { + session, + hostname, + cwd, + } +} #[async_trait] pub trait Database { @@ -24,7 +47,13 @@ pub trait Database { async fn save_bulk(&mut self, h: &[History]) -> Result<()>; async fn load(&self, id: &str) -> Result<History>; - async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>; + async fn list( + &self, + filter: FilterMode, + context: &Context, + max: Option<usize>, + unique: bool, + ) -> Result<Vec<History>>; async fn range( &self, from: chrono::DateTime<Utc>, @@ -42,6 +71,8 @@ pub trait Database { &self, limit: Option<i64>, search_mode: SearchMode, + filter: FilterMode, + context: &Context, query: &str, ) -> Result<Vec<History>>; @@ -179,33 +210,52 @@ impl Database for Sqlite { } // make a unique list, that only shows the *newest* version of things - async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> { + async fn list( + &self, + filter: FilterMode, + context: &Context, + max: Option<usize>, + unique: bool, + ) -> Result<Vec<History>> { debug!("listing history"); - // very likely vulnerable to SQL injection - // however, this is client side, and only used by the client, on their - // own data. They can just open the db file... - // otherwise building the query is awkward + // gotta get that query builder in soon cuz I kinda hate this + let query = if unique { + "where timestamp = ( + select max(timestamp) from history + where h.command = history.command + )" + } else { + "" + } + .to_string(); + + let mut join = if unique { "and" } else { "where" }.to_string(); + + let filter_query = match filter { + FilterMode::Global => { + join = "".to_string(); + "".to_string() + } + FilterMode::Host => format!("hostname = '{}'", context.hostname).to_string(), + FilterMode::Session => format!("session = '{}'", context.session).to_string(), + FilterMode::Directory => format!("cwd = '{}'", context.cwd).to_string(), + }; + + let filter = format!("{} {}", join, filter_query); + + let limit = if let Some(max) = max { + format!("limit {}", max) + } else { + "".to_string() + }; + let query = format!( "select * from history h {} order by timestamp desc - {}", - // inject the unique check - if unique { - "where timestamp = ( - select max(timestamp) from history - where h.command = history.command - )" - } else { - "" - }, - // inject the limit - if let Some(max) = max { - format!("limit {}", max) - } else { - "".to_string() - } + {} {}", + query, filter, limit, ); let res = sqlx::query(query.as_str()) @@ -281,6 +331,8 @@ impl Database for Sqlite { &self, limit: Option<i64>, search_mode: SearchMode, + filter: FilterMode, + context: &Context, query: &str, ) -> Result<Vec<History>> { let orig_query = query; @@ -350,6 +402,13 @@ impl Database for Sqlite { } }; + let filter_sql = match filter { + FilterMode::Global => String::from(""), + FilterMode::Session => format!("and session = '{}'", context.session), + FilterMode::Directory => format!("and cwd = '{}'", context.cwd), + FilterMode::Host => format!("and hostname = '{}'", context.hostname), + }; + let res = query_params .iter() .fold( @@ -357,10 +416,12 @@ impl Database for Sqlite { format!( "select * from history h where {} + {} group by command having max(timestamp) order by timestamp desc {}", query_sql.as_str(), + filter_sql.as_str(), limit.clone() ) .as_str(), @@ -392,10 +453,18 @@ mod test { async fn assert_search_eq<'a>( db: &impl Database, mode: SearchMode, + filter_mode: FilterMode, query: &str, expected: usize, ) -> Result<Vec<History>> { - let results = db.search(None, mode, query).await?; + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + }; + + let results = db.search(None, mode, filter_mode, &context, query).await?; + assert_eq!( results.len(), expected, @@ -409,10 +478,11 @@ mod test { async fn assert_search_commands( db: &impl Database, mode: SearchMode, + filter_mode: FilterMode, query: &str, expected_commands: Vec<&str>, ) { - let results = assert_search_eq(db, mode, query, expected_commands.len()) + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) .await .unwrap(); let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); @@ -437,13 +507,13 @@ mod test { let mut db = Sqlite::new("sqlite::memory:").await.unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - assert_search_eq(&db, SearchMode::Prefix, "ls", 1) + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, "/home", 0) + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) .await .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, "ls ", 0) + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) .await .unwrap(); } @@ -453,13 +523,13 @@ mod test { let mut db = Sqlite::new("sqlite::memory:").await.unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - assert_search_eq(&db, SearchMode::FullText, "ls", 1) + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::FullText, "/home", 1) + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::FullText, "ls ", 0) + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ", 0) .await .unwrap(); } @@ -474,70 +544,82 @@ mod test { .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "ls /", 3) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "ls/", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "l/h/", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "/h/e", 3) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "/hmoe/", 0) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "ellie/home", 0) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "lsellie", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, " ", 4) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) .await .unwrap(); // single term operators - assert_search_eq(&db, SearchMode::Fuzzy, "^ls", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "'ls", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "ellie$", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "!^ls", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "!ellie", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "!ellie$", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) .await .unwrap(); // multiple terms - assert_search_eq(&db, SearchMode::Fuzzy, "ls !ellie", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "^ls !e$", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "home !^ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup 'ls", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) .await .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); // case matching - assert_search_eq(&db, SearchMode::Fuzzy, "Ellie", 1) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) .await .unwrap(); } @@ -551,18 +633,31 @@ mod test { new_history_item(&mut db, "corburl").await.unwrap(); // if fuzzy reordering is on, it should come back in a more sensible order - assert_search_commands(&db, SearchMode::Fuzzy, "curl", vec!["curl", "corburl"]).await; + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; - assert_search_eq(&db, SearchMode::Fuzzy, "xxxx", 0) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) .await .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, "", 2) + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) .await .unwrap(); } #[tokio::test(flavor = "multi_thread")] async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + }; + let mut db = Sqlite::new("sqlite::memory:").await.unwrap(); for _i in 1..10000 { new_history_item(&mut db, "i am a duplicated command") @@ -570,7 +665,10 @@ mod test { .unwrap(); } let start = Instant::now(); - let _results = db.search(None, SearchMode::Fuzzy, "").await.unwrap(); + let _results = db + .search(None, SearchMode::Fuzzy, FilterMode::Global, &context, "") + .await + .unwrap(); let duration = start.elapsed(); assert!(duration < Duration::from_secs(15)); |
