diff options
Diffstat (limited to 'crates/atuin-pty-proxy')
| -rw-r--r-- | crates/atuin-pty-proxy/src/capture.rs | 467 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/debug.rs | 53 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/lib.rs | 502 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/osc133.rs | 313 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/pty_proxy.rs | 231 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/runtime.rs | 184 | ||||
| -rw-r--r-- | crates/atuin-pty-proxy/src/screen.rs | 104 |
7 files changed, 1353 insertions, 501 deletions
diff --git a/crates/atuin-pty-proxy/src/capture.rs b/crates/atuin-pty-proxy/src/capture.rs new file mode 100644 index 00000000..6426035b --- /dev/null +++ b/crates/atuin-pty-proxy/src/capture.rs @@ -0,0 +1,467 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; + +use crate::osc133::{Event, Params, Parser, Zone}; + +const HISTORY_ID_PARAM: &str = "history_id"; +const SESSION_ID_PARAM: &str = "session_id"; +const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandCapture { + pub prompt: String, + pub command: String, + pub output: String, + pub exit_code: Option<i32>, + pub history_id: Option<String>, + pub session_id: Option<String>, + pub output_truncated: bool, + pub output_observed_bytes: u64, +} + +pub type CommandCaptureSink = Box<dyn Fn(CommandCapture) + Send + 'static>; + +#[derive(Default)] +struct CaptureBuffers { + prompt: Vec<u8>, + command: Vec<u8>, + output: Vec<u8>, + output_observed_bytes: u64, + output_truncated: bool, + exit_code: Option<i32>, + history_id: Option<String>, + session_id: Option<String>, +} + +pub(crate) struct CommandCaptureTracker { + parser: Parser, + zone: Zone, + buffers: CaptureBuffers, + cols: Arc<AtomicU16>, +} + +impl CommandCaptureTracker { + pub(crate) fn new(cols: Arc<AtomicU16>) -> Self { + Self { + parser: Parser::new(), + zone: Zone::Unknown, + buffers: CaptureBuffers::default(), + cols, + } + } + + pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + let mut start = 0; + for located in events { + let marker_start = located.start_offset.min(data.len()).max(start); + let offset = located.offset.min(data.len()); + self.append(&data[start..marker_start]); + self.handle_event(located.event, &located.params, &mut on_capture); + self.zone = located.zone; + start = offset; + } + + let append_end = self + .parser + .incomplete_osc_sequence_start() + .map_or(data.len(), |sequence_start| { + sequence_start.min(data.len()).max(start) + }); + if start < append_end { + self.append(&data[start..append_end]); + } + } + + fn append(&mut self, data: &[u8]) { + match self.zone { + Zone::Prompt => self.buffers.prompt.extend_from_slice(data), + Zone::Input => self.buffers.command.extend_from_slice(data), + Zone::Output => self.append_output(data), + Zone::Unknown => {} + } + } + + fn append_output(&mut self, data: &[u8]) { + self.buffers.output_observed_bytes = self + .buffers + .output_observed_bytes + .saturating_add(data.len() as u64); + + if self.buffers.output_truncated { + return; + } + + let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); + let retained = data.len().min(remaining); + self.buffers.output_truncated = retained < data.len(); + + if retained > 0 { + self.buffers.output.extend_from_slice(&data[..retained]); + } + } + + fn handle_event( + &mut self, + event: Event, + params: &Params, + on_capture: &mut impl FnMut(CommandCapture), + ) { + match event { + Event::PromptStart => { + if self.zone != Zone::Prompt { + self.buffers = CaptureBuffers::default(); + } + } + Event::CommandStart | Event::CommandExecuted => {} + Event::CommandFinished { exit_code } => { + let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { + return; + }; + + if exit_code.is_some() || self.buffers.exit_code.is_none() { + self.buffers.exit_code = exit_code; + } + self.buffers.history_id = Some(history_id); + self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); + + if let Some(capture) = self.finish_capture() { + on_capture(capture); + } + } + } + } + + fn finish_capture(&mut self) -> Option<CommandCapture> { + let buffers = std::mem::take(&mut self.buffers); + let cols = self.cols.load(Ordering::Relaxed).max(1); + let prompt = render_plain_text(&buffers.prompt, cols); + let command = render_plain_text(&buffers.command, cols) + .trim_matches(|c| c == '\r' || c == '\n') + .to_string(); + let output = render_plain_text(&buffers.output, cols); + let output_truncated = buffers.output_truncated; + let output_observed_bytes = buffers.output_observed_bytes; + let exit_code = buffers.exit_code; + let history_id = buffers.history_id; + let session_id = buffers.session_id; + + if command.is_empty() && output.is_empty() { + return None; + } + + Some(CommandCapture { + prompt, + command, + output, + exit_code, + history_id, + session_id, + output_truncated, + output_observed_bytes, + }) + } +} + +const CLEAN_TEXT_MAX_ROWS: usize = 10_000; + +fn render_plain_text(bytes: &[u8], cols: u16) -> String { + if bytes.is_empty() { + return String::new(); + } + + let cols = cols.max(1); + let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); + parser.process(bytes); + normalize_screen_contents(&parser.screen().contents()) +} + +fn normalize_screen_contents(contents: &str) -> String { + let mut lines = contents.lines().map(str::trim_end).collect::<Vec<_>>(); + while lines.last().is_some_and(|line| line.is_empty()) { + lines.pop(); + } + lines.join("\n") +} + +fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { + let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; + let wrapped_rows = bytes.len() / cols as usize; + newline_rows + .saturating_add(wrapped_rows) + .saturating_add(1) + .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tracker(cols: u16) -> CommandCaptureTracker { + CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) + } + + fn assert_no_terminal_controls(text: &str) { + assert!( + !text + .chars() + .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), + "text still contains terminal controls: {text:?}" + ); + } + + #[test] + fn command_text_collapses_terminal_echo_edits() { + assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); + assert_eq!( + render_plain_text( + b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", + 80 + ), + "echo hi" + ); + assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); + } + + #[test] + fn text_cleaning_strips_ansi_and_terminal_controls() { + let text = render_plain_text( + b"\x1b[32mhi\x1b[0m\r\n% \r \r", + 80, + ); + + assert_eq!(text, "hi"); + assert_no_terminal_controls(&text); + } + + #[test] + fn text_cleaning_preserves_valid_utf8_after_backspace() { + let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); + + assert_eq!(text, "🦀 crab"); + assert_no_terminal_controls(&text); + } + + #[test] + fn command_text_replays_backspaces() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + let input = + b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; + tracker.push(input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + assert_no_terminal_controls(&captures[0].command); + assert_no_terminal_controls(&captures[0].output); + } + + #[test] + fn captures_complete_command() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "$".to_string(), + command: "echo hi".to_string(), + output: "hi".to_string(), + exit_code: Some(0), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 4, + }] + ); + } + + #[test] + fn strips_ansi_and_split_markers() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); + tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); + tracker.push( + b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", + |capture| { + captures.push(capture); + }, + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "%".to_string(), + command: "ls".to_string(), + output: "file".to_string(), + exit_code: Some(1), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 15, + }] + ); + } + + #[test] + fn duplicate_prompt_start_does_not_reset_prompt_capture() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].prompt, "$ continued"); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + } + + #[test] + fn bare_finish_without_metadata_is_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + + tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); + + assert!(captures.is_empty()); + } + + #[test] + fn bare_finish_before_metadata_in_same_push_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn metadata_arriving_after_bare_finish_across_pushes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { + captures.push(capture) + }); + + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn split_finish_marker_is_not_counted_as_output() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", + |capture| { + captures.push(capture); + }, + ); + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].output_observed_bytes, 10); + } + + #[test] + fn captures_output_with_history_metadata_from_d_marker() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: String::new(), + command: String::new(), + output: "line one".to_string(), + exit_code: Some(0), + history_id: Some("018f".to_string()), + session_id: Some("abcd".to_string()), + output_truncated: false, + output_observed_bytes: 10, + }] + ); + } + + #[test] + fn output_capture_is_capped_and_reports_observed_bytes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + let mut input = b"\x1b]133;C\x07".to_vec(); + input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); + input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); + + tracker.push(&input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert!(captures[0].output_truncated); + assert_eq!( + captures[0].output_observed_bytes, + (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 + ); + } + + #[test] + fn resets_buffers_between_c_d_only_captures() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 2); + assert_eq!(captures[0].output, "first"); + assert_eq!(captures[0].history_id.as_deref(), Some("one")); + assert_eq!(captures[1].output, "second"); + assert_eq!(captures[1].history_id.as_deref(), Some("two")); + } +} diff --git a/crates/atuin-pty-proxy/src/debug.rs b/crates/atuin-pty-proxy/src/debug.rs new file mode 100644 index 00000000..806bde90 --- /dev/null +++ b/crates/atuin-pty-proxy/src/debug.rs @@ -0,0 +1,53 @@ +use crate::osc133::{Event, Parser}; + +pub(crate) const RESET: &[u8] = b"\x1b[0m"; + +pub(crate) struct Osc133DebugHighlighter { + parser: Parser, +} + +impl Osc133DebugHighlighter { + pub(crate) fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub(crate) fn render(&mut self, data: &[u8]) -> Vec<u8> { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + if events.is_empty() { + return data.to_vec(); + } + + let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); + let mut start = 0; + + for located in events { + let offset = located.offset.min(data.len()); + if offset > start { + rendered.extend_from_slice(&data[start..offset]); + } + + rendered.extend_from_slice(event_label(&located.event)); + rendered.extend_from_slice(RESET); + start = offset; + } + + rendered.extend_from_slice(&data[start..]); + rendered + } +} + +fn event_label(event: &Event) -> &'static [u8] { + match event { + Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", + Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", + Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", + Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", + Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", + Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", + } +} diff --git a/crates/atuin-pty-proxy/src/lib.rs b/crates/atuin-pty-proxy/src/lib.rs index 16b29dff..65b03df3 100644 --- a/crates/atuin-pty-proxy/src/lib.rs +++ b/crates/atuin-pty-proxy/src/lib.rs @@ -1,478 +1,48 @@ -pub mod osc133; - -use clap::{Args, Subcommand, ValueEnum}; - -#[derive(Subcommand, Debug)] -pub enum Cmd { - /// Print shell code to initialize atuin pty-proxy on shell startup - Init(Init), -} - -#[derive(Args, Debug)] -pub struct Init { - /// Shell to generate init for. If omitted, attempt auto-detection - #[arg(value_enum)] - shell: Option<Shell>, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] -#[value(rename_all = "lower")] -#[allow(clippy::enum_variant_names, clippy::doc_markdown)] -enum Shell { - /// Zsh setup - Zsh, - /// Bash setup - Bash, - /// Fish setup - Fish, - /// Nu setup - Nu, -} - -impl Init { - fn run(self) -> Result<(), String> { - let shell = detect_shell(self.shell)?; - let script = render_init(shell); - print!("{script}"); - Ok(()) - } -} - -pub fn run(cmd: Option<Cmd>) { - match cmd { - Some(Cmd::Init(init)) => { - if let Err(err) = init.run() { - eprintln!("atuin pty-proxy: {err}"); - std::process::exit(1); - } - } - None => app::main(), - } -} - -fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> { - if let Some(shell) = cli_shell { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("ATUIN_SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - Err( - "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" - .to_string(), - ) -} - -fn shell_from_name(name: &str) -> Option<Shell> { - let shell = name - .trim() - .rsplit('/') - .next() - .unwrap_or(name) - .trim_start_matches('-') - .to_ascii_lowercase(); - - match shell.as_str() { - "bash" => Some(Shell::Bash), - "zsh" => Some(Shell::Zsh), - "fish" => Some(Shell::Fish), - "nu" => Some(Shell::Nu), - _ => None, - } -} - -fn render_init(shell: Shell) -> &'static str { - match shell { - Shell::Bash | Shell::Zsh => { - r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then - _atuin_pty_proxy_tmux_current="${TMUX:-}" - _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-${ATUIN_HEX_TMUX:-}}" - - if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-${ATUIN_HEX_ACTIVE:-}}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then - export ATUIN_PTY_PROXY_ACTIVE=1 - export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - fi - - unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous -fi -"# - } - Shell::Fish => { - r#"if status is-interactive; and test -t 0; and test -t 1 - set -l _atuin_pty_proxy_tmux_current "" - if set -q TMUX - set _atuin_pty_proxy_tmux_current "$TMUX" - end - - set -l _atuin_pty_proxy_tmux_previous "" - if set -q ATUIN_PTY_PROXY_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" - else if set -q ATUIN_HEX_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_HEX_TMUX" - end - - if not set -q ATUIN_PTY_PROXY_ACTIVE; and not set -q ATUIN_HEX_ACTIVE - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - end -end -"# - } - // Nushell cannot dynamically source the output of `atuin init nu`, - // so we only output the pty-proxy preamble here. Users must also set up - // `atuin init nu` separately. - Shell::Nu => { - r#"if (is-terminal --stdin) and (is-terminal --stdout) { - let tmux_current = ($env.TMUX? | default "") - let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default ($env.ATUIN_HEX_TMUX? | default "")) - - if (($env.ATUIN_PTY_PROXY_ACTIVE? | default ($env.ATUIN_HEX_ACTIVE? | default "")) | is-empty) or ($tmux_current != $tmux_previous) { - $env.ATUIN_PTY_PROXY_ACTIVE = "1" - $env.ATUIN_PTY_PROXY_TMUX = $tmux_current - exec atuin pty-proxy - } -} -"# - } - } -} - -#[cfg(not(unix))] -mod app { - pub(crate) fn main() { - eprintln!("atuin pty-proxy currently supports unix platforms"); - std::process::exit(1); - } -} - #[cfg(unix)] -mod app { - use std::io::{Read, Write}; - use std::os::unix::net::UnixListener; - use std::sync::mpsc; - - use crossterm::terminal; - use portable_pty::{CommandBuilder, PtySize, native_pty_system}; - - enum ParserMsg { - Data(Vec<u8>), - Resize { rows: u16, cols: u16 }, - ScreenRequest(mpsc::Sender<Vec<u8>>), - } - - pub(crate) fn main() { - if let Err(e) = run() { - let _ = terminal::disable_raw_mode(); - eprintln!("atuin pty-proxy: {e:#}"); - std::process::exit(1); - } - } - - fn socket_path() -> std::path::PathBuf { - let dir = std::env::temp_dir(); - dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) - } - - /// Wire format written to the Unix socket: - /// - /// ```text - /// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] - /// [row_0_len: u32 BE][row_0_bytes...] - /// [row_1_len: u32 BE][row_1_bytes...] - /// ... - /// ``` - /// - /// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain - /// pre-built ANSI escape sequences. The client can write them directly to - /// stdout without needing its own vt100 parser. - fn encode_screen(parser: &vt100::Parser) -> Vec<u8> { - let screen = parser.screen(); - let (rows, cols) = screen.size(); - let (cursor_row, cursor_col) = screen.cursor_position(); - - let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize)); - buf.extend_from_slice(&rows.to_be_bytes()); - buf.extend_from_slice(&cols.to_be_bytes()); - buf.extend_from_slice(&cursor_row.to_be_bytes()); - buf.extend_from_slice(&cursor_col.to_be_bytes()); - - for row_bytes in screen.rows_formatted(0, cols) { - let len = row_bytes.len() as u32; - buf.extend_from_slice(&len.to_be_bytes()); - buf.extend_from_slice(&row_bytes); - } - - buf - } - - fn handle_parser_msg(parser: &mut vt100::Parser, msg: ParserMsg) { - match msg { - ParserMsg::Data(data) => parser.process(&data), - ParserMsg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), - ParserMsg::ScreenRequest(reply_tx) => { - let _ = reply_tx.send(encode_screen(parser)); - } - } - } - - fn run() -> eyre::Result<()> { - let (cols, rows) = terminal::size()?; - - let pty_system = native_pty_system(); - let pair = pty_system - .openpty(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Set up socket path and expose it to child processes - let sock_path = socket_path(); - // Clean up any stale socket from a previous crash - let _ = std::fs::remove_file(&sock_path); - - let mut cmd = CommandBuilder::new_default_prog(); - cmd.cwd(std::env::current_dir()?); - cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); - cmd.env("ATUIN_HEX_SOCKET", sock_path.as_os_str()); - - let mut child = pair - .slave - .spawn_command(cmd) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Close slave side in parent process - drop(pair.slave); - - let mut pty_reader = pair - .master - .try_clone_reader() - .map_err(|e| eyre::eyre!("{e:#}"))?; - let mut pty_writer = pair - .master - .take_writer() - .map_err(|e| eyre::eyre!("{e:#}"))?; - - // Channel: stdout/sigwinch/socket threads -> parser thread (bounded, non-blocking send) - let (msg_tx, msg_rx) = mpsc::sync_channel::<ParserMsg>(64); - - // --- Parser thread --- - // Maintains a persistent vt100::Parser fed bytes as they arrive. - // On screen request: reads current state directly (no replay). - std::thread::spawn(move || { - let mut parser = vt100::Parser::new(rows, cols, 0); - - loop { - // Block until at least one message arrives - let first = match msg_rx.recv() { - Ok(msg) => msg, - Err(_) => break, - }; - - handle_parser_msg(&mut parser, first); - - // Drain all remaining pending messages so the parser stays - // caught up during high-throughput bursts (e.g. `cat bigfile`). - // The channel holds at most 64 items, so this is bounded. - while let Ok(msg) = msg_rx.try_recv() { - handle_parser_msg(&mut parser, msg); - } - } - }); - - // --- Socket server thread --- - // Listens on Unix socket; on connection, requests screen state from parser thread. - { - let sock_path_clone = sock_path.clone(); - let screen_tx = msg_tx.clone(); - std::thread::spawn(move || { - let listener = match UnixListener::bind(&sock_path_clone) { - Ok(l) => l, - Err(e) => { - eprintln!("atuin pty-proxy: failed to bind socket: {e}"); - return; - } - }; - - for stream in listener.incoming() { - let mut stream = match stream { - Ok(s) => s, - Err(_) => break, - }; - - let (reply_tx, reply_rx) = mpsc::channel(); - if screen_tx.send(ParserMsg::ScreenRequest(reply_tx)).is_err() { - break; - } - if let Ok(data) = reply_rx.recv() { - let _ = stream.write_all(&data); - let _ = stream.flush(); - } - } - }); - } - - // Handle terminal resize via SIGWINCH - { - use signal_hook::consts::SIGWINCH; - use signal_hook::iterator::Signals; - - let master = pair.master; - let resize_tx = msg_tx.clone(); - let mut signals = Signals::new([SIGWINCH])?; - - std::thread::spawn(move || { - for _ in signals.forever() { - if let Ok((cols, rows)) = terminal::size() { - let _ = master.resize(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }); - let _ = resize_tx.try_send(ParserMsg::Resize { rows, cols }); - } - } - }); - } - - terminal::enable_raw_mode()?; - - // PTY -> stdout (with OSC 133 parsing + buffer feed) - let stdout_thread = std::thread::spawn(move || { - let mut stdout = std::io::stdout(); - let mut parser = crate::osc133::Parser::new(); - let mut buf = [0u8; 8192]; - loop { - match pty_reader.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - parser.push(&buf[..n], |_event| { - // Zone transitions are tracked inside the parser. - // Callers can query parser.zone() after push. - }); - - // Feed bytes to the shadow parser. Drops on backpressure — - // the screen snapshot may be stale during bursts, but - // self-corrects once output settles. - let _ = msg_tx.try_send(ParserMsg::Data(buf[..n].to_vec())); - - if stdout.write_all(&buf[..n]).is_err() { - break; - } - let _ = stdout.flush(); - } - } - } - }); - - // stdin -> PTY - std::thread::spawn(move || { - let mut stdin = std::io::stdin(); - let mut buf = [0u8; 8192]; - loop { - match stdin.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - if pty_writer.write_all(&buf[..n]).is_err() { - break; - } - } - } - } - }); +mod capture; +#[cfg(unix)] +mod debug; +#[cfg(unix)] +mod osc133; +#[cfg(unix)] +mod pty_proxy; +#[cfg(unix)] +mod runtime; +#[cfg(unix)] +mod screen; - let status = child.wait()?; - let _ = stdout_thread.join(); +#[cfg(unix)] +pub use capture::{CommandCapture, CommandCaptureSink}; +#[cfg(unix)] +pub use pty_proxy::PtyProxy; - let _ = terminal::disable_raw_mode(); +#[cfg(not(unix))] +#[allow(dead_code)] +mod unsupported { + use clap::{Args, Subcommand}; - // Clean up socket file - let _ = std::fs::remove_file(&sock_path); + #[derive(Args, Debug)] + pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, - std::process::exit(process_exit_code(status.exit_code())); + #[command(subcommand)] + cmd: Option<Cmd>, } - fn process_exit_code(code: u32) -> i32 { - i32::try_from(code).unwrap_or(1) + #[derive(Subcommand, Debug)] + enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), } - #[cfg(test)] - mod tests { - use super::process_exit_code; - - #[test] - fn process_exit_code_preserves_valid_values() { - assert_eq!(process_exit_code(0), 0); - assert_eq!(process_exit_code(127), 127); - assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); - } - - #[test] - fn process_exit_code_defaults_when_out_of_range() { - assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); - } + #[derive(Args, Debug)] + struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + shell: Option<String>, } } -#[cfg(test)] -mod tests { - use super::{Shell, render_init, shell_from_name}; - - #[test] - fn shell_from_name_handles_paths() { - assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); - assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); - assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); - assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); - } - - #[test] - fn posix_init_uses_exec_and_tmux_guard() { - let script = render_init(Shell::Bash); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(!script.contains("eval \"$(atuin init bash)\"")); - } - - #[test] - fn posix_init_has_no_double_braces() { - let script = render_init(Shell::Bash); - assert!(!script.contains("${{"), "double braces in bash init script"); - } - - #[test] - fn fish_init_uses_source() { - let script = render_init(Shell::Fish); - assert!(script.contains("exec atuin pty-proxy")); - assert!(!script.contains("atuin init fish | source")); - } - - #[test] - fn nu_init_uses_exec_and_tty_guard() { - let script = render_init(Shell::Nu); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(script.contains("is-terminal --stdin")); - assert!(script.contains("is-terminal --stdout")); - assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); - } -} +#[cfg(not(unix))] +pub use unsupported::PtyProxy; diff --git a/crates/atuin-pty-proxy/src/osc133.rs b/crates/atuin-pty-proxy/src/osc133.rs index d6ee1220..51fda848 100644 --- a/crates/atuin-pty-proxy/src/osc133.rs +++ b/crates/atuin-pty-proxy/src/osc133.rs @@ -9,18 +9,19 @@ //! | C | Command submitted — output begins | //! | D[;n] | Command finished with exit code *n* | //! -//! The wire format is `ESC ] 133 ; <cmd> [; <params>] ST` where ST is either -//! BEL (0x07) or ESC \ (0x1B 0x5C). +//! The wire format is `ESC ] 133 ; <cmd> [; <params>] ST` where ST is BEL +//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). //! //! # Design goals //! -//! * **Zero-copy** — the parser observes the byte stream without buffering or -//! modifying it. -//! * **Zero-alloc** — after construction no heap allocation occurs. +//! * **Transparent** — the parser observes the byte stream without modifying it; +//! the caller remains responsible for forwarding bytes to their destination. +//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot +//! grow memory without limit. //! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available //! and returns immediately. -//! * **Transparent** — the caller is responsible for forwarding bytes to their -//! destination; the parser only emits [`Event`]s through a callback. +//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata +//! can ride alongside standard OSC 133 markers. /// Events emitted when an OSC 133 marker is detected. #[derive(Debug, Clone, PartialEq, Eq)] @@ -38,6 +39,63 @@ pub enum Event { }, } +/// Parameters attached to an OSC 133 marker. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Params { + items: Vec<Param>, +} + +impl Params { + /// Iterate over all marker parameters in order. + #[cfg(test)] + #[inline] + pub fn iter(&self) -> impl Iterator<Item = &Param> { + self.items.iter() + } + + /// Return the value for the first `key=value` parameter with this key. + #[inline] + pub fn get(&self, key: &str) -> Option<&str> { + self.items.iter().find_map(|item| match item { + Param::KeyValue { + key: item_key, + value, + } if item_key == key => Some(value.as_str()), + Param::Value(_) | Param::KeyValue { .. } => None, + }) + } +} + +/// A single OSC 133 marker parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Param { + /// A positional parameter without an equals sign. + Value(String), + /// A `key=value` parameter. + KeyValue { key: String, value: String }, +} + +/// An OSC 133 event with its position in the most recent input chunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocatedEvent { + /// The OSC 133 event that was parsed. + pub event: Event, + /// Offset where this marker starts in the current chunk. + /// + /// If a marker started in an earlier [`Parser::push_located`] call, this is + /// `0` in the chunk that completed the marker. + pub start_offset: usize, + /// Offset immediately after this marker's terminator in the current chunk. + /// + /// If a marker spans multiple [`Parser::push_located`] calls, this is still + /// the offset in the chunk that completed the marker. + pub offset: usize, + /// The semantic zone after applying this event. + pub zone: Zone, + /// Metadata parameters attached to this marker. + pub params: Params, +} + /// The current semantic zone as determined by the most recent OSC 133 marker. #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] #[allow(dead_code)] @@ -59,14 +117,14 @@ pub enum Zone { const ESC: u8 = 0x1B; const BEL: u8 = 0x07; +const C1_ST: u8 = 0x9C; const BACKSLASH: u8 = b'\\'; const RIGHT_BRACKET: u8 = b']'; -/// Maximum bytes we'll buffer for the OSC parameter string. 32 bytes is far -/// more than any valid OSC 133 payload needs (e.g. `133;D;127` is 9 bytes). -/// Longer (non-133) OSC sequences simply stop accumulating once the buffer is -/// full — the dispatch logic will harmlessly ignore them. -const PARAM_BUF_CAP: usize = 32; +/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough +/// for Atuin metadata such as history/session IDs while still bounding malformed +/// OSC sequences. +const PARAM_BUF_CAP: usize = 512; // --------------------------------------------------------------------------- // State machine @@ -94,6 +152,7 @@ enum State { pub struct Parser { state: State, zone: Zone, + sequence_start: Option<usize>, param_buf: [u8; PARAM_BUF_CAP], param_len: usize, } @@ -111,6 +170,7 @@ impl Parser { Self { state: State::Ground, zone: Zone::Unknown, + sequence_start: None, param_buf: [0u8; PARAM_BUF_CAP], param_len: 0, } @@ -123,18 +183,40 @@ impl Parser { self.zone } + /// Start offset of an incomplete OSC sequence in the most recent chunk. + #[inline] + pub(crate) fn incomplete_osc_sequence_start(&self) -> Option<usize> { + matches!(self.state, State::OscParam | State::OscEsc) + .then(|| self.sequence_start.unwrap_or(0)) + } + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker /// found. /// /// All bytes in `data` should still be forwarded to the terminal by the /// caller — this method only *observes* the stream. + #[cfg(test)] #[inline] pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { - for &byte in data { + self.push_located(data, |located| on_event(located.event)); + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found with its byte offset in this chunk. + /// + /// The offset points to the first byte after the marker terminator, making + /// it suitable for callers that need to split the original chunk at marker + /// boundaries. + #[inline] + pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { + self.sequence_start = (self.state != State::Ground).then_some(0); + + for (offset, &byte) in data.iter().enumerate() { match self.state { State::Ground => { if byte == ESC { self.state = State::Esc; + self.sequence_start = Some(offset); } } State::Esc => { @@ -143,12 +225,14 @@ impl Parser { self.param_len = 0; } else { self.state = State::Ground; + self.sequence_start = None; } } State::OscParam => { - if byte == BEL { - self.dispatch(&mut on_event); + if byte == BEL || byte == C1_ST { + self.dispatch(offset + 1, &mut on_event); self.state = State::Ground; + self.sequence_start = None; } else if byte == ESC { self.state = State::OscEsc; } else if self.param_len < PARAM_BUF_CAP { @@ -160,12 +244,13 @@ impl Parser { } State::OscEsc => { if byte == BACKSLASH { - self.dispatch(&mut on_event); + self.dispatch(offset + 1, &mut on_event); } // Whether we got a valid ST or not, return to ground. // (A new ESC ] would restart accumulation via the Ground // -> Esc -> OscParam path on the *next* byte.) self.state = State::Ground; + self.sequence_start = None; } } } @@ -174,46 +259,104 @@ impl Parser { /// Inspect the accumulated parameter buffer. If it holds an OSC 133 /// payload, emit the corresponding [`Event`] and update the zone. #[inline] - fn dispatch(&mut self, on_event: &mut impl FnMut(Event)) { - let params = &self.param_buf[..self.param_len]; + fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { + let payload = &self.param_buf[..self.param_len]; + + if payload.len() < 5 || &payload[..4] != b"133;" { + return; + } - // Must start with "133;" - if params.len() < 5 || ¶ms[..4] != b"133;" { + if payload.len() > 5 && payload[5] != b';' { return; } - let cmd = params[4]; - let event = match cmd { + let metadata = payload.get(6..).unwrap_or_default(); + let cmd = payload[4]; + let (event, params) = match cmd { b'A' => { self.zone = Zone::Prompt; - Event::PromptStart + (Event::PromptStart, parse_params(metadata)) } b'B' => { self.zone = Zone::Input; - Event::CommandStart + (Event::CommandStart, parse_params(metadata)) } b'C' => { self.zone = Zone::Output; - Event::CommandExecuted + (Event::CommandExecuted, parse_params(metadata)) } b'D' => { - let exit_code = if params.len() > 6 && params[5] == b';' { - std::str::from_utf8(¶ms[6..]) - .ok() - .and_then(|s| s.parse::<i32>().ok()) - } else { - None - }; + let (exit_code, params) = parse_command_finished_params(metadata); self.zone = Zone::Unknown; - Event::CommandFinished { exit_code } + (Event::CommandFinished { exit_code }, params) } _ => return, }; - on_event(event); + on_event(LocatedEvent { + event, + start_offset: self.sequence_start.unwrap_or(0), + offset, + zone: self.zone, + params, + }); } } +fn parse_command_finished_params(metadata: &[u8]) -> (Option<i32>, Params) { + if metadata.is_empty() { + return (None, Params::default()); + } + + let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { + return parse_exit_code(metadata).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), Params::default()), + ); + }; + + let (first, rest) = metadata.split_at(separator); + let rest = &rest[1..]; + + parse_exit_code(first).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), parse_params(rest)), + ) +} + +fn parse_exit_code(code: &[u8]) -> Option<i32> { + if code.is_empty() { + return None; + } + + std::str::from_utf8(code) + .ok() + .and_then(|code| code.parse::<i32>().ok()) +} + +fn parse_params(metadata: &[u8]) -> Params { + let items = metadata + .split(|byte| *byte == b';') + .filter(|part| !part.is_empty()) + .map(parse_param) + .collect(); + + Params { items } +} + +fn parse_param(param: &[u8]) -> Param { + let param = String::from_utf8_lossy(param); + + if let Some((key, value)) = param.split_once('=') { + return Param::KeyValue { + key: key.to_string(), + value: value.to_string(), + }; + } + + Param::Value(param.into_owned()) +} + // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- @@ -468,6 +611,12 @@ mod tests { assert!(parse_events(data).is_empty()); } + #[test] + fn marker_with_unexpected_trailing_bytes_ignored() { + let data = b"\x1b]133;ABC\x07"; + assert!(parse_events(data).is_empty()); + } + // -- Malformed sequences -------------------------------------------------- #[test] @@ -509,7 +658,7 @@ mod tests { fn very_long_osc_does_not_panic() { let mut data = Vec::new(); data.extend_from_slice(b"\x1b]"); - data.extend(std::iter::repeat(b'x').take(1000)); + data.extend(std::iter::repeat_n(b'x', 1000)); data.push(BEL); // Should not panic and should produce no event. assert!(parse_events(&data).is_empty()); @@ -589,6 +738,100 @@ mod tests { ); } + #[test] + fn detects_c1_st_terminator() { + let data = b"\x1b]133;A\x9c"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + // -- Located event offsets ------------------------------------------------ + + #[test] + fn located_event_reports_offset_after_marker() { + let data = b"before\x1b]133;A\x07prompt"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(data, |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::PromptStart, + start_offset: b"before".len(), + offset: b"before\x1b]133;A\x07".len(), + zone: Zone::Prompt, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_offset_is_relative_to_completing_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;", |e| events.push(e)); + parser.push_located(b"D;42\x07after", |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::CommandFinished { + exit_code: Some(42) + }, + start_offset: 0, + offset: b"D;42\x07".len(), + zone: Zone::Unknown, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_preserves_metadata_params() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located( + b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", + |event| events.push(event), + ); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!( + event.event, + Event::CommandFinished { + exit_code: Some(127) + } + ); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + assert!( + event + .params + .iter() + .any(|param| param == &Param::Value("flag".to_string())) + ); + } + + #[test] + fn command_finished_metadata_without_exit_code_is_preserved() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { + events.push(event); + }); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!(event.event, Event::CommandFinished { exit_code: None }); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + } + // -- Default trait -------------------------------------------------------- #[test] diff --git a/crates/atuin-pty-proxy/src/pty_proxy.rs b/crates/atuin-pty-proxy/src/pty_proxy.rs new file mode 100644 index 00000000..030ef9b5 --- /dev/null +++ b/crates/atuin-pty-proxy/src/pty_proxy.rs @@ -0,0 +1,231 @@ +use clap::{Args, Subcommand, ValueEnum}; + +use crate::{CommandCaptureSink, runtime}; + +#[derive(Args, Debug)] +pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, + + #[command(subcommand)] + cmd: Option<Cmd>, +} + +#[derive(Subcommand, Debug)] +pub enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), +} + +#[derive(Args, Debug)] +pub struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + #[arg(value_enum)] + shell: Option<Shell>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +#[value(rename_all = "lower")] +#[allow(clippy::enum_variant_names, clippy::doc_markdown)] +enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, +} + +pub(crate) struct RuntimeOptions { + pub(crate) debug_osc133: bool, + pub(crate) command_capture_sink: Option<CommandCaptureSink>, +} + +impl RuntimeOptions { + fn new(debug_osc133: bool, command_capture_sink: Option<CommandCaptureSink>) -> Self { + Self { + debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), + command_capture_sink, + } + } +} + +impl PtyProxy { + pub fn run(self, command_capture_sink: Option<CommandCaptureSink>) { + match self.cmd { + Some(Cmd::Init(init)) => { + if let Err(err) = init.run() { + eprintln!("atuin pty-proxy: {err}"); + std::process::exit(1); + } + } + None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), + } + } +} + +impl Init { + fn run(self) -> Result<(), String> { + let shell = detect_shell(self.shell)?; + let script = render_init(shell); + print!("{script}"); + Ok(()) + } +} + +fn detect_shell(cli_shell: Option<Shell>) -> Result<Shell, String> { + if let Some(shell) = cli_shell { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("ATUIN_SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + Err( + "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" + .to_string(), + ) +} + +fn shell_from_name(name: &str) -> Option<Shell> { + let shell = name + .trim() + .rsplit('/') + .next() + .unwrap_or(name) + .trim_start_matches('-') + .to_ascii_lowercase(); + + match shell.as_str() { + "bash" => Some(Shell::Bash), + "zsh" => Some(Shell::Zsh), + "fish" => Some(Shell::Fish), + "nu" => Some(Shell::Nu), + _ => None, + } +} + +fn env_flag(name: &str) -> bool { + std::env::var(name).is_ok_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn render_init(shell: Shell) -> &'static str { + match shell { + Shell::Bash | Shell::Zsh => { + r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then + _atuin_pty_proxy_tmux_current="${TMUX:-}" + _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" + + if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then + export ATUIN_PTY_PROXY_ACTIVE=1 + export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + fi + + unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous +fi +"# + } + Shell::Fish => { + r#"if status is-interactive; and test -t 0; and test -t 1 + set -l _atuin_pty_proxy_tmux_current "" + if set -q TMUX + set _atuin_pty_proxy_tmux_current "$TMUX" + end + + set -l _atuin_pty_proxy_tmux_previous "" + if set -q ATUIN_PTY_PROXY_TMUX + set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" + end + + if not set -q ATUIN_PTY_PROXY_ACTIVE + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + end +end +"# + } + // Nushell cannot dynamically source the output of `atuin init nu`, + // so we only output the pty-proxy preamble here. Users must also set up + // `atuin init nu` separately. + Shell::Nu => { + r#"if (is-terminal --stdin) and (is-terminal --stdout) { + let tmux_current = ($env.TMUX? | default "") + let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") + + if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { + $env.ATUIN_PTY_PROXY_ACTIVE = "1" + $env.ATUIN_PTY_PROXY_TMUX = $tmux_current + exec atuin pty-proxy + } +} +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::{Shell, render_init, shell_from_name}; + + #[test] + fn shell_from_name_handles_paths() { + assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); + assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); + assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); + assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); + } + + #[test] + fn posix_init_uses_exec_and_tmux_guard() { + let script = render_init(Shell::Bash); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(!script.contains("eval \"$(atuin init bash)\"")); + } + + #[test] + fn posix_init_has_no_double_braces() { + let script = render_init(Shell::Bash); + assert!(!script.contains("${{"), "double braces in bash init script"); + } + + #[test] + fn fish_init_uses_source() { + let script = render_init(Shell::Fish); + assert!(script.contains("exec atuin pty-proxy")); + assert!(!script.contains("atuin init fish | source")); + } + + #[test] + fn nu_init_uses_exec_and_tty_guard() { + let script = render_init(Shell::Nu); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(script.contains("is-terminal --stdin")); + assert!(script.contains("is-terminal --stdout")); + assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); + } +} diff --git a/crates/atuin-pty-proxy/src/runtime.rs b/crates/atuin-pty-proxy/src/runtime.rs new file mode 100644 index 00000000..2b34fbb7 --- /dev/null +++ b/crates/atuin-pty-proxy/src/runtime.rs @@ -0,0 +1,184 @@ +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::mpsc; + +use crossterm::terminal; +use portable_pty::{CommandBuilder, PtySize, native_pty_system}; + +use crate::capture::CommandCaptureTracker; +use crate::debug::{Osc133DebugHighlighter, RESET}; +use crate::pty_proxy::RuntimeOptions; +use crate::screen::{self, Msg}; + +pub(crate) fn main(options: RuntimeOptions) { + if let Err(e) = run(options) { + let _ = terminal::disable_raw_mode(); + eprintln!("atuin pty-proxy: {e:#}"); + std::process::exit(1); + } +} + +fn run(options: RuntimeOptions) -> eyre::Result<()> { + let (cols, rows) = terminal::size()?; + + let pty_system = native_pty_system(); + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let sock_path = screen::socket_path(); + let _ = std::fs::remove_file(&sock_path); + + let mut cmd = CommandBuilder::new_default_prog(); + cmd.cwd(std::env::current_dir()?); + cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); + cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); + + let mut child = pair + .slave + .spawn_command(cmd) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + drop(pair.slave); + + let mut pty_reader = pair + .master + .try_clone_reader() + .map_err(|e| eyre::eyre!("{e:#}"))?; + let mut pty_writer = pair + .master + .take_writer() + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let (msg_tx, msg_rx) = mpsc::sync_channel::<Msg>(64); + let current_cols = Arc::new(AtomicU16::new(cols.max(1))); + + screen::spawn_parser_thread(rows, cols, msg_rx); + screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); + spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; + + terminal::enable_raw_mode()?; + + let stdout_thread = std::thread::spawn(move || { + let mut stdout = std::io::stdout(); + let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); + let mut capture_tracker = options + .command_capture_sink + .as_ref() + .map(|_| CommandCaptureTracker::new(current_cols)); + let mut buf = [0u8; 8192]; + + loop { + match pty_reader.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if let (Some(tracker), Some(sink)) = ( + capture_tracker.as_mut(), + options.command_capture_sink.as_ref(), + ) { + tracker.push(&buf[..n], sink); + } + + if let Some(highlighter) = highlighter.as_mut() { + let rendered = highlighter.render(&buf[..n]); + let _ = msg_tx.try_send(Msg::Data(rendered.clone())); + + if stdout.write_all(&rendered).is_err() { + break; + } + } else { + let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); + + if stdout.write_all(&buf[..n]).is_err() { + break; + } + } + let _ = stdout.flush(); + } + } + } + + if highlighter.is_some() { + let _ = stdout.write_all(RESET); + let _ = stdout.flush(); + } + }); + + std::thread::spawn(move || { + let mut stdin = std::io::stdin(); + let mut buf = [0u8; 8192]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if pty_writer.write_all(&buf[..n]).is_err() { + break; + } + } + } + } + }); + + let status = child.wait()?; + let _ = stdout_thread.join(); + + let _ = terminal::disable_raw_mode(); + let _ = std::fs::remove_file(&sock_path); + + std::process::exit(process_exit_code(status.exit_code())); +} + +fn spawn_resize_handler( + master: Box<dyn portable_pty::MasterPty + Send>, + resize_tx: mpsc::SyncSender<Msg>, + current_cols: Arc<AtomicU16>, +) -> eyre::Result<()> { + use signal_hook::consts::SIGWINCH; + use signal_hook::iterator::Signals; + + let mut signals = Signals::new([SIGWINCH])?; + + std::thread::spawn(move || { + for _ in signals.forever() { + if let Ok((cols, rows)) = terminal::size() { + current_cols.store(cols.max(1), Ordering::Relaxed); + let _ = master.resize(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }); + let _ = resize_tx.try_send(Msg::Resize { rows, cols }); + } + } + }); + + Ok(()) +} + +fn process_exit_code(code: u32) -> i32 { + i32::try_from(code).unwrap_or(1) +} + +#[cfg(test)] +mod tests { + use super::process_exit_code; + + #[test] + fn process_exit_code_preserves_valid_values() { + assert_eq!(process_exit_code(0), 0); + assert_eq!(process_exit_code(127), 127); + assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); + } + + #[test] + fn process_exit_code_defaults_when_out_of_range() { + assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); + } +} diff --git a/crates/atuin-pty-proxy/src/screen.rs b/crates/atuin-pty-proxy/src/screen.rs new file mode 100644 index 00000000..5b892e21 --- /dev/null +++ b/crates/atuin-pty-proxy/src/screen.rs @@ -0,0 +1,104 @@ +use std::io::Write; +use std::os::unix::net::UnixListener; +use std::path::PathBuf; +use std::sync::mpsc::{self, Receiver, SyncSender}; + +pub(crate) enum Msg { + Data(Vec<u8>), + Resize { rows: u16, cols: u16 }, + ScreenRequest(mpsc::Sender<Vec<u8>>), +} + +pub(crate) fn socket_path() -> PathBuf { + let dir = std::env::temp_dir(); + dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) +} + +pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver<Msg>) { + std::thread::spawn(move || { + let mut parser = vt100::Parser::new(rows, cols, 0); + + loop { + let first = match msg_rx.recv() { + Ok(msg) => msg, + Err(_) => break, + }; + + handle_parser_msg(&mut parser, first); + + while let Ok(msg) = msg_rx.try_recv() { + handle_parser_msg(&mut parser, msg); + } + } + }); +} + +pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender<Msg>) { + std::thread::spawn(move || { + let listener = match UnixListener::bind(&sock_path) { + Ok(l) => l, + Err(e) => { + eprintln!("atuin pty-proxy: failed to bind socket: {e}"); + return; + } + }; + + for stream in listener.incoming() { + let mut stream = match stream { + Ok(s) => s, + Err(_) => break, + }; + + let (reply_tx, reply_rx) = mpsc::channel(); + if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { + break; + } + if let Ok(data) = reply_rx.recv() { + let _ = stream.write_all(&data); + let _ = stream.flush(); + } + } + }); +} + +/// Wire format written to the Unix socket: +/// +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +/// +/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain +/// pre-built ANSI escape sequences. The client can write them directly to +/// stdout without needing its own vt100 parser. +fn encode_screen(parser: &vt100::Parser) -> Vec<u8> { + let screen = parser.screen(); + let (rows, cols) = screen.size(); + let (cursor_row, cursor_col) = screen.cursor_position(); + + let mut buf: Vec<u8> = Vec::with_capacity(256 + (rows as usize * cols as usize)); + buf.extend_from_slice(&rows.to_be_bytes()); + buf.extend_from_slice(&cols.to_be_bytes()); + buf.extend_from_slice(&cursor_row.to_be_bytes()); + buf.extend_from_slice(&cursor_col.to_be_bytes()); + + for row_bytes in screen.rows_formatted(0, cols) { + let len = row_bytes.len() as u32; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(&row_bytes); + } + + buf +} + +fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { + match msg { + Msg::Data(data) => parser.process(&data), + Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), + Msg::ScreenRequest(reply_tx) => { + let _ = reply_tx.send(encode_screen(parser)); + } + } +} |
