use eyre::{Context as EyreContext, Result}; use tonic::Code; use tonic::transport::{Channel, Endpoint, Uri}; use tower::service_fn; use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixStream; use tracing::{Level, instrument, span}; use crate::atuin_daemon::generated; use crate::{ atuin_client::{ database::Context, history::History, settings::{FilterMode, Settings}, }, atuin_daemon::{ events::DaemonEvent, generated::{ control::{ ForceSyncEvent, HistoryDeletedEvent, HistoryPrunedEvent, HistoryRebuiltEvent, SendEventRequest, SettingsReloadedEvent, ShutdownEvent, control_client::ControlClient as ControlServiceClient, }, history::{ EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, StatusReply, StatusRequest, TailHistoryReply, TailHistoryRequest, history_client::HistoryClient as HistoryServiceClient, }, search::{ FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, search_client::SearchClient as SearchServiceClient, }, semantic::{ CommandCapture, RecordCommandsReply, semantic_client::SemanticClient as SemanticServiceClient, }, }, }, }; pub(crate) struct HistoryClient { client: HistoryServiceClient, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum DaemonClientErrorKind { Connect, Unavailable, Unimplemented, Other, } #[must_use] pub(crate) fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { for cause in error.chain() { if cause.downcast_ref::().is_some() { return DaemonClientErrorKind::Connect; } if let Some(status) = cause.downcast_ref::() { return match status.code() { Code::Unavailable => DaemonClientErrorKind::Unavailable, Code::Unimplemented => DaemonClientErrorKind::Unimplemented, _ => DaemonClientErrorKind::Other, }; } } DaemonClientErrorKind::Other } // Wrap the grpc client impl HistoryClient { #[cfg(unix)] pub(crate) async fn new(path: String) -> Result { use eyre::Context; let log_path = path.clone(); let channel = Endpoint::try_from("http://atuin_local_daemon:0")? .connect_with_connector(service_fn(move |_: Uri| { let path = path.clone(); async move { Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) } })) .await .wrap_err_with(|| { format!( "failed to connect to local atuin daemon at {}. Is it running?", &log_path ) })?; let client = HistoryServiceClient::new(channel); Ok(HistoryClient { client }) } pub(crate) async fn start_history(&mut self, h: History) -> Result { let req = StartHistoryRequest { command: h.command, cwd: h.cwd, hostname: h.hostname, session: h.session, timestamp: h.timestamp.unix_timestamp_nanos() as u64, author: h.author, intent: h.intent.unwrap_or_default(), }; Ok(self.client.start_history(req).await?.into_inner()) } pub(crate) async fn end_history( &mut self, id: String, duration: u64, exit: i64, ) -> Result { let req = EndHistoryRequest { id, exit, duration }; Ok(self.client.end_history(req).await?.into_inner()) } pub(crate) async fn status(&mut self) -> Result { Ok(self.client.status(StatusRequest {}).await?.into_inner()) } pub(crate) async fn tail_history(&mut self) -> Result> { Ok(self .client .tail_history(TailHistoryRequest {}) .await? .into_inner()) } pub(crate) async fn shutdown(&mut self) -> Result { let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); Ok(resp.accepted) } } pub(crate) struct SearchClient { client: SearchServiceClient, } impl SearchClient { #[cfg(unix)] pub(crate) async fn new(path: String) -> Result { let log_path = path.clone(); let channel = Endpoint::try_from("http://atuin_local_daemon:0")? .connect_with_connector(service_fn(move |_: Uri| { let path = path.clone(); async move { Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) } })) .await .wrap_err_with(|| { format!( "failed to connect to local atuin daemon at {}. Is it running?", &log_path ) })?; let client = SearchServiceClient::new(channel); Ok(SearchClient { client }) } #[instrument(skip_all, level = Level::TRACE, name = "daemon_client_search", fields(query = %query, query_id = query_id))] pub(crate) async fn search( &mut self, query: String, query_id: u64, filter_mode: FilterMode, context: Option, ) -> Result> { let request = SearchRequest { query, query_id, filter_mode: RpcFilterMode::from(filter_mode).into(), context: context.map(RpcSearchContext::from), }; let request_stream = tokio_stream::once(request); let response = span!(Level::TRACE, "daemon_client_search.request") .in_scope(async || self.client.search(request_stream).await) .await?; Ok(response.into_inner()) } } impl From for RpcFilterMode { fn from(filter_mode: FilterMode) -> Self { match filter_mode { FilterMode::Global => RpcFilterMode::Global, FilterMode::Host => RpcFilterMode::Host, FilterMode::Session => RpcFilterMode::Session, FilterMode::Directory => RpcFilterMode::Directory, FilterMode::Workspace => RpcFilterMode::Workspace, FilterMode::SessionPreload => RpcFilterMode::SessionPreload, } } } impl From for RpcSearchContext { fn from(context: Context) -> Self { RpcSearchContext { session_id: context.session, cwd: context.cwd, hostname: context.hostname, host_id: context.host_id, git_root: context .git_root .map(|path| path.to_string_lossy().to_string()), } } } pub(crate) struct SemanticClient { client: SemanticServiceClient, } impl SemanticClient { #[cfg(unix)] pub(crate) async fn new(path: String) -> Result { let log_path = path.clone(); let channel = Endpoint::try_from("http://atuin_local_daemon:0")? .connect_with_connector(service_fn(move |_: Uri| { let path = path.clone(); async move { Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) } })) .await .wrap_err_with(|| { format!( "failed to connect to local atuin daemon at {}. Is it running?", &log_path ) })?; let client = SemanticServiceClient::new(channel); Ok(SemanticClient { client }) } #[cfg(unix)] pub(crate) async fn from_settings(settings: &Settings) -> Result { Self::new(settings.daemon.socket_path.clone()).await } pub(crate) async fn record_commands( &mut self, captures: Vec, ) -> Result { let stream = tokio_stream::iter(captures); Ok(self.client.record_commands(stream).await?.into_inner()) } } // ============================================================================ // Control Client // ============================================================================ /// Client for the Control gRPC service. /// /// Used to inject events into a running daemon from external processes. pub(crate) struct ControlClient { client: ControlServiceClient, } impl ControlClient { /// Connect to the daemon's control service. #[cfg(unix)] pub(crate) async fn new(path: String) -> Result { let log_path = path.clone(); let channel = Endpoint::try_from("http://atuin_local_daemon:0")? .connect_with_connector(service_fn(move |_: Uri| { let path = path.clone(); async move { Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) } })) .await .wrap_err_with(|| { format!( "failed to connect to local atuin daemon at {}. Is it running?", &log_path ) })?; let client = ControlServiceClient::new(channel); Ok(ControlClient { client }) } /// Connect using settings. #[cfg(unix)] pub(crate) async fn from_settings(settings: &Settings) -> Result { Self::new(settings.daemon.socket_path.clone()).await } /// Send an event to the daemon. pub(crate) async fn send_event(&mut self, event: DaemonEvent) -> Result<()> { let proto_event = daemon_event_to_proto(event); let request = SendEventRequest { event: Some(proto_event), }; self.client.send_event(request).await?; Ok(()) } } /// Convert a daemon event to its proto representation. fn daemon_event_to_proto(event: DaemonEvent) -> generated::control::send_event_request::Event { use generated::control::send_event_request::Event; match event { DaemonEvent::HistoryPruned => Event::HistoryPruned(HistoryPrunedEvent {}), DaemonEvent::HistoryRebuilt => Event::HistoryRebuilt(HistoryRebuiltEvent {}), DaemonEvent::HistoryDeleted { ids } => Event::HistoryDeleted(HistoryDeletedEvent { ids: ids.into_iter().map(|id| id.0).collect(), }), DaemonEvent::ForceSync => Event::ForceSync(ForceSyncEvent {}), DaemonEvent::SettingsReloaded => Event::SettingsReloaded(SettingsReloadedEvent {}), DaemonEvent::ShutdownRequested => Event::Shutdown(ShutdownEvent {}), // These events are internal and not sent via the control service DaemonEvent::HistoryStarted(_) | DaemonEvent::HistoryEnded(_) | DaemonEvent::RecordsAdded(_) | DaemonEvent::SyncCompleted { .. } | DaemonEvent::SyncFailed { .. } => { // Use shutdown as a fallback, though this shouldn't happen tracing::warn!("attempted to send internal event via control service"); Event::Shutdown(ShutdownEvent {}) } } }