diff options
| author | Ellie Huxtable <ellie@atuin.sh> | 2026-02-13 11:37:58 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-13 11:37:58 -0800 |
| commit | d52c4d6003adced1f6763261a7f9132719be5533 (patch) | |
| tree | 94d2f7c815814135edbc2ee0af56d2cedd1a1724 /crates/atuin-daemon | |
| parent | feat: add Atuin AI inline CLI MVP (#3178) (diff) | |
| download | atuin-d52c4d6003adced1f6763261a7f9132719be5533.zip | |
feat: add autostart and pid management to daemon (#3180)
Diffstat (limited to 'crates/atuin-daemon')
| -rw-r--r-- | crates/atuin-daemon/Cargo.toml | 4 | ||||
| -rw-r--r-- | crates/atuin-daemon/proto/history.proto | 21 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/client.rs | 51 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/server.rs | 89 | ||||
| -rw-r--r-- | crates/atuin-daemon/tests/lifecycle.rs | 127 |
5 files changed, 268 insertions, 24 deletions
diff --git a/crates/atuin-daemon/Cargo.toml b/crates/atuin-daemon/Cargo.toml index 9adbe5e8..8d3b4ab6 100644 --- a/crates/atuin-daemon/Cargo.toml +++ b/crates/atuin-daemon/Cargo.toml @@ -39,6 +39,10 @@ rand.workspace = true [target.'cfg(target_os = "linux")'.dependencies] listenfd = "1.0.1" +[dev-dependencies] +tempfile = { workspace = true } +atuin-common = { path = "../atuin-common", version = "18.12.1" } + [build-dependencies] protox = "0.8.0" tonic-build = "0.12" diff --git a/crates/atuin-daemon/proto/history.proto b/crates/atuin-daemon/proto/history.proto index 1172b91b..9fbd3372 100644 --- a/crates/atuin-daemon/proto/history.proto +++ b/crates/atuin-daemon/proto/history.proto @@ -18,14 +18,35 @@ message EndHistoryRequest { message StartHistoryReply { string id = 1; + string version = 2; + uint32 protocol = 3; } message EndHistoryReply { string id = 1; uint64 idx = 2; + string version = 3; + uint32 protocol = 4; +} + +message StatusRequest {} + +message StatusReply { + bool healthy = 1; + string version = 2; + uint32 pid = 3; + uint32 protocol = 4; +} + +message ShutdownRequest {} + +message ShutdownReply { + bool accepted = 1; } service History { rpc StartHistory(StartHistoryRequest) returns (StartHistoryReply); rpc EndHistory(EndHistoryRequest) returns (EndHistoryReply); + rpc Status(StatusRequest) returns (StatusReply); + rpc Shutdown(ShutdownRequest) returns (ShutdownReply); } diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index a4b4690e..05067bda 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -1,6 +1,7 @@ use eyre::{Context, Result}; #[cfg(windows)] use tokio::net::TcpStream; +use tonic::Code; use tonic::transport::{Channel, Endpoint, Uri}; use tower::service_fn; @@ -12,13 +13,41 @@ use tokio::net::UnixStream; use atuin_client::history::History; use crate::history::{ - EndHistoryRequest, StartHistoryRequest, history_client::HistoryClient as HistoryServiceClient, + EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, + StatusReply, StatusRequest, history_client::HistoryClient as HistoryServiceClient, }; pub struct HistoryClient { client: HistoryServiceClient<Channel>, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DaemonClientErrorKind { + Connect, + Unavailable, + Unimplemented, + Other, +} + +#[must_use] +pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { + for cause in error.chain() { + if cause.downcast_ref::<tonic::transport::Error>().is_some() { + return DaemonClientErrorKind::Connect; + } + + if let Some(status) = cause.downcast_ref::<tonic::Status>() { + return match status.code() { + Code::Unavailable => DaemonClientErrorKind::Unavailable, + Code::Unimplemented => DaemonClientErrorKind::Unimplemented, + _ => DaemonClientErrorKind::Other, + }; + } + } + + DaemonClientErrorKind::Other +} + // Wrap the grpc client impl HistoryClient { #[cfg(unix)] @@ -67,7 +96,7 @@ impl HistoryClient { Ok(HistoryClient { client }) } - pub async fn start_history(&mut self, h: History) -> Result<String> { + pub async fn start_history(&mut self, h: History) -> Result<StartHistoryReply> { let req = StartHistoryRequest { command: h.command, cwd: h.cwd, @@ -76,9 +105,7 @@ impl HistoryClient { timestamp: h.timestamp.unix_timestamp_nanos() as u64, }; - let resp = self.client.start_history(req).await?; - - Ok(resp.into_inner().id) + Ok(self.client.start_history(req).await?.into_inner()) } pub async fn end_history( @@ -86,12 +113,18 @@ impl HistoryClient { id: String, duration: u64, exit: i64, - ) -> Result<(String, u64)> { + ) -> Result<EndHistoryReply> { let req = EndHistoryRequest { id, duration, exit }; - let resp = self.client.end_history(req).await?; - let resp = resp.into_inner(); + Ok(self.client.end_history(req).await?.into_inner()) + } + + pub async fn status(&mut self) -> Result<StatusReply> { + Ok(self.client.status(StatusRequest {}).await?.into_inner()) + } - Ok((resp.id, resp.idx)) + pub async fn shutdown(&mut self) -> Result<bool> { + let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); + Ok(resp.accepted) } } diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs index 2cba1753..9622d2b6 100644 --- a/crates/atuin-daemon/src/server.rs +++ b/crates/atuin-daemon/src/server.rs @@ -4,10 +4,12 @@ use atuin_client::encryption; use atuin_client::history::store::HistoryStore; use atuin_client::record::sqlite_store::SqliteStore; use atuin_client::settings::Settings; +use std::io::ErrorKind; #[cfg(unix)] use std::path::PathBuf; use std::sync::Arc; use time::OffsetDateTime; +use tokio::sync::watch; use tracing::{Level, instrument}; use atuin_client::database::{Database, Sqlite as HistoryDatabase}; @@ -19,9 +21,12 @@ use tonic::{Request, Response, Status, transport::Server}; use crate::history::history_server::{History as HistorySvc, HistoryServer}; use crate::history::{EndHistoryReply, EndHistoryRequest, StartHistoryReply, StartHistoryRequest}; +use crate::history::{ShutdownReply, ShutdownRequest, StatusReply, StatusRequest}; mod sync; +const DAEMON_PROTOCOL_VERSION: u32 = 1; + #[derive(Debug)] pub struct HistoryService { // A store for WIP history @@ -29,14 +34,20 @@ pub struct HistoryService { running: Arc<DashMap<HistoryId, History>>, store: HistoryStore, history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, } impl HistoryService { - pub fn new(store: HistoryStore, history_db: HistoryDatabase) -> Self { + pub fn new( + store: HistoryStore, + history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, + ) -> Self { Self { running: Arc::new(DashMap::new()), store, history_db, + shutdown_tx, } } } @@ -77,7 +88,11 @@ impl HistorySvc for HistoryService { tracing::info!(id = id.to_string(), "start history"); running.insert(id.clone(), h); - let reply = StartHistoryReply { id: id.to_string() }; + let reply = StartHistoryReply { + id: id.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; Ok(Response::new(reply)) } @@ -122,6 +137,8 @@ impl HistorySvc for HistoryService { let reply = EndHistoryReply { id: id.0.to_string(), idx, + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, }; return Ok(Response::new(reply)); @@ -131,10 +148,35 @@ impl HistorySvc for HistoryService { "could not find history with id: {id}" ))) } + + #[instrument(skip_all, level = Level::INFO)] + async fn status( + &self, + _request: Request<StatusRequest>, + ) -> Result<Response<StatusReply>, Status> { + let reply = StatusReply { + // If status RPC responds, the daemon control plane is healthy. + healthy: true, + version: env!("CARGO_PKG_VERSION").to_string(), + pid: std::process::id(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn shutdown( + &self, + _request: Request<ShutdownRequest>, + ) -> Result<Response<ShutdownReply>, Status> { + let _ = self.shutdown_tx.send(true); + Ok(Response::new(ShutdownReply { accepted: true })) + } } #[cfg(unix)] -async fn shutdown_signal(socket: Option<PathBuf>) { +async fn shutdown_signal(socket: Option<PathBuf>, mut shutdown_rx: watch::Receiver<bool>) { let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to register sigterm handler"); let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) @@ -143,26 +185,38 @@ async fn shutdown_signal(socket: Option<PathBuf>) { tokio::select! { _ = term.recv() => {}, _ = int.recv() => {}, + _ = shutdown_rx.changed() => {}, } eprintln!("Removing socket..."); if let Some(socket) = socket { - std::fs::remove_file(socket).expect("failed to remove socket"); + match std::fs::remove_file(socket) { + Ok(()) => {} + Err(err) if err.kind() == ErrorKind::NotFound => {} + Err(err) => { + eprintln!("failed to remove socket: {err}"); + } + } } eprintln!("Shutting down..."); } #[cfg(windows)] -async fn shutdown_signal() { - tokio::signal::windows::ctrl_c() - .expect("failed to register signal handler") - .recv() - .await; +async fn shutdown_signal(mut shutdown_rx: watch::Receiver<bool>) { + let mut ctrl_c = tokio::signal::windows::ctrl_c().expect("failed to register signal handler"); + tokio::select! { + _ = ctrl_c.recv() => {}, + _ = shutdown_rx.changed() => {}, + } eprintln!("Shutting down..."); } #[cfg(unix)] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::UnixListener; use tokio_stream::wrappers::UnixListenerStream; @@ -215,7 +269,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> .add_service(HistoryServer::new(history)) .serve_with_incoming_shutdown( uds_stream, - shutdown_signal(cleanup.then_some(socket_path.into())), + shutdown_signal(cleanup.then_some(socket_path.into()), shutdown_rx), ) .await?; @@ -223,7 +277,11 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> } #[cfg(not(unix))] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; @@ -236,7 +294,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> Server::builder() .add_service(HistoryServer::new(history)) - .serve_with_incoming_shutdown(tcp_stream, shutdown_signal()) + .serve_with_incoming_shutdown(tcp_stream, shutdown_signal(shutdown_rx)) .await?; Ok(()) } @@ -257,7 +315,8 @@ pub async fn listen( let host_id = Settings::host_id().await?; let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - let history = HistoryService::new(history_store.clone(), history_db.clone()); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let history = HistoryService::new(history_store.clone(), history_db.clone(), shutdown_tx); // start services tokio::spawn(sync::worker( @@ -267,5 +326,5 @@ pub async fn listen( history_db, )); - start_server(settings, history).await + start_server(settings, history, shutdown_rx).await } diff --git a/crates/atuin-daemon/tests/lifecycle.rs b/crates/atuin-daemon/tests/lifecycle.rs new file mode 100644 index 00000000..56457fa7 --- /dev/null +++ b/crates/atuin-daemon/tests/lifecycle.rs @@ -0,0 +1,127 @@ +//! Integration tests for the daemon server lifecycle. +//! +//! Each test spins up a real gRPC server on a temporary unix socket, +//! connects a client, and exercises the daemon RPCs. + +#[cfg(unix)] +mod unix { + use std::time::Duration; + + use atuin_client::database::Sqlite; + use atuin_client::history::store::HistoryStore; + use atuin_client::record::sqlite_store::SqliteStore; + use atuin_common::record::HostId; + use atuin_common::utils::uuid_v7; + use atuin_daemon::client::HistoryClient; + use atuin_daemon::history::history_server::HistoryServer; + use atuin_daemon::server::HistoryService; + use tempfile::TempDir; + use tokio::net::UnixListener; + use tokio::sync::watch; + use tokio_stream::wrappers::UnixListenerStream; + use tonic::transport::Server; + + /// Spins up a daemon server on a temp socket and returns a connected client, + /// the shutdown sender, and the temp dir (must be held to keep paths alive). + async fn start_test_daemon() -> (HistoryClient, watch::Sender<bool>, TempDir) { + let tmp = tempfile::tempdir().unwrap(); + + let db_path = tmp.path().join("history.db"); + let record_path = tmp.path().join("records.db"); + + let history_db = Sqlite::new(&db_path, 5.0).await.unwrap(); + let store = SqliteStore::new(&record_path, 5.0).await.unwrap(); + + let host_id = HostId(uuid_v7()); + let encryption_key = [0u8; 32]; + let history_store = HistoryStore::new(store, host_id, encryption_key); + + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let service = HistoryService::new(history_store, history_db, shutdown_tx.clone()); + + let socket_path = tmp.path().join("test.sock"); + let uds = UnixListener::bind(&socket_path).unwrap(); + let stream = UnixListenerStream::new(uds); + + let mut rx = shutdown_rx.clone(); + tokio::spawn(async move { + Server::builder() + .add_service(HistoryServer::new(service)) + .serve_with_incoming_shutdown(stream, async move { + let _ = rx.changed().await; + }) + .await + .unwrap(); + }); + + // Give the server a moment to bind. + tokio::time::sleep(Duration::from_millis(50)).await; + + let client = HistoryClient::new(socket_path.to_string_lossy().to_string()) + .await + .unwrap(); + + (client, shutdown_tx, tmp) + } + + #[tokio::test] + async fn test_status() { + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let status = client.status().await.unwrap(); + assert!(status.healthy); + assert_eq!(status.version, env!("CARGO_PKG_VERSION")); + assert_eq!(status.protocol, 1); + assert!(status.pid > 0); + } + + #[tokio::test] + async fn test_start_end_history() { + use atuin_client::history::History; + + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let history = History::daemon() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo hello".to_string()) + .cwd("/tmp".to_string()) + .session("test-session".to_string()) + .hostname("test-host".to_string()) + .build() + .into(); + + let start_reply = client.start_history(history).await.unwrap(); + assert!(!start_reply.id.is_empty()); + + let end_reply = client + .end_history(start_reply.id, 1_000_000, 0) + .await + .unwrap(); + assert!(!end_reply.id.is_empty()); + } + + #[tokio::test] + async fn test_end_unknown_history_fails() { + let (mut client, _shutdown, _tmp) = start_test_daemon().await; + + let result = client + .end_history("nonexistent-id".to_string(), 1000, 0) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_shutdown() { + let (mut client, _shutdown_tx, _tmp) = start_test_daemon().await; + + let accepted = client.shutdown().await.unwrap(); + assert!(accepted); + + // Give server time to shut down. + tokio::time::sleep(Duration::from_millis(100)).await; + + // Subsequent calls should fail since the server is gone. + let result = client.status().await; + assert!(result.is_err()); + } +} |
