diff options
| author | Ellie Huxtable <ellie@atuin.sh> | 2026-02-13 11:37:58 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-13 11:37:58 -0800 |
| commit | d52c4d6003adced1f6763261a7f9132719be5533 (patch) | |
| tree | 94d2f7c815814135edbc2ee0af56d2cedd1a1724 /crates | |
| parent | feat: add Atuin AI inline CLI MVP (#3178) (diff) | |
| download | atuin-d52c4d6003adced1f6763261a7f9132719be5533.zip | |
feat: add autostart and pid management to daemon (#3180)
Diffstat (limited to 'crates')
| -rw-r--r-- | crates/atuin-client/config.toml | 10 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings.rs | 29 | ||||
| -rw-r--r-- | crates/atuin-daemon/Cargo.toml | 4 | ||||
| -rw-r--r-- | crates/atuin-daemon/proto/history.proto | 21 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/client.rs | 51 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/server.rs | 89 | ||||
| -rw-r--r-- | crates/atuin-daemon/tests/lifecycle.rs | 127 | ||||
| -rw-r--r-- | crates/atuin/Cargo.toml | 4 | ||||
| -rw-r--r-- | crates/atuin/src/command/client.rs | 15 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/daemon.rs | 627 | ||||
| -rw-r--r-- | crates/atuin/src/command/client/history.rs | 23 |
11 files changed, 948 insertions, 52 deletions
diff --git a/crates/atuin-client/config.toml b/crates/atuin-client/config.toml index 03e093fc..6e67a4e1 100644 --- a/crates/atuin-client/config.toml +++ b/crates/atuin-client/config.toml @@ -259,9 +259,13 @@ records = true # strategy = "auto" [daemon] -## Enables using the daemon to sync. Requires the daemon to be running in the background. Start it with `atuin daemon` +## Enables using the daemon to sync. # enabled = false +## Automatically start and manage the daemon when needed. +## Not compatible with `systemd_socket = true`. +# autostart = false + ## How often the daemon should sync in seconds # sync_frequency = 300 @@ -270,6 +274,10 @@ records = true ## windows: Not Supported # socket_path = "~/.local/share/atuin/atuin.sock" +## The daemon pidfile used for lifecycle management. +## Defaults to the Atuin data directory. +# pidfile_path = "~/.local/share/atuin/atuin-daemon.pid" + ## Use systemd socket activation rather than opening the given path (the path must still be correct for the client) ## linux: false ## mac/windows: Not Supported diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 7e062e75..1c35e6eb 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -436,16 +436,22 @@ pub struct Theme { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Daemon { /// Use the daemon to sync - /// If enabled, requires a running daemon with `atuin daemon` + /// If enabled, history hooks are routed through the daemon. #[serde(alias = "enable")] pub enabled: bool, + /// Automatically start and manage a local daemon when needed. + pub autostart: bool, + /// The daemon will handle sync on an interval. How often to sync, in seconds. pub sync_frequency: u64, /// The path to the unix socket used by the daemon pub socket_path: String, + /// Path to the daemon pidfile used for process coordination. + pub pidfile_path: String, + /// Use a socket passed via systemd's socket activation protocol, instead of the path pub systemd_socket: bool, @@ -493,8 +499,10 @@ impl Default for Daemon { fn default() -> Self { Self { enabled: false, + autostart: false, sync_frequency: 300, socket_path: "".to_string(), + pidfile_path: "".to_string(), systemd_socket: false, tcp_port: 8889, } @@ -1007,6 +1015,7 @@ impl Settings { let kv_path = data_dir.join("kv.db"); let scripts_path = data_dir.join("scripts.db"); let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); + let pidfile_path = data_dir.join("atuin-daemon.pid"); let key_path = data_dir.join("key"); let meta_path = data_dir.join("meta.db"); @@ -1070,7 +1079,9 @@ impl Settings { .set_default("store_failed", true)? .set_default("daemon.sync_frequency", 300)? .set_default("daemon.enabled", false)? + .set_default("daemon.autostart", false)? .set_default("daemon.socket_path", socket_path.to_str())? + .set_default("daemon.pidfile_path", pidfile_path.to_str())? .set_default("daemon.systemd_socket", false)? .set_default("daemon.tcp_port", 8889)? .set_default("kv.db_path", kv_path.to_str())? @@ -1189,6 +1200,7 @@ impl Settings { settings.record_store_path = Self::expand_path(settings.record_store_path)?; settings.key_path = Self::expand_path(settings.key_path)?; settings.daemon.socket_path = Self::expand_path(settings.daemon.socket_path)?; + settings.daemon.pidfile_path = Self::expand_path(settings.daemon.pidfile_path)?; // Validate UI settings settings.ui.validate()?; @@ -1351,6 +1363,9 @@ mod tests { let kv_db_path: String = config.get("kv.db_path")?; let scripts_db_path: String = config.get("scripts.db_path")?; let meta_db_path: String = config.get("meta.db_path")?; + let daemon_socket_path: String = config.get("daemon.socket_path")?; + let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; + let daemon_autostart: bool = config.get("daemon.autostart")?; assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); @@ -1364,6 +1379,18 @@ mod tests { custom_dir.join("scripts.db").to_str().unwrap() ); assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); + assert_eq!( + daemon_socket_path, + atuin_common::utils::runtime_dir() + .join("atuin.sock") + .to_str() + .unwrap() + ); + assert_eq!( + daemon_pidfile_path, + custom_dir.join("atuin-daemon.pid").to_str().unwrap() + ); + assert!(!daemon_autostart); Ok(()) } diff --git a/crates/atuin-daemon/Cargo.toml b/crates/atuin-daemon/Cargo.toml index 9adbe5e8..8d3b4ab6 100644 --- a/crates/atuin-daemon/Cargo.toml +++ b/crates/atuin-daemon/Cargo.toml @@ -39,6 +39,10 @@ rand.workspace = true [target.'cfg(target_os = "linux")'.dependencies] listenfd = "1.0.1" +[dev-dependencies] +tempfile = { workspace = true } +atuin-common = { path = "../atuin-common", version = "18.12.1" } + [build-dependencies] protox = "0.8.0" tonic-build = "0.12" diff --git a/crates/atuin-daemon/proto/history.proto b/crates/atuin-daemon/proto/history.proto index 1172b91b..9fbd3372 100644 --- a/crates/atuin-daemon/proto/history.proto +++ b/crates/atuin-daemon/proto/history.proto @@ -18,14 +18,35 @@ message EndHistoryRequest { message StartHistoryReply { string id = 1; + string version = 2; + uint32 protocol = 3; } message EndHistoryReply { string id = 1; uint64 idx = 2; + string version = 3; + uint32 protocol = 4; +} + +message StatusRequest {} + +message StatusReply { + bool healthy = 1; + string version = 2; + uint32 pid = 3; + uint32 protocol = 4; +} + +message ShutdownRequest {} + +message ShutdownReply { + bool accepted = 1; } service History { rpc StartHistory(StartHistoryRequest) returns (StartHistoryReply); rpc EndHistory(EndHistoryRequest) returns (EndHistoryReply); + rpc Status(StatusRequest) returns (StatusReply); + rpc Shutdown(ShutdownRequest) returns (ShutdownReply); } diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index a4b4690e..05067bda 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -1,6 +1,7 @@ use eyre::{Context, Result}; #[cfg(windows)] use tokio::net::TcpStream; +use tonic::Code; use tonic::transport::{Channel, Endpoint, Uri}; use tower::service_fn; @@ -12,13 +13,41 @@ use tokio::net::UnixStream; use atuin_client::history::History; use crate::history::{ - EndHistoryRequest, StartHistoryRequest, history_client::HistoryClient as HistoryServiceClient, + EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, + StatusReply, StatusRequest, history_client::HistoryClient as HistoryServiceClient, }; pub struct HistoryClient { client: HistoryServiceClient<Channel>, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DaemonClientErrorKind { + Connect, + Unavailable, + Unimplemented, + Other, +} + +#[must_use] +pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { + for cause in error.chain() { + if cause.downcast_ref::<tonic::transport::Error>().is_some() { + return DaemonClientErrorKind::Connect; + } + + if let Some(status) = cause.downcast_ref::<tonic::Status>() { + return match status.code() { + Code::Unavailable => DaemonClientErrorKind::Unavailable, + Code::Unimplemented => DaemonClientErrorKind::Unimplemented, + _ => DaemonClientErrorKind::Other, + }; + } + } + + DaemonClientErrorKind::Other +} + // Wrap the grpc client impl HistoryClient { #[cfg(unix)] @@ -67,7 +96,7 @@ impl HistoryClient { Ok(HistoryClient { client }) } - pub async fn start_history(&mut self, h: History) -> Result<String> { + pub async fn start_history(&mut self, h: History) -> Result<StartHistoryReply> { let req = StartHistoryRequest { command: h.command, cwd: h.cwd, @@ -76,9 +105,7 @@ impl HistoryClient { timestamp: h.timestamp.unix_timestamp_nanos() as u64, }; - let resp = self.client.start_history(req).await?; - - Ok(resp.into_inner().id) + Ok(self.client.start_history(req).await?.into_inner()) } pub async fn end_history( @@ -86,12 +113,18 @@ impl HistoryClient { id: String, duration: u64, exit: i64, - ) -> Result<(String, u64)> { + ) -> Result<EndHistoryReply> { let req = EndHistoryRequest { id, duration, exit }; - let resp = self.client.end_history(req).await?; - let resp = resp.into_inner(); + Ok(self.client.end_history(req).await?.into_inner()) + } + + pub async fn status(&mut self) -> Result<StatusReply> { + Ok(self.client.status(StatusRequest {}).await?.into_inner()) + } - Ok((resp.id, resp.idx)) + pub async fn shutdown(&mut self) -> Result<bool> { + let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); + Ok(resp.accepted) } } diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs index 2cba1753..9622d2b6 100644 --- a/crates/atuin-daemon/src/server.rs +++ b/crates/atuin-daemon/src/server.rs @@ -4,10 +4,12 @@ use atuin_client::encryption; use atuin_client::history::store::HistoryStore; use atuin_client::record::sqlite_store::SqliteStore; use atuin_client::settings::Settings; +use std::io::ErrorKind; #[cfg(unix)] use std::path::PathBuf; use std::sync::Arc; use time::OffsetDateTime; +use tokio::sync::watch; use tracing::{Level, instrument}; use atuin_client::database::{Database, Sqlite as HistoryDatabase}; @@ -19,9 +21,12 @@ use tonic::{Request, Response, Status, transport::Server}; use crate::history::history_server::{History as HistorySvc, HistoryServer}; use crate::history::{EndHistoryReply, EndHistoryRequest, StartHistoryReply, StartHistoryRequest}; +use crate::history::{ShutdownReply, ShutdownRequest, StatusReply, StatusRequest}; mod sync; +const DAEMON_PROTOCOL_VERSION: u32 = 1; + #[derive(Debug)] pub struct HistoryService { // A store for WIP history @@ -29,14 +34,20 @@ pub struct HistoryService { running: Arc<DashMap<HistoryId, History>>, store: HistoryStore, history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, } impl HistoryService { - pub fn new(store: HistoryStore, history_db: HistoryDatabase) -> Self { + pub fn new( + store: HistoryStore, + history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, + ) -> Self { Self { running: Arc::new(DashMap::new()), store, history_db, + shutdown_tx, } } } @@ -77,7 +88,11 @@ impl HistorySvc for HistoryService { tracing::info!(id = id.to_string(), "start history"); running.insert(id.clone(), h); - let reply = StartHistoryReply { id: id.to_string() }; + let reply = StartHistoryReply { + id: id.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; Ok(Response::new(reply)) } @@ -122,6 +137,8 @@ impl HistorySvc for HistoryService { let reply = EndHistoryReply { id: id.0.to_string(), idx, + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, }; return Ok(Response::new(reply)); @@ -131,10 +148,35 @@ impl HistorySvc for HistoryService { "could not find history with id: {id}" ))) } + + #[instrument(skip_all, level = Level::INFO)] + async fn status( + &self, + _request: Request<StatusRequest>, + ) -> Result<Response<StatusReply>, Status> { + let reply = StatusReply { + // If status RPC responds, the daemon control plane is healthy. + healthy: true, + version: env!("CARGO_PKG_VERSION").to_string(), + pid: std::process::id(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn shutdown( + &self, + _request: Request<ShutdownRequest>, + ) -> Result<Response<ShutdownReply>, Status> { + let _ = self.shutdown_tx.send(true); + Ok(Response::new(ShutdownReply { accepted: true })) + } } #[cfg(unix)] -async fn shutdown_signal(socket: Option<PathBuf>) { +async fn shutdown_signal(socket: Option<PathBuf>, mut shutdown_rx: watch::Receiver<bool>) { let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to register sigterm handler"); let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) @@ -143,26 +185,38 @@ async fn shutdown_signal(socket: Option<PathBuf>) { tokio::select! { _ = term.recv() => {}, _ = int.recv() => {}, + _ = shutdown_rx.changed() => {}, } eprintln!("Removing socket..."); if let Some(socket) = socket { - std::fs::remove_file(socket).expect("failed to remove socket"); + match std::fs::remove_file(socket) { + Ok(()) => {} + Err(err) if err.kind() == ErrorKind::NotFound => {} + Err(err) => { + eprintln!("failed to remove socket: {err}"); + } + } } eprintln!("Shutting down..."); } #[cfg(windows)] -async fn shutdown_signal() { - tokio::signal::windows::ctrl_c() - .expect("failed to register signal handler") - .recv() - .await; +async fn shutdown_signal(mut shutdown_rx: watch::Receiver<bool>) { + let mut ctrl_c = tokio::signal::windows::ctrl_c().expect("failed to register signal handler"); + tokio::select! { + _ = ctrl_c.recv() => {}, + _ = shutdown_rx.changed() => {}, + } eprintln!("Shutting down..."); } #[cfg(unix)] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::UnixListener; use tokio_stream::wrappers::UnixListenerStream; @@ -215,7 +269,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> .add_service(HistoryServer::new(history)) .serve_with_incoming_shutdown( uds_stream, - shutdown_signal(cleanup.then_some(socket_path.into())), + shutdown_signal(cleanup.then_some(socket_path.into()), shutdown_rx), ) .await?; @@ -223,7 +277,11 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> } #[cfg(not(unix))] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; @@ -236,7 +294,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> Server::builder() .add_service(HistoryServer::new(history)) - .serve_with_incoming_shutdown(tcp_stream, shutdown_signal()) + .serve_with_incoming_shutdown(tcp_stream, shutdown_signal(shutdown_rx)) .await?; Ok(()) } @@ -257,7 +315,8 @@ pub async fn listen( let host_id = Settings::host_id().await?; let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - let history = HistoryService::new(history_store.clone(), history_db.clone()); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let history = HistoryService::new(history_store.clone(), history_db.clone(), shutdown_tx); // start services tokio::spawn(sync::worker( @@ -267,5 +326,5 @@ pub async fn listen( history_db, )); - start_server(settings, history).await + start_server(settings, history, shutdown_rx).await } diff --git a/crates/atuin-daemon/tests/lifecycle.rs b/crates/atuin-daemon/tests/lifecycle.rs new file mode 100644 index 00000000..56457fa7 --- /dev/null +++ b/crates/atuin-daemon/tests/lifecycle.rs @@ -0,0 +1,127 @@ +//! Integration tests for the daemon server lifecycle. +//! +//! Each test spins up a real gRPC server on a temporary unix socket, +//! connects a client, and exercises the daemon RPCs. + +#[cfg(unix)] +mod unix { + use std::time::Duration; + + use atuin_client::database::Sqlite; + use atuin_client::history::store::HistoryStore; + use atuin_client::record::sqlite_store::SqliteStore; + use atuin_common::record::HostId; + use atuin_common::utils::uuid_v7; + use atuin_daemon::client::HistoryClient; + use atuin_daemon::history::history_server::HistoryServer; + use atuin_daemon::server::HistoryService; + use tempfile::TempDir; + use tokio::net::UnixListener; + use tokio::sync::watch; + use tokio_stream::wrappers::UnixListenerStream; + use tonic::transport::Server; + + /// Spins up a daemon server on a temp socket and returns a connected client, + /// the shutdown sender, and the temp dir (must be held to keep paths alive). + async fn start_test_daemon() -> (HistoryClient, watch::Sender<bool>, TempDir) { + let tmp = tempfile::tempdir().unwrap(); + + let db_path = tmp.path().join("history.db"); + let record_path = tmp.path().join("records.db"); + + let history_db = Sqlite::new(&db_path, 5.0).await.unwrap(); + let store = SqliteStore::new(&record_path, 5.0).await.unwrap(); + + let host_id = HostId(uuid_v7()); + let encryption_key = [0u8; 32]; + let history_store = HistoryStore::new(store, host_id, encryption_key); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let service = HistoryService::new(history_store, history_db, shutdown_tx.clone()); + + let socket_path = tmp.path().join("test.sock"); + let uds = UnixListener::bind(&socket_path).unwrap(); + let stream = UnixListenerStream::new(uds); + + let mut rx = shutdown_rx.clone(); + tokio::spawn(async move { + Server::builder() + .add_service(HistoryServer::new(service)) + .serve_with_incoming_shutdown(stream, async move { + let _ = rx.changed().await; + }) + .await + .unwrap(); + }); + + // Give the server a moment to bind. + tokio::time::sleep(Duration::from_millis(50)).await; + + let client = HistoryClient::new(socket_path.to_string_lossy().to_string()) + .await + .unwrap(); + + (client, shutdown_tx, tmp) + } + + #[tokio::test] + async fn test_status() { + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let status = client.status().await.unwrap(); + assert!(status.healthy); + assert_eq!(status.version, env!("CARGO_PKG_VERSION")); + assert_eq!(status.protocol, 1); + assert!(status.pid > 0); + } + + #[tokio::test] + async fn test_start_end_history() { + use atuin_client::history::History; + + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let history = History::daemon() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo hello".to_string()) + .cwd("/tmp".to_string()) + .session("test-session".to_string()) + .hostname("test-host".to_string()) + .build() + .into(); + + let start_reply = client.start_history(history).await.unwrap(); + assert!(!start_reply.id.is_empty()); + + let end_reply = client + .end_history(start_reply.id, 1_000_000, 0) + .await + .unwrap(); + assert!(!end_reply.id.is_empty()); + } + + #[tokio::test] + async fn test_end_unknown_history_fails() { + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let result = client + .end_history("nonexistent-id".to_string(), 1000, 0) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_shutdown() { + let (mut client, _shutdown_tx, _tmp) = start_test_daemon().await; + + let accepted = client.shutdown().await.unwrap(); + assert!(accepted); + + // Give server time to shut down. + tokio::time::sleep(Duration::from_millis(100)).await; + + // Subsequent calls should fail since the server is gone. + let result = client.status().await; + assert!(result.is_err()); + } +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 3329f298..b2b45b6d 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -65,6 +65,7 @@ clap = { workspace = true } clap_complete = "4.5.8" clap_complete_nushell = "4.5.4" fs-err = { workspace = true } +fs4 = "0.13.1" rpassword = "7.0" semver = { workspace = true } rustix = { workspace = true } @@ -92,6 +93,9 @@ arboard = { version = "3.4", optional = true, features = [ "wayland-data-control", ] } +[target.'cfg(unix)'.dependencies] +daemonize = "0.5.0" + [dev-dependencies] tracing-tree = "0.4" diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index a0d4373f..0cb0a2ae 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -87,10 +87,10 @@ pub enum Cmd { #[command()] Wrapped { year: Option<i32> }, - /// *Experimental* Start the background daemon + /// *Experimental* Manage the background daemon #[cfg(feature = "daemon")] #[command()] - Daemon, + Daemon(daemon::Cmd), /// Print the default atuin configuration (config.toml) #[command()] @@ -99,6 +99,15 @@ pub enum Cmd { impl Cmd { pub fn run(self) -> Result<()> { + // Daemonize before creating the async runtime – fork() inside a live + // tokio runtime corrupts its internal state. + #[cfg(all(unix, feature = "daemon"))] + if let Self::Daemon(ref cmd) = self + && cmd.should_daemonize() + { + daemon::daemonize_current_process()?; + } + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -179,7 +188,7 @@ impl Cmd { Self::Wrapped { year } => wrapped::run(year, &db, &settings, sqlite_store, theme).await, #[cfg(feature = "daemon")] - Self::Daemon => daemon::run(settings, sqlite_store, db).await, + Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, Self::History(_) | Self::Init(_) | Self::Doctor => unreachable!(), } diff --git a/crates/atuin/src/command/client/daemon.rs b/crates/atuin/src/command/client/daemon.rs index 38ba6908..a92e8f8e 100644 --- a/crates/atuin/src/command/client/daemon.rs +++ b/crates/atuin/src/command/client/daemon.rs @@ -1,10 +1,629 @@ -use eyre::Result; +use std::fs::{self, File, OpenOptions}; +use std::io::{ErrorKind, Write}; +#[cfg(unix)] +use std::os::unix::net::UnixStream as StdUnixStream; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{Duration, Instant}; -use atuin_client::{database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings}; -use atuin_daemon::server::listen; +use atuin_client::{ + database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, +}; +use atuin_daemon::{ + client::{DaemonClientErrorKind, HistoryClient, classify_error}, + server::listen, +}; +use clap::Subcommand; +#[cfg(unix)] +use daemonize::Daemonize; +use eyre::{Result, WrapErr, bail, eyre}; +use fs4::fs_std::FileExt; +use tokio::time::sleep; + +#[derive(clap::Args, Debug)] +pub struct Cmd { + /// Internal flag for daemonization + #[arg(long, hide = true)] + daemonize: bool, + + #[command(subcommand)] + subcmd: Option<SubCmd>, +} + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum SubCmd { + /// Start the daemon server + Start { + #[arg(long, hide = true)] + daemonize: bool, + }, + + /// Show the daemon's current status + Status, + + /// Stop the daemon gracefully + Stop, + + /// Restart the daemon (stop, then start in background) + Restart, +} + +impl Cmd { + /// Returns `true` when the process should daemonize before creating the + /// async runtime or opening any database connections. + #[cfg(unix)] + pub fn should_daemonize(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { daemonize }) => *daemonize, + None => self.daemonize, + _ => false, + } + } + + pub async fn run( + self, + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + ) -> Result<()> { + match self.subcmd { + None => { + eprintln!("Warning: `atuin daemon` is deprecated, use `atuin daemon start`"); + run(settings, store, history_db).await + } + Some(SubCmd::Start { .. }) => run(settings, store, history_db).await, + Some(SubCmd::Status) => status_cmd(&settings).await, + Some(SubCmd::Stop) => stop_cmd(&settings).await, + Some(SubCmd::Restart) => restart_cmd(&settings).await, + } + } +} + +const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); +const DAEMON_PROTOCOL_VERSION: u32 = 1; +const STARTUP_POLL: Duration = Duration::from_millis(40); +const LOCK_POLL: Duration = Duration::from_millis(20); +const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; + +struct PidfileGuard { + file: File, +} + +impl PidfileGuard { + fn acquire(path: &Path) -> Result<Self> { + let mut file = open_lock_file(path)?; + + if !file.try_lock_exclusive()? { + bail!( + "daemon already running (pidfile lock busy at {})", + path.display() + ); + } + + file.set_len(0) + .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; + writeln!(file, "{}", std::process::id()) + .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) + .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; + + Ok(Self { file }) + } +} + +impl Drop for PidfileGuard { + fn drop(&mut self) { + let _ = self.file.unlock(); + } +} + +enum Probe { + Ready(HistoryClient), + NeedsRestart(String), + Unreachable(eyre::Report), +} + +fn daemon_matches_expected(version: &str, protocol: u32) -> bool { + version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION +} + +fn daemon_mismatch_message(version: &str, protocol: u32) -> String { + if protocol == DAEMON_PROTOCOL_VERSION { + format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") + } else { + format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") + } +} + +fn is_legacy_daemon_error(err: &eyre::Report) -> bool { + matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) +} + +fn should_retry_after_error(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) +} + +fn daemon_startup_lock_path(pidfile_path: &Path) -> PathBuf { + let mut os = pidfile_path.as_os_str().to_os_string(); + os.push(".startup.lock"); + PathBuf::from(os) +} + +fn open_lock_file(path: &Path) -> Result<File> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; + } + + OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(path) + .wrap_err_with(|| format!("could not open lock file {}", path.display())) +} + +async fn wait_for_lock(path: &Path, timeout: Duration) -> Result<File> { + let file = open_lock_file(path)?; + let start = Instant::now(); + + loop { + match file.try_lock_exclusive() { + Ok(true) => return Ok(file), + Ok(false) => { + if start.elapsed() >= timeout { + bail!("timed out waiting for lock at {}", path.display()); + } + + sleep(LOCK_POLL).await; + } + Err(err) => { + return Err(eyre!("could not lock {}: {err}", path.display())); + } + } + } +} + +async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { + let file = wait_for_lock(path, timeout).await?; + file.unlock() + .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; + Ok(()) +} + +async fn connect_client(settings: &Settings) -> Result<HistoryClient> { + HistoryClient::new( + #[cfg(not(unix))] + settings.daemon.tcp_port, + #[cfg(unix)] + settings.daemon.socket_path.clone(), + ) + .await +} + +async fn probe(settings: &Settings) -> Probe { + let mut client = match connect_client(settings).await { + Ok(client) => client, + Err(err) => return Probe::Unreachable(err), + }; + + match client.status().await { + Ok(status) => { + if daemon_matches_expected(&status.version, status.protocol) { + Probe::Ready(client) + } else { + Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) + } + } + Err(err) => Probe::Unreachable(err), + } +} + +async fn request_shutdown(settings: &Settings) { + if let Ok(mut client) = connect_client(settings).await { + let _ = client.shutdown().await; + } +} + +fn spawn_daemon_process() -> Result<()> { + let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; + + let mut cmd = Command::new(exe); + cmd.arg("daemon") + .arg("start") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + #[cfg(unix)] + cmd.arg("--daemonize"); + + cmd.spawn().wrap_err("failed to spawn daemon process")?; + + Ok(()) +} + +fn startup_timeout(settings: &Settings) -> Duration { + Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) +} + +#[cfg(unix)] +fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { + if settings.daemon.systemd_socket { + return Ok(()); + } + + let socket_path = Path::new(&settings.daemon.socket_path); + if !socket_path.exists() { + return Ok(()); + } + + match StdUnixStream::connect(socket_path) { + Ok(stream) => { + drop(stream); + Ok(()) + } + Err(err) if err.kind() == ErrorKind::ConnectionRefused => { + fs::remove_file(socket_path).wrap_err_with(|| { + format!( + "failed to remove stale daemon socket {}", + socket_path.display() + ) + })?; + Ok(()) + } + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + Err(_) => Ok(()), + } +} + +async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result<HistoryClient> { + let start = Instant::now(); + let mut last_error = eyre!("daemon did not become ready"); + + loop { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) => { + last_error = eyre!(reason); + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + last_error = err; + } + } + + if start.elapsed() >= timeout { + return Err(last_error.wrap_err(format!( + "timed out waiting for daemon startup after {}ms", + timeout.as_millis() + ))); + } + + sleep(STARTUP_POLL).await; + } +} + +fn ensure_autostart_supported(settings: &Settings) -> Result<()> { + #[cfg(unix)] + if settings.daemon.systemd_socket { + bail!( + "daemon autostart is incompatible with `daemon.systemd_socket = true`; use systemd to manage the daemon" + ); + } + #[cfg(not(unix))] + let _ = settings; + + Ok(()) +} + +async fn restart_daemon(settings: &Settings) -> Result<HistoryClient> { + ensure_autostart_supported(settings)?; + + let timeout = startup_timeout(settings); + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let startup_lock_path = daemon_startup_lock_path(&pidfile_path); + let startup_lock = wait_for_lock(&startup_lock_path, timeout).await?; + + match probe(settings).await { + Probe::Ready(client) => { + drop(startup_lock); + return Ok(client); + } + Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + } + } + + // This prevents rapid-fire hook invocations from racing daemon restart. + wait_for_pidfile_available(&pidfile_path, timeout).await?; + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + let client = wait_until_ready(settings, timeout).await?; + + drop(startup_lock); + Ok(client) +} + +fn ensure_reply_compatible(settings: &Settings, version: &str, protocol: u32) -> Result<()> { + if daemon_matches_expected(version, protocol) { + return Ok(()); + } + + let message = daemon_mismatch_message(version, protocol); + if settings.daemon.autostart { + bail!("{message}"); + } + + bail!("{message}. Enable `daemon.autostart = true` or restart the daemon manually"); +} + +pub async fn start_history(settings: &Settings, history: History) -> Result<String> { + match async { + connect_client(settings) + .await? + .start_history(history.clone()) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(resp.id); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .start_history(history) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(resp.id) +} + +pub async fn end_history(settings: &Settings, id: String, duration: u64, exit: i64) -> Result<()> { + match async { + connect_client(settings) + .await? + .end_history(id.clone(), duration, exit) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(()); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + + // End succeeded on the running daemon, so avoid replaying it. + // We only restart to make subsequent hook calls target the expected version. + let _ = restart_daemon(settings).await; + return Ok(()); + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .end_history(id, duration, exit) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(()) +} + +async fn status_cmd(settings: &Settings) -> Result<()> { + match probe(settings).await { + Probe::Ready(mut client) => { + let status = client.status().await?; + println!("Daemon running"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + println!(" Protocol: {}", status.protocol); + println!(" Healthy: {}", status.healthy); + #[cfg(unix)] + println!(" Socket: {}", settings.daemon.socket_path); + #[cfg(not(unix))] + println!(" Port: {}", settings.daemon.tcp_port); + } + Probe::NeedsRestart(reason) => { + println!("Daemon running (needs restart)"); + println!(" Reason: {reason}"); + } + Probe::Unreachable(_) => { + println!("Daemon is not running"); + } + } + + Ok(()) +} + +async fn stop_cmd(settings: &Settings) -> Result<()> { + let Ok(mut client) = connect_client(settings).await else { + println!("Daemon is not running"); + return Ok(()); + }; + + match client.shutdown().await { + Ok(true) => { + println!("Shutdown requested"); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + match wait_for_pidfile_available(&pidfile_path, timeout).await { + Ok(()) => println!("Daemon stopped"), + Err(_) => println!("Daemon may still be shutting down"), + } + + Ok(()) + } + Ok(false) => bail!("Daemon rejected shutdown request"), + Err(err) => Err(err.wrap_err("Failed to send shutdown request")), + } +} + +async fn restart_cmd(settings: &Settings) -> Result<()> { + // Stop if running + match probe(settings).await { + Probe::Ready(_) | Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + println!("Stopping daemon..."); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + wait_for_pidfile_available(&pidfile_path, timeout) + .await + .wrap_err("Timed out waiting for old daemon to stop")?; + } + Probe::Unreachable(_) => { + println!("No daemon running"); + } + } + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + println!("Starting daemon..."); + + let timeout = startup_timeout(settings); + let status = wait_until_ready(settings, timeout).await?.status().await?; + + println!("Daemon restarted"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + + Ok(()) +} + +/// Daemonize the current process. Must be called before creating the tokio +/// runtime or opening database connections, since `fork()` inside an async +/// runtime corrupts its internal state. +#[cfg(unix)] +pub fn daemonize_current_process() -> Result<()> { + let cwd = + std::env::current_dir().wrap_err("could not determine current directory for daemon")?; + + Daemonize::new() + .working_directory(cwd) + .start() + .wrap_err("failed to daemonize process")?; + + Ok(()) +} + +async fn run(settings: Settings, store: SqliteStore, history_db: Sqlite) -> Result<()> { + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; -pub async fn run(settings: Settings, store: SqliteStore, history_db: Sqlite) -> Result<()> { listen(settings, store, history_db).await?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_matches() { + assert!(daemon_matches_expected( + DAEMON_VERSION, + DAEMON_PROTOCOL_VERSION + )); + } + + #[test] + fn test_version_mismatch() { + assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); + assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); + assert!(!daemon_matches_expected("0.0.0", 999)); + } + + #[test] + fn test_mismatch_message_version() { + let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); + assert!(msg.contains("out of date"), "got: {msg}"); + assert!(msg.contains("0.0.0")); + assert!(msg.contains(DAEMON_VERSION)); + } + + #[test] + fn test_mismatch_message_protocol() { + let msg = daemon_mismatch_message(DAEMON_VERSION, 999); + assert!(msg.contains("protocol mismatch"), "got: {msg}"); + } + + #[test] + fn test_startup_lock_path() { + let pidfile = Path::new("/tmp/atuin-daemon.pid"); + let lock = daemon_startup_lock_path(pidfile); + assert_eq!(lock, PathBuf::from("/tmp/atuin-daemon.pid.startup.lock")); + } + + #[test] + fn test_pidfile_guard_acquire_and_drop() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + { + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + // Guard holds an exclusive lock — on Windows other handles cannot + // read the file, so we verify contents after the guard is dropped. + } + + let contents = std::fs::read_to_string(&pidfile).unwrap(); + let lines: Vec<&str> = contents.lines().collect(); + assert_eq!(lines.len(), 2); + assert_eq!(lines[0], std::process::id().to_string()); + assert_eq!(lines[1], DAEMON_VERSION); + + // After guard is dropped, lock should be released — acquiring again must succeed. + let _guard2 = PidfileGuard::acquire(&pidfile).unwrap(); + } + + #[test] + fn test_pidfile_guard_prevents_double_acquire() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + let result = PidfileGuard::acquire(&pidfile); + assert!(result.is_err()); + } +} diff --git a/crates/atuin/src/command/client/history.rs b/crates/atuin/src/command/client/history.rs index e8e544b5..ea22a2fd 100644 --- a/crates/atuin/src/command/client/history.rs +++ b/crates/atuin/src/command/client/history.rs @@ -27,6 +27,8 @@ use atuin_client::{record, sync}; use log::{debug, warn}; use time::{OffsetDateTime, macros::format_description}; +#[cfg(feature = "daemon")] +use super::daemon; use super::search::format_duration_into; #[derive(Subcommand, Debug)] @@ -392,15 +394,7 @@ impl Cmd { return Ok(()); } - let resp = atuin_daemon::client::HistoryClient::new( - #[cfg(not(unix))] - settings.daemon.tcp_port, - #[cfg(unix)] - settings.daemon.socket_path.clone(), - ) - .await? - .start_history(h) - .await?; + let resp = daemon::start_history(settings, h).await?; // print the ID // we use this as the key for calling end @@ -477,22 +471,13 @@ impl Cmd { } #[cfg(feature = "daemon")] - #[allow(unused_variables)] async fn handle_daemon_end( settings: &Settings, id: &str, exit: i64, duration: Option<u64>, ) -> Result<()> { - let resp = atuin_daemon::client::HistoryClient::new( - #[cfg(not(unix))] - settings.daemon.tcp_port, - #[cfg(unix)] - settings.daemon.socket_path.clone(), - ) - .await? - .end_history(id.to_string(), duration.unwrap_or(0), exit) - .await?; + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; Ok(()) } |
