diff options
| author | Ellie Huxtable <ellie@atuin.sh> | 2026-02-13 11:37:58 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-13 11:37:58 -0800 |
| commit | d52c4d6003adced1f6763261a7f9132719be5533 (patch) | |
| tree | 94d2f7c815814135edbc2ee0af56d2cedd1a1724 /crates/atuin-daemon/src | |
| parent | feat: add Atuin AI inline CLI MVP (#3178) (diff) | |
| download | atuin-d52c4d6003adced1f6763261a7f9132719be5533.zip | |
feat: add autostart and pid management to daemon (#3180)
Diffstat (limited to 'crates/atuin-daemon/src')
| -rw-r--r-- | crates/atuin-daemon/src/client.rs | 51 | ||||
| -rw-r--r-- | crates/atuin-daemon/src/server.rs | 89 |
2 files changed, 116 insertions, 24 deletions
diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index a4b4690e..05067bda 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -1,6 +1,7 @@ use eyre::{Context, Result}; #[cfg(windows)] use tokio::net::TcpStream; +use tonic::Code; use tonic::transport::{Channel, Endpoint, Uri}; use tower::service_fn; @@ -12,13 +13,41 @@ use tokio::net::UnixStream; use atuin_client::history::History; use crate::history::{ - EndHistoryRequest, StartHistoryRequest, history_client::HistoryClient as HistoryServiceClient, + EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, + StatusReply, StatusRequest, history_client::HistoryClient as HistoryServiceClient, }; pub struct HistoryClient { client: HistoryServiceClient<Channel>, } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DaemonClientErrorKind { + Connect, + Unavailable, + Unimplemented, + Other, +} + +#[must_use] +pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { + for cause in error.chain() { + if cause.downcast_ref::<tonic::transport::Error>().is_some() { + return DaemonClientErrorKind::Connect; + } + + if let Some(status) = cause.downcast_ref::<tonic::Status>() { + return match status.code() { + Code::Unavailable => DaemonClientErrorKind::Unavailable, + Code::Unimplemented => DaemonClientErrorKind::Unimplemented, + _ => DaemonClientErrorKind::Other, + }; + } + } + + DaemonClientErrorKind::Other +} + // Wrap the grpc client impl HistoryClient { #[cfg(unix)] @@ -67,7 +96,7 @@ impl HistoryClient { Ok(HistoryClient { client }) } - pub async fn start_history(&mut self, h: History) -> Result<String> { + pub async fn start_history(&mut self, h: History) -> Result<StartHistoryReply> { let req = StartHistoryRequest { command: h.command, cwd: h.cwd, @@ -76,9 +105,7 @@ impl HistoryClient { timestamp: h.timestamp.unix_timestamp_nanos() as u64, }; - let resp = self.client.start_history(req).await?; - - Ok(resp.into_inner().id) + Ok(self.client.start_history(req).await?.into_inner()) } pub async fn end_history( @@ -86,12 +113,18 @@ impl HistoryClient { id: String, duration: u64, exit: i64, - ) -> Result<(String, u64)> { + ) -> Result<EndHistoryReply> { let req = EndHistoryRequest { id, duration, exit }; - let resp = self.client.end_history(req).await?; - let resp = resp.into_inner(); + Ok(self.client.end_history(req).await?.into_inner()) + } + + pub async fn status(&mut self) -> Result<StatusReply> { + Ok(self.client.status(StatusRequest {}).await?.into_inner()) + } - Ok((resp.id, resp.idx)) + pub async fn shutdown(&mut self) -> Result<bool> { + let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); + Ok(resp.accepted) } } diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs index 2cba1753..9622d2b6 100644 --- a/crates/atuin-daemon/src/server.rs +++ b/crates/atuin-daemon/src/server.rs @@ -4,10 +4,12 @@ use atuin_client::encryption; use atuin_client::history::store::HistoryStore; use atuin_client::record::sqlite_store::SqliteStore; use atuin_client::settings::Settings; +use std::io::ErrorKind; #[cfg(unix)] use std::path::PathBuf; use std::sync::Arc; use time::OffsetDateTime; +use tokio::sync::watch; use tracing::{Level, instrument}; use atuin_client::database::{Database, Sqlite as HistoryDatabase}; @@ -19,9 +21,12 @@ use tonic::{Request, Response, Status, transport::Server}; use crate::history::history_server::{History as HistorySvc, HistoryServer}; use crate::history::{EndHistoryReply, EndHistoryRequest, StartHistoryReply, StartHistoryRequest}; +use crate::history::{ShutdownReply, ShutdownRequest, StatusReply, StatusRequest}; mod sync; +const DAEMON_PROTOCOL_VERSION: u32 = 1; + #[derive(Debug)] pub struct HistoryService { // A store for WIP history @@ -29,14 +34,20 @@ pub struct HistoryService { running: Arc<DashMap<HistoryId, History>>, store: HistoryStore, history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, } impl HistoryService { - pub fn new(store: HistoryStore, history_db: HistoryDatabase) -> Self { + pub fn new( + store: HistoryStore, + history_db: HistoryDatabase, + shutdown_tx: watch::Sender<bool>, + ) -> Self { Self { running: Arc::new(DashMap::new()), store, history_db, + shutdown_tx, } } } @@ -77,7 +88,11 @@ impl HistorySvc for HistoryService { tracing::info!(id = id.to_string(), "start history"); running.insert(id.clone(), h); - let reply = StartHistoryReply { id: id.to_string() }; + let reply = StartHistoryReply { + id: id.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; Ok(Response::new(reply)) } @@ -122,6 +137,8 @@ impl HistorySvc for HistoryService { let reply = EndHistoryReply { id: id.0.to_string(), idx, + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, }; return Ok(Response::new(reply)); @@ -131,10 +148,35 @@ impl HistorySvc for HistoryService { "could not find history with id: {id}" ))) } + + #[instrument(skip_all, level = Level::INFO)] + async fn status( + &self, + _request: Request<StatusRequest>, + ) -> Result<Response<StatusReply>, Status> { + let reply = StatusReply { + // If status RPC responds, the daemon control plane is healthy. + healthy: true, + version: env!("CARGO_PKG_VERSION").to_string(), + pid: std::process::id(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn shutdown( + &self, + _request: Request<ShutdownRequest>, + ) -> Result<Response<ShutdownReply>, Status> { + let _ = self.shutdown_tx.send(true); + Ok(Response::new(ShutdownReply { accepted: true })) + } } #[cfg(unix)] -async fn shutdown_signal(socket: Option<PathBuf>) { +async fn shutdown_signal(socket: Option<PathBuf>, mut shutdown_rx: watch::Receiver<bool>) { let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to register sigterm handler"); let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) @@ -143,26 +185,38 @@ async fn shutdown_signal(socket: Option<PathBuf>) { tokio::select! { _ = term.recv() => {}, _ = int.recv() => {}, + _ = shutdown_rx.changed() => {}, } eprintln!("Removing socket..."); if let Some(socket) = socket { - std::fs::remove_file(socket).expect("failed to remove socket"); + match std::fs::remove_file(socket) { + Ok(()) => {} + Err(err) if err.kind() == ErrorKind::NotFound => {} + Err(err) => { + eprintln!("failed to remove socket: {err}"); + } + } } eprintln!("Shutting down..."); } #[cfg(windows)] -async fn shutdown_signal() { - tokio::signal::windows::ctrl_c() - .expect("failed to register signal handler") - .recv() - .await; +async fn shutdown_signal(mut shutdown_rx: watch::Receiver<bool>) { + let mut ctrl_c = tokio::signal::windows::ctrl_c().expect("failed to register signal handler"); + tokio::select! { + _ = ctrl_c.recv() => {}, + _ = shutdown_rx.changed() => {}, + } eprintln!("Shutting down..."); } #[cfg(unix)] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::UnixListener; use tokio_stream::wrappers::UnixListenerStream; @@ -215,7 +269,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> .add_service(HistoryServer::new(history)) .serve_with_incoming_shutdown( uds_stream, - shutdown_signal(cleanup.then_some(socket_path.into())), + shutdown_signal(cleanup.then_some(socket_path.into()), shutdown_rx), ) .await?; @@ -223,7 +277,11 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> } #[cfg(not(unix))] -async fn start_server(settings: Settings, history: HistoryService) -> Result<()> { +async fn start_server( + settings: Settings, + history: HistoryService, + shutdown_rx: watch::Receiver<bool>, +) -> Result<()> { use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; @@ -236,7 +294,7 @@ async fn start_server(settings: Settings, history: HistoryService) -> Result<()> Server::builder() .add_service(HistoryServer::new(history)) - .serve_with_incoming_shutdown(tcp_stream, shutdown_signal()) + .serve_with_incoming_shutdown(tcp_stream, shutdown_signal(shutdown_rx)) .await?; Ok(()) } @@ -257,7 +315,8 @@ pub async fn listen( let host_id = Settings::host_id().await?; let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - let history = HistoryService::new(history_store.clone(), history_db.clone()); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let history = HistoryService::new(history_store.clone(), history_db.clone(), shutdown_tx); // start services tokio::spawn(sync::worker( @@ -267,5 +326,5 @@ pub async fn listen( history_db, )); - start_server(settings, history).await + start_server(settings, history, shutdown_rx).await } |
