diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-06-08 09:12:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-06-08 09:12:45 -0700 |
| commit | bcdf8c8cde31e826000f1b2d6eeaebdd865a07c1 (patch) | |
| tree | f62f66e4dede22ce73ea5dafe69881d6af9b3101 | |
| parent | chore(deps): bump debian from bookworm-20260421-slim to bookworm-20260518-sli... (diff) | |
| download | atuin-bcdf8c8cde31e826000f1b2d6eeaebdd865a07c1.zip | |
feat: Capture command output + expose to new `atuin_output` tool (#3510)
36 files changed, 3023 insertions, 596 deletions
@@ -280,6 +280,7 @@ dependencies = [ "async-trait", "atuin-client", "atuin-common", + "atuin-daemon", "chrono", "chrono-humanize", "clap", diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index 06e50a4e..027bd490 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -14,12 +14,14 @@ repository = { workspace = true } [features] default = [] +daemon = [] tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] [dependencies] async-trait = { workspace = true } atuin-client = { workspace = true } atuin-common = { workspace = true } +atuin-daemon = { workspace = true } tokio = { workspace = true } eyre = { workspace = true } clap = { workspace = true, features = ["derive", "env"] } diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index 989b95c0..6d1f9c51 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -67,7 +67,7 @@ pub(crate) async fn run( settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - history_db.last().await.ok().flatten().map(|h| h.command) + history_db.last().await.ok().flatten() } else { None }; @@ -84,6 +84,7 @@ pub(crate) async fn run( history_db: std::sync::Arc::new(history_db), git_root, capabilities: settings.ai.capabilities.clone(), + daemon_enabled: settings.daemon.enabled, }; let action = run_inline_tui(ctx, initial_command, settings).await?; diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs index 625de0c6..f891a9fc 100644 --- a/crates/atuin-ai/src/context.rs +++ b/crates/atuin-ai/src/context.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::Arc; use atuin_client::distro::detect_linux_distribution; +use atuin_client::history::History; use atuin_client::settings::AiCapabilities; /// Session-scoped context for the AI chat session. @@ -11,12 +12,17 @@ pub(crate) struct AppContext { pub endpoint: String, pub token: String, pub send_cwd: bool, - pub last_command: Option<String>, + pub last_command: Option<History>, pub history_db: Arc<atuin_client::database::Sqlite>, /// Git root of the current working directory, if inside a git repo. /// Resolves through worktrees to the main repo root. pub git_root: Option<PathBuf>, pub capabilities: AiCapabilities, + pub daemon_enabled: bool, +} + +pub(crate) fn history_output_capability_available(daemon_enabled: bool) -> bool { + cfg!(feature = "daemon") && daemon_enabled } impl AppContext { @@ -33,6 +39,11 @@ impl AppContext { if self.capabilities.enable_command_execution.unwrap_or(true) { caps.push("client_v1_execute_shell_command".to_string()); } + if history_output_capability_available(self.daemon_enabled) + && self.capabilities.enable_history_output.unwrap_or(true) + { + caps.push("client_v1_atuin_output".to_string()); + } caps.push("client_v1_load_skill".to_string()); if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( @@ -69,7 +80,11 @@ impl ClientContext { /// Serialize to the JSON format the API expects for the "context" field. /// The `pwd` field is always dynamic (current working directory), so it's /// computed fresh on each call if `send_cwd` is true. - pub(crate) fn to_json(&self, send_cwd: bool, last_command: Option<&str>) -> serde_json::Value { + pub(crate) fn to_json( + &self, + send_cwd: bool, + last_command: Option<&History>, + ) -> serde_json::Value { let mut ctx = serde_json::json!({ "os": self.os, "shell": self.shell, @@ -78,9 +93,15 @@ impl ClientContext { } else { None }, - "last_command": last_command, }); + if let Some(history) = last_command { + ctx["last_command"] = serde_json::json!(crate::history_format::format_last_command( + history, + crate::history_format::current_local_offset(), + )); + } + if let Some(ref distro) = self.distro { ctx["distro"] = serde_json::json!(distro); } diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index ddb839b7..82d666ef 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -492,6 +492,7 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { messages.clone(), session_id.clone(), &app.capabilities, + app.daemon_enabled, fsm.ctx.invocation_id.clone(), ); tokio::spawn(async move { @@ -570,7 +571,6 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { Effect::ExecuteTool { tool_id, tool } => { let tool_id = tool_id.clone(); - let tool = tool.clone(); let tx = tx.clone(); let db = io.app_ctx.history_db.clone(); @@ -731,8 +731,9 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { preview: None, })); } - ClientToolCall::AtuinHistory(_) => { + ClientToolCall::AtuinHistory(tool) => { // History search needs async DB access + let tool = tool.clone(); tokio::spawn(async move { let outcome = tool.execute(&db).await; let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { @@ -742,6 +743,17 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { })); }); } + ClientToolCall::AtuinOutput(tool) => { + let tool = tool.clone(); + tokio::spawn(async move { + let outcome = tool.execute().await; + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + }); + } ClientToolCall::LoadSkill(skill_call) => { let skill_name = skill_call.name.clone(); let registry = io.skill_registry.clone(); diff --git a/crates/atuin-ai/src/history_format.rs b/crates/atuin-ai/src/history_format.rs new file mode 100644 index 00000000..24aa963e --- /dev/null +++ b/crates/atuin-ai/src/history_format.rs @@ -0,0 +1,120 @@ +use atuin_client::history::History; +use time::UtcOffset; + +pub(crate) fn current_local_offset() -> UtcOffset { + UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC) +} + +pub(crate) fn format_last_command(history: &History, local_offset: UtcOffset) -> String { + format!( + "History ID: {} - `{}`\n{}", + history.id, + history.command, + format_history_metadata(history, local_offset) + ) +} + +pub(crate) fn format_history_search_result( + ordinal: usize, + history: &History, + local_offset: UtcOffset, +) -> String { + format!( + "## #{}. (History ID: {}):\n`{}`\n{}\n", + ordinal, + history.id, + history.command, + format_history_metadata(history, local_offset) + ) +} + +fn format_history_metadata(history: &History, local_offset: UtcOffset) -> String { + format!( + "[{}] (in `{}`, exit {}){}", + format_timestamp(history, local_offset), + history.cwd, + history.exit, + format_duration(history.duration) + ) +} + +fn format_timestamp(history: &History, local_offset: UtcOffset) -> String { + let ts = history.timestamp.to_offset(local_offset); + format!( + "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", + ts.year(), + ts.month() as u8, + ts.day(), + ts.hour(), + ts.minute(), + ts.second(), + ) +} + +fn format_duration(nanos: i64) -> String { + if nanos <= 0 { + return String::new(); + } + + let total_secs = nanos / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + + if total_secs >= 3600 { + let hours = total_secs / 3600; + let mins = (total_secs % 3600) / 60; + let secs = total_secs % 60; + format!(", {hours}h{mins}m{secs}s") + } else if total_secs >= 60 { + let mins = total_secs / 60; + let secs = total_secs % 60; + format!(", {mins}m{secs}s") + } else if total_secs > 0 { + if millis > 0 { + format!(", {total_secs}.{millis:03}s") + } else { + format!(", {total_secs}s") + } + } else { + format!(", {millis}ms") + } +} + +#[cfg(test)] +mod tests { + use atuin_client::history::{History, HistoryId}; + use time::{OffsetDateTime, UtcOffset}; + + use super::*; + + fn history(duration: i64) -> History { + History { + id: HistoryId("018f011c-9a0a-7000-8000-000000000001".to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration, + exit: 2, + command: "cargo test".to_string(), + cwd: "/repo".to_string(), + session: String::new(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + #[test] + fn formats_last_command() { + assert_eq!( + format_last_command(&history(1_234_000_000), UtcOffset::UTC), + "History ID: 018f011c-9a0a-7000-8000-000000000001 - `cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2), 1.234s" + ); + } + + #[test] + fn formats_history_search_result() { + assert_eq!( + format_history_search_result(3, &history(0), UtcOffset::UTC), + "## #3. (History ID: 018f011c-9a0a-7000-8000-000000000001):\n`cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2)\n" + ); + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index b3587739..f972d4ff 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -7,6 +7,7 @@ pub(crate) mod edit_permissions; pub(crate) mod event_serde; pub(crate) mod file_tracker; pub(crate) mod fsm; +pub(crate) mod history_format; pub(crate) mod permissions; pub(crate) mod session; pub(crate) mod skills; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs index 96abc3ab..bb1eae0c 100644 --- a/crates/atuin-ai/src/permissions/check.rs +++ b/crates/atuin-ai/src/permissions/check.rs @@ -1,13 +1,13 @@ use eyre::Result; -use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; +use crate::{permissions::file::RuleFile, tools::PermissibleToolCall}; pub(crate) struct PermissionRequest<'t> { - call: &'t (dyn PermissableToolCall + Send + Sync), + call: &'t (dyn PermissibleToolCall + Send + Sync), } impl<'t> PermissionRequest<'t> { - pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + pub fn new(call: &'t (dyn PermissibleToolCall + Send + Sync)) -> Self { Self { call } } } diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 084e8238..e78dc2e1 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -2,7 +2,10 @@ // SSE streaming // ─────────────────────────────────────────────────────────────────── +use atuin_client::history::History; use atuin_client::settings::AiCapabilities; + +use crate::context::history_output_capability_available; use atuin_common::tls::ensure_crypto_provider; use eventsource_stream::Eventsource; @@ -61,6 +64,7 @@ impl ChatRequest { messages: Vec<serde_json::Value>, session_id: Option<String>, capabilities: &AiCapabilities, + history_output_available: bool, invocation_id: String, ) -> Self { let mut caps = vec![ @@ -78,6 +82,11 @@ impl ChatRequest { if capabilities.enable_command_execution.unwrap_or(true) { caps.push("client_v1_execute_shell_command".to_string()); } + if history_output_capability_available(history_output_available) + && capabilities.enable_history_output.unwrap_or(true) + { + caps.push("client_v1_atuin_output".to_string()); + } if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( extra @@ -103,7 +112,7 @@ pub(crate) fn create_chat_stream( request: ChatRequest, client_ctx: ClientContext, send_cwd: bool, - last_command: Option<String>, + last_command: Option<History>, user_contexts: Vec<crate::user_context::UserContext>, skill_summaries: Vec<crate::skills::SkillSummary>, skill_overflow: Option<String>, @@ -120,7 +129,7 @@ pub(crate) fn create_chat_stream( tracing::debug!("Sending SSE request to {endpoint}"); - let context = client_ctx.to_json(send_cwd, last_command.as_deref()); + let context = client_ctx.to_json(send_cwd, last_command.as_ref()); let mut config = serde_json::json!({ "capabilities": request.capabilities, diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs index 06858bf8..4190540c 100644 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -67,6 +67,15 @@ pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { is_client: true, }; +pub(crate) const ATUIN_OUTPUT: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["atuin_output"], + capability: Some("client_v1_atuin_output"), + display_verb: "view the output for command", + progressive_verb: "Viewing output...", + past_verb: "Viewed output", + is_client: true, +}; + pub(crate) const LOAD_SKILL: &ToolDescriptor = &ToolDescriptor { canonical_names: &["load_skill"], capability: Some("client_v1_load_skill"), @@ -104,6 +113,7 @@ const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ WRITE, SHELL, ATUIN_HISTORY, + ATUIN_OUTPUT, LOAD_SKILL, SERVER_SEARCH, SERVER_SCRAPE, diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index fdda10a4..d1352661 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -5,6 +5,7 @@ use std::{ }; use eyre::Result; +use uuid::Uuid; const DEFAULT_FILE_READ_LINES: u64 = 100; const MAX_FILE_READ_LINES: u64 = 1000; @@ -158,6 +159,7 @@ pub(crate) enum ClientToolCall { Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), + AtuinOutput(AtuinOutputToolCall), LoadSkill(LoadSkillToolCall), } @@ -173,6 +175,9 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, )), + "atuin_output" => Ok(ClientToolCall::AtuinOutput(AtuinOutputToolCall::try_from( + input, + )?)), "load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from( input, )?)), @@ -189,6 +194,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => descriptor::WRITE, ClientToolCall::Shell(_) => descriptor::SHELL, ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + ClientToolCall::AtuinOutput(_) => descriptor::ATUIN_OUTPUT, ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL, } } @@ -205,6 +211,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => "Write", ClientToolCall::Shell(_) => "Shell", ClientToolCall::AtuinHistory(_) => "AtuinHistory", + ClientToolCall::AtuinOutput(_) => "AtuinOutput", ClientToolCall::LoadSkill(_) => "LoadSkill", } } @@ -218,6 +225,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => Some(tool.resolved_path()), ClientToolCall::Shell(_) | ClientToolCall::AtuinHistory(_) + | ClientToolCall::AtuinOutput(_) | ClientToolCall::LoadSkill(_) => None, } } @@ -229,6 +237,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.matches_rule(rule), ClientToolCall::Shell(tool) => tool.matches_rule(rule), ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + ClientToolCall::AtuinOutput(tool) => tool.matches_rule(rule), ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule), } } @@ -240,26 +249,14 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.target_dir(), ClientToolCall::Shell(tool) => tool.target_dir(), ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + ClientToolCall::AtuinOutput(tool) => tool.target_dir(), ClientToolCall::LoadSkill(tool) => tool.target_dir(), } } - - /// Execute this client-side tool and return the result. - pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { - match self { - ClientToolCall::Read(tool) => tool.execute(), - ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, - // LoadSkill is handled separately by the driver (needs registry access) - ClientToolCall::LoadSkill(_) => { - ToolOutcome::Error("LoadSkill must be executed via the driver".to_string()) - } - _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), - } - } } /// A trait for tool calls that can be checked against permission rules. -pub(crate) trait PermissableToolCall { +pub(crate) trait PermissibleToolCall { /// Checks if this tool call matches the given permission rule. fn matches_rule(&self, rule: &Rule) -> bool; @@ -277,7 +274,7 @@ pub(crate) trait PermissableToolCall { } } -impl PermissableToolCall for ClientToolCall { +impl PermissibleToolCall for ClientToolCall { fn matches_rule(&self, rule: &Rule) -> bool { self.matches_rule(rule) } @@ -416,7 +413,7 @@ impl ReadToolCall { } } -impl PermissableToolCall for ReadToolCall { +impl PermissibleToolCall for ReadToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -616,7 +613,7 @@ impl EditToolCall { } } -impl PermissableToolCall for EditToolCall { +impl PermissibleToolCall for EditToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -724,7 +721,7 @@ impl WriteToolCall { } } -impl PermissableToolCall for WriteToolCall { +impl PermissibleToolCall for WriteToolCall { fn target_dir(&self) -> Option<&Path> { Some(&self.path) } @@ -792,7 +789,7 @@ impl TryFrom<&serde_json::Value> for ShellToolCall { } } -impl PermissableToolCall for ShellToolCall { +impl PermissibleToolCall for ShellToolCall { fn target_dir(&self) -> Option<&Path> { self.dir.as_deref() } @@ -1134,7 +1131,7 @@ impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { } } -impl PermissableToolCall for AtuinHistoryToolCall { +impl PermissibleToolCall for AtuinHistoryToolCall { fn target_dir(&self) -> Option<&Path> { None } @@ -1148,7 +1145,6 @@ impl AtuinHistoryToolCall { pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { use atuin_client::database::{self, Database as _, OptFilters}; use atuin_client::settings::SearchMode; - use time::UtcOffset; let context = match database::current_context().await { Ok(ctx) => ctx, @@ -1184,34 +1180,13 @@ impl AtuinHistoryToolCall { return ToolOutcome::Success("No matching history entries found.".to_string()); } - let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + let local_offset = crate::history_format::current_local_offset(); let formatted: Vec<String> = results .iter() .enumerate() - .map(|(i, h)| { - let ts = h.timestamp.to_offset(local_offset); - let time_str = format!( - "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", - ts.year(), - ts.month() as u8, - ts.day(), - ts.hour(), - ts.minute(), - ts.second(), - ); - - let duration_str = format_duration(h.duration); - - format!( - "{}. `{}` [{}] ({}, exit: {}){}", - i + 1, - h.command, - time_str, - h.cwd, - h.exit, - duration_str, - ) + .map(|(i, history)| { + crate::history_format::format_history_search_result(i + 1, history, local_offset) }) .collect(); @@ -1220,6 +1195,146 @@ impl AtuinHistoryToolCall { } #[derive(Debug, Clone)] +pub(crate) struct AtuinOutputToolCall { + pub history_id: Uuid, + pub ranges: Vec<(i64, i64)>, +} + +impl TryFrom<&serde_json::Value> for AtuinOutputToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let history_id = value + .get("history_id") + .and_then(|v| v.as_str()) + .and_then(|v| Uuid::parse_str(v).ok()) + .ok_or(eyre::eyre!("Missing or invalid history ID"))?; + + let ranges = value + .get("ranges") + .and_then(|v| v.as_array()) + .map(Vec::as_slice) + .unwrap_or(&[]); + + let ranges = ranges + .iter() + .map(|r| { + let range = r + .as_array() + .filter(|a| a.len() == 2) + .ok_or_else(|| eyre::eyre!("Each range must be a [start, end] array"))?; + + let start = range[0] + .as_i64() + .ok_or_else(|| eyre::eyre!("Range start must be an integer"))?; + let end = range[1] + .as_i64() + .ok_or_else(|| eyre::eyre!("Range end must be an integer"))?; + + Ok((start, end)) + }) + .collect::<Result<Vec<(i64, i64)>, eyre::Error>>()?; + + Ok(Self { history_id, ranges }) + } +} + +impl PermissibleToolCall for AtuinOutputToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "AtuinOutput" + } +} + +fn format_output_lines_for_llm(lines: &[atuin_daemon::semantic::OutputLine]) -> String { + let width = lines + .iter() + .map(|line| line.line_number) + .max() + .unwrap_or(1) + .max(1) + .ilog10() as usize + + 1; + let mut formatted = Vec::with_capacity(lines.len()); + let mut previous_line_number = None; + + for line in lines { + if let Some(previous) = previous_line_number { + let skipped = line.line_number.saturating_sub(previous + 1); + if skipped > 0 { + formatted.push(format!("[...skipped {skipped} lines...]")); + } + } + + formatted.push(format!("{:>width$}\t{}", line.line_number, line.content)); + previous_line_number = Some(line.line_number); + } + + formatted.join("\n") +} + +impl AtuinOutputToolCall { + pub(crate) async fn execute(&self) -> ToolOutcome { + let settings = match atuin_client::settings::Settings::new() { + Ok(settings) => settings, + Err(e) => return ToolOutcome::Error(format!("Failed to load Atuin settings: {e}")), + }; + + let mut client = match atuin_daemon::SemanticClient::from_settings(&settings).await { + Ok(client) => client, + Err(e) => return ToolOutcome::Error(format!("Failed to connect to Atuin daemon: {e}")), + }; + + let history_id = self.history_id.as_simple().to_string(); + let response = match client + .command_output(history_id.clone(), self.ranges.clone()) + .await + { + Ok(response) => response, + Err(e) => return ToolOutcome::Error(format!("Failed to fetch command output: {e}")), + }; + + if !response.found { + return ToolOutcome::Success(format!( + "No captured output found for history ID {history_id}." + )); + } + + if response.total_lines == 0 { + return ToolOutcome::Success(format!( + "Captured output for history ID {history_id} is empty." + )); + } + + let output = format_output_lines_for_llm(&response.lines); + if output.is_empty() { + return ToolOutcome::Success(format!( + "No lines selected from captured output for history ID {history_id}." + )); + } + + let total_output = if response.output_truncated { + format!( + "{} bytes captured, {} bytes observed before truncation, {} lines", + response.total_bytes, response.output_observed_bytes, response.total_lines + ) + } else { + format!( + "{} bytes, {} lines", + response.total_bytes, response.total_lines + ) + }; + + ToolOutcome::Success(format!( + "History ID: {history_id}\nTotal output: {total_output}\nSelected output:\n{output}" + )) + } +} + +#[derive(Debug, Clone)] pub(crate) struct LoadSkillToolCall { pub name: String, } @@ -1239,7 +1354,7 @@ impl TryFrom<&serde_json::Value> for LoadSkillToolCall { } } -impl PermissableToolCall for LoadSkillToolCall { +impl PermissibleToolCall for LoadSkillToolCall { fn target_dir(&self) -> Option<&Path> { None } @@ -1286,6 +1401,52 @@ mod tests { // ── Cross-platform tests ── #[test] + fn atuin_output_ranges_are_optional() { + let input = serde_json::json!({ + "history_id": "018f0000000070008000000000000000" + }); + + let call = AtuinOutputToolCall::try_from(&input).unwrap(); + + assert_eq!( + call.history_id.as_simple().to_string(), + "018f0000000070008000000000000000" + ); + assert!(call.ranges.is_empty()); + } + + #[test] + fn atuin_output_parses_line_ranges() { + let input = serde_json::json!({ + "history_id": "018f0000000070008000000000000000", + "ranges": [[0, 30], [-100, -1]] + }); + + let call = AtuinOutputToolCall::try_from(&input).unwrap(); + + assert_eq!(call.ranges, vec![(0, 30), (-100, -1)]); + } + + #[test] + fn atuin_output_formats_lines_like_read_file() { + let lines = vec![ + atuin_daemon::semantic::OutputLine { + line_number: 98, + content: "near end".to_string(), + }, + atuin_daemon::semantic::OutputLine { + line_number: 100, + content: "end".to_string(), + }, + ]; + + assert_eq!( + format_output_lines_for_llm(&lines), + " 98\tnear end\n[...skipped 1 lines...]\n100\tend" + ); + } + + #[test] fn no_scope_matches_everything() { assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); @@ -1996,31 +2157,3 @@ mod tests { } } } - -fn format_duration(nanos: i64) -> String { - if nanos <= 0 { - return String::new(); - } - - let total_secs = nanos / 1_000_000_000; - let millis = (nanos % 1_000_000_000) / 1_000_000; - - if total_secs >= 3600 { - let hours = total_secs / 3600; - let mins = (total_secs % 3600) / 60; - let secs = total_secs % 60; - format!(", {hours}h{mins}m{secs}s") - } else if total_secs >= 60 { - let mins = total_secs / 60; - let secs = total_secs % 60; - format!(", {mins}m{secs}s") - } else if total_secs > 0 { - if millis > 0 { - format!(", {total_secs}.{millis:03}s") - } else { - format!(", {total_secs}s") - } - } else { - format!(", {millis}ms") - } -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 73dc2ad7..b594cedf 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -168,6 +168,7 @@ fn tool_call_view(tool_call: &crate::fsm::tools::TrackedTool, in_git_project: bo ClientToolCall::Write(tool) => tool.path.display().to_string(), ClientToolCall::Shell(tool) => tool.command.clone(), ClientToolCall::AtuinHistory(tool) => tool.query.clone(), + ClientToolCall::AtuinOutput(tool) => tool.history_id.to_string(), ClientToolCall::LoadSkill(tool) => format!("skill: {}", tool.name), }; diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index c74395b8..aa1f55fa 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -495,6 +495,7 @@ impl<'a> TurnBuilder<'a> { query: history.query.clone(), filter_modes: history.filter_modes.clone(), }, + ClientToolCall::AtuinOutput(_) => ToolRenderData::Remote, ClientToolCall::LoadSkill(skill) => ToolRenderData::SkillLoad { _name: skill.name.clone(), }, diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 4df404c4..1be6f363 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -687,6 +687,8 @@ pub struct Ai { pub struct AiCapabilities { /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). pub enable_history_search: Option<bool>, + /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_output: Option<bool>, /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). pub enable_file_tools: Option<bool>, /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). diff --git a/crates/atuin-daemon/build.rs b/crates/atuin-daemon/build.rs index 7034aa04..7808a07b 100644 --- a/crates/atuin-daemon/build.rs +++ b/crates/atuin-daemon/build.rs @@ -7,6 +7,7 @@ fn main() -> std::io::Result<()> { "proto/history.proto", "proto/search.proto", "proto/control.proto", + "proto/semantic.proto", ]; let proto_include_dirs = ["proto"]; diff --git a/crates/atuin-daemon/proto/semantic.proto b/crates/atuin-daemon/proto/semantic.proto new file mode 100644 index 00000000..07e550c8 --- /dev/null +++ b/crates/atuin-daemon/proto/semantic.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; +package semantic; + +service Semantic { + rpc RecordCommands(stream CommandCapture) returns (RecordCommandsReply); + rpc CommandOutput(CommandOutputRequest) returns (CommandOutputReply); +} + +message CommandCapture { + string prompt = 1; + string command = 2; + string output = 3; + optional int32 exit_code = 4; + optional string history_id = 5; + optional string session_id = 6; + bool output_truncated = 7; + uint64 output_observed_bytes = 8; +} + +message RecordCommandsReply { + uint64 accepted = 1; +} + +message CommandOutputRequest { + string history_id = 1; + repeated OutputRange ranges = 2; +} + +message OutputRange { + int64 start = 1; + int64 end = 2; +} + +message OutputLine { + uint64 line_number = 1; + string content = 2; +} + +message CommandOutputReply { + bool found = 1; + string output = 2; + uint64 total_bytes = 3; + uint64 total_lines = 4; + repeated OutputLine lines = 5; + bool output_truncated = 6; + uint64 output_observed_bytes = 7; +} diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs index 5f4ce20f..c18e0e46 100644 --- a/crates/atuin-daemon/src/client.rs +++ b/crates/atuin-daemon/src/client.rs @@ -30,6 +30,10 @@ use crate::search::{ FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, search_client::SearchClient as SearchServiceClient, }; +use crate::semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputRange, RecordCommandsReply, + semantic_client::SemanticClient as SemanticServiceClient, +}; pub struct HistoryClient { client: HistoryServiceClient<Channel>, @@ -256,6 +260,92 @@ impl From<Context> for RpcSearchContext { } } +pub struct SemanticClient { + client: SemanticServiceClient<Channel>, +} + +impl SemanticClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result<Self> { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(not(unix))] + pub async fn new(port: u64) -> Result<Self> { + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let url = format!("127.0.0.1:{port}"); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result<Self> { + Self::new(settings.daemon.socket_path.clone()).await + } + + #[cfg(not(unix))] + pub async fn from_settings(settings: &Settings) -> Result<Self> { + Self::new(settings.daemon.tcp_port).await + } + + pub async fn record_commands( + &mut self, + captures: Vec<CommandCapture>, + ) -> Result<RecordCommandsReply> { + let stream = tokio_stream::iter(captures); + Ok(self.client.record_commands(stream).await?.into_inner()) + } + + pub async fn command_output( + &mut self, + history_id: String, + ranges: Vec<(i64, i64)>, + ) -> Result<CommandOutputReply> { + let request = CommandOutputRequest { + history_id, + ranges: ranges + .into_iter() + .map(|(start, end)| OutputRange { start, end }) + .collect(), + }; + + Ok(self.client.command_output(request).await?.into_inner()) + } +} + // ============================================================================ // Control Client // ============================================================================ diff --git a/crates/atuin-daemon/src/components/mod.rs b/crates/atuin-daemon/src/components/mod.rs index 5950d5d5..447e31df 100644 --- a/crates/atuin-daemon/src/components/mod.rs +++ b/crates/atuin-daemon/src/components/mod.rs @@ -11,12 +11,15 @@ //! //! - [`history::HistoryComponent`]: Command history lifecycle management //! - [`search::SearchComponent`]: Fuzzy search over history +//! - [`semantic::SemanticComponent`]: In-memory semantic command captures //! - [`sync::SyncComponent`]: Cloud sync pub mod history; pub mod search; +pub mod semantic; pub mod sync; pub use history::HistoryComponent; pub use search::SearchComponent; +pub use semantic::SemanticComponent; pub use sync::SyncComponent; diff --git a/crates/atuin-daemon/src/components/semantic.rs b/crates/atuin-daemon/src/components/semantic.rs new file mode 100644 index 00000000..dff38fd3 --- /dev/null +++ b/crates/atuin-daemon/src/components/semantic.rs @@ -0,0 +1,900 @@ +//! Semantic command capture component. +//! +//! This is a prototype in-memory store for completed command captures emitted +//! by atuin-pty-proxy. It keeps recent captures per Atuin session and indexes +//! them by history ID for AI tool lookup. + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use atuin_client::history::{History, HistoryId}; +use eyre::Result; +use tokio::sync::Mutex; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, instrument}; + +use crate::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputLine, RecordCommandsReply, + semantic_server::{Semantic as SemanticSvc, SemanticServer}, + }, +}; + +const MAX_SESSIONS: usize = 20; +const MAX_COMMANDS_PER_SESSION: usize = 128; +const MAX_BYTES_PER_SESSION: usize = 32 * 1024 * 1024; +const MAX_PENDING_HISTORIES: usize = 128; + +/// Stores completed command captures and associates them with history events. +pub struct SemanticComponent { + inner: Arc<SemanticComponentInner>, +} + +struct SemanticComponentInner { + state: Mutex<SemanticState>, +} + +#[derive(Default)] +struct SemanticState { + sessions: HashMap<SessionId, SessionCaptures>, + session_lru: VecDeque<SessionId>, + history_index: HashMap<HistoryId, CaptureRef>, + pending_histories: VecDeque<History>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct SessionId(String); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct CaptureId(u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CaptureRef { + session_id: SessionId, + capture_id: CaptureId, +} + +#[derive(Default)] +struct SessionCaptures { + next_id: u64, + records: VecDeque<StoredCapture>, + output_bytes: usize, +} + +struct StoredCapture { + id: CaptureId, + history_id: HistoryId, + output_bytes: usize, + record: SemanticCommandRecord, +} + +struct EvictedCapture { + history_id: HistoryId, + capture_id: CaptureId, +} + +#[derive(Debug, Clone)] +struct SemanticCommandRecord { + capture: CommandCapture, + history: Option<History>, +} + +impl SemanticComponent { + pub fn new() -> Self { + Self { + inner: Arc::new(SemanticComponentInner { + state: Mutex::new(SemanticState::default()), + }), + } + } + + pub fn grpc_service(&self) -> SemanticServer<SemanticGrpcService> { + SemanticServer::new(SemanticGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for SemanticComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SemanticComponent { + fn name(&self) -> &'static str { + "semantic" + } + + async fn start(&mut self, _handle: DaemonHandle) -> Result<()> { + tracing::info!("semantic component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::HistoryEnded(history) = event { + self.inner.record_history(history.clone()).await; + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + let state = self.inner.state.lock().await; + tracing::info!( + sessions = state.sessions.len(), + records = state.record_count(), + indexed_histories = state.history_index.len(), + pending_histories = state.pending_histories.len(), + "semantic component stopped" + ); + Ok(()) + } +} + +impl SemanticComponentInner { + async fn record_capture(&self, capture: CommandCapture) -> bool { + let mut state = self.state.lock().await; + state.record_capture(capture) + } + + async fn record_history(&self, history: History) { + let mut state = self.state.lock().await; + state.record_history(history); + } + + async fn command_output(&self, request: &CommandOutputRequest) -> CommandOutputReply { + let mut state = self.state.lock().await; + state.command_output(request) + } +} + +impl SemanticState { + fn record_capture(&mut self, mut capture: CommandCapture) -> bool { + let Some(history_id) = history_id_from_str(capture.history_id.as_deref()) else { + tracing::debug!( + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without history id" + ); + return false; + }; + + let history = take_pending_history(&mut self.pending_histories, &history_id); + let Some(session_id) = capture + .session_id + .as_deref() + .and_then(|session_id| SessionId::try_from(session_id).ok()) + .or_else(|| { + history + .as_ref() + .and_then(|history| SessionId::try_from(history.session.as_str()).ok()) + }) + else { + tracing::debug!( + history_id = %history_id, + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without session id" + ); + return false; + }; + + capture.history_id = Some(history_id.to_string()); + capture.session_id = Some(session_id.to_string()); + if capture.output_observed_bytes == 0 { + capture.output_observed_bytes = capture.output.len() as u64; + } + + let record = SemanticCommandRecord { capture, history }; + log_record(&record, "recorded semantic command capture"); + self.push_record(session_id, history_id, record); + true + } + + fn record_history(&mut self, history: History) { + let history_id = history.id.clone(); + + if let Some(capture_ref) = self.history_index.get(&history_id).cloned() { + if let Some(stored) = self.stored_capture_mut(&capture_ref) { + stored.record.history = Some(history); + log_record( + &stored.record, + "associated semantic command capture with history", + ); + return; + } + + self.history_index.remove(&history_id); + } + + tracing::debug!( + id = %history.id, + command_bytes = history.command.len(), + "history ended before semantic capture arrived" + ); + push_pending_history(&mut self.pending_histories, history); + } + + fn command_output(&mut self, request: &CommandOutputRequest) -> CommandOutputReply { + let Some(history_id) = history_id_from_str(Some(&request.history_id)) else { + return command_output_not_found(); + }; + let Some(capture_ref) = self.history_index.get(&history_id).cloned() else { + return command_output_not_found(); + }; + + let Some(reply) = self.command_output_for_ref(&capture_ref, &request.ranges) else { + self.history_index.remove(&history_id); + return command_output_not_found(); + }; + + self.touch_session(&capture_ref.session_id); + reply + } + + fn command_output_for_ref( + &self, + capture_ref: &CaptureRef, + ranges: &[crate::semantic::OutputRange], + ) -> Option<CommandOutputReply> { + let stored = self + .sessions + .get(&capture_ref.session_id)? + .stored_capture(capture_ref.capture_id)?; + let output = &stored.record.capture.output; + let output_observed_bytes = stored + .record + .capture + .output_observed_bytes + .max(output.len() as u64); + + Some(CommandOutputReply { + found: true, + output: String::new(), + total_bytes: output.len() as u64, + total_lines: output.lines().count() as u64, + lines: select_output_ranges(output, ranges), + output_truncated: stored.record.capture.output_truncated, + output_observed_bytes, + }) + } + + fn push_record( + &mut self, + session_id: SessionId, + history_id: HistoryId, + record: SemanticCommandRecord, + ) { + self.touch_session(&session_id); + + let (capture_id, evicted) = { + let session = self.sessions.entry(session_id.clone()).or_default(); + session.push(history_id.clone(), record) + }; + + let capture_ref = CaptureRef { + session_id: session_id.clone(), + capture_id, + }; + self.history_index.insert(history_id, capture_ref); + + for evicted in evicted { + self.remove_history_index_if_matches( + &session_id, + &evicted.history_id, + evicted.capture_id, + ); + } + + self.expire_lru_sessions(); + } + + fn touch_session(&mut self, session_id: &SessionId) { + if let Some(index) = self.session_lru.iter().position(|id| id == session_id) { + self.session_lru.remove(index); + } + self.session_lru.push_back(session_id.clone()); + } + + fn expire_lru_sessions(&mut self) { + while self.session_lru.len() > MAX_SESSIONS { + let Some(session_id) = self.session_lru.pop_front() else { + break; + }; + let Some(session) = self.sessions.remove(&session_id) else { + continue; + }; + + for stored in session.records { + self.remove_history_index_if_matches(&session_id, &stored.history_id, stored.id); + } + } + } + + fn remove_history_index_if_matches( + &mut self, + session_id: &SessionId, + history_id: &HistoryId, + capture_id: CaptureId, + ) { + if self + .history_index + .get(history_id) + .is_some_and(|capture_ref| { + &capture_ref.session_id == session_id && capture_ref.capture_id == capture_id + }) + { + self.history_index.remove(history_id); + } + } + + fn stored_capture_mut(&mut self, capture_ref: &CaptureRef) -> Option<&mut StoredCapture> { + self.sessions + .get_mut(&capture_ref.session_id)? + .stored_capture_mut(capture_ref.capture_id) + } + + fn record_count(&self) -> usize { + self.sessions + .values() + .map(|session| session.records.len()) + .sum() + } +} + +impl SessionCaptures { + fn push( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + ) -> (CaptureId, Vec<EvictedCapture>) { + self.push_with_limits( + history_id, + record, + MAX_COMMANDS_PER_SESSION, + MAX_BYTES_PER_SESSION, + ) + } + + fn push_with_limits( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + max_commands: usize, + max_output_bytes: usize, + ) -> (CaptureId, Vec<EvictedCapture>) { + let capture_id = CaptureId(self.next_id); + self.next_id = self.next_id.saturating_add(1); + let output_bytes = record.capture.output.len(); + self.output_bytes = self.output_bytes.saturating_add(output_bytes); + self.records.push_back(StoredCapture { + id: capture_id, + history_id, + output_bytes, + record, + }); + + ( + capture_id, + self.evict_to_limits(max_commands, max_output_bytes), + ) + } + + fn evict_to_limits( + &mut self, + max_commands: usize, + max_output_bytes: usize, + ) -> Vec<EvictedCapture> { + let mut evicted = Vec::new(); + while self.records.len() > max_commands || self.output_bytes > max_output_bytes { + let Some(record) = self.records.pop_front() else { + break; + }; + self.output_bytes = self.output_bytes.saturating_sub(record.output_bytes); + evicted.push(EvictedCapture { + history_id: record.history_id, + capture_id: record.id, + }); + } + evicted + } + + fn stored_capture(&self, capture_id: CaptureId) -> Option<&StoredCapture> { + self.records.iter().find(|record| record.id == capture_id) + } + + fn stored_capture_mut(&mut self, capture_id: CaptureId) -> Option<&mut StoredCapture> { + self.records + .iter_mut() + .find(|record| record.id == capture_id) + } +} + +impl TryFrom<&str> for SessionId { + type Error = (); + + fn try_from(value: &str) -> std::result::Result<Self, Self::Error> { + let value = value.trim(); + if value.is_empty() { + return Err(()); + } + + Ok(Self(value.to_string())) + } +} + +impl TryFrom<String> for SessionId { + type Error = (); + + fn try_from(value: String) -> std::result::Result<Self, Self::Error> { + Self::try_from(value.as_str()) + } +} + +impl AsRef<str> for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Display for SessionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +pub struct SemanticGrpcService { + inner: Arc<SemanticComponentInner>, +} + +#[tonic::async_trait] +impl SemanticSvc for SemanticGrpcService { + #[instrument(skip_all, level = Level::INFO)] + async fn record_commands( + &self, + request: Request<Streaming<CommandCapture>>, + ) -> Result<Response<RecordCommandsReply>, Status> { + let mut stream = request.into_inner(); + let mut accepted = 0_u64; + + while let Some(capture) = stream.message().await? { + if self.inner.record_capture(capture).await { + accepted += 1; + } + } + + Ok(Response::new(RecordCommandsReply { accepted })) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn command_output( + &self, + request: Request<CommandOutputRequest>, + ) -> Result<Response<CommandOutputReply>, Status> { + let request = request.into_inner(); + if request.history_id.trim().is_empty() { + return Err(Status::invalid_argument("history_id is required")); + } + + Ok(Response::new(self.inner.command_output(&request).await)) + } +} + +fn history_id_from_str(value: Option<&str>) -> Option<HistoryId> { + let value = value?.trim(); + (!value.is_empty()).then(|| HistoryId(value.to_string())) +} + +fn take_pending_history( + histories: &mut VecDeque<History>, + history_id: &HistoryId, +) -> Option<History> { + let index = histories + .iter() + .position(|history| &history.id == history_id)?; + histories.remove(index) +} + +fn push_pending_history(histories: &mut VecDeque<History>, history: History) { + if let Some(index) = histories + .iter() + .position(|pending| pending.id == history.id) + { + histories.remove(index); + } + + histories.push_back(history); + trim_front(histories, MAX_PENDING_HISTORIES); +} + +fn trim_front<T>(records: &mut VecDeque<T>, max_len: usize) { + while records.len() > max_len { + records.pop_front(); + } +} + +fn command_output_not_found() -> CommandOutputReply { + CommandOutputReply { + found: false, + output: String::new(), + total_bytes: 0, + total_lines: 0, + lines: Vec::new(), + output_truncated: false, + output_observed_bytes: 0, + } +} + +fn select_output_ranges(output: &str, ranges: &[crate::semantic::OutputRange]) -> Vec<OutputLine> { + let lines: Vec<&str> = output.lines().collect(); + if lines.is_empty() { + return Vec::new(); + } + + let ranges = if ranges.is_empty() { + vec![crate::semantic::OutputRange { start: 0, end: 999 }] + } else { + ranges.to_vec() + }; + + let mut ranges = ranges + .into_iter() + .filter_map(|range| normalize_line_range(range.start, range.end, lines.len())) + .collect::<Vec<_>>(); + ranges.sort_unstable_by_key(|(start, _)| *start); + + let mut merged: Vec<(usize, usize)> = Vec::new(); + for (start, end) in ranges { + match merged.last_mut() { + Some((_, merged_end)) if start <= merged_end.saturating_add(1) => { + *merged_end = (*merged_end).max(end); + } + _ => merged.push((start, end)), + } + } + + merged + .into_iter() + .flat_map(|(start, end)| { + lines[start..=end] + .iter() + .enumerate() + .map(move |(offset, line)| OutputLine { + line_number: (start + offset + 1) as u64, + content: (*line).to_string(), + }) + }) + .collect() +} + +fn normalize_line_range(start: i64, end: i64, line_count: usize) -> Option<(usize, usize)> { + let line_count = i64::try_from(line_count).ok()?; + let start = if start < 0 { line_count + start } else { start }; + let end = if end < 0 { line_count + end } else { end }; + + if end < 0 || start >= line_count { + return None; + } + + let start = start.max(0); + let end = end.min(line_count - 1); + + (start <= end).then_some((start as usize, end as usize)) +} + +fn log_record(record: &SemanticCommandRecord, message: &'static str) { + let history_id = record.capture.history_id.as_deref().unwrap_or("<missing>"); + let associated_history_id = record + .history + .as_ref() + .map(|history| history.id.to_string()); + let exit = record.history.as_ref().map(|history| history.exit); + let duration = record.history.as_ref().map(|history| history.duration); + let author = record + .history + .as_ref() + .map(|history| history.author.as_str()); + let session_id = record.capture.session_id.as_deref(); + + tracing::debug!( + history_id = %history_id, + associated_history_id = ?associated_history_id, + session_id = ?session_id, + command_bytes = record.capture.command.len(), + prompt_bytes = record.capture.prompt.len(), + output_bytes = record.capture.output.len(), + output_truncated = record.capture.output_truncated, + output_observed_bytes = record.capture.output_observed_bytes, + capture_exit_code = ?record.capture.exit_code, + history_exit = ?exit, + duration = ?duration, + author = ?author, + "{message}" + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use time::OffsetDateTime; + + fn history(id: &str, session: &str, command: &str) -> History { + History { + id: HistoryId(id.to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration: 0, + exit: 0, + command: command.to_string(), + cwd: String::new(), + session: session.to_string(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + fn capture(history_id: Option<&str>, session_id: Option<&str>, output: &str) -> CommandCapture { + CommandCapture { + prompt: String::new(), + command: String::new(), + output: output.to_string(), + exit_code: None, + history_id: history_id.map(str::to_string), + session_id: session_id.map(str::to_string), + output_truncated: false, + output_observed_bytes: output.len() as u64, + } + } + + fn command_output(state: &mut SemanticState, history_id: &str) -> CommandOutputReply { + state.command_output(&CommandOutputRequest { + history_id: history_id.to_string(), + ranges: Vec::new(), + }) + } + + fn output_line(line_number: u64, content: &str) -> OutputLine { + OutputLine { + line_number, + content: content.to_string(), + } + } + + #[test] + fn drops_capture_without_history_id() { + let mut state = SemanticState::default(); + + assert!(!state.record_capture(capture(None, Some("session-1"), "output"))); + assert!(!command_output(&mut state, "id-1").found); + assert_eq!(state.record_count(), 0); + } + + #[test] + fn stores_capture_by_session_and_history_id() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.found); + assert_eq!(reply.total_bytes, 6); + assert_eq!(reply.output_observed_bytes, 6); + assert_eq!(reply.lines, vec![output_line(1, "output")]); + } + + #[test] + fn uses_pending_history_session_when_capture_session_is_missing() { + let mut state = SemanticState::default(); + + state.record_history(history("id-1", "session-from-history", "cargo test")); + assert!(state.record_capture(capture(Some("id-1"), None, "output"))); + + assert!( + state + .sessions + .contains_key(&SessionId("session-from-history".to_string())) + ); + assert!(command_output(&mut state, "id-1").found); + } + + #[test] + fn associates_history_by_id_after_capture_arrives() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + state.record_history(history("id-1", "session-1", "different command")); + + let capture_ref = state + .history_index + .get(&HistoryId("id-1".to_string())) + .unwrap(); + let stored = state + .sessions + .get(&capture_ref.session_id) + .unwrap() + .stored_capture(capture_ref.capture_id) + .unwrap(); + assert!(stored.record.history.is_some()); + } + + #[test] + fn evicts_oldest_command_when_session_ring_is_full() { + let mut state = SemanticState::default(); + + for index in 0..=MAX_COMMANDS_PER_SESSION { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some("session-1"), + "output", + ))); + } + + assert!(!command_output(&mut state, "id-0").found); + assert!(command_output(&mut state, &format!("id-{MAX_COMMANDS_PER_SESSION}")).found); + assert_eq!(state.record_count(), MAX_COMMANDS_PER_SESSION); + } + + #[test] + fn evicts_oldest_session_after_lru_limit() { + let mut state = SemanticState::default(); + + for index in 0..MAX_SESSIONS { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some(&format!("session-{index}")), + "output", + ))); + } + assert!(command_output(&mut state, "id-0").found); + + assert!(state.record_capture(capture(Some("new-id"), Some("new-session"), "output",))); + + assert!(command_output(&mut state, "id-0").found); + assert!(!command_output(&mut state, "id-1").found); + assert!(command_output(&mut state, "new-id").found); + assert_eq!(state.sessions.len(), MAX_SESSIONS); + } + + #[test] + fn evicts_by_session_byte_limit() { + let mut session = SessionCaptures::default(); + let first_output = "x".repeat(10); + let second_output = "y"; + let (_, evicted_first) = session.push_with_limits( + HistoryId("first".to_string()), + SemanticCommandRecord { + capture: capture(Some("first"), Some("session-1"), &first_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + assert!(evicted_first.is_empty()); + + let (_, evicted_second) = session.push_with_limits( + HistoryId("second".to_string()), + SemanticCommandRecord { + capture: capture(Some("second"), Some("session-1"), second_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + + assert_eq!(evicted_second.len(), 1); + assert_eq!(evicted_second[0].history_id, HistoryId("first".to_string())); + assert_eq!(session.records.len(), 1); + assert_eq!(session.output_bytes, 1); + } + + #[test] + fn command_output_reports_truncation_metadata() { + let mut state = SemanticState::default(); + let mut capture = capture(Some("id-1"), Some("session-1"), "partial"); + capture.output_truncated = true; + capture.output_observed_bytes = 1024; + + assert!(state.record_capture(capture)); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.output_truncated); + assert_eq!(reply.total_bytes, 7); + assert_eq!(reply.output_observed_bytes, 1024); + } + + #[test] + fn output_ranges_are_line_based_inclusive_and_support_negative_offsets() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::semantic::OutputRange { start: 1, end: 2 }, + crate::semantic::OutputRange { start: -2, end: -1 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(2, "one"), + output_line(3, "two"), + output_line(4, "three"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn output_ranges_merge_overlaps_and_adjacent_ranges() { + let output = (0..100) + .map(|n| format!("line {n}")) + .collect::<Vec<_>>() + .join("\n"); + let ranges = vec![ + crate::semantic::OutputRange { start: 0, end: 100 }, + crate::semantic::OutputRange { + start: -100, + end: -1, + }, + ]; + + let selected = select_output_ranges(&output, &ranges); + + assert_eq!(selected.len(), 100); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(100, "line 99"))); + } + + #[test] + fn output_ranges_can_leave_gaps_for_client_formatting() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::semantic::OutputRange { start: 0, end: 1 }, + crate::semantic::OutputRange { start: 4, end: 4 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(1, "zero"), + output_line(2, "one"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn empty_output_ranges_default_to_first_thousand_lines() { + let output = (0..1001) + .map(|n| format!("line {n}")) + .collect::<Vec<_>>() + .join("\n"); + + let selected = select_output_ranges(&output, &[]); + + assert_eq!(selected.len(), 1000); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(1000, "line 999"))); + } + + #[test] + fn output_ranges_skip_ranges_fully_outside_output() { + let output = "zero\none\ntwo"; + let ranges = vec![ + crate::semantic::OutputRange { start: 10, end: 20 }, + crate::semantic::OutputRange { + start: -20, + end: -10, + }, + ]; + + assert_eq!(select_output_ranges(output, &ranges), Vec::new()); + } +} diff --git a/crates/atuin-daemon/src/lib.rs b/crates/atuin-daemon/src/lib.rs index 84f808e4..27d3932b 100644 --- a/crates/atuin-daemon/src/lib.rs +++ b/crates/atuin-daemon/src/lib.rs @@ -10,6 +10,7 @@ pub mod daemon; pub mod events; pub mod history; pub mod search; +pub mod semantic; pub mod server; // Re-export core daemon types for convenience @@ -17,10 +18,10 @@ pub use daemon::{Component, Daemon, DaemonBuilder, DaemonHandle}; pub use events::DaemonEvent; // Re-export components -pub use components::{HistoryComponent, SearchComponent, SyncComponent}; +pub use components::{HistoryComponent, SearchComponent, SemanticComponent, SyncComponent}; // Re-export client helpers -pub use client::{ControlClient, emit_event, emit_event_with_settings}; +pub use client::{ControlClient, SemanticClient, emit_event, emit_event_with_settings}; /// Boot the daemon using the new component-based architecture. /// @@ -34,12 +35,14 @@ pub async fn boot( // Create the components let history_component = HistoryComponent::new(); let search_component = SearchComponent::new(); + let semantic_component = SemanticComponent::new(); let sync_component = SyncComponent::new(); // Get the gRPC services before moving components into the daemon // (The services share state with the components via Arc) let history_service = history_component.grpc_service(); let search_service = search_component.grpc_service(); + let semantic_service = semantic_component.grpc_service(); // Build the daemon let mut daemon = Daemon::builder(settings.clone()) @@ -47,6 +50,7 @@ pub async fn boot( .history_db(history_db) .component(history_component) .component(search_component) + .component(semantic_component) .component(sync_component) .build() .await?; @@ -93,6 +97,7 @@ pub async fn boot( settings, history_service, search_service, + semantic_service, control_service.into_server(), handle, ) diff --git a/crates/atuin-daemon/src/semantic/mod.rs b/crates/atuin-daemon/src/semantic/mod.rs new file mode 100644 index 00000000..c3511676 --- /dev/null +++ b/crates/atuin-daemon/src/semantic/mod.rs @@ -0,0 +1,3 @@ +//! Semantic command capture gRPC service types. + +tonic::include_proto!("semantic"); diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs index a11de612..b823cff2 100644 --- a/crates/atuin-daemon/src/server.rs +++ b/crates/atuin-daemon/src/server.rs @@ -2,10 +2,12 @@ use eyre::Result; use crate::components::history::HistoryGrpcService; use crate::components::search::SearchGrpcService; +use crate::components::semantic::SemanticGrpcService; use crate::control::{ControlService, control_server::ControlServer}; use crate::daemon::DaemonHandle; use crate::history::history_server::HistoryServer; use crate::search::search_server::SearchServer; +use crate::semantic::semantic_server::SemanticServer; use atuin_client::settings::Settings; @@ -18,6 +20,7 @@ pub async fn run_grpc_server( settings: Settings, history_service: HistoryServer<HistoryGrpcService>, search_service: SearchServer<SearchGrpcService>, + semantic_service: SemanticServer<SemanticGrpcService>, control_service: ControlServer<ControlService>, handle: DaemonHandle, ) -> Result<()> { @@ -101,6 +104,7 @@ pub async fn run_grpc_server( if let Err(e) = Server::builder() .add_service(history_service) .add_service(search_service) + .add_service(semantic_service) .add_service(control_service) .serve_with_incoming_shutdown(uds_stream, shutdown_signal) .await @@ -118,6 +122,7 @@ pub async fn run_grpc_server( settings: Settings, history_service: HistoryServer<HistoryGrpcService>, search_service: SearchServer<SearchGrpcService>, + semantic_service: SemanticServer<SemanticGrpcService>, control_service: ControlServer<ControlService>, handle: DaemonHandle, ) -> Result<()> { @@ -152,6 +157,7 @@ pub async fn run_grpc_server( if let Err(e) = Server::builder() .add_service(history_service) .add_service(search_service) + .add_service(semantic_service) .add_service(control_service) .serve_with_incoming_shutdown(tcp_stream, shutdown_signal) .await diff --git a/crates/atuin-pty-proxy/src/capture.rs b/crates/atuin-pty-proxy/src/capture.rs new file mode 100644 index 00000000..6426035b --- /dev/null +++ b/crates/atuin-pty-proxy/src/capture.rs @@ -0,0 +1,467 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; + +use crate::osc133::{Event, Params, Parser, Zone}; + +const HISTORY_ID_PARAM: &str = "history_id"; +const SESSION_ID_PARAM: &str = "session_id"; +const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandCapture { + pub prompt: String, + pub command: String, + pub output: String, + pub exit_code: Option<i32>, + pub history_id: Option<String>, + pub session_id: Option<String>, + pub output_truncated: bool, + pub output_observed_bytes: u64, +} + +pub type CommandCaptureSink = Box<dyn Fn(CommandCapture) + Send + 'static>; + +#[derive(Default)] +struct CaptureBuffers { + prompt: Vec<u8>, + command: Vec<u8>, + output: Vec<u8>, + output_observed_bytes: u64, + output_truncated: bool, + exit_code: Option<i32>, + history_id: Option<String>, + session_id: Option<String>, +} + +pub(crate) struct CommandCaptureTracker { + parser: Parser, + zone: Zone, + buffers: CaptureBuffers, + cols: Arc<AtomicU16>, +} + +impl CommandCaptureTracker { + pub(crate) fn new(cols: Arc<AtomicU16>) -> Self { + Self { + parser: Parser::new(), + zone: Zone::Unknown, + buffers: CaptureBuffers::default(), + cols, + } + } + + pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + let mut start = 0; + for located in events { + let marker_start = located.start_offset.min(data.len()).max(start); + let offset = located.offset.min(data.len()); + self.append(&data[start..marker_start]); + self.handle_event(located.event, &located.params, &mut on_capture); + self.zone = located.zone; + start = offset; + } + + let append_end = self + .parser + .incomplete_osc_sequence_start() + .map_or(data.len(), |sequence_start| { + sequence_start.min(data.len()).max(start) + }); + if start < append_end { + self.append(&data[start..append_end]); + } + } + + fn append(&mut self, data: &[u8]) { + match self.zone { + Zone::Prompt => self.buffers.prompt.extend_from_slice(data), + Zone::Input => self.buffers.command.extend_from_slice(data), + Zone::Output => self.append_output(data), + Zone::Unknown => {} + } + } + + fn append_output(&mut self, data: &[u8]) { + self.buffers.output_observed_bytes = self + .buffers + .output_observed_bytes + .saturating_add(data.len() as u64); + + if self.buffers.output_truncated { + return; + } + + let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); + let retained = data.len().min(remaining); + self.buffers.output_truncated = retained < data.len(); + + if retained > 0 { + self.buffers.output.extend_from_slice(&data[..retained]); + } + } + + fn handle_event( + &mut self, + event: Event, + params: &Params, + on_capture: &mut impl FnMut(CommandCapture), + ) { + match event { + Event::PromptStart => { + if self.zone != Zone::Prompt { + self.buffers = CaptureBuffers::default(); + } + } + Event::CommandStart | Event::CommandExecuted => {} + Event::CommandFinished { exit_code } => { + let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { + return; + }; + + if exit_code.is_some() || self.buffers.exit_code.is_none() { + self.buffers.exit_code = exit_code; + } + self.buffers.history_id = Some(history_id); + self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); + + if let Some(capture) = self.finish_capture() { + on_capture(capture); + } + } + } + } + + fn finish_capture(&mut self) -> Option<CommandCapture> { + let buffers = std::mem::take(&mut self.buffers); + let cols = self.cols.load(Ordering::Relaxed).max(1); + let prompt = render_plain_text(&buffers.prompt, cols); + let command = render_plain_text(&buffers.command, cols) + .trim_matches(|c| c == '\r' || c == '\n') + .to_string(); + let output = render_plain_text(&buffers.output, cols); + let output_truncated = buffers.output_truncated; + let output_observed_bytes = buffers.output_observed_bytes; + let exit_code = buffers.exit_code; + let history_id = buffers.history_id; + let session_id = buffers.session_id; + + if command.is_empty() && output.is_empty() { + return None; + } + + Some(CommandCapture { + prompt, + command, + output, + exit_code, + history_id, + session_id, + output_truncated, + output_observed_bytes, + }) + } +} + +const CLEAN_TEXT_MAX_ROWS: usize = 10_000; + +fn render_plain_text(bytes: &[u8], cols: u16) -> String { + if bytes.is_empty() { + return String::new(); + } + + let cols = cols.max(1); + let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); + parser.process(bytes); + normalize_screen_contents(&parser.screen().contents()) +} + +fn normalize_screen_contents(contents: &str) -> String { + let mut lines = contents.lines().map(str::trim_end).collect::<Vec<_>>(); + while lines.last().is_some_and(|line| line.is_empty()) { + lines.pop(); + } + lines.join("\n") +} + +fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { + let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; + let wrapped_rows = bytes.len() / cols as usize; + newline_rows + .saturating_add(wrapped_rows) + .saturating_add(1) + .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tracker(cols: u16) -> CommandCaptureTracker { + CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) + } + + fn assert_no_terminal_controls(text: &str) { + assert!( + !text + .chars() + .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), + "text still contains terminal controls: {text:?}" + ); + } + + #[test] + fn command_text_collapses_terminal_echo_edits() { + assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); + assert_eq!( + render_plain_text( + b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", + 80 + ), + "echo hi" + ); + assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); + } + + #[test] + fn text_cleaning_strips_ansi_and_terminal_controls() { + let text = render_plain_text( + b"\x1b[32mhi\x1b[0m\r\n% \r \r", + 80, + ); + + assert_eq!(text, "hi"); + assert_no_terminal_controls(&text); + } + + #[test] + fn text_cleaning_preserves_valid_utf8_after_backspace() { + let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); + + assert_eq!(text, "🦀 crab"); + assert_no_terminal_controls(&text); + } + + #[test] + fn command_text_replays_backspaces() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + let input = + b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; + tracker.push(input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + assert_no_terminal_controls(&captures[0].command); + assert_no_terminal_controls(&captures[0].output); + } + + #[test] + fn captures_complete_command() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "$".to_string(), + command: "echo hi".to_string(), + output: "hi".to_string(), + exit_code: Some(0), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 4, + }] + ); + } + + #[test] + fn strips_ansi_and_split_markers() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); + tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); + tracker.push( + b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", + |capture| { + captures.push(capture); + }, + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "%".to_string(), + command: "ls".to_string(), + output: "file".to_string(), + exit_code: Some(1), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 15, + }] + ); + } + + #[test] + fn duplicate_prompt_start_does_not_reset_prompt_capture() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].prompt, "$ continued"); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + } + + #[test] + fn bare_finish_without_metadata_is_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + + tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); + + assert!(captures.is_empty()); + } + + #[test] + fn bare_finish_before_metadata_in_same_push_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn metadata_arriving_after_bare_finish_across_pushes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { + captures.push(capture) + }); + + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn split_finish_marker_is_not_counted_as_output() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", + |capture| { + captures.push(capture); + }, + ); + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].output_observed_bytes, 10); + } + + #[test] + fn captures_output_with_history_metadata_from_d_marker() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: String::new(), + command: String::new(), + output: "line one".to_string(), + exit_code: Some(0), + history_id: Some("018f".to_string()), + session_id: Some("abcd".to_string()), + output_truncated: false, + output_observed_bytes: 10, + }] + ); + } + + #[test] + fn output_capture_is_capped_and_reports_observed_bytes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + let mut input = b"\x1b]133;C\x07".to_vec(); + input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); + input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); + + tracker.push(&input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert!(captures[0].output_truncated); + assert_eq!( + captures[0].output_observed_bytes, + (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 + ); + } + + #[test] + fn resets_buffers_between_c_d_only_captures() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 2); + assert_eq!(captures[0].output, "first"); + assert_eq!(captures[0].history_id.as_deref(), Some("one")); + assert_eq!(captures[1].output, "second"); + assert_eq!(captures[1].history_id.as_deref(), Some("two")); + } +} diff --git a/crates/atuin-pty-proxy/src/debug.rs b/crates/atuin-pty-proxy/src/debug.rs new file mode 100644 index 00000000..806bde90 --- /dev/null +++ b/crates/atuin-pty-proxy/src/debug.rs @@ -0,0 +1,53 @@ +use crate::osc133::{Event, Parser}; + +pub(crate) const RESET: &[u8] = b"\x1b[0m"; + +pub(crate) struct Osc133DebugHighlighter { + parser: Parser, +} + +impl Osc133DebugHighlighter { + pub(crate) fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub(crate) fn render(&mut self, data: &[u8]) -> Vec<u8> { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + if events.is_empty() { + return data.to_vec(); + } + + let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); + let mut start = 0; + + for located in events { + let offset = located.offset.min(data.len()); + if offset > start { + rendered.extend_from_slice(&data[start..offset]); + } + + rendered.extend_from_slice(event_label(&located.event)); + rendered.extend_from_slice(RESET); + start = offset; + } + + rendered.extend_from_slice(&data[start..]); + rendered + } +} + +fn event_label(event: &Event) -> &'static [u8] { + match event { + Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", + Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", + Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", + Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", + Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", + Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", + } +} diff --git a/crates/atuin-pty-proxy/src/lib.rs b/crates/atuin-pty-proxy/src/lib.rs index 16b29dff..65b03df3 100644 --- a/crates/atuin-pty-proxy/src/lib.rs +++ b/crates/atuin-pty-proxy/src/lib.rs @@ -1,478 +1,48 @@ -pub mod osc133; - -use clap::{Args, Subcommand, ValueEnum}; - -#[derive(Subcommand, Debug)] -pub enum Cmd { - /// Print shell code to initialize atuin pty-proxy on shell startup - Init(Init), -} - -#[derive(Args, Debug)] -pub struct Init { - /// Shell to generate init for. If omitted, attempt auto-detection - #[arg(value_enum)] - shell: Option<Shell>, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] -#[value(rename_all = "lower")] -#[allow(clippy::enum_variant_names, clippy::doc_markdown)] -enum Shell { - /// Zsh setup - Zsh, - /// Bash setup - Bash, - /// Fish setup - Fish, - /// Nu setup - Nu, -} - -impl Init { - fn run(self) -> Result<(), String> { - let shell = detect_shell(self.shell)?; - let script = render_init(shell); - print!("{script}"); - Ok(()) - } -} - -pub fn run(cmd: Option<Cmd>) { - match cmd { - Some(Cmd::Init(init)) => { - if let Err(err) = init.run() { - eprintln!("atuin pty-proxy: {err}"); - std::process::exit(1); - } - } - None => app::main(), - } -} - -fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> { - if let Some(shell) = cli_shell { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("ATUIN_SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - Err( - "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" - .to_string(), - ) -} - -fn shell_from_name(name: &str) -> Option<Shell> { - let shell = name - .trim() - .rsplit('/') - .next() - .unwrap_or(name) - .trim_start_matches('-') - .to_ascii_lowercase(); - - match shell.as_str() { - "bash" => Some(Shell::Bash), - "zsh" => Some(Shell::Zsh), - "fish" => Some(Shell::Fish), - "nu" => Some(Shell::Nu), - _ => None, - } -} - -fn render_init(shell: Shell) -> &'static str { - match shell { - Shell::Bash | Shell::Zsh => { - r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then - _atuin_pty_proxy_tmux_current="${TMUX:-}" - _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-${ATUIN_HEX_TMUX:-}}" - - if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-${ATUIN_HEX_ACTIVE:-}}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then - export ATUIN_PTY_PROXY_ACTIVE=1 - export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - fi - - unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous -fi -"# - } - Shell::Fish => { - r#"if status is-interactive; and test -t 0; and test -t 1 - set -l _atuin_pty_proxy_tmux_current "" - if set -q TMUX - set _atuin_pty_proxy_tmux_current "$TMUX" - end - - set -l _atuin_pty_proxy_tmux_previous "" - if set -q ATUIN_PTY_PROXY_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" - else if set -q ATUIN_HEX_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_HEX_TMUX" - end - - if not set -q ATUIN_PTY_PROXY_ACTIVE; and not set -q ATUIN_HEX_ACTIVE - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - end -end -"# - } - // Nushell cannot dynamically source the output of `atuin init nu`, - // so we only output the pty-proxy preamble here. Users must also set up - // `atuin init nu` separately. - Shell::Nu => { - r#"if (is-terminal --stdin) and (is-terminal --stdout) { - let tmux_current = ($env.TMUX? | default "") - let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default ($env.ATUIN_HEX_TMUX? | default "")) - - if (($env.ATUIN_PTY_PROXY_ACTIVE? | default ($env.ATUIN_HEX_ACTIVE? | default "")) | is-empty) or ($tmux_current != $tmux_previous) { - $env.ATUIN_PTY_PROXY_ACTIVE = "1" - $env.ATUIN_PTY_PROXY_TMUX = $tmux_current - exec atuin pty-proxy - } -} -"# - } - } -} - -#[cfg(not(unix))] -mod app { - pub(crate) fn main() { - eprintln!("atuin pty-proxy currently supports unix platforms"); - std::process::exit(1); - } -} - #[cfg(unix)] -mod app { - use std::io::{Read, Write}; - use std::os::unix::net::UnixListener; - use std::sync::mpsc; - - use crossterm::terminal; - use portable_pty::{CommandBuilder, PtySize, native_pty_system}; - - enum ParserMsg { - Data(Vec<u8>), - Resize { rows: u16, cols: u16 }, - ScreenRequest(mpsc::Sender<Vec<u8>>), - } - - pub(crate) fn main() { - if let Err(e) = run() { - let _ = terminal::disable_raw_mode(); - eprintln!("atuin pty-proxy: {e:#}"); - std::process::exit(1); - } - } - - fn socket_path() -> std::path::PathBuf { - let dir = std::env::temp_dir(); - dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) - } - - /// Wire format written to the Unix socket: - /// - /// ```text - /// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] - /// [row_0_len: u32 BE][row_0_bytes...] - /// [row_1_len: u32 BE][row_1_bytes...] - /// ... - /// ``` - /// - /// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain - /// pre-built ANSI escape sequences. The client can write them directly to - /// stdout without needing its own vt100 parser. - fn encode_screen(parser: &vt100::Parser) -> Vec<u8> { - let screen = parser.screen(); - let (rows, cols) = screen.size(); - let (cursor_row, cursor_col) = screen.cursor_position(); - - let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize)); - buf.extend_from_slice(&rows.to_be_bytes()); - buf.extend_from_slice(&cols.to_be_bytes()); - buf.extend_from_slice(&cursor_row.to_be_bytes()); - buf.extend_from_slice(&cursor_col.to_be_bytes()); - - for row_bytes in screen.rows_formatted(0, cols) { - let len = row_bytes.len() as u32; - buf.extend_from_slice(&len.to_be_bytes()); - buf.extend_from_slice(&row_bytes); - } - - buf - } - - fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) { - match msg { - ParserMsg::Data(data) => parser.process(&data), - ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), - ParserMsg::ScreenRequest(reply_tx) => { - let _ = reply_tx.send(encode_screen(parser)); - } - } - } - - fn run() -> eyre::Result<()> { - let (cols, rows) = terminal::size()?; - - let pty_system = native_pty_system(); - let pair = pty_system - .openpty(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Set up socket path and expose it to child processes - let sock_path = socket_path(); - // Clean up any stale socket from a previous crash - let _ = std::fs::remove_file(&sock_path); - - let mut cmd = CommandBuilder::new_default_prog(); - cmd.cwd(std::env::current_dir()?); - cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); - cmd.env("ATUIN_HEX_SOCKET", sock_path.as_os_str()); - - let mut child = pair - .slave - .spawn_command(cmd) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Close slave side in parent process - drop(pair.slave); - - let mut pty_reader = pair - .master - .try_clone_reader() - .map_err(|e| eyre::eyre!("{e:#}"))?; - let mut pty_writer = pair - .master - .take_writer() - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Channel: stdout/sigwinch/socket threads -> parser thread (bounded, non-blocking send) - let (msg_tx, msg_rx) = mpsc::sync_channel::<ParserMsg>(64); - - // --- Parser thread --- - // Maintains a persistent vt100::Parser fed bytes as they arrive. - // On screen request: reads current state directly (no replay). - std::thread::spawn(move || { - let mut parser = vt100::Parser::new(rows, cols, 0); - - loop { - // Block until at least one message arrives - let first = match msg_rx.recv() { - Ok(msg) => msg, - Err(_) => break, - }; - - handle_parser_msg(&mut parser, first); - - // Drain all remaining pending messages so the parser stays - // caught up during high-throughput bursts (e.g. `cat bigfile`). - // The channel holds at most 64 items, so this is bounded. - while let Ok(msg) = msg_rx.try_recv() { - handle_parser_msg(&mut parser, msg); - } - } - }); - - // --- Socket server thread --- - // Listens on Unix socket; on connection, requests screen state from parser thread. - { - let sock_path_clone = sock_path.clone(); - let screen_tx = msg_tx.clone(); - std::thread::spawn(move || { - let listener = match UnixListener::bind(&sock_path_clone) { - Ok(l) => l, - Err(e) => { - eprintln!("atuin pty-proxy: failed to bind socket: {e}"); - return; - } - }; - - for stream in listener.incoming() { - let mut stream = match stream { - Ok(s) => s, - Err(_) => break, - }; - - let (reply_tx, reply_rx) = mpsc::channel(); - if screen_tx.send(ParserMsg::ScreenRequest(reply_tx)).is_err() { - break; - } - if let Ok(data) = reply_rx.recv() { - let _ = stream.write_all(&data); - let _ = stream.flush(); - } - } - }); - } - - // Handle terminal resize via SIGWINCH - { - use signal_hook::consts::SIGWINCH; - use signal_hook::iterator::Signals; - - let master = pair.master; - let resize_tx = msg_tx.clone(); - let mut signals = Signals::new([SIGWINCH])?; - - std::thread::spawn(move || { - for _ in signals.forever() { - if let Ok((cols, rows)) = terminal::size() { - let _ = master.resize(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }); - let _ = resize_tx.try_send(ParserMsg::Resize { rows, cols }); - } - } - }); - } - - terminal::enable_raw_mode()?; - - // PTY -> stdout (with OSC 133 parsing + buffer feed) - let stdout_thread = std::thread::spawn(move || { - let mut stdout = std::io::stdout(); - let mut parser = crate::osc133::Parser::new(); - let mut buf = [0u8; 8192]; - loop { - match pty_reader.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - parser.push(&buf[..n], |_event| { - // Zone transitions are tracked inside the parser. - // Callers can query parser.zone() after push. - }); - - // Feed bytes to the shadow parser. Drops on backpressure — - // the screen snapshot may be stale during bursts, but - // self-corrects once output settles. - let _ = msg_tx.try_send(ParserMsg::Data(buf[..n].to_vec())); - - if stdout.write_all(&buf[..n]).is_err() { - break; - } - let _ = stdout.flush(); - } - } - } - }); - - // stdin -> PTY - std::thread::spawn(move || { - let mut stdin = std::io::stdin(); - let mut buf = [0u8; 8192]; - loop { - match stdin.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - if pty_writer.write_all(&buf[..n]).is_err() { - break; - } - } - } - } - }); +mod capture; +#[cfg(unix)] +mod debug; +#[cfg(unix)] +mod osc133; +#[cfg(unix)] +mod pty_proxy; +#[cfg(unix)] +mod runtime; +#[cfg(unix)] +mod screen; - let status = child.wait()?; - let _ = stdout_thread.join(); +#[cfg(unix)] +pub use capture::{CommandCapture, CommandCaptureSink}; +#[cfg(unix)] +pub use pty_proxy::PtyProxy; - let _ = terminal::disable_raw_mode(); +#[cfg(not(unix))] +#[allow(dead_code)] +mod unsupported { + use clap::{Args, Subcommand}; - // Clean up socket file - let _ = std::fs::remove_file(&sock_path); + #[derive(Args, Debug)] + pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, - std::process::exit(process_exit_code(status.exit_code())); + #[command(subcommand)] + cmd: Option<Cmd>, } - fn process_exit_code(code: u32) -> i32 { - i32::try_from(code).unwrap_or(1) + #[derive(Subcommand, Debug)] + enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), } - #[cfg(test)] - mod tests { - use super::process_exit_code; - - #[test] - fn process_exit_code_preserves_valid_values() { - assert_eq!(process_exit_code(0), 0); - assert_eq!(process_exit_code(127), 127); - assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); - } - - #[test] - fn process_exit_code_defaults_when_out_of_range() { - assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); - } + #[derive(Args, Debug)] + struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + shell: Option<String>, } } -#[cfg(test)] -mod tests { - use super::{Shell, render_init, shell_from_name}; - - #[test] - fn shell_from_name_handles_paths() { - assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); - assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); - assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); - assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); - } - - #[test] - fn posix_init_uses_exec_and_tmux_guard() { - let script = render_init(Shell::Bash); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(!script.contains("eval \"$(atuin init bash)\"")); - } - - #[test] - fn posix_init_has_no_double_braces() { - let script = render_init(Shell::Bash); - assert!(!script.contains("${{"), "double braces in bash init script"); - } - - #[test] - fn fish_init_uses_source() { - let script = render_init(Shell::Fish); - assert!(script.contains("exec atuin pty-proxy")); - assert!(!script.contains("atuin init fish | source")); - } - - #[test] - fn nu_init_uses_exec_and_tty_guard() { - let script = render_init(Shell::Nu); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(script.contains("is-terminal --stdin")); - assert!(script.contains("is-terminal --stdout")); - assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); - } -} +#[cfg(not(unix))] +pub use unsupported::PtyProxy; diff --git a/crates/atuin-pty-proxy/src/osc133.rs b/crates/atuin-pty-proxy/src/osc133.rs index d6ee1220..51fda848 100644 --- a/crates/atuin-pty-proxy/src/osc133.rs +++ b/crates/atuin-pty-proxy/src/osc133.rs @@ -9,18 +9,19 @@ //! | C | Command submitted — output begins | //! | D[;n] | Command finished with exit code *n* | //! -//! The wire format is `ESC ] 133 ; <cmd> [; <params>] ST` where ST is either -//! BEL (0x07) or ESC \ (0x1B 0x5C). +//! The wire format is `ESC ] 133 ; <cmd> [; <params>] ST` where ST is BEL +//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). //! //! # Design goals //! -//! * **Zero-copy** — the parser observes the byte stream without buffering or -//! modifying it. -//! * **Zero-alloc** — after construction no heap allocation occurs. +//! * **Transparent** — the parser observes the byte stream without modifying it; +//! the caller remains responsible for forwarding bytes to their destination. +//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot +//! grow memory without limit. //! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available //! and returns immediately. -//! * **Transparent** — the caller is responsible for forwarding bytes to their -//! destination; the parser only emits [`Event`]s through a callback. +//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata +//! can ride alongside standard OSC 133 markers. /// Events emitted when an OSC 133 marker is detected. #[derive(Debug, Clone, PartialEq, Eq)] @@ -38,6 +39,63 @@ pub enum Event { }, } +/// Parameters attached to an OSC 133 marker. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Params { + items: Vec<Param>, +} + +impl Params { + /// Iterate over all marker parameters in order. + #[cfg(test)] + #[inline] + pub fn iter(&self) -> impl Iterator<Item = &Param> { + self.items.iter() + } + + /// Return the value for the first `key=value` parameter with this key. + #[inline] + pub fn get(&self, key: &str) -> Option<&str> { + self.items.iter().find_map(|item| match item { + Param::KeyValue { + key: item_key, + value, + } if item_key == key => Some(value.as_str()), + Param::Value(_) | Param::KeyValue { .. } => None, + }) + } +} + +/// A single OSC 133 marker parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Param { + /// A positional parameter without an equals sign. + Value(String), + /// A `key=value` parameter. + KeyValue { key: String, value: String }, +} + +/// An OSC 133 event with its position in the most recent input chunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocatedEvent { + /// The OSC 133 event that was parsed. + pub event: Event, + /// Offset where this marker starts in the current chunk. + /// + /// If a marker started in an earlier [`Parser::push_located`] call, this is + /// `0` in the chunk that completed the marker. + pub start_offset: usize, + /// Offset immediately after this marker's terminator in the current chunk. + /// + /// If a marker spans multiple [`Parser::push_located`] calls, this is still + /// the offset in the chunk that completed the marker. + pub offset: usize, + /// The semantic zone after applying this event. + pub zone: Zone, + /// Metadata parameters attached to this marker. + pub params: Params, +} + /// The current semantic zone as determined by the most recent OSC 133 marker. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] #[allow(dead_code)] @@ -59,14 +117,14 @@ pub enum Zone { const ESC: u8 = 0x1B; const BEL: u8 = 0x07; +const C1_ST: u8 = 0x9C; const BACKSLASH: u8 = b'\\'; const RIGHT_BRACKET: u8 = b']'; -/// Maximum bytes we'll buffer for the OSC parameter string. 32 bytes is far -/// more than any valid OSC 133 payload needs (e.g. `133;D;127` is 9 bytes). -/// Longer (non-133) OSC sequences simply stop accumulating once the buffer is -/// full — the dispatch logic will harmlessly ignore them. -const PARAM_BUF_CAP: usize = 32; +/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough +/// for Atuin metadata such as history/session IDs while still bounding malformed +/// OSC sequences. +const PARAM_BUF_CAP: usize = 512; // --------------------------------------------------------------------------- // State machine @@ -94,6 +152,7 @@ enum State { pub struct Parser { state: State, zone: Zone, + sequence_start: Option<usize>, param_buf: [u8; PARAM_BUF_CAP], param_len: usize, } @@ -111,6 +170,7 @@ impl Parser { Self { state: State::Ground, zone: Zone::Unknown, + sequence_start: None, param_buf: [0u8; PARAM_BUF_CAP], param_len: 0, } @@ -123,18 +183,40 @@ impl Parser { self.zone } + /// Start offset of an incomplete OSC sequence in the most recent chunk. + #[inline] + pub(crate) fn incomplete_osc_sequence_start(&self) -> Option<usize> { + matches!(self.state, State::OscParam | State::OscEsc) + .then(|| self.sequence_start.unwrap_or(0)) + } + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker /// found. /// /// All bytes in `data` should still be forwarded to the terminal by the /// caller — this method only *observes* the stream. + #[cfg(test)] #[inline] pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { - for &byte in data { + self.push_located(data, |located| on_event(located.event)); + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found with its byte offset in this chunk. + /// + /// The offset points to the first byte after the marker terminator, making + /// it suitable for callers that need to split the original chunk at marker + /// boundaries. + #[inline] + pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { + self.sequence_start = (self.state != State::Ground).then_some(0); + + for (offset, &byte) in data.iter().enumerate() { match self.state { State::Ground => { if byte == ESC { self.state = State::Esc; + self.sequence_start = Some(offset); } } State::Esc => { @@ -143,12 +225,14 @@ impl Parser { self.param_len = 0; } else { self.state = State::Ground; + self.sequence_start = None; } } State::OscParam => { - if byte == BEL { - self.dispatch(&mut on_event); + if byte == BEL || byte == C1_ST { + self.dispatch(offset + 1, &mut on_event); self.state = State::Ground; + self.sequence_start = None; } else if byte == ESC { self.state = State::OscEsc; } else if self.param_len < PARAM_BUF_CAP { @@ -160,12 +244,13 @@ impl Parser { } State::OscEsc => { if byte == BACKSLASH { - self.dispatch(&mut on_event); + self.dispatch(offset + 1, &mut on_event); } // Whether we got a valid ST or not, return to ground. // (A new ESC ] would restart accumulation via the Ground // -> Esc -> OscParam path on the *next* byte.) self.state = State::Ground; + self.sequence_start = None; } } } @@ -174,46 +259,104 @@ impl Parser { /// Inspect the accumulated parameter buffer. If it holds an OSC 133 /// payload, emit the corresponding [`Event`] and update the zone. #[inline] - fn dispatch(&mut self, on_event: &mut impl FnMut(Event)) { - let params = &self.param_buf[..self.param_len]; + fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { + let payload = &self.param_buf[..self.param_len]; + + if payload.len() < 5 || &payload[..4] != b"133;" { + return; + } - // Must start with "133;" - if params.len() < 5 || ¶ms[..4] != b"133;" { + if payload.len() > 5 && payload[5] != b';' { return; } - let cmd = params[4]; - let event = match cmd { + let metadata = payload.get(6..).unwrap_or_default(); + let cmd = payload[4]; + let (event, params) = match cmd { b'A' => { self.zone = Zone::Prompt; - Event::PromptStart + (Event::PromptStart, parse_params(metadata)) } b'B' => { self.zone = Zone::Input; - Event::CommandStart + (Event::CommandStart, parse_params(metadata)) } b'C' => { self.zone = Zone::Output; - Event::CommandExecuted + (Event::CommandExecuted, parse_params(metadata)) } b'D' => { - let exit_code = if params.len() > 6 && params[5] == b';' { - std::str::from_utf8(¶ms[6..]) - .ok() - .and_then(|s| s.parse::<i32>().ok()) - } else { - None - }; + let (exit_code, params) = parse_command_finished_params(metadata); self.zone = Zone::Unknown; - Event::CommandFinished { exit_code } + (Event::CommandFinished { exit_code }, params) } _ => return, }; - on_event(event); + on_event(LocatedEvent { + event, + start_offset: self.sequence_start.unwrap_or(0), + offset, + zone: self.zone, + params, + }); } } +fn parse_command_finished_params(metadata: &[u8]) -> (Option<i32>, Params) { + if metadata.is_empty() { + return (None, Params::default()); + } + + let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { + return parse_exit_code(metadata).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), Params::default()), + ); + }; + + let (first, rest) = metadata.split_at(separator); + let rest = &rest[1..]; + + parse_exit_code(first).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), parse_params(rest)), + ) +} + +fn parse_exit_code(code: &[u8]) -> Option<i32> { + if code.is_empty() { + return None; + } + + std::str::from_utf8(code) + .ok() + .and_then(|code| code.parse::<i32>().ok()) +} + +fn parse_params(metadata: &[u8]) -> Params { + let items = metadata + .split(|byte| *byte == b';') + .filter(|part| !part.is_empty()) + .map(parse_param) + .collect(); + + Params { items } +} + +fn parse_param(param: &[u8]) -> Param { + let param = String::from_utf8_lossy(param); + + if let Some((key, value)) = param.split_once('=') { + return Param::KeyValue { + key: key.to_string(), + value: value.to_string(), + }; + } + + Param::Value(param.into_owned()) +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -468,6 +611,12 @@ mod tests { assert!(parse_events(data).is_empty()); } + #[test] + fn marker_with_unexpected_trailing_bytes_ignored() { + let data = b"\x1b]133;ABC\x07"; + assert!(parse_events(data).is_empty()); + } + // -- Malformed sequences -------------------------------------------------- #[test] @@ -509,7 +658,7 @@ mod tests { fn very_long_osc_does_not_panic() { let mut data = Vec::new(); data.extend_from_slice(b"\x1b]"); - data.extend(std::iter::repeat(b'x').take(1000)); + data.extend(std::iter::repeat_n(b'x', 1000)); data.push(BEL); // Should not panic and should produce no event. assert!(parse_events(&data).is_empty()); @@ -589,6 +738,100 @@ mod tests { ); } + #[test] + fn detects_c1_st_terminator() { + let data = b"\x1b]133;A\x9c"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + // -- Located event offsets ------------------------------------------------ + + #[test] + fn located_event_reports_offset_after_marker() { + let data = b"before\x1b]133;A\x07prompt"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(data, |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::PromptStart, + start_offset: b"before".len(), + offset: b"before\x1b]133;A\x07".len(), + zone: Zone::Prompt, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_offset_is_relative_to_completing_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;", |e| events.push(e)); + parser.push_located(b"D;42\x07after", |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::CommandFinished { + exit_code: Some(42) + }, + start_offset: 0, + offset: b"D;42\x07".len(), + zone: Zone::Unknown, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_preserves_metadata_params() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located( + b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", + |event| events.push(event), + ); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!( + event.event, + Event::CommandFinished { + exit_code: Some(127) + } + ); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + assert!( + event + .params + .iter() + .any(|param| param == &Param::Value("flag".to_string())) + ); + } + + #[test] + fn command_finished_metadata_without_exit_code_is_preserved() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { + events.push(event); + }); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!(event.event, Event::CommandFinished { exit_code: None }); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + } + // -- Default trait -------------------------------------------------------- #[test] diff --git a/crates/atuin-pty-proxy/src/pty_proxy.rs b/crates/atuin-pty-proxy/src/pty_proxy.rs new file mode 100644 index 00000000..030ef9b5 --- /dev/null +++ b/crates/atuin-pty-proxy/src/pty_proxy.rs @@ -0,0 +1,231 @@ +use clap::{Args, Subcommand, ValueEnum}; + +use crate::{CommandCaptureSink, runtime}; + +#[derive(Args, Debug)] +pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, + + #[command(subcommand)] + cmd: Option<Cmd>, +} + +#[derive(Subcommand, Debug)] +pub enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), +} + +#[derive(Args, Debug)] +pub struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + #[arg(value_enum)] + shell: Option<Shell>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +#[value(rename_all = "lower")] +#[allow(clippy::enum_variant_names, clippy::doc_markdown)] +enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, +} + +pub(crate) struct RuntimeOptions { + pub(crate) debug_osc133: bool, + pub(crate) command_capture_sink: Option<CommandCaptureSink>, +} + +impl RuntimeOptions { + fn new(debug_osc133: bool, command_capture_sink: Option<CommandCaptureSink>) -> Self { + Self { + debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), + command_capture_sink, + } + } +} + +impl PtyProxy { + pub fn run(self, command_capture_sink: Option<CommandCaptureSink>) { + match self.cmd { + Some(Cmd::Init(init)) => { + if let Err(err) = init.run() { + eprintln!("atuin pty-proxy: {err}"); + std::process::exit(1); + } + } + None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), + } + } +} + +impl Init { + fn run(self) -> Result<(), String> { + let shell = detect_shell(self.shell)?; + let script = render_init(shell); + print!("{script}"); + Ok(()) + } +} + +fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> { + if let Some(shell) = cli_shell { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("ATUIN_SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + Err( + "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" + .to_string(), + ) +} + +fn shell_from_name(name: &str) -> Option<Shell> { + let shell = name + .trim() + .rsplit('/') + .next() + .unwrap_or(name) + .trim_start_matches('-') + .to_ascii_lowercase(); + + match shell.as_str() { + "bash" => Some(Shell::Bash), + "zsh" => Some(Shell::Zsh), + "fish" => Some(Shell::Fish), + "nu" => Some(Shell::Nu), + _ => None, + } +} + +fn env_flag(name: &str) -> bool { + std::env::var(name).is_ok_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn render_init(shell: Shell) -> &'static str { + match shell { + Shell::Bash | Shell::Zsh => { + r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then + _atuin_pty_proxy_tmux_current="${TMUX:-}" + _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" + + if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then + export ATUIN_PTY_PROXY_ACTIVE=1 + export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + fi + + unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous +fi +"# + } + Shell::Fish => { + r#"if status is-interactive; and test -t 0; and test -t 1 + set -l _atuin_pty_proxy_tmux_current "" + if set -q TMUX + set _atuin_pty_proxy_tmux_current "$TMUX" + end + + set -l _atuin_pty_proxy_tmux_previous "" + if set -q ATUIN_PTY_PROXY_TMUX + set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" + end + + if not set -q ATUIN_PTY_PROXY_ACTIVE + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + end +end +"# + } + // Nushell cannot dynamically source the output of `atuin init nu`, + // so we only output the pty-proxy preamble here. Users must also set up + // `atuin init nu` separately. + Shell::Nu => { + r#"if (is-terminal --stdin) and (is-terminal --stdout) { + let tmux_current = ($env.TMUX? | default "") + let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") + + if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { + $env.ATUIN_PTY_PROXY_ACTIVE = "1" + $env.ATUIN_PTY_PROXY_TMUX = $tmux_current + exec atuin pty-proxy + } +} +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::{Shell, render_init, shell_from_name}; + + #[test] + fn shell_from_name_handles_paths() { + assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); + assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); + assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); + assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); + } + + #[test] + fn posix_init_uses_exec_and_tmux_guard() { + let script = render_init(Shell::Bash); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(!script.contains("eval \"$(atuin init bash)\"")); + } + + #[test] + fn posix_init_has_no_double_braces() { + let script = render_init(Shell::Bash); + assert!(!script.contains("${{"), "double braces in bash init script"); + } + + #[test] + fn fish_init_uses_source() { + let script = render_init(Shell::Fish); + assert!(script.contains("exec atuin pty-proxy")); + assert!(!script.contains("atuin init fish | source")); + } + + #[test] + fn nu_init_uses_exec_and_tty_guard() { + let script = render_init(Shell::Nu); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(script.contains("is-terminal --stdin")); + assert!(script.contains("is-terminal --stdout")); + assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); + } +} diff --git a/crates/atuin-pty-proxy/src/runtime.rs b/crates/atuin-pty-proxy/src/runtime.rs new file mode 100644 index 00000000..2b34fbb7 --- /dev/null +++ b/crates/atuin-pty-proxy/src/runtime.rs @@ -0,0 +1,184 @@ +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::mpsc; + +use crossterm::terminal; +use portable_pty::{CommandBuilder, PtySize, native_pty_system}; + +use crate::capture::CommandCaptureTracker; +use crate::debug::{Osc133DebugHighlighter, RESET}; +use crate::pty_proxy::RuntimeOptions; +use crate::screen::{self, Msg}; + +pub(crate) fn main(options: RuntimeOptions) { + if let Err(e) = run(options) { + let _ = terminal::disable_raw_mode(); + eprintln!("atuin pty-proxy: {e:#}"); + std::process::exit(1); + } +} + +fn run(options: RuntimeOptions) -> eyre::Result<()> { + let (cols, rows) = terminal::size()?; + + let pty_system = native_pty_system(); + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let sock_path = screen::socket_path(); + let _ = std::fs::remove_file(&sock_path); + + let mut cmd = CommandBuilder::new_default_prog(); + cmd.cwd(std::env::current_dir()?); + cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); + cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); + + let mut child = pair + .slave + .spawn_command(cmd) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + drop(pair.slave); + + let mut pty_reader = pair + .master + .try_clone_reader() + .map_err(|e| eyre::eyre!("{e:#}"))?; + let mut pty_writer = pair + .master + .take_writer() + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let (msg_tx, msg_rx) = mpsc::sync_channel::<Msg>(64); + let current_cols = Arc::new(AtomicU16::new(cols.max(1))); + + screen::spawn_parser_thread(rows, cols, msg_rx); + screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); + spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; + + terminal::enable_raw_mode()?; + + let stdout_thread = std::thread::spawn(move || { + let mut stdout = std::io::stdout(); + let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); + let mut capture_tracker = options + .command_capture_sink + .as_ref() + .map(|_| CommandCaptureTracker::new(current_cols)); + let mut buf = [0u8; 8192]; + + loop { + match pty_reader.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if let (Some(tracker), Some(sink)) = ( + capture_tracker.as_mut(), + options.command_capture_sink.as_ref(), + ) { + tracker.push(&buf[..n], sink); + } + + if let Some(highlighter) = highlighter.as_mut() { + let rendered = highlighter.render(&buf[..n]); + let _ = msg_tx.try_send(Msg::Data(rendered.clone())); + + if stdout.write_all(&rendered).is_err() { + break; + } + } else { + let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); + + if stdout.write_all(&buf[..n]).is_err() { + break; + } + } + let _ = stdout.flush(); + } + } + } + + if highlighter.is_some() { + let _ = stdout.write_all(RESET); + let _ = stdout.flush(); + } + }); + + std::thread::spawn(move || { + let mut stdin = std::io::stdin(); + let mut buf = [0u8; 8192]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if pty_writer.write_all(&buf[..n]).is_err() { + break; + } + } + } + } + }); + + let status = child.wait()?; + let _ = stdout_thread.join(); + + let _ = terminal::disable_raw_mode(); + let _ = std::fs::remove_file(&sock_path); + + std::process::exit(process_exit_code(status.exit_code())); +} + +fn spawn_resize_handler( + master: Box<dyn portable_pty::MasterPty + Send>, + resize_tx: mpsc::SyncSender<Msg>, + current_cols: Arc<AtomicU16>, +) -> eyre::Result<()> { + use signal_hook::consts::SIGWINCH; + use signal_hook::iterator::Signals; + + let mut signals = Signals::new([SIGWINCH])?; + + std::thread::spawn(move || { + for _ in signals.forever() { + if let Ok((cols, rows)) = terminal::size() { + current_cols.store(cols.max(1), Ordering::Relaxed); + let _ = master.resize(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }); + let _ = resize_tx.try_send(Msg::Resize { rows, cols }); + } + } + }); + + Ok(()) +} + +fn process_exit_code(code: u32) -> i32 { + i32::try_from(code).unwrap_or(1) +} + +#[cfg(test)] +mod tests { + use super::process_exit_code; + + #[test] + fn process_exit_code_preserves_valid_values() { + assert_eq!(process_exit_code(0), 0); + assert_eq!(process_exit_code(127), 127); + assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); + } + + #[test] + fn process_exit_code_defaults_when_out_of_range() { + assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); + } +} diff --git a/crates/atuin-pty-proxy/src/screen.rs b/crates/atuin-pty-proxy/src/screen.rs new file mode 100644 index 00000000..5b892e21 --- /dev/null +++ b/crates/atuin-pty-proxy/src/screen.rs @@ -0,0 +1,104 @@ +use std::io::Write; +use std::os::unix::net::UnixListener; +use std::path::PathBuf; +use std::sync::mpsc::{self, Receiver, SyncSender}; + +pub(crate) enum Msg { + Data(Vec<u8>), + Resize { rows: u16, cols: u16 }, + ScreenRequest(mpsc::Sender<Vec<u8>>), +} + +pub(crate) fn socket_path() -> PathBuf { + let dir = std::env::temp_dir(); + dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) +} + +pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver<Msg>) { + std::thread::spawn(move || { + let mut parser = vt100::Parser::new(rows, cols, 0); + + loop { + let first = match msg_rx.recv() { + Ok(msg) => msg, + Err(_) => break, + }; + + handle_parser_msg(&mut parser, first); + + while let Ok(msg) = msg_rx.try_recv() { + handle_parser_msg(&mut parser, msg); + } + } + }); +} + +pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender<Msg>) { + std::thread::spawn(move || { + let listener = match UnixListener::bind(&sock_path) { + Ok(l) => l, + Err(e) => { + eprintln!("atuin pty-proxy: failed to bind socket: {e}"); + return; + } + }; + + for stream in listener.incoming() { + let mut stream = match stream { + Ok(s) => s, + Err(_) => break, + }; + + let (reply_tx, reply_rx) = mpsc::channel(); + if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { + break; + } + if let Ok(data) = reply_rx.recv() { + let _ = stream.write_all(&data); + let _ = stream.flush(); + } + } + }); +} + +/// Wire format written to the Unix socket: +/// +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +/// +/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain +/// pre-built ANSI escape sequences. The client can write them directly to +/// stdout without needing its own vt100 parser. +fn encode_screen(parser: &vt100::Parser) -> Vec<u8> { + let screen = parser.screen(); + let (rows, cols) = screen.size(); + let (cursor_row, cursor_col) = screen.cursor_position(); + + let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize)); + buf.extend_from_slice(&rows.to_be_bytes()); + buf.extend_from_slice(&cols.to_be_bytes()); + buf.extend_from_slice(&cursor_row.to_be_bytes()); + buf.extend_from_slice(&cursor_col.to_be_bytes()); + + for row_bytes in screen.rows_formatted(0, cols) { + let len = row_bytes.len() as u32; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(&row_bytes); + } + + buf +} + +fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { + match msg { + Msg::Data(data) => parser.process(&data), + Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), + Msg::ScreenRequest(reply_tx) => { + let _ = reply_tx.send(encode_screen(parser)); + } + } +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 0ac1e889..8e425232 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -36,7 +36,7 @@ atuin = { path = "/usr/bin/atuin" } default = ["client", "sync", "clipboard", "check-update", "daemon", "ai", "pty-proxy"] client = ["atuin-client"] sync = ["atuin-client/sync"] -daemon = ["atuin-client/daemon", "atuin-daemon"] +daemon = ["atuin-client/daemon", "atuin-daemon", "atuin-ai?/daemon"] ai = ["atuin-ai"] pty-proxy = ["dep:atuin-pty-proxy"] hex = ["pty-proxy"] diff --git a/crates/atuin/src/command/mod.rs b/crates/atuin/src/command/mod.rs index 7deb72d6..6cd221a4 100644 --- a/crates/atuin/src/command/mod.rs +++ b/crates/atuin/src/command/mod.rs @@ -24,10 +24,7 @@ pub enum AtuinCmd { /// PTY proxy for atuin #[cfg(feature = "pty-proxy")] #[command(alias = "hex")] - PtyProxy { - #[command(subcommand)] - cmd: Option<atuin_pty_proxy::Cmd>, - }, + PtyProxy(atuin_pty_proxy::PtyProxy), /// Generate a UUID Uuid, @@ -56,8 +53,8 @@ impl AtuinCmd { Self::Client(client) => client.run(), #[cfg(feature = "pty-proxy")] - Self::PtyProxy { cmd } => { - atuin_pty_proxy::run(cmd); + Self::PtyProxy(proxy) => { + run_pty_proxy(proxy); Ok(()) } @@ -74,3 +71,92 @@ impl AtuinCmd { } } } + +#[cfg(all(feature = "pty-proxy", unix))] +fn run_pty_proxy(proxy: atuin_pty_proxy::PtyProxy) { + #[cfg(feature = "daemon")] + proxy.run(semantic_command_capture_sink()); + + #[cfg(not(feature = "daemon"))] + proxy.run(None); +} + +#[cfg(all(feature = "pty-proxy", not(unix)))] +fn run_pty_proxy(_proxy: atuin_pty_proxy::PtyProxy) { + eprintln!("atuin pty-proxy currently supports unix platforms"); + std::process::exit(1); +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +fn semantic_command_capture_sink() -> Option<atuin_pty_proxy::CommandCaptureSink> { + use std::sync::mpsc; + use std::time::Duration; + + if is_truthy_env("ATUIN_TERMINAL") { + return None; + } + + let settings = atuin_client::settings::Settings::new().ok()?; + let (tx, rx) = mpsc::sync_channel::<atuin_pty_proxy::CommandCapture>(128); + + std::thread::spawn(move || { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + while let Ok(first) = rx.recv() { + let mut batch = vec![first]; + + while batch.len() < 64 { + match rx.recv_timeout(Duration::from_millis(25)) { + Ok(capture) => batch.push(capture), + Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { + break; + } + } + } + + runtime.block_on(send_semantic_command_captures(&settings, batch)); + } + }); + + Some(Box::new(move |capture| { + let _ = tx.try_send(capture); + })) +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +#[inline] +fn is_truthy_env(name: &str) -> bool { + std::env::var(name) + .ok() + .as_ref() + .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +async fn send_semantic_command_captures( + settings: &atuin_client::settings::Settings, + batch: Vec<atuin_pty_proxy::CommandCapture>, +) { + let captures = batch + .into_iter() + .map(|capture| atuin_daemon::semantic::CommandCapture { + prompt: capture.prompt, + command: capture.command, + output: capture.output, + exit_code: capture.exit_code, + history_id: capture.history_id, + session_id: capture.session_id, + output_truncated: capture.output_truncated, + output_observed_bytes: capture.output_observed_bytes, + }) + .collect(); + + if let Ok(mut client) = atuin_daemon::SemanticClient::from_settings(settings).await { + let _ = client.record_commands(captures).await; + } +} diff --git a/crates/atuin/src/shell/atuin.bash b/crates/atuin/src/shell/atuin.bash index 45fdced9..8b540bd7 100644 --- a/crates/atuin/src/shell/atuin.bash +++ b/crates/atuin/src/shell/atuin.bash @@ -20,6 +20,35 @@ fi ATUIN_STTY=$(stty -g) ATUIN_HISTORY_ID="" +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'\001\033]133;A;cl=line\a\002' +__atuin_osc133_prompt_end=$'\001\033]133;B\a\002' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PS1-}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PS1="${__atuin_osc133_prompt_start}${__atuin_prompt}${__atuin_osc133_prompt_end}" + else + PS1="$__atuin_prompt" + fi +} + export ATUIN_PREEXEC_BACKEND=$SHLVL:none __atuin_update_preexec_backend() { if [[ ${BLE_ATTACHED-} ]]; then @@ -59,15 +88,19 @@ __atuin_preexec() { local id id=$(atuin history start -- "$1" 2>/dev/null) export ATUIN_HISTORY_ID=$id + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_executed __atuin_preexec_time=${EPOCHREALTIME-} } __atuin_precmd() { local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} + __atuin_osc133_wrap_prompt + [[ ! $ATUIN_HISTORY_ID ]] && return # If the previous preexec hook failed, we manually call __atuin_preexec + local __atuin_skip_osc133="" if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then # This is the command extraction code taken from bash-preexec local previous_command @@ -75,6 +108,7 @@ __atuin_precmd() { export LC_ALL=C HISTTIMEFORMAT='' builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' ) + __atuin_skip_osc133=1 __atuin_preexec "$previous_command" fi @@ -106,6 +140,7 @@ __atuin_precmd() { fi fi + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_finished "$EXIT" (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 export ATUIN_HISTORY_ID="" } diff --git a/crates/atuin/src/shell/atuin.fish b/crates/atuin/src/shell/atuin.fish index ddf55f3d..15b33451 100644 --- a/crates/atuin/src/shell/atuin.fish +++ b/crates/atuin/src/shell/atuin.fish @@ -4,9 +4,24 @@ if not set -q ATUIN_SESSION; or test "$ATUIN_SHLVL" != "$SHLVL" end set --erase ATUIN_HISTORY_ID +function _atuin_osc133_command_executed + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;C\a' +end + +function _atuin_osc133_command_finished --argument-names exit_code + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$exit_code" "$ATUIN_HISTORY_ID" "$ATUIN_SESSION" +end + function _atuin_preexec --on-event fish_preexec if not test -n "$fish_private_mode" set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]" 2>/dev/null) + _atuin_osc133_command_executed end end @@ -14,6 +29,7 @@ function _atuin_postexec --on-event fish_postexec set -l s $status if test -n "$ATUIN_HISTORY_ID" + _atuin_osc133_command_finished $s ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & disown end diff --git a/crates/atuin/src/shell/atuin.nu b/crates/atuin/src/shell/atuin.nu index c1a38313..d37457e4 100644 --- a/crates/atuin/src/shell/atuin.nu +++ b/crates/atuin/src/shell/atuin.nu @@ -14,6 +14,28 @@ if 'ATUIN_SESSION' not-in $env or ('ATUIN_SHLVL' not-in $env) or ($env.ATUIN_SHL } hide-env -i ATUIN_HISTORY_ID +def _atuin_osc133_command_executed [] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;C(char bel)" +} + +def _atuin_osc133_command_finished [exit_code: int] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;D;($exit_code);history_id=($env.ATUIN_HISTORY_ID);session_id=($env.ATUIN_SESSION)(char bel)" +} + # Magic token to make sure we don't record commands run by keybindings let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" @@ -27,6 +49,7 @@ let _atuin_pre_execution = {|| } if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd | complete | get stdout | str trim) + _atuin_osc133_command_executed } } @@ -35,6 +58,7 @@ let _atuin_pre_prompt = {|| if 'ATUIN_HISTORY_ID' not-in $env { return } + _atuin_osc133_command_finished $last_exit with-env { ATUIN_LOG: error } { if (version).minor >= 104 or (version).major > 0 { job spawn { diff --git a/crates/atuin/src/shell/atuin.zsh b/crates/atuin/src/shell/atuin.zsh index 87f47531..7a7375aa 100644 --- a/crates/atuin/src/shell/atuin.zsh +++ b/crates/atuin/src/shell/atuin.zsh @@ -31,16 +31,54 @@ if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then fi ATUIN_HISTORY_ID="" +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'%{\033]133;A;cl=line\a%}' +__atuin_osc133_prompt_end=$'%{\033]133;B\a%}' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PROMPT-}" + local __atuin_rprompt="${RPROMPT-}" + + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_start/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PROMPT="${__atuin_osc133_prompt_start}${__atuin_prompt}" + RPROMPT="${__atuin_rprompt}${__atuin_osc133_prompt_end}" + else + PROMPT="$__atuin_prompt" + RPROMPT="$__atuin_rprompt" + fi +} + _atuin_preexec() { local id id=$(atuin history start -- "$1" 2>/dev/null) export ATUIN_HISTORY_ID="$id" + __atuin_osc133_command_executed __atuin_preexec_time=${EPOCHREALTIME-} } _atuin_precmd() { local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} + __atuin_osc133_wrap_prompt + [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return local duration="" @@ -48,6 +86,7 @@ _atuin_precmd() { printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) fi + __atuin_osc133_command_finished "$EXIT" (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 export ATUIN_HISTORY_ID="" } diff --git a/docs/docs/ai/settings.md b/docs/docs/ai/settings.md index a8d3dab3..edc54aaf 100644 --- a/docs/docs/ai/settings.md +++ b/docs/docs/ai/settings.md @@ -1,6 +1,6 @@ # AI Settings -All the settings that control the behavior of [Atuin AI](./introduction.md) are specified in an `[ai]` section in your `config.toml`. See [the configuration documentation](../../configuration/config/) for more detailed information about Atuin's configuration system. +All the settings that control the behavior of [Atuin AI](./introduction.md) are specified in an `[ai]` section in your `config.toml`. See [the configuration documentation](../configuration/config.md) for more detailed information about Atuin's configuration system. ### enabled @@ -42,6 +42,12 @@ Default: `true` Whether or not to include the "history search" capability in the context sent to the LLM. This allows the AI to request to search your Atuin history for relevant commands when generating suggestions or answering questions. +### enable_history_output + +Default: `true` + +Whether or not to include the "history output" capability in the context sent to the LLM. This allows the AI to request to view the output of previous commands. This requires the [pty-proxy](../reference/pty-proxy.md) and [daemon](../reference/daemon.md) to be enabled and running in order for Atuin to capture commands' outputs. + ### enable_file_tools Default: `true` |
