From bcdf8c8cde31e826000f1b2d6eeaebdd865a07c1 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Mon, 8 Jun 2026 09:12:45 -0700 Subject: feat: Capture command output + expose to new `atuin_output` tool (#3510) --- crates/atuin-ai/Cargo.toml | 2 + crates/atuin-ai/src/commands/inline.rs | 3 +- crates/atuin-ai/src/context.rs | 27 +- crates/atuin-ai/src/driver.rs | 16 +- crates/atuin-ai/src/history_format.rs | 120 ++++ crates/atuin-ai/src/lib.rs | 1 + crates/atuin-ai/src/permissions/check.rs | 6 +- crates/atuin-ai/src/stream.rs | 13 +- crates/atuin-ai/src/tools/descriptor.rs | 10 + crates/atuin-ai/src/tools/mod.rs | 281 ++++++-- crates/atuin-ai/src/tui/view/mod.rs | 1 + crates/atuin-ai/src/tui/view/turn.rs | 1 + crates/atuin-client/src/settings.rs | 2 + crates/atuin-daemon/build.rs | 1 + crates/atuin-daemon/proto/semantic.proto | 47 ++ crates/atuin-daemon/src/client.rs | 90 +++ crates/atuin-daemon/src/components/mod.rs | 3 + crates/atuin-daemon/src/components/semantic.rs | 900 +++++++++++++++++++++++++ crates/atuin-daemon/src/lib.rs | 9 +- crates/atuin-daemon/src/semantic/mod.rs | 3 + crates/atuin-daemon/src/server.rs | 6 + crates/atuin-pty-proxy/src/capture.rs | 467 +++++++++++++ crates/atuin-pty-proxy/src/debug.rs | 53 ++ crates/atuin-pty-proxy/src/lib.rs | 502 +------------- crates/atuin-pty-proxy/src/osc133.rs | 313 ++++++++- crates/atuin-pty-proxy/src/pty_proxy.rs | 231 +++++++ crates/atuin-pty-proxy/src/runtime.rs | 184 +++++ crates/atuin-pty-proxy/src/screen.rs | 104 +++ crates/atuin/Cargo.toml | 2 +- crates/atuin/src/command/mod.rs | 98 ++- crates/atuin/src/shell/atuin.bash | 35 + crates/atuin/src/shell/atuin.fish | 16 + crates/atuin/src/shell/atuin.nu | 24 + crates/atuin/src/shell/atuin.zsh | 39 ++ 34 files changed, 3015 insertions(+), 595 deletions(-) create mode 100644 crates/atuin-ai/src/history_format.rs create mode 100644 crates/atuin-daemon/proto/semantic.proto create mode 100644 crates/atuin-daemon/src/components/semantic.rs create mode 100644 crates/atuin-daemon/src/semantic/mod.rs create mode 100644 crates/atuin-pty-proxy/src/capture.rs create mode 100644 crates/atuin-pty-proxy/src/debug.rs create mode 100644 crates/atuin-pty-proxy/src/pty_proxy.rs create mode 100644 crates/atuin-pty-proxy/src/runtime.rs create mode 100644 crates/atuin-pty-proxy/src/screen.rs (limited to 'crates') diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 06e50a4e..027bd490 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -14,12 +14,14 @@ repository = { workspace = true } [features] default = [] +daemon = [] tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] [dependencies] async-trait = { workspace = true } atuin-client = { workspace = true } atuin-common = { workspace = true } +atuin-daemon = { workspace = true } tokio = { workspace = true } eyre = { workspace = true } clap = { workspace = true, features = ["derive", "env"] } diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 989b95c0..6d1f9c51 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -67,7 +67,7 @@ pub(crate) async fn run( settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - history_db.last().await.ok().flatten().map(|h| h.command) + history_db.last().await.ok().flatten() } else { None }; @@ -84,6 +84,7 @@ pub(crate) async fn run( history_db: std::sync::Arc::new(history_db), git_root, capabilities: settings.ai.capabilities.clone(), + daemon_enabled: settings.daemon.enabled, }; let action = run_inline_tui(ctx, initial_command, settings).await?; diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs index 625de0c6..f891a9fc 100644 --- a/crates/atuin-ai/src/context.rs +++ b/crates/atuin-ai/src/context.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::Arc; use atuin_client::distro::detect_linux_distribution; +use atuin_client::history::History; use atuin_client::settings::AiCapabilities; /// Session-scoped context for the AI chat session. @@ -11,12 +12,17 @@ pub(crate) struct AppContext { pub endpoint: String, pub token: String, pub send_cwd: bool, - pub last_command: Option, + pub last_command: Option, pub history_db: Arc, /// Git root of the current working directory, if inside a git repo. /// Resolves through worktrees to the main repo root. pub git_root: Option, pub capabilities: AiCapabilities, + pub daemon_enabled: bool, +} + +pub(crate) fn history_output_capability_available(daemon_enabled: bool) -> bool { + cfg!(feature = "daemon") && daemon_enabled } impl AppContext { @@ -33,6 +39,11 @@ impl AppContext { if self.capabilities.enable_command_execution.unwrap_or(true) { caps.push("client_v1_execute_shell_command".to_string()); } + if history_output_capability_available(self.daemon_enabled) + && self.capabilities.enable_history_output.unwrap_or(true) + { + caps.push("client_v1_atuin_output".to_string()); + } caps.push("client_v1_load_skill".to_string()); if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( @@ -69,7 +80,11 @@ impl ClientContext { /// Serialize to the JSON format the API expects for the "context" field. /// The `pwd` field is always dynamic (current working directory), so it's /// computed fresh on each call if `send_cwd` is true. - pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value { + pub(crate) fn to_json( + &self, + send_cwd: bool, + last_command: Option<&History>, + ) -> serde_json::Value { let mut ctx = serde_json::json!({ "os": self.os, "shell": self.shell, @@ -78,9 +93,15 @@ impl ClientContext { } else { None }, - "last_command": last_command, }); + if let Some(history) = last_command { + ctx["last_command"] = serde_json::json!(crate::history_format::format_last_command( + history, + crate::history_format::current_local_offset(), + )); + } + if let Some(ref distro) = self.distro { ctx["distro"] = serde_json::json!(distro); } diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index ddb839b7..82d666ef 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -492,6 +492,7 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { messages.clone(), session_id.clone(), &app.capabilities, + app.daemon_enabled, fsm.ctx.invocation_id.clone(), ); tokio::spawn(async move { @@ -570,7 +571,6 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { Effect::ExecuteTool { tool_id, tool } => { let tool_id = tool_id.clone(); - let tool = tool.clone(); let tx = tx.clone(); let db = io.app_ctx.history_db.clone(); @@ -731,8 +731,9 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { preview: None, })); } - ClientToolCall::AtuinHistory(_) => { + ClientToolCall::AtuinHistory(tool) => { // History search needs async DB access + let tool = tool.clone(); tokio::spawn(async move { let outcome = tool.execute(&db).await; let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { @@ -742,6 +743,17 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { })); }); } + ClientToolCall::AtuinOutput(tool) => { + let tool = tool.clone(); + tokio::spawn(async move { + let outcome = tool.execute().await; + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + }); + } ClientToolCall::LoadSkill(skill_call) => { let skill_name = skill_call.name.clone(); let registry = io.skill_registry.clone(); diff --git a/crates/atuin-ai/src/history_format.rs b/crates/atuin-ai/src/history_format.rs new file mode 100644 index 00000000..24aa963e --- /dev/null +++ b/crates/atuin-ai/src/history_format.rs @@ -0,0 +1,120 @@ +use atuin_client::history::History; +use time::UtcOffset; + +pub(crate) fn current_local_offset() -> UtcOffset { + UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC) +} + +pub(crate) fn format_last_command(history: &History, local_offset: UtcOffset) -> String { + format!( + "History ID: {} - `{}`\n{}", + history.id, + history.command, + format_history_metadata(history, local_offset) + ) +} + +pub(crate) fn format_history_search_result( + ordinal: usize, + history: &History, + local_offset: UtcOffset, +) -> String { + format!( + "## #{}. (History ID: {}):\n`{}`\n{}\n", + ordinal, + history.id, + history.command, + format_history_metadata(history, local_offset) + ) +} + +fn format_history_metadata(history: &History, local_offset: UtcOffset) -> String { + format!( + "[{}] (in `{}`, exit {}){}", + format_timestamp(history, local_offset), + history.cwd, + history.exit, + format_duration(history.duration) + ) +} + +fn format_timestamp(history: &History, local_offset: UtcOffset) -> String { + let ts = history.timestamp.to_offset(local_offset); + format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) +} + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} + +#[cfg(test)] +mod tests { + use atuin_client::history::{History, HistoryId}; + use time::{OffsetDateTime, UtcOffset}; + + use super::*; + + fn history(duration: i64) -> History { + History { + id: HistoryId("018f011c-9a0a-7000-8000-000000000001".to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration, + exit: 2, + command: "cargo test".to_string(), + cwd: "/repo".to_string(), + session: String::new(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + #[test] + fn formats_last_command() { + assert_eq!( + format_last_command(&history(1_234_000_000), UtcOffset::UTC), + "History ID: 018f011c-9a0a-7000-8000-000000000001 - `cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2), 1.234s" + ); + } + + #[test] + fn formats_history_search_result() { + assert_eq!( + format_history_search_result(3, &history(0), UtcOffset::UTC), + "## #3. (History ID: 018f011c-9a0a-7000-8000-000000000001):\n`cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2)\n" + ); + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index b3587739..f972d4ff 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -7,6 +7,7 @@ pub(crate) mod edit_permissions; pub(crate) mod event_serde; pub(crate) mod file_tracker; pub(crate) mod fsm; +pub(crate) mod history_format; pub(crate) mod permissions; pub(crate) mod session; pub(crate) mod skills; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs index 96abc3ab..bb1eae0c 100644 --- a/crates/atuin-ai/src/permissions/check.rs +++ b/crates/atuin-ai/src/permissions/check.rs @@ -1,13 +1,13 @@ use eyre::Result; -use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; +use crate::{permissions::file::RuleFile, tools::PermissibleToolCall}; pub(crate) struct PermissionRequest<'t> { - call: &'t (dyn PermissableToolCall + Send + Sync), + call: &'t (dyn PermissibleToolCall + Send + Sync), } impl<'t> PermissionRequest<'t> { - pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + pub fn new(call: &'t (dyn PermissibleToolCall + Send + Sync)) -> Self { Self { call } } } diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 084e8238..e78dc2e1 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -2,7 +2,10 @@ // SSE streaming // ─────────────────────────────────────────────────────────────────── +use atuin_client::history::History; use atuin_client::settings::AiCapabilities; + +use crate::context::history_output_capability_available; use atuin_common::tls::ensure_crypto_provider; use eventsource_stream::Eventsource; @@ -61,6 +64,7 @@ impl ChatRequest { messages: Vec, session_id: Option, capabilities: &AiCapabilities, + history_output_available: bool, invocation_id: String, ) -> Self { let mut caps = vec![ @@ -78,6 +82,11 @@ impl ChatRequest { if capabilities.enable_command_execution.unwrap_or(true) { caps.push("client_v1_execute_shell_command".to_string()); } + if history_output_capability_available(history_output_available) + && capabilities.enable_history_output.unwrap_or(true) + { + caps.push("client_v1_atuin_output".to_string()); + } if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( extra @@ -103,7 +112,7 @@ pub(crate) fn create_chat_stream( request: ChatRequest, client_ctx: ClientContext, send_cwd: bool, - last_command: Option, + last_command: Option, user_contexts: Vec, skill_summaries: Vec, skill_overflow: Option, @@ -120,7 +129,7 @@ pub(crate) fn create_chat_stream( tracing::debug!("Sending SSE request to {endpoint}"); - let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + let context = client_ctx.to_json(send_cwd, last_command.as_ref()); let mut config = serde_json::json!({ "capabilities": request.capabilities, diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs index 06858bf8..4190540c 100644 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -67,6 +67,15 @@ pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { is_client: true, }; +pub(crate) const ATUIN_OUTPUT: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_output"], + capability: Some("client_v1_atuin_output"), + display_verb: "view the output for command", + progressive_verb: "Viewing output...", + past_verb: "Viewed output", + is_client: true, +}; + pub(crate) const LOAD_SKILL: &ToolDescriptor = &ToolDescriptor { canonical_names: &["load_skill"], capability: Some("client_v1_load_skill"), @@ -104,6 +113,7 @@ const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ WRITE, SHELL, ATUIN_HISTORY, + ATUIN_OUTPUT, LOAD_SKILL, SERVER_SEARCH, SERVER_SCRAPE, diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index fdda10a4..d1352661 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -5,6 +5,7 @@ use std::{ }; use eyre::Result; +use uuid::Uuid; const DEFAULT_FILE_READ_LINES: u64 = 100; const MAX_FILE_READ_LINES: u64 = 1000; @@ -158,6 +159,7 @@ pub(crate) enum ClientToolCall { Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), + AtuinOutput(AtuinOutputToolCall), LoadSkill(LoadSkillToolCall), } @@ -173,6 +175,9 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, )), + "atuin_output" => Ok(ClientToolCall::AtuinOutput(AtuinOutputToolCall::try_from( + input, + )?)), "load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from( input, )?)), @@ -189,6 +194,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => descriptor::WRITE, ClientToolCall::Shell(_) => descriptor::SHELL, ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + ClientToolCall::AtuinOutput(_) => descriptor::ATUIN_OUTPUT, ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL, } } @@ -205,6 +211,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => "Write", ClientToolCall::Shell(_) => "Shell", ClientToolCall::AtuinHistory(_) => "AtuinHistory", + ClientToolCall::AtuinOutput(_) => "AtuinOutput", ClientToolCall::LoadSkill(_) => "LoadSkill", } } @@ -218,6 +225,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => Some(tool.resolved_path()), ClientToolCall::Shell(_) | ClientToolCall::AtuinHistory(_) + | ClientToolCall::AtuinOutput(_) | ClientToolCall::LoadSkill(_) => None, } } @@ -229,6 +237,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.matches_rule(rule), ClientToolCall::Shell(tool) => tool.matches_rule(rule), ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinOutput(tool) => tool.matches_rule(rule), ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule), } } @@ -240,26 +249,14 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.target_dir(), ClientToolCall::Shell(tool) => tool.target_dir(), ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + ClientToolCall::AtuinOutput(tool) => tool.target_dir(), ClientToolCall::LoadSkill(tool) => tool.target_dir(), } } - - /// Execute this client-side tool and return the result. - pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { - match self { - ClientToolCall::Read(tool) => tool.execute(), - ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, - // LoadSkill is handled separately by the driver (needs registry access) - ClientToolCall::LoadSkill(_) => { - ToolOutcome::Error("LoadSkill must be executed via the driver".to_string()) - } - _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), - } - } } /// A trait for tool calls that can be checked against permission rules. -pub(crate) trait PermissableToolCall { +pub(crate) trait PermissibleToolCall { /// Checks if this tool call matches the given permission rule. fn matches_rule(&self, rule: &Rule) -> bool; @@ -277,7 +274,7 @@ pub(crate) trait PermissableToolCall { } } -impl PermissableToolCall for ClientToolCall { +impl PermissibleToolCall for ClientToolCall { fn matches_rule(&self, rule: &Rule) -> bool { self.matches_rule(rule) } @@ -416,7 +413,7 @@ impl ReadToolCall { } } -impl PermissableToolCall for ReadToolCall { +impl PermissibleToolCall for ReadToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -616,7 +613,7 @@ impl EditToolCall { } } -impl PermissableToolCall for EditToolCall { +impl PermissibleToolCall for EditToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -724,7 +721,7 @@ impl WriteToolCall { } } -impl PermissableToolCall for WriteToolCall { +impl PermissibleToolCall for WriteToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -792,7 +789,7 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { } } -impl PermissableToolCall for ShellToolCall { +impl PermissibleToolCall for ShellToolCall { fn target_dir(&self) -> Option<&Path> { self.dir.as_deref() } @@ -1134,7 +1131,7 @@ impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { } } -impl PermissableToolCall for AtuinHistoryToolCall { +impl PermissibleToolCall for AtuinHistoryToolCall { fn target_dir(&self) -> Option<&Path> { None } @@ -1148,7 +1145,6 @@ impl AtuinHistoryToolCall { pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { use atuin_client::database::{self, Database as _, OptFilters}; use atuin_client::settings::SearchMode; - use time::UtcOffset; let context = match database::current_context().await { Ok(ctx) => ctx, @@ -1184,34 +1180,13 @@ impl AtuinHistoryToolCall { return ToolOutcome::Success("No matching history entries found.".to_string()); } - let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + let local_offset = crate::history_format::current_local_offset(); let formatted: Vec = results .iter() .enumerate() - .map(|(i, h)| { - let ts = h.timestamp.to_offset(local_offset); - let time_str = format!( - "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", - ts.year(), - ts.month() as u8, - ts.day(), - ts.hour(), - ts.minute(), - ts.second(), - ); - - let duration_str = format_duration(h.duration); - - format!( - "{}. `{}` [{}] ({}, exit: {}){}", - i + 1, - h.command, - time_str, - h.cwd, - h.exit, - duration_str, - ) + .map(|(i, history)| { + crate::history_format::format_history_search_result(i + 1, history, local_offset) }) .collect(); @@ -1219,6 +1194,146 @@ impl AtuinHistoryToolCall { } } +#[derive(Debug, Clone)] +pub(crate) struct AtuinOutputToolCall { + pub history_id: Uuid, + pub ranges: Vec<(i64, i64)>, +} + +impl TryFrom<&serde_json::Value> for AtuinOutputToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result { + let history_id = value + .get("history_id") + .and_then(|v| v.as_str()) + .and_then(|v| Uuid::parse_str(v).ok()) + .ok_or(eyre::eyre!("Missing or invalid history ID"))?; + + let ranges = value + .get("ranges") + .and_then(|v| v.as_array()) + .map(Vec::as_slice) + .unwrap_or(&[]); + + let ranges = ranges + .iter() + .map(|r| { + let range = r + .as_array() + .filter(|a| a.len() == 2) + .ok_or_else(|| eyre::eyre!("Each range must be a [start, end] array"))?; + + let start = range[0] + .as_i64() + .ok_or_else(|| eyre::eyre!("Range start must be an integer"))?; + let end = range[1] + .as_i64() + .ok_or_else(|| eyre::eyre!("Range end must be an integer"))?; + + Ok((start, end)) + }) + .collect::, eyre::Error>>()?; + + Ok(Self { history_id, ranges }) + } +} + +impl PermissibleToolCall for AtuinOutputToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "AtuinOutput" + } +} + +fn format_output_lines_for_llm(lines: &[atuin_daemon::semantic::OutputLine]) -> String { + let width = lines + .iter() + .map(|line| line.line_number) + .max() + .unwrap_or(1) + .max(1) + .ilog10() as usize + + 1; + let mut formatted = Vec::with_capacity(lines.len()); + let mut previous_line_number = None; + + for line in lines { + if let Some(previous) = previous_line_number { + let skipped = line.line_number.saturating_sub(previous + 1); + if skipped > 0 { + formatted.push(format!("[...skipped {skipped} lines...]")); + } + } + + formatted.push(format!("{:>width$}\t{}", line.line_number, line.content)); + previous_line_number = Some(line.line_number); + } + + formatted.join("\n") +} + +impl AtuinOutputToolCall { + pub(crate) async fn execute(&self) -> ToolOutcome { + let settings = match atuin_client::settings::Settings::new() { + Ok(settings) => settings, + Err(e) => return ToolOutcome::Error(format!("Failed to load Atuin settings: {e}")), + }; + + let mut client = match atuin_daemon::SemanticClient::from_settings(&settings).await { + Ok(client) => client, + Err(e) => return ToolOutcome::Error(format!("Failed to connect to Atuin daemon: {e}")), + }; + + let history_id = self.history_id.as_simple().to_string(); + let response = match client + .command_output(history_id.clone(), self.ranges.clone()) + .await + { + Ok(response) => response, + Err(e) => return ToolOutcome::Error(format!("Failed to fetch command output: {e}")), + }; + + if !response.found { + return ToolOutcome::Success(format!( + "No captured output found for history ID {history_id}." + )); + } + + if response.total_lines == 0 { + return ToolOutcome::Success(format!( + "Captured output for history ID {history_id} is empty." + )); + } + + let output = format_output_lines_for_llm(&response.lines); + if output.is_empty() { + return ToolOutcome::Success(format!( + "No lines selected from captured output for history ID {history_id}." + )); + } + + let total_output = if response.output_truncated { + format!( + "{} bytes captured, {} bytes observed before truncation, {} lines", + response.total_bytes, response.output_observed_bytes, response.total_lines + ) + } else { + format!( + "{} bytes, {} lines", + response.total_bytes, response.total_lines + ) + }; + + ToolOutcome::Success(format!( + "History ID: {history_id}\nTotal output: {total_output}\nSelected output:\n{output}" + )) + } +} + #[derive(Debug, Clone)] pub(crate) struct LoadSkillToolCall { pub name: String, @@ -1239,7 +1354,7 @@ impl TryFrom<&serde_json::Value> for LoadSkillToolCall { } } -impl PermissableToolCall for LoadSkillToolCall { +impl PermissibleToolCall for LoadSkillToolCall { fn target_dir(&self) -> Option<&Path> { None } @@ -1285,6 +1400,52 @@ mod tests { // ── Cross-platform tests ── + #[test] + fn atuin_output_ranges_are_optional() { + let input = serde_json::json!({ + "history_id": "018f0000000070008000000000000000" + }); + + let call = AtuinOutputToolCall::try_from(&input).unwrap(); + + assert_eq!( + call.history_id.as_simple().to_string(), + "018f0000000070008000000000000000" + ); + assert!(call.ranges.is_empty()); + } + + #[test] + fn atuin_output_parses_line_ranges() { + let input = serde_json::json!({ + "history_id": "018f0000000070008000000000000000", + "ranges": [[0, 30], [-100, -1]] + }); + + let call = AtuinOutputToolCall::try_from(&input).unwrap(); + + assert_eq!(call.ranges, vec![(0, 30), (-100, -1)]); + } + + #[test] + fn atuin_output_formats_lines_like_read_file() { + let lines = vec![ + atuin_daemon::semantic::OutputLine { + line_number: 98, + content: "near end".to_string(), + }, + atuin_daemon::semantic::OutputLine { + line_number: 100, + content: "end".to_string(), + }, + ]; + + assert_eq!( + format_output_lines_for_llm(&lines), + " 98\tnear end\n[...skipped 1 lines...]\n100\tend" + ); + } + #[test] fn no_scope_matches_everything() { assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); @@ -1996,31 +2157,3 @@ mod tests { } } } - -fn format_duration(nanos: i64) -> String { - if nanos <= 0 { - return String::new(); - } - - let total_secs = nanos / 1_000_000_000; - let millis = (nanos % 1_000_000_000) / 1_000_000; - - if total_secs >= 3600 { - let hours = total_secs / 3600; - let mins = (total_secs % 3600) / 60; - let secs = total_secs % 60; - format!(", {hours}h{mins}m{secs}s") - } else if total_secs >= 60 { - let mins = total_secs / 60; - let secs = total_secs % 60; - format!(", {mins}m{secs}s") - } else if total_secs > 0 { - if millis > 0 { - format!(", {total_secs}.{millis:03}s") - } else { - format!(", {total_secs}s") - } - } else { - format!(", {millis}ms") - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 73dc2ad7..b594cedf 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -168,6 +168,7 @@ fn tool_call_view(tool_call: &crate::fsm::tools::TrackedTool, in_git_project: bo ClientToolCall::Write(tool) => tool.path.display().to_string(), ClientToolCall::Shell(tool) => tool.command.clone(), ClientToolCall::AtuinHistory(tool) => tool.query.clone(), + ClientToolCall::AtuinOutput(tool) => tool.history_id.to_string(), ClientToolCall::LoadSkill(tool) => format!("skill: {}", tool.name), }; diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index c74395b8..aa1f55fa 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -495,6 +495,7 @@ impl<'a> TurnBuilder<'a> { query: history.query.clone(), filter_modes: history.filter_modes.clone(), }, + ClientToolCall::AtuinOutput(_) => ToolRenderData::Remote, ClientToolCall::LoadSkill(skill) => ToolRenderData::SkillLoad { _name: skill.name.clone(), }, diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 4df404c4..1be6f363 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -687,6 +687,8 @@ pub struct Ai { pub struct AiCapabilities { /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). pub enable_history_search: Option, + /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_output: Option, /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). pub enable_file_tools: Option, /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). diff --git a/crates/atuin-daemon/build.rs b/crates/atuin-daemon/build.rs index 7034aa04..7808a07b 100644 --- a/crates/atuin-daemon/build.rs +++ b/crates/atuin-daemon/build.rs @@ -7,6 +7,7 @@ fn main() -> std::io::Result<()> { "proto/history.proto", "proto/search.proto", "proto/control.proto", + "proto/semantic.proto", ]; let proto_include_dirs = ["proto"]; diff --git a/crates/atuin-daemon/proto/semantic.proto b/crates/atuin-daemon/proto/semantic.proto new file mode 100644 index 00000000..07e550c8 --- /dev/null +++ b/crates/atuin-daemon/proto/semantic.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; +package semantic; + +service Semantic { + rpc RecordCommands(stream CommandCapture) returns (RecordCommandsReply); + rpc CommandOutput(CommandOutputRequest) returns (CommandOutputReply); +} + +message CommandCapture { + string prompt = 1; + string command = 2; + string output = 3; + optional int32 exit_code = 4; + optional string history_id = 5; + optional string session_id = 6; + bool output_truncated = 7; + uint64 output_observed_bytes = 8; +} + +message RecordCommandsReply { + uint64 accepted = 1; +} + +message CommandOutputRequest { + string history_id = 1; + repeated OutputRange ranges = 2; +} + +message OutputRange { + int64 start = 1; + int64 end = 2; +} + +message OutputLine { + uint64 line_number = 1; + string content = 2; +} + +message CommandOutputReply { + bool found = 1; + string output = 2; + uint64 total_bytes = 3; + uint64 total_lines = 4; + repeated OutputLine lines = 5; + bool output_truncated = 6; + uint64 output_observed_bytes = 7; +} diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index 5f4ce20f..c18e0e46 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -30,6 +30,10 @@ use crate::search::{ FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, search_client::SearchClient as SearchServiceClient, }; +use crate::semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputRange, RecordCommandsReply, + semantic_client::SemanticClient as SemanticServiceClient, +}; pub struct HistoryClient { client: HistoryServiceClient, @@ -256,6 +260,92 @@ impl From for RpcSearchContext { } } +pub struct SemanticClient { + client: SemanticServiceClient, +} + +impl SemanticClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(not(unix))] + pub async fn new(port: u64) -> Result { + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let url = format!("127.0.0.1:{port}"); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result { + Self::new(settings.daemon.socket_path.clone()).await + } + + #[cfg(not(unix))] + pub async fn from_settings(settings: &Settings) -> Result { + Self::new(settings.daemon.tcp_port).await + } + + pub async fn record_commands( + &mut self, + captures: Vec, + ) -> Result { + let stream = tokio_stream::iter(captures); + Ok(self.client.record_commands(stream).await?.into_inner()) + } + + pub async fn command_output( + &mut self, + history_id: String, + ranges: Vec<(i64, i64)>, + ) -> Result { + let request = CommandOutputRequest { + history_id, + ranges: ranges + .into_iter() + .map(|(start, end)| OutputRange { start, end }) + .collect(), + }; + + Ok(self.client.command_output(request).await?.into_inner()) + } +} + // ============================================================================ // Control Client // ============================================================================ diff --git a/crates/atuin-daemon/src/components/mod.rs b/crates/atuin-daemon/src/components/mod.rs index 5950d5d5..447e31df 100644 --- a/crates/atuin-daemon/src/components/mod.rs +++ b/crates/atuin-daemon/src/components/mod.rs @@ -11,12 +11,15 @@ //! //! - [`history::HistoryComponent`]: Command history lifecycle management //! - [`search::SearchComponent`]: Fuzzy search over history +//! - [`semantic::SemanticComponent`]: In-memory semantic command captures //! - [`sync::SyncComponent`]: Cloud sync pub mod history; pub mod search; +pub mod semantic; pub mod sync; pub use history::HistoryComponent; pub use search::SearchComponent; +pub use semantic::SemanticComponent; pub use sync::SyncComponent; diff --git a/crates/atuin-daemon/src/components/semantic.rs b/crates/atuin-daemon/src/components/semantic.rs new file mode 100644 index 00000000..dff38fd3 --- /dev/null +++ b/crates/atuin-daemon/src/components/semantic.rs @@ -0,0 +1,900 @@ +//! Semantic command capture component. +//! +//! This is a prototype in-memory store for completed command captures emitted +//! by atuin-pty-proxy. It keeps recent captures per Atuin session and indexes +//! them by history ID for AI tool lookup. + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use atuin_client::history::{History, HistoryId}; +use eyre::Result; +use tokio::sync::Mutex; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, instrument}; + +use crate::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputLine, RecordCommandsReply, + semantic_server::{Semantic as SemanticSvc, SemanticServer}, + }, +}; + +const MAX_SESSIONS: usize = 20; +const MAX_COMMANDS_PER_SESSION: usize = 128; +const MAX_BYTES_PER_SESSION: usize = 32 * 1024 * 1024; +const MAX_PENDING_HISTORIES: usize = 128; + +/// Stores completed command captures and associates them with history events. +pub struct SemanticComponent { + inner: Arc, +} + +struct SemanticComponentInner { + state: Mutex, +} + +#[derive(Default)] +struct SemanticState { + sessions: HashMap, + session_lru: VecDeque, + history_index: HashMap, + pending_histories: VecDeque, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct SessionId(String); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct CaptureId(u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CaptureRef { + session_id: SessionId, + capture_id: CaptureId, +} + +#[derive(Default)] +struct SessionCaptures { + next_id: u64, + records: VecDeque, + output_bytes: usize, +} + +struct StoredCapture { + id: CaptureId, + history_id: HistoryId, + output_bytes: usize, + record: SemanticCommandRecord, +} + +struct EvictedCapture { + history_id: HistoryId, + capture_id: CaptureId, +} + +#[derive(Debug, Clone)] +struct SemanticCommandRecord { + capture: CommandCapture, + history: Option, +} + +impl SemanticComponent { + pub fn new() -> Self { + Self { + inner: Arc::new(SemanticComponentInner { + state: Mutex::new(SemanticState::default()), + }), + } + } + + pub fn grpc_service(&self) -> SemanticServer { + SemanticServer::new(SemanticGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for SemanticComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SemanticComponent { + fn name(&self) -> &'static str { + "semantic" + } + + async fn start(&mut self, _handle: DaemonHandle) -> Result<()> { + tracing::info!("semantic component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::HistoryEnded(history) = event { + self.inner.record_history(history.clone()).await; + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + let state = self.inner.state.lock().await; + tracing::info!( + sessions = state.sessions.len(), + records = state.record_count(), + indexed_histories = state.history_index.len(), + pending_histories = state.pending_histories.len(), + "semantic component stopped" + ); + Ok(()) + } +} + +impl SemanticComponentInner { + async fn record_capture(&self, capture: CommandCapture) -> bool { + let mut state = self.state.lock().await; + state.record_capture(capture) + } + + async fn record_history(&self, history: History) { + let mut state = self.state.lock().await; + state.record_history(history); + } + + async fn command_output(&self, request: &CommandOutputRequest) -> CommandOutputReply { + let mut state = self.state.lock().await; + state.command_output(request) + } +} + +impl SemanticState { + fn record_capture(&mut self, mut capture: CommandCapture) -> bool { + let Some(history_id) = history_id_from_str(capture.history_id.as_deref()) else { + tracing::debug!( + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without history id" + ); + return false; + }; + + let history = take_pending_history(&mut self.pending_histories, &history_id); + let Some(session_id) = capture + .session_id + .as_deref() + .and_then(|session_id| SessionId::try_from(session_id).ok()) + .or_else(|| { + history + .as_ref() + .and_then(|history| SessionId::try_from(history.session.as_str()).ok()) + }) + else { + tracing::debug!( + history_id = %history_id, + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without session id" + ); + return false; + }; + + capture.history_id = Some(history_id.to_string()); + capture.session_id = Some(session_id.to_string()); + if capture.output_observed_bytes == 0 { + capture.output_observed_bytes = capture.output.len() as u64; + } + + let record = SemanticCommandRecord { capture, history }; + log_record(&record, "recorded semantic command capture"); + self.push_record(session_id, history_id, record); + true + } + + fn record_history(&mut self, history: History) { + let history_id = history.id.clone(); + + if let Some(capture_ref) = self.history_index.get(&history_id).cloned() { + if let Some(stored) = self.stored_capture_mut(&capture_ref) { + stored.record.history = Some(history); + log_record( + &stored.record, + "associated semantic command capture with history", + ); + return; + } + + self.history_index.remove(&history_id); + } + + tracing::debug!( + id = %history.id, + command_bytes = history.command.len(), + "history ended before semantic capture arrived" + ); + push_pending_history(&mut self.pending_histories, history); + } + + fn command_output(&mut self, request: &CommandOutputRequest) -> CommandOutputReply { + let Some(history_id) = history_id_from_str(Some(&request.history_id)) else { + return command_output_not_found(); + }; + let Some(capture_ref) = self.history_index.get(&history_id).cloned() else { + return command_output_not_found(); + }; + + let Some(reply) = self.command_output_for_ref(&capture_ref, &request.ranges) else { + self.history_index.remove(&history_id); + return command_output_not_found(); + }; + + self.touch_session(&capture_ref.session_id); + reply + } + + fn command_output_for_ref( + &self, + capture_ref: &CaptureRef, + ranges: &[crate::semantic::OutputRange], + ) -> Option { + let stored = self + .sessions + .get(&capture_ref.session_id)? + .stored_capture(capture_ref.capture_id)?; + let output = &stored.record.capture.output; + let output_observed_bytes = stored + .record + .capture + .output_observed_bytes + .max(output.len() as u64); + + Some(CommandOutputReply { + found: true, + output: String::new(), + total_bytes: output.len() as u64, + total_lines: output.lines().count() as u64, + lines: select_output_ranges(output, ranges), + output_truncated: stored.record.capture.output_truncated, + output_observed_bytes, + }) + } + + fn push_record( + &mut self, + session_id: SessionId, + history_id: HistoryId, + record: SemanticCommandRecord, + ) { + self.touch_session(&session_id); + + let (capture_id, evicted) = { + let session = self.sessions.entry(session_id.clone()).or_default(); + session.push(history_id.clone(), record) + }; + + let capture_ref = CaptureRef { + session_id: session_id.clone(), + capture_id, + }; + self.history_index.insert(history_id, capture_ref); + + for evicted in evicted { + self.remove_history_index_if_matches( + &session_id, + &evicted.history_id, + evicted.capture_id, + ); + } + + self.expire_lru_sessions(); + } + + fn touch_session(&mut self, session_id: &SessionId) { + if let Some(index) = self.session_lru.iter().position(|id| id == session_id) { + self.session_lru.remove(index); + } + self.session_lru.push_back(session_id.clone()); + } + + fn expire_lru_sessions(&mut self) { + while self.session_lru.len() > MAX_SESSIONS { + let Some(session_id) = self.session_lru.pop_front() else { + break; + }; + let Some(session) = self.sessions.remove(&session_id) else { + continue; + }; + + for stored in session.records { + self.remove_history_index_if_matches(&session_id, &stored.history_id, stored.id); + } + } + } + + fn remove_history_index_if_matches( + &mut self, + session_id: &SessionId, + history_id: &HistoryId, + capture_id: CaptureId, + ) { + if self + .history_index + .get(history_id) + .is_some_and(|capture_ref| { + &capture_ref.session_id == session_id && capture_ref.capture_id == capture_id + }) + { + self.history_index.remove(history_id); + } + } + + fn stored_capture_mut(&mut self, capture_ref: &CaptureRef) -> Option<&mut StoredCapture> { + self.sessions + .get_mut(&capture_ref.session_id)? + .stored_capture_mut(capture_ref.capture_id) + } + + fn record_count(&self) -> usize { + self.sessions + .values() + .map(|session| session.records.len()) + .sum() + } +} + +impl SessionCaptures { + fn push( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + ) -> (CaptureId, Vec) { + self.push_with_limits( + history_id, + record, + MAX_COMMANDS_PER_SESSION, + MAX_BYTES_PER_SESSION, + ) + } + + fn push_with_limits( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + max_commands: usize, + max_output_bytes: usize, + ) -> (CaptureId, Vec) { + let capture_id = CaptureId(self.next_id); + self.next_id = self.next_id.saturating_add(1); + let output_bytes = record.capture.output.len(); + self.output_bytes = self.output_bytes.saturating_add(output_bytes); + self.records.push_back(StoredCapture { + id: capture_id, + history_id, + output_bytes, + record, + }); + + ( + capture_id, + self.evict_to_limits(max_commands, max_output_bytes), + ) + } + + fn evict_to_limits( + &mut self, + max_commands: usize, + max_output_bytes: usize, + ) -> Vec { + let mut evicted = Vec::new(); + while self.records.len() > max_commands || self.output_bytes > max_output_bytes { + let Some(record) = self.records.pop_front() else { + break; + }; + self.output_bytes = self.output_bytes.saturating_sub(record.output_bytes); + evicted.push(EvictedCapture { + history_id: record.history_id, + capture_id: record.id, + }); + } + evicted + } + + fn stored_capture(&self, capture_id: CaptureId) -> Option<&StoredCapture> { + self.records.iter().find(|record| record.id == capture_id) + } + + fn stored_capture_mut(&mut self, capture_id: CaptureId) -> Option<&mut StoredCapture> { + self.records + .iter_mut() + .find(|record| record.id == capture_id) + } +} + +impl TryFrom<&str> for SessionId { + type Error = (); + + fn try_from(value: &str) -> std::result::Result { + let value = value.trim(); + if value.is_empty() { + return Err(()); + } + + Ok(Self(value.to_string())) + } +} + +impl TryFrom for SessionId { + type Error = (); + + fn try_from(value: String) -> std::result::Result { + Self::try_from(value.as_str()) + } +} + +impl AsRef for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Display for SessionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +pub struct SemanticGrpcService { + inner: Arc, +} + +#[tonic::async_trait] +impl SemanticSvc for SemanticGrpcService { + #[instrument(skip_all, level = Level::INFO)] + async fn record_commands( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); + let mut accepted = 0_u64; + + while let Some(capture) = stream.message().await? { + if self.inner.record_capture(capture).await { + accepted += 1; + } + } + + Ok(Response::new(RecordCommandsReply { accepted })) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn command_output( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + if request.history_id.trim().is_empty() { + return Err(Status::invalid_argument("history_id is required")); + } + + Ok(Response::new(self.inner.command_output(&request).await)) + } +} + +fn history_id_from_str(value: Option<&str>) -> Option { + let value = value?.trim(); + (!value.is_empty()).then(|| HistoryId(value.to_string())) +} + +fn take_pending_history( + histories: &mut VecDeque, + history_id: &HistoryId, +) -> Option { + let index = histories + .iter() + .position(|history| &history.id == history_id)?; + histories.remove(index) +} + +fn push_pending_history(histories: &mut VecDeque, history: History) { + if let Some(index) = histories + .iter() + .position(|pending| pending.id == history.id) + { + histories.remove(index); + } + + histories.push_back(history); + trim_front(histories, MAX_PENDING_HISTORIES); +} + +fn trim_front(records: &mut VecDeque, max_len: usize) { + while records.len() > max_len { + records.pop_front(); + } +} + +fn command_output_not_found() -> CommandOutputReply { + CommandOutputReply { + found: false, + output: String::new(), + total_bytes: 0, + total_lines: 0, + lines: Vec::new(), + output_truncated: false, + output_observed_bytes: 0, + } +} + +fn select_output_ranges(output: &str, ranges: &[crate::semantic::OutputRange]) -> Vec { + let lines: Vec<&str> = output.lines().collect(); + if lines.is_empty() { + return Vec::new(); + } + + let ranges = if ranges.is_empty() { + vec![crate::semantic::OutputRange { start: 0, end: 999 }] + } else { + ranges.to_vec() + }; + + let mut ranges = ranges + .into_iter() + .filter_map(|range| normalize_line_range(range.start, range.end, lines.len())) + .collect::>(); + ranges.sort_unstable_by_key(|(start, _)| *start); + + let mut merged: Vec<(usize, usize)> = Vec::new(); + for (start, end) in ranges { + match merged.last_mut() { + Some((_, merged_end)) if start <= merged_end.saturating_add(1) => { + *merged_end = (*merged_end).max(end); + } + _ => merged.push((start, end)), + } + } + + merged + .into_iter() + .flat_map(|(start, end)| { + lines[start..=end] + .iter() + .enumerate() + .map(move |(offset, line)| OutputLine { + line_number: (start + offset + 1) as u64, + content: (*line).to_string(), + }) + }) + .collect() +} + +fn normalize_line_range(start: i64, end: i64, line_count: usize) -> Option<(usize, usize)> { + let line_count = i64::try_from(line_count).ok()?; + let start = if start < 0 { line_count + start } else { start }; + let end = if end < 0 { line_count + end } else { end }; + + if end < 0 || start >= line_count { + return None; + } + + let start = start.max(0); + let end = end.min(line_count - 1); + + (start <= end).then_some((start as usize, end as usize)) +} + +fn log_record(record: &SemanticCommandRecord, message: &'static str) { + let history_id = record.capture.history_id.as_deref().unwrap_or(""); + let associated_history_id = record + .history + .as_ref() + .map(|history| history.id.to_string()); + let exit = record.history.as_ref().map(|history| history.exit); + let duration = record.history.as_ref().map(|history| history.duration); + let author = record + .history + .as_ref() + .map(|history| history.author.as_str()); + let session_id = record.capture.session_id.as_deref(); + + tracing::debug!( + history_id = %history_id, + associated_history_id = ?associated_history_id, + session_id = ?session_id, + command_bytes = record.capture.command.len(), + prompt_bytes = record.capture.prompt.len(), + output_bytes = record.capture.output.len(), + output_truncated = record.capture.output_truncated, + output_observed_bytes = record.capture.output_observed_bytes, + capture_exit_code = ?record.capture.exit_code, + history_exit = ?exit, + duration = ?duration, + author = ?author, + "{message}" + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use time::OffsetDateTime; + + fn history(id: &str, session: &str, command: &str) -> History { + History { + id: HistoryId(id.to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration: 0, + exit: 0, + command: command.to_string(), + cwd: String::new(), + session: session.to_string(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + fn capture(history_id: Option<&str>, session_id: Option<&str>, output: &str) -> CommandCapture { + CommandCapture { + prompt: String::new(), + command: String::new(), + output: output.to_string(), + exit_code: None, + history_id: history_id.map(str::to_string), + session_id: session_id.map(str::to_string), + output_truncated: false, + output_observed_bytes: output.len() as u64, + } + } + + fn command_output(state: &mut SemanticState, history_id: &str) -> CommandOutputReply { + state.command_output(&CommandOutputRequest { + history_id: history_id.to_string(), + ranges: Vec::new(), + }) + } + + fn output_line(line_number: u64, content: &str) -> OutputLine { + OutputLine { + line_number, + content: content.to_string(), + } + } + + #[test] + fn drops_capture_without_history_id() { + let mut state = SemanticState::default(); + + assert!(!state.record_capture(capture(None, Some("session-1"), "output"))); + assert!(!command_output(&mut state, "id-1").found); + assert_eq!(state.record_count(), 0); + } + + #[test] + fn stores_capture_by_session_and_history_id() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.found); + assert_eq!(reply.total_bytes, 6); + assert_eq!(reply.output_observed_bytes, 6); + assert_eq!(reply.lines, vec![output_line(1, "output")]); + } + + #[test] + fn uses_pending_history_session_when_capture_session_is_missing() { + let mut state = SemanticState::default(); + + state.record_history(history("id-1", "session-from-history", "cargo test")); + assert!(state.record_capture(capture(Some("id-1"), None, "output"))); + + assert!( + state + .sessions + .contains_key(&SessionId("session-from-history".to_string())) + ); + assert!(command_output(&mut state, "id-1").found); + } + + #[test] + fn associates_history_by_id_after_capture_arrives() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + state.record_history(history("id-1", "session-1", "different command")); + + let capture_ref = state + .history_index + .get(&HistoryId("id-1".to_string())) + .unwrap(); + let stored = state + .sessions + .get(&capture_ref.session_id) + .unwrap() + .stored_capture(capture_ref.capture_id) + .unwrap(); + assert!(stored.record.history.is_some()); + } + + #[test] + fn evicts_oldest_command_when_session_ring_is_full() { + let mut state = SemanticState::default(); + + for index in 0..=MAX_COMMANDS_PER_SESSION { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some("session-1"), + "output", + ))); + } + + assert!(!command_output(&mut state, "id-0").found); + assert!(command_output(&mut state, &format!("id-{MAX_COMMANDS_PER_SESSION}")).found); + assert_eq!(state.record_count(), MAX_COMMANDS_PER_SESSION); + } + + #[test] + fn evicts_oldest_session_after_lru_limit() { + let mut state = SemanticState::default(); + + for index in 0..MAX_SESSIONS { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some(&format!("session-{index}")), + "output", + ))); + } + assert!(command_output(&mut state, "id-0").found); + + assert!(state.record_capture(capture(Some("new-id"), Some("new-session"), "output",))); + + assert!(command_output(&mut state, "id-0").found); + assert!(!command_output(&mut state, "id-1").found); + assert!(command_output(&mut state, "new-id").found); + assert_eq!(state.sessions.len(), MAX_SESSIONS); + } + + #[test] + fn evicts_by_session_byte_limit() { + let mut session = SessionCaptures::default(); + let first_output = "x".repeat(10); + let second_output = "y"; + let (_, evicted_first) = session.push_with_limits( + HistoryId("first".to_string()), + SemanticCommandRecord { + capture: capture(Some("first"), Some("session-1"), &first_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + assert!(evicted_first.is_empty()); + + let (_, evicted_second) = session.push_with_limits( + HistoryId("second".to_string()), + SemanticCommandRecord { + capture: capture(Some("second"), Some("session-1"), second_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + + assert_eq!(evicted_second.len(), 1); + assert_eq!(evicted_second[0].history_id, HistoryId("first".to_string())); + assert_eq!(session.records.len(), 1); + assert_eq!(session.output_bytes, 1); + } + + #[test] + fn command_output_reports_truncation_metadata() { + let mut state = SemanticState::default(); + let mut capture = capture(Some("id-1"), Some("session-1"), "partial"); + capture.output_truncated = true; + capture.output_observed_bytes = 1024; + + assert!(state.record_capture(capture)); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.output_truncated); + assert_eq!(reply.total_bytes, 7); + assert_eq!(reply.output_observed_bytes, 1024); + } + + #[test] + fn output_ranges_are_line_based_inclusive_and_support_negative_offsets() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::semantic::OutputRange { start: 1, end: 2 }, + crate::semantic::OutputRange { start: -2, end: -1 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(2, "one"), + output_line(3, "two"), + output_line(4, "three"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn output_ranges_merge_overlaps_and_adjacent_ranges() { + let output = (0..100) + .map(|n| format!("line {n}")) + .collect::>() + .join("\n"); + let ranges = vec![ + crate::semantic::OutputRange { start: 0, end: 100 }, + crate::semantic::OutputRange { + start: -100, + end: -1, + }, + ]; + + let selected = select_output_ranges(&output, &ranges); + + assert_eq!(selected.len(), 100); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(100, "line 99"))); + } + + #[test] + fn output_ranges_can_leave_gaps_for_client_formatting() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::semantic::OutputRange { start: 0, end: 1 }, + crate::semantic::OutputRange { start: 4, end: 4 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(1, "zero"), + output_line(2, "one"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn empty_output_ranges_default_to_first_thousand_lines() { + let output = (0..1001) + .map(|n| format!("line {n}")) + .collect::>() + .join("\n"); + + let selected = select_output_ranges(&output, &[]); + + assert_eq!(selected.len(), 1000); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(1000, "line 999"))); + } + + #[test] + fn output_ranges_skip_ranges_fully_outside_output() { + let output = "zero\none\ntwo"; + let ranges = vec![ + crate::semantic::OutputRange { start: 10, end: 20 }, + crate::semantic::OutputRange { + start: -20, + end: -10, + }, + ]; + + assert_eq!(select_output_ranges(output, &ranges), Vec::new()); + } +} diff --git a/crates/atuin-daemon/src/lib.rs b/crates/atuin-daemon/src/lib.rs index 84f808e4..27d3932b 100644 --- a/crates/atuin-daemon/src/lib.rs +++ b/crates/atuin-daemon/src/lib.rs @@ -10,6 +10,7 @@ pub mod daemon; pub mod events; pub mod history; pub mod search; +pub mod semantic; pub mod server; // Re-export core daemon types for convenience @@ -17,10 +18,10 @@ pub use daemon::{Component, Daemon, DaemonBuilder, DaemonHandle}; pub use events::DaemonEvent; // Re-export components -pub use components::{HistoryComponent, SearchComponent, SyncComponent}; +pub use components::{HistoryComponent, SearchComponent, SemanticComponent, SyncComponent}; // Re-export client helpers -pub use client::{ControlClient, emit_event, emit_event_with_settings}; +pub use client::{ControlClient, SemanticClient, emit_event, emit_event_with_settings}; /// Boot the daemon using the new component-based architecture. /// @@ -34,12 +35,14 @@ pub async fn boot( // Create the components let history_component = HistoryComponent::new(); let search_component = SearchComponent::new(); + let semantic_component = SemanticComponent::new(); let sync_component = SyncComponent::new(); // Get the gRPC services before moving components into the daemon // (The services share state with the components via Arc) let history_service = history_component.grpc_service(); let search_service = search_component.grpc_service(); + let semantic_service = semantic_component.grpc_service(); // Build the daemon let mut daemon = Daemon::builder(settings.clone()) @@ -47,6 +50,7 @@ pub async fn boot( .history_db(history_db) .component(history_component) .component(search_component) + .component(semantic_component) .component(sync_component) .build() .await?; @@ -93,6 +97,7 @@ pub async fn boot( settings, history_service, search_service, + semantic_service, control_service.into_server(), handle, ) diff --git a/crates/atuin-daemon/src/semantic/mod.rs b/crates/atuin-daemon/src/semantic/mod.rs new file mode 100644 index 00000000..c3511676 --- /dev/null +++ b/crates/atuin-daemon/src/semantic/mod.rs @@ -0,0 +1,3 @@ +//! Semantic command capture gRPC service types. + +tonic::include_proto!("semantic"); diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs index a11de612..b823cff2 100644 --- a/crates/atuin-daemon/src/server.rs +++ b/crates/atuin-daemon/src/server.rs @@ -2,10 +2,12 @@ use eyre::Result; use crate::components::history::HistoryGrpcService; use crate::components::search::SearchGrpcService; +use crate::components::semantic::SemanticGrpcService; use crate::control::{ControlService, control_server::ControlServer}; use crate::daemon::DaemonHandle; use crate::history::history_server::HistoryServer; use crate::search::search_server::SearchServer; +use crate::semantic::semantic_server::SemanticServer; use atuin_client::settings::Settings; @@ -18,6 +20,7 @@ pub async fn run_grpc_server( settings: Settings, history_service: HistoryServer, search_service: SearchServer, + semantic_service: SemanticServer, control_service: ControlServer, handle: DaemonHandle, ) -> Result<()> { @@ -101,6 +104,7 @@ pub async fn run_grpc_server( if let Err(e) = Server::builder() .add_service(history_service) .add_service(search_service) + .add_service(semantic_service) .add_service(control_service) .serve_with_incoming_shutdown(uds_stream, shutdown_signal) .await @@ -118,6 +122,7 @@ pub async fn run_grpc_server( settings: Settings, history_service: HistoryServer, search_service: SearchServer, + semantic_service: SemanticServer, control_service: ControlServer, handle: DaemonHandle, ) -> Result<()> { @@ -152,6 +157,7 @@ pub async fn run_grpc_server( if let Err(e) = Server::builder() .add_service(history_service) .add_service(search_service) + .add_service(semantic_service) .add_service(control_service) .serve_with_incoming_shutdown(tcp_stream, shutdown_signal) .await diff --git a/crates/atuin-pty-proxy/src/capture.rs b/crates/atuin-pty-proxy/src/capture.rs new file mode 100644 index 00000000..6426035b --- /dev/null +++ b/crates/atuin-pty-proxy/src/capture.rs @@ -0,0 +1,467 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; + +use crate::osc133::{Event, Params, Parser, Zone}; + +const HISTORY_ID_PARAM: &str = "history_id"; +const SESSION_ID_PARAM: &str = "session_id"; +const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandCapture { + pub prompt: String, + pub command: String, + pub output: String, + pub exit_code: Option, + pub history_id: Option, + pub session_id: Option, + pub output_truncated: bool, + pub output_observed_bytes: u64, +} + +pub type CommandCaptureSink = Box; + +#[derive(Default)] +struct CaptureBuffers { + prompt: Vec, + command: Vec, + output: Vec, + output_observed_bytes: u64, + output_truncated: bool, + exit_code: Option, + history_id: Option, + session_id: Option, +} + +pub(crate) struct CommandCaptureTracker { + parser: Parser, + zone: Zone, + buffers: CaptureBuffers, + cols: Arc, +} + +impl CommandCaptureTracker { + pub(crate) fn new(cols: Arc) -> Self { + Self { + parser: Parser::new(), + zone: Zone::Unknown, + buffers: CaptureBuffers::default(), + cols, + } + } + + pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + let mut start = 0; + for located in events { + let marker_start = located.start_offset.min(data.len()).max(start); + let offset = located.offset.min(data.len()); + self.append(&data[start..marker_start]); + self.handle_event(located.event, &located.params, &mut on_capture); + self.zone = located.zone; + start = offset; + } + + let append_end = self + .parser + .incomplete_osc_sequence_start() + .map_or(data.len(), |sequence_start| { + sequence_start.min(data.len()).max(start) + }); + if start < append_end { + self.append(&data[start..append_end]); + } + } + + fn append(&mut self, data: &[u8]) { + match self.zone { + Zone::Prompt => self.buffers.prompt.extend_from_slice(data), + Zone::Input => self.buffers.command.extend_from_slice(data), + Zone::Output => self.append_output(data), + Zone::Unknown => {} + } + } + + fn append_output(&mut self, data: &[u8]) { + self.buffers.output_observed_bytes = self + .buffers + .output_observed_bytes + .saturating_add(data.len() as u64); + + if self.buffers.output_truncated { + return; + } + + let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); + let retained = data.len().min(remaining); + self.buffers.output_truncated = retained < data.len(); + + if retained > 0 { + self.buffers.output.extend_from_slice(&data[..retained]); + } + } + + fn handle_event( + &mut self, + event: Event, + params: &Params, + on_capture: &mut impl FnMut(CommandCapture), + ) { + match event { + Event::PromptStart => { + if self.zone != Zone::Prompt { + self.buffers = CaptureBuffers::default(); + } + } + Event::CommandStart | Event::CommandExecuted => {} + Event::CommandFinished { exit_code } => { + let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { + return; + }; + + if exit_code.is_some() || self.buffers.exit_code.is_none() { + self.buffers.exit_code = exit_code; + } + self.buffers.history_id = Some(history_id); + self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); + + if let Some(capture) = self.finish_capture() { + on_capture(capture); + } + } + } + } + + fn finish_capture(&mut self) -> Option { + let buffers = std::mem::take(&mut self.buffers); + let cols = self.cols.load(Ordering::Relaxed).max(1); + let prompt = render_plain_text(&buffers.prompt, cols); + let command = render_plain_text(&buffers.command, cols) + .trim_matches(|c| c == '\r' || c == '\n') + .to_string(); + let output = render_plain_text(&buffers.output, cols); + let output_truncated = buffers.output_truncated; + let output_observed_bytes = buffers.output_observed_bytes; + let exit_code = buffers.exit_code; + let history_id = buffers.history_id; + let session_id = buffers.session_id; + + if command.is_empty() && output.is_empty() { + return None; + } + + Some(CommandCapture { + prompt, + command, + output, + exit_code, + history_id, + session_id, + output_truncated, + output_observed_bytes, + }) + } +} + +const CLEAN_TEXT_MAX_ROWS: usize = 10_000; + +fn render_plain_text(bytes: &[u8], cols: u16) -> String { + if bytes.is_empty() { + return String::new(); + } + + let cols = cols.max(1); + let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); + parser.process(bytes); + normalize_screen_contents(&parser.screen().contents()) +} + +fn normalize_screen_contents(contents: &str) -> String { + let mut lines = contents.lines().map(str::trim_end).collect::>(); + while lines.last().is_some_and(|line| line.is_empty()) { + lines.pop(); + } + lines.join("\n") +} + +fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { + let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; + let wrapped_rows = bytes.len() / cols as usize; + newline_rows + .saturating_add(wrapped_rows) + .saturating_add(1) + .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tracker(cols: u16) -> CommandCaptureTracker { + CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) + } + + fn assert_no_terminal_controls(text: &str) { + assert!( + !text + .chars() + .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), + "text still contains terminal controls: {text:?}" + ); + } + + #[test] + fn command_text_collapses_terminal_echo_edits() { + assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); + assert_eq!( + render_plain_text( + b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", + 80 + ), + "echo hi" + ); + assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); + } + + #[test] + fn text_cleaning_strips_ansi_and_terminal_controls() { + let text = render_plain_text( + b"\x1b[32mhi\x1b[0m\r\n% \r \r", + 80, + ); + + assert_eq!(text, "hi"); + assert_no_terminal_controls(&text); + } + + #[test] + fn text_cleaning_preserves_valid_utf8_after_backspace() { + let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); + + assert_eq!(text, "🦀 crab"); + assert_no_terminal_controls(&text); + } + + #[test] + fn command_text_replays_backspaces() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + let input = + b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; + tracker.push(input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + assert_no_terminal_controls(&captures[0].command); + assert_no_terminal_controls(&captures[0].output); + } + + #[test] + fn captures_complete_command() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "$".to_string(), + command: "echo hi".to_string(), + output: "hi".to_string(), + exit_code: Some(0), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 4, + }] + ); + } + + #[test] + fn strips_ansi_and_split_markers() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); + tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); + tracker.push( + b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", + |capture| { + captures.push(capture); + }, + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "%".to_string(), + command: "ls".to_string(), + output: "file".to_string(), + exit_code: Some(1), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 15, + }] + ); + } + + #[test] + fn duplicate_prompt_start_does_not_reset_prompt_capture() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].prompt, "$ continued"); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + } + + #[test] + fn bare_finish_without_metadata_is_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + + tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); + + assert!(captures.is_empty()); + } + + #[test] + fn bare_finish_before_metadata_in_same_push_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn metadata_arriving_after_bare_finish_across_pushes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { + captures.push(capture) + }); + + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn split_finish_marker_is_not_counted_as_output() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", + |capture| { + captures.push(capture); + }, + ); + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].output_observed_bytes, 10); + } + + #[test] + fn captures_output_with_history_metadata_from_d_marker() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: String::new(), + command: String::new(), + output: "line one".to_string(), + exit_code: Some(0), + history_id: Some("018f".to_string()), + session_id: Some("abcd".to_string()), + output_truncated: false, + output_observed_bytes: 10, + }] + ); + } + + #[test] + fn output_capture_is_capped_and_reports_observed_bytes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + let mut input = b"\x1b]133;C\x07".to_vec(); + input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); + input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); + + tracker.push(&input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert!(captures[0].output_truncated); + assert_eq!( + captures[0].output_observed_bytes, + (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 + ); + } + + #[test] + fn resets_buffers_between_c_d_only_captures() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 2); + assert_eq!(captures[0].output, "first"); + assert_eq!(captures[0].history_id.as_deref(), Some("one")); + assert_eq!(captures[1].output, "second"); + assert_eq!(captures[1].history_id.as_deref(), Some("two")); + } +} diff --git a/crates/atuin-pty-proxy/src/debug.rs b/crates/atuin-pty-proxy/src/debug.rs new file mode 100644 index 00000000..806bde90 --- /dev/null +++ b/crates/atuin-pty-proxy/src/debug.rs @@ -0,0 +1,53 @@ +use crate::osc133::{Event, Parser}; + +pub(crate) const RESET: &[u8] = b"\x1b[0m"; + +pub(crate) struct Osc133DebugHighlighter { + parser: Parser, +} + +impl Osc133DebugHighlighter { + pub(crate) fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub(crate) fn render(&mut self, data: &[u8]) -> Vec { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + if events.is_empty() { + return data.to_vec(); + } + + let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); + let mut start = 0; + + for located in events { + let offset = located.offset.min(data.len()); + if offset > start { + rendered.extend_from_slice(&data[start..offset]); + } + + rendered.extend_from_slice(event_label(&located.event)); + rendered.extend_from_slice(RESET); + start = offset; + } + + rendered.extend_from_slice(&data[start..]); + rendered + } +} + +fn event_label(event: &Event) -> &'static [u8] { + match event { + Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", + Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", + Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", + Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", + Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", + Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", + } +} diff --git a/crates/atuin-pty-proxy/src/lib.rs b/crates/atuin-pty-proxy/src/lib.rs index 16b29dff..65b03df3 100644 --- a/crates/atuin-pty-proxy/src/lib.rs +++ b/crates/atuin-pty-proxy/src/lib.rs @@ -1,478 +1,48 @@ -pub mod osc133; - -use clap::{Args, Subcommand, ValueEnum}; - -#[derive(Subcommand, Debug)] -pub enum Cmd { - /// Print shell code to initialize atuin pty-proxy on shell startup - Init(Init), -} - -#[derive(Args, Debug)] -pub struct Init { - /// Shell to generate init for. If omitted, attempt auto-detection - #[arg(value_enum)] - shell: Option, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] -#[value(rename_all = "lower")] -#[allow(clippy::enum_variant_names, clippy::doc_markdown)] -enum Shell { - /// Zsh setup - Zsh, - /// Bash setup - Bash, - /// Fish setup - Fish, - /// Nu setup - Nu, -} - -impl Init { - fn run(self) -> Result<(), String> { - let shell = detect_shell(self.shell)?; - let script = render_init(shell); - print!("{script}"); - Ok(()) - } -} - -pub fn run(cmd: Option) { - match cmd { - Some(Cmd::Init(init)) => { - if let Err(err) = init.run() { - eprintln!("atuin pty-proxy: {err}"); - std::process::exit(1); - } - } - None => app::main(), - } -} - -fn detect_shell(cli_shell: Option) -> Result { - if let Some(shell) = cli_shell { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("ATUIN_SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - Err( - "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" - .to_string(), - ) -} - -fn shell_from_name(name: &str) -> Option { - let shell = name - .trim() - .rsplit('/') - .next() - .unwrap_or(name) - .trim_start_matches('-') - .to_ascii_lowercase(); - - match shell.as_str() { - "bash" => Some(Shell::Bash), - "zsh" => Some(Shell::Zsh), - "fish" => Some(Shell::Fish), - "nu" => Some(Shell::Nu), - _ => None, - } -} - -fn render_init(shell: Shell) -> &'static str { - match shell { - Shell::Bash | Shell::Zsh => { - r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then - _atuin_pty_proxy_tmux_current="${TMUX:-}" - _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-${ATUIN_HEX_TMUX:-}}" - - if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-${ATUIN_HEX_ACTIVE:-}}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then - export ATUIN_PTY_PROXY_ACTIVE=1 - export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - fi - - unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous -fi -"# - } - Shell::Fish => { - r#"if status is-interactive; and test -t 0; and test -t 1 - set -l _atuin_pty_proxy_tmux_current "" - if set -q TMUX - set _atuin_pty_proxy_tmux_current "$TMUX" - end - - set -l _atuin_pty_proxy_tmux_previous "" - if set -q ATUIN_PTY_PROXY_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" - else if set -q ATUIN_HEX_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_HEX_TMUX" - end - - if not set -q ATUIN_PTY_PROXY_ACTIVE; and not set -q ATUIN_HEX_ACTIVE - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - end -end -"# - } - // Nushell cannot dynamically source the output of `atuin init nu`, - // so we only output the pty-proxy preamble here. Users must also set up - // `atuin init nu` separately. - Shell::Nu => { - r#"if (is-terminal --stdin) and (is-terminal --stdout) { - let tmux_current = ($env.TMUX? | default "") - let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default ($env.ATUIN_HEX_TMUX? | default "")) - - if (($env.ATUIN_PTY_PROXY_ACTIVE? | default ($env.ATUIN_HEX_ACTIVE? | default "")) | is-empty) or ($tmux_current != $tmux_previous) { - $env.ATUIN_PTY_PROXY_ACTIVE = "1" - $env.ATUIN_PTY_PROXY_TMUX = $tmux_current - exec atuin pty-proxy - } -} -"# - } - } -} - -#[cfg(not(unix))] -mod app { - pub(crate) fn main() { - eprintln!("atuin pty-proxy currently supports unix platforms"); - std::process::exit(1); - } -} - #[cfg(unix)] -mod app { - use std::io::{Read, Write}; - use std::os::unix::net::UnixListener; - use std::sync::mpsc; - - use crossterm::terminal; - use portable_pty::{CommandBuilder, PtySize, native_pty_system}; - - enum ParserMsg { - Data(Vec), - Resize { rows: u16, cols: u16 }, - ScreenRequest(mpsc::Sender>), - } - - pub(crate) fn main() { - if let Err(e) = run() { - let _ = terminal::disable_raw_mode(); - eprintln!("atuin pty-proxy: {e:#}"); - std::process::exit(1); - } - } - - fn socket_path() -> std::path::PathBuf { - let dir = std::env::temp_dir(); - dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) - } - - /// Wire format written to the Unix socket: - /// - /// ```text - /// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] - /// [row_0_len: u32 BE][row_0_bytes...] - /// [row_1_len: u32 BE][row_1_bytes...] - /// ... - /// ``` - /// - /// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain - /// pre-built ANSI escape sequences. The client can write them directly to - /// stdout without needing its own vt100 parser. - fn encode_screen(parser: &vt100::Parser) -> Vec { - let screen = parser.screen(); - let (rows, cols) = screen.size(); - let (cursor_row, cursor_col) = screen.cursor_position(); - - let mut buf: Vec = Vec::with_capacity(256 + (rows as usize * cols as usize)); - buf.extend_from_slice(&rows.to_be_bytes()); - buf.extend_from_slice(&cols.to_be_bytes()); - buf.extend_from_slice(&cursor_row.to_be_bytes()); - buf.extend_from_slice(&cursor_col.to_be_bytes()); - - for row_bytes in screen.rows_formatted(0, cols) { - let len = row_bytes.len() as u32; - buf.extend_from_slice(&len.to_be_bytes()); - buf.extend_from_slice(&row_bytes); - } - - buf - } - - fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) { - match msg { - ParserMsg::Data(data) => parser.process(&data), - ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), - ParserMsg::ScreenRequest(reply_tx) => { - let _ = reply_tx.send(encode_screen(parser)); - } - } - } - - fn run() -> eyre::Result<()> { - let (cols, rows) = terminal::size()?; - - let pty_system = native_pty_system(); - let pair = pty_system - .openpty(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Set up socket path and expose it to child processes - let sock_path = socket_path(); - // Clean up any stale socket from a previous crash - let _ = std::fs::remove_file(&sock_path); - - let mut cmd = CommandBuilder::new_default_prog(); - cmd.cwd(std::env::current_dir()?); - cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); - cmd.env("ATUIN_HEX_SOCKET", sock_path.as_os_str()); - - let mut child = pair - .slave - .spawn_command(cmd) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Close slave side in parent process - drop(pair.slave); - - let mut pty_reader = pair - .master - .try_clone_reader() - .map_err(|e| eyre::eyre!("{e:#}"))?; - let mut pty_writer = pair - .master - .take_writer() - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Channel: stdout/sigwinch/socket threads -> parser thread (bounded, non-blocking send) - let (msg_tx, msg_rx) = mpsc::sync_channel::(64); - - // --- Parser thread --- - // Maintains a persistent vt100::Parser fed bytes as they arrive. - // On screen request: reads current state directly (no replay). - std::thread::spawn(move || { - let mut parser = vt100::Parser::new(rows, cols, 0); - - loop { - // Block until at least one message arrives - let first = match msg_rx.recv() { - Ok(msg) => msg, - Err(_) => break, - }; - - handle_parser_msg(&mut parser, first); - - // Drain all remaining pending messages so the parser stays - // caught up during high-throughput bursts (e.g. `cat bigfile`). - // The channel holds at most 64 items, so this is bounded. - while let Ok(msg) = msg_rx.try_recv() { - handle_parser_msg(&mut parser, msg); - } - } - }); - - // --- Socket server thread --- - // Listens on Unix socket; on connection, requests screen state from parser thread. - { - let sock_path_clone = sock_path.clone(); - let screen_tx = msg_tx.clone(); - std::thread::spawn(move || { - let listener = match UnixListener::bind(&sock_path_clone) { - Ok(l) => l, - Err(e) => { - eprintln!("atuin pty-proxy: failed to bind socket: {e}"); - return; - } - }; - - for stream in listener.incoming() { - let mut stream = match stream { - Ok(s) => s, - Err(_) => break, - }; - - let (reply_tx, reply_rx) = mpsc::channel(); - if screen_tx.send(ParserMsg::ScreenRequest(reply_tx)).is_err() { - break; - } - if let Ok(data) = reply_rx.recv() { - let _ = stream.write_all(&data); - let _ = stream.flush(); - } - } - }); - } - - // Handle terminal resize via SIGWINCH - { - use signal_hook::consts::SIGWINCH; - use signal_hook::iterator::Signals; - - let master = pair.master; - let resize_tx = msg_tx.clone(); - let mut signals = Signals::new([SIGWINCH])?; - - std::thread::spawn(move || { - for _ in signals.forever() { - if let Ok((cols, rows)) = terminal::size() { - let _ = master.resize(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }); - let _ = resize_tx.try_send(ParserMsg::Resize { rows, cols }); - } - } - }); - } - - terminal::enable_raw_mode()?; - - // PTY -> stdout (with OSC 133 parsing + buffer feed) - let stdout_thread = std::thread::spawn(move || { - let mut stdout = std::io::stdout(); - let mut parser = crate::osc133::Parser::new(); - let mut buf = [0u8; 8192]; - loop { - match pty_reader.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - parser.push(&buf[..n], |_event| { - // Zone transitions are tracked inside the parser. - // Callers can query parser.zone() after push. - }); - - // Feed bytes to the shadow parser. Drops on backpressure — - // the screen snapshot may be stale during bursts, but - // self-corrects once output settles. - let _ = msg_tx.try_send(ParserMsg::Data(buf[..n].to_vec())); - - if stdout.write_all(&buf[..n]).is_err() { - break; - } - let _ = stdout.flush(); - } - } - } - }); - - // stdin -> PTY - std::thread::spawn(move || { - let mut stdin = std::io::stdin(); - let mut buf = [0u8; 8192]; - loop { - match stdin.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - if pty_writer.write_all(&buf[..n]).is_err() { - break; - } - } - } - } - }); +mod capture; +#[cfg(unix)] +mod debug; +#[cfg(unix)] +mod osc133; +#[cfg(unix)] +mod pty_proxy; +#[cfg(unix)] +mod runtime; +#[cfg(unix)] +mod screen; - let status = child.wait()?; - let _ = stdout_thread.join(); +#[cfg(unix)] +pub use capture::{CommandCapture, CommandCaptureSink}; +#[cfg(unix)] +pub use pty_proxy::PtyProxy; - let _ = terminal::disable_raw_mode(); +#[cfg(not(unix))] +#[allow(dead_code)] +mod unsupported { + use clap::{Args, Subcommand}; - // Clean up socket file - let _ = std::fs::remove_file(&sock_path); + #[derive(Args, Debug)] + pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, - std::process::exit(process_exit_code(status.exit_code())); + #[command(subcommand)] + cmd: Option, } - fn process_exit_code(code: u32) -> i32 { - i32::try_from(code).unwrap_or(1) + #[derive(Subcommand, Debug)] + enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), } - #[cfg(test)] - mod tests { - use super::process_exit_code; - - #[test] - fn process_exit_code_preserves_valid_values() { - assert_eq!(process_exit_code(0), 0); - assert_eq!(process_exit_code(127), 127); - assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); - } - - #[test] - fn process_exit_code_defaults_when_out_of_range() { - assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); - } + #[derive(Args, Debug)] + struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + shell: Option, } } -#[cfg(test)] -mod tests { - use super::{Shell, render_init, shell_from_name}; - - #[test] - fn shell_from_name_handles_paths() { - assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); - assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); - assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); - assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); - } - - #[test] - fn posix_init_uses_exec_and_tmux_guard() { - let script = render_init(Shell::Bash); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(!script.contains("eval \"$(atuin init bash)\"")); - } - - #[test] - fn posix_init_has_no_double_braces() { - let script = render_init(Shell::Bash); - assert!(!script.contains("${{"), "double braces in bash init script"); - } - - #[test] - fn fish_init_uses_source() { - let script = render_init(Shell::Fish); - assert!(script.contains("exec atuin pty-proxy")); - assert!(!script.contains("atuin init fish | source")); - } - - #[test] - fn nu_init_uses_exec_and_tty_guard() { - let script = render_init(Shell::Nu); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(script.contains("is-terminal --stdin")); - assert!(script.contains("is-terminal --stdout")); - assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); - } -} +#[cfg(not(unix))] +pub use unsupported::PtyProxy; diff --git a/crates/atuin-pty-proxy/src/osc133.rs b/crates/atuin-pty-proxy/src/osc133.rs index d6ee1220..51fda848 100644 --- a/crates/atuin-pty-proxy/src/osc133.rs +++ b/crates/atuin-pty-proxy/src/osc133.rs @@ -9,18 +9,19 @@ //! | C | Command submitted — output begins | //! | D[;n] | Command finished with exit code *n* | //! -//! The wire format is `ESC ] 133 ; [; ] ST` where ST is either -//! BEL (0x07) or ESC \ (0x1B 0x5C). +//! The wire format is `ESC ] 133 ; [; ] ST` where ST is BEL +//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). //! //! # Design goals //! -//! * **Zero-copy** — the parser observes the byte stream without buffering or -//! modifying it. -//! * **Zero-alloc** — after construction no heap allocation occurs. +//! * **Transparent** — the parser observes the byte stream without modifying it; +//! the caller remains responsible for forwarding bytes to their destination. +//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot +//! grow memory without limit. //! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available //! and returns immediately. -//! * **Transparent** — the caller is responsible for forwarding bytes to their -//! destination; the parser only emits [`Event`]s through a callback. +//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata +//! can ride alongside standard OSC 133 markers. /// Events emitted when an OSC 133 marker is detected. #[derive(Debug, Clone, PartialEq, Eq)] @@ -38,6 +39,63 @@ pub enum Event { }, } +/// Parameters attached to an OSC 133 marker. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Params { + items: Vec, +} + +impl Params { + /// Iterate over all marker parameters in order. + #[cfg(test)] + #[inline] + pub fn iter(&self) -> impl Iterator { + self.items.iter() + } + + /// Return the value for the first `key=value` parameter with this key. + #[inline] + pub fn get(&self, key: &str) -> Option<&str> { + self.items.iter().find_map(|item| match item { + Param::KeyValue { + key: item_key, + value, + } if item_key == key => Some(value.as_str()), + Param::Value(_) | Param::KeyValue { .. } => None, + }) + } +} + +/// A single OSC 133 marker parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Param { + /// A positional parameter without an equals sign. + Value(String), + /// A `key=value` parameter. + KeyValue { key: String, value: String }, +} + +/// An OSC 133 event with its position in the most recent input chunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocatedEvent { + /// The OSC 133 event that was parsed. + pub event: Event, + /// Offset where this marker starts in the current chunk. + /// + /// If a marker started in an earlier [`Parser::push_located`] call, this is + /// `0` in the chunk that completed the marker. + pub start_offset: usize, + /// Offset immediately after this marker's terminator in the current chunk. + /// + /// If a marker spans multiple [`Parser::push_located`] calls, this is still + /// the offset in the chunk that completed the marker. + pub offset: usize, + /// The semantic zone after applying this event. + pub zone: Zone, + /// Metadata parameters attached to this marker. + pub params: Params, +} + /// The current semantic zone as determined by the most recent OSC 133 marker. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] #[allow(dead_code)] @@ -59,14 +117,14 @@ pub enum Zone { const ESC: u8 = 0x1B; const BEL: u8 = 0x07; +const C1_ST: u8 = 0x9C; const BACKSLASH: u8 = b'\\'; const RIGHT_BRACKET: u8 = b']'; -/// Maximum bytes we'll buffer for the OSC parameter string. 32 bytes is far -/// more than any valid OSC 133 payload needs (e.g. `133;D;127` is 9 bytes). -/// Longer (non-133) OSC sequences simply stop accumulating once the buffer is -/// full — the dispatch logic will harmlessly ignore them. -const PARAM_BUF_CAP: usize = 32; +/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough +/// for Atuin metadata such as history/session IDs while still bounding malformed +/// OSC sequences. +const PARAM_BUF_CAP: usize = 512; // --------------------------------------------------------------------------- // State machine @@ -94,6 +152,7 @@ enum State { pub struct Parser { state: State, zone: Zone, + sequence_start: Option, param_buf: [u8; PARAM_BUF_CAP], param_len: usize, } @@ -111,6 +170,7 @@ impl Parser { Self { state: State::Ground, zone: Zone::Unknown, + sequence_start: None, param_buf: [0u8; PARAM_BUF_CAP], param_len: 0, } @@ -123,18 +183,40 @@ impl Parser { self.zone } + /// Start offset of an incomplete OSC sequence in the most recent chunk. + #[inline] + pub(crate) fn incomplete_osc_sequence_start(&self) -> Option { + matches!(self.state, State::OscParam | State::OscEsc) + .then(|| self.sequence_start.unwrap_or(0)) + } + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker /// found. /// /// All bytes in `data` should still be forwarded to the terminal by the /// caller — this method only *observes* the stream. + #[cfg(test)] #[inline] pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { - for &byte in data { + self.push_located(data, |located| on_event(located.event)); + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found with its byte offset in this chunk. + /// + /// The offset points to the first byte after the marker terminator, making + /// it suitable for callers that need to split the original chunk at marker + /// boundaries. + #[inline] + pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { + self.sequence_start = (self.state != State::Ground).then_some(0); + + for (offset, &byte) in data.iter().enumerate() { match self.state { State::Ground => { if byte == ESC { self.state = State::Esc; + self.sequence_start = Some(offset); } } State::Esc => { @@ -143,12 +225,14 @@ impl Parser { self.param_len = 0; } else { self.state = State::Ground; + self.sequence_start = None; } } State::OscParam => { - if byte == BEL { - self.dispatch(&mut on_event); + if byte == BEL || byte == C1_ST { + self.dispatch(offset + 1, &mut on_event); self.state = State::Ground; + self.sequence_start = None; } else if byte == ESC { self.state = State::OscEsc; } else if self.param_len < PARAM_BUF_CAP { @@ -160,12 +244,13 @@ impl Parser { } State::OscEsc => { if byte == BACKSLASH { - self.dispatch(&mut on_event); + self.dispatch(offset + 1, &mut on_event); } // Whether we got a valid ST or not, return to ground. // (A new ESC ] would restart accumulation via the Ground // -> Esc -> OscParam path on the *next* byte.) self.state = State::Ground; + self.sequence_start = None; } } } @@ -174,46 +259,104 @@ impl Parser { /// Inspect the accumulated parameter buffer. If it holds an OSC 133 /// payload, emit the corresponding [`Event`] and update the zone. #[inline] - fn dispatch(&mut self, on_event: &mut impl FnMut(Event)) { - let params = &self.param_buf[..self.param_len]; + fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { + let payload = &self.param_buf[..self.param_len]; + + if payload.len() < 5 || &payload[..4] != b"133;" { + return; + } - // Must start with "133;" - if params.len() < 5 || ¶ms[..4] != b"133;" { + if payload.len() > 5 && payload[5] != b';' { return; } - let cmd = params[4]; - let event = match cmd { + let metadata = payload.get(6..).unwrap_or_default(); + let cmd = payload[4]; + let (event, params) = match cmd { b'A' => { self.zone = Zone::Prompt; - Event::PromptStart + (Event::PromptStart, parse_params(metadata)) } b'B' => { self.zone = Zone::Input; - Event::CommandStart + (Event::CommandStart, parse_params(metadata)) } b'C' => { self.zone = Zone::Output; - Event::CommandExecuted + (Event::CommandExecuted, parse_params(metadata)) } b'D' => { - let exit_code = if params.len() > 6 && params[5] == b';' { - std::str::from_utf8(¶ms[6..]) - .ok() - .and_then(|s| s.parse::().ok()) - } else { - None - }; + let (exit_code, params) = parse_command_finished_params(metadata); self.zone = Zone::Unknown; - Event::CommandFinished { exit_code } + (Event::CommandFinished { exit_code }, params) } _ => return, }; - on_event(event); + on_event(LocatedEvent { + event, + start_offset: self.sequence_start.unwrap_or(0), + offset, + zone: self.zone, + params, + }); } } +fn parse_command_finished_params(metadata: &[u8]) -> (Option, Params) { + if metadata.is_empty() { + return (None, Params::default()); + } + + let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { + return parse_exit_code(metadata).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), Params::default()), + ); + }; + + let (first, rest) = metadata.split_at(separator); + let rest = &rest[1..]; + + parse_exit_code(first).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), parse_params(rest)), + ) +} + +fn parse_exit_code(code: &[u8]) -> Option { + if code.is_empty() { + return None; + } + + std::str::from_utf8(code) + .ok() + .and_then(|code| code.parse::().ok()) +} + +fn parse_params(metadata: &[u8]) -> Params { + let items = metadata + .split(|byte| *byte == b';') + .filter(|part| !part.is_empty()) + .map(parse_param) + .collect(); + + Params { items } +} + +fn parse_param(param: &[u8]) -> Param { + let param = String::from_utf8_lossy(param); + + if let Some((key, value)) = param.split_once('=') { + return Param::KeyValue { + key: key.to_string(), + value: value.to_string(), + }; + } + + Param::Value(param.into_owned()) +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -468,6 +611,12 @@ mod tests { assert!(parse_events(data).is_empty()); } + #[test] + fn marker_with_unexpected_trailing_bytes_ignored() { + let data = b"\x1b]133;ABC\x07"; + assert!(parse_events(data).is_empty()); + } + // -- Malformed sequences -------------------------------------------------- #[test] @@ -509,7 +658,7 @@ mod tests { fn very_long_osc_does_not_panic() { let mut data = Vec::new(); data.extend_from_slice(b"\x1b]"); - data.extend(std::iter::repeat(b'x').take(1000)); + data.extend(std::iter::repeat_n(b'x', 1000)); data.push(BEL); // Should not panic and should produce no event. assert!(parse_events(&data).is_empty()); @@ -589,6 +738,100 @@ mod tests { ); } + #[test] + fn detects_c1_st_terminator() { + let data = b"\x1b]133;A\x9c"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + // -- Located event offsets ------------------------------------------------ + + #[test] + fn located_event_reports_offset_after_marker() { + let data = b"before\x1b]133;A\x07prompt"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(data, |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::PromptStart, + start_offset: b"before".len(), + offset: b"before\x1b]133;A\x07".len(), + zone: Zone::Prompt, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_offset_is_relative_to_completing_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;", |e| events.push(e)); + parser.push_located(b"D;42\x07after", |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::CommandFinished { + exit_code: Some(42) + }, + start_offset: 0, + offset: b"D;42\x07".len(), + zone: Zone::Unknown, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_preserves_metadata_params() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located( + b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", + |event| events.push(event), + ); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!( + event.event, + Event::CommandFinished { + exit_code: Some(127) + } + ); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + assert!( + event + .params + .iter() + .any(|param| param == &Param::Value("flag".to_string())) + ); + } + + #[test] + fn command_finished_metadata_without_exit_code_is_preserved() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { + events.push(event); + }); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!(event.event, Event::CommandFinished { exit_code: None }); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + } + // -- Default trait -------------------------------------------------------- #[test] diff --git a/crates/atuin-pty-proxy/src/pty_proxy.rs b/crates/atuin-pty-proxy/src/pty_proxy.rs new file mode 100644 index 00000000..030ef9b5 --- /dev/null +++ b/crates/atuin-pty-proxy/src/pty_proxy.rs @@ -0,0 +1,231 @@ +use clap::{Args, Subcommand, ValueEnum}; + +use crate::{CommandCaptureSink, runtime}; + +#[derive(Args, Debug)] +pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, + + #[command(subcommand)] + cmd: Option, +} + +#[derive(Subcommand, Debug)] +pub enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), +} + +#[derive(Args, Debug)] +pub struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + #[arg(value_enum)] + shell: Option, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +#[value(rename_all = "lower")] +#[allow(clippy::enum_variant_names, clippy::doc_markdown)] +enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, +} + +pub(crate) struct RuntimeOptions { + pub(crate) debug_osc133: bool, + pub(crate) command_capture_sink: Option, +} + +impl RuntimeOptions { + fn new(debug_osc133: bool, command_capture_sink: Option) -> Self { + Self { + debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), + command_capture_sink, + } + } +} + +impl PtyProxy { + pub fn run(self, command_capture_sink: Option) { + match self.cmd { + Some(Cmd::Init(init)) => { + if let Err(err) = init.run() { + eprintln!("atuin pty-proxy: {err}"); + std::process::exit(1); + } + } + None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), + } + } +} + +impl Init { + fn run(self) -> Result<(), String> { + let shell = detect_shell(self.shell)?; + let script = render_init(shell); + print!("{script}"); + Ok(()) + } +} + +fn detect_shell(cli_shell: Option) -> Result { + if let Some(shell) = cli_shell { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("ATUIN_SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + Err( + "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" + .to_string(), + ) +} + +fn shell_from_name(name: &str) -> Option { + let shell = name + .trim() + .rsplit('/') + .next() + .unwrap_or(name) + .trim_start_matches('-') + .to_ascii_lowercase(); + + match shell.as_str() { + "bash" => Some(Shell::Bash), + "zsh" => Some(Shell::Zsh), + "fish" => Some(Shell::Fish), + "nu" => Some(Shell::Nu), + _ => None, + } +} + +fn env_flag(name: &str) -> bool { + std::env::var(name).is_ok_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn render_init(shell: Shell) -> &'static str { + match shell { + Shell::Bash | Shell::Zsh => { + r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then + _atuin_pty_proxy_tmux_current="${TMUX:-}" + _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" + + if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then + export ATUIN_PTY_PROXY_ACTIVE=1 + export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + fi + + unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous +fi +"# + } + Shell::Fish => { + r#"if status is-interactive; and test -t 0; and test -t 1 + set -l _atuin_pty_proxy_tmux_current "" + if set -q TMUX + set _atuin_pty_proxy_tmux_current "$TMUX" + end + + set -l _atuin_pty_proxy_tmux_previous "" + if set -q ATUIN_PTY_PROXY_TMUX + set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" + end + + if not set -q ATUIN_PTY_PROXY_ACTIVE + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + end +end +"# + } + // Nushell cannot dynamically source the output of `atuin init nu`, + // so we only output the pty-proxy preamble here. Users must also set up + // `atuin init nu` separately. + Shell::Nu => { + r#"if (is-terminal --stdin) and (is-terminal --stdout) { + let tmux_current = ($env.TMUX? | default "") + let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") + + if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { + $env.ATUIN_PTY_PROXY_ACTIVE = "1" + $env.ATUIN_PTY_PROXY_TMUX = $tmux_current + exec atuin pty-proxy + } +} +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::{Shell, render_init, shell_from_name}; + + #[test] + fn shell_from_name_handles_paths() { + assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); + assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); + assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); + assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); + } + + #[test] + fn posix_init_uses_exec_and_tmux_guard() { + let script = render_init(Shell::Bash); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(!script.contains("eval \"$(atuin init bash)\"")); + } + + #[test] + fn posix_init_has_no_double_braces() { + let script = render_init(Shell::Bash); + assert!(!script.contains("${{"), "double braces in bash init script"); + } + + #[test] + fn fish_init_uses_source() { + let script = render_init(Shell::Fish); + assert!(script.contains("exec atuin pty-proxy")); + assert!(!script.contains("atuin init fish | source")); + } + + #[test] + fn nu_init_uses_exec_and_tty_guard() { + let script = render_init(Shell::Nu); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(script.contains("is-terminal --stdin")); + assert!(script.contains("is-terminal --stdout")); + assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); + } +} diff --git a/crates/atuin-pty-proxy/src/runtime.rs b/crates/atuin-pty-proxy/src/runtime.rs new file mode 100644 index 00000000..2b34fbb7 --- /dev/null +++ b/crates/atuin-pty-proxy/src/runtime.rs @@ -0,0 +1,184 @@ +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::mpsc; + +use crossterm::terminal; +use portable_pty::{CommandBuilder, PtySize, native_pty_system}; + +use crate::capture::CommandCaptureTracker; +use crate::debug::{Osc133DebugHighlighter, RESET}; +use crate::pty_proxy::RuntimeOptions; +use crate::screen::{self, Msg}; + +pub(crate) fn main(options: RuntimeOptions) { + if let Err(e) = run(options) { + let _ = terminal::disable_raw_mode(); + eprintln!("atuin pty-proxy: {e:#}"); + std::process::exit(1); + } +} + +fn run(options: RuntimeOptions) -> eyre::Result<()> { + let (cols, rows) = terminal::size()?; + + let pty_system = native_pty_system(); + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let sock_path = screen::socket_path(); + let _ = std::fs::remove_file(&sock_path); + + let mut cmd = CommandBuilder::new_default_prog(); + cmd.cwd(std::env::current_dir()?); + cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); + cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); + + let mut child = pair + .slave + .spawn_command(cmd) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + drop(pair.slave); + + let mut pty_reader = pair + .master + .try_clone_reader() + .map_err(|e| eyre::eyre!("{e:#}"))?; + let mut pty_writer = pair + .master + .take_writer() + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let (msg_tx, msg_rx) = mpsc::sync_channel::(64); + let current_cols = Arc::new(AtomicU16::new(cols.max(1))); + + screen::spawn_parser_thread(rows, cols, msg_rx); + screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); + spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; + + terminal::enable_raw_mode()?; + + let stdout_thread = std::thread::spawn(move || { + let mut stdout = std::io::stdout(); + let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); + let mut capture_tracker = options + .command_capture_sink + .as_ref() + .map(|_| CommandCaptureTracker::new(current_cols)); + let mut buf = [0u8; 8192]; + + loop { + match pty_reader.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if let (Some(tracker), Some(sink)) = ( + capture_tracker.as_mut(), + options.command_capture_sink.as_ref(), + ) { + tracker.push(&buf[..n], sink); + } + + if let Some(highlighter) = highlighter.as_mut() { + let rendered = highlighter.render(&buf[..n]); + let _ = msg_tx.try_send(Msg::Data(rendered.clone())); + + if stdout.write_all(&rendered).is_err() { + break; + } + } else { + let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); + + if stdout.write_all(&buf[..n]).is_err() { + break; + } + } + let _ = stdout.flush(); + } + } + } + + if highlighter.is_some() { + let _ = stdout.write_all(RESET); + let _ = stdout.flush(); + } + }); + + std::thread::spawn(move || { + let mut stdin = std::io::stdin(); + let mut buf = [0u8; 8192]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if pty_writer.write_all(&buf[..n]).is_err() { + break; + } + } + } + } + }); + + let status = child.wait()?; + let _ = stdout_thread.join(); + + let _ = terminal::disable_raw_mode(); + let _ = std::fs::remove_file(&sock_path); + + std::process::exit(process_exit_code(status.exit_code())); +} + +fn spawn_resize_handler( + master: Box, + resize_tx: mpsc::SyncSender, + current_cols: Arc, +) -> eyre::Result<()> { + use signal_hook::consts::SIGWINCH; + use signal_hook::iterator::Signals; + + let mut signals = Signals::new([SIGWINCH])?; + + std::thread::spawn(move || { + for _ in signals.forever() { + if let Ok((cols, rows)) = terminal::size() { + current_cols.store(cols.max(1), Ordering::Relaxed); + let _ = master.resize(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }); + let _ = resize_tx.try_send(Msg::Resize { rows, cols }); + } + } + }); + + Ok(()) +} + +fn process_exit_code(code: u32) -> i32 { + i32::try_from(code).unwrap_or(1) +} + +#[cfg(test)] +mod tests { + use super::process_exit_code; + + #[test] + fn process_exit_code_preserves_valid_values() { + assert_eq!(process_exit_code(0), 0); + assert_eq!(process_exit_code(127), 127); + assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); + } + + #[test] + fn process_exit_code_defaults_when_out_of_range() { + assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); + } +} diff --git a/crates/atuin-pty-proxy/src/screen.rs b/crates/atuin-pty-proxy/src/screen.rs new file mode 100644 index 00000000..5b892e21 --- /dev/null +++ b/crates/atuin-pty-proxy/src/screen.rs @@ -0,0 +1,104 @@ +use std::io::Write; +use std::os::unix::net::UnixListener; +use std::path::PathBuf; +use std::sync::mpsc::{self, Receiver, SyncSender}; + +pub(crate) enum Msg { + Data(Vec), + Resize { rows: u16, cols: u16 }, + ScreenRequest(mpsc::Sender>), +} + +pub(crate) fn socket_path() -> PathBuf { + let dir = std::env::temp_dir(); + dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) +} + +pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver) { + std::thread::spawn(move || { + let mut parser = vt100::Parser::new(rows, cols, 0); + + loop { + let first = match msg_rx.recv() { + Ok(msg) => msg, + Err(_) => break, + }; + + handle_parser_msg(&mut parser, first); + + while let Ok(msg) = msg_rx.try_recv() { + handle_parser_msg(&mut parser, msg); + } + } + }); +} + +pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender) { + std::thread::spawn(move || { + let listener = match UnixListener::bind(&sock_path) { + Ok(l) => l, + Err(e) => { + eprintln!("atuin pty-proxy: failed to bind socket: {e}"); + return; + } + }; + + for stream in listener.incoming() { + let mut stream = match stream { + Ok(s) => s, + Err(_) => break, + }; + + let (reply_tx, reply_rx) = mpsc::channel(); + if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { + break; + } + if let Ok(data) = reply_rx.recv() { + let _ = stream.write_all(&data); + let _ = stream.flush(); + } + } + }); +} + +/// Wire format written to the Unix socket: +/// +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +/// +/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain +/// pre-built ANSI escape sequences. The client can write them directly to +/// stdout without needing its own vt100 parser. +fn encode_screen(parser: &vt100::Parser) -> Vec { + let screen = parser.screen(); + let (rows, cols) = screen.size(); + let (cursor_row, cursor_col) = screen.cursor_position(); + + let mut buf: Vec = Vec::with_capacity(256 + (rows as usize * cols as usize)); + buf.extend_from_slice(&rows.to_be_bytes()); + buf.extend_from_slice(&cols.to_be_bytes()); + buf.extend_from_slice(&cursor_row.to_be_bytes()); + buf.extend_from_slice(&cursor_col.to_be_bytes()); + + for row_bytes in screen.rows_formatted(0, cols) { + let len = row_bytes.len() as u32; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(&row_bytes); + } + + buf +} + +fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { + match msg { + Msg::Data(data) => parser.process(&data), + Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), + Msg::ScreenRequest(reply_tx) => { + let _ = reply_tx.send(encode_screen(parser)); + } + } +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 0ac1e889..8e425232 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -36,7 +36,7 @@ atuin = { path = "/usr/bin/atuin" } default = ["client", "sync", "clipboard", "check-update", "daemon", "ai", "pty-proxy"] client = ["atuin-client"] sync = ["atuin-client/sync"] -daemon = ["atuin-client/daemon", "atuin-daemon"] +daemon = ["atuin-client/daemon", "atuin-daemon", "atuin-ai?/daemon"] ai = ["atuin-ai"] pty-proxy = ["dep:atuin-pty-proxy"] hex = ["pty-proxy"] diff --git a/crates/atuin/src/command/mod.rs b/crates/atuin/src/command/mod.rs index 7deb72d6..6cd221a4 100644 --- a/crates/atuin/src/command/mod.rs +++ b/crates/atuin/src/command/mod.rs @@ -24,10 +24,7 @@ pub enum AtuinCmd { /// PTY proxy for atuin #[cfg(feature = "pty-proxy")] #[command(alias = "hex")] - PtyProxy { - #[command(subcommand)] - cmd: Option, - }, + PtyProxy(atuin_pty_proxy::PtyProxy), /// Generate a UUID Uuid, @@ -56,8 +53,8 @@ impl AtuinCmd { Self::Client(client) => client.run(), #[cfg(feature = "pty-proxy")] - Self::PtyProxy { cmd } => { - atuin_pty_proxy::run(cmd); + Self::PtyProxy(proxy) => { + run_pty_proxy(proxy); Ok(()) } @@ -74,3 +71,92 @@ impl AtuinCmd { } } } + +#[cfg(all(feature = "pty-proxy", unix))] +fn run_pty_proxy(proxy: atuin_pty_proxy::PtyProxy) { + #[cfg(feature = "daemon")] + proxy.run(semantic_command_capture_sink()); + + #[cfg(not(feature = "daemon"))] + proxy.run(None); +} + +#[cfg(all(feature = "pty-proxy", not(unix)))] +fn run_pty_proxy(_proxy: atuin_pty_proxy::PtyProxy) { + eprintln!("atuin pty-proxy currently supports unix platforms"); + std::process::exit(1); +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +fn semantic_command_capture_sink() -> Option { + use std::sync::mpsc; + use std::time::Duration; + + if is_truthy_env("ATUIN_TERMINAL") { + return None; + } + + let settings = atuin_client::settings::Settings::new().ok()?; + let (tx, rx) = mpsc::sync_channel::(128); + + std::thread::spawn(move || { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + while let Ok(first) = rx.recv() { + let mut batch = vec![first]; + + while batch.len() < 64 { + match rx.recv_timeout(Duration::from_millis(25)) { + Ok(capture) => batch.push(capture), + Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { + break; + } + } + } + + runtime.block_on(send_semantic_command_captures(&settings, batch)); + } + }); + + Some(Box::new(move |capture| { + let _ = tx.try_send(capture); + })) +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +#[inline] +fn is_truthy_env(name: &str) -> bool { + std::env::var(name) + .ok() + .as_ref() + .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +async fn send_semantic_command_captures( + settings: &atuin_client::settings::Settings, + batch: Vec, +) { + let captures = batch + .into_iter() + .map(|capture| atuin_daemon::semantic::CommandCapture { + prompt: capture.prompt, + command: capture.command, + output: capture.output, + exit_code: capture.exit_code, + history_id: capture.history_id, + session_id: capture.session_id, + output_truncated: capture.output_truncated, + output_observed_bytes: capture.output_observed_bytes, + }) + .collect(); + + if let Ok(mut client) = atuin_daemon::SemanticClient::from_settings(settings).await { + let _ = client.record_commands(captures).await; + } +} diff --git a/crates/atuin/src/shell/atuin.bash b/crates/atuin/src/shell/atuin.bash index 45fdced9..8b540bd7 100644 --- a/crates/atuin/src/shell/atuin.bash +++ b/crates/atuin/src/shell/atuin.bash @@ -20,6 +20,35 @@ fi ATUIN_STTY=$(stty -g) ATUIN_HISTORY_ID="" +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'\001\033]133;A;cl=line\a\002' +__atuin_osc133_prompt_end=$'\001\033]133;B\a\002' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PS1-}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PS1="${__atuin_osc133_prompt_start}${__atuin_prompt}${__atuin_osc133_prompt_end}" + else + PS1="$__atuin_prompt" + fi +} + export ATUIN_PREEXEC_BACKEND=$SHLVL:none __atuin_update_preexec_backend() { if [[ ${BLE_ATTACHED-} ]]; then @@ -59,15 +88,19 @@ __atuin_preexec() { local id id=$(atuin history start -- "$1" 2>/dev/null) export ATUIN_HISTORY_ID=$id + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_executed __atuin_preexec_time=${EPOCHREALTIME-} } __atuin_precmd() { local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} + __atuin_osc133_wrap_prompt + [[ ! $ATUIN_HISTORY_ID ]] && return # If the previous preexec hook failed, we manually call __atuin_preexec + local __atuin_skip_osc133="" if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then # This is the command extraction code taken from bash-preexec local previous_command @@ -75,6 +108,7 @@ __atuin_precmd() { export LC_ALL=C HISTTIMEFORMAT='' builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' ) + __atuin_skip_osc133=1 __atuin_preexec "$previous_command" fi @@ -106,6 +140,7 @@ __atuin_precmd() { fi fi + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_finished "$EXIT" (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 export ATUIN_HISTORY_ID="" } diff --git a/crates/atuin/src/shell/atuin.fish b/crates/atuin/src/shell/atuin.fish index ddf55f3d..15b33451 100644 --- a/crates/atuin/src/shell/atuin.fish +++ b/crates/atuin/src/shell/atuin.fish @@ -4,9 +4,24 @@ if not set -q ATUIN_SESSION; or test "$ATUIN_SHLVL" != "$SHLVL" end set --erase ATUIN_HISTORY_ID +function _atuin_osc133_command_executed + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;C\a' +end + +function _atuin_osc133_command_finished --argument-names exit_code + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$exit_code" "$ATUIN_HISTORY_ID" "$ATUIN_SESSION" +end + function _atuin_preexec --on-event fish_preexec if not test -n "$fish_private_mode" set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]" 2>/dev/null) + _atuin_osc133_command_executed end end @@ -14,6 +29,7 @@ function _atuin_postexec --on-event fish_postexec set -l s $status if test -n "$ATUIN_HISTORY_ID" + _atuin_osc133_command_finished $s ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & disown end diff --git a/crates/atuin/src/shell/atuin.nu b/crates/atuin/src/shell/atuin.nu index c1a38313..d37457e4 100644 --- a/crates/atuin/src/shell/atuin.nu +++ b/crates/atuin/src/shell/atuin.nu @@ -14,6 +14,28 @@ if 'ATUIN_SESSION' not-in $env or ('ATUIN_SHLVL' not-in $env) or ($env.ATUIN_SHL } hide-env -i ATUIN_HISTORY_ID +def _atuin_osc133_command_executed [] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;C(char bel)" +} + +def _atuin_osc133_command_finished [exit_code: int] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;D;($exit_code);history_id=($env.ATUIN_HISTORY_ID);session_id=($env.ATUIN_SESSION)(char bel)" +} + # Magic token to make sure we don't record commands run by keybindings let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" @@ -27,6 +49,7 @@ let _atuin_pre_execution = {|| } if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd | complete | get stdout | str trim) + _atuin_osc133_command_executed } } @@ -35,6 +58,7 @@ let _atuin_pre_prompt = {|| if 'ATUIN_HISTORY_ID' not-in $env { return } + _atuin_osc133_command_finished $last_exit with-env { ATUIN_LOG: error } { if (version).minor >= 104 or (version).major > 0 { job spawn { diff --git a/crates/atuin/src/shell/atuin.zsh b/crates/atuin/src/shell/atuin.zsh index 87f47531..7a7375aa 100644 --- a/crates/atuin/src/shell/atuin.zsh +++ b/crates/atuin/src/shell/atuin.zsh @@ -31,16 +31,54 @@ if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then fi ATUIN_HISTORY_ID="" +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'%{\033]133;A;cl=line\a%}' +__atuin_osc133_prompt_end=$'%{\033]133;B\a%}' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PROMPT-}" + local __atuin_rprompt="${RPROMPT-}" + + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_start/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PROMPT="${__atuin_osc133_prompt_start}${__atuin_prompt}" + RPROMPT="${__atuin_rprompt}${__atuin_osc133_prompt_end}" + else + PROMPT="$__atuin_prompt" + RPROMPT="$__atuin_rprompt" + fi +} + _atuin_preexec() { local id id=$(atuin history start -- "$1" 2>/dev/null) export ATUIN_HISTORY_ID="$id" + __atuin_osc133_command_executed __atuin_preexec_time=${EPOCHREALTIME-} } _atuin_precmd() { local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} + __atuin_osc133_wrap_prompt + [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return local duration="" @@ -48,6 +86,7 @@ _atuin_precmd() { printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) fi + __atuin_osc133_command_finished "$EXIT" (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 export ATUIN_HISTORY_ID="" } -- cgit v1.3.1