diff options
Diffstat (limited to 'crates/atuin-client/src')
| -rw-r--r-- | crates/atuin-client/src/database.rs | 192 | ||||
| -rw-r--r-- | crates/atuin-client/src/hub.rs | 20 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings.rs | 221 |
3 files changed, 431 insertions, 2 deletions
diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs index 5f292bec..7c63368d 100644 --- a/crates/atuin-client/src/database.rs +++ b/crates/atuin-client/src/database.rs @@ -138,9 +138,13 @@ pub trait Database: Send + Sync + 'static { async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; + async fn stats(&self, h: &History) -> Result<HistoryStats>; async fn get_dups(&self, before: i64, dupkeep: u32) -> Result<Vec<History>>; + + fn clone_boxed(&self) -> Box<dyn Database + 'static>; } // Intended for use on a developer machine and not a sync server. @@ -650,6 +654,10 @@ impl Database for Sqlite { Ok(res) } + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + } + // deleted_at doesn't mean the actual time that the user deleted it, // but the time that the system marks it as deleted async fn delete(&self, mut h: History) -> Result<()> { @@ -814,6 +822,70 @@ impl Database for Sqlite { Ok(res) } + + fn clone_boxed(&self) -> Box<dyn Database + 'static> { + Box::new(self.clone()) + } +} + +pub struct Paged { + database: Box<dyn Database + 'static>, + page_size: usize, + last_id: Option<String>, + include_deleted: bool, + unique: bool, +} + +impl Paged { + pub fn new( + database: Box<dyn Database + 'static>, + page_size: usize, + include_deleted: bool, + unique: bool, + ) -> Self { + Self { + database, + page_size, + last_id: None, + include_deleted, + unique, + } + } + + pub async fn next(&mut self) -> Result<Option<Vec<History>>> { + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query.field("*").order_desc("id"); + + if !self.include_deleted { + query.and_where_is_null("deleted_at"); + } + + if self.unique { + // We want to deduplicate on command, but the user can search via cwd, hostname, and session. + // Without those fields, filter modes won't work right. With those fields, we get duplicates. + // This must be handled upstream. + query + .group_by("command, cwd, hostname, session") + .having("max(timestamp)"); + } + + query.limit(self.page_size); + + if let Some(last_id) = &self.last_id { + query.and_where_lt("id", quote(last_id)); + } + + let query = query.sql().expect("bug in list query. please report"); + let res = self.database.query_history(&query).await?; + + if res.is_empty() { + Ok(None) + } else { + self.last_id = Some(res.last().unwrap().id.0.clone()); + Ok(Some(res)) + } + } } trait SqlBuilderExt { @@ -1166,6 +1238,126 @@ mod test { } #[tokio::test(flavor = "multi_thread")] + async fn test_paged_basic() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add 5 history items + for i in 0..5 { + new_history_item(&mut db, &format!("command{}", i)) + .await + .unwrap(); + } + + // Create a paged iterator with page_size of 2 + let mut paged = db.all_paged(2, false, false); + + // First page should have 2 items + let page1 = paged.next().await.unwrap(); + assert!(page1.is_some()); + assert_eq!(page1.unwrap().len(), 2); + + // Second page should have 2 items + let page2 = paged.next().await.unwrap(); + assert!(page2.is_some()); + assert_eq!(page2.unwrap().len(), 2); + + // Third page should have 1 item + let page3 = paged.next().await.unwrap(); + assert!(page3.is_some()); + assert_eq!(page3.unwrap().len(), 1); + + // Fourth page should be None (exhausted) + let page4 = paged.next().await.unwrap(); + assert!(page4.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_empty() { + let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Create a paged iterator on empty database + let mut paged = db.all_paged(10, false, false); + + // Should return None immediately + let page = paged.next().await.unwrap(); + assert!(page.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_unique() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add duplicate commands + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "unique1").await.unwrap(); + new_history_item(&mut db, "unique2").await.unwrap(); + + // Without unique flag - should get all 4 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 4); + + // With unique flag - should get 3 (duplicates collapsed) + let mut paged_unique = db.all_paged(10, false, true); + let page_unique = paged_unique.next().await.unwrap().unwrap(); + assert_eq!(page_unique.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_include_deleted() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add items + new_history_item(&mut db, "keep1").await.unwrap(); + new_history_item(&mut db, "keep2").await.unwrap(); + new_history_item(&mut db, "delete_me").await.unwrap(); + + // Delete one item + let all = db + .list( + &[], + &Context { + hostname: "".to_string(), + session: "".to_string(), + cwd: "".to_string(), + host_id: "".to_string(), + git_root: None, + }, + None, + false, + false, + ) + .await + .unwrap(); + + let to_delete = all + .iter() + .find(|h| h.command == "delete_me") + .unwrap() + .clone(); + db.delete(to_delete).await.unwrap(); + + // Without include_deleted - should get 2 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 2); + + // With include_deleted - should get 3 + let mut paged_deleted = db.all_paged(10, true, false); + let page_deleted = paged_deleted.next().await.unwrap().unwrap(); + assert_eq!(page_deleted.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] async fn test_search_bench_dupes() { let context = Context { hostname: "test:host".to_string(), diff --git a/crates/atuin-client/src/hub.rs b/crates/atuin-client/src/hub.rs index 5b34574b..b94c69ea 100644 --- a/crates/atuin-client/src/hub.rs +++ b/crates/atuin-client/src/hub.rs @@ -58,10 +58,14 @@ impl HubAuthSession { /// /// Returns a session containing the code and auth URL that the user should visit. pub async fn start(settings: &Settings) -> Result<Self> { + debug!("Starting Hub authentication process..."); + let code_response = request_code(&settings.hub_address) .await .context("Failed to request authentication code from Hub")?; + debug!("Received code from Hub"); + let code = code_response.code; let auth_url = format!("{}/auth/cli?code={}", settings.hub_address, code); @@ -79,8 +83,10 @@ impl HubAuthSession { match verify_code(&self.hub_address, &self.code).await { Ok(response) => { if let Some(token) = response.token { + debug!("Authentication complete, received token"); Ok(HubAuthStatus::Complete(token)) } else if let Some(error) = response.error { + error!("Authentication failed: {}", error); Ok(HubAuthStatus::Failed(error)) } else { Ok(HubAuthStatus::Pending) @@ -105,8 +111,11 @@ impl HubAuthSession { ) -> Result<String> { let start = std::time::Instant::now(); + debug!("Polling for Hub authentication completion..."); + loop { if start.elapsed() > timeout { + warn!("Authentication loop exited due to timeout"); bail!("Authentication timed out. Please try again."); } @@ -181,17 +190,21 @@ async fn handle_resp_error(resp: reqwest::Response) -> Result<reqwest::Response> let status = resp.status(); if status == StatusCode::SERVICE_UNAVAILABLE { + error!("Service unavailable: check https://status.atuin.sh"); bail!("Service unavailable: check https://status.atuin.sh"); } if status == StatusCode::TOO_MANY_REQUESTS { + error!("Rate limited; please wait before trying again"); bail!("Rate limited; please wait before trying again"); } if !status.is_success() { if let Ok(error) = resp.json::<ErrorResponse>().await { + error!("Hub error: {} - {}", status, error.reason); bail!("Hub error: {} - {}", status, error.reason); } + error!("Hub request failed with status: {}", status); bail!("Hub request failed with status: {}", status); } @@ -204,6 +217,8 @@ async fn request_code(address: &str) -> Result<CliCodeResponse> { let url = make_url(address, "/auth/cli/code")?; let client = reqwest::Client::new(); + debug!("Requesting code from Hub at {url}"); + let resp = client .post(&url) .header(USER_AGENT, APP_USER_AGENT) @@ -219,9 +234,12 @@ async fn request_code(address: &str) -> Result<CliCodeResponse> { /// Poll to verify the CLI auth code and get the session token async fn verify_code(address: &str, code: &str) -> Result<CliVerifyResponse> { ensure_crypto_provider(); - let url = make_url(address, &format!("/auth/cli/verify?code={}", code))?; + let base = make_url(address, "/auth/cli/verify")?; + let url = format!("{base}?code={code}"); let client = reqwest::Client::new(); + debug!("Verifying code with Hub at {base}?code=******"); + let resp = client .post(&url) .header(USER_AGENT, APP_USER_AGENT) diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index a15ce461..8e874832 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -42,6 +42,10 @@ pub enum SearchMode { #[serde(rename = "skim")] Skim, + + #[serde(rename = "daemon-fuzzy")] + #[clap(aliases = &["daemon-fuzzy"])] + DaemonFuzzy, } impl SearchMode { @@ -51,6 +55,7 @@ impl SearchMode { SearchMode::FullText => "FULLTXT", SearchMode::Fuzzy => "FUZZY", SearchMode::Skim => "SKIM", + SearchMode::DaemonFuzzy => "DAEMON", } } pub fn next(&self, settings: &Settings) -> Self { @@ -58,9 +63,13 @@ impl SearchMode { SearchMode::Prefix => SearchMode::FullText, // if the user is using skim, we go to skim SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, + // if the user is using daemon-fuzzy, we go to daemon-fuzzy + SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { + SearchMode::DaemonFuzzy + } // otherwise fuzzy. SearchMode::FullText => SearchMode::Fuzzy, - SearchMode::Fuzzy | SearchMode::Skim => SearchMode::Prefix, + SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, } } } @@ -477,6 +486,78 @@ pub struct Tmux { pub height: String, } +/// Log level for file logging. Maps to tracing's LevelFilter. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Trace, + Debug, + #[default] + Info, + Warn, + Error, +} + +impl LogLevel { + /// Convert to a tracing directive string for use with EnvFilter. + pub fn as_directive(&self) -> &'static str { + match self { + LogLevel::Trace => "trace", + LogLevel::Debug => "debug", + LogLevel::Info => "info", + LogLevel::Warn => "warn", + LogLevel::Error => "error", + } + } +} + +/// Configuration for a specific log type (search or daemon). +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct LogConfig { + /// Log file name (relative to dir) or absolute path. + pub file: String, + + /// Override global enabled setting for this log type. + pub enabled: Option<bool>, + + /// Override global level setting for this log type. + pub level: Option<LogLevel>, + + /// Override global retention days setting for this log type. + pub retention: Option<u64>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Logs { + /// Enable file logging globally. Defaults to true. + #[serde(default = "Logs::default_enabled")] + pub enabled: bool, + + /// Directory for log files. Defaults to ~/.atuin/logs + pub dir: String, + + /// Default log level for file logging. Defaults to "info". + /// Note: ATUIN_LOG environment variable overrides this. + #[serde(default)] + pub level: LogLevel, + + /// Default retention days for log files. Defaults to 4. + #[serde(default = "Logs::default_retention")] + pub retention: u64, + + /// Search log settings + #[serde(default)] + pub search: LogConfig, + + /// Daemon log settings + #[serde(default)] + pub daemon: LogConfig, + + /// AI log settings + #[serde(default)] + pub ai: LogConfig, +} + #[derive(Default, Clone, Debug, Deserialize, Serialize)] pub struct Ai { /// The address of the Atuin AI endpoint. Used for AI features like command generation. @@ -523,6 +604,117 @@ impl Default for Daemon { } } +impl Default for Logs { + fn default() -> Self { + Self { + enabled: true, + dir: "".to_string(), + level: LogLevel::default(), + retention: Self::default_retention(), + search: LogConfig { + file: "search.log".to_string(), + ..Default::default() + }, + daemon: LogConfig { + file: "daemon.log".to_string(), + ..Default::default() + }, + ai: LogConfig { + file: "ai.log".to_string(), + ..Default::default() + }, + } + } +} + +impl Logs { + fn default_enabled() -> bool { + true + } + + fn default_retention() -> u64 { + 4 + } + + /// Returns whether search logging is enabled. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_enabled(&self) -> bool { + self.search.enabled.unwrap_or(self.enabled) + } + + /// Returns whether daemon logging is enabled. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_enabled(&self) -> bool { + self.daemon.enabled.unwrap_or(self.enabled) + } + + /// Returns whether AI logging is enabled. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_enabled(&self) -> bool { + self.ai.enabled.unwrap_or(self.enabled) + } + + /// Returns the log level for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_level(&self) -> LogLevel { + self.search.level.unwrap_or(self.level) + } + + /// Returns the log level for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_level(&self) -> LogLevel { + self.daemon.level.unwrap_or(self.level) + } + + /// Returns the log level for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_level(&self) -> LogLevel { + self.ai.level.unwrap_or(self.level) + } + + /// Returns the retention days for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_retention(&self) -> u64 { + self.search.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_retention(&self) -> u64 { + self.daemon.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_retention(&self) -> u64 { + self.ai.retention.unwrap_or(self.retention) + } + + /// Returns the full path for the search log file. + /// If `file` is an absolute path, returns it directly. + /// Otherwise, joins it with `dir`. + pub fn search_path(&self) -> PathBuf { + let path = PathBuf::from(&self.search.file); + if path.is_absolute() { + path + } else { + PathBuf::from(&self.dir).join(path) + } + } + + /// Returns the full path for the daemon log file. + /// If `file` is an absolute path, returns it directly. + /// Otherwise, joins it with `dir`. + pub fn daemon_path(&self) -> PathBuf { + let path = PathBuf::from(&self.daemon.file); + if path.is_absolute() { + path + } else { + PathBuf::from(&self.dir).join(path) + } + } +} + impl Default for Search { fn default() -> Self { Self { @@ -849,6 +1041,9 @@ pub struct Settings { pub tmux: Tmux, #[serde(default)] + pub logs: Logs, + + #[serde(default)] pub meta: meta::Settings, #[serde(default)] @@ -1033,6 +1228,7 @@ impl Settings { let scripts_path = data_dir.join("scripts.db"); let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); let pidfile_path = data_dir.join("atuin-daemon.pid"); + let logs_dir = atuin_common::utils::logs_dir(); let key_path = data_dir.join("key"); let meta_path = data_dir.join("meta.db"); @@ -1101,6 +1297,12 @@ impl Settings { .set_default("daemon.pidfile_path", pidfile_path.to_str())? .set_default("daemon.systemd_socket", false)? .set_default("daemon.tcp_port", 8889)? + .set_default("logs.enabled", true)? + .set_default("logs.dir", logs_dir.to_str())? + .set_default("logs.level", "info")? + .set_default("logs.search.file", "search.log")? + .set_default("logs.daemon.file", "daemon.log")? + .set_default("logs.ai.file", "ai.log")? .set_default("kv.db_path", kv_path.to_str())? .set_default("scripts.db_path", scripts_path.to_str())? .set_default("meta.db_path", meta_path.to_str())? @@ -1218,6 +1420,9 @@ impl Settings { settings.key_path = Self::expand_path(settings.key_path)?; settings.daemon.socket_path = Self::expand_path(settings.daemon.socket_path)?; settings.daemon.pidfile_path = Self::expand_path(settings.daemon.pidfile_path)?; + settings.logs.dir = Self::expand_path(settings.logs.dir)?; + settings.logs.search.file = Self::expand_path(settings.logs.search.file)?; + settings.logs.daemon.file = Self::expand_path(settings.logs.daemon.file)?; // Validate UI settings settings.ui.validate()?; @@ -1264,6 +1469,20 @@ impl Default for Settings { } } +/// Initialize the meta store configuration for testing. +/// +/// This should only be used in tests. It allows tests to bypass the normal +/// Settings::new() flow while still being able to use Settings::host_id() +/// and other meta store dependent functions. +/// +/// # Safety +/// This function is not thread-safe with concurrent calls to Settings::new() +/// or other meta store initialization. Only call from tests. +#[doc(hidden)] +pub fn init_meta_config_for_testing(meta_db_path: impl Into<String>, local_timeout: f64) { + META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); +} + #[cfg(test)] pub(crate) fn test_local_timeout() -> f64 { std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") |
