diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-12 01:54:21 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-12 01:54:21 +0200 |
| commit | bbdf38018b47328b5faa2cef635c37095045be72 (patch) | |
| tree | 8983817d547551ae12508a8ae8731b622d990af4 | |
| parent | feat(server): Make user stuff stateless (diff) | |
| download | atuin-bbdf38018b47328b5faa2cef635c37095045be72.zip | |
feat(server): Really make users stateless (with tests)
This commit also remove another load of unneeded features.
119 files changed, 2761 insertions, 4310 deletions
diff --git a/crates/turtle/db/server-pg-migrations/20230623070418_records.sql b/crates/turtle/db/server-pg-migrations/20230623070418_records.sql index 22437595..a3e5de2e 100644 --- a/crates/turtle/db/server-pg-migrations/20230623070418_records.sql +++ b/crates/turtle/db/server-pg-migrations/20230623070418_records.sql @@ -8,7 +8,7 @@ create table records ( version text not null, tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host data text not null, -- store the actual history data, encrypted. I don't wanna know! - cek text not null, + cek text not null, user_id bigint not null, -- allow multiple users created_at timestamp not null default current_timestamp diff --git a/crates/turtle/db/server-pg-migrations/20240614104159_idx-cache.sql b/crates/turtle/db/server-pg-migrations/20240614104159_idx-cache.sql index 76425ed7..12bbbecc 100644 --- a/crates/turtle/db/server-pg-migrations/20240614104159_idx-cache.sql +++ b/crates/turtle/db/server-pg-migrations/20240614104159_idx-cache.sql @@ -1,8 +1,8 @@ -create table store_idx_cache( - id bigserial primary key, - user_id bigint, +CREATE TABLE store_idx_cache( + id BIGSERIAL PRIMARY KEY, + user_id UUID, - host uuid, - tag text, - idx bigint + host UUID, + tag TEXT, + idx BIGINT ); diff --git a/crates/turtle/db/server-pg-migrations/20260611222503_make_user-id_an_uuid.sql b/crates/turtle/db/server-pg-migrations/20260611222503_make_user-id_an_uuid.sql new file mode 100644 index 00000000..d31c23e2 --- /dev/null +++ b/crates/turtle/db/server-pg-migrations/20260611222503_make_user-id_an_uuid.sql @@ -0,0 +1,11 @@ +-- Add migration script here + +ALTER TABLE records +DROP COLUMN IF EXISTS user_id; +ALTER TABLE records +ADD COLUMN user_id UUID NOT NULL; + +ALTER TABLE store +DROP COLUMN IF EXISTS user_id; +ALTER TABLE store +ADD COLUMN user_id UUID NOT NULL; diff --git a/crates/turtle/src/atuin_client/api_client.rs b/crates/turtle/src/atuin_client/api_client.rs index b4657a47..15d96d93 100644 --- a/crates/turtle/src/atuin_client/api_client.rs +++ b/crates/turtle/src/atuin_client/api_client.rs @@ -2,52 +2,33 @@ use std::env; use std::time::Duration; use eyre::{Result, bail, eyre}; -use reqwest::{ - Response, StatusCode, Url, - header::{AUTHORIZATION, HeaderMap}, -}; +use reqwest::{Response, StatusCode, Url, header::HeaderMap}; use tracing::debug; +use uuid::Uuid; +use crate::atuin_common::{api::ErrorResponse, record::RecordStatus}; use crate::atuin_common::{ api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, record::{EncryptedData, HostId, Record, RecordIdx}, tls::ensure_crypto_provider, }; -use crate::atuin_common::{ - api::{ErrorResponse, MeResponse}, - record::RecordStatus, -}; use semver::Version; static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); -/// Authentication token for sync API requests. -/// -/// Used with `Token <token>` header. -#[derive(Debug, Clone)] -pub(crate) struct AuthToken(pub(crate) String); - -impl AuthToken { - /// Format the token as an Authorization header value - fn to_header_value(&self) -> String { - format!("Token {}", self.0) - } -} - pub(crate) struct Client<'a> { sync_addr: &'a str, client: reqwest::Client, + user_id: Uuid, } -fn make_url(address: &str, path: &str) -> Result<String> { +fn make_url(address: &str, path: &str, user_id: Uuid) -> Result<String> { + let address = address.strip_suffix('/').unwrap_or(address); + // `join()` expects a trailing `/` in order to join paths // e.g. it treats `http://host:port/subdir` as a file called `subdir` - let address = if address.ends_with('/') { - address - } else { - &format!("{address}/") - }; + let address = &format!("{address}/api/v0/{}/", user_id.to_string()); // passing a path with a leading `/` will cause `join()` to replace the entire URL path let path = path.strip_prefix("/").unwrap_or(path); @@ -123,18 +104,18 @@ async fn handle_resp_error(resp: Response) -> Result<Response> { impl<'a> Client<'a> { pub(crate) fn new( sync_addr: &'a str, - auth: AuthToken, connect_timeout: u64, timeout: u64, + user_id: Uuid, ) -> Result<Self> { ensure_crypto_provider(); let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); // used for semver server check headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); Ok(Client { + user_id, sync_addr, client: reqwest::Client::builder() .user_agent(APP_USER_AGENT) @@ -145,20 +126,8 @@ impl<'a> Client<'a> { }) } - pub(crate) async fn me(&self) -> Result<MeResponse> { - let url = make_url(self.sync_addr, "/api/v0/me")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let status = resp.json::<MeResponse>().await?; - - Ok(status) - } - pub(crate) async fn delete_store(&self) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/store")?; + let url = make_url(self.sync_addr, "/store", self.user_id)?; let url = Url::parse(url.as_str())?; let resp = self.client.delete(url).send().await?; @@ -169,7 +138,7 @@ impl<'a> Client<'a> { } pub(crate) async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = make_url(self.sync_addr, "/record", self.user_id)?; let url = Url::parse(url.as_str())?; debug!("uploading {} records to {url}", records.len()); @@ -192,9 +161,10 @@ impl<'a> Client<'a> { let url = make_url( self.sync_addr, &format!( - "/api/v0/record/next?host={}&tag={}&count={}&start={}", + "/record/next?host={}&tag={}&count={}&start={}", host.0, tag, count, start ), + self.user_id, )?; let url = Url::parse(url.as_str())?; @@ -208,7 +178,7 @@ impl<'a> Client<'a> { } pub(crate) async fn record_status(&self) -> Result<RecordStatus> { - let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = make_url(self.sync_addr, "/record", self.user_id)?; let url = Url::parse(url.as_str())?; let resp = self.client.get(url).send().await?; diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs index 1bfe93a7..f8b73809 100644 --- a/crates/turtle/src/atuin_client/database.rs +++ b/crates/turtle/src/atuin_client/database.rs @@ -5,9 +5,7 @@ use std::{ time::Duration, }; -use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; use crate::atuin_common::utils; -use async_trait::async_trait; use fs_err as fs; use itertools::Itertools; use rand::{Rng, distributions::Alphanumeric}; @@ -55,8 +53,6 @@ pub(crate) struct OptFilters { pub(crate) offset: Option<i64>, pub(crate) reverse: bool, pub(crate) include_duplicates: bool, - /// Author filter. Supports special values `$all-user` and `$all-agent`. - pub(crate) authors: Vec<String>, } pub(crate) async fn current_context() -> eyre::Result<Context> { @@ -80,47 +76,15 @@ pub(crate) async fn current_context() -> eyre::Result<Context> { impl Context { pub(crate) fn from_history(entry: &History) -> Self { Context { - session: entry.session.to_string(), - cwd: entry.cwd.to_string(), - hostname: entry.hostname.to_string(), + session: entry.session.clone(), + cwd: entry.cwd.clone(), + hostname: entry.hostname.clone(), host_id: String::new(), git_root: utils::in_git_repo(entry.cwd.as_str()), } } } -/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. -fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { - let mut conditions: Vec<String> = Vec::new(); - let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); - let author_expr = "CASE \ - WHEN author IS NULL OR trim(author) = '' THEN \ - CASE \ - WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ - ELSE hostname \ - END \ - ELSE author \ - END"; - - for author in authors { - match author.as_str() { - AUTHOR_FILTER_ALL_USER => { - conditions.push(format!("{author_expr} NOT IN ({agent_list})")); - } - AUTHOR_FILTER_ALL_AGENT => { - conditions.push(format!("{author_expr} IN ({agent_list})")); - } - literal => { - conditions.push(format!("{author_expr} = {}", quote(literal))); - } - } - } - - if !conditions.is_empty() { - sql.and_where(format!("({})", conditions.join(" OR "))); - } -} - fn get_session_start_time(session_id: &str) -> Option<i64> { if let Ok(uuid) = Uuid::parse_str(session_id) && let Some(timestamp) = uuid.get_timestamp() @@ -131,73 +95,22 @@ fn get_session_start_time(session_id: &str) -> Option<i64> { None } -#[async_trait] -pub(crate) trait Database: Send + Sync + 'static { - async fn save(&self, h: &History) -> Result<()>; - async fn save_bulk(&self, h: &[History]) -> Result<()>; - - async fn load(&self, id: &str) -> Result<Option<History>>; - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option<usize>, - unique: bool, - include_deleted: bool, - ) -> Result<Vec<History>>; - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; - - async fn update(&self, h: &History) -> Result<()>; - async fn history_count(&self, include_deleted: bool) -> Result<i64>; - - async fn last(&self) -> Result<Option<History>>; - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; - - async fn delete(&self, h: History) -> Result<()>; - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; - async fn deleted(&self) -> Result<Vec<History>>; - - // Yes I know, it's a lot. - // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. - // Been debating maybe a DSL for search? eg "before:time limit:1 the query" - #[expect(clippy::too_many_arguments)] - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result<Vec<History>>; - - async fn query_history(&self, query: &str) -> Result<Vec<History>>; - - async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; - - async fn stats(&self, h: &History) -> Result<HistoryStats>; - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; - - fn clone_boxed(&self) -> Box<dyn Database + 'static>; -} - // Intended for use on a developer machine and not a sync server. // TODO: implement IntoIterator #[derive(Debug, Clone)] -pub(crate) struct Sqlite { +pub(crate) struct ClientSqlite { pub(crate) pool: SqlitePool, } -impl Sqlite { +impl ClientSqlite { pub(crate) async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { let path = path.as_ref(); debug!("opening sqlite database at {path:?}"); if utils::broken_symlink(path) { eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + "Atuin: Sqlite db path ({}) is a broken symlink. Unable to read or create replacement.", + path.display() ); std::process::exit(1); } @@ -224,12 +137,6 @@ impl Sqlite { Ok(Self { pool }) } - pub(crate) async fn sqlite_version(&self) -> Result<String> { - sqlx::query_scalar("SELECT sqlite_version()") - .fetch_one(&self.pool) - .await - } - async fn setup_db(pool: &SqlitePool) -> Result<()> { debug!("running sqlite database setup"); @@ -272,7 +179,7 @@ impl Sqlite { Ok(()) } - fn query_history(row: SqliteRow) -> History { + fn query_history_inner(row: SqliteRow) -> History { let deleted_at: Option<i64> = row.get("deleted_at"); let hostname: String = row.get("hostname"); let author: Option<String> = row.try_get("author").ok().flatten(); @@ -304,9 +211,8 @@ impl Sqlite { } } -#[async_trait] -impl Database for Sqlite { - async fn save(&self, h: &History) -> Result<()> { +impl ClientSqlite { + pub(crate) async fn save(&self, h: &History) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; Self::save_raw(&mut tx, h).await?; @@ -315,7 +221,7 @@ impl Database for Sqlite { Ok(()) } - async fn save_bulk(&self, h: &[History]) -> Result<()> { + pub(crate) async fn save_bulk(&self, h: &[History]) -> Result<()> { debug!("saving history to sqlite"); let mut tx = self.pool.begin().await?; @@ -329,19 +235,19 @@ impl Database for Sqlite { Ok(()) } - async fn load(&self, id: &str) -> Result<Option<History>> { + pub(crate) async fn load(&self, id: &str) -> Result<Option<History>> { debug!("loading history item {}", id); let res = sqlx::query("select * from history where id = ?1") .bind(id) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; Ok(res) } - async fn update(&self, h: &History) -> Result<()> { + pub(crate) async fn update(&self, h: &History) -> Result<()> { debug!("updating sqlite history"); sqlx::query( @@ -367,7 +273,7 @@ impl Database for Sqlite { } // make a unique list, that only shows the *newest* version of things - async fn list( + pub(crate) async fn list( &self, filters: &[FilterMode], context: &Context, @@ -419,14 +325,18 @@ impl Database for Sqlite { let query = query.sql().expect("bug in list query. please report"); let res = sqlx::query(&query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + pub(crate) async fn range( + &self, + from: OffsetDateTime, + to: OffsetDateTime, + ) -> Result<Vec<History>> { debug!("listing history from {:?} to {:?}", from, to); let res = sqlx::query( @@ -434,47 +344,25 @@ impl Database for Sqlite { ) .bind(from.unix_timestamp_nanos() as i64) .bind(to.unix_timestamp_nanos() as i64) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn last(&self) -> Result<Option<History>> { + pub(crate) async fn last(&self) -> Result<Option<History>> { let res = sqlx::query( "select * from history where duration >= 0 order by timestamp desc limit 1", ) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; Ok(res) } - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { - let res = sqlx::query( - "select * from history where timestamp < ?1 order by timestamp desc limit ?2", - ) - .bind(timestamp.unix_timestamp_nanos() as i64) - .bind(count) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn deleted(&self) -> Result<Vec<History>> { - let res = sqlx::query("select * from history where deleted_at is not null") - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn history_count(&self, include_deleted: bool) -> Result<i64> { + pub(crate) async fn history_count(&self, include_deleted: bool) -> Result<i64> { let query = if include_deleted { "select count(1) from history" } else { @@ -485,7 +373,10 @@ impl Database for Sqlite { Ok(res.0) } - async fn search( + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + pub(crate) async fn search( &self, search_mode: SearchMode, filter: FilterMode, @@ -541,54 +432,53 @@ impl Database for Sqlite { let orig_query = query; let mut regexes = Vec::new(); - match search_mode { - SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), - _ => { - let mut is_or = false; - for token in QueryTokenizer::new(query) { - // TODO smart case mode could be made configurable like in fzf - let (is_glob, glob) = if token.has_uppercase() { - (true, "*") - } else { - (false, "%") - }; - let param = match token { - QueryToken::Regex(r) => { - regexes.push(String::from(r)); + if search_mode == SearchMode::Prefix { + sql.and_where_like_left("command", query.replace('*', "%")) + } else { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; continue; + } else { + format!("{glob}|{glob}") } - QueryToken::Or => { - if !is_or { - is_or = true; - continue; - } else { - format!("{glob}|{glob}") - } - } - QueryToken::MatchStart(term, _) => { - format!("{term}{glob}") - } - QueryToken::MatchEnd(term, _) => { - format!("{glob}{term}") - } - QueryToken::MatchFull(term, _) => { + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) } - QueryToken::Match(term, _) => { - if search_mode == SearchMode::FullText { - format!("{glob}{term}{glob}") - } else { - term.split("").join(glob) - } - } - }; - - sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); - is_or = false; - } + } + }; - &mut sql + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; } + + &mut sql }; for regex in regexes { @@ -631,32 +521,28 @@ impl Database for Sqlite { .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) }); - if !filter_options.authors.is_empty() { - apply_author_filter(&mut sql, &filter_options.authors); - } - sql.and_where_is_null("deleted_at"); let query = sql.sql().expect("bug in search query. please report"); let res = sqlx::query(&query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) } - async fn query_history(&self, query: &str) -> Result<Vec<History>> { + pub(crate) async fn query_history(&self, query: &str) -> Result<Vec<History>> { let res = sqlx::query(query) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + pub(crate) async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { debug!("listing history"); let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); @@ -686,7 +572,7 @@ impl Database for Sqlite { let res = sqlx::query(&query) .map(|row: SqliteRow| { let count: i32 = row.get("count"); - (Self::query_history(row), count) + (Self::query_history_inner(row), count) }) .fetch_all(&self.pool) .await?; @@ -694,13 +580,13 @@ impl Database for Sqlite { Ok(res) } - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { - Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + pub(crate) fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(self.clone(), page_size, include_deleted, unique) } // deleted_at doesn't mean the actual time that the user deleted it, // but the time that the system marks it as deleted - async fn delete(&self, mut h: History) -> Result<()> { + pub(crate) async fn delete(&self, mut h: History) -> Result<()> { let now = OffsetDateTime::now_utc(); h.command = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -714,7 +600,7 @@ impl Database for Sqlite { Ok(()) } - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + pub(crate) async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { let mut tx = self.pool.begin().await?; for id in ids { @@ -726,7 +612,7 @@ impl Database for Sqlite { Ok(()) } - async fn stats(&self, h: &History) -> Result<HistoryStats> { + pub(crate) async fn stats(&self, h: &History) -> Result<HistoryStats> { // We select the previous in the session by time let mut prev = SqlBuilder::select_from("history"); prev.field("*") @@ -791,14 +677,14 @@ impl Database for Sqlite { let prev = sqlx::query(&prev) .bind(h.timestamp.unix_timestamp_nanos() as i64) .bind(&h.session) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; let next = sqlx::query(&next) .bind(h.timestamp.unix_timestamp_nanos() as i64) .bind(&h.session) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_optional(&self.pool) .await?; @@ -843,7 +729,7 @@ impl Database for Sqlite { }) } - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { + pub(crate) async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>> { let res = sqlx::query( "SELECT * FROM ( SELECT *, ROW_NUMBER() @@ -856,20 +742,16 @@ impl Database for Sqlite { ) .bind(dupkeep) .bind(before) - .map(Self::query_history) + .map(Self::query_history_inner) .fetch_all(&self.pool) .await?; Ok(res) } - - fn clone_boxed(&self) -> Box<dyn Database + 'static> { - Box::new(self.clone()) - } } pub(crate) struct Paged { - database: Box<dyn Database + 'static>, + database: ClientSqlite, page_size: usize, last_id: Option<String>, include_deleted: bool, @@ -878,7 +760,7 @@ pub(crate) struct Paged { impl Paged { pub(crate) fn new( - database: Box<dyn Database + 'static>, + database: ClientSqlite, page_size: usize, include_deleted: bool, unique: bool, @@ -976,7 +858,7 @@ mod test { use std::time::{Duration, Instant}; async fn assert_search_eq( - db: &impl Database, + db: &ClientSqlite, mode: SearchMode, filter_mode: FilterMode, query: &str, @@ -1013,7 +895,7 @@ mod test { } async fn assert_search_commands( - db: &impl Database, + db: &ClientSqlite, mode: SearchMode, filter_mode: FilterMode, query: &str, @@ -1026,7 +908,7 @@ mod test { assert_eq!(commands, expected_commands); } - async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + async fn new_history_item(db: &mut ClientSqlite, cmd: &str) -> Result<()> { let mut captured: History = History::capture() .timestamp(OffsetDateTime::now_utc()) .command(cmd) @@ -1044,7 +926,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_prefix() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1062,7 +944,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_fulltext() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1148,7 +1030,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); new_history_item(&mut db, "ls /home/ellie").await.unwrap(); @@ -1251,7 +1133,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_search_reordered_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); // test ordering of results: we should choose the first, even though it happened longer ago. @@ -1279,7 +1161,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_basic() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1315,7 +1197,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_empty() { - let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1329,7 +1211,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_unique() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1352,7 +1234,7 @@ mod test { #[tokio::test(flavor = "multi_thread")] async fn test_paged_include_deleted() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); @@ -1407,7 +1289,7 @@ mod test { git_root: None, }; - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + let mut db = ClientSqlite::new("sqlite::memory:", test_local_timeout()) .await .unwrap(); for _i in 1..10000 { @@ -1448,7 +1330,7 @@ pub(crate) enum QueryToken<'a> { Regex(&'a str), } -impl<'a> QueryToken<'a> { +impl QueryToken<'_> { pub(crate) fn has_uppercase(&self) -> bool { match self { Self::Match(term, _) diff --git a/crates/turtle/src/atuin_client/distro.rs b/crates/turtle/src/atuin_client/distro.rs deleted file mode 100644 index 00b92fe9..00000000 --- a/crates/turtle/src/atuin_client/distro.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::process::Command; - -/// Detect the Linux distribution from the system, -/// using system-specific release files and falling -/// back to lsb_release. -pub(crate) fn detect_linux_distribution() -> String { - detect_from_os_release() - .or_else(detect_from_debian_version) - .or_else(detect_from_centos_release) - .or_else(detect_from_redhat_release) - .or_else(detect_from_fedora_release) - .or_else(detect_from_arch_release) - .or_else(detect_from_alpine_release) - .or_else(detect_from_suse_release) - .or_else(detect_from_lsb_release) - .unwrap_or_else(|| "Unknown".to_string()) -} - -fn detect_from_os_release() -> Option<String> { - let content = std::fs::read_to_string("/etc/os-release").ok()?; - - content - .lines() - .find(|l| l.starts_with("PRETTY_NAME=")) - .and_then(|l| l.split_once('=').map(|s| s.1)) - .map(|s| s.trim_matches('"').to_string()) -} - -fn detect_from_debian_version() -> Option<String> { - std::fs::read_to_string("/etc/debian_version") - .ok() - .map(|v| format!("Debian {}", v.trim())) -} - -fn detect_from_centos_release() -> Option<String> { - std::fs::read_to_string("/etc/centos-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_redhat_release() -> Option<String> { - std::fs::read_to_string("/etc/redhat-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_fedora_release() -> Option<String> { - std::fs::read_to_string("/etc/fedora-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_arch_release() -> Option<String> { - std::fs::read_to_string("/etc/arch-release") - .ok() - .filter(|v| !v.trim().is_empty()) - .map(|_| "Arch Linux".to_string()) -} - -fn detect_from_alpine_release() -> Option<String> { - std::fs::read_to_string("/etc/alpine-release") - .ok() - .map(|v| format!("Alpine {}", v.trim())) -} - -fn detect_from_suse_release() -> Option<String> { - std::fs::read_to_string("/etc/SuSE-release") - .ok() - .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) -} - -fn detect_from_lsb_release() -> Option<String> { - let output = Command::new("lsb_release").arg("-a").output().ok()?; - - if !output.status.success() { - return None; - } - - let output = String::from_utf8(output.stdout).ok()?; - linux_distro_from_lsb_release(&output) -} - -fn linux_distro_from_lsb_release(output: &str) -> Option<String> { - output - .lines() - .find(|line| line.starts_with("Description:")) - .and_then(|line| line.split_once(':').map(|s| s.1)) - .map(|s| s.trim().to_string()) -} diff --git a/crates/turtle/src/atuin_client/encryption.rs b/crates/turtle/src/atuin_client/encryption.rs index f5d4f20d..e9c8d7e9 100644 --- a/crates/turtle/src/atuin_client/encryption.rs +++ b/crates/turtle/src/atuin_client/encryption.rs @@ -8,27 +8,16 @@ // clients must share the secret in order to be able to sync, as it is needed // to decrypt -use std::{io::prelude::*, path::PathBuf}; +use std::io::prelude::Write; use base64::prelude::{BASE64_STANDARD, Engine}; pub(crate) use crypto_secretbox::Key; -use crypto_secretbox::{ - AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, - aead::{Nonce, OsRng}, -}; +use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; use eyre::{Context, Result, bail, ensure, eyre}; use fs_err as fs; -use rmp::{Marker, decode::Bytes}; -use serde::{Deserialize, Serialize}; -use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; +use rmp::Marker; -use crate::atuin_client::{history::History, settings::Settings}; - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct EncryptedHistory { - pub(crate) ciphertext: Vec<u8>, - pub(crate) nonce: Nonce<XSalsa20Poly1305>, -} +use crate::atuin_client::settings::Settings; pub(crate) fn generate_encoded_key() -> Result<(Key, String)> { let key = XSalsa20Poly1305::generate_key(&mut OsRng); @@ -38,33 +27,27 @@ pub(crate) fn generate_encoded_key() -> Result<(Key, String)> { } pub(crate) fn new_key(settings: &Settings) -> Result<Key> { - let path = settings.key_path.as_str(); - let path = PathBuf::from(path); - - if path.exists() { + if settings.sync.encryption_key()?.is_some() { bail!("key already exists! cannot overwrite"); - } - - let (key, encoded) = generate_encoded_key()?; + } else if let Some(path) = settings.sync.encryption_key_path.as_ref() { + let (key, encoded) = generate_encoded_key()?; - let mut file = fs::File::create(path)?; - file.write_all(encoded.as_bytes())?; + let mut file = fs::File::create(path)?; + file.write_all(encoded.as_bytes())?; - Ok(key) + Ok(key) + } else { + bail!("No key-path set, cannot generate key") + } } // Loads the secret key, will create + save if it doesn't exist pub(crate) fn load_key(settings: &Settings) -> Result<Key> { - let path = settings.key_path.as_str(); - - let key = if PathBuf::from(path).exists() { - let key = fs_err::read_to_string(path)?; - decode_key(key)? + if let Some(key) = settings.sync.encryption_key()? { + Ok(key) } else { - new_key(settings)? - }; - - Ok(key) + Ok(new_key(settings)?) + } } pub(crate) fn encode_key(key: &Key) -> Result<String> { @@ -72,7 +55,7 @@ pub(crate) fn encode_key(key: &Key) -> Result<String> { rmp::encode::write_array_len(&mut buf, key.len() as u32) .wrap_err("could not encode key to message pack")?; for b in key { - rmp::encode::write_uint(&mut buf, *b as u64) + rmp::encode::write_uint(&mut buf, u64::from(*b)) .wrap_err("could not encode key to message pack")?; } let buf = BASE64_STANDARD.encode(buf); @@ -89,316 +72,37 @@ pub(crate) fn decode_key(key: String) -> Result<Key> { // old code wrote the key as a fixed length array of 32 bytes // new code writes the key with a length prefix - match <[u8; 32]>::try_from(&*buf) { - Ok(key) => Ok(key.into()), - Err(_) => { - let mut bytes = rmp::decode::Bytes::new(&buf); + if let Ok(key) = <[u8; 32]>::try_from(&*buf) { + Ok(key.into()) + } else { + let mut bytes = rmp::decode::Bytes::new(&buf); - match Marker::from_u8(buf[0]) { - Marker::Bin8 => { - let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); - let key = <[u8; 32]>::try_from(bytes.remaining_slice()) - .context("could not decode encryption key")?; - Ok(key.into()) - } - Marker::Array16 => { - let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); + match Marker::from_u8(buf[0]) { + Marker::Bin8 => { + let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + let key = <[u8; 32]>::try_from(bytes.remaining_slice()) + .context("could not decode encryption key")?; + Ok(key.into()) + } + Marker::Array16 => { + let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); - let mut key = Key::default(); - for i in &mut key { - *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - } - Ok(key) + let mut key = Key::default(); + for i in &mut key { + *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; } - _ => bail!("could not decode encryption key"), + Ok(key) } + _ => bail!("could not decode encryption key"), } } } -pub(crate) fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> { - // serialize with msgpack - let mut buf = encode(history)?; - - let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); - XSalsa20Poly1305::new(key) - .encrypt_in_place(&nonce, &[], &mut buf) - .map_err(|_| eyre!("could not encrypt"))?; - - Ok(EncryptedHistory { - ciphertext: buf, - nonce, - }) -} - -pub(crate) fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> { - XSalsa20Poly1305::new(key) - .decrypt_in_place( - &encrypted_history.nonce, - &[], - &mut encrypted_history.ciphertext, - ) - .map_err(|_| eyre!("could not decrypt history"))?; - let plaintext = encrypted_history.ciphertext; - - let history = decode(&plaintext)?; - - Ok(history) -} - -fn format_rfc3339(ts: OffsetDateTime) -> Result<String> { - // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. - // time does not have this functionality. - static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); - static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); - static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); - static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); - - let fmt = match ts.nanosecond() { - 0 => PARTIAL_RFC3339_0, - ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, - ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, - _ => PARTIAL_RFC3339_9, - }; - - Ok(ts.format(fmt)?) -} - -fn encode(h: &History) -> Result<Vec<u8>> { - use rmp::encode; - - let mut output = vec![]; - // INFO: ensure this is updated when adding new fields - encode::write_array_len(&mut output, 9)?; - - encode::write_str(&mut output, &h.id.0)?; - encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; - encode::write_sint(&mut output, h.duration)?; - encode::write_sint(&mut output, h.exit)?; - encode::write_str(&mut output, &h.command)?; - encode::write_str(&mut output, &h.cwd)?; - encode::write_str(&mut output, &h.session)?; - encode::write_str(&mut output, &h.hostname)?; - match h.deleted_at { - Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, - None => encode::write_nil(&mut output)?, - } - - Ok(output) -} - -fn decode(bytes: &[u8]) -> Result<History> { - use rmp::decode::{self, DecodeStringError}; - - let mut bytes = Bytes::new(bytes); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - if nfields < 8 { - bail!("malformed decrypted history") - } - if nfields > 9 { - bail!("cannot decrypt history from a newer version of atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - // if we have more fields, try and get the deleted_at - let mut deleted_at = None; - let mut bytes = bytes; - if nfields > 8 { - bytes = match decode::read_str_from_slice(bytes) { - Ok((d, b)) => { - deleted_at = Some(d); - b - } - // we accept null here - Err(DecodeStringError::TypeMismatch(Marker::Null)) => { - // consume the null marker - let mut c = Bytes::new(bytes); - decode::read_nil(&mut c).map_err(error_report)?; - c.remaining_slice() - } - Err(err) => return Err(error_report(err)), - }; - } - - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: History::author_from_hostname(hostname), - intent: None, - deleted_at: deleted_at - .map(|t| OffsetDateTime::parse(t, &Rfc3339)) - .transpose()?, - }) -} - -fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") -} - #[cfg(test)] mod test { - use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; use pretty_assertions::assert_eq; - use time::{OffsetDateTime, macros::datetime}; - - use crate::history::History; - - use super::{decode, decrypt, encode, encrypt}; - - #[test] - fn test_encrypt_decrypt() { - let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); - let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); - - let history = History::from_db() - .id("1".into()) - .timestamp(OffsetDateTime::now_utc()) - .command("ls".into()) - .cwd("/home/ellie".into()) - .exit(0) - .duration(1) - .session("beep boop".into()) - .hostname("booop".into()) - .author("booop".into()) - .intent(None) - .deleted_at(None) - .build() - .into(); - - let e1 = encrypt(&history, &key1).unwrap(); - let e2 = encrypt(&history, &key2).unwrap(); - - assert_ne!(e1.ciphertext, e2.ciphertext); - assert_ne!(e1.nonce, e2.nonce); - - // test decryption works - // this should pass - match decrypt(e1, &key1) { - Err(e) => panic!("failed to decrypt, got {e}"), - Ok(h) => assert_eq!(h, history), - }; - - // this should err - let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); - } - - #[test] - fn test_decode() { - let bytes = [ - 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, 192, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - - let b = encode(&h).unwrap(); - assert_eq!(&bytes, &*b); - } - - #[test] - fn test_decode_deleted() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), - }; - - let b = encode(&history).unwrap(); - let h = decode(&b).unwrap(); - assert_eq!(history, h); - } - - #[test] - fn test_decode_old() { - let bytes = [ - 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - } #[test] fn key_encodings() { diff --git a/crates/turtle/src/atuin_client/history.rs b/crates/turtle/src/atuin_client/history.rs index 6bc0bc38..5e2f89f2 100644 --- a/crates/turtle/src/atuin_client/history.rs +++ b/crates/turtle/src/atuin_client/history.rs @@ -18,24 +18,6 @@ use time::OffsetDateTime; mod builder; pub(crate) mod store; -/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. -pub(crate) const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; -pub(crate) const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; -pub(crate) const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; - -pub(crate) fn is_known_agent(author: &str) -> bool { - KNOWN_AGENTS.contains(&author) -} - -pub(crate) fn author_matches_filters(author: &str, filters: &[String]) -> bool { - filters.is_empty() - || filters.iter().any(|filter| match filter.as_str() { - AUTHOR_FILTER_ALL_USER => !is_known_agent(author), - AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), - literal => author == literal, - }) -} - pub(crate) const HISTORY_VERSION_V0: &str = "v0"; pub(crate) const HISTORY_VERSION_V1: &str = "v1"; const HISTORY_RECORD_VERSION_V0: u16 = 0; @@ -527,8 +509,7 @@ impl History { } pub(crate) fn should_save(&self, settings: &Settings) -> bool { - !(self.command.starts_with(' ') - || self.command.is_empty() + !(self.command.is_empty() || settings.history_filter.is_match(&self.command) || settings.cwd_filter.is_match(&self.cwd) || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) @@ -540,12 +521,9 @@ mod tests { use regex::RegexSet; use time::macros::datetime; - use crate::{ - history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, - settings::Settings, - }; + use crate::atuin_client::{history::HISTORY_VERSION, settings::Settings}; - use super::{History, author_matches_filters, is_known_agent}; + use super::History; // Test that we don't save history where necessary #[test] @@ -553,7 +531,7 @@ mod tests { let settings = Settings { cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), history_filter: RegexSet::new(["^psql"]).unwrap(), - ..Settings::utc() + ..Settings::default() }; let normal_command: History = History::capture() @@ -607,19 +585,6 @@ mod tests { } #[test] - fn known_agents_include_pi() { - assert!(is_known_agent("pi")); - assert!(author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_AGENT.to_string()] - )); - assert!(!author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_USER.to_string()] - )); - } - - #[test] fn disable_secrets() { let settings = Settings { secrets_filter: false, @@ -641,7 +606,7 @@ mod tests { let history = History { id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, + duration: 49_206_000, exit: 0, command: "git status".to_owned(), cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), @@ -669,7 +634,7 @@ mod tests { let history = History { id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, + duration: 49_206_000, exit: 0, command: "git status".to_owned(), cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), @@ -693,7 +658,7 @@ mod tests { let history = History { id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, + duration: 49_206_000, exit: 0, command: "git status".to_owned(), cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), @@ -735,7 +700,7 @@ mod tests { let current = History { id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, + duration: 49_206_000, exit: 0, command: "git status".to_owned(), cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), diff --git a/crates/turtle/src/atuin_client/history/store.rs b/crates/turtle/src/atuin_client/history/store.rs index b2265698..a8162e21 100644 --- a/crates/turtle/src/atuin_client/history/store.rs +++ b/crates/turtle/src/atuin_client/history/store.rs @@ -6,7 +6,7 @@ use rmp::decode::Bytes; use tracing::debug; use crate::atuin_client::{ - database::{Database, current_context}, + database::{ClientSqlite, current_context}, record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, }; use crate::atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; @@ -226,7 +226,7 @@ impl HistoryStore { Ok(ret) } - pub(crate) async fn build(&self, database: &dyn Database) -> Result<()> { + pub(crate) async fn build(&self, database: &ClientSqlite) -> Result<()> { // I'd like to change how we rebuild and not couple this with the database, but need to // consider the structure more deeply. This will be easy to change. @@ -258,7 +258,11 @@ impl HistoryStore { Ok(()) } - pub(crate) async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { + pub(crate) async fn incremental_build( + &self, + database: &ClientSqlite, + ids: &[RecordId], + ) -> Result<()> { for id in ids { let record = self.store.get(*id).await; @@ -310,7 +314,7 @@ impl HistoryStore { Ok(ret) } - pub(crate) async fn init_store(&self, db: &impl Database) -> Result<()> { + pub(crate) async fn init_store(&self, db: &ClientSqlite) -> Result<()> { let pb = ProgressBar::new_spinner(); pb.set_style( ProgressStyle::with_template("{spinner:.blue} {msg}") diff --git a/crates/turtle/src/atuin_client/import/bash.rs b/crates/turtle/src/atuin_client/import/bash.rs deleted file mode 100644 index e35634e7..00000000 --- a/crates/turtle/src/atuin_client/import/bash.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use itertools::Itertools; -use time::{Duration, OffsetDateTime}; -use tracing::warn; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct Bash { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".bash_history")) -} - -#[async_trait] -impl Importer for Bash { - const NAME: &'static str = "bash"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - let count = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| matches!(line, LineType::Command(_))) - .count(); - Ok(count) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let lines = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored - .collect_vec(); - - let (commands_before_first_timestamp, first_timestamp) = lines - .iter() - .enumerate() - .find_map(|(i, line)| match line { - LineType::Timestamp(t) => Some((i, *t)), - _ => None, - }) - // if no known timestamps, use now as base - .unwrap_or((lines.len(), OffsetDateTime::now_utc())); - - // if no timestamp is recorded, then use this increment to set an arbitrary timestamp - // to preserve ordering - // this increment is deliberately very small to prevent particularly fast fingers - // causing ordering issues; it also helps in handling the "here document" syntax, - // where several lines are recorded in succession without individual timestamps - let timestamp_increment = Duration::milliseconds(1); - - // make sure there is a minimum amount of time before the first known timestamp - // to fit all commands, given the default increment - let mut next_timestamp = - first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; - - for line in lines.into_iter() { - match line { - LineType::NotUtf8 => unreachable!(), // already filtered - LineType::Empty => {} // do nothing - LineType::Timestamp(t) => { - if t < next_timestamp { - warn!( - "Time reversal detected in Bash history! Commands may be ordered incorrectly." - ); - } - next_timestamp = t; - } - LineType::Command(c) => { - let imported = History::import().timestamp(next_timestamp).command(c); - - h.push(imported.build().into()).await?; - next_timestamp += timestamp_increment; - } - } - } - - Ok(()) - } -} - -#[derive(Debug, Clone)] -enum LineType<'a> { - NotUtf8, - /// Can happen when using the "here document" syntax. - Empty, - /// A timestamp line start with a '#', followed immediately by an integer - /// that represents seconds since UNIX epoch. - Timestamp(OffsetDateTime), - /// Anything else. - Command(&'a str), -} -impl<'a> From<&'a [u8]> for LineType<'a> { - fn from(bytes: &'a [u8]) -> Self { - let Ok(line) = str::from_utf8(bytes) else { - return LineType::NotUtf8; - }; - if line.is_empty() { - return LineType::Empty; - } - - match try_parse_line_as_timestamp(line) { - Some(time) => LineType::Timestamp(time), - None => LineType::Command(line), - } - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { - let seconds = line.strip_prefix('#')?.parse().ok()?; - OffsetDateTime::from_unix_timestamp(seconds).ok() -} - -#[cfg(test)] -mod test { - use std::cmp::Ordering; - - use itertools::{Itertools, assert_equal}; - - use crate::atuin_client::import::{Importer, tests::TestLoader}; - - use super::Bash; - - #[tokio::test] - async fn parse_no_timestamps() { - let bytes = r"cargo install atuin -cargo update -cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - #[tokio::test] - async fn parse_with_timestamps() { - let bytes = b"#1672918999 -git reset -#1672919006 -git clean -dxf -#1672919020 -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert_equal( - loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), - [1_672_918_999, 1_672_919_006, 1_672_919_020], - ) - } - - #[tokio::test] - async fn parse_with_partial_timestamps() { - let bytes = b"git reset -#1672919006 -git clean -dxf -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - fn is_strictly_sorted<T>(iter: impl IntoIterator<Item = T>) -> bool - where - T: Clone + PartialOrd, - { - iter.into_iter() - .tuple_windows() - .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) - } -} diff --git a/crates/turtle/src/atuin_client/import/fish.rs b/crates/turtle/src/atuin_client/import/fish.rs deleted file mode 100644 index edc2d437..00000000 --- a/crates/turtle/src/atuin_client/import/fish.rs +++ /dev/null @@ -1,179 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct Fish { - bytes: Vec<u8>, -} - -/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history -fn default_histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let data = std::env::var("XDG_DATA_HOME").map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ); - - // fish supports multiple history sessions - // If `fish_history` var is missing, or set to `default`, use `fish` as the session - let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); - let session = if session == "default" { - String::from("fish") - } else { - session - }; - - let mut histpath = data.join("fish"); - histpath.push(format!("{session}_history")); - - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Fish { - const NAME: &'static str = "fish"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(default_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut time: Option<OffsetDateTime> = None; - let mut cmd: Option<String> = None; - - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - if let Some(c) = s.strip_prefix("- cmd: ") { - // first, we must deal with the prev cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - // using raw strings to avoid needing escaping. - // replaces double backslashes with single backslashes - let c = c.replace(r"\\", r"\"); - // replaces escaped newlines - let c = c.replace(r"\n", "\n"); - // TODO: any other escape characters? - - cmd = Some(c); - } else if let Some(t) = s.strip_prefix(" when: ") { - // if t is not an int, just ignore this line - if let Ok(t) = t.parse::<i64>() { - time = Some(OffsetDateTime::from_unix_timestamp(t)?); - } - } else { - // ... ignore paths lines - } - } - - // we might have a trailing cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Fish; - - #[tokio::test] - async fn parse_complex() { - // complicated input with varying contents and escaped strings. - let bytes = r#"- cmd: history --help - when: 1639162832 -- cmd: cat ~/.bash_history - when: 1639162851 - paths: - - ~/.bash_history -- cmd: ls ~/.local/share/fish/fish_history - when: 1639162890 - paths: - - ~/.local/share/fish/fish_history -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162893 - paths: - - ~/.local/share/fish/fish_history -ERROR -- CORRUPTED: ENTRY - CONTINUE: - - AS - - NORMAL -- cmd: echo "foo" \\\n'bar' baz - when: 1639162933 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162939 - paths: - - ~/.local/share/fish/fish_history -- cmd: echo "\\"" \\\\ "\\\\" - when: 1639163063 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639163066 - paths: - - ~/.local/share/fish/fish_history -"# - .as_bytes() - .to_owned(); - - let fish = Fish { bytes }; - - let mut loader = TestLoader::default(); - fish.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for fish history entry - macro_rules! fishtory { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - fishtory!(1639162832, "history --help"); - fishtory!(1639162851, "cat ~/.bash_history"); - fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); - fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); - fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639163063, r#"echo "\"" \\ "\\""#); - fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); - } -} diff --git a/crates/turtle/src/atuin_client/import/mod.rs b/crates/turtle/src/atuin_client/import/mod.rs deleted file mode 100644 index 81e01991..00000000 --- a/crates/turtle/src/atuin_client/import/mod.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::fs::File; -use std::io::Read; -use std::path::PathBuf; - -use async_trait::async_trait; -use eyre::{Result, bail}; -use memchr::Memchr; - -use crate::atuin_client::history::History; - -pub(crate) mod bash; -pub(crate) mod fish; -pub(crate) mod nu; -pub(crate) mod nu_histdb; -pub(crate) mod powershell; -pub(crate) mod replxx; -pub(crate) mod resh; -pub(crate) mod xonsh; -pub(crate) mod xonsh_sqlite; -pub(crate) mod zsh; -pub(crate) mod zsh_histdb; - -#[async_trait] -pub(crate) trait Importer: Sized { - const NAME: &'static str; - async fn new() -> Result<Self>; - async fn entries(&mut self) -> Result<usize>; - async fn load(self, loader: &mut impl Loader) -> Result<()>; -} - -#[async_trait] -pub(crate) trait Loader: Sync + Send { - async fn push(&mut self, hist: History) -> eyre::Result<()>; -} - -fn unix_byte_lines(input: &[u8]) -> impl Iterator<Item = &[u8]> { - UnixByteLines { - iter: memchr::memchr_iter(b'\n', input), - bytes: input, - i: 0, - } -} - -struct UnixByteLines<'a> { - iter: Memchr<'a>, - bytes: &'a [u8], - i: usize, -} - -impl<'a> Iterator for UnixByteLines<'a> { - type Item = &'a [u8]; - - fn next(&mut self) -> Option<Self::Item> { - let j = self.iter.next()?; - let out = &self.bytes[self.i..j]; - self.i = j + 1; - Some(out) - } - - fn count(self) -> usize - where - Self: Sized, - { - self.iter.count() - } -} - -fn count_lines(input: &[u8]) -> usize { - unix_byte_lines(input).count() -} - -fn get_histpath<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - if let Ok(p) = std::env::var("HISTFILE") { - Ok(PathBuf::from(p)) - } else { - def() - } -} - -fn get_histfile_path<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - get_histpath(def).and_then(is_file) -} - -fn get_histdir_path<D>(def: D) -> Result<PathBuf> -where - D: FnOnce() -> Result<PathBuf>, -{ - get_histpath(def).and_then(is_dir) -} - -fn read_to_end(path: PathBuf) -> Result<Vec<u8>> { - let mut bytes = Vec::new(); - let mut f = File::open(path)?; - f.read_to_end(&mut bytes)?; - Ok(bytes) -} -fn is_file(p: PathBuf) -> Result<PathBuf> { - if p.is_file() { - Ok(p) - } else { - bail!( - "Could not find history file {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} -fn is_dir(p: PathBuf) -> Result<PathBuf> { - if p.is_dir() { - Ok(p) - } else { - bail!( - "Could not find history directory {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[derive(Default)] - pub(crate) struct TestLoader { - pub(crate) buf: Vec<History>, - } - - #[async_trait] - impl Loader for TestLoader { - async fn push(&mut self, hist: History) -> Result<()> { - self.buf.push(hist); - Ok(()) - } - } -} diff --git a/crates/turtle/src/atuin_client/import/nu.rs b/crates/turtle/src/atuin_client/import/nu.rs deleted file mode 100644 index 1897a969..00000000 --- a/crates/turtle/src/atuin_client/import/nu.rs +++ /dev/null @@ -1,67 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct Nu { - bytes: Vec<u8>, -} - -fn get_histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histpath = config_dir.join("history.txt"); - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Nu { - const NAME: &'static str = "nu"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - let cmd: String = s.replace("<\\n>", "\n"); - - let offset = time::Duration::nanoseconds(counter); - counter += 1; - - let entry = History::import().timestamp(now - offset).command(cmd); - - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/turtle/src/atuin_client/import/nu_histdb.rs b/crates/turtle/src/atuin_client/import/nu_histdb.rs deleted file mode 100644 index 1f66ea38..00000000 --- a/crates/turtle/src/atuin_client/import/nu_histdb.rs +++ /dev/null @@ -1,113 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::{Duration, OffsetDateTime}; - -use super::Importer; -use crate::atuin_client::history::History; -use crate::atuin_client::import::Loader; - -#[derive(sqlx::FromRow, Debug)] -pub(crate) struct HistDbEntry { - pub(crate) id: i64, - pub(crate) command_line: Vec<u8>, - pub(crate) start_timestamp: i64, - pub(crate) session_id: i64, - pub(crate) hostname: Vec<u8>, - pub(crate) cwd: Vec<u8>, - pub(crate) duration_ms: i64, - pub(crate) exit_status: i64, - pub(crate) more_info: Vec<u8>, -} - -impl From<HistDbEntry> for History { - fn from(histdb_item: HistDbEntry) -> Self { - let ts_secs = histdb_item.start_timestamp / 1000; - let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; - let imported = History::import() - .timestamp( - OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() - + Duration::nanoseconds(ts_ns), - ) - .command(String::from_utf8(histdb_item.command_line).unwrap()) - .cwd(String::from_utf8(histdb_item.cwd).unwrap()) - .exit(histdb_item.exit_status) - .duration(histdb_item.duration_ms) - .session(format!("{:x}", histdb_item.session_id)) - .hostname(String::from_utf8(histdb_item.hostname).unwrap()); - - imported.build().into() - } -} - -#[derive(Debug)] -pub(crate) struct NuHistDb { - histdb: Vec<HistDbEntry>, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { - let query = r#" - SELECT - id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, - more_info - FROM history - ORDER BY start_timestamp - "#; - let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl NuHistDb { - pub(crate) fn histpath() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histdb_path = config_dir.join("history.sqlite3"); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!("Could not find history file.")) - } - } -} - -#[async_trait] -impl Importer for NuHistDb { - // Not sure how this is used - const NAME: &'static str = "nu_histdb"; - - /// Creates a new NuHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result<Self> { - let dbpath = NuHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for i in self.histdb { - h.push(i.into()).await?; - } - Ok(()) - } -} diff --git a/crates/turtle/src/atuin_client/import/powershell.rs b/crates/turtle/src/atuin_client/import/powershell.rs deleted file mode 100644 index 09da0825..00000000 --- a/crates/turtle/src/atuin_client/import/powershell.rs +++ /dev/null @@ -1,202 +0,0 @@ -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use std::path::PathBuf; -use time::{Duration, OffsetDateTime}; - -use super::{Importer, Loader, count_lines, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct PowerShell { - bytes: Vec<u8>, - line_count: Option<usize>, -} - -fn get_history_path() -> Result<PathBuf> { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - - // The command line history in PowerShell is maintained by the PSReadLine module: - // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history - // - // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. - // > The history files are a file named `$($Host.Name)_history.txt`. - // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. - // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` - // > or `$Env:HOME/.local/share/powershell/PSReadLine`. - - let dir = if cfg!(windows) { - base.data_dir() - .join("Microsoft") - .join("Windows") - .join("PowerShell") - .join("PSReadLine") - } else { - std::env::var("XDG_DATA_HOME") - .map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ) - .join("powershell") - .join("PSReadLine") - }; - - // The history is stored in a file named `$($Host.Name)_history.txt`. - // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: - // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks - - let file = dir.join("ConsoleHost_history.txt"); - - if file.is_file() { - Ok(file) - } else { - Err(eyre!("Could not find history file: {}", file.display())) - } -} - -#[async_trait] -impl Importer for PowerShell { - const NAME: &'static str = "PowerShell"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_history_path()?)?; - Ok(Self { - bytes, - line_count: None, - }) - } - - async fn entries(&mut self) -> Result<usize> { - // Commands can be split over multiple lines, - // but this is only used for a progress bar, and multi-line commands - // should be quite rare, so this is not an issue in practice. - if self.line_count.is_none() { - self.line_count = Some(count_lines(&self.bytes)); - } - Ok(self.line_count.unwrap()) - } - - async fn load(mut self, h: &mut impl Loader) -> Result<()> { - let line_count = self.entries().await?; - let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); - - let mut counter = 0; - let mut iter = unix_byte_lines(&self.bytes); - - while let Some(s) = iter.next() { - let Ok(s) = read_line(s) else { - continue; // We can skip past things like invalid utf8 - }; - - let mut cmd = s.to_string(); - - // Multi-line commands end with a backtick, append the following lines. - while cmd.ends_with('`') { - cmd.pop(); - - let Some(next) = iter.next() else { - break; - }; - let Ok(next) = read_line(next) else { - break; - }; - - cmd.push('\n'); - cmd.push_str(next); - } - - if cmd.is_empty() { - continue; - } - - let offset = Duration::milliseconds(counter); - counter += 1; - - let entry = History::import().timestamp(start + offset).command(cmd); - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -fn read_line(s: &[u8]) -> Result<&str> { - let s = str::from_utf8(s)?; - - // History is stored in CRLF on Windows, normalize the input to LF on all platforms. - let s = s.strip_suffix('\r').unwrap_or(s); - - Ok(s) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::import::tests::TestLoader; - use itertools::assert_equal; - - const INPUT: &str = r#"cargo install atuin -cargo update -echo "first line` -second line` -` -last line" -echo foo - -echo bar -echo baz -"#; - - const EXPECTED: &[&str] = &[ - "cargo install atuin", - "cargo update", - "echo \"first line\nsecond line\n\nlast line\"", - "echo foo", - "echo bar", - "echo baz", - ]; - - #[tokio::test] - async fn test_import() { - let loader = import(INPUT).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_crlf() { - let input = INPUT.replace("\n", "\r\n"); - let loader = import(input.as_str()).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_timestamps() { - let loader = import(INPUT).await; - - let mut prev = loader.buf.first().unwrap().timestamp; - for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { - assert!(current > prev); - prev = current; - } - } - - async fn import(input: &str) -> TestLoader { - let powershell = PowerShell { - bytes: input.as_bytes().to_vec(), - line_count: None, - }; - - let mut loader = TestLoader::default(); - powershell.load(&mut loader).await.unwrap(); - loader - } -} diff --git a/crates/turtle/src/atuin_client/import/replxx.rs b/crates/turtle/src/atuin_client/import/replxx.rs deleted file mode 100644 index fbce2598..00000000 --- a/crates/turtle/src/atuin_client/import/replxx.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct Replxx { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - // There is no default histfile for replxx. - // Here we try a couple of common names. - let mut candidates = ["replxx_history.txt", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Replxx { - const NAME: &'static str = "replxx"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes) / 2) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut timestamp = OffsetDateTime::UNIX_EPOCH; - - for b in unix_byte_lines(&self.bytes) { - let s = std::str::from_utf8(b)?; - match try_parse_line_as_timestamp(s) { - Some(t) => timestamp = t, - None => { - // replxx uses ETB character (0x17) as line breaker - let cmd = s.replace('\u{0017}', "\n"); - let imported = History::import().timestamp(timestamp).command(cmd); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { - // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx - let date_time_str = line.strip_prefix("### ")?; - let format = - format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); - - let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; - // There is no safe way to get local time offset. - // For simplicity let's just assume UTC. - Some(primitive_date_time.assume_utc()) -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Replxx; - - #[tokio::test] - async fn parse_complex() { - let bytes = r#"### 2024-02-10 22:16:28.302 -select * from remote('127.0.0.1:20222', view(select 1)) -### 2024-02-10 22:16:36.919 -select * from numbers(10) -### 2024-02-10 22:16:41.710 -select * from system.numbers -### 2024-02-10 22:19:28.655 -select 1 -### 2024-02-22 11:15:33.046 -CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); -"# - .as_bytes() - .to_owned(); - - let replxx = Replxx { bytes }; - - let mut loader = TestLoader::default(); - replxx.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for replxx history entry - macro_rules! history { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - history!( - 1707603388, - "select * from remote('127.0.0.1:20222', view(select 1))" - ); - history!(1707603396, "select * from numbers(10)"); - history!(1707603401, "select * from system.numbers"); - history!(1707603568, "select 1"); - history!( - 1708600533, - "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" - ); - } -} diff --git a/crates/turtle/src/atuin_client/import/resh.rs b/crates/turtle/src/atuin_client/import/resh.rs deleted file mode 100644 index 2c75e387..00000000 --- a/crates/turtle/src/atuin_client/import/resh.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; - -use crate::atuin_common::utils::uuid_v7; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "camelCase")] -pub(crate) struct ReshEntry { - pub(crate) cmd_line: String, - pub(crate) exit_code: i64, - pub(crate) shell: String, - pub(crate) uname: String, - pub(crate) session_id: String, - pub(crate) home: String, - pub(crate) lang: String, - pub(crate) lc_all: String, - pub(crate) login: String, - pub(crate) pwd: String, - pub(crate) pwd_after: String, - pub(crate) shell_env: String, - pub(crate) term: String, - pub(crate) real_pwd: String, - pub(crate) real_pwd_after: String, - pub(crate) pid: i64, - pub(crate) session_pid: i64, - pub(crate) host: String, - pub(crate) hosttype: String, - pub(crate) ostype: String, - pub(crate) machtype: String, - pub(crate) shlvl: i64, - pub(crate) timezone_before: String, - pub(crate) timezone_after: String, - pub(crate) realtime_before: f64, - pub(crate) realtime_after: f64, - pub(crate) realtime_before_local: f64, - pub(crate) realtime_after_local: f64, - pub(crate) realtime_duration: f64, - pub(crate) realtime_since_session_start: f64, - pub(crate) realtime_since_boot: f64, - pub(crate) git_dir: String, - pub(crate) git_real_dir: String, - pub(crate) git_origin_remote: String, - pub(crate) git_dir_after: String, - pub(crate) git_real_dir_after: String, - pub(crate) git_origin_remote_after: String, - pub(crate) machine_id: String, - pub(crate) os_release_id: String, - pub(crate) os_release_version_id: String, - pub(crate) os_release_id_like: String, - pub(crate) os_release_name: String, - pub(crate) os_release_pretty_name: String, - pub(crate) resh_uuid: String, - pub(crate) resh_version: String, - pub(crate) resh_revision: String, - pub(crate) parts_merged: bool, - pub(crate) recalled: bool, - pub(crate) recall_last_cmd_line: String, - pub(crate) cols: String, - pub(crate) lines: String, -} - -#[derive(Debug)] -pub(crate) struct Resh { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".resh_history.json")) -} - -#[async_trait] -impl Importer for Resh { - const NAME: &'static str = "resh"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let entry = match serde_json::from_str::<ReshEntry>(s) { - Ok(e) => e, - Err(_) => continue, // skip invalid json :shrug: - }; - - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let timestamp = { - let secs = entry.realtime_before.floor() as i64; - let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; - OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) - }; - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let duration = { - let secs = entry.realtime_after.floor() as i64; - let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; - let base = OffsetDateTime::from_unix_timestamp(secs)? - + time::Duration::nanoseconds(nanosecs); - let difference = base - timestamp; - difference.whole_nanoseconds() as i64 - }; - - let imported = History::import() - .command(entry.cmd_line) - .timestamp(timestamp) - .duration(duration) - .exit(entry.exit_code) - .cwd(entry.pwd) - .hostname(entry.host) - // CHECK: should we add uuid here? It's not set in the other importers - .session(uuid_v7().as_simple().to_string()); - - h.push(imported.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/turtle/src/atuin_client/import/xonsh.rs b/crates/turtle/src/atuin_client/import/xonsh.rs deleted file mode 100644 index 5df24284..00000000 --- a/crates/turtle/src/atuin_client/import/xonsh.rs +++ /dev/null @@ -1,234 +0,0 @@ -use std::env; -use std::fs::{self, File}; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histdir_path}; -use crate::atuin_client::history::History; -use crate::atuin_client::utils::get_host_user; - -// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't -// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. -#[derive(Debug, Deserialize)] -struct HistoryFile { - data: HistoryData, -} - -#[derive(Debug, Deserialize)] -struct HistoryData { - sessionid: String, - cmds: Vec<HistoryCmd>, -} - -#[derive(Debug, Deserialize)] -struct HistoryCmd { - cwd: String, - inp: String, - rtn: Option<i64>, - ts: (f64, f64), -} - -#[derive(Debug)] -pub(crate) struct Xonsh { - // history is stored as a bunch of json files, one per session - sessions: Vec<HistoryData>, - hostname: String, -} - -fn xonsh_hist_dir(xonsh_data_dir: Option<String>) -> Result<PathBuf> { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("history_json"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_dir = base.data_dir().join("xonsh/history_json"); - if hist_dir.exists() || cfg!(test) { - Ok(hist_dir) - } else { - Err(eyre!("Could not find xonsh history files")) - } -} - -fn load_sessions(hist_dir: &Path) -> Result<Vec<HistoryData>> { - let mut sessions = vec![]; - for entry in fs::read_dir(hist_dir)? { - let p = entry?.path(); - let ext = p.extension().and_then(|e| e.to_str()); - if p.is_file() - && ext == Some("json") - && let Some(data) = load_session(&p)? - { - sessions.push(data); - } - } - Ok(sessions) -} - -fn load_session(path: &Path) -> Result<Option<HistoryData>> { - let file = File::open(path)?; - // empty files are not valid json, so we can't deserialize them - if file.metadata()?.len() == 0 { - return Ok(None); - } - - let mut hist_file: HistoryFile = serde_json::from_reader(file)?; - - // if there are commands in this session, replace the existing UUIDv4 - // with a UUIDv7 generated from the timestamp of the first command - if let Some(cmd) = hist_file.data.cmds.first() { - let seconds = cmd.ts.0.trunc() as u64; - let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; - let ts = Timestamp::from_unix(NoContext, seconds, nanos); - hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); - } - Ok(Some(hist_file.data)) -} - -#[async_trait] -impl Importer for Xonsh { - const NAME: &'static str = "xonsh"; - - async fn new() -> Result<Self> { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; - let sessions = load_sessions(&hist_dir)?; - let hostname = get_host_user(); - Ok(Xonsh { sessions, hostname }) - } - - async fn entries(&mut self) -> Result<usize> { - let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); - Ok(total) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - for session in self.sessions { - for cmd in session.cmds { - let (start, end) = cmd.ts; - let ts_nanos = (start * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; - - let duration = (end - start) * 1_000_000_000_f64; - - match cmd.rtn { - Some(exit) => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - None => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - } - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_hist_dir_xonsh() { - let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - hist_dir, - PathBuf::from("/home/user/xonsh_data/history_json") - ); - } - - #[tokio::test] - async fn test_import() { - let dir = PathBuf::from("tests/data/xonsh"); - let sessions = load_sessions(&dir).unwrap(); - let hostname = "box:user".to_string(); - let xonsh = Xonsh { sessions, hostname }; - - let mut loader = TestLoader::default(); - xonsh.load(&mut loader).await.unwrap(); - // order in buf will depend on filenames, so sort by timestamp for consistency - loader.buf.sort_by_key(|h| h.timestamp); - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4651069) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(21288633) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(1) - .duration(10269403) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(0) - .duration(4259347) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs deleted file mode 100644 index 326fe74b..00000000 --- a/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::env; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use futures::TryStreamExt; -use sqlx::{FromRow, Row, sqlite::SqlitePool}; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histfile_path}; -use crate::atuin_client::history::History; -use crate::atuin_client::utils::get_host_user; - -#[derive(Debug, FromRow)] -struct HistDbEntry { - inp: String, - rtn: Option<i64>, - tsb: f64, - tse: f64, - cwd: String, - session_start: f64, -} - -impl HistDbEntry { - fn into_hist_with_hostname(self, hostname: String) -> History { - let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); - - let session_ts_seconds = self.session_start.trunc() as u64; - let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; - let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); - let session_id = Uuid::new_v7(session_ts).to_string(); - let duration = (self.tse - self.tsb) * 1_000_000_000_f64; - - if let Some(exit) = self.rtn { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } else { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } - } -} - -fn xonsh_db_path(xonsh_data_dir: Option<String>) -> Result<PathBuf> { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("xonsh-history.sqlite"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); - if hist_file.exists() || cfg!(test) { - Ok(hist_file) - } else { - Err(eyre!( - "Could not find xonsh history db at: {}", - hist_file.to_string_lossy() - )) - } -} - -#[derive(Debug)] -pub(crate) struct XonshSqlite { - pool: SqlitePool, - hostname: String, -} - -#[async_trait] -impl Importer for XonshSqlite { - const NAME: &'static str = "xonsh_sqlite"; - - async fn new() -> Result<Self> { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; - let connection_str = db_path.to_str().ok_or_else(|| { - eyre!( - "Invalid path for SQLite database: {}", - db_path.to_string_lossy() - ) - })?; - - let pool = SqlitePool::connect(connection_str).await?; - let hostname = get_host_user(); - Ok(XonshSqlite { pool, hostname }) - } - - async fn entries(&mut self) -> Result<usize> { - let query = "SELECT COUNT(*) FROM xonsh_history"; - let row = sqlx::query(query).fetch_one(&self.pool).await?; - let count: u32 = row.get(0); - Ok(count as usize) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let query = r#" - SELECT inp, rtn, tsb, tse, cwd, - MIN(tsb) OVER (PARTITION BY sessionid) AS session_start - FROM xonsh_history - ORDER BY rowid - "#; - - let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); - - let mut count = 0; - while let Some(entry) = entries.try_next().await? { - let hist = entry.into_hist_with_hostname(self.hostname.clone()); - loader.push(hist).await?; - count += 1; - } - - println!("Loaded: {count}"); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_db_path_xonsh() { - let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - db_path, - PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") - ); - } - - #[tokio::test] - async fn test_import() { - let connection_str = "tests/data/xonsh-history.sqlite"; - let xonsh_sqlite = XonshSqlite { - pool: SqlitePool::connect(connection_str).await.unwrap(), - hostname: "box:user".to_string(), - }; - - let mut loader = TestLoader::default(); - xonsh_sqlite.load(&mut loader).await.unwrap(); - - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(2628564) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(9371519) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(1) - .duration(17337560) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4599094) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/turtle/src/atuin_client/import/zsh.rs b/crates/turtle/src/atuin_client/import/zsh.rs deleted file mode 100644 index 55d082d3..00000000 --- a/crates/turtle/src/atuin_client/import/zsh.rs +++ /dev/null @@ -1,230 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::borrow::Cow; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::atuin_client::history::History; -use crate::atuin_client::import::read_to_end; - -#[derive(Debug)] -pub(crate) struct Zsh { - bytes: Vec<u8>, -} - -fn default_histpath() -> Result<PathBuf> { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 - // we could maybe be smarter about this in the future :) - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Zsh { - const NAME: &'static str = "zsh"; - - async fn new() -> Result<Self> { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut line = String::new(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match unmetafy(b) { - Some(s) => s, - _ => continue, // we can skip past things like invalid utf8 - }; - - if let Some(s) = s.strip_suffix('\\') { - line.push_str(s); - line.push('\n'); - } else { - line.push_str(&s); - let command = std::mem::take(&mut line); - - if let Some(command) = command.strip_prefix(": ") { - counter += 1; - h.push(parse_extended(command, counter)).await?; - } else { - let offset = time::Duration::seconds(counter); - counter += 1; - - let imported = History::import() - // preserve ordering - .timestamp(now - offset) - .command(command.trim_end().to_string()); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn parse_extended(line: &str, counter: i64) -> History { - let (time, duration) = line.split_once(':').unwrap(); - let (duration, command) = duration.split_once(';').unwrap(); - - let time = time - .parse::<i64>() - .ok() - .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) - .unwrap_or_else(OffsetDateTime::now_utc) - + time::Duration::milliseconds(counter); - - // use nanos, because why the hell not? we won't display them. - let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); - - let imported = History::import() - .timestamp(time) - .command(command.trim_end().to_string()) - .duration(duration); - - imported.build().into() -} - -fn unmetafy(line: &[u8]) -> Option<Cow<'_, str>> { - if line.contains(&0x83) { - let mut s = Vec::with_capacity(line.len()); - let mut is_meta = false; - for ch in line { - if *ch == 0x83 { - is_meta = true; - } else if is_meta { - is_meta = false; - s.push(*ch ^ 32); - } else { - s.push(*ch) - } - } - String::from_utf8(s).ok().map(Cow::Owned) - } else { - std::str::from_utf8(line).ok().map(Cow::Borrowed) - } -} - -#[cfg(test)] -mod test { - use itertools::assert_equal; - - use crate::import::tests::TestLoader; - - use super::*; - - #[test] - fn test_parse_extended_simple() { - let parsed = parse_extended("1613322469:0;cargo install atuin", 0); - - assert_eq!(parsed.command, "cargo install atuin"); - assert_eq!(parsed.duration, 0); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); - - assert_eq!(parsed.command, "cargo install atuin;cargo update"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); - - assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); - - assert_eq!(parsed.command, "cargo install \\n atuin"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - } - - #[tokio::test] - async fn test_parse_file() { - let bytes = r": 1613322469:0;cargo install atuin -: 1613322469:10;cargo install atuin; \\ -cargo update -: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 4); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo install atuin; \\\ncargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - } - - #[tokio::test] - async fn test_parse_metafied() { - let bytes = - b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 2); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["echo 你好", "ls ~/音乐"], - ); - } -} diff --git a/crates/turtle/src/atuin_client/import/zsh_histdb.rs b/crates/turtle/src/atuin_client/import/zsh_histdb.rs deleted file mode 100644 index 46622e32..00000000 --- a/crates/turtle/src/atuin_client/import/zsh_histdb.rs +++ /dev/null @@ -1,249 +0,0 @@ -// import old shell history from zsh-histdb! -// automatically hoover up all that we can find - -// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based -// on the schema from 2022-05-01 -// -// I have run into some histories that will not import b/c of non UTF-8 characters. -// - -// -// An Example sqlite query for hsitdb data: -// -//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir -// -// -// select -// history.id, -// history.start_time, -// places.host, -// places.dir, -// commands.argv -// from history -// left join commands on history.command_id = commands.id -// left join places on history.place_id = places.id ; -// -// CREATE TABLE history (id integer primary key autoincrement, -// session int, -// command_id int references commands (id), -// place_id int references places (id), -// exit_status int, -// start_time int, -// duration int); -// - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use crate::atuin_common::utils::uuid_v7; -use directories::UserDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::PrimitiveDateTime; - -use super::Importer; -use crate::atuin_client::history::History; -use crate::atuin_client::import::Loader; -use crate::atuin_client::utils::{get_hostname, get_username}; - -#[derive(sqlx::FromRow, Debug)] -pub(crate) struct HistDbEntryCount { - pub(crate) count: usize, -} - -#[derive(sqlx::FromRow, Debug)] -pub(crate) struct HistDbEntry { - pub(crate) id: i64, - pub(crate) start_time: PrimitiveDateTime, - pub(crate) host: Vec<u8>, - pub(crate) dir: Vec<u8>, - pub(crate) argv: Vec<u8>, - pub(crate) duration: i64, - pub(crate) exit_status: i64, - pub(crate) session: i64, -} - -#[derive(Debug)] -pub(crate) struct ZshHistDb { - histdb: Vec<HistDbEntry>, - username: String, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { - let query = r#" - SELECT - history.id, history.start_time, history.duration, places.host, places.dir, - commands.argv, history.exit_status, history.session - FROM history - LEFT JOIN commands ON history.command_id = commands.id - LEFT JOIN places ON history.place_id = places.id - ORDER BY history.start_time - "#; - let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl ZshHistDb { - pub(crate) fn histpath_candidate() -> PathBuf { - // By default histdb database is `${HOME}/.histdb/zsh-history.db` - // This can be modified by ${HISTDB_FILE} - // - // if [[ -z ${HISTDB_FILE} ]]; then - // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" - let user_dirs = UserDirs::new().unwrap(); // should catch error here? - let home_dir = user_dirs.home_dir(); - std::env::var("HISTDB_FILE") - .as_ref() - .map(|x| Path::new(x).to_path_buf()) - .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) - } - pub(crate) fn histpath() -> Result<PathBuf> { - let histdb_path = ZshHistDb::histpath_candidate(); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!( - "Could not find history file. Try setting $HISTDB_FILE" - )) - } - } -} - -#[async_trait] -impl Importer for ZshHistDb { - // Not sure how this is used - const NAME: &'static str = "zsh_histdb"; - - /// Creates a new ZshHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result<Self> { - let dbpath = ZshHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - username: get_username(), - }) - } - - async fn entries(&mut self) -> Result<usize> { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut session_map = HashMap::new(); - for entry in self.histdb { - let command = match std::str::from_utf8(&entry.argv) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let cwd = match std::str::from_utf8(&entry.dir) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let hostname = format!( - "{}:{}", - String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), - self.username - ); - let session = session_map.entry(entry.session).or_insert_with(uuid_v7); - - let imported = History::import() - .timestamp(entry.start_time.assume_utc()) - .command(command) - .cwd(cwd) - .duration(entry.duration * 1_000_000_000) - .exit(entry.exit_status) - .session(session.as_simple().to_string()) - .hostname(hostname) - .build(); - h.push(imported.into()).await?; - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use super::*; - use sqlx::sqlite::SqlitePoolOptions; - use std::env; - #[tokio::test(flavor = "multi_thread")] - #[expect(unsafe_code)] - async fn test_env_vars() { - let test_env_db = "nonstd-zsh-history.db"; - let key = "HISTDB_FILE"; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var(key, test_env_db) }; - - // test the env got set - assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); - - // test histdb returns the proper db from previous step - let histdb_path = ZshHistDb::histpath_candidate(); - assert_eq!(histdb_path.to_str().unwrap(), test_env_db); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_import() { - let pool: SqlitePool = SqlitePoolOptions::new() - .min_connections(2) - .connect(":memory:") - .await - .unwrap(); - - // sql dump directly from a test database. - let db_sql = r#" - PRAGMA foreign_keys=OFF; - BEGIN TRANSACTION; - CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); - INSERT INTO commands VALUES(1,'pwd'); - INSERT INTO commands VALUES(2,'curl google.com'); - INSERT INTO commands VALUES(3,'bash'); - CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); - INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); - CREATE TABLE history (id integer primary key autoincrement, - session int, - command_id int references commands (id), - place_id int references places (id), - exit_status int, - start_time int, - duration int); - INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); - INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); - INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); - DELETE FROM sqlite_sequence; - INSERT INTO sqlite_sequence VALUES('commands',3); - INSERT INTO sqlite_sequence VALUES('places',3); - INSERT INTO sqlite_sequence VALUES('history',3); - CREATE INDEX hist_time on history(start_time); - CREATE INDEX place_dir on places(dir); - CREATE INDEX place_host on places(host); - CREATE INDEX history_command_place on history(command_id, place_id); - COMMIT; "#; - - sqlx::query(db_sql).execute(&pool).await.unwrap(); - - // test histdb iterator - let histdb_vec = hist_from_db_conn(pool).await.unwrap(); - let histdb = ZshHistDb { - histdb: histdb_vec, - username: get_username(), - }; - - println!("h: {:#?}", histdb.histdb); - println!("counter: {:?}", histdb.histdb.len()); - for i in histdb.histdb { - println!("{i:?}"); - } - } -} diff --git a/crates/turtle/src/atuin_client/meta.rs b/crates/turtle/src/atuin_client/meta.rs index f3815b9e..c5c89512 100644 --- a/crates/turtle/src/atuin_client/meta.rs +++ b/crates/turtle/src/atuin_client/meta.rs @@ -12,7 +12,6 @@ use uuid::Uuid; const KEY_HOST_ID: &str = "host_id"; const KEY_LAST_SYNC: &str = "last_sync_time"; -const KEY_SESSION: &str = "session"; pub(crate) struct MetaStore { pool: SqlitePool, @@ -98,15 +97,6 @@ impl MetaStore { Ok(()) } - pub(crate) async fn delete(&self, key: &str) -> Result<()> { - sqlx::query("DELETE FROM meta WHERE key = ?1") - .bind(key) - .execute(&self.pool) - .await?; - - Ok(()) - } - // Typed accessors pub(crate) async fn host_id(&self) -> Result<HostId> { diff --git a/crates/turtle/src/atuin_client/mod.rs b/crates/turtle/src/atuin_client/mod.rs index a4323f56..530b7d81 100644 --- a/crates/turtle/src/atuin_client/mod.rs +++ b/crates/turtle/src/atuin_client/mod.rs @@ -2,13 +2,10 @@ pub(crate) mod api_client; pub(crate) mod database; -pub(crate) mod distro; pub(crate) mod encryption; pub(crate) mod history; -pub(crate) mod import; pub(crate) mod meta; pub(crate) mod ordering; -pub(crate) mod plugin; pub(crate) mod record; pub(crate) mod secrets; pub(crate) mod settings; diff --git a/crates/turtle/src/atuin_client/plugin.rs b/crates/turtle/src/atuin_client/plugin.rs deleted file mode 100644 index e97b1dbf..00000000 --- a/crates/turtle/src/atuin_client/plugin.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug, Clone)] -pub(crate) struct OfficialPlugin { - pub(crate) name: String, - pub(crate) description: String, - pub(crate) install_message: String, -} - -impl OfficialPlugin { - pub(crate) fn new(name: &str, description: &str, install_message: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - install_message: install_message.to_string(), - } - } -} - -pub(crate) struct OfficialPluginRegistry { - plugins: HashMap<String, OfficialPlugin>, -} - -impl OfficialPluginRegistry { - pub(crate) fn new() -> Self { - let mut registry = Self { - plugins: HashMap::new(), - }; - - // Register official plugins - registry.register_official_plugins(); - - registry - } - - fn register_official_plugins(&mut self) { - // atuin-update plugin - self.plugins.insert( - "update".to_string(), - OfficialPlugin::new( - "update", - "Update atuin to the latest version", - "The 'atuin update' command is provided by the atuin-update plugin.\n\ - It is only installed if you used the install script\n \ - If you used a package manager (brew, apt, etc), please continue to use it for updates", - ), - ); - } - - pub(crate) fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { - self.plugins.get(name) - } - - pub(crate) fn is_official_plugin(&self, name: &str) -> bool { - self.plugins.contains_key(name) - } - - pub(crate) fn get_install_message(&self, name: &str) -> Option<&str> { - self.plugins - .get(name) - .map(|plugin| plugin.install_message.as_str()) - } -} - -impl Default for OfficialPluginRegistry { - fn default() -> Self { - Self::new() - } -} - -pub(crate) struct PluginContext { - #[cfg(windows)] - _update_on_windows: Option<UpdateOnWindowsContext>, -} - -impl PluginContext { - pub(crate) fn new(_subcommand: &str) -> Self { - PluginContext { - #[cfg(windows)] - _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), - } - } -} - -impl Drop for PluginContext { - fn drop(&mut self) {} -} - -#[cfg(windows)] -struct UpdateOnWindowsContext { - initial_exe: Option<std::path::PathBuf>, -} - -#[cfg(windows)] -impl UpdateOnWindowsContext { - const OLD_FILE_NAME: &'static str = "atuin.old"; - - pub(crate) fn new() -> Self { - // Windows doesn't let you overwrite a running exe, but it lets you rename it, - // so make some room for atuin-update to install the new version. - let initial_exe = std::env::current_exe().ok().and_then(|exe| { - std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; - Some(exe) - }); - - Self { initial_exe } - } -} - -#[cfg(windows)] -impl Drop for UpdateOnWindowsContext { - fn drop(&mut self) { - if let Some(exe) = &self.initial_exe - && !exe.exists() - { - // The update failed, roll back the current exe to its initial name. - std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { - eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); - }); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_registry_creation() { - let registry = OfficialPluginRegistry::new(); - assert!(registry.is_official_plugin("update")); - assert!(!registry.is_official_plugin("nonexistent")); - } - - #[test] - fn test_get_plugin() { - let registry = OfficialPluginRegistry::new(); - let plugin = registry.get_plugin("update"); - assert!(plugin.is_some()); - assert_eq!(plugin.unwrap().name, "update"); - } - - #[test] - fn test_get_install_message() { - let registry = OfficialPluginRegistry::new(); - let message = registry.get_install_message("update"); - assert!(message.is_some()); - assert!(message.unwrap().contains("atuin-update")); - } -} diff --git a/crates/turtle/src/atuin_client/record/encryption.rs b/crates/turtle/src/atuin_client/record/encryption.rs index dabd0fa7..70723bb7 100644 --- a/crates/turtle/src/atuin_client/record/encryption.rs +++ b/crates/turtle/src/atuin_client/record/encryption.rs @@ -328,7 +328,7 @@ mod tests { .version("v0".to_owned()) .tag("kv".to_owned()) .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) + .timestamp(1_687_244_806_000_000) .data(DecryptedData(vec![1, 2, 3, 4])) .idx(0) .build(); @@ -351,7 +351,7 @@ mod tests { .version("v0".to_owned()) .tag("kv".to_owned()) .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) + .timestamp(1_687_244_806_000_000) .data(DecryptedData(vec![1, 2, 3, 4])) .idx(0) .build(); diff --git a/crates/turtle/src/atuin_client/record/sync.rs b/crates/turtle/src/atuin_client/record/sync.rs index 4284da87..9a7abfba 100644 --- a/crates/turtle/src/atuin_client/record/sync.rs +++ b/crates/turtle/src/atuin_client/record/sync.rs @@ -1,7 +1,7 @@ // do a sync :O use std::{cmp::Ordering, fmt::Write}; -use eyre::Result; +use eyre::{OptionExt, Result}; use thiserror::Error; use tracing::error; @@ -60,15 +60,15 @@ pub(crate) enum Operation { pub(crate) async fn build_client(settings: &Settings) -> Result<Client<'_>, SyncError> { Client::new( - &settings.sync_address, + &settings.sync.address, + settings.network_connect_timeout, + settings.network_timeout, settings - .sync_auth() - .await + .sync + .user_id() .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })? - .into_auth_token() + .ok_or_eyre("No sync user-id set") .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, - settings.network_connect_timeout, - settings.network_timeout, ) .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) } diff --git a/crates/turtle/src/atuin_client/secrets.rs b/crates/turtle/src/atuin_client/secrets.rs index 09058071..30723890 100644 --- a/crates/turtle/src/atuin_client/secrets.rs +++ b/crates/turtle/src/atuin_client/secrets.rs @@ -3,41 +3,54 @@ use regex::RegexSet; use std::sync::LazyLock; +#[cfg(test)] pub(crate) enum TestValue<'a> { Single(&'a str), Multiple(&'a [&'a str]), } +#[cfg(test)] +type SpType<'a> = &'a [(&'a str, &'a str, TestValue<'a>)]; + +#[cfg(not(test))] +type SpType<'a> = &'a [(&'a str, &'a str)]; + /// A list of `(name, regex, test)`, where `test` should match against `regex`. -pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ +pub(crate) static SECRET_PATTERNS: SpType = &[ ( "AWS Access Key ID", "A[KS]IA[0-9A-Z]{16}", + #[cfg(test)] TestValue::Single("AKIAIOSFODNN7EXAMPLE"), ), ( "AWS Secret Access Key env var", "AWS_SECRET_ACCESS_KEY", + #[cfg(test)] TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), ), ( "AWS Session Token env var", "AWS_SESSION_TOKEN", + #[cfg(test)] TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), ), ( "Microsoft Azure secret access key env var", "AZURE_.*_KEY", + #[cfg(test)] TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), ), ( "Google cloud platform key env var", "GOOGLE_SERVICE_ACCOUNT_KEY", + #[cfg(test)] TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), ), ( "Atuin login", r"atuin\s+login", + #[cfg(test)] TestValue::Single( "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", ), @@ -45,11 +58,13 @@ pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ ( "GitHub PAT (old)", "ghp_[a-zA-Z0-9]{36}", + #[cfg(test)] TestValue::Single("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"), // legit, I expired it ), ( "GitHub PAT (new)", "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", + #[cfg(test)] TestValue::Multiple(&[ "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", "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", // also legit, also expired @@ -58,16 +73,19 @@ pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ ( "GitHub OAuth Access Token", "gho_[A-Za-z0-9]{36}", + #[cfg(test)] TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token ), ( "GitHub OAuth Access Token (user)", "ghu_[A-Za-z0-9]{36}", + #[cfg(test)] TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token ), ( "GitHub App Installation Access Token", "ghs_[A-Za-z0-9._-]{36,}", + #[cfg(test)] TestValue::Multiple(&[ "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data @@ -76,6 +94,7 @@ pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ ( "GitHub Refresh Token", "ghr_[A-Za-z0-9]{76}", + #[cfg(test)] TestValue::Single( "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", ), // not a real token @@ -83,26 +102,31 @@ pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ ( "GitHub App Installation Access Token v1", "v1\\.[0-9A-Fa-f]{40}", + #[cfg(test)] TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token ), ( "GitLab PAT", "glpat-[a-zA-Z0-9_]{20}", + #[cfg(test)] TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), ), ( "Slack OAuth v2 bot", "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + #[cfg(test)] TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), ), ( "Slack OAuth v2 user token", "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + #[cfg(test)] TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), ), ( "Slack webhook", "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", + #[cfg(test)] TestValue::Single( "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", ), @@ -110,26 +134,31 @@ pub(crate) static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ ( "Stripe test key", "sk_test_[0-9a-zA-Z]{24}", + #[cfg(test)] TestValue::Single("sk_test_1234567890abcdefghijklmnop"), ), ( "Stripe live key", "sk_live_[0-9a-zA-Z]{24}", + #[cfg(test)] TestValue::Single("sk_live_1234567890abcdefghijklmnop"), ), ( "Netlify authentication token", "nf[pcoub]_[0-9a-zA-Z]{36}", + #[cfg(test)] TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), ), ( "npm token", "npm_[A-Za-z0-9]{36}", + #[cfg(test)] TestValue::Single("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"), ), ( "Pulumi personal access token", "pul-[0-9a-f]{40}", + #[cfg(test)] TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), ), ]; @@ -144,7 +173,7 @@ pub(crate) static SECRET_PATTERNS_RE: LazyLock<RegexSet> = LazyLock::new(|| { mod tests { use regex::Regex; - use crate::secrets::{SECRET_PATTERNS, TestValue}; + use crate::atuin_client::secrets::{SECRET_PATTERNS, TestValue}; #[test] fn test_secrets() { diff --git a/crates/turtle/src/atuin_client/settings.rs b/crates/turtle/src/atuin_client/settings.rs index e8ff98ee..d84e2eb0 100644 --- a/crates/turtle/src/atuin_client/settings.rs +++ b/crates/turtle/src/atuin_client/settings.rs @@ -1,8 +1,13 @@ -use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; +use crypto_secretbox::Key; +use std::{ + collections::HashMap, fmt, fs::read_to_string, io::prelude::Write, path::PathBuf, str::FromStr, + sync::OnceLock, +}; use tokio::sync::OnceCell; +use uuid::Uuid; -use crate::atuin_common::record::HostId; use crate::atuin_common::utils; +use crate::{atuin_client::encryption::decode_key, atuin_common::record::HostId}; use clap::ValueEnum; use config::{ Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, @@ -217,16 +222,6 @@ pub(crate) enum KeymapMode { Auto, } -impl KeymapMode { - pub(crate) fn as_str(&self) -> &'static str { - match self { - KeymapMode::Emacs => "EMACS", - KeymapMode::VimNormal => "VIMNORMAL", - KeymapMode::VimInsert => "VIMINSERT", - KeymapMode::Auto => "AUTO", - } - } -} // We want to translate the config to crossterm::cursor::SetCursorStyle, but // the original type does not implement trait serde::Deserialize unfortunately. @@ -257,19 +252,6 @@ pub(crate) enum CursorStyle { SteadyBar, } -impl CursorStyle { - pub(crate) fn as_str(&self) -> &'static str { - match self { - CursorStyle::DefaultUserShape => "DEFAULT", - CursorStyle::BlinkingBlock => "BLINKBLOCK", - CursorStyle::SteadyBlock => "STEADYBLOCK", - CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", - CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", - CursorStyle::BlinkingBar => "BLINKBAR", - CursorStyle::SteadyBar => "STEADYBAR", - } - } -} #[derive(Clone, Debug, Deserialize, Serialize)] pub(crate) struct Stats { @@ -330,36 +312,6 @@ impl Default for Stats { } } -/// Resolved authentication state for sync operations. -/// -/// Determined at runtime by examining which tokens are available and what -/// server the client is configured to talk to. Operations use this to pick -/// the right auth header and endpoint style. -#[cfg(feature = "sync")] -#[derive(Debug, Clone)] -pub(crate) enum SyncAuth { - /// Self-hosted Rust server. Uses `Authorization: Token <session>` and - /// legacy endpoints. - Legacy { token: String }, - - /// Not authenticated at all. Contains an actionable user-facing message. - NotLoggedIn { reason: String }, -} - -#[cfg(feature = "sync")] -impl SyncAuth { - /// Convert into the auth token type used by the API client. - /// - /// Returns an error with an actionable message for `NotLoggedIn`. - pub(crate) fn into_auth_token(self) -> Result<crate::atuin_client::api_client::AuthToken> { - use crate::atuin_client::api_client::AuthToken; - match self { - SyncAuth::Legacy { token } => Ok(AuthToken(token)), - SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), - } - } -} - #[derive(Clone, Debug, Deserialize, Default, Serialize)] #[expect(clippy::struct_excessive_bools)] pub(crate) struct Keys { @@ -781,14 +733,6 @@ impl UiColumn { column_type, } } - - pub(crate) fn with_width(column_type: UiColumnType, width: u16) -> Self { - Self { - column_type, - width, - expand: column_type == UiColumnType::Command, - } - } } // Custom deserialize to handle both string and object formats: @@ -902,6 +846,82 @@ impl Default for Ui { } } +/// Sync-specific settings. +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub(crate) struct Sync { + /// The sync address for atuin. + pub(crate) address: String, + + #[serde(default)] + pub(crate) frequency: String, + + #[serde(default)] + pub(crate) auto: bool, + + #[serde(default)] + pub(crate) user_id_path: Option<PathBuf>, + + #[serde(default)] + pub(crate) encryption_key_path: Option<PathBuf>, +} + +impl Sync { + fn try_read_file(file: Option<&PathBuf>) -> Result<Option<String>> { + if let Some(path) = file { + if path.try_exists()? { + let user = read_to_string(path)?; + + if user.is_empty() { + Ok(None) + } else { + Ok(Some(user)) + } + } else { + // It's okay that the file doesn't exist. + // The important part is to error out if we can't access it (e.g. Because of missing + // permissions). + Ok(None) + } + } else { + Ok(None) + } + } + + pub(crate) fn have_sync_user(&self) -> Result<bool> { + let sa = self.user_id()?; + Ok(sa.is_some()) + } + + pub(crate) fn user_id(&self) -> Result<Option<Uuid>> { + Self::try_read_file(self.user_id_path.as_ref())? + .map(|file| Uuid::try_parse(&file).map_err(Into::into)) + .transpose() + } + pub(crate) fn encryption_key(&self) -> Result<Option<Key>> { + Self::try_read_file(self.encryption_key_path.as_ref())? + .map(decode_key) + .transpose() + } + + pub(crate) async fn should_sync(&self) -> Result<bool> { + if !self.auto || !self.have_sync_user()? { + return Ok(false); + } + + if self.frequency == "0" || self.frequency.is_empty() { + return Ok(true); + } + + match parse_duration(self.frequency.as_str()) { + Ok(d) => { + let d = time::Duration::try_from(d)?; + Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } +} + #[derive(Clone, Debug, Deserialize, Serialize)] #[expect(clippy::struct_excessive_bools)] pub(crate) struct Settings { @@ -909,15 +929,9 @@ pub(crate) struct Settings { pub(crate) dialect: Dialect, pub(crate) timezone: Timezone, pub(crate) style: Style, - pub(crate) auto_sync: bool, - - /// The sync address for atuin. - pub(crate) sync_address: String, - pub(crate) sync_frequency: String, pub(crate) db_path: String, pub(crate) record_store_path: String, - pub(crate) key_path: String, pub(crate) search_mode: SearchMode, pub(crate) filter_mode: Option<FilterMode>, pub(crate) filter_mode_shell_up_key_binding: Option<FilterMode>, @@ -963,6 +977,9 @@ pub(crate) struct Settings { pub(crate) command_chaining: bool, #[serde(default)] + pub(crate) sync: Sync, + + #[serde(default)] pub(crate) stats: Stats, #[serde(default)] @@ -994,24 +1011,6 @@ pub(crate) struct Settings { } impl Settings { - pub(crate) fn utc() -> Self { - Self::builder() - .expect("Could not build default") - .set_override("timezone", "0") - .expect("failed to override timezone with UTC") - .build() - .expect("Could not build config") - .try_deserialize() - .expect("Could not deserialize config") - } - - pub(crate) fn effective_data_dir() -> PathBuf { - DATA_DIR - .get() - .cloned() - .unwrap_or_else(crate::atuin_common::utils::data_dir) - } - // -- Meta store: lazily initialized on first access -- pub(crate) async fn meta_store() -> Result<&'static crate::atuin_client::meta::MetaStore> { @@ -1037,34 +1036,6 @@ impl Settings { Self::meta_store().await?.save_sync_time().await } - pub(crate) async fn should_sync(&self) -> Result<bool> { - if !self.auto_sync || !self.have_sync_key().await? { - return Ok(false); - } - - if self.sync_frequency == "0" { - return Ok(true); - } - - match parse_duration(self.sync_frequency.as_str()) { - Ok(d) => { - let d = time::Duration::try_from(d)?; - Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) - } - Err(e) => Err(eyre!("failed to check sync: {}", e)), - } - } - - pub(crate) async fn have_sync_key(&self) -> Result<bool> { - let sa = self.sync_auth().await?; - Ok(matches!(sa, SyncAuth::Legacy { .. })) - } - - pub(crate) async fn sync_auth(&self) -> Result<SyncAuth> { - // TODO(@bpeetz): Add this <2026-06-11> - todo!() - } - pub(crate) fn default_filter_mode(&self, git_root: bool) -> FilterMode { self.filter_mode .filter(|x| self.search.filters.contains(x)) @@ -1213,7 +1184,7 @@ impl Settings { let config_dir = crate::atuin_common::utils::config_dir(); create_dir_all(&config_dir) - .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; + .wrap_err_with(|| format!("could not create dir {}", config_dir.display()))?; let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { PathBuf::from(p) @@ -1414,12 +1385,8 @@ impl Settings { } pub(crate) fn paths_ok(&self) -> bool { - let paths = [ - &self.db_path, - &self.record_store_path, - &self.key_path, - &self.meta.db_path, - ]; + // TODO(@bpeetz): Add the `sync.*` paths <2026-06-11> + let paths = [&self.db_path, &self.record_store_path, &self.meta.db_path]; paths.iter().all(|p| !utils::broken_symlink(p)) } } @@ -1437,20 +1404,6 @@ impl Default for Settings { } } -/// Initialize the meta store configuration for testing. -/// -/// This should only be used in tests. It allows tests to bypass the normal -/// Settings::new() flow while still being able to use Settings::host_id() -/// and other meta store dependent functions. -/// -/// # Safety -/// This function is not thread-safe with concurrent calls to Settings::new() -/// or other meta store initialization. Only call from tests. -#[doc(hidden)] -pub(crate) fn init_meta_config_for_testing(meta_db_path: impl Into<String>, local_timeout: f64) { - META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); -} - #[cfg(test)] pub(crate) fn test_local_timeout() -> f64 { std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") diff --git a/crates/turtle/src/atuin_common/api.rs b/crates/turtle/src/atuin_common/api.rs index c18db04f..0868943d 100644 --- a/crates/turtle/src/atuin_common/api.rs +++ b/crates/turtle/src/atuin_common/api.rs @@ -11,28 +11,6 @@ pub(crate) static ATUIN_VERSION: LazyLock<Version> = LazyLock::new(|| Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver")); #[derive(Debug, Serialize, Deserialize)] -pub(crate) struct RegisterResponse { - pub(crate) session: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct ChangePasswordRequest { - pub(crate) current_password: String, - pub(crate) new_password: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct LoginRequest { - pub(crate) username: String, - pub(crate) password: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct LoginResponse { - pub(crate) session: String, -} - -#[derive(Debug, Serialize, Deserialize)] pub(crate) struct ErrorResponse<'a> { pub(crate) reason: Cow<'a, str>, } @@ -42,8 +20,3 @@ pub(crate) struct IndexResponse { pub(crate) homage: String, pub(crate) version: String, } - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct MeResponse { - pub(crate) username: String, -} diff --git a/crates/turtle/src/atuin_common/calendar.rs b/crates/turtle/src/atuin_common/calendar.rs deleted file mode 100644 index befe8c2e..00000000 --- a/crates/turtle/src/atuin_common/calendar.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Calendar data -use serde::{Serialize, Deserialize}; - -pub(crate) enum TimePeriod { - YEAR, - MONTH, - DAY, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct TimePeriodInfo { - pub(crate) count: u64, - - // TODO: Use this for merkle tree magic - pub(crate) hash: String, -} diff --git a/crates/turtle/src/atuin_common/shell.rs b/crates/turtle/src/atuin_common/shell.rs index dbd9b982..d259b99e 100644 --- a/crates/turtle/src/atuin_common/shell.rs +++ b/crates/turtle/src/atuin_common/shell.rs @@ -1,9 +1,3 @@ -use std::{ffi::OsStr, path::Path, process::Command}; - -use serde::Serialize; -use sysinfo::{Process, System, get_current_pid}; -use thiserror::Error; - #[derive(PartialEq)] pub(crate) enum Shell { Sh, @@ -35,91 +29,15 @@ impl std::fmt::Display for Shell { } } -#[derive(Debug, Error, Serialize)] -pub(crate) enum ShellError { - #[error("shell not supported")] - NotSupported, - - #[error("failed to execute shell command: {0}")] - ExecError(String), -} - impl Shell { - pub(crate) fn current() -> Shell { - let sys = System::new_all(); - - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - let parent = sys - .process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist"); - - let shell = parent.name().trim().to_lowercase(); - let shell = shell.strip_prefix('-').unwrap_or(&shell); - - Shell::from_string(shell.to_string()) - } - pub(crate) fn from_env() -> Shell { std::env::var("ATUIN_SHELL").map_or(Shell::Unknown, |shell| { - Shell::from_string(shell.trim().to_lowercase()) + Shell::from_string(shell.trim().to_lowercase().as_str()) }) } - pub(crate) fn config_file(&self) -> Option<std::path::PathBuf> { - let mut path = if let Some(base) = directories::BaseDirs::new() { - base.home_dir().to_owned() - } else { - return None; - }; - - // TODO: handle all shells - match self { - Shell::Bash => path.push(".bashrc"), - Shell::Zsh => path.push(".zshrc"), - Shell::Fish => path.push(".config/fish/config.fish"), - - _ => return None, - }; - - Some(path) - } - - /// Best-effort attempt to determine the default shell - /// This implementation will be different across different platforms - /// Caller should ensure to handle Shell::Unknown correctly - pub(crate) fn default_shell() -> Result<Shell, ShellError> { - let sys = System::name().unwrap_or("".to_string()).to_lowercase(); - - // TODO: Support Linux - // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues - let path = if sys.contains("darwin") { - // This works in my testing so far - Shell::Sh.run_interactive([ - "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", - ])? - } else if cfg!(windows) { - return Ok(Shell::Powershell); - } else { - Shell::Sh.run_interactive(["getent passwd $LOGNAME | cut -d: -f7"])? - }; - - let path = Path::new(path.trim()); - let shell = path.file_name(); - - if shell.is_none() { - return Err(ShellError::NotSupported); - } - - Ok(Shell::from_string( - shell.unwrap().to_string_lossy().to_string(), - )) - } - - pub(crate) fn from_string(name: String) -> Shell { - match name.as_str() { + pub(crate) fn from_string(name: &str) -> Shell { + match name { "bash" => Shell::Bash, "fish" => Shell::Fish, "zsh" => Shell::Zsh, @@ -131,53 +49,4 @@ impl Shell { _ => Shell::Unknown, } } - - /// Returns true if the shell is posix-like - /// Note that while fish is not posix compliant, it behaves well enough for our current - /// featureset that this does not matter. - pub(crate) fn is_posixish(&self) -> bool { - matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) - } - - pub(crate) fn run_interactive<I, S>(&self, args: I) -> Result<String, ShellError> - where - I: IntoIterator<Item = S>, - S: AsRef<OsStr>, - { - let shell = self.to_string(); - let output = if self == &Self::Powershell { - Command::new(shell) - .args(args) - .output() - .map_err(|e| ShellError::ExecError(e.to_string()))? - } else { - Command::new(shell) - .arg("-ic") - .args(args) - .output() - .map_err(|e| ShellError::ExecError(e.to_string()))? - }; - - Ok(String::from_utf8(output.stdout).unwrap()) - } -} - -pub(crate) fn shell_name(parent: Option<&Process>) -> String { - let sys = System::new_all(); - - let parent = if let Some(parent) = parent { - parent - } else { - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - sys.process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist") - }; - - let shell = parent.name().trim().to_lowercase(); - let shell = shell.strip_prefix('-').unwrap_or(&shell); - - shell.to_string() } diff --git a/crates/turtle/src/atuin_common/utils.rs b/crates/turtle/src/atuin_common/utils.rs index 09718241..ba0c8eb7 100644 --- a/crates/turtle/src/atuin_common/utils.rs +++ b/crates/turtle/src/atuin_common/utils.rs @@ -2,30 +2,8 @@ use std::borrow::Cow; use std::env; use std::path::{Path, PathBuf}; -use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; -use getrandom::getrandom; use uuid::Uuid; -/// Generate N random bytes, using a cryptographically secure source -pub(crate) fn crypto_random_bytes<const N: usize>() -> [u8; N] { - // rand say they are in principle safe for crypto purposes, but that it is perhaps a better - // idea to use getrandom for things such as passwords. - let mut ret = [0u8; N]; - - getrandom(&mut ret).expect("Failed to generate random bytes!"); - - ret -} - -/// Generate N random bytes using a cryptographically secure source, return encoded as a string -pub(crate) fn crypto_random_string<const N: usize>() -> String { - let bytes = crypto_random_bytes::<N>(); - - // We only use this to create a random string, and won't be reversing it to find the original - // data - no padding is OK there. It may be in URLs. - BASE64_URL_SAFE_NO_PAD.encode(bytes) -} - pub(crate) fn uuid_v7() -> Uuid { Uuid::now_v7() } diff --git a/crates/turtle/src/atuin_daemon/components/history.rs b/crates/turtle/src/atuin_daemon/components/history.rs index ec41977f..b71543c1 100644 --- a/crates/turtle/src/atuin_daemon/components/history.rs +++ b/crates/turtle/src/atuin_daemon/components/history.rs @@ -5,7 +5,6 @@ use std::{pin::Pin, sync::Arc}; use crate::atuin_client::{ - database::Database, history::{History, HistoryId, store::HistoryStore}, settings::Settings, }; diff --git a/crates/turtle/src/atuin_daemon/components/search.rs b/crates/turtle/src/atuin_daemon/components/search.rs index 17decdad..832d05d8 100644 --- a/crates/turtle/src/atuin_daemon/components/search.rs +++ b/crates/turtle/src/atuin_daemon/components/search.rs @@ -5,7 +5,6 @@ use std::{pin::Pin, sync::Arc}; -use crate::atuin_client::database::Database; use eyre::Result; use tokio::sync::RwLock; use tokio_stream::Stream; @@ -394,15 +393,6 @@ fn convert_filter_mode( } } -#[cfg(windows)] -pub(crate) fn with_trailing_slash(s: &str) -> String { - if s.ends_with('\\') { - s.to_string() - } else { - format!("{}\\", s) - } -} - #[cfg(not(windows))] pub(crate) fn with_trailing_slash(s: &str) -> String { if s.ends_with('/') { diff --git a/crates/turtle/src/atuin_daemon/components/sync.rs b/crates/turtle/src/atuin_daemon/components/sync.rs index fdd00b5f..fbfbbd67 100644 --- a/crates/turtle/src/atuin_daemon/components/sync.rs +++ b/crates/turtle/src/atuin_daemon/components/sync.rs @@ -141,7 +141,7 @@ async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand> // Skip periodic ticks if auto_sync is disabled AND we're not retrying // a previous failure. Retries must continue regardless of auto_sync. - if !settings.auto_sync && sync_state == SyncState::Idle { + if !settings.sync.auto && sync_state == SyncState::Idle { tracing::debug!("auto_sync disabled, skipping periodic sync tick"); continue; } @@ -190,7 +190,7 @@ async fn do_sync_tick( tracing::info!("sync tick"); // Check if logged in - let logged_in = match settings.have_sync_key().await { + let logged_in = match settings.sync.have_sync_user() { Ok(v) => v, Err(e) => { tracing::warn!("failed to check login status, skipping sync tick: {e}"); diff --git a/crates/turtle/src/atuin_daemon/daemon.rs b/crates/turtle/src/atuin_daemon/daemon.rs index 3268548e..7583c197 100644 --- a/crates/turtle/src/atuin_daemon/daemon.rs +++ b/crates/turtle/src/atuin_daemon/daemon.rs @@ -11,7 +11,7 @@ use std::sync::Arc; use crate::atuin_client::{ - database::Sqlite as HistoryDatabase, encryption, record::sqlite_store::SqliteStore, + database::ClientSqlite as HistoryDatabase, encryption, record::sqlite_store::SqliteStore, settings::Settings, }; use eyre::{Context, Result}; diff --git a/crates/turtle/src/atuin_daemon/mod.rs b/crates/turtle/src/atuin_daemon/mod.rs index eac28f78..6037b5a8 100644 --- a/crates/turtle/src/atuin_daemon/mod.rs +++ b/crates/turtle/src/atuin_daemon/mod.rs @@ -1,4 +1,4 @@ -use crate::atuin_client::database::Sqlite as HistoryDatabase; +use crate::atuin_client::database::ClientSqlite as HistoryDatabase; use crate::atuin_client::record::sqlite_store::SqliteStore; use crate::atuin_client::settings::{Settings, watcher::global_settings_watcher}; use eyre::Result; diff --git a/crates/turtle/src/atuin_daemon/search/index.rs b/crates/turtle/src/atuin_daemon/search/index.rs index 446d7992..a23b3133 100644 --- a/crates/turtle/src/atuin_daemon/search/index.rs +++ b/crates/turtle/src/atuin_daemon/search/index.rs @@ -14,8 +14,7 @@ use std::{ use crate::atuin_client::settings::Search; use crate::{ - atuin_client::history::{History, is_known_agent}, - atuin_daemon::components::search::with_trailing_slash, + atuin_client::history::History, atuin_daemon::components::search::with_trailing_slash, }; use atuin_nucleo::{Injector, Nucleo, pattern}; use dashmap::DashMap; @@ -195,7 +194,11 @@ impl CommandData { /// Check if any invocation matches a directory prefix (workspace/git root). /// O(n) where n = number of unique directories for this command. - pub(crate) fn has_invocation_in_workspace(&self, prefix: &str, interner: &ThreadedRodeo) -> bool { + pub(crate) fn has_invocation_in_workspace( + &self, + prefix: &str, + interner: &ThreadedRodeo, + ) -> bool { self.directories .iter() .any(|&spur| interner.resolve(&spur).starts_with(prefix)) @@ -289,10 +292,6 @@ impl SearchIndex { /// If the command already exists, updates its invocation data. /// If it's a new command, adds it to both the map and Nucleo. pub(crate) fn add_history(&self, history: &History) { - if is_known_agent(&history.author) { - return; - } - let command = history.command.as_str(); // DashMap with Arc<str> keys can be looked up with &str via Borrow trait diff --git a/crates/turtle/src/atuin_server/database/calendar.rs b/crates/turtle/src/atuin_server/database/calendar.rs deleted file mode 100644 index f1c78262..00000000 --- a/crates/turtle/src/atuin_server/database/calendar.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Calendar data - -use serde::{Deserialize, Serialize}; -use time::Month; - -pub(crate) enum TimePeriod { - Year, - Month { year: i32 }, - Day { year: i32, month: Month }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct TimePeriodInfo { - pub(crate) count: u64, - - // TODO: Use this for merkle tree magic - pub(crate) hash: String, -} diff --git a/crates/turtle/src/atuin_server/database/db/mod.rs b/crates/turtle/src/atuin_server/database/db/mod.rs index 4ec51bf1..5b3c169b 100644 --- a/crates/turtle/src/atuin_server/database/db/mod.rs +++ b/crates/turtle/src/atuin_server/database/db/mod.rs @@ -17,13 +17,13 @@ mod wrappers; const MIN_PG_VERSION: u32 = 14; #[derive(Clone)] -pub struct Database { +pub struct ServerPostgres { pool: sqlx::Pool<sqlx::postgres::Postgres>, /// Optional read replica pool for read-only queries read_pool: Option<sqlx::Pool<sqlx::postgres::Postgres>>, } -impl Database { +impl ServerPostgres { /// Returns the appropriate pool for read operations. /// Uses read_pool if available, otherwise falls back to the primary pool. fn read_pool(&self) -> &sqlx::Pool<sqlx::postgres::Postgres> { @@ -31,7 +31,7 @@ impl Database { } } -impl Database { +impl ServerPostgres { pub(crate) async fn new(settings: &DbSettings) -> DbResult<Self> { let pool = PgPoolOptions::new() .max_connections(100) @@ -138,7 +138,7 @@ impl Database { .entry((i.host.id, &i.tag)) .and_modify(|e| { if i.idx > *e { - *e = i.idx + *e = i.idx; } }) .or_insert(i.idx); @@ -229,7 +229,8 @@ impl Database { // 3. If we don't use the cache, read from the store table // IDX_CACHE_ROLLOUT should be between 0 and 100. - let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string()); + let idx_cache_rollout = + std::env::var("IDX_CACHE_ROLLOUT").unwrap_or_else(|_| "0".to_string()); let idx_cache_rollout = idx_cache_rollout.parse::<f64>().unwrap_or(0.0); let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0); @@ -264,7 +265,7 @@ impl Database { let mut status = RecordStatus::new(); - for i in res.iter() { + for i in &res { status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64); } diff --git a/crates/turtle/src/atuin_server/database/db/wrappers.rs b/crates/turtle/src/atuin_server/database/db/wrappers.rs index 40fd5b4a..8a52d56e 100644 --- a/crates/turtle/src/atuin_server/database/db/wrappers.rs +++ b/crates/turtle/src/atuin_server/database/db/wrappers.rs @@ -1,22 +1,8 @@ -use crate::{ - atuin_common::record::{EncryptedData, Host, Record}, - atuin_server::database::models::Session, -}; +use crate::atuin_common::record::{EncryptedData, Host, Record}; use sqlx::{Row, postgres::PgRow}; -pub struct DbSession(pub Session); pub struct DbRecord(pub Record<EncryptedData>); -impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { - fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { - Ok(Self(Session { - id: row.try_get("id")?, - user_id: row.try_get("user_id")?, - token: row.try_get("token")?, - })) - } -} - impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { let timestamp: i64 = row.try_get("timestamp")?; diff --git a/crates/turtle/src/atuin_server/database/mod.rs b/crates/turtle/src/atuin_server/database/mod.rs index a009ae1f..bb64767a 100644 --- a/crates/turtle/src/atuin_server/database/mod.rs +++ b/crates/turtle/src/atuin_server/database/mod.rs @@ -1,4 +1,3 @@ -pub(crate) mod calendar; pub(crate) mod db; pub(crate) mod models; @@ -14,7 +13,10 @@ pub(crate) enum DbError { impl Display for DbError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") + match self { + DbError::NotFound => write!(f, "Not found"), + DbError::Other(report) => write!(f, "Other: {report}"), + } } } diff --git a/crates/turtle/src/atuin_server/mod.rs b/crates/turtle/src/atuin_server/mod.rs index c96a13bc..a4b10acf 100644 --- a/crates/turtle/src/atuin_server/mod.rs +++ b/crates/turtle/src/atuin_server/mod.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::net::SocketAddr; use axum::{Router, serve}; -use database::db::Database; +use database::db::ServerPostgres; use eyre::{Context, Result}; pub(crate) mod database; @@ -78,7 +78,7 @@ pub(crate) async fn launch_metrics_server(host: String, port: u16) -> Result<()> } async fn make_router(settings: Settings) -> Result<Router, eyre::Error> { - let db = Database::new(&settings.db_settings) + let db = ServerPostgres::new(&settings.db_settings) .await .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; let r = router::router(db, settings); diff --git a/crates/turtle/src/atuin_server/router.rs b/crates/turtle/src/atuin_server/router.rs index dfc2cac4..d9cfc979 100644 --- a/crates/turtle/src/atuin_server/router.rs +++ b/crates/turtle/src/atuin_server/router.rs @@ -1,6 +1,6 @@ use crate::{ atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}, - atuin_server::database::{db::Database, models::User}, + atuin_server::database::{db::ServerPostgres, models::User}, }; use axum::{ Router, @@ -67,11 +67,11 @@ async fn semver(request: Request, next: Next) -> Response { #[derive(Clone)] pub(crate) struct AppState { - pub(crate) database: Database, + pub(crate) database: ServerPostgres, pub(crate) settings: Settings, } -pub(crate) fn router(database: Database, settings: Settings) -> Router { +pub(crate) fn router(database: ServerPostgres, settings: Settings) -> Router { let routes = Router::new() .route("/", get(handlers::index)) .route("/api/v0/{user_id}/record", post(handlers::v0::record::post)) diff --git a/crates/turtle/src/command/client.rs b/crates/turtle/src/command/client.rs index 9d5b4605..9ab28e15 100644 --- a/crates/turtle/src/command/client.rs +++ b/crates/turtle/src/command/client.rs @@ -5,7 +5,7 @@ use clap::Subcommand; use eyre::{Result, WrapErr}; use crate::atuin_client::{ - database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, + database::ClientSqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, }; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{ @@ -48,9 +48,7 @@ mod daemon; mod config; mod default_config; -mod doctor; mod history; -mod import; mod info; mod init; mod search; @@ -66,18 +64,12 @@ pub(crate) enum Cmd { #[command(subcommand)] History(history::Cmd), - /// Import shell history from file - #[command(subcommand)] - Import(import::Cmd), - - /// Calculate statistics for your history - Stats(stats::Cmd), - /// Interactive history search Search(search::Cmd), #[cfg(feature = "sync")] - #[command(flatten)] + #[command(subcommand)] + /// Request a sync or view sync status Sync(sync::Cmd), /// Manage the atuin server @@ -96,11 +88,11 @@ pub(crate) enum Cmd { #[command()] Info, - /// Run the doctor to check for common issues - #[command()] - Doctor, + /// Calculate statistics for your history + Stats(stats::Cmd), #[command()] + /// Display a recap of your last year's history Wrapped { year: Option<i32> }, /// *Experimental* Manage the background daemon @@ -113,6 +105,7 @@ pub(crate) enum Cmd { DefaultConfig, #[command(subcommand)] + /// Manage your configuration Config(config::Cmd), } @@ -131,19 +124,27 @@ impl Cmd { let runtime = runtime.enable_all().build().unwrap(); - // For non-history commands, we want to initialize logging and the theme manager before - // doing anything else. History commands are performance-sensitive and run before and after - // every shell command, so we want to skip any unnecessary initialization for them. - let settings = Settings::new().wrap_err("could not load client settings")?; - let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); - let res = runtime.block_on(self.run_inner(settings, theme_manager)); + // Start the server before descending into the client-specific setup code. + // We simply cannot setup settings or a theme on the server, because the client-specific + // stuff will error out. + let res = if let Self::Server(server) = self { + runtime.block_on(server.run()) + } else { + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. + let settings = Settings::new().wrap_err("could not load client settings")?; + let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); + + runtime.block_on(self.run_inner(settings, theme_manager)) + }; runtime.shutdown_timeout(std::time::Duration::from_millis(50)); res } - #[expect(clippy::too_many_lines, clippy::future_not_send)] + #[expect(clippy::too_many_lines)] async fn run_inner( self, mut settings: Settings, @@ -306,7 +307,6 @@ impl Cmd { init.run(&settings); return Ok(()); } - Self::Doctor => return doctor::run(&settings).await, Self::Config(config) => return config.run(&settings).await, _ => {} } @@ -314,14 +314,13 @@ impl Cmd { let db_path = PathBuf::from(settings.db_path.as_str()); let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; + let db = ClientSqlite::new(db_path, settings.local_timeout).await?; let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; let theme_name = settings.theme.name.clone(); let theme = theme_manager.load_theme(theme_name.as_str(), settings.theme.max_depth); match self { - Self::Import(import) => import.run(&db).await, Self::Stats(stats) => stats.run(&db, &settings, theme).await, Self::Search(search) => search.run(db, &mut settings, sqlite_store, theme).await, @@ -330,12 +329,7 @@ impl Cmd { Self::Store(store) => store.run(&settings, &db, sqlite_store).await, - Self::Server(server) => server.run().await, - - Self::Info => { - info::run(&settings); - Ok(()) - } + Self::Info => info::run(&settings), Self::DefaultConfig => { default_config::run(); @@ -347,7 +341,7 @@ impl Cmd { #[cfg(feature = "daemon")] Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, - Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + Self::History(_) | Self::Init(_) | Self::Config(_) | Self::Server(_) => { unreachable!() } } diff --git a/crates/turtle/src/command/client/daemon.rs b/crates/turtle/src/command/client/daemon.rs index 2fb090aa..cb5dd118 100644 --- a/crates/turtle/src/command/client/daemon.rs +++ b/crates/turtle/src/command/client/daemon.rs @@ -7,7 +7,7 @@ use std::process::{Command, Stdio}; use std::time::{Duration, Instant}; use crate::atuin_client::{ - database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, + database::ClientSqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, }; use crate::atuin_daemon::DaemonEvent; use crate::atuin_daemon::client::{ @@ -86,7 +86,7 @@ impl Cmd { self, settings: Settings, store: SqliteStore, - history_db: Sqlite, + history_db: ClientSqlite, ) -> Result<()> { match self.subcmd { None => { @@ -634,7 +634,7 @@ pub(crate) fn daemonize_current_process() -> Result<()> { async fn run( settings: Settings, store: SqliteStore, - history_db: Sqlite, + history_db: ClientSqlite, force: bool, ) -> Result<()> { if force { diff --git a/crates/turtle/src/command/client/doctor.rs b/crates/turtle/src/command/client/doctor.rs deleted file mode 100644 index eec690a5..00000000 --- a/crates/turtle/src/command/client/doctor.rs +++ /dev/null @@ -1,405 +0,0 @@ -use std::process::Command; -use std::{env, str::FromStr}; - -use crate::atuin_client::database::Sqlite; -use crate::atuin_client::settings::Settings; -use crate::atuin_common::shell::{Shell, shell_name}; -use crate::atuin_common::utils; -use colored::Colorize; -use eyre::Result; -use serde::Serialize; - -use sysinfo::{Disks, System, get_current_pid}; - -#[derive(Debug, Serialize)] -struct ShellInfo { - pub(crate) name: String, - - // best-effort, not supported on all OSes - pub(crate) default: String, - - // Detect some shell plugins that the user has installed. - // I'm just going to start with preexec/blesh - pub(crate) plugins: Vec<String>, - - // The preexec framework used in the current session, if Atuin is loaded. - pub(crate) preexec: Option<String>, -} - -impl ShellInfo { - // HACK ALERT! - // Many of the shell vars we need to detect are not exported :( - // So, we're going to run a interactive session and directly check the - // variable. There's a chance this won't work, so it should not be fatal. - // - // Every shell we support handles `shell -ic 'command'` - fn shellvar_exists(shell: &str, var: &str) -> bool { - let cmd = Command::new(shell) - .args([ - "-ic", - format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), - ]) - .output() - .map_or(String::new(), |v| { - let out = v.stdout; - String::from_utf8(out).unwrap_or_default() - }); - - cmd.contains("ATUIN_DOCTOR_ENV_FOUND") - } - - fn detect_preexec_framework(shell: &str) -> Option<String> { - if env::var("ATUIN_SESSION").ok().is_none() { - None - } else if shell.starts_with("bash") || shell == "sh" { - env::var("ATUIN_PREEXEC_BACKEND") - .ok() - .filter(|value| !value.is_empty()) - .and_then(|atuin_preexec_backend| { - atuin_preexec_backend.rfind(':').and_then(|pos_colon| { - u32::from_str(&atuin_preexec_backend[..pos_colon]) - .ok() - .is_some_and(|preexec_shlvl| { - env::var("SHLVL") - .ok() - .and_then(|shlvl| u32::from_str(&shlvl).ok()) - .is_some_and(|shlvl| shlvl == preexec_shlvl) - }) - .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) - }) - }) - } else { - Some("built-in".to_string()) - } - } - - fn validate_plugin_blesh( - _shell: &str, - shell_process: &sysinfo::Process, - ble_session_id: &str, - ) -> Option<String> { - ble_session_id - .split('/') - .nth(1) - .and_then(|field| u32::from_str(field).ok()) - .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) - .map(|_| "blesh".to_string()) - } - - pub(crate) fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec<String> { - // consider a different detection approach if there are plugins - // that don't set shell vars - - enum PluginShellType { - Any, - Bash, - - // Note: these are currently unused - #[expect(dead_code)] - Zsh, - #[expect(dead_code)] - Fish, - #[expect(dead_code)] - Nushell, - #[expect(dead_code)] - Xonsh, - } - - enum PluginProbeType { - EnvironmentVariable(&'static str), - InteractiveShellVariable(&'static str), - } - - type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option<String>; - - let plugin_list: [( - &str, - PluginShellType, - PluginProbeType, - Option<PluginValidator>, - ); 3] = [ - ( - "atuin", - PluginShellType::Any, - PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), - None, - ), - ( - "blesh", - PluginShellType::Bash, - PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), - Some(Self::validate_plugin_blesh), - ), - ( - "bash-preexec", - PluginShellType::Bash, - PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), - None, - ), - ]; - - plugin_list - .into_iter() - .filter(|(_, shell_type, _, _)| match shell_type { - PluginShellType::Any => true, - PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", - PluginShellType::Zsh => shell.starts_with("zsh"), - PluginShellType::Fish => shell.starts_with("fish"), - PluginShellType::Nushell => shell.starts_with("nu"), - PluginShellType::Xonsh => shell.starts_with("xonsh"), - }) - .filter_map(|(plugin, _, probe_type, validator)| -> Option<String> { - match probe_type { - PluginProbeType::EnvironmentVariable(env) => { - env::var(env).ok().filter(|value| !value.is_empty()) - } - PluginProbeType::InteractiveShellVariable(shellvar) => { - ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) - } - } - .and_then(|value| { - validator.map_or_else( - || Some(plugin.to_string()), - |validator| validator(shell, shell_process, &value), - ) - }) - }) - .collect() - } - - pub(crate) fn new() -> Self { - // TODO: rework to use crate::atuin_common::Shell - - let sys = System::new_all(); - - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - let parent = sys - .process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist"); - - let name = shell_name(Some(parent)); - - let plugins = ShellInfo::plugins(name.as_str(), parent); - - let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); - - let preexec = Self::detect_preexec_framework(name.as_str()); - - Self { - name, - default, - plugins, - preexec, - } - } -} - -#[derive(Debug, Serialize)] -struct DiskInfo { - pub(crate) name: String, - pub(crate) filesystem: String, -} - -#[derive(Debug, Serialize)] -struct SystemInfo { - pub(crate) os: String, - - pub(crate) arch: String, - - pub(crate) version: String, - pub(crate) disks: Vec<DiskInfo>, -} - -impl SystemInfo { - pub(crate) fn new() -> Self { - let disks = Disks::new_with_refreshed_list(); - let disks = disks - .list() - .iter() - .map(|d| DiskInfo { - name: d.name().to_os_string().into_string().unwrap(), - filesystem: d.file_system().to_os_string().into_string().unwrap(), - }) - .collect(); - - Self { - os: System::name().unwrap_or_else(|| "unknown".to_string()), - arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), - version: System::os_version().unwrap_or_else(|| "unknown".to_string()), - disks, - } - } -} - -#[derive(Debug, Serialize)] -struct SyncInfo { - pub(crate) auth_state: String, - pub(crate) auto_sync: bool, - - pub(crate) last_sync: String, -} - -impl SyncInfo { - pub(crate) async fn new(settings: &Settings) -> Result<Self> { - let has_cli_token = settings.have_sync_key().await?; - - let auth_state = if has_cli_token { - "Self-hosted (authenticated)".into() - } else { - "Not authenticated".into() - }; - - Ok(Self { - auth_state, - auto_sync: settings.auto_sync, - last_sync: Settings::last_sync() - .await - .map_or_else(|_| "no last sync".to_string(), |v| v.to_string()), - }) - } -} - -#[derive(Debug)] -struct SettingPaths { - db: String, - record_store: String, - key: String, -} - -impl SettingPaths { - pub(crate) fn new(settings: &Settings) -> Self { - Self { - db: settings.db_path.clone(), - record_store: settings.record_store_path.clone(), - key: settings.key_path.clone(), - } - } - - pub(crate) fn verify(&self) { - let paths = vec![ - ("ATUIN_DB_PATH", &self.db), - ("ATUIN_RECORD_STORE", &self.record_store), - ("ATUIN_KEY", &self.key), - ]; - - for (path_env_var, path) in paths { - if utils::broken_symlink(path) { - eprintln!( - "{path} (${path_env_var}) is a broken symlink. This may cause issues with Atuin." - ); - } - } - } -} - -#[derive(Debug, Serialize)] -struct AtuinInfo { - pub(crate) version: String, - pub(crate) commit: String, - - /// Whether the main Atuin sync server is in use - /// I'm just calling it Atuin Cloud for lack of a better name atm - pub(crate) sync: Option<SyncInfo>, - - pub(crate) sqlite_version: String, - - #[serde(skip)] // probably unnecessary to expose this - pub(crate) setting_paths: SettingPaths, -} - -impl AtuinInfo { - pub(crate) async fn new(settings: &Settings) -> Result<Self> { - let logged_in = settings.have_sync_key().await?; - - let sync = if logged_in { - Some(SyncInfo::new(settings).await?) - } else { - None - }; - - let sqlite_version = match Sqlite::new("sqlite::memory:", 0.1).await { - Ok(db) => db - .sqlite_version() - .await - .unwrap_or_else(|_| "unknown".to_string()), - Err(_) => "error".to_string(), - }; - - Ok(Self { - version: crate::VERSION.to_string(), - commit: crate::SHA.to_string(), - sync, - sqlite_version, - setting_paths: SettingPaths::new(settings), - }) - } -} - -#[derive(Debug, Serialize)] -struct DoctorDump { - pub(crate) atuin: AtuinInfo, - pub(crate) shell: ShellInfo, - pub(crate) system: SystemInfo, -} - -impl DoctorDump { - pub(crate) async fn new(settings: &Settings) -> Result<Self> { - Ok(Self { - atuin: AtuinInfo::new(settings).await?, - shell: ShellInfo::new(), - system: SystemInfo::new(), - }) - } -} - -fn checks(info: &DoctorDump) { - println!(); // spacing - // - let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); - let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh (>= 0.4) be installed. An older ble.sh may not be detected. so ignore this if you have ble.sh >= 0.4 set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); - let blesh_integration_error = "[Shell] Atuin and ble.sh seem to be loaded in the session, but the integration does not seem to be working. Please check the setup in .bashrc.".bold().red(); - - // ZFS: https://github.com/atuinsh/atuin/issues/952 - if info.system.disks.iter().any(|d| d.filesystem == "zfs") { - println!("{zfs_error}"); - } - - info.atuin.setting_paths.verify(); - - // Shell - if info.shell.name == "bash" { - if !info - .shell - .plugins - .iter() - .any(|p| p == "blesh" || p == "bash-preexec") - { - println!("{bash_plugin_error}"); - } - - if info.shell.plugins.iter().any(|plugin| plugin == "atuin") - && info.shell.plugins.iter().any(|plugin| plugin == "blesh") - && info.shell.preexec.as_ref().is_some_and(|val| val == "none") - { - println!("{blesh_integration_error}"); - } - } -} - -pub(crate) async fn run(settings: &Settings) -> Result<()> { - println!("{}", "Atuin Doctor".bold()); - println!("Checking for diagnostics"); - let dump = DoctorDump::new(settings).await?; - - checks(&dump); - - let dump = serde_json::to_string_pretty(&dump)?; - - println!("\nPlease include the output below with any bug reports or issues\n"); - println!("{dump}"); - - Ok(()) -} diff --git a/crates/turtle/src/command/client/history.rs b/crates/turtle/src/command/client/history.rs index e533759b..693098c0 100644 --- a/crates/turtle/src/command/client/history.rs +++ b/crates/turtle/src/command/client/history.rs @@ -21,7 +21,7 @@ use serde::Serialize; use crate::atuin_daemon::history::{HistoryEventKind, TailHistoryReply}; use crate::atuin_client::{ - database::{Database, Sqlite, current_context}, + database::{ClientSqlite, current_context}, encryption, history::{History, store::HistoryStore}, record::sqlite_store::SqliteStore, @@ -411,7 +411,7 @@ fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> & } async fn handle_start( - db: &impl Database, + db: &ClientSqlite, settings: &Settings, command: &str, author: Option<&str>, @@ -484,7 +484,7 @@ async fn handle_daemon_start( #[expect(unused_variables)] async fn handle_end( - db: &impl Database, + db: &ClientSqlite, store: SqliteStore, history_store: HistoryStore, settings: &Settings, @@ -527,7 +527,7 @@ async fn handle_end( db.update(&h).await?; history_store.push(h).await?; - if settings.should_sync().await? { + if settings.sync.should_sync().await? { let (_, downloaded) = record::sync::sync(settings, &store, &history_store.encryption_key).await?; Settings::save_sync_time().await?; @@ -564,7 +564,7 @@ pub(super) async fn start_history_entry( } let db_path = PathBuf::from(settings.db_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; + let db = ClientSqlite::new(db_path, settings.local_timeout).await?; handle_start(&db, settings, command, author, intent).await } @@ -582,7 +582,7 @@ pub(super) async fn end_history_entry( let db_path = PathBuf::from(settings.db_path.as_str()); let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; + let db = ClientSqlite::new(db_path, settings.local_timeout).await?; let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; let encryption_key: [u8; 32] = encryption::load_key(settings) @@ -922,7 +922,7 @@ impl Cmd { #[expect(clippy::too_many_arguments)] #[expect(clippy::fn_params_excessive_bools)] async fn handle_list( - db: &impl Database, + db: &ClientSqlite, settings: &Settings, context: crate::atuin_client::database::Context, session: bool, @@ -964,7 +964,7 @@ impl Cmd { } async fn handle_prune( - db: &impl Database, + db: &ClientSqlite, settings: &Settings, store: SqliteStore, context: crate::atuin_client::database::Context, @@ -1017,7 +1017,7 @@ impl Cmd { } async fn handle_dedup( - db: &impl Database, + db: &ClientSqlite, settings: &Settings, store: SqliteStore, before: i64, @@ -1119,7 +1119,7 @@ impl Cmd { let db_path = PathBuf::from(settings.db_path.as_str()); let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; + let db = ClientSqlite::new(db_path, settings.local_timeout).await?; let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; let encryption_key: [u8; 32] = encryption::load_key(settings) @@ -1233,7 +1233,7 @@ mod tests { #[tokio::test] async fn handle_start_saves_trimmed_command() { - let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let db = ClientSqlite::new("sqlite::memory:", 2.0).await.unwrap(); let settings = Settings::utc(); handle_start(&db, &settings, "ls \t", None, None) @@ -1251,7 +1251,7 @@ mod tests { #[tokio::test] async fn handle_start_can_keep_trailing_whitespace() { - let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let db = ClientSqlite::new("sqlite::memory:", 2.0).await.unwrap(); let settings = Settings { strip_trailing_whitespace: false, ..Settings::utc() diff --git a/crates/turtle/src/command/client/import.rs b/crates/turtle/src/command/client/import.rs deleted file mode 100644 index 3ec524d2..00000000 --- a/crates/turtle/src/command/client/import.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::env; - -use async_trait::async_trait; -use clap::Parser; -use eyre::Result; -use indicatif::ProgressBar; - -use crate::atuin_client::{ - database::Database, - history::History, - import::{ - Importer, Loader, bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, - powershell::PowerShell, replxx::Replxx, resh::Resh, xonsh::Xonsh, - xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, - }, -}; - -#[derive(Parser, Debug)] -#[command(infer_subcommands = true)] -pub(crate) enum Cmd { - /// Import history for the current shell - Auto, - - /// Import history from the zsh history file - Zsh, - /// Import history from the zsh history file - ZshHistDb, - /// Import history from the bash history file - Bash, - /// Import history from the replxx history file - Replxx, - /// Import history from the resh history file - Resh, - /// Import history from the fish history file - Fish, - /// Import history from the nu history file - Nu, - /// Import history from the nu history file - NuHistDb, - /// Import history from xonsh json files - Xonsh, - /// Import history from xonsh sqlite db - XonshSqlite, - /// Import history from the powershell history file - Powershell, -} - -const BATCH_SIZE: usize = 100; - -impl Cmd { - #[expect(clippy::cognitive_complexity)] - pub(crate) async fn run<DB: Database>(&self, db: &DB) -> Result<()> { - println!(" Atuin "); - println!("======================"); - println!(" \u{1f30d} "); - println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); - println!(" \u{1f422} "); - println!("======================"); - println!("Importing history..."); - - match self { - Self::Auto => { - if cfg!(windows) { - return if env::var("PSModulePath").is_ok() { - println!("Detected PowerShell"); - import::<PowerShell, DB>(db).await - } else { - println!("Could not detect the current shell."); - println!("Please run atuin import <SHELL>."); - println!("To view a list of shells, run atuin import."); - Ok(()) - }; - } - - // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is - let xonsh_histfile = - env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); - let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); - - if xonsh_histfile.to_lowercase().ends_with(".json") { - println!("Detected Xonsh"); - import::<Xonsh, DB>(db).await - } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { - println!("Detected Xonsh (SQLite backend)"); - import::<XonshSqlite, DB>(db).await - } else if shell.ends_with("/zsh") { - if ZshHistDb::histpath().is_ok() { - println!( - "Detected Zsh-HistDb, using :{}", - ZshHistDb::histpath().unwrap().to_str().unwrap() - ); - import::<ZshHistDb, DB>(db).await - } else { - println!("Detected ZSH"); - import::<Zsh, DB>(db).await - } - } else if shell.ends_with("/fish") { - println!("Detected Fish"); - import::<Fish, DB>(db).await - } else if shell.ends_with("/bash") { - println!("Detected Bash"); - import::<Bash, DB>(db).await - } else if shell.ends_with("/nu") { - if NuHistDb::histpath().is_ok() { - println!( - "Detected Nu-HistDb, using :{}", - NuHistDb::histpath().unwrap().to_str().unwrap() - ); - import::<NuHistDb, DB>(db).await - } else { - println!("Detected Nushell"); - import::<Nu, DB>(db).await - } - } else if shell.ends_with("/pwsh") { - println!("Detected PowerShell"); - import::<PowerShell, DB>(db).await - } else { - println!("cannot import {shell} history"); - Ok(()) - } - } - - Self::Zsh => import::<Zsh, DB>(db).await, - Self::ZshHistDb => import::<ZshHistDb, DB>(db).await, - Self::Bash => import::<Bash, DB>(db).await, - Self::Replxx => import::<Replxx, DB>(db).await, - Self::Resh => import::<Resh, DB>(db).await, - Self::Fish => import::<Fish, DB>(db).await, - Self::Nu => import::<Nu, DB>(db).await, - Self::NuHistDb => import::<NuHistDb, DB>(db).await, - Self::Xonsh => import::<Xonsh, DB>(db).await, - Self::XonshSqlite => import::<XonshSqlite, DB>(db).await, - Self::Powershell => import::<PowerShell, DB>(db).await, - } - } -} - -pub(crate) struct HistoryImporter<'db, DB: Database> { - pb: ProgressBar, - buf: Vec<History>, - db: &'db DB, -} - -impl<'db, DB: Database> HistoryImporter<'db, DB> { - fn new(db: &'db DB, len: usize) -> Self { - Self { - pb: ProgressBar::new(len as u64), - buf: Vec::with_capacity(BATCH_SIZE), - db, - } - } - - async fn flush(self) -> Result<()> { - if !self.buf.is_empty() { - self.db.save_bulk(&self.buf).await?; - } - self.pb.finish(); - Ok(()) - } -} - -#[async_trait] -impl<DB: Database> Loader for HistoryImporter<'_, DB> { - async fn push(&mut self, hist: History) -> Result<()> { - self.pb.inc(1); - self.buf.push(hist); - if self.buf.len() == self.buf.capacity() { - self.db.save_bulk(&self.buf).await?; - self.buf.clear(); - } - Ok(()) - } -} - -async fn import<I: Importer + Send, DB: Database>(db: &DB) -> Result<()> { - println!("Importing history from {}", I::NAME); - - let mut importer = I::new().await?; - let len = importer.entries().await.unwrap(); - let mut loader = HistoryImporter::new(db, len); - importer.load(&mut loader).await?; - loader.flush().await?; - - println!("Import complete!"); - Ok(()) -} diff --git a/crates/turtle/src/command/client/info.rs b/crates/turtle/src/command/client/info.rs index fc944987..49c92193 100644 --- a/crates/turtle/src/command/client/info.rs +++ b/crates/turtle/src/command/client/info.rs @@ -1,8 +1,9 @@ use crate::atuin_client::settings::Settings;
-
use crate::{SHA, VERSION};
-pub(crate) fn run(settings: &Settings) {
+use eyre::Result;
+
+pub(crate) fn run(settings: &Settings) -> Result<()> {
let config = crate::atuin_common::utils::config_dir();
let mut config_file = config.clone();
config_file.push("config.toml");
@@ -14,7 +15,7 @@ pub(crate) fn run(settings: &Settings) { config_file.to_string_lossy(),
sever_config.to_string_lossy(),
settings.db_path,
- settings.key_path,
+ settings.sync.encryption_key()?,
settings.meta.db_path
);
@@ -28,4 +29,6 @@ pub(crate) fn run(settings: &Settings) { let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}");
println!("{print_out}");
+
+ Ok(())
}
diff --git a/crates/turtle/src/command/client/init.rs b/crates/turtle/src/command/client/init.rs index 0643cb73..0cdcd425 100644 --- a/crates/turtle/src/command/client/init.rs +++ b/crates/turtle/src/command/client/init.rs @@ -89,7 +89,7 @@ $env.config = ( } } - fn static_init(&self, settings: &Settings) { + fn static_init(&self) { match self.shell { Shell::Zsh => { zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r); @@ -119,6 +119,6 @@ $env.config = ( ); } - self.static_init(settings); + self.static_init(); } } diff --git a/crates/turtle/src/command/client/search.rs b/crates/turtle/src/command/client/search.rs index 72112084..962e6b1e 100644 --- a/crates/turtle/src/command/client/search.rs +++ b/crates/turtle/src/command/client/search.rs @@ -1,12 +1,12 @@ use std::fs::File; use std::io::{IsTerminal as _, Write, stderr, stdout}; +use crate::atuin_client::database::ClientSqlite; use crate::atuin_common::utils::{self, Escapable as _}; use clap::Parser; use eyre::Result; use crate::atuin_client::{ - database::Database, database::{OptFilters, current_context}, encryption, history::{History, store::HistoryStore}, @@ -157,7 +157,7 @@ impl Cmd { #[expect(clippy::too_many_lines)] pub(crate) async fn run( self, - db: impl Database, + db: ClientSqlite, settings: &mut Settings, store: SqliteStore, theme: &Theme, @@ -253,7 +253,6 @@ impl Cmd { offset: self.offset, reverse: self.reverse, include_duplicates: self.include_duplicates, - authors: self.author.clone().unwrap_or_default(), }; let mut entries = @@ -310,7 +309,7 @@ async fn run_non_interactive( settings: &Settings, filter_options: OptFilters, query: &[String], - db: &impl Database, + db: &ClientSqlite, ) -> Result<Vec<History>> { let dir = if filter_options.cwd.as_deref() == Some(".") { Some(utils::get_current_dir()) diff --git a/crates/turtle/src/command/client/search/engines.rs b/crates/turtle/src/command/client/search/engines.rs index d6335a38..a84c4798 100644 --- a/crates/turtle/src/command/client/search/engines.rs +++ b/crates/turtle/src/command/client/search/engines.rs @@ -1,9 +1,9 @@ -use async_trait::async_trait; use crate::atuin_client::{ - database::{Context, Database, OptFilters}, - history::{AUTHOR_FILTER_ALL_USER, History, HistoryId}, + database::{ClientSqlite, Context, OptFilters}, + history::{History, HistoryId}, settings::{FilterMode, SearchMode, Settings}, }; +use async_trait::async_trait; use eyre::Result; use super::cursor::Cursor; @@ -67,10 +67,10 @@ pub(crate) trait SearchEngine: Send + Sync + 'static { async fn full_query( &mut self, state: &SearchState, - db: &mut dyn Database, + db: &mut ClientSqlite, ) -> Result<Vec<History>>; - async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { + async fn query(&mut self, state: &SearchState, db: &mut ClientSqlite) -> Result<Vec<History>> { if state.input.as_str().is_empty() { Ok(db .search( @@ -80,7 +80,6 @@ pub(crate) trait SearchEngine: Send + Sync + 'static { "", OptFilters { limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], ..Default::default() }, ) diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs index df5ab9f8..55b3c6f2 100644 --- a/crates/turtle/src/command/client/search/engines/daemon.rs +++ b/crates/turtle/src/command/client/search/engines/daemon.rs @@ -1,6 +1,6 @@ use crate::atuin_client::{ - database::{Database, OptFilters}, - history::{AUTHOR_FILTER_ALL_USER, History}, + database::{ClientSqlite, OptFilters}, + history::History, settings::{SearchMode, Settings}, }; use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; @@ -75,7 +75,7 @@ impl Search { async fn fallback_to_db_search( &self, state: &SearchState, - db: &dyn Database, + db: &ClientSqlite, ) -> Result<Vec<History>> { let results = db .search( @@ -85,7 +85,6 @@ impl Search { state.input.as_str(), OptFilters { limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], ..Default::default() }, ) @@ -95,7 +94,7 @@ impl Search { } #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] - async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result<Vec<History>> { + async fn hydrate_from_db(&self, db: &ClientSqlite, ids: &[String]) -> Result<Vec<History>> { let placeholders: Vec<String> = ids.iter().map(|id| format!("'{id}'")).collect(); let sql_query = format!( "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", @@ -111,7 +110,7 @@ impl SearchEngine for Search { async fn full_query( &mut self, state: &SearchState, - db: &mut dyn Database, + db: &mut ClientSqlite, ) -> Result<Vec<History>> { let query = state.input.as_str().to_string(); diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs index 86917a02..e6657b17 100644 --- a/crates/turtle/src/command/client/search/engines/db.rs +++ b/crates/turtle/src/command/client/search/engines/db.rs @@ -1,12 +1,10 @@ use super::{SearchEngine, SearchState}; -use async_trait::async_trait; use crate::atuin_client::{ - database::Database, - database::OptFilters, - database::{QueryToken, QueryTokenizer}, - history::{AUTHOR_FILTER_ALL_USER, History}, + database::{ClientSqlite, OptFilters, QueryToken, QueryTokenizer}, + history::History, settings::SearchMode, }; +use async_trait::async_trait; use eyre::Result; use norm::Metric; use norm::fzf::{FzfParser, FzfV2}; @@ -21,7 +19,7 @@ impl SearchEngine for Search { async fn full_query( &mut self, state: &SearchState, - db: &mut dyn Database, + db: &mut ClientSqlite, ) -> Result<Vec<History>> { let results = db .search( @@ -31,7 +29,6 @@ impl SearchEngine for Search { state.input.as_str(), OptFilters { limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], ..Default::default() }, ) diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs index fe2bdea3..a6a77573 100644 --- a/crates/turtle/src/command/client/search/engines/skim.rs +++ b/crates/turtle/src/command/client/search/engines/skim.rs @@ -1,18 +1,13 @@ use std::path::Path; +use crate::atuin_client::{database::ClientSqlite, history::History, settings::FilterMode}; use async_trait::async_trait; -use crate::atuin_client::{ - database::Database, - history::{History, is_known_agent}, - settings::FilterMode, -}; use eyre::Result; use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; use itertools::Itertools; use time::OffsetDateTime; use tokio::task::yield_now; use tracing::{Level, instrument, warn}; -use uuid; use super::{SearchEngine, SearchState}; @@ -36,7 +31,7 @@ impl SearchEngine for Search { async fn full_query( &mut self, state: &SearchState, - db: &mut dyn Database, + db: &mut ClientSqlite, ) -> Result<Vec<History>> { if self.all_history.is_empty() { self.all_history = load_all_history(db).await; @@ -56,7 +51,7 @@ impl SearchEngine for Search { } #[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] -async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { +async fn load_all_history(db: &ClientSqlite) -> Vec<(History, i32)> { db.all_with_count().await.unwrap() } @@ -76,9 +71,7 @@ async fn fuzzy_search( if i % 256 == 0 { yield_now().await; } - if is_known_agent(&history.author) { - continue; - } + let context = &state.context; let git_root = context .git_root diff --git a/crates/turtle/src/command/client/search/interactive.rs b/crates/turtle/src/command/client/search/interactive.rs index 380fc33b..1d067e50 100644 --- a/crates/turtle/src/command/client/search/interactive.rs +++ b/crates/turtle/src/command/client/search/interactive.rs @@ -6,7 +6,10 @@ use std::{ #[cfg(unix)] use std::io::Read as _; -use crate::atuin_common::{shell::Shell, utils::Escapable as _}; +use crate::{ + atuin_client::database::ClientSqlite, + atuin_common::{shell::Shell, utils::Escapable as _}, +}; use eyre::Result; use time::OffsetDateTime; use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; @@ -17,7 +20,7 @@ use super::{ history_list::{HistoryList, ListState}, }; use crate::atuin_client::{ - database::{Context, Database, current_context}, + database::{Context, current_context}, history::{History, HistoryId, HistoryStats, store::HistoryStore}, settings::{ CursorStyle, ExitMode, FilterMode, KeymapMode, PreviewStrategy, SearchMode, Settings, @@ -149,7 +152,7 @@ struct StyleState { impl State { async fn query_results( &mut self, - db: &mut dyn Database, + db: &mut ClientSqlite, smart_sort: bool, ) -> Result<Vec<History>> { let results = self.engine.query(&self.search, db).await?; @@ -1550,7 +1553,7 @@ fn compute_popup_placement( pub(crate) async fn history( query: &[String], settings: &Settings, - mut db: impl Database, + mut db: ClientSqlite, history_store: &HistoryStore, theme: &Theme, ) -> Result<String> { diff --git a/crates/turtle/src/command/client/server.rs b/crates/turtle/src/command/client/server.rs index def1dfb3..d821d6f8 100644 --- a/crates/turtle/src/command/client/server.rs +++ b/crates/turtle/src/command/client/server.rs @@ -24,7 +24,6 @@ pub(crate) enum Cmd { } impl Cmd { - #[expect(clippy::too_many_lines)] pub(crate) async fn run(self) -> Result<()> { match self { Cmd::Start { host, port } => { diff --git a/crates/turtle/src/command/client/stats.rs b/crates/turtle/src/command/client/stats.rs index 98401cd3..17432bb2 100644 --- a/crates/turtle/src/command/client/stats.rs +++ b/crates/turtle/src/command/client/stats.rs @@ -3,11 +3,8 @@ use eyre::Result; use interim::parse_date_string; use time::{Duration, OffsetDateTime, Time}; -use crate::atuin_client::{ - database::{Database, current_context}, - settings::Settings, - theme::Theme, -}; +use crate::atuin_client::database::ClientSqlite; +use crate::atuin_client::{database::current_context, settings::Settings, theme::Theme}; use crate::atuin_history::stats::{compute, pretty_print}; @@ -39,7 +36,12 @@ pub(crate) struct Cmd { } impl Cmd { - pub(crate) async fn run(&self, db: &impl Database, settings: &Settings, theme: &Theme) -> Result<()> { + pub(crate) async fn run( + &self, + db: &ClientSqlite, + settings: &Settings, + theme: &Theme, + ) -> Result<()> { let context = current_context().await?; let words = if self.period.is_empty() { String::from("all") diff --git a/crates/turtle/src/command/client/store.rs b/crates/turtle/src/command/client/store.rs index 3e9355b5..347c4bee 100644 --- a/crates/turtle/src/command/client/store.rs +++ b/crates/turtle/src/command/client/store.rs @@ -2,9 +2,7 @@ use clap::Subcommand; use eyre::Result; use crate::atuin_client::{ - database::Database, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, + database::ClientSqlite, record::{sqlite_store::SqliteStore, store::Store}, settings::Settings }; use itertools::Itertools; use time::{OffsetDateTime, UtcOffset}; @@ -51,7 +49,7 @@ impl Cmd { pub(crate) async fn run( &self, settings: &Settings, - database: &dyn Database, + database: &ClientSqlite, store: SqliteStore, ) -> Result<()> { match self { diff --git a/crates/turtle/src/command/client/store/pull.rs b/crates/turtle/src/command/client/store/pull.rs index 6b709a64..f2e628d6 100644 --- a/crates/turtle/src/command/client/store/pull.rs +++ b/crates/turtle/src/command/client/store/pull.rs @@ -2,12 +2,7 @@ use clap::Args; use eyre::Result; use crate::atuin_client::{ - database::Database, - encryption::load_key, - record::store::Store, - record::sync::Operation, - record::{sqlite_store::SqliteStore, sync}, - settings::Settings, + database::ClientSqlite, encryption::load_key, record::{sqlite_store::SqliteStore, store::Store, sync::{self, Operation}}, settings::Settings }; #[derive(Args, Debug)] @@ -32,7 +27,7 @@ impl Pull { &self, settings: &Settings, store: SqliteStore, - db: &dyn Database, + db: &ClientSqlite, ) -> Result<()> { if self.force { println!("Forcing local overwrite!"); diff --git a/crates/turtle/src/command/client/store/push.rs b/crates/turtle/src/command/client/store/push.rs index 30177dbd..beec613c 100644 --- a/crates/turtle/src/command/client/store/push.rs +++ b/crates/turtle/src/command/client/store/push.rs @@ -1,6 +1,6 @@ use crate::atuin_common::record::HostId; use clap::Args; -use eyre::Result; +use eyre::{OptionExt, Result}; use uuid::Uuid; use crate::atuin_client::{ @@ -42,11 +42,12 @@ impl Push { println!("Clearing remote store"); let client = Client::new( - &settings.sync_address, - settings.sync_auth().await?.into_auth_token()?, + &settings.sync.address, settings.network_connect_timeout, - settings.network_timeout * 10, // we may be deleting a lot of data... so up the - // timeout + // we may be deleting a lot of data... so increase the + // timeout + settings.network_timeout * 10, + settings.sync.user_id()?.ok_or_eyre("no sync user-id")?, ) .expect("failed to create client"); diff --git a/crates/turtle/src/command/client/store/rebuild.rs b/crates/turtle/src/command/client/store/rebuild.rs index 0959b74e..bee1aa05 100644 --- a/crates/turtle/src/command/client/store/rebuild.rs +++ b/crates/turtle/src/command/client/store/rebuild.rs @@ -5,7 +5,7 @@ use eyre::{Result, bail}; use crate::command::client::daemon as daemon_cmd; use crate::atuin_client::{ - database::Database, encryption, history::store::HistoryStore, + database::ClientSqlite, encryption, history::store::HistoryStore, record::sqlite_store::SqliteStore, settings::Settings, }; @@ -19,7 +19,7 @@ impl Rebuild { &self, settings: &Settings, store: SqliteStore, - database: &dyn Database, + database: &ClientSqlite, ) -> Result<()> { // keep it as a string and not an enum atm // would be super cool to build this dynamically in the future @@ -41,7 +41,7 @@ impl Rebuild { &self, settings: &Settings, store: SqliteStore, - database: &dyn Database, + database: &ClientSqlite, ) -> Result<()> { let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); diff --git a/crates/turtle/src/command/client/store/rekey.rs b/crates/turtle/src/command/client/store/rekey.rs index 3472222f..b99fb16a 100644 --- a/crates/turtle/src/command/client/store/rekey.rs +++ b/crates/turtle/src/command/client/store/rekey.rs @@ -32,9 +32,15 @@ impl Rekey { store.re_encrypt(¤t_key, &new_key).await?; - println!("Store rewritten. Saving new key"); - let mut file = File::create(settings.key_path.clone()).await?; - file.write_all(key.as_bytes()).await?; + if let Some(key_path) = settings.sync.encryption_key_path.as_ref() { + println!("Store rewritten. Saving new key"); + let mut file = File::create(key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + println!( + "No key-path (settings.sync.encryption_key_path) set in config, will not save new key." + ); + } Ok(()) } diff --git a/crates/turtle/src/command/client/sync.rs b/crates/turtle/src/command/client/sync.rs index 7adf90ed..84b74cc1 100644 --- a/crates/turtle/src/command/client/sync.rs +++ b/crates/turtle/src/command/client/sync.rs @@ -1,12 +1,12 @@ use clap::Subcommand; use eyre::{Result, WrapErr}; +use serde_json::json; -use crate::atuin_client::{ - database::Database, - encryption, - history::store::HistoryStore, - record::{sqlite_store::SqliteStore, store::Store, sync}, - settings::Settings, +use crate::{ + atuin_client::{ + database::ClientSqlite, encryption, history::store::HistoryStore, record::{sqlite_store::SqliteStore, store::Store, sync}, settings::Settings + }, + atuin_common::utils, }; mod status; @@ -15,14 +15,14 @@ mod status; #[command(infer_subcommands = true)] pub(crate) enum Cmd { /// Sync with the configured server - Sync { + Perform { /// Force re-download everything #[arg(long, short)] force: bool, }, - /// Print the encryption key for transfer to another machine - Key {}, + /// Print (or generate) the encryption key and user id for transfer to another machine + KeyAndId {}, /// Display the sync status Status, @@ -32,18 +32,28 @@ impl Cmd { pub(crate) async fn run( self, settings: Settings, - db: &impl Database, + db: &ClientSqlite, store: SqliteStore, ) -> Result<()> { match self { - Self::Sync { force } => run(&settings, force, db, store).await, + Self::Perform { force } => run(&settings, force, db, store).await, Self::Status => status::run(&settings).await, - Self::Key {} => { + Self::KeyAndId {} => { use crate::atuin_client::encryption::{encode_key, load_key}; + let key = load_key(&settings).wrap_err("could not load encryption key")?; + let user_id = settings + .sync + .user_id() + .wrap_err("Failed to load user-id")? + .unwrap_or_else(utils::uuid_v7); + + let key = encode_key(&key).wrap_err("could not encode encryption key")?; + + let json = serde_json::to_string_pretty(&json!({ "key": key, "user_id": user_id })) + .expect("Will always be formattable"); - let encode = encode_key(&key).wrap_err("could not encode encryption key")?; - println!("{encode}"); + println!("{json}"); Ok(()) } @@ -54,7 +64,7 @@ impl Cmd { async fn run( settings: &Settings, force: bool, - db: &impl Database, + db: &ClientSqlite, store: SqliteStore, ) -> Result<()> { let encryption_key: [u8; 32] = encryption::load_key(settings) diff --git a/crates/turtle/src/command/client/sync/status.rs b/crates/turtle/src/command/client/sync/status.rs index 27b10dbd..e75171eb 100644 --- a/crates/turtle/src/command/client/sync/status.rs +++ b/crates/turtle/src/command/client/sync/status.rs @@ -1,36 +1,24 @@ -use crate::atuin_client::{api_client, settings::Settings}; +use crate::atuin_client::settings::Settings; use crate::{SHA, VERSION}; use colored::Colorize; use eyre::{Result, bail}; pub(crate) async fn run(settings: &Settings) -> Result<()> { - if !settings.have_sync_key().await? { - bail!("You are not logged in to a sync server - cannot show sync status"); - } - - let client = api_client::Client::new( - &settings.sync_address, - settings.sync_auth().await?.into_auth_token()?, - settings.network_connect_timeout, - settings.network_timeout, - )?; - - let me = client.me().await?; - let last_sync = Settings::last_sync().await?; + if let Some(me) = settings.sync.user_id()? { + let last_sync = Settings::last_sync().await?; - println!("Atuin v{VERSION} - Build rev {SHA}\n"); + println!("Atuin v{VERSION} - Build rev {SHA}\n"); - println!("{}", "[Local]".green()); - - if settings.auto_sync { - println!("Sync frequency: {}", settings.sync_frequency); + println!("{}", "[Local]".green()); + println!("Sync frequency: {}", settings.sync.frequency); println!("Last sync: {}", last_sync.to_offset(settings.timezone.0)); - } + println!("Auto sync: {}", settings.sync.auto); - if settings.auto_sync { println!("{}", "[Remote]".green()); - println!("Address: {}", settings.sync_address); - println!("Username: {}", me.username); + println!("Address: {}", settings.sync.address); + println!("User id: {}", me); + } else { + bail!("You are not logged in to a sync server - cannot show sync status"); } Ok(()) diff --git a/crates/turtle/src/command/client/wrapped.rs b/crates/turtle/src/command/client/wrapped.rs index 5e41657e..d502d3ec 100644 --- a/crates/turtle/src/command/client/wrapped.rs +++ b/crates/turtle/src/command/client/wrapped.rs @@ -3,7 +3,8 @@ use eyre::Result; use std::collections::{HashMap, HashSet}; use time::{Date, Duration, Month, OffsetDateTime, Time}; -use crate::atuin_client::{database::Database, settings::Settings, theme::Theme}; +use crate::atuin_client::database::ClientSqlite; +use crate::atuin_client::{settings::Settings, theme::Theme}; use crate::atuin_history::stats::{Stats, compute}; @@ -268,7 +269,7 @@ fn print_fun_facts(wrapped_stats: &WrappedStats, stats: &Stats, year: i32) { pub(crate) async fn run( year: Option<i32>, - db: &impl Database, + db: &ClientSqlite, settings: &Settings, theme: &Theme, ) -> Result<()> { diff --git a/crates/turtle/src/command/external.rs b/crates/turtle/src/command/external.rs deleted file mode 100644 index a5daea21..00000000 --- a/crates/turtle/src/command/external.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::fmt::Write as _; -use std::process::Command; -use std::{io, process}; - -#[cfg(feature = "client")] -use crate::atuin_client::plugin::{OfficialPluginRegistry, PluginContext}; -use clap::CommandFactory; -use clap::builder::{StyledStr, Styles}; -use eyre::Result; - -use crate::Atuin; - -pub(crate) fn run(args: &[String]) -> Result<()> { - let subcommand = &args[0]; - let bin = format!("atuin-{subcommand}"); - let mut cmd = Command::new(&bin); - cmd.args(&args[1..]); - - #[cfg(feature = "client")] - let context = PluginContext::new(subcommand); - - let spawn_result = match cmd.spawn() { - Ok(child) => Ok(child), - Err(e) => match e.kind() { - io::ErrorKind::NotFound => { - let output = render_not_found(subcommand, &bin); - Err(output) - } - _ => Err(e.to_string().into()), - }, - }; - - match spawn_result { - Ok(mut child) => { - let status = child.wait()?; - if status.success() { - Ok(()) - } else { - #[cfg(feature = "client")] - drop(context); - - process::exit(status.code().unwrap_or(1)); - } - } - Err(e) => { - eprintln!("{}", e.ansi()); - - #[cfg(feature = "client")] - drop(context); - - process::exit(1); - } - } -} - -fn render_not_found(subcommand: &str, bin: &str) -> StyledStr { - let mut output = StyledStr::new(); - let styles = Styles::styled(); - - let error = styles.get_error(); - let invalid = styles.get_invalid(); - let literal = styles.get_literal(); - - #[cfg(feature = "client")] - { - let registry = OfficialPluginRegistry::new(); - - // Check if this is an official plugin - if let Some(install_message) = registry.get_install_message(subcommand) { - let _ = write!(output, "{error}error:{error:#} "); - let _ = write!( - output, - "'{invalid}{subcommand}{invalid:#}' is an official atuin plugin, but it's not installed" - ); - let _ = write!(output, "\n\n"); - let _ = write!(output, "{install_message}"); - return output; - } - } - - let mut atuin_cmd = Atuin::command(); - let usage = atuin_cmd.render_usage(); - - let _ = write!(output, "{error}error:{error:#} "); - let _ = write!( - output, - "unrecognized subcommand '{invalid}{subcommand}{invalid:#}' " - ); - let _ = write!( - output, - "and no executable named '{invalid}{bin}{invalid:#}' found in your PATH" - ); - let _ = write!(output, "\n\n"); - let _ = write!(output, "{usage}"); - let _ = write!(output, "\n\n"); - let _ = write!( - output, - "For more information, try '{literal}--help{literal:#}'." - ); - - output -} diff --git a/crates/turtle/src/command/mod.rs b/crates/turtle/src/command/mod.rs index 5d5d839e..308e1970 100644 --- a/crates/turtle/src/command/mod.rs +++ b/crates/turtle/src/command/mod.rs @@ -11,8 +11,6 @@ mod contributors; mod gen_completions; -mod external; - #[derive(Subcommand)] #[command(infer_subcommands = true)] #[expect(clippy::large_enum_variant)] @@ -33,9 +31,6 @@ pub(crate) enum AtuinCmd { /// Generate shell completions GenCompletions(gen_completions::Cmd), - - #[command(external_subcommand)] - External(Vec<String>), } impl AtuinCmd { @@ -67,7 +62,6 @@ impl AtuinCmd { Ok(()) } Self::GenCompletions(gen_completions) => gen_completions.run(), - Self::External(args) => external::run(&args), } } } diff --git a/crates/turtle/src/sync.rs b/crates/turtle/src/sync.rs index cb743097..abe1a201 100644 --- a/crates/turtle/src/sync.rs +++ b/crates/turtle/src/sync.rs @@ -1,8 +1,8 @@ use eyre::{Context, Result}; +use crate::atuin_client::database::ClientSqlite; use crate::atuin_client::{ - database::Database, history::store::HistoryStore, record::sqlite_store::SqliteStore, - settings::Settings, + history::store::HistoryStore, record::sqlite_store::SqliteStore, settings::Settings, }; use crate::atuin_common::record::RecordId; @@ -15,7 +15,7 @@ use crate::atuin_common::record::RecordId; pub(crate) async fn build( settings: &Settings, store: &SqliteStore, - db: &dyn Database, + db: &ClientSqlite, downloaded: Option<&[RecordId]>, ) -> Result<()> { let encryption_key: [u8; 32] = crate::atuin_client::encryption::load_key(settings) @@ -10,12 +10,30 @@ }: let system = "x86_64-linux"; pkgs = nixpkgs.outputs.legacyPackages.${system}; + + turtle = pkgs.callPackage ./nix/package.nix {}; + tests = let + test-turtle = turtle.overrideAttrs { + cargoBuildType = "debug"; + }; + in + pkgs.testers.runNixOSTest { + imports = [./tests/basic.nix]; + defaults = { + services.turtle.package = test-turtle; + environment.systemPackages = [ + test-turtle + ]; + }; + }; in { packages."${system}" = { - atuin = pkgs.callPackage ./atuin.nix {}; - default = self.outputs.packages.${system}.atuin; + inherit turtle; + default = self.outputs.packages.${system}.turtle; }; + checks."${system}".default = tests; + devShells."${system}".default = self.packages.${system}.default.overrideAttrs (super: { nativeBuildInputs = super.nativeBuildInputs diff --git a/nix/module.nix b/nix/module.nix new file mode 100644 index 00000000..da6fd02c --- /dev/null +++ b/nix/module.nix @@ -0,0 +1,144 @@ +{ + config, + pkgs, + lib, + ... +}: let + inherit (lib) mkOption types mkIf; + cfg = config.services.turtle; +in { + options = { + services.turtle = { + enable = lib.mkEnableOption "turtle server for shell history sync"; + + package = lib.mkPackageOption pkgs "turtle" {}; + + host = mkOption { + type = types.str; + default = "127.0.0.1"; + description = "The host address the turtle server should listen on."; + }; + + port = mkOption { + type = types.port; + default = 8888; + description = "The port the turtle server should listen on."; + }; + + database = { + createLocally = mkOption { + type = types.bool; + default = true; + description = "Create the database and database user locally."; + }; + + uri = mkOption { + type = types.nullOr types.str; + default = "postgresql:///turtle?host=/run/postgresql"; + example = "postgresql://turtle@localhost:5432/turtle"; + description = '' + URI to the database. + Can be set to null in which case ATUIN_DB_URI should be set through an EnvironmentFile + ''; + }; + }; + + environmentFile = lib.mkOption { + type = lib.types.nullOr lib.types.externalPath; + default = null; + description = '' + Environment file, used to set any secret ATUIN_* environment variables, such as ATUIN_DB_URI containing a password. + See https://docs.atuin.sh/cli/self-hosting/server-setup/#configuration for available environment variables. + ''; + }; + }; + }; + + config = mkIf cfg.enable { + assertions = [ + { + assertion = cfg.database.createLocally -> config.services.postgresql.enable; + message = "Postgresql must be enabled to create a local database"; + } + ]; + + services.postgresql = mkIf cfg.database.createLocally { + enable = true; + ensureUsers = [ + { + name = "turtle"; + ensureDBOwnership = true; + } + ]; + ensureDatabases = ["turtle"]; + }; + + systemd.services.turtle = { + description = "turtle server"; + requires = lib.optionals cfg.database.createLocally ["postgresql.target"]; + after = + [ + "network-online.target" + ] + ++ lib.optionals cfg.database.createLocally ["postgresql.target"]; + wants = + [ + "network-online.target" + ] + ++ lib.optionals cfg.database.createLocally ["postgresql.target"]; + wantedBy = ["multi-user.target"]; + + serviceConfig = { + ExecStart = "${lib.getExe' cfg.package "atuin"} server start"; + EnvironmentFile = lib.mkIf (cfg.environmentFile != null) [cfg.environmentFile]; + RuntimeDirectory = "turtle"; + RuntimeDirectoryMode = "0700"; + DynamicUser = true; + + # Hardening + CapabilityBoundingSet = ""; + LockPersonality = true; + NoNewPrivileges = true; + MemoryDenyWriteExecute = true; + PrivateDevices = true; + PrivateMounts = true; + PrivateTmp = true; + PrivateUsers = true; + ProcSubset = "pid"; + ProtectClock = true; + ProtectControlGroups = true; + ProtectHome = true; + ProtectHostname = true; + ProtectKernelLogs = true; + ProtectKernelModules = true; + ProtectKernelTunables = true; + ProtectProc = "invisible"; + ProtectSystem = "full"; + RemoveIPC = true; + RestrictAddressFamilies = [ + "AF_INET" + "AF_INET6" + # Required for connecting to database sockets, + "AF_UNIX" + ]; + RestrictNamespaces = true; + RestrictRealtime = true; + RestrictSUIDSGID = true; + SystemCallArchitectures = "native"; + SystemCallFilter = [ + "@system-service" + "~@privileged" + ]; + UMask = "0077"; + }; + + environment = + { + ATUIN_HOST = cfg.host; + ATUIN_PORT = toString cfg.port; + ATUIN_CONFIG_DIR = "/run/turtle"; # required to start, but not used as configuration is via environment variables + } + // lib.optionalAttrs (cfg.database.uri != null) {ATUIN_DB_URI = cfg.database.uri;}; + }; + }; +} diff --git a/atuin.nix b/nix/package.nix index 22022835..82a348b8 100644 --- a/atuin.nix +++ b/nix/package.nix @@ -12,12 +12,31 @@ libiconv, }: rustPlatform.buildRustPackage { - name = "atuin"; + name = "turtle"; - src = lib.cleanSource ./.; + src = lib.cleanSourceWith { + src = lib.cleanSource ./..; + filter = name: type: + (type == "directory") + || (builtins.elem (builtins.baseNameOf name) [ + "Cargo.toml" + "Cargo.lock" + "CONTRIBUTORS" + + "atuin.bash" + "atuin.fish" + "atuin.nu" + "atuin.ps1" + "atuin.xsh" + "atuin.zsh" + ]) + || (lib.strings.hasSuffix ".rs" (builtins.baseNameOf name)) + || (lib.strings.hasSuffix ".proto" (builtins.baseNameOf name)) + || (lib.strings.hasSuffix ".sql" (builtins.baseNameOf name)); + }; cargoLock = { - lockFile = ./Cargo.lock; + lockFile = ../Cargo.lock; # Allow dependencies to be fetched from git and avoid having to set the outputHashes manually allowBuiltinFetchGit = true; }; diff --git a/tests/basic.nix b/tests/basic.nix new file mode 100644 index 00000000..7495d093 --- /dev/null +++ b/tests/basic.nix @@ -0,0 +1,185 @@ +{pkgs, ...}: { + name = "turtle-sync"; + + node = {}; + + nodes = let + atuinSession = "01969ec6b8d07e30a9d2df0911fbfe2a"; + in { + acme = { + imports = [ + ./common/acme/server.nix + ./common/dns/client.nix + ../nix/module.nix + ]; + }; + name_server = {nodes, ...}: { + imports = [ + ./common/acme/client.nix + ./common/dns/server.nix + ../nix/module.nix + ]; + + vhack.dns.zones = { + "turtle-sync.server" = { + SOA = { + nameServer = "ns"; + adminEmail = "admin@server.com"; + serial = 2025012301; + }; + useOrigin = false; + + A = [ + nodes.server.networking.primaryIPAddress + ]; + AAAA = [ + nodes.server.networking.primaryIPv6Address + ]; + }; + }; + }; + server = {config, ...}: let + turtleCfg = config.services.turtle; + in { + imports = [ + ../nix/module.nix + ./common/acme/client.nix + ./common/dns/client.nix + ]; + + config = { + services = { + postgresql.enable = true; + turtle = { + enable = true; + host = "127.0.0.1"; + database.createLocally = true; + }; + nginx = { + enable = true; + + recommendedTlsSettings = true; + recommendedOptimisation = true; + recommendedGzipSettings = true; + recommendedProxySettings = true; + + virtualHosts."turtle-sync.server" = { + locations."/" = { + proxyPass = "http://${turtleCfg.host}:${toString turtleCfg.port}"; + + recommendedProxySettings = true; + proxyWebsockets = true; + }; + + enableACME = true; + forceSSL = true; + }; + }; + }; + networking.firewall = { + allowedTCPPorts = [80 443]; + }; + }; + }; + + client1 = { + config, + pkgs, + ... + }: { + imports = [ + ../nix/module.nix + ./common/acme/client.nix + ./common/dns/client.nix + ]; + + environment.sessionVariables.ATUIN_SESSION = atuinSession; + }; + client2 = { + config, + pkgs, + ... + }: { + imports = [ + ../nix/module.nix + ./common/acme/client.nix + ./common/dns/client.nix + ]; + + environment.sessionVariables.ATUIN_SESSION = atuinSession; + }; + }; + + testScript = {nodes, ...}: let + mkSyncConfig = pkgs.writeShellScript "write-turtle-sync-config" '' + mkdir --parents ~/.config/atuin/ + cat << EOF > ~/.config/atuin/config.toml + + [sync] + address = "https://turtle-sync.server" + user_id_path = "${pkgs.writeText "user-id" "019eb88a-6b51-7e52-b12c-7d30bd8e5928"}" + encryption_key_path = "${pkgs.writeText "encryption-key" "3AAgbWsDzL7M00/Mq0LMjsyOCy3MnsypBsyQzKbMywNGzNnMrUBozIINAxdbIiDMhQ=="}" + EOF + ''; + + runCommandAndRecordInTurtle = pkgs.writeShellScript "run-command-and-record-in-turtle" '' + # SPDX-SnippetBegin + # SPDX-SnippetCopyrightText: 2023 mentalisttraceur (https://github.com/mentalisttraceur) + # Source: https://github.com/atuinsh/atuin/issues/1188#issuecomment-1698354107 + run_and_record_in_turtle() + { + local id + local status + local escaped_command="$(printf '%q ' "$@")" + id="$(atuin history start -- "$escaped_command")" + "$@" + status=$? + atuin history end --exit $status "$id" + return $status + } + # SPDX-SnippetEnd + + run_and_record_in_turtle "$@" + ''; + + acme = import ./common/acme {inherit pkgs;}; + in + acme.prepare ["server" "client1" "client2"] + # Python + '' + server.wait_for_unit("turtle.service") + server.wait_for_open_port(443) + + # Wait for the server to acquire the acme certificate + client1.wait_until_succeeds("curl https://turtle-sync.server") + + with subtest("Setup client syncing"): + for client in [client1, client2]: + client.succeed("${mkSyncConfig}") + + with subtest("Can generate shell history"): + client1.succeed("${runCommandAndRecordInTurtle} echo hi - client 1") + client2.succeed("${runCommandAndRecordInTurtle} echo hi - client 2") + + with subtest("Can sync"): + for client in [client1, client2]: + client.succeed("atuin sync perform --force") + client1.succeed("atuin sync perform --force") + + + with subtest("Have correct tasks"): + hist1 = client1.succeed("atuin history list --format '{command}'").strip().split('\n') + hist2 = client2.succeed("atuin history list --format '{command}'").strip().split('\n') + + hist1.sort() + hist2.sort() + + canonicalHistory = [ + "echo hi - client 1", + "echo hi - client 2" + ] + + assert hist1 == hist2, f"The clients don't have the same amount of history items, client1: '{hist1}', client2: '{hist2}'" + assert hist1 == canonicalHistory, f"The history is not correct: '{hist1}' vs. '{canonicalHistory}'" + ''; +} diff --git a/tests/common/acme/certs/generate b/tests/common/acme/certs/generate new file mode 100755 index 00000000..0d6258eb --- /dev/null +++ b/tests/common/acme/certs/generate @@ -0,0 +1,66 @@ +#! /usr/bin/env nix-shell +#! nix-shell -p gnutls -p dash -i dash --impure +# shellcheck shell=dash + +# For development and testing. +# Create a CA key and cert, and use that to generate a server key and cert. +# Creates: +# ca.key.pem +# ca.cert.pem +# server.key.pem +# server.cert.pem + +export SEC_PARAM=ultra +export EXPIRATION_DAYS=123456 +export ORGANIZATION="Vhack.eu Test Keys" +export COUNTRY=EU +export SAN="acme.test" +export KEY_TYPE="ed25519" + +BASEDIR="$(dirname "$0")" +GENERATION_LOCATION="$BASEDIR/output" +cd "$BASEDIR" || { + echo "(BUG?) No basedir ('$BASEDIR')" 1>&2 + exit 1 +} + +ca=false +clients=false + +usage() { + echo "Usage: $0 --ca|--clients" + exit 2 +} + +if [ "$#" -eq 0 ]; then + usage +fi + +for arg in "$@"; do + case "$arg" in + "--ca") + ca=true + ;; + "--clients") + clients=true + ;; + *) + usage + ;; + esac +done + +[ -d "$GENERATION_LOCATION" ] || mkdir --parents "$GENERATION_LOCATION" +cd "$GENERATION_LOCATION" || echo "(BUG?) No generation location fould!" 1>&2 + +[ "$ca" = true ] && ../generate.ca + +# Creates: +# <client_name>.key.pem +# <client_name>.cert.pem +# +[ "$clients" = true ] && ../generate.client "acme.test" + +echo "(INFO) Look for the keys at: $GENERATION_LOCATION" + +# vim: ft=sh diff --git a/tests/common/acme/certs/generate.ca b/tests/common/acme/certs/generate.ca new file mode 100755 index 00000000..92832c54 --- /dev/null +++ b/tests/common/acme/certs/generate.ca @@ -0,0 +1,38 @@ +#! /usr/bin/env sh + +# Take the correct binary to create the certificates +CERTTOOL=$(command -v gnutls-certtool 2>/dev/null || command -v certtool 2>/dev/null) +if [ -z "$CERTTOOL" ]; then + echo "ERROR: No certtool found" >&2 + exit 1 +fi + +# Create a CA key. +$CERTTOOL \ + --generate-privkey \ + --sec-param "$SEC_PARAM" \ + --key-type "$KEY_TYPE" \ + --outfile ca.key.pem + +chmod 600 ca.key.pem + +# Sign a CA cert. +cat <<EOF >ca.template +country = $COUNTRY +dns_name = "$SAN" +expiration_days = $EXPIRATION_DAYS +organization = $ORGANIZATION +ca +EOF +#state = $STATE +#locality = $LOCALITY + +$CERTTOOL \ + --generate-self-signed \ + --load-privkey ca.key.pem \ + --template ca.template \ + --outfile ca.cert.pem + +chmod 600 ca.cert.pem + +# vim: ft=sh diff --git a/tests/common/acme/certs/generate.client b/tests/common/acme/certs/generate.client new file mode 100755 index 00000000..5930298a --- /dev/null +++ b/tests/common/acme/certs/generate.client @@ -0,0 +1,44 @@ +#! /usr/bin/env sh + +# Take the correct binary to create the certificates +CERTTOOL=$(command -v gnutls-certtool 2>/dev/null || command -v certtool 2>/dev/null) +if [ -z "$CERTTOOL" ]; then + echo "ERROR: No certtool found" >&2 + exit 1 +fi + +NAME=client +if [ $# -gt 0 ]; then + NAME="$1" +fi + +# Create a client key. +$CERTTOOL \ + --generate-privkey \ + --sec-param "$SEC_PARAM" \ + --key-type "$KEY_TYPE" \ + --outfile "$NAME".key.pem + +chmod 600 "$NAME".key.pem + +# Sign a client cert with the key. +cat <<EOF >"$NAME".template +dns_name = "$NAME" +dns_name = "$SAN" +expiration_days = $EXPIRATION_DAYS +organization = $ORGANIZATION +encryption_key +signing_key +EOF + +$CERTTOOL \ + --generate-certificate \ + --load-privkey "$NAME".key.pem \ + --load-ca-certificate ca.cert.pem \ + --load-ca-privkey ca.key.pem \ + --template "$NAME".template \ + --outfile "$NAME".cert.pem + +chmod 600 "$NAME".cert.pem + +# vim: ft=sh diff --git a/tests/common/acme/certs/output/acme.test.cert.pem b/tests/common/acme/certs/output/acme.test.cert.pem new file mode 100644 index 00000000..687101d1 --- /dev/null +++ b/tests/common/acme/certs/output/acme.test.cert.pem @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBjTCCAT+gAwIBAgIUfiDKld3eiPKuFhsaiHpPNmbMJU8wBQYDK2VwMCoxCzAJ +BgNVBAYTAkVVMRswGQYDVQQKExJWaGFjay5ldSBUZXN0IEtleXMwIBcNMjUwMzAx +MTEyNjU2WhgPMjM2MzAzMDYxMTI2NTZaMB0xGzAZBgNVBAoTElZoYWNrLmV1IFRl +c3QgS2V5czAqMAUGAytlcAMhAHYq2cjrfrlslWxvcKjs2cD7THbpmtq+jf/dlrKW +UEo8o4GBMH8wDAYDVR0TAQH/BAIwADAfBgNVHREEGDAWgglhY21lLnRlc3SCCWFj +bWUudGVzdDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0OBBYEFN/1UyS0jnC3LoryMIL2 +/6cdsYBBMB8GA1UdIwQYMBaAFLUZcL/zguHlulHg5GYyYhXmVt/6MAUGAytlcANB +ALz3u7lBreHeVZ0YXrwK3SDwlhWIH/SeUQwbxQlarzR47qu3cwQQ93Y1xjtOdu+h +hOM/ig3nLGVOT6qL8IsZrQk= +-----END CERTIFICATE----- diff --git a/tests/common/acme/certs/output/acme.test.key.pem b/tests/common/acme/certs/output/acme.test.key.pem new file mode 100644 index 00000000..06195b8c --- /dev/null +++ b/tests/common/acme/certs/output/acme.test.key.pem @@ -0,0 +1,25 @@ +Public Key Info: + Public Key Algorithm: EdDSA (Ed25519) + Key Security Level: High (256 bits) + +curve: Ed25519 +private key: + 9d:25:38:89:f2:37:d7:65:41:f5:24:ba:4c:19:fb:0f + 86:c8:a3:cf:f7:08:57:69:cc:64:cf:55:2d:8e:99:3e + + +x: + 76:2a:d9:c8:eb:7e:b9:6c:95:6c:6f:70:a8:ec:d9:c0 + fb:4c:76:e9:9a:da:be:8d:ff:dd:96:b2:96:50:4a:3c + + + +Public Key PIN: + pin-sha256:NPwZitkDv4isUmdiicSsM1t1OtYoxqhdvBUnqSc4bFQ= +Public Key ID: + sha256:34fc198ad903bf88ac52676289c4ac335b753ad628c6a85dbc1527a927386c54 + sha1:dff55324b48e70b72e8af23082f6ffa71db18041 + +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIJ0lOInyN9dlQfUkukwZ+w+GyKPP9whXacxkz1Utjpk+ +-----END PRIVATE KEY----- diff --git a/tests/common/acme/certs/output/acme.test.template b/tests/common/acme/certs/output/acme.test.template new file mode 100644 index 00000000..320a1701 --- /dev/null +++ b/tests/common/acme/certs/output/acme.test.template @@ -0,0 +1,5 @@ +dns_name = "acme.test" +dns_name = "acme.test" +expiration_days = 123456 +organization = Vhack.eu Test Keys +encryption_key diff --git a/tests/common/acme/certs/output/ca.cert.pem b/tests/common/acme/certs/output/ca.cert.pem new file mode 100644 index 00000000..0fa9d144 --- /dev/null +++ b/tests/common/acme/certs/output/ca.cert.pem @@ -0,0 +1,10 @@ +-----BEGIN CERTIFICATE----- +MIIBYDCCARKgAwIBAgIUdhVVcf+NgElqGuutU55FUDBtFVMwBQYDK2VwMCoxCzAJ +BgNVBAYTAkVVMRswGQYDVQQKExJWaGFjay5ldSBUZXN0IEtleXMwIBcNMjUwMzAx +MTEyNjU2WhgPMjM2MzAzMDYxMTI2NTZaMCoxCzAJBgNVBAYTAkVVMRswGQYDVQQK +ExJWaGFjay5ldSBUZXN0IEtleXMwKjAFBgMrZXADIQCkO1LhHINvJjt41JD6UEc4 +ZKKUubB8lKPxSOyTkFBOgqNIMEYwDwYDVR0TAQH/BAUwAwEB/zAUBgNVHREEDTAL +gglhY21lLnRlc3QwHQYDVR0OBBYEFLUZcL/zguHlulHg5GYyYhXmVt/6MAUGAytl +cANBAFMFFy5tjuQtp5GVEN6qM50L4lteQuxfhlQqmOOfl06HV6153wJnrlKaTOYO +t0dKlSqKROMYUYeU39xDp07MLAc= +-----END CERTIFICATE----- diff --git a/tests/common/acme/certs/output/ca.key.pem b/tests/common/acme/certs/output/ca.key.pem new file mode 100644 index 00000000..64263bcb --- /dev/null +++ b/tests/common/acme/certs/output/ca.key.pem @@ -0,0 +1,25 @@ +Public Key Info: + Public Key Algorithm: EdDSA (Ed25519) + Key Security Level: High (256 bits) + +curve: Ed25519 +private key: + 82:0d:fc:f0:d6:82:89:63:e5:bc:23:78:ba:98:38:83 + 09:2d:e0:78:4c:53:92:e3:db:5b:2f:e4:39:ce:96:3d + + +x: + a4:3b:52:e1:1c:83:6f:26:3b:78:d4:90:fa:50:47:38 + 64:a2:94:b9:b0:7c:94:a3:f1:48:ec:93:90:50:4e:82 + + + +Public Key PIN: + pin-sha256:jpzYZMOHDPCeSXxfL+YUXgSPcbO9MAs8foGMP5CJiD8= +Public Key ID: + sha256:8e9cd864c3870cf09e497c5f2fe6145e048f71b3bd300b3c7e818c3f9089883f + sha1:b51970bff382e1e5ba51e0e466326215e656dffa + +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIIIN/PDWgolj5bwjeLqYOIMJLeB4TFOS49tbL+Q5zpY9 +-----END PRIVATE KEY----- diff --git a/tests/common/acme/certs/output/ca.template b/tests/common/acme/certs/output/ca.template new file mode 100644 index 00000000..a2295d8d --- /dev/null +++ b/tests/common/acme/certs/output/ca.template @@ -0,0 +1,5 @@ +country = EU +dns_name = "acme.test" +expiration_days = 123456 +organization = Vhack.eu Test Keys +ca diff --git a/tests/common/acme/certs/snakeoil-certs.nix b/tests/common/acme/certs/snakeoil-certs.nix new file mode 100644 index 00000000..aeb6dfce --- /dev/null +++ b/tests/common/acme/certs/snakeoil-certs.nix @@ -0,0 +1,13 @@ +let + domain = "acme.test"; +in { + inherit domain; + ca = { + cert = ./output/ca.cert.pem; + key = ./output/ca.key.pem; + }; + "${domain}" = { + cert = ./output/. + "/${domain}.cert.pem"; + key = ./output/. + "/${domain}.key.pem"; + }; +} diff --git a/tests/common/acme/client.nix b/tests/common/acme/client.nix new file mode 100644 index 00000000..2b870e89 --- /dev/null +++ b/tests/common/acme/client.nix @@ -0,0 +1,21 @@ +{ + nodes, + lib, + ... +}: let + inherit (nodes.acme.test-support.acme) caCert; + inherit (nodes.acme.test-support.acme) caDomain; +in { + security = { + acme = { + acceptTerms = true; + defaults = { + server = "https://${caDomain}/dir"; + }; + }; + + pki = { + certificateFiles = lib.mkForce [caCert]; + }; + }; +} diff --git a/tests/common/acme/default.nix b/tests/common/acme/default.nix new file mode 100644 index 00000000..c756a4f1 --- /dev/null +++ b/tests/common/acme/default.nix @@ -0,0 +1,47 @@ +{pkgs}: let + add_pebble_ca_certs = pkgs.writeShellScript "fetch-and-set-ca" '' + set -xe + + # Fetch the randomly generated ca certificate + curl https://acme.test:15000/roots/0 > /tmp/ca.crt + curl https://acme.test:15000/intermediates/0 >> /tmp/ca.crt + + # Append it to the various system stores + # The file paths are from <nixpgks>/modules/security/ca.nix + for cert_path in "ssl/certs/ca-certificates.crt" "ssl/certs/ca-bundle.crt" "pki/tls/certs/ca-bundle.crt"; do + cert_path="/etc/$cert_path" + + mv "$cert_path" "$cert_path.old" + cat "$cert_path.old" > "$cert_path" + cat /tmp/ca.crt >> "$cert_path" + done + + export NIX_SSL_CERT_FILE=/tmp/ca.crt + export SSL_CERT_FILE=/tmp/ca.crt + + # TODO + # # P11-Kit trust source. + # environment.etc."ssl/trust-source".source = "$${cacertPackage.p11kit}/etc/ssl/trust-source"; + ''; +in { + prepare = clients: extra: + # The parens are needed for the syntax highlighting to work. + ( # python + '' + # Start dependencies for the other services + acme.start() + acme.wait_for_unit("pebble.service") + name_server.start() + name_server.wait_for_unit("nsd.service") + + # Start actual test + start_all() + + with subtest("Add pebble ca key to all services"): + for node in [name_server, ${builtins.concatStringsSep "," clients}]: + node.wait_until_succeeds("curl https://acme.test:15000/roots/0") + node.succeed("${add_pebble_ca_certs}") + '' + ) + + extra; +} diff --git a/tests/common/acme/server.nix b/tests/common/acme/server.nix new file mode 100644 index 00000000..997c944a --- /dev/null +++ b/tests/common/acme/server.nix @@ -0,0 +1,91 @@ +# Add this node as acme server. +# This also needs a DNS server. +{ + config, + pkgs, + lib, + ... +}: let + testCerts = import ./certs/snakeoil-certs.nix; + inherit (testCerts) domain; + + pebbleConf.pebble = { + listenAddress = "0.0.0.0:443"; + managementListenAddress = "0.0.0.0:15000"; + + # The cert and key are used only for the Web Front End (WFE) + certificate = testCerts.${domain}.cert; + privateKey = testCerts.${domain}.key; + + httpPort = 80; + tlsPort = 443; + ocspResponderURL = "http://${domain}:4002"; + strict = true; + }; + + pebbleConfFile = pkgs.writeText "pebble.conf" (builtins.toJSON pebbleConf); +in { + options.test-support.acme = { + caDomain = lib.mkOption { + type = lib.types.str; + default = domain; + readOnly = true; + description = '' + A domain name to use with the `nodes` attribute to + identify the CA server in the `client` config. + ''; + }; + caCert = lib.mkOption { + type = lib.types.path; + readOnly = true; + default = testCerts.ca.cert; + description = '' + A certificate file to use with the `nodes` attribute to + inject the test CA certificate used in the ACME server into + {option}`security.pki.certificateFiles`. + ''; + }; + }; + + config = { + networking = { + # This has priority 140, because modules/testing/test-instrumentation.nix + # already overrides this with priority 150. + nameservers = lib.mkOverride 140 ["127.0.0.1"]; + firewall.allowedTCPPorts = [ + 80 + 443 + 15000 + 4002 + ]; + + extraHosts = '' + 127.0.0.1 ${domain} + ${config.networking.primaryIPAddress} ${domain} + ''; + }; + + systemd.services = { + pebble = { + enable = true; + description = "Pebble ACME server"; + wantedBy = ["network.target"]; + environment = { + # We're not testing lego, we're just testing our configuration. + # No need to sleep. + PEBBLE_VA_NOSLEEP = "1"; + }; + + serviceConfig = { + RuntimeDirectory = "pebble"; + WorkingDirectory = "/run/pebble"; + + # Required to bind on privileged ports. + AmbientCapabilities = ["CAP_NET_BIND_SERVICE"]; + + ExecStart = "${pkgs.pebble}/bin/pebble -config ${pebbleConfFile}"; + }; + }; + }; + }; +} diff --git a/tests/common/dns/client.nix b/tests/common/dns/client.nix new file mode 100644 index 00000000..52f32671 --- /dev/null +++ b/tests/common/dns/client.nix @@ -0,0 +1,10 @@ +{ + lib, + nodes, + ... +}: { + networking.nameservers = lib.mkForce [ + nodes.name_server.networking.primaryIPAddress + nodes.name_server.networking.primaryIPv6Address + ]; +} diff --git a/tests/common/dns/module/default.nix b/tests/common/dns/module/default.nix new file mode 100644 index 00000000..8f4ad37a --- /dev/null +++ b/tests/common/dns/module/default.nix @@ -0,0 +1,86 @@ +{ + config, + lib, + ... +}: let + cfg = config.vhack.dns; + + zones = + builtins.mapAttrs (name: value: { + data = + dns.types.zone.renderToString name value; + }) + cfg.zones; + + dns = import ./dns {inherit lib;}; + + ports = let + parsePorts = listeners: let + splitAddress = addr: lib.splitString "@" addr; + + extractPort = addr: let + split = splitAddress addr; + in + lib.toInt ( + if (builtins.length split) == 2 + then builtins.elemAt split 1 + else "53" + ); + in + builtins.map extractPort listeners; + in + lib.unique (parsePorts cfg.interfaces); +in { + options.vhack.dns = { + enable = lib.mkEnableOption "custom dns server"; + + openFirewall = lib.mkOption { + type = lib.types.bool; + default = false; + description = '' + Open the following ports: + TCP (${lib.concatStringsSep ", " (map toString ports)}) + UDP (${lib.concatStringsSep ", " (map toString ports)}) + ''; + }; + + interfaces = lib.mkOption { + type = lib.types.listOf lib.types.str; + description = '' + A list of the interfaces to bind to. To select the port add `@` to the end of the + interface. The default port is 53. + ''; + example = [ + "192.168.1.3" + "2001:db8:1::3" + ]; + }; + + zones = lib.mkOption { + type = lib.types.attrsOf dns.types.zone.zone; + description = "DNS zones"; + }; + }; + + config = lib.mkIf cfg.enable { + services.nsd = { + enable = true; + verbosity = 4; + inherit (cfg) interfaces; + inherit zones; + }; + + networking.firewall.allowedUDPPorts = lib.mkIf cfg.openFirewall ports; + networking.firewall.allowedTCPPorts = lib.mkIf cfg.openFirewall ports; + + systemd.services.nsd = { + requires = [ + "network-online.target" + ]; + after = [ + "network.target" + "network-online.target" + ]; + }; + }; +} diff --git a/tests/common/dns/module/dns/default.nix b/tests/common/dns/module/dns/default.nix new file mode 100644 index 00000000..4ce07d8f --- /dev/null +++ b/tests/common/dns/module/dns/default.nix @@ -0,0 +1,13 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{lib}: let + util = import ./util {inherit lib;}; + types = import ./types {inherit lib util;}; +in { + inherit + types + ; +} diff --git a/tests/common/dns/module/dns/types/default.nix b/tests/common/dns/module/dns/types/default.nix new file mode 100644 index 00000000..ece315fa --- /dev/null +++ b/tests/common/dns/module/dns/types/default.nix @@ -0,0 +1,16 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{ + lib, + util, +}: let + simple = {types = import ./simple.nix {inherit lib;};}; +in { + record = import ./record.nix {inherit lib util;}; + records = import ./records {inherit lib util simple;}; + + zone = import ./zone.nix {inherit lib util simple;}; +} diff --git a/tests/common/dns/module/dns/types/record.nix b/tests/common/dns/module/dns/types/record.nix new file mode 100644 index 00000000..e992bf90 --- /dev/null +++ b/tests/common/dns/module/dns/types/record.nix @@ -0,0 +1,75 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# SPDX-FileCopyrightText: 2021 Naïm Favier <n@monade.li> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{lib, ...}: let + inherit (lib) hasSuffix isString mkOption removeSuffix types; + + recordType = rsubt: let + submodule = types.submodule { + options = + { + class = mkOption { + type = types.enum ["IN"]; + default = "IN"; + example = "IN"; + description = "Resource record class. Only IN is supported"; + }; + ttl = mkOption { + type = types.nullOr types.ints.unsigned; # TODO: u32 + default = null; + example = 300; + description = "Record caching duration (in seconds)"; + }; + } + // rsubt.options; + }; + in + ( + if rsubt ? fromString + then types.either types.str + else lib.id + ) + submodule; + + # name == "@" : use unqualified domain name + writeRecord = name: rsubt: data: let + data' = + if isString data && rsubt ? fromString + then + # add default values for the record type + (recordType rsubt).merge [] [ + { + file = ""; + value = rsubt.fromString data; + } + ] + else data; + name' = let + fname = rsubt.nameFixup or (n: _: n) name data'; + in + if name == "@" + then name + else if (hasSuffix ".@" name) + then removeSuffix ".@" fname + else "${fname}."; + inherit (rsubt) rtype; + in + lib.concatStringsSep " " (with data'; + [ + name' + ] + ++ lib.optionals (ttl != null) [ + (toString ttl) + ] + ++ [ + class + rtype + (rsubt.dataToString data') + ]); +in { + inherit recordType; + inherit writeRecord; +} diff --git a/tests/common/dns/module/dns/types/records/A.nix b/tests/common/dns/module/dns/types/records/A.nix new file mode 100644 index 00000000..296943ef --- /dev/null +++ b/tests/common/dns/module/dns/types/records/A.nix @@ -0,0 +1,19 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{lib, ...}: let + inherit (lib) mkOption types; +in { + rtype = "A"; + options = { + address = mkOption { + type = types.str; + example = "26.3.0.103"; + description = "IP address of the host"; + }; + }; + dataToString = {address, ...}: address; + fromString = address: {inherit address;}; +} diff --git a/tests/common/dns/module/dns/types/records/AAAA.nix b/tests/common/dns/module/dns/types/records/AAAA.nix new file mode 100644 index 00000000..4717176a --- /dev/null +++ b/tests/common/dns/module/dns/types/records/AAAA.nix @@ -0,0 +1,19 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{lib, ...}: let + inherit (lib) mkOption types; +in { + rtype = "AAAA"; + options = { + address = mkOption { + type = types.str; + example = "4321:0:1:2:3:4:567:89ab"; + description = "IPv6 address of the host"; + }; + }; + dataToString = {address, ...}: address; + fromString = address: {inherit address;}; +} diff --git a/tests/common/dns/module/dns/types/records/CAA.nix b/tests/common/dns/module/dns/types/records/CAA.nix new file mode 100644 index 00000000..4b405107 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/CAA.nix @@ -0,0 +1,42 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 8659 +{lib, ...}: let + inherit (lib) mkOption types; +in { + rtype = "CAA"; + options = { + issuerCritical = mkOption { + type = types.bool; + example = true; + description = '' + If set to '1', indicates that the corresponding property tag + MUST be understood if the semantics of the CAA record are to be + correctly interpreted by an issuer + ''; + }; + tag = mkOption { + type = types.enum ["issue" "issuewild" "iodef"]; + example = "issue"; + description = "One of the defined property tags"; + }; + value = mkOption { + type = types.str; # section 4.1.1: not limited in length + example = "ca.example.net"; + description = "Value of the property"; + }; + }; + dataToString = { + issuerCritical, + tag, + value, + ... + }: ''${ + if issuerCritical + then "128" + else "0" + } ${tag} "${value}"''; +} diff --git a/tests/common/dns/module/dns/types/records/CNAME.nix b/tests/common/dns/module/dns/types/records/CNAME.nix new file mode 100644 index 00000000..095b078c --- /dev/null +++ b/tests/common/dns/module/dns/types/records/CNAME.nix @@ -0,0 +1,27 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.1 +{ + lib, + simple, + ... +}: let + inherit (lib) mkOption; +in { + rtype = "CNAME"; + options = { + cname = mkOption { + type = simple.types.domain-name; + example = "www.test.com"; + description = '' + A <domain-name> which specifies the canonical or primary name + for the owner. The owner name is an alias. + ''; + }; + }; + dataToString = {cname, ...}: "${cname}"; + fromString = cname: {inherit cname;}; +} diff --git a/tests/common/dns/module/dns/types/records/DKIM.nix b/tests/common/dns/module/dns/types/records/DKIM.nix new file mode 100644 index 00000000..31b2f67e --- /dev/null +++ b/tests/common/dns/module/dns/types/records/DKIM.nix @@ -0,0 +1,75 @@ +# +# SPDX-FileCopyrightText: 2020 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# This is a “fake” record type, not actually part of DNS. +# It gets compiled down to a TXT record. +# RFC 6376 +{ + lib, + util, + ... +}: let + inherit (lib) mkOption types; +in rec { + rtype = "TXT"; + options = { + selector = mkOption { + type = types.str; + example = "mail"; + description = "DKIM selector name"; + }; + h = mkOption { + type = types.listOf types.str; + default = []; + example = ["sha1" "sha256"]; + description = "Acceptable hash algorithms. Empty means all of them"; + apply = lib.concatStringsSep ":"; + }; + k = mkOption { + type = types.nullOr types.str; + default = "rsa"; + example = "rsa"; + description = "Key type"; + }; + n = mkOption { + type = types.str; + default = ""; + example = "Just any kind of arbitrary notes."; + description = "Notes that might be of interest to a human"; + }; + p = mkOption { + type = types.str; + example = "MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDwIRP/UC3SBsEmGqZ9ZJW3/DkMoGeLnQg1fWn7/zYtIxN2SnFCjxOCKG9v3b4jYfcTNh5ijSsq631uBItLa7od+v/RtdC2UzJ1lWT947qR+Rcac2gbto/NMqJ0fzfVjH4OuKhitdY9tf6mcwGjaNBcWToIMmPSPDdQPNUYckcQ2QIDAQAB"; + description = "Public-key data (base64)"; + }; + s = mkOption { + type = types.listOf (types.enum ["*" "email"]); + default = ["*"]; + example = ["email"]; + description = "Service Type"; + apply = lib.concatStringsSep ":"; + }; + t = mkOption { + type = types.listOf (types.enum ["y" "s"]); + default = []; + example = ["y"]; + description = "Flags"; + apply = lib.concatStringsSep ":"; + }; + }; + dataToString = data: let + items = + ["v=DKIM1"] + ++ lib.pipe data [ + (builtins.intersectAttrs options) # remove garbage list `_module` + (lib.filterAttrs (_k: v: v != null && v != "")) + (lib.filterAttrs (k: _v: k != "selector")) + (lib.mapAttrsToList (k: v: "${k}=${v}")) + ]; + result = lib.concatStringsSep "; " items + ";"; + in + util.writeCharacterString result; + nameFixup = name: self: "${self.selector}._domainkey.${name}"; +} diff --git a/tests/common/dns/module/dns/types/records/DMARC.nix b/tests/common/dns/module/dns/types/records/DMARC.nix new file mode 100644 index 00000000..0f10f2c1 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/DMARC.nix @@ -0,0 +1,108 @@ +# +# SPDX-FileCopyrightText: 2020 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# This is a “fake” record type, not actually part of DNS. +# It gets compiled down to a TXT record. +# RFC 7489 +{ + lib, + util, + ... +}: let + inherit (lib) mkOption types; +in rec { + rtype = "TXT"; + options = { + adkim = mkOption { + type = types.enum ["relaxed" "strict"]; + default = "relaxed"; + example = "strict"; + description = "DKIM Identifier Alignment mode"; + apply = builtins.substring 0 1; + }; + aspf = mkOption { + type = types.enum ["relaxed" "strict"]; + default = "relaxed"; + example = "strict"; + description = "SPF Identifier Alignment mode"; + apply = builtins.substring 0 1; + }; + fo = mkOption { + type = types.listOf (types.enum ["0" "1" "d" "s"]); + default = ["0"]; + example = ["0" "1" "s"]; + description = "Failure reporting options"; + apply = lib.concatStringsSep ":"; + }; + p = mkOption { + type = types.enum ["none" "quarantine" "reject"]; + example = "quarantine"; + description = "Requested Mail Receiver policy"; + }; + pct = mkOption { + type = types.ints.between 0 100; + default = 100; + example = 30; + description = "Percentage of messages to which the DMARC policy is to be applied"; + apply = builtins.toString; + }; + rf = mkOption { + type = types.listOf (types.enum ["afrf"]); + default = ["afrf"]; + example = ["afrf"]; + description = "Format to be used for message-specific failure reports"; + apply = lib.concatStringsSep ":"; + }; + ri = mkOption { + type = types.ints.unsigned; # FIXME: u32 + default = 86400; + example = 12345; + description = "Interval requested between aggregate reports"; + apply = builtins.toString; + }; + rua = mkOption { + type = types.oneOf [types.str (types.listOf types.str)]; + default = []; + example = "mailto:dmarc+rua@example.com"; + description = "Addresses to which aggregate feedback is to be sent"; + apply = val: + # FIXME: need to encode commas in URIs + if builtins.isList val + then lib.concatStringsSep "," val + else val; + }; + ruf = mkOption { + type = types.listOf types.str; + default = []; + example = ["mailto:dmarc+ruf@example.com" "mailto:another+ruf@example.com"]; + description = "Addresses to which message-specific failure information is to be reported"; + apply = val: + # FIXME: need to encode commas in URIs + if builtins.isList val + then lib.concatStringsSep "," val + else val; + }; + sp = mkOption { + type = types.nullOr (types.enum ["none" "quarantine" "reject"]); + default = null; + example = "quarantine"; + description = "Requested Mail Receiver policy for all subdomains"; + }; + }; + dataToString = data: let + # The specification could be more clear on this, but `v` and `p` MUST + # be the first two tags in the record. + items = + ["v=DMARC1; p=${data.p}"] + ++ lib.pipe data [ + (builtins.intersectAttrs options) # remove garbage list `_module` + (lib.filterAttrs (k: v: v != null && v != "" && k != "p")) + (lib.mapAttrsToList (k: v: "${k}=${v}")) + ]; + result = lib.concatStringsSep "; " items + ";"; + in + util.writeCharacterString result; + nameFixup = name: _self: "_dmarc.${name}"; +} diff --git a/tests/common/dns/module/dns/types/records/DNAME.nix b/tests/common/dns/module/dns/types/records/DNAME.nix new file mode 100644 index 00000000..042ce95c --- /dev/null +++ b/tests/common/dns/module/dns/types/records/DNAME.nix @@ -0,0 +1,15 @@ +# RFC 6672 +{lib, ...}: let + inherit (lib) dns mkOption; +in { + rtype = "DNAME"; + options = { + dname = mkOption { + type = dns.types.domain-name; + example = "www.test.com"; + description = "A <domain-name> which provides redirection from a part of the DNS name tree to another part of the DNS name tree"; + }; + }; + dataToString = {dname, ...}: "${dname}"; + fromString = dname: {inherit dname;}; +} diff --git a/tests/common/dns/module/dns/types/records/DNSKEY.nix b/tests/common/dns/module/dns/types/records/DNSKEY.nix new file mode 100644 index 00000000..86ce3a10 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/DNSKEY.nix @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2020 Aluísio Augusto Silva Gonçalves <https://aasg.name> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# RFC 4034, 2 +{lib, ...}: let + inherit (builtins) isInt split; + inherit (lib) concatStrings flatten mkOption types; + + dnssecOptions = import ./dnssec.nix {inherit lib;}; + inherit (dnssecOptions) mkDNSSECAlgorithmOption; +in { + rtype = "DNSKEY"; + options = { + flags = mkOption { + description = "Flags pertaining to this RR."; + type = types.either types.ints.u16 (types.submodule { + options = { + zoneSigningKey = mkOption { + description = "Whether this RR holds a zone signing key (ZSK)."; + type = types.bool; + default = false; + }; + secureEntryPoint = mkOption { + type = types.bool; + description = '' + Whether this RR holds a secure entry point. + In general, this means the key is a key-signing key (KSK), as opposed to a zone-signing key. + ''; + default = false; + }; + }; + }); + apply = value: + if isInt value + then value + else + ( + if value.zoneSigningKey + then 256 + else 0 + ) + + ( + if value.secureEntryPoint + then 1 + else 0 + ); + }; + algorithm = mkDNSSECAlgorithmOption { + description = "Algorithm of the key referenced by this RR."; + }; + publicKey = mkOption { + type = types.str; + description = "Base64-encoded public key."; + apply = value: concatStrings (flatten (split "[[:space:]]" value)); + }; + }; + dataToString = { + flags, + algorithm, + publicKey, + ... + }: "${toString flags} 3 ${toString algorithm} ${publicKey}"; +} diff --git a/tests/common/dns/module/dns/types/records/DS.nix b/tests/common/dns/module/dns/types/records/DS.nix new file mode 100644 index 00000000..76fac9a3 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/DS.nix @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2020 Aluísio Augusto Silva Gonçalves <https://aasg.name> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# RFC 4034, 5 +{lib, ...}: let + inherit (lib) mkOption types; + + dnssecOptions = import ./dnssec.nix {inherit lib;}; + inherit (dnssecOptions) mkRegisteredNumberOption mkDNSSECAlgorithmOption; + + mkDSDigestTypeOption = args: + mkRegisteredNumberOption { + registryName = "Delegation Signer (DS) Resource Record (RR) Type Digest Algorithms"; + numberType = types.ints.u8; + # These mnemonics are unofficial, unlike the DNSSEC algorithm ones. + mnemonics = { + "sha-1" = 1; + "sha-256" = 2; + "gost" = 3; + "sha-384" = 4; + }; + }; +in { + rtype = "DS"; + options = { + keyTag = mkOption { + description = "Tag computed over the DNSKEY referenced by this RR to identify it."; + type = types.ints.u16; + }; + algorithm = mkDNSSECAlgorithmOption { + description = "Algorithm of the key referenced by this RR."; + }; + digestType = mkDSDigestTypeOption { + description = "Type of the digest given in the `digest` attribute."; + }; + digest = mkOption { + description = "Digest of the DNSKEY referenced by this RR."; + type = types.strMatching "[[:xdigit:]]+"; + }; + }; + dataToString = { + keyTag, + algorithm, + digestType, + digest, + ... + }: "${toString keyTag} ${toString algorithm} ${toString digestType} ${digest}"; +} diff --git a/tests/common/dns/module/dns/types/records/HTTPS.nix b/tests/common/dns/module/dns/types/records/HTTPS.nix new file mode 100644 index 00000000..6e2ef3df --- /dev/null +++ b/tests/common/dns/module/dns/types/records/HTTPS.nix @@ -0,0 +1,5 @@ +args: +import ./SVCB.nix args +// { + rtype = "HTTPS"; +} diff --git a/tests/common/dns/module/dns/types/records/MTA-STS.nix b/tests/common/dns/module/dns/types/records/MTA-STS.nix new file mode 100644 index 00000000..030490e1 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/MTA-STS.nix @@ -0,0 +1,42 @@ +# +# SPDX-FileCopyrightText: 2025 Benedikt Peetz <benedikt.peetz@b-peetz.de> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# This is a “fake” record type, not actually part of DNS. +# It gets compiled down to a TXT record. +# RFC 8461 +{ + lib, + util, + ... +}: let + inherit (lib) mkOption types; +in rec { + rtype = "TXT"; + options = { + id = mkOption { + type = types.str; + example = "20160831085700Z"; + description = '' + A short string used to track policy updates. This string MUST + uniquely identify a given instance of a policy, such that senders + can determine when the policy has been updated by comparing to the + "id" of a previously seen policy. There is no implied ordering of + "id" fields between revisions. + ''; + }; + }; + dataToString = data: let + items = + ["v=STSv1"] + ++ lib.pipe data [ + (builtins.intersectAttrs options) # remove garbage list `_module` + (lib.filterAttrs (k: v: v != null && v != "")) + (lib.mapAttrsToList (k: v: "${k}=${v}")) + ]; + result = lib.concatStringsSep "; " items + ";"; + in + util.writeCharacterString result; + nameFixup = name: _self: "_mta-sts.${name}"; +} diff --git a/tests/common/dns/module/dns/types/records/MX.nix b/tests/common/dns/module/dns/types/records/MX.nix new file mode 100644 index 00000000..c25b89cf --- /dev/null +++ b/tests/common/dns/module/dns/types/records/MX.nix @@ -0,0 +1,32 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.9 +{ + lib, + simple, + ... +}: let + inherit (lib) mkOption types; +in { + rtype = "MX"; + options = { + preference = mkOption { + type = types.ints.u16; + example = 10; + description = "The preference given to this RR among others at the same owner. Lower values are preferred"; + }; + exchange = mkOption { + type = simple.types.domain-name; + example = "smtp.example.com."; + description = "A <domain-name> which specifies a host willing to act as a mail exchange for the owner name"; + }; + }; + dataToString = { + preference, + exchange, + ... + }: "${toString preference} ${exchange}"; +} diff --git a/tests/common/dns/module/dns/types/records/NS.nix b/tests/common/dns/module/dns/types/records/NS.nix new file mode 100644 index 00000000..ea60a911 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/NS.nix @@ -0,0 +1,24 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.11 +{ + lib, + simple, + ... +}: let + inherit (lib) mkOption; +in { + rtype = "NS"; + options = { + nsdname = mkOption { + type = simple.types.domain-name; + example = "ns2.example.com"; + description = "A <domain-name> which specifies a host which should be authoritative for the specified class and domain"; + }; + }; + dataToString = {nsdname, ...}: "${nsdname}"; + fromString = nsdname: {inherit nsdname;}; +} diff --git a/tests/common/dns/module/dns/types/records/OPENPGPKEY.nix b/tests/common/dns/module/dns/types/records/OPENPGPKEY.nix new file mode 100644 index 00000000..1f39cb93 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/OPENPGPKEY.nix @@ -0,0 +1,18 @@ +# RFC7929 +{ + lib, + util, + ... +}: let + inherit (lib) mkOption types; +in { + rtype = "OPENPGPKEY"; + options = { + data = mkOption { + type = types.str; + }; + }; + + dataToString = {data, ...}: util.writeCharacterString data; + fromString = data: {inherit data;}; +} diff --git a/tests/common/dns/module/dns/types/records/PTR.nix b/tests/common/dns/module/dns/types/records/PTR.nix new file mode 100644 index 00000000..075f82ee --- /dev/null +++ b/tests/common/dns/module/dns/types/records/PTR.nix @@ -0,0 +1,92 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.12 +{ + lib, + simple, + ... +}: let + inherit (lib) mkOption; + + inherit (lib.strings) stringToCharacters splitString; + + reverseIpv4 = input: + builtins.concatStringsSep "." (lib.lists.reverseList (splitString "." + input)); + + reverseIpv6 = input: let + split = splitString ":" input; + elementLength = builtins.length split; + + reverseString = string: + builtins.concatStringsSep "" (lib.lists.reverseList + (stringToCharacters string)); + in + reverseString (builtins.concatStringsSep "." (stringToCharacters (builtins.concatStringsSep + "" (builtins.map ( + part: let + c = stringToCharacters part; + in + if builtins.length c == 4 + then + # valid part + part + else if builtins.length c < 4 && builtins.length c > 0 + then + # leading zeros were elided + (builtins.concatStringsSep "" ( + builtins.map builtins.toString ( + builtins.genList (_: 0) (4 - (builtins.length c)) + ) + )) + + part + else if builtins.length c == 0 + then + # Multiple full blocks were elided. Only one of these can be in an + # IPv6 address, as such we can simply add (8 - (elementLength - 1)) `0000` + # blocks. We need to substract one from `elementLength` because + # this empty part is included in the `elementLength`. + builtins.concatStringsSep "" (builtins.genList (_: "0000") (8 - (elementLength - 1))) + else builtins.throw "Impossible" + ) + split)))); +in { + rtype = "PTR"; + options = { + name = mkOption { + type = simple.types.domain-name; + example = "mail2.server.com"; + description = "The <domain-name> which is defined by the IP."; + }; + ip = { + v4 = mkOption { + type = lib.types.nullOr lib.types.str; + example = "192.168.1.4"; + description = "The IPv4 address of the host."; + default = null; + apply = v: + if v != null + then reverseIpv4 v + else v; + }; + v6 = mkOption { + type = lib.types.nullOr lib.types.str; + example = "192.168.1.4"; + description = "The IPv6 address of the host."; + default = null; + apply = v: + if v != null + then reverseIpv6 v + else v; + }; + }; + }; + dataToString = {name, ...}: "${name}."; + nameFixup = name: self: + if self.ip.v6 == null + then "${self.ip.v4}.in-addr.arpa" + else "${self.ip.v6}.ip6.arpa"; +} diff --git a/tests/common/dns/module/dns/types/records/SOA.nix b/tests/common/dns/module/dns/types/records/SOA.nix new file mode 100644 index 00000000..db7436e9 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/SOA.nix @@ -0,0 +1,65 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.13 +{ + lib, + simple, + ... +}: let + inherit (lib) concatStringsSep removeSuffix replaceStrings; + inherit (lib) mkOption types; +in { + rtype = "SOA"; + options = { + nameServer = mkOption { + type = simple.types.domain-name; + example = "ns1.example.com"; + description = "The <domain-name> of the name server that was the original or primary source of data for this zone. Don't forget the dot at the end!"; + }; + adminEmail = mkOption { + type = simple.types.domain-name; + example = "admin@example.com"; + description = "An email address of the person responsible for this zone. (Note: in traditional zone files you are supposed to put a dot instead of `@` in your address; you can use `@` with this module and it is recommended to do so. Also don't put the dot at the end!)"; + apply = s: replaceStrings ["@"] ["."] (removeSuffix "." s); + }; + serial = mkOption { + type = types.ints.unsigned; # TODO: u32 + example = 20; + description = "Version number of the original copy of the zone"; + }; + refresh = mkOption { + type = types.ints.unsigned; # TODO: u32 + default = 24 * 60 * 60; + example = 7200; + description = "Time interval before the zone should be refreshed"; + }; + retry = mkOption { + type = types.ints.unsigned; # TODO: u32 + default = 10 * 60; + example = 600; + description = "Time interval that should elapse before a failed refresh should be retried"; + }; + expire = mkOption { + type = types.ints.unsigned; # TODO: u32 + default = 10 * 24 * 60 * 60; + example = 3600000; + description = "Time value that specifies the upper limit on the time interval that can elapse before the zone is no longer authoritative"; + }; + minimum = mkOption { + type = types.ints.unsigned; # TODO: u32 + default = 60; + example = 60; + description = "Minimum TTL field that should be exported with any RR from this zone"; + }; + }; + dataToString = data @ { + nameServer, + adminEmail, + ... + }: let + numbers = map toString (with data; [serial refresh retry expire minimum]); + in "${nameServer} ${adminEmail}. (${concatStringsSep " " numbers})"; +} diff --git a/tests/common/dns/module/dns/types/records/SRV.nix b/tests/common/dns/module/dns/types/records/SRV.nix new file mode 100644 index 00000000..5f558edd --- /dev/null +++ b/tests/common/dns/module/dns/types/records/SRV.nix @@ -0,0 +1,51 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 2782 +{ + lib, + simple, + ... +}: let + inherit (lib) mkOption types; +in { + rtype = "SRV"; + options = { + service = mkOption { + type = types.str; + example = "foobar"; + description = "The symbolic name of the desired service. Do not add the underscore!"; + }; + proto = mkOption { + type = types.str; + example = "tcp"; + description = "The symbolic name of the desired protocol. Do not add the underscore!"; + }; + priority = mkOption { + type = types.ints.u16; + default = 0; + example = 0; + description = "The priority of this target host"; + }; + weight = mkOption { + type = types.ints.u16; + default = 100; + example = 20; + description = "The weight field specifies a relative weight for entries with the same priority. Larger weights SHOULD be given a proportionately higher probability of being selected"; + }; + port = mkOption { + type = types.ints.u16; + example = 9; + description = "The port on this target host of this service"; + }; + target = mkOption { + type = simple.types.domain-name; + example = ""; + description = "The domain name of the target host"; + }; + }; + dataToString = data: with data; "${toString priority} ${toString weight} ${toString port} ${target}"; + nameFixup = name: self: "_${self.service}._${self.proto}.${name}"; +} diff --git a/tests/common/dns/module/dns/types/records/SSHFP.nix b/tests/common/dns/module/dns/types/records/SSHFP.nix new file mode 100644 index 00000000..14098603 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/SSHFP.nix @@ -0,0 +1,39 @@ +# RFC 4255 +{lib, ...}: let + inherit (lib) mkOption types; + inherit (builtins) attrNames; + algorithm = { + "rsa" = 1; + "dsa" = 2; + "ecdsa" = 3; # RFC 6594 + "ed25519" = 4; # RFC 7479 / RFC 8709 + "ed448" = 6; # RFC 8709 + }; + mode = { + "sha1" = 1; + "sha256" = 2; # RFC 6594 + }; +in { + rtype = "SSHFP"; + options = { + algorithm = mkOption { + example = "ed25519"; + type = types.enum (attrNames algorithm); + apply = value: algorithm.${value}; + }; + fingerprintType = mkOption { + example = "sha256"; + type = types.enum (attrNames mode); + apply = value: mode.${value}; + }; + fingerprint = mkOption { + type = types.str; + }; + }; + dataToString = { + algorithm, + fingerprintType, + fingerprint, + ... + }: "${toString algorithm} ${toString fingerprintType} ${fingerprint}"; +} diff --git a/tests/common/dns/module/dns/types/records/SVCB.nix b/tests/common/dns/module/dns/types/records/SVCB.nix new file mode 100644 index 00000000..62cbc3da --- /dev/null +++ b/tests/common/dns/module/dns/types/records/SVCB.nix @@ -0,0 +1,100 @@ +# rfc9460 +{lib, ...}: let + inherit + (lib) + concatStringsSep + filter + isInt + isList + mapAttrsToList + mkOption + types + ; + + mkSvcParams = params: + concatStringsSep " " ( + filter (s: s != "") ( + mapAttrsToList ( + name: value: + if value + then name + else if isList value + then "${name}=${concatStringsSep "," value}" + else if isInt value + then "${name}=${builtins.toString value}" + else "" + ) + params + ) + ); +in { + rtype = "SVCB"; + options = { + svcPriority = mkOption { + example = 1; + type = types.ints.u16; + }; + targetName = mkOption { + example = "."; + type = types.str; + }; + mandatory = mkOption { + example = ["ipv4hint"]; + default = null; + type = types.nullOr (types.nonEmptyListOf types.str); + }; + alpn = mkOption { + example = ["h2"]; + default = null; + type = types.nullOr (types.nonEmptyListOf types.str); + }; + no-default-alpn = mkOption { + example = true; + default = false; + type = types.bool; + }; + port = mkOption { + example = 443; + default = null; + type = types.nullOr types.port; + }; + ipv4hint = mkOption { + example = ["127.0.0.1"]; + default = null; + type = types.nullOr (types.nonEmptyListOf types.str); + }; + ipv6hint = mkOption { + example = ["::1"]; + default = null; + type = types.nullOr (types.nonEmptyListOf types.str); + }; + ech = mkOption { + type = types.nullOr types.str; + default = null; + }; + }; + dataToString = { + svcPriority, + targetName, + mandatory ? null, + alpn ? null, + no-default-alpn ? null, + port ? null, + ipv4hint ? null, + ipv6hint ? null, + ech ? null, + ... + }: "${toString svcPriority} ${targetName} ${ + mkSvcParams { + inherit + alpn + ech + ipv4hint + ipv6hint + mandatory + no-default-alpn + port + ; + } + }"; +} diff --git a/tests/common/dns/module/dns/types/records/TLSA.nix b/tests/common/dns/module/dns/types/records/TLSA.nix new file mode 100644 index 00000000..d92a29b0 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/TLSA.nix @@ -0,0 +1,50 @@ +# RFC 6698 +{lib, ...}: let + inherit (lib) mkOption types; + inherit (builtins) attrNames; + + certUsage = { + "pkix-ta" = 0; + "pkix-ee" = 1; + "dane-ta" = 2; + "dane-ee" = 3; + }; + selectors = { + "cert" = 0; + "spki" = 1; + }; + match = { + "exact" = 0; + "sha256" = 1; + "sha512" = 2; + }; +in { + rtype = "TLSA"; + options = { + certUsage = mkOption { + example = "dane-ee"; + type = types.enum (attrNames certUsage); + apply = value: certUsage.${value}; + }; + selector = mkOption { + example = "spki"; + type = types.enum (attrNames selectors); + apply = value: selectors.${value}; + }; + matchingType = mkOption { + example = "sha256"; + type = types.enum (attrNames match); + apply = value: match.${value}; + }; + certificate = mkOption { + type = types.str; + }; + }; + dataToString = { + certUsage, + selector, + matchingType, + certificate, + ... + }: "${toString certUsage} ${toString selector} ${toString matchingType} ${certificate}"; +} diff --git a/tests/common/dns/module/dns/types/records/TXT.nix b/tests/common/dns/module/dns/types/records/TXT.nix new file mode 100644 index 00000000..d605ce82 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/TXT.nix @@ -0,0 +1,24 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +# RFC 1035, 3.3.14 +{ + lib, + util, + ... +}: let + inherit (lib) mkOption types; +in { + rtype = "TXT"; + options = { + data = mkOption { + type = types.str; + example = "favorite drink=orange juice"; + description = "Arbitrary information"; + }; + }; + dataToString = {data, ...}: util.writeCharacterString data; + fromString = data: {inherit data;}; +} diff --git a/tests/common/dns/module/dns/types/records/default.nix b/tests/common/dns/module/dns/types/records/default.nix new file mode 100644 index 00000000..76a86cdd --- /dev/null +++ b/tests/common/dns/module/dns/types/records/default.nix @@ -0,0 +1,43 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{ + lib, + util, + simple, +}: let + inherit (lib.attrsets) genAttrs; + + types = [ + "A" + "AAAA" + "CAA" + "CNAME" + "DNAME" + "MX" + "NS" + "SOA" + "SRV" + "TXT" + "PTR" + + # DNSSEC types + "DNSKEY" + "DS" + + # DANE types + "SSHFP" + "TLSA" + "OPENPGPKEY" + "SVCB" + "HTTPS" + + # Pseudo types + "DKIM" + "DMARC" + "MTA-STS" + ]; +in + genAttrs types (t: import (./. + "/${t}.nix") {inherit lib simple util;}) diff --git a/tests/common/dns/module/dns/types/records/dnssec.nix b/tests/common/dns/module/dns/types/records/dnssec.nix new file mode 100644 index 00000000..648f6762 --- /dev/null +++ b/tests/common/dns/module/dns/types/records/dnssec.nix @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2020 Aluísio Augusto Silva Gonçalves <https://aasg.name> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +{lib}: let + inherit (builtins) attrNames isInt removeAttrs; + inherit (lib) mkOption types; +in rec { + mkRegisteredNumberOption = { + registryName, + numberType, + mnemonics, + } @ args: + mkOption + { + type = + types.either numberType (types.enum (attrNames mnemonics)) + // { + name = "registeredNumber"; + description = "number in IANA registry '${registryName}'"; + }; + apply = value: + if isInt value + then value + else mnemonics.${value}; + } + // removeAttrs args ["registryName" "numberType" "mnemonics"]; + + mkDNSSECAlgorithmOption = args: + mkRegisteredNumberOption { + registryName = "Domain Name System Security (DNSSEC) Algorithm Numbers"; + numberType = types.ints.u8; + mnemonics = { + "dsa" = 3; + "rsasha1" = 5; + "dsa-nsec3-sha1" = 6; + "rsasha1-nsec3-sha1" = 7; + "rsasha256" = 8; + "rsasha512" = 10; + "ecc-gost" = 12; + "ecdsap256sha256" = 13; + "ecdsap384sha384" = 14; + "ed25519" = 15; + "ed448" = 16; + "privatedns" = 253; + "privateoid" = 254; + }; + }; +} diff --git a/tests/common/dns/module/dns/types/simple.nix b/tests/common/dns/module/dns/types/simple.nix new file mode 100644 index 00000000..fece2c9b --- /dev/null +++ b/tests/common/dns/module/dns/types/simple.nix @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: 2021 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +{lib}: let + inherit (builtins) stringLength; +in { + # RFC 1035, 3.1 + domain-name = lib.types.addCheck lib.types.str (s: stringLength s <= 255); +} diff --git a/tests/common/dns/module/dns/types/zone.nix b/tests/common/dns/module/dns/types/zone.nix new file mode 100644 index 00000000..44ccb150 --- /dev/null +++ b/tests/common/dns/module/dns/types/zone.nix @@ -0,0 +1,119 @@ +# +# SPDX-FileCopyrightText: 2019 Kirill Elagin <https://kir.elagin.me/> +# SPDX-FileCopyrightText: 2021 Naïm Favier <n@monade.li> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +# +{ + lib, + util, + simple, +}: let + inherit (builtins) filter removeAttrs; + inherit + (lib) + concatMapStringsSep + concatStringsSep + mapAttrs + mapAttrsToList + optionalString + ; + inherit (lib) mkOption literalExample types; + + inherit (import ./record.nix {inherit lib;}) recordType writeRecord; + + rsubtypes = import ./records {inherit lib util simple;}; + rsubtypes' = removeAttrs rsubtypes ["SOA"]; + + subzoneOptions = + { + subdomains = mkOption { + type = types.attrsOf subzone; + default = {}; + example = { + www = { + A = [{address = "1.1.1.1";}]; + }; + staging = { + A = [{address = "1.0.0.1";}]; + }; + }; + description = "Records for subdomains of the domain"; + }; + } + // mapAttrs (n: t: + mkOption { + type = types.listOf (recordType t); + default = []; + # example = [ t.example ]; # TODO: any way to auto-generate an example for submodule? + description = "List of ${n} records for this zone/subzone"; + }) + rsubtypes'; + + subzone = types.submodule { + options = subzoneOptions; + }; + + writeSubzone = name: zone: let + groupToString = pseudo: subt: + concatMapStringsSep "\n" (writeRecord name subt) zone."${pseudo}"; + groups = mapAttrsToList groupToString rsubtypes'; + groups' = filter (s: s != "") groups; + + writeSubzone' = subname: writeSubzone "${subname}.${name}"; + sub = concatStringsSep "\n\n" (mapAttrsToList writeSubzone' zone.subdomains); + in + concatStringsSep "\n\n" groups' + + optionalString (sub != "") ("\n\n" + sub); + zone = types.submodule ({name, ...}: { + options = + { + useOrigin = mkOption { + type = types.bool; + default = false; + description = "Wether to use $ORIGIN and unqualified name or fqdn when exporting the zone."; + }; + + TTL = mkOption { + type = types.ints.unsigned; + default = 24 * 60 * 60; + example = literalExample "60 * 60"; + description = "Default record caching duration. Sets the $TTL variable"; + }; + SOA = mkOption rec { + type = recordType rsubtypes.SOA; + example = + { + ttl = 24 * 60 * 60; + } + // type.example; + description = "SOA record"; + }; + } + // subzoneOptions; + }); + renderToString = name: { + useOrigin, + TTL, + SOA, + ... + } @ zone: + if useOrigin + then '' + $ORIGIN ${name}. + $TTL ${toString TTL} + + ${writeRecord "@" rsubtypes.SOA SOA} + + ${writeSubzone "@" zone} + '' + else '' + $TTL ${toString TTL} + + ${writeRecord name rsubtypes.SOA SOA} + + ${writeSubzone name zone} + ''; +in { + inherit zone subzone renderToString; +} diff --git a/tests/common/dns/module/dns/util/default.nix b/tests/common/dns/module/dns/util/default.nix new file mode 100644 index 00000000..59e661d7 --- /dev/null +++ b/tests/common/dns/module/dns/util/default.nix @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: 2021 Kirill Elagin <https://kir.elagin.me/> +# +# SPDX-License-Identifier: MPL-2.0 or MIT +{lib}: let + inherit + (builtins) + concatStringsSep + genList + stringLength + substring + ; + inherit + (lib.strings) + concatMapStrings + concatMapStringsSep + fixedWidthString + splitString + stringToCharacters + ; + inherit (lib.lists) filter reverseList; + + /* + Split a string into byte chunks, such that each output String is less then or equal to + `n` bytes. + + # Type + + splitInGroupsOf :: Integer -> String -> [String] + + # Arguments + + n + : The number of bytes to put into each String. + + s + : The String to split. + */ + splitInGroupsOf = n: s: let + groupCount = (stringLength s - 1) / n + 1; + in + genList (i: substring (i * n) n s) groupCount; + + # : str -> str + # Prepares a Nix string to be written to a zone file as a character-string + # literal: breaks it into chunks of 255 (per RFC 1035, 3.3) and encloses + # each chunk in quotation marks. + writeCharacterString = s: + if stringLength s <= 255 + then ''"${s}"'' + else concatMapStringsSep " " (x: ''"${x}"'') (splitInGroupsOf 255 s); + + # : str -> str, with length 4 (zeros are padded to the left) + align4Bytes = fixedWidthString 4 "0"; + + # : int -> str -> str + # Expands "" to 4n zeros and aligns the rest on 4 bytes + align4BytesOrExpand = n: v: + if v == "" + then (fixedWidthString (4 * n) "0" "") + else align4Bytes v; + + # : str -> [ str ] + # Returns the record of the ipv6 as a list + mkRecordAux = v6: let + splitted = splitString ":" v6; + n = 8 - builtins.length (filter (x: x != "") splitted); + in + stringToCharacters (concatMapStrings (align4BytesOrExpand n) splitted); + + # : str -> str + # Returns the reversed record of the ipv6 + mkReverseRecord = v6: + concatStringsSep "." (reverseList (mkRecordAux v6)) + ".ip6.arpa"; +in { + inherit writeCharacterString mkReverseRecord; +} diff --git a/tests/common/dns/server.nix b/tests/common/dns/server.nix new file mode 100644 index 00000000..1fb5dadb --- /dev/null +++ b/tests/common/dns/server.nix @@ -0,0 +1,43 @@ +{ + lib, + nodes, + ... +}: { + imports = [ + ./module + ]; + + networking.nameservers = lib.mkForce [ + nodes.name_server.networking.primaryIPAddress + nodes.name_server.networking.primaryIPv6Address + ]; + + vhack = { + dns = { + enable = true; + openFirewall = true; + interfaces = [ + nodes.name_server.networking.primaryIPAddress + nodes.name_server.networking.primaryIPv6Address + ]; + + zones = { + "acme.test" = { + SOA = { + nameServer = "ns"; + adminEmail = "admin@server.com"; + serial = 2025012301; + }; + useOrigin = false; + + A = [ + nodes.acme.networking.primaryIPAddress + ]; + AAAA = [ + nodes.acme.networking.primaryIPv6Address + ]; + }; + }; + }; + }; +} |
