From 6ea760bb6b36da241961e8ecd60cb2c5e15c0a78 Mon Sep 17 00:00:00 2001 From: Michelle Tilley Date: Tue, 24 Feb 2026 11:48:20 -0800 Subject: feat: Generate commands or ask questions with `atuin ai` (#3199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR refines the system created in #3178 to be suitable for a v1 release. --- ## Overview `atuin-ai` is a separate binary that allows for generating commands and asking questions from the command line. It is fully opt-in. ## Usage `atuin ai init` will output bindings for your shell. Currently, bash, zsh, and fish are supported. ```bash eval "$(atuin ai init)" ``` Once the hooks are installed, just press `?` on an empty prompt line to call up the TUI. `atuin ai` requires an account on [Atuin Hub](https://hub.atuin.sh/); you will be prompted to log in on first use. ## Features ### Command generation Prompt the LLM to create a command, and get one back, no fuss. Press `enter` to run, or `tab` to insert. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` ### Follow-up You can follow-up with `f` to specify a refinement prompt to update the command that will be inserted. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > Actually I want to get all docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps -a │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` You can also follow-up with questions to get responses in natural language. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > Actually I want to get all docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps -a │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > What other useful flags to `docker ps` should I know? │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ Here are some handy `docker ps` flags: │ │ │ │ - `-q` — Only show container IDs (great for piping to │ │ other commands) │ │ - `-s` — Show container sizes │ │ - `-n 5` — Show the last 5 created containers │ │ - `-l` — Show only the latest created container │ │ - `--no-trunc` — Don't truncate output (shows full IDs and │ │ commands) │ │ - `-f` or `--filter` — Filter by condition, e.g.: │ │ - `-f status=exited` — only exited containers │ │ - `-f name=myapp` — filter by name │ │ - `-f ancestor=nginx` — filter by image │ │ - `--format` — Custom output using Go templates, e.g.: │ │ `--format "table {{.Names}}\t{{.Status}}\t{{.Ports}}"` │ │ │ │ A common combo is `docker ps -aq` to get all container │ │ IDs, useful for bulk operations like `docker rm $(docker │ │ ps -aq)`. │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` You can use `enter` or `tab` at any time to run or insert the last suggested command, even if it was suggested in a previous turn. ### Conversational and search usage If you prompt the LLM with a question that doesn't imply you want to generate a command, it can respond in natural language, and use web search if necessary to fetch the data it needs. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > What is the latest version of atuin? │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ ✓ Used 2 tools │ │ │ │ The latest version of Atuin is **v18.12.0**, available on │ │ the [GitHub releases │ │ page](https://github.com/atuinsh/atuin/releases). │ │ │ └─────────────────────────────────[f]: Follow-up [Esc]: Cancel┘ ``` ### Dangerous or low-confidence command detection The LLM scores its confidence in the command, as well as how dangerous the command is. This information is shown if a threshold is exceeded, and requires an extra confirmation step before running automatically with `enter`. The Atuin Hub server also monitors suggested commands for dangerous patterns the LLM didn't catch, and appends its own assessment at the end of the LLM's own assessment. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Delete all files from $HOME │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ rm -rf $HOME/* │ │ │ │ ! ⚠️ This will PERMANENTLY delete ALL files and directories │ │ in your home directory, including documents, downloads, │ │ configurations, SSH keys, and everything else. This is │ │ irreversible and will likely break your system. Also note │ │ this won't delete hidden (dot) files — if you want those │ │ too, that's even more destructive.; [Server] Recursive │ │ delete of critical directory │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` --------- Co-authored-by: Claude Opus 4.5 --- crates/atuin-ai/src/commands/inline.rs | 932 +++++++++++++++++---------------- 1 file changed, 476 insertions(+), 456 deletions(-) (limited to 'crates/atuin-ai/src/commands/inline.rs') diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index cfa27db4..3f9278a2 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -1,52 +1,52 @@ +use crate::commands::detect_shell; +use crate::tui::render::render; +use crate::tui::{ + App, AppEvent, AppMode, ConversationEvent, EventLoop, ExitAction, RenderContext, TerminalGuard, + calculate_needed_height, install_panic_hook, +}; +use atuin_client::theme::ThemeManager; use atuin_common::tls::ensure_crypto_provider; use crossterm::{ - cursor, event::{self, Event, KeyCode}, terminal::{disable_raw_mode, enable_raw_mode}, }; +use eventsource_stream::Eventsource; use eyre::{Context as _, Result, bail}; -use ratatui::{ - Frame, Terminal, TerminalOptions, Viewport, - backend::CrosstermBackend, - layout::{Alignment, Rect}, - text::Line, - widgets::{Block, Borders, Paragraph, Wrap}, -}; +use futures::StreamExt; use reqwest::Url; -use serde::{Deserialize, Serialize}; -use std::time::Duration; - -#[derive(Debug, Serialize)] -struct GenerateRequest { - query: String, - description: String, - context: GenerateContext, -} - -#[derive(Debug, Serialize)] -struct GenerateContext { - os: String, - shell: String, - pwd: Option, -} - -#[derive(Debug, Deserialize)] -struct GenerateResponse { - command: String, - #[serde(default)] - explanation: Option, -} +use std::io::Write; pub async fn run( initial_command: Option, natural_language: bool, api_endpoint: Option, + api_token: Option, + keep_output: bool, + debug_state_file: Option, ) -> Result<()> { + // Install panic hook once at entry point to ensure terminal restoration + install_panic_hook(); + + // Token and endpoint priority: + // 1. Command line arguments/environment variables + // 2. Settings file + // 3. Default let settings = atuin_client::settings::Settings::new()?; - let endpoint = api_endpoint - .as_deref() - .unwrap_or(settings.hub_address.as_str()); - let token = ensure_hub_session(&settings, endpoint).await?; + let endpoint = api_endpoint.as_deref().unwrap_or( + settings + .ai + .ai_endpoint + .as_deref() + .unwrap_or("https://hub.atuin.sh"), + ); + let api_token = api_token.as_deref().or(settings.ai.ai_api_token.as_deref()); + + let token = if let Some(token) = &api_token { + token.to_string() + } else { + ensure_hub_session(&settings, endpoint).await? + }; + let action = run_inline_tui( endpoint.to_string(), token, @@ -55,6 +55,8 @@ pub async fn run( } else { initial_command }, + keep_output, + debug_state_file, ) .await?; emit_shell_result(action.0, &action.1); @@ -95,55 +97,172 @@ async fn ensure_hub_session( Ok(token) } -async fn generate_command( - hub_address: &str, - token: &str, - description: &str, -) -> Result { - ensure_crypto_provider(); - let endpoint = hub_url(hub_address, "/api/cli/generate")?; - let request = GenerateRequest { - query: description.to_string(), - description: description.to_string(), - context: GenerateContext { - os: detect_os(), - shell: detect_shell(), - pwd: std::env::current_dir() - .ok() - .map(|path| path.to_string_lossy().into_owned()), - }, - }; +/// SSE event received from chat endpoint +#[derive(Debug, Clone)] +enum ChatStreamEvent { + /// Text chunk to display + TextChunk(String), + /// Tool call event (need to echo back, may contain suggest_command) + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + /// Tool result from server-side execution + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, + /// Status update from server + Status(String), + /// Stream complete + Done { session_id: String }, + /// Error from server + Error(String), +} + +fn create_chat_stream( + hub_address: String, + token: String, + session_id: Option, + messages: Vec, + settings: &atuin_client::settings::Settings, +) -> std::pin::Pin> + Send>> { + let send_cwd = settings.ai.send_cwd; + + Box::pin(async_stream::stream! { + ensure_crypto_provider(); + let endpoint = match hub_url(&hub_address, "/api/cli/chat") { + Ok(url) => url, + Err(e) => { + yield Err(e); + return; + } + }; + + // Build request body + let mut request_body = serde_json::json!({ + "messages": messages, + "context": { + "os": detect_os(), + "shell": detect_shell(), + "pwd": if send_cwd { std::env::current_dir() + .ok() + .map(|path| path.to_string_lossy().into_owned()) } else { None }, + } + }); + + // Include session_id only if present (not on first request) + if let Some(ref sid) = session_id { + request_body["session_id"] = serde_json::json!(sid); + } - let client = reqwest::Client::new(); - let response = client - .post(endpoint) - .bearer_auth(token) - .json(&request) - .send() - .await - .context("failed to call Atuin Hub generate endpoint")?; - - if response.status().is_success() { - let generated = response - .json::() + + let client = reqwest::Client::new(); + let response = match client + .post(endpoint.clone()) + .header("Accept", "text/event-stream") + .bearer_auth(&token) + .json(&request_body) + .send() .await - .context("failed to decode generate response")?; + { + Ok(resp) => resp, + Err(e) => { + yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); + return; + } + }; - if generated.command.trim().is_empty() { - bail!("Hub returned an empty command. Please try again with a more specific request."); + let status = response.status(); + if status == reqwest::StatusCode::UNAUTHORIZED { + // Clear saved session on auth error + let _ = atuin_client::hub::delete_session().await; + yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); + return; + } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); + return; } - return Ok(generated); - } + let byte_stream = response.bytes_stream(); + let mut stream = byte_stream.eventsource(); - if response.status() == reqwest::StatusCode::UNAUTHORIZED { - atuin_client::hub::delete_session().await?; - bail!("Hub session expired. Re-run to authenticate again."); - } + while let Some(event) = stream.next().await { + match event { + Ok(sse_event) => { + let event_type = sse_event.event.as_str(); + let data = sse_event.data.clone(); + + tracing::debug!(event_type = %event_type, data = %data, "SSE event received"); - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - bail!("Hub request failed ({status}): {body}"); + match event_type { + "text" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(content) = json.get("content").and_then(|v| v.as_str()) + { + yield Ok(ChatStreamEvent::TextChunk(content.to_string())); + } + } + "tool_call" => { + if let Ok(json) = serde_json::from_str::(&data) { + let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); + yield Ok(ChatStreamEvent::ToolCall { id, name, input }); + } + } + "tool_result" => { + if let Ok(json) = serde_json::from_str::(&data) { + let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); + yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error }); + } + } + "status" => { + if let Ok(json) = serde_json::from_str::(&data) + && let Some(state) = json.get("state").and_then(|v| v.as_str()) + { + yield Ok(ChatStreamEvent::Status(state.to_string())); + } + } + "done" => { + if let Ok(json) = serde_json::from_str::(&data) { + let session_id = json.get("session_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + yield Ok(ChatStreamEvent::Done { session_id }); + } else { + yield Ok(ChatStreamEvent::Done { session_id: String::new() }); + } + break; + } + "error" => { + if let Ok(json) = serde_json::from_str::(&data) { + let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); + yield Ok(ChatStreamEvent::Error(message)); + } else { + yield Ok(ChatStreamEvent::Error(data)); + } + break; + } + _ => { + // Unknown event type, ignore + } + } + } + Err(e) => { + yield Err(eyre::eyre!("SSE error: {}", e)); + break; + } + } + } + }) } fn hub_url(base: &str, path: &str) -> Result { @@ -162,35 +281,11 @@ fn detect_os() -> String { match std::env::consts::OS { "macos" => "macos".to_string(), "linux" => "linux".to_string(), + "windows" => "windows".to_string(), _ => "linux".to_string(), } } -fn detect_shell() -> String { - if let Ok(shell) = std::env::var("ATUIN_SHELL") - && !shell.trim().is_empty() - { - return shell; - } - - let shell = std::env::var("SHELL") - .ok() - .and_then(|value| { - std::path::Path::new(&value) - .file_name() - .map(std::ffi::OsStr::to_string_lossy) - .map(std::borrow::Cow::into_owned) - }) - .filter(|value| !value.trim().is_empty()); - - match shell.as_deref() { - Some("zsh") => "zsh".to_string(), - Some("fish") => "fish".to_string(), - Some("bash") => "bash".to_string(), - _ => "bash".to_string(), - } -} - #[derive(Clone, Copy)] enum Action { Execute, @@ -198,105 +293,306 @@ enum Action { Cancel, } +/// Serialize AppState to JSON for debug logging +fn state_to_json(state: &crate::tui::AppState) -> serde_json::Value { + let events: Vec = state.events.iter().map(|e| e.to_json()).collect(); + + let mode = match state.mode { + AppMode::Input => "Input", + AppMode::Generating => "Generating", + AppMode::Streaming => "Streaming", + AppMode::Review => "Review", + AppMode::Error => "Error", + }; + + // Get input and cursor from textarea + let input = state.input(); + let cursor = state.textarea.cursor(); + + let mut json = serde_json::json!({ + "events": events, + "mode": mode, + "input": input, + "cursor_row": cursor.0, + "cursor_col": cursor.1, + "spinner_frame": state.spinner_frame, + "confirmation_pending": state.confirmation_pending, + }); + + // Add streaming fields if in streaming mode + if !state.streaming_text.is_empty() { + json["streaming_text"] = serde_json::json!(state.streaming_text); + } + if let Some(ref status) = state.streaming_status { + json["streaming_status"] = serde_json::json!(status.display_text()); + } + if let Some(ref err) = state.error { + json["error"] = serde_json::json!(err); + } + + json +} + +/// Debug logger that writes state changes to a file +struct DebugStateLogger { + file: std::fs::File, + entry_count: usize, + width: u16, +} + +impl DebugStateLogger { + fn new(path: &str) -> Result { + let file = std::fs::File::create(path) + .with_context(|| format!("Failed to create debug state file: {}", path))?; + // Get terminal width, default to 80 + let (width, _) = crossterm::terminal::size().unwrap_or((80, 24)); + Ok(Self { + file, + entry_count: 0, + width, + }) + } + + fn log(&mut self, label: &str, state: &crate::tui::AppState) { + use crate::tui::calculate_needed_height; + + self.entry_count += 1; + let timestamp_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or(0); + + // Calculate the actual content height needed for this state + let content_height = calculate_needed_height(state); + + let mut state_json = state_to_json(state); + // Add dimensions for accurate replay + state_json["width"] = serde_json::json!(self.width); + state_json["height"] = serde_json::json!(content_height); + + let entry = serde_json::json!({ + "entry": self.entry_count, + "label": label, + "timestamp_ms": timestamp_ms, + "state": state_json, + }); + + // Write as JSONL (one JSON object per line) + if let Err(e) = writeln!(self.file, "{}", entry) { + tracing::warn!("Failed to write debug state: {}", e); + } + let _ = self.file.flush(); + } +} + async fn run_inline_tui( endpoint: String, token: String, initial_prompt: Option, + keep_output: bool, + debug_state_file: Option, ) -> Result<(Action, String)> { - let mut ui = InlineUi::new()?; - let mut prompt = initial_prompt.unwrap_or_default(); - let mut spinner_idx = 0usize; + // Initialize terminal guard and app state + let mut guard = TerminalGuard::new(keep_output)?; + let mut app = App::new(); + if let Some(prompt) = initial_prompt { + // Set initial text in textarea + let mut textarea = tui_textarea::TextArea::from(prompt.lines()); + // Disable underline on cursor line + textarea.set_cursor_line_style(ratatui::style::Style::default()); + // Enable word wrapping + textarea.set_wrap_mode(tui_textarea::WrapMode::Word); + // Move cursor to end + textarea.move_cursor(tui_textarea::CursorMove::End); + app.state.textarea = textarea; + } - loop { - ui.render_prompt(&prompt)?; - if !event::poll(Duration::from_millis(250)).context("failed to poll for input")? { - continue; - } + // Initialize debug state logger if requested + let mut debug_logger = debug_state_file + .map(|path| DebugStateLogger::new(&path)) + .transpose()?; - let ev = event::read().context("failed to read terminal event")?; - let Event::Key(key) = ev else { - continue; + // Helper macro to log state changes + macro_rules! log_state { + ($label:expr) => { + if let Some(ref mut logger) = debug_logger { + logger.log($label, &app.state); + } }; + } - match key.code { - KeyCode::Esc => return Ok((Action::Cancel, String::new())), - KeyCode::Backspace => { - prompt.pop(); - } - KeyCode::Enter => { - let query = prompt.trim().to_string(); - if query.is_empty() { - return Ok((Action::Cancel, String::new())); - } + // Log initial state + log_state!("init"); - let response = loop { - let endpoint_clone = endpoint.clone(); - let token_clone = token.clone(); - let query_clone = query.clone(); - let task = tokio::spawn(async move { - generate_command(&endpoint_clone, &token_clone, &query_clone).await - }); - - let generated = loop { - if task.is_finished() { - break task.await.context("generate task join failed")?; - } + // Load theme + let settings = atuin_client::settings::Settings::new()?; + let mut theme_manager = ThemeManager::new(None, None); + let theme = theme_manager.load_theme(&settings.theme.name, None); - ui.render_generating(&prompt, spinner_idx)?; - spinner_idx = (spinner_idx + 1) % SPINNER_FRAMES.len(); + // Initialize event loop + let mut event_loop = EventLoop::new(); - if event::poll(Duration::from_millis(100)) - .context("failed to poll while generating")? - { - let ev = event::read().context("failed reading generate event")?; - if let Event::Key(key) = ev - && key.code == KeyCode::Esc - { - task.abort(); - return Ok((Action::Cancel, String::new())); + // Track chat stream + let mut chat_stream: Option< + std::pin::Pin> + Send>>, + > = None; + + loop { + // Ensure viewport is large enough for current content (capped at terminal height) + let needed_height = calculate_needed_height(&app.state); + let actual_height = guard.ensure_height(needed_height)?; + + // Render current state + let anchor_col = guard.anchor_col(); + let ctx = RenderContext { + theme, + anchor_col, + textarea: Some(&app.state.textarea), + max_height: actual_height, + }; + // Handle draw errors gracefully - cursor position reads can fail during resize + if let Err(e) = guard.terminal().draw(|frame| { + render(frame, &app.state, &ctx); + }) { + let err_msg = e.to_string(); + if err_msg.contains("cursor position") { + // Cursor position read failed (common during terminal resize) + // Skip this frame and continue - next frame will likely succeed + tracing::debug!( + "Skipping frame due to cursor position read error: {}", + err_msg + ); + continue; + } + return Err(e.into()); + } + + // Get next event + let event = event_loop.run().await?; + + // Handle event based on app mode + match event { + AppEvent::Key(key) => { + app.handle_key(key); + log_state!("key"); + } + AppEvent::Tick => { + app.state.tick(); + + // Poll chat stream if active - keep polling until done regardless of mode + // (mode may change to Review before we receive the done event with session_id) + if let Some(stream) = &mut chat_stream { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + match stream.as_mut().poll_next(&mut cx) { + std::task::Poll::Ready(Some(Ok(event))) => match event { + ChatStreamEvent::TextChunk(text) => { + tracing::debug!(text = %text, "Processing TextChunk"); + app.state.append_streaming_text(&text); + log_state!("text_chunk"); } - } - }; - - match generated { - Ok(value) => break value, - Err(err) => { - ui.render_error(&prompt, &err.to_string())?; - if !wait_for_retry_or_cancel()? { - return Ok((Action::Cancel, String::new())); + ChatStreamEvent::ToolCall { id, name, input } => { + tracing::debug!(id = %id, name = %name, "Processing ToolCall"); + app.state.add_tool_call(id, name, input); + log_state!("tool_call"); + } + ChatStreamEvent::ToolResult { + tool_use_id, + content, + is_error, + } => { + tracing::debug!(tool_use_id = %tool_use_id, "Processing ToolResult"); + app.state.add_tool_result(tool_use_id, content, is_error); + log_state!("tool_result"); + } + ChatStreamEvent::Status(status) => { + tracing::debug!(status = %status, "Processing Status"); + app.state.update_streaming_status(&status); + log_state!("status"); + } + ChatStreamEvent::Done { session_id } => { + tracing::debug!(session_id = %session_id, "Processing Done"); + chat_stream = None; + if !session_id.is_empty() { + app.state.store_session_id(session_id); + } + app.state.finalize_streaming(); + log_state!("done"); } + ChatStreamEvent::Error(msg) => { + tracing::debug!(error = %msg, "Processing Error"); + chat_stream = None; + app.state.streaming_error(msg); + log_state!("error"); + } + }, + std::task::Poll::Ready(Some(Err(e))) => { + chat_stream = None; + app.state.streaming_error(e.to_string()); + log_state!("stream_error"); } - } - }; - - loop { - ui.render_review(&prompt, &response)?; - if !event::poll(Duration::from_millis(250)) - .context("failed to poll in review")? - { - continue; - } - - let ev = event::read().context("failed to read review event")?; - let Event::Key(key) = ev else { - continue; - }; - - match key.code { - KeyCode::Enter => return Ok((Action::Execute, response.command)), - KeyCode::Tab => return Ok((Action::Insert, response.command)), - KeyCode::Esc => return Ok((Action::Cancel, String::new())), - KeyCode::Char('e') => break, - _ => {} + std::task::Poll::Ready(None) => { + chat_stream = None; + app.state.finalize_streaming(); + log_state!("stream_end"); + } + std::task::Poll::Pending => {} } } } - KeyCode::Char(c) => { - prompt.push(c); - } _ => {} } + + // Handle user cancellation (Esc during streaming) - drop the stream + if app.state.was_interrupted && chat_stream.is_some() { + tracing::debug!("User cancelled streaming, dropping chat stream"); + chat_stream = None; + app.state.was_interrupted = false; // Reset the flag + } + + // Check exit condition + if app.state.should_exit { + break; + } + + // Handle generation trigger - unified path for all turns + if app.state.mode == AppMode::Generating && chat_stream.is_none() { + // Get the last user message from events + let last_user_content = app.state.events.iter().rev().find_map(|e| { + if let ConversationEvent::UserMessage { content } = e { + Some(content.clone()) + } else { + None + } + }); + + if last_user_content.is_some() { + // Build messages in Claude API format + let messages = app.state.events_to_messages(); + + // Transition to streaming mode + app.state.start_streaming(); + log_state!("start_streaming"); + + // Start the chat stream + chat_stream = Some(create_chat_stream( + endpoint.clone(), + token.clone(), + app.state.session_id.clone(), + messages, + &settings, + )); + } + } } + + // Map exit action to return value + let result = match app.state.exit_action { + Some(ExitAction::Execute(cmd)) => (Action::Execute, cmd), + Some(ExitAction::Insert(cmd)) => (Action::Insert, cmd), + _ => (Action::Cancel, String::new()), + }; + + Ok(result) } struct RawModeGuard; @@ -330,279 +626,3 @@ fn wait_for_login_confirmation() -> Result { } } } - -fn wait_for_retry_or_cancel() -> Result { - loop { - let ev = event::read().context("failed to read retry/cancel key")?; - if let Event::Key(key) = ev { - match key.code { - KeyCode::Enter | KeyCode::Char('r') => return Ok(true), - KeyCode::Esc => return Ok(false), - _ => {} - } - } - } -} - -const SPINNER_FRAMES: [&str; 4] = ["/", "-", "\\", "|"]; - -struct InlineUi { - terminal: Terminal>, - anchor_col: u16, -} - -impl InlineUi { - fn new() -> Result { - let anchor_col = cursor::position().map(|(x, _)| x).unwrap_or(0); - enable_raw_mode().context("failed to enable raw mode for inline UI")?; - let backend = CrosstermBackend::new(std::io::stdout()); - let terminal = Terminal::with_options( - backend, - TerminalOptions { - viewport: Viewport::Inline(16), - }, - ) - .context("failed to initialize inline UI")?; - Ok(Self { - terminal, - anchor_col, - }) - } - - fn render_prompt(&mut self, prompt: &str) -> Result<()> { - self.render(Screen::Prompt { - prompt, - footer: "[Enter]: Accept [Esc]: Cancel", - }) - } - - fn render_generating(&mut self, prompt: &str, spinner_idx: usize) -> Result<()> { - self.render(Screen::Generating { - prompt, - footer: "[Esc]: Cancel", - spinner_idx, - }) - } - - fn render_review(&mut self, prompt: &str, response: &GenerateResponse) -> Result<()> { - self.render(Screen::Review { - prompt, - response, - footer: "[Enter]: Run [Tab]: Insert [e]: Edit [Esc]: Cancel", - }) - } - - fn render_error(&mut self, prompt: &str, err: &str) -> Result<()> { - self.render(Screen::Error { - prompt, - err, - footer: "[Enter]/[r]: Retry [Esc]: Cancel", - }) - } - - fn render(&mut self, screen: Screen<'_>) -> Result<()> { - self.terminal - .draw(|f| draw_screen(f, screen, self.anchor_col)) - .context("failed rendering inline UI")?; - Ok(()) - } -} - -impl Drop for InlineUi { - fn drop(&mut self) { - let _ = self.terminal.clear(); - let _ = disable_raw_mode(); - } -} - -enum Screen<'a> { - Prompt { - prompt: &'a str, - footer: &'a str, - }, - Generating { - prompt: &'a str, - footer: &'a str, - spinner_idx: usize, - }, - Review { - prompt: &'a str, - response: &'a GenerateResponse, - footer: &'a str, - }, - Error { - prompt: &'a str, - err: &'a str, - footer: &'a str, - }, -} - -fn draw_screen(frame: &mut Frame, screen: Screen<'_>, anchor_col: u16) { - let area = frame.area(); - let desired_width = 64u16.min(area.width.saturating_sub(2)).max(32); - let content_width = usize::from(desired_width.saturating_sub(2)).max(1); - let (content_preview, _, _) = build_screen_content(&screen, content_width); - let desired_height = (wrapped_line_count(&content_preview, content_width) as u16) - .saturating_add(2) - .min(area.height.max(1)) - .max(3); - - let max_x = area.x + area.width.saturating_sub(desired_width); - let preferred_x = area.x + anchor_col.saturating_sub(2); - let card = Rect { - x: preferred_x.min(max_x), - y: area.y, - width: desired_width, - height: desired_height, - }; - - let footer = match &screen { - Screen::Prompt { footer, .. } - | Screen::Generating { footer, .. } - | Screen::Review { footer, .. } - | Screen::Error { footer, .. } => *footer, - }; - - let block = Block::default() - .borders(Borders::ALL) - .title("Describe the command you'd like to generate:") - .title_bottom(Line::from(footer).alignment(Alignment::Right)); - - let content_area = block.inner(card); - frame.render_widget(block, card); - - let (content, show_cursor, cursor_prompt) = - build_screen_content(&screen, usize::from(content_area.width).max(1)); - - let paragraph = Paragraph::new(content).wrap(Wrap { trim: false }); - frame.render_widget(paragraph, content_area); - - if show_cursor { - let width = usize::from(content_area.width).max(1); - let (cursor_row, cursor_col) = - prompt_cursor_position(cursor_prompt.as_deref().unwrap_or_default(), width); - let cursor_x = content_area.x.saturating_add(cursor_col); - let cursor_y = content_area.y.saturating_add(cursor_row); - frame.set_cursor_position((cursor_x, cursor_y)); - } -} - -fn format_prompt(prompt: &str) -> String { - if prompt.is_empty() { - return "> ".to_string(); - } - format!("> {prompt}") -} - -fn wrapped_line_count(text: &str, width: usize) -> usize { - if width == 0 { - return 1; - } - - text.split('\n') - .map(|line| { - let len = line.chars().count(); - len.max(1).div_ceil(width) - }) - .sum::() - .max(1) -} - -fn build_screen_content( - screen: &Screen<'_>, - content_width: usize, -) -> (String, bool, Option) { - match screen { - Screen::Prompt { prompt, .. } => { - let formatted = format_prompt(prompt); - (formatted, true, Some((*prompt).to_string())) - } - Screen::Generating { - prompt, - spinner_idx, - .. - } => ( - format!( - "{}\n\n{} Generating...", - format_prompt(prompt), - SPINNER_FRAMES[*spinner_idx] - ), - false, - None, - ), - Screen::Review { - prompt, response, .. - } => { - let separator = "─".repeat(content_width.max(1)); - let mut text = format!( - "{}\n\n{}\n\n$ {}\n", - format_prompt(prompt), - separator, - response.command - ); - if let Some(explanation) = &response.explanation { - text.push('\n'); - text.push_str(explanation); - } - (text, false, None) - } - Screen::Error { prompt, err, .. } => ( - format!("{}\n\nRequest failed:\n{}", format_prompt(prompt), err), - false, - None, - ), - } -} - -fn prompt_cursor_position(prompt: &str, width: usize) -> (u16, u16) { - if width == 0 { - return (0, 0); - } - - // The visible prompt line is always `> {prompt}`. - // We mimic word-wrapping so cursor tracking matches visual layout. - let mut row = 0usize; - let mut col = 2usize; // "> " - - let mut saw_any_word = false; - for word in prompt.split_whitespace() { - let word_len = word.chars().count(); - if !saw_any_word { - saw_any_word = true; - if col + word_len <= width { - col += word_len; - } else if word_len >= width { - let used = width.saturating_sub(col); - let remaining = word_len.saturating_sub(used); - row += 1 + (remaining / width); - col = remaining % width; - } else { - row += 1; - col = word_len; - } - continue; - } - - if col + 1 + word_len <= width { - col += 1 + word_len; - } else if word_len >= width { - row += 1 + (word_len / width); - col = word_len % width; - } else { - row += 1; - col = word_len; - } - } - - // Keep trailing spaces user typed. - let trailing_spaces = prompt.chars().rev().take_while(|c| *c == ' ').count(); - for _ in 0..trailing_spaces { - if col >= width { - row += 1; - col = 0; - } - col += 1; - } - - (row as u16, col as u16) -} -- cgit v1.3.1