aboutsummaryrefslogtreecommitdiffstats
path: root/atuin-client/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'atuin-client/src/database.rs')
-rw-r--r--atuin-client/src/database.rs216
1 files changed, 157 insertions, 59 deletions
diff --git a/atuin-client/src/database.rs b/atuin-client/src/database.rs
index 9efde2cd..d1b892e2 100644
--- a/atuin-client/src/database.rs
+++ b/atuin-client/src/database.rs
@@ -1,3 +1,4 @@
+use std::env;
use std::path::Path;
use std::str::FromStr;
@@ -16,7 +17,29 @@ use sqlx::{
use super::history::History;
use super::ordering;
-use super::settings::SearchMode;
+use super::settings::{FilterMode, SearchMode};
+
+pub struct Context {
+ session: String,
+ cwd: String,
+ hostname: String,
+}
+
+pub fn current_context() -> Context {
+ let session =
+ env::var("ATUIN_SESSION").expect("failed to find ATUIN_SESSION - check your shell setup");
+ let hostname = format!("{}:{}", whoami::hostname(), whoami::username());
+ let cwd = match env::current_dir() {
+ Ok(dir) => dir.display().to_string(),
+ Err(_) => String::from(""),
+ };
+
+ Context {
+ session,
+ hostname,
+ cwd,
+ }
+}
#[async_trait]
pub trait Database {
@@ -24,7 +47,13 @@ pub trait Database {
async fn save_bulk(&mut self, h: &[History]) -> Result<()>;
async fn load(&self, id: &str) -> Result<History>;
- async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>>;
+ async fn list(
+ &self,
+ filter: FilterMode,
+ context: &Context,
+ max: Option<usize>,
+ unique: bool,
+ ) -> Result<Vec<History>>;
async fn range(
&self,
from: chrono::DateTime<Utc>,
@@ -42,6 +71,8 @@ pub trait Database {
&self,
limit: Option<i64>,
search_mode: SearchMode,
+ filter: FilterMode,
+ context: &Context,
query: &str,
) -> Result<Vec<History>>;
@@ -179,33 +210,52 @@ impl Database for Sqlite {
}
// make a unique list, that only shows the *newest* version of things
- async fn list(&self, max: Option<usize>, unique: bool) -> Result<Vec<History>> {
+ async fn list(
+ &self,
+ filter: FilterMode,
+ context: &Context,
+ max: Option<usize>,
+ unique: bool,
+ ) -> Result<Vec<History>> {
debug!("listing history");
- // very likely vulnerable to SQL injection
- // however, this is client side, and only used by the client, on their
- // own data. They can just open the db file...
- // otherwise building the query is awkward
+ // gotta get that query builder in soon cuz I kinda hate this
+ let query = if unique {
+ "where timestamp = (
+ select max(timestamp) from history
+ where h.command = history.command
+ )"
+ } else {
+ ""
+ }
+ .to_string();
+
+ let mut join = if unique { "and" } else { "where" }.to_string();
+
+ let filter_query = match filter {
+ FilterMode::Global => {
+ join = "".to_string();
+ "".to_string()
+ }
+ FilterMode::Host => format!("hostname = '{}'", context.hostname).to_string(),
+ FilterMode::Session => format!("session = '{}'", context.session).to_string(),
+ FilterMode::Directory => format!("cwd = '{}'", context.cwd).to_string(),
+ };
+
+ let filter = format!("{} {}", join, filter_query);
+
+ let limit = if let Some(max) = max {
+ format!("limit {}", max)
+ } else {
+ "".to_string()
+ };
+
let query = format!(
"select * from history h
{}
order by timestamp desc
- {}",
- // inject the unique check
- if unique {
- "where timestamp = (
- select max(timestamp) from history
- where h.command = history.command
- )"
- } else {
- ""
- },
- // inject the limit
- if let Some(max) = max {
- format!("limit {}", max)
- } else {
- "".to_string()
- }
+ {} {}",
+ query, filter, limit,
);
let res = sqlx::query(query.as_str())
@@ -281,6 +331,8 @@ impl Database for Sqlite {
&self,
limit: Option<i64>,
search_mode: SearchMode,
+ filter: FilterMode,
+ context: &Context,
query: &str,
) -> Result<Vec<History>> {
let orig_query = query;
@@ -350,6 +402,13 @@ impl Database for Sqlite {
}
};
+ let filter_sql = match filter {
+ FilterMode::Global => String::from(""),
+ FilterMode::Session => format!("and session = '{}'", context.session),
+ FilterMode::Directory => format!("and cwd = '{}'", context.cwd),
+ FilterMode::Host => format!("and hostname = '{}'", context.hostname),
+ };
+
let res = query_params
.iter()
.fold(
@@ -357,10 +416,12 @@ impl Database for Sqlite {
format!(
"select * from history h
where {}
+ {}
group by command
having max(timestamp)
order by timestamp desc {}",
query_sql.as_str(),
+ filter_sql.as_str(),
limit.clone()
)
.as_str(),
@@ -392,10 +453,18 @@ mod test {
async fn assert_search_eq<'a>(
db: &impl Database,
mode: SearchMode,
+ filter_mode: FilterMode,
query: &str,
expected: usize,
) -> Result<Vec<History>> {
- let results = db.search(None, mode, query).await?;
+ let context = Context {
+ hostname: "test:host".to_string(),
+ session: "beepboopiamasession".to_string(),
+ cwd: "/home/ellie".to_string(),
+ };
+
+ let results = db.search(None, mode, filter_mode, &context, query).await?;
+
assert_eq!(
results.len(),
expected,
@@ -409,10 +478,11 @@ mod test {
async fn assert_search_commands(
db: &impl Database,
mode: SearchMode,
+ filter_mode: FilterMode,
query: &str,
expected_commands: Vec<&str>,
) {
- let results = assert_search_eq(db, mode, query, expected_commands.len())
+ let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len())
.await
.unwrap();
let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect();
@@ -437,13 +507,13 @@ mod test {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
- assert_search_eq(&db, SearchMode::Prefix, "ls", 1)
+ assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Prefix, "/home", 0)
+ assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Prefix, "ls ", 0)
+ assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0)
.await
.unwrap();
}
@@ -453,13 +523,13 @@ mod test {
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
new_history_item(&mut db, "ls /home/ellie").await.unwrap();
- assert_search_eq(&db, SearchMode::FullText, "ls", 1)
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::FullText, "/home", 1)
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::FullText, "ls ", 0)
+ assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ", 0)
.await
.unwrap();
}
@@ -474,70 +544,82 @@ mod test {
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "ls /", 3)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "ls/", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "l/h/", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "/h/e", 3)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "/hmoe/", 0)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "ellie/home", 0)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "lsellie", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, " ", 4)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4)
.await
.unwrap();
// single term operators
- assert_search_eq(&db, SearchMode::Fuzzy, "^ls", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "'ls", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "ellie$", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "!^ls", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "!ellie", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "!ellie$", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2)
.await
.unwrap();
// multiple terms
- assert_search_eq(&db, SearchMode::Fuzzy, "ls !ellie", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "^ls !e$", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "home !^ls", 2)
- .await
- .unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup", 2)
- .await
- .unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "'frank | 'rustup 'ls", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2)
.await
.unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::Fuzzy,
+ FilterMode::Global,
+ "'frank | 'rustup",
+ 2,
+ )
+ .await
+ .unwrap();
+ assert_search_eq(
+ &db,
+ SearchMode::Fuzzy,
+ FilterMode::Global,
+ "'frank | 'rustup 'ls",
+ 1,
+ )
+ .await
+ .unwrap();
// case matching
- assert_search_eq(&db, SearchMode::Fuzzy, "Ellie", 1)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1)
.await
.unwrap();
}
@@ -551,18 +633,31 @@ mod test {
new_history_item(&mut db, "corburl").await.unwrap();
// if fuzzy reordering is on, it should come back in a more sensible order
- assert_search_commands(&db, SearchMode::Fuzzy, "curl", vec!["curl", "corburl"]).await;
+ assert_search_commands(
+ &db,
+ SearchMode::Fuzzy,
+ FilterMode::Global,
+ "curl",
+ vec!["curl", "corburl"],
+ )
+ .await;
- assert_search_eq(&db, SearchMode::Fuzzy, "xxxx", 0)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0)
.await
.unwrap();
- assert_search_eq(&db, SearchMode::Fuzzy, "", 2)
+ assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2)
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_search_bench_dupes() {
+ let context = Context {
+ hostname: "test:host".to_string(),
+ session: "beepboopiamasession".to_string(),
+ cwd: "/home/ellie".to_string(),
+ };
+
let mut db = Sqlite::new("sqlite::memory:").await.unwrap();
for _i in 1..10000 {
new_history_item(&mut db, "i am a duplicated command")
@@ -570,7 +665,10 @@ mod test {
.unwrap();
}
let start = Instant::now();
- let _results = db.search(None, SearchMode::Fuzzy, "").await.unwrap();
+ let _results = db
+ .search(None, SearchMode::Fuzzy, FilterMode::Global, &context, "")
+ .await
+ .unwrap();
let duration = start.elapsed();
assert!(duration < Duration::from_secs(15));