use std::fs::{self, File, OpenOptions}; use std::io::{ErrorKind, Write}; #[cfg(unix)] use std::os::unix::net::UnixStream as StdUnixStream; use std::path::{Path, PathBuf}; use std::process::{Command, Stdio}; use std::time::{Duration, Instant}; use crate::atuin_client::{ database::ClientSqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, }; use crate::atuin_daemon::DaemonEvent; use crate::atuin_daemon::client::{ ControlClient, DaemonClientErrorKind, HistoryClient, classify_error, }; use clap::Subcommand; #[cfg(unix)] use daemonize::Daemonize; use eyre::{Result, WrapErr, bail, eyre}; use fs4::fs_std::FileExt; use tokio::time::sleep; #[derive(clap::Args, Debug)] pub(crate) struct Cmd { #[command(subcommand)] subcmd: SubCmd, } #[derive(Subcommand, Debug)] #[command(infer_subcommands = true)] pub(crate) enum SubCmd { /// Start the daemon server Start { #[arg(long, hide = true)] daemonize: bool, /// Also write daemon logs to the console (useful for debugging) #[arg(long)] show_logs: bool, /// Force start: kill existing daemon process and reset the socket #[arg(long)] force: bool, }, /// Show the daemon's current status Status, /// Stop the daemon gracefully Stop, /// Restart the daemon (stop, then start in background) Restart, } impl Cmd { /// Returns `true` when the process should daemonize before creating the /// async runtime or opening any database connections. #[cfg(unix)] pub(crate) fn should_daemonize(&self) -> bool { match &self.subcmd { SubCmd::Start { daemonize, .. } => *daemonize, _ => false, } } /// Returns `true` when logs should also be written to the console. pub(crate) fn show_logs(&self) -> bool { match &self.subcmd { SubCmd::Start { show_logs, .. } => *show_logs, _ => false, } } pub(crate) async fn run( self, settings: Settings, store: SqliteStore, history_db: ClientSqlite, ) -> Result<()> { match self.subcmd { SubCmd::Start { force, .. } => run(settings, store, history_db, force).await, SubCmd::Status => status_cmd(&settings).await, SubCmd::Stop => stop_cmd(&settings).await, SubCmd::Restart => restart_cmd(&settings).await, } } } const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); const DAEMON_PROTOCOL_VERSION: u32 = 1; const STARTUP_POLL: Duration = Duration::from_millis(40); const LOCK_POLL: Duration = Duration::from_millis(20); const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; struct PidfileGuard { file: File, } impl PidfileGuard { fn acquire(path: &Path) -> Result { let mut file = open_lock_file(path)?; if !file.try_lock_exclusive()? { bail!( "daemon already running (pidfile lock busy at {})", path.display() ); } file.set_len(0) .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; writeln!(file, "{}", std::process::id()) .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; Ok(Self { file }) } } impl Drop for PidfileGuard { fn drop(&mut self) { drop(self.file.unlock()); } } enum Probe { Ready(HistoryClient), NeedsRestart(String), Unreachable(eyre::Report), } fn daemon_matches_expected(version: &str, protocol: u32) -> bool { version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION } fn daemon_mismatch_message(version: &str, protocol: u32) -> String { if protocol == DAEMON_PROTOCOL_VERSION { format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") } else { format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") } } fn is_legacy_daemon_error(err: &eyre::Report) -> bool { matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) } fn open_lock_file(path: &Path) -> Result { if let Some(parent) = path.parent() { fs::create_dir_all(parent) .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; } OpenOptions::new() .read(true) .write(true) .create(true) .truncate(false) .open(path) .wrap_err_with(|| format!("could not open lock file {}", path.display())) } async fn wait_for_lock(path: &Path, timeout: Duration) -> Result { let file = open_lock_file(path)?; let start = Instant::now(); loop { match file.try_lock_exclusive() { Ok(true) => return Ok(file), Ok(false) => { if start.elapsed() >= timeout { bail!("timed out waiting for lock at {}", path.display()); } sleep(LOCK_POLL).await; } Err(err) => { return Err(eyre!("could not lock {}: {err}", path.display())); } } } } async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { let file = wait_for_lock(path, timeout).await?; file.unlock() .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; Ok(()) } async fn connect_client(settings: &Settings) -> Result { HistoryClient::new( #[cfg(unix)] settings.daemon.socket_path.clone(), ) .await } async fn probe(settings: &Settings) -> Probe { let mut client = match connect_client(settings).await { Ok(client) => client, Err(err) => return Probe::Unreachable(err), }; match client.status().await { Ok(status) => { if daemon_matches_expected(&status.version, status.protocol) { Probe::Ready(client) } else { Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) } } Err(err) => Probe::Unreachable(err), } } async fn request_shutdown(settings: &Settings) { if let Ok(mut client) = connect_client(settings).await { drop(client.shutdown().await); } } fn spawn_daemon_process() -> Result<()> { let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; let mut cmd = Command::new(exe); cmd.arg("daemon") .arg("start") .stdin(Stdio::null()) .stdout(Stdio::null()) .stderr(Stdio::null()); #[cfg(unix)] cmd.arg("--daemonize"); cmd.spawn().wrap_err("failed to spawn daemon process")?; Ok(()) } fn startup_timeout(settings: &Settings) -> Duration { Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) } #[cfg(unix)] fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { if settings.daemon.systemd_socket { return Ok(()); } let socket_path = Path::new(&settings.daemon.socket_path); if !socket_path.exists() { return Ok(()); } match StdUnixStream::connect(socket_path) { Ok(stream) => { drop(stream); Ok(()) } Err(err) if err.kind() == ErrorKind::ConnectionRefused => { fs::remove_file(socket_path).wrap_err_with(|| { format!( "failed to remove stale daemon socket {}", socket_path.display() ) })?; Ok(()) } Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), Err(_) => Ok(()), } } async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result { let start = Instant::now(); let mut last_error = eyre!("daemon did not become ready"); loop { match probe(settings).await { Probe::Ready(client) => return Ok(client), Probe::NeedsRestart(reason) => { last_error = eyre!(reason); } Probe::Unreachable(err) => { if is_legacy_daemon_error(&err) { return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); } last_error = err; } } if start.elapsed() >= timeout { return Err(last_error.wrap_err(format!( "timed out waiting for daemon startup after {}ms", timeout.as_millis() ))); } sleep(STARTUP_POLL).await; } } pub(crate) async fn start_history(settings: &Settings, history: History) -> Result { match async { connect_client(settings) .await? .start_history(history.clone()) .await } .await { Ok(resp) => { if daemon_matches_expected(&resp.version, resp.protocol) { return Ok(resp.id); } Err(eyre!( "{}. Restart the daemon manually", daemon_mismatch_message(&resp.version, resp.protocol) )) } Err(err) => Err(err), } } pub(crate) async fn end_history( settings: &Settings, id: String, duration: u64, exit: i64, ) -> Result<()> { match async { connect_client(settings) .await? .end_history(id.clone(), duration, exit) .await } .await { Ok(resp) => { if daemon_matches_expected(&resp.version, resp.protocol) { return Ok(()); } Err(eyre!( "{}. Restart the daemon manually", daemon_mismatch_message(&resp.version, resp.protocol) )) } Err(err) => Err(err), } } /// Emit a daemon event. pub(crate) async fn emit_event(settings: &Settings, event: DaemonEvent) { // Try to connect and send match ControlClient::from_settings(settings).await { Ok(mut client) => { if let Err(e) = client.send_event(event).await { tracing::debug!(?e, "failed to send event to daemon"); } } Err(e) => { tracing::debug!(?e, "daemon not available, skipping event emission"); } } } pub(crate) async fn tail_client(settings: &Settings) -> Result { match probe(settings).await { Probe::Ready(client) => Ok(client), Probe::NeedsRestart(reason) => { bail!("{reason}. Restart the daemon manually"); } Probe::Unreachable(err) if is_legacy_daemon_error(&err) => { Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)) } Probe::Unreachable(err) => Err(err), } } async fn status_cmd(settings: &Settings) -> Result<()> { match probe(settings).await { Probe::Ready(mut client) => { let status = client.status().await?; println!("Daemon running"); println!(" PID: {}", status.pid); println!(" Version: {}", status.version); println!(" Protocol: {}", status.protocol); println!(" Healthy: {}", status.healthy); #[cfg(unix)] println!(" Socket: {}", settings.daemon.socket_path); } Probe::NeedsRestart(reason) => { println!("Daemon running (needs restart)"); println!(" Reason: {reason}"); } Probe::Unreachable(_) => { println!("Daemon is not running"); } } Ok(()) } async fn stop_cmd(settings: &Settings) -> Result<()> { let Ok(mut client) = connect_client(settings).await else { println!("Daemon is not running"); return Ok(()); }; match client.shutdown().await { Ok(true) => { println!("Shutdown requested"); let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); let timeout = Duration::from_secs(5); match wait_for_pidfile_available(&pidfile_path, timeout).await { Ok(()) => println!("Daemon stopped"), Err(_) => println!("Daemon may still be shutting down"), } Ok(()) } Ok(false) => bail!("Daemon rejected shutdown request"), Err(err) => Err(err.wrap_err("Failed to send shutdown request")), } } async fn restart_cmd(settings: &Settings) -> Result<()> { // Stop if running match probe(settings).await { Probe::Ready(_) | Probe::NeedsRestart(_) => { request_shutdown(settings).await; println!("Stopping daemon..."); let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); let timeout = Duration::from_secs(5); wait_for_pidfile_available(&pidfile_path, timeout) .await .wrap_err("Timed out waiting for old daemon to stop")?; } Probe::Unreachable(_) => { println!("No daemon running"); } } #[cfg(unix)] remove_stale_socket_if_present(settings)?; spawn_daemon_process()?; println!("Starting daemon..."); let timeout = startup_timeout(settings); let status = wait_until_ready(settings, timeout).await?.status().await?; println!("Daemon restarted"); println!(" PID: {}", status.pid); println!(" Version: {}", status.version); Ok(()) } /// Daemonize the current process. Must be called before creating the tokio /// runtime or opening database connections, since `fork()` inside an async /// runtime corrupts its internal state. #[cfg(unix)] pub(crate) fn daemonize_current_process() -> Result<()> { let cwd = std::env::current_dir().wrap_err("could not determine current directory for daemon")?; Daemonize::new() .working_directory(cwd) .start() .wrap_err("failed to daemonize process")?; Ok(()) } async fn run( settings: Settings, store: SqliteStore, history_db: ClientSqlite, force: bool, ) -> Result<()> { if force { force_cleanup(&settings); } let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; crate::atuin_daemon::boot(settings, store, history_db).await?; Ok(()) } /// Force cleanup: kill existing daemon process and remove socket. fn force_cleanup(settings: &Settings) { let pidfile_path = Path::new(&settings.daemon.pidfile_path); // Read and kill the existing process if pidfile exists if pidfile_path.exists() { if let Ok(contents) = fs::read_to_string(pidfile_path) && let Some(pid_str) = contents.lines().next() && let Ok(pid) = pid_str.parse::() { kill_process(pid); // Give it a moment to release resources std::thread::sleep(Duration::from_millis(100)); } // Remove the pidfile if let Err(e) = fs::remove_file(pidfile_path) && e.kind() != ErrorKind::NotFound { tracing::warn!("failed to remove pidfile: {e}"); } } // Remove the socket file #[cfg(unix)] { let socket_path = Path::new(&settings.daemon.socket_path); if socket_path.exists() && let Err(e) = fs::remove_file(socket_path) && e.kind() != ErrorKind::NotFound { tracing::warn!("failed to remove socket: {e}"); } } } /// Kill a process by PID. #[cfg(unix)] fn kill_process(pid: u32) { // Use kill command to send SIGTERM for graceful shutdown drop( Command::new("kill") .args(["-TERM", &pid.to_string()]) .stdout(Stdio::null()) .stderr(Stdio::null()) .status(), ); } #[cfg(test)] mod tests { use super::{ DAEMON_PROTOCOL_VERSION, DAEMON_VERSION, daemon_matches_expected, daemon_mismatch_message, }; #[test] fn test_version_matches() { assert!(daemon_matches_expected( DAEMON_VERSION, DAEMON_PROTOCOL_VERSION )); } #[test] fn test_version_mismatch() { assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); assert!(!daemon_matches_expected("0.0.0", 999)); } #[test] fn test_mismatch_message_version() { let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); assert!(msg.contains("out of date"), "got: {msg}"); assert!(msg.contains("0.0.0")); assert!(msg.contains(DAEMON_VERSION)); } #[test] fn test_mismatch_message_protocol() { let msg = daemon_mismatch_message(DAEMON_VERSION, 999); assert!(msg.contains("protocol mismatch"), "got: {msg}"); } }