aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/commands/inline.rs
diff options
context:
space:
mode:
authorMichelle Tilley <michelle@michelletilley.net>2026-02-24 11:48:20 -0800
committerGitHub <noreply@github.com>2026-02-24 11:48:20 -0800
commit6ea760bb6b36da241961e8ecd60cb2c5e15c0a78 (patch)
tree18ebbb710cea24e30bc69b5d6bc807518a950746 /crates/atuin-ai/src/commands/inline.rs
parentfix: forward $PATH to tmux popup in zsh (#3198) (diff)
downloadatuin-6ea760bb6b36da241961e8ecd60cb2c5e15c0a78.zip
feat: Generate commands or ask questions with `atuin ai` (#3199)
This PR refines the system created in #3178 to be suitable for a v1 release. --- ## Overview `atuin-ai` is a separate binary that allows for generating commands and asking questions from the command line. It is fully opt-in. ## Usage `atuin ai init` will output bindings for your shell. Currently, bash, zsh, and fish are supported. ```bash eval "$(atuin ai init)" ``` Once the hooks are installed, just press `?` on an empty prompt line to call up the TUI. `atuin ai` requires an account on [Atuin Hub](https://hub.atuin.sh/); you will be prompted to log in on first use. ## Features ### Command generation Prompt the LLM to create a command, and get one back, no fuss. Press `enter` to run, or `tab` to insert. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` ### Follow-up You can follow-up with `f` to specify a refinement prompt to update the command that will be inserted. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > Actually I want to get all docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps -a │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` You can also follow-up with questions to get responses in natural language. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Get a list of running docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > Actually I want to get all docker containers │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ docker ps -a │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ > What other useful flags to `docker ps` should I know? │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ Here are some handy `docker ps` flags: │ │ │ │ - `-q` — Only show container IDs (great for piping to │ │ other commands) │ │ - `-s` — Show container sizes │ │ - `-n 5` — Show the last 5 created containers │ │ - `-l` — Show only the latest created container │ │ - `--no-trunc` — Don't truncate output (shows full IDs and │ │ commands) │ │ - `-f` or `--filter` — Filter by condition, e.g.: │ │ - `-f status=exited` — only exited containers │ │ - `-f name=myapp` — filter by name │ │ - `-f ancestor=nginx` — filter by image │ │ - `--format` — Custom output using Go templates, e.g.: │ │ `--format "table {{.Names}}\t{{.Status}}\t{{.Ports}}"` │ │ │ │ A common combo is `docker ps -aq` to get all container │ │ IDs, useful for bulk operations like `docker rm $(docker │ │ ps -aq)`. │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` You can use `enter` or `tab` at any time to run or insert the last suggested command, even if it was suggested in a previous turn. ### Conversational and search usage If you prompt the LLM with a question that doesn't imply you want to generate a command, it can respond in natural language, and use web search if necessary to fetch the data it needs. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > What is the latest version of atuin? │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ ✓ Used 2 tools │ │ │ │ The latest version of Atuin is **v18.12.0**, available on │ │ the [GitHub releases │ │ page](https://github.com/atuinsh/atuin/releases). │ │ │ └─────────────────────────────────[f]: Follow-up [Esc]: Cancel┘ ``` ### Dangerous or low-confidence command detection The LLM scores its confidence in the command, as well as how dangerous the command is. This information is shown if a threshold is exceeded, and requires an extra confirmation step before running automatically with `enter`. The Atuin Hub server also monitors suggested commands for dangerous patterns the LLM didn't catch, and appends its own assessment at the end of the LLM's own assessment. ``` ┌Ask questions or generate a command:──────────────────────────┐ │ │ │ > Delete all files from $HOME │ │ │ ├──────────────────────────────────────────────────────────────┤ │ │ │ $ rm -rf $HOME/* │ │ │ │ ! ⚠️ This will PERMANENTLY delete ALL files and directories │ │ in your home directory, including documents, downloads, │ │ configurations, SSH keys, and everything else. This is │ │ irreversible and will likely break your system. Also note │ │ this won't delete hidden (dot) files — if you want those │ │ too, that's even more destructive.; [Server] Recursive │ │ delete of critical directory │ │ │ └────[Enter]: Run [Tab]: Insert [f]: Follow-up [Esc]: Cancel┘ ``` --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'crates/atuin-ai/src/commands/inline.rs')
-rw-r--r--crates/atuin-ai/src/commands/inline.rs924
1 files changed, 472 insertions, 452 deletions
diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs
index cfa27db4..3f9278a2 100644
--- a/crates/atuin-ai/src/commands/inline.rs
+++ b/crates/atuin-ai/src/commands/inline.rs
@@ -1,52 +1,52 @@
+use crate::commands::detect_shell;
+use crate::tui::render::render;
+use crate::tui::{
+ App, AppEvent, AppMode, ConversationEvent, EventLoop, ExitAction, RenderContext, TerminalGuard,
+ calculate_needed_height, install_panic_hook,
+};
+use atuin_client::theme::ThemeManager;
use atuin_common::tls::ensure_crypto_provider;
use crossterm::{
- cursor,
event::{self, Event, KeyCode},
terminal::{disable_raw_mode, enable_raw_mode},
};
+use eventsource_stream::Eventsource;
use eyre::{Context as _, Result, bail};
-use ratatui::{
- Frame, Terminal, TerminalOptions, Viewport,
- backend::CrosstermBackend,
- layout::{Alignment, Rect},
- text::Line,
- widgets::{Block, Borders, Paragraph, Wrap},
-};
+use futures::StreamExt;
use reqwest::Url;
-use serde::{Deserialize, Serialize};
-use std::time::Duration;
-
-#[derive(Debug, Serialize)]
-struct GenerateRequest {
- query: String,
- description: String,
- context: GenerateContext,
-}
-
-#[derive(Debug, Serialize)]
-struct GenerateContext {
- os: String,
- shell: String,
- pwd: Option<String>,
-}
-
-#[derive(Debug, Deserialize)]
-struct GenerateResponse {
- command: String,
- #[serde(default)]
- explanation: Option<String>,
-}
+use std::io::Write;
pub async fn run(
initial_command: Option<String>,
natural_language: bool,
api_endpoint: Option<String>,
+ api_token: Option<String>,
+ keep_output: bool,
+ debug_state_file: Option<String>,
) -> Result<()> {
+ // Install panic hook once at entry point to ensure terminal restoration
+ install_panic_hook();
+
+ // Token and endpoint priority:
+ // 1. Command line arguments/environment variables
+ // 2. Settings file
+ // 3. Default
let settings = atuin_client::settings::Settings::new()?;
- let endpoint = api_endpoint
- .as_deref()
- .unwrap_or(settings.hub_address.as_str());
- let token = ensure_hub_session(&settings, endpoint).await?;
+ let endpoint = api_endpoint.as_deref().unwrap_or(
+ settings
+ .ai
+ .ai_endpoint
+ .as_deref()
+ .unwrap_or("https://hub.atuin.sh"),
+ );
+ let api_token = api_token.as_deref().or(settings.ai.ai_api_token.as_deref());
+
+ let token = if let Some(token) = &api_token {
+ token.to_string()
+ } else {
+ ensure_hub_session(&settings, endpoint).await?
+ };
+
let action = run_inline_tui(
endpoint.to_string(),
token,
@@ -55,6 +55,8 @@ pub async fn run(
} else {
initial_command
},
+ keep_output,
+ debug_state_file,
)
.await?;
emit_shell_result(action.0, &action.1);
@@ -95,55 +97,172 @@ async fn ensure_hub_session(
Ok(token)
}
-async fn generate_command(
- hub_address: &str,
- token: &str,
- description: &str,
-) -> Result<GenerateResponse> {
- ensure_crypto_provider();
- let endpoint = hub_url(hub_address, "/api/cli/generate")?;
- let request = GenerateRequest {
- query: description.to_string(),
- description: description.to_string(),
- context: GenerateContext {
- os: detect_os(),
- shell: detect_shell(),
- pwd: std::env::current_dir()
- .ok()
- .map(|path| path.to_string_lossy().into_owned()),
- },
- };
+/// SSE event received from chat endpoint
+#[derive(Debug, Clone)]
+enum ChatStreamEvent {
+ /// Text chunk to display
+ TextChunk(String),
+ /// Tool call event (need to echo back, may contain suggest_command)
+ ToolCall {
+ id: String,
+ name: String,
+ input: serde_json::Value,
+ },
+ /// Tool result from server-side execution
+ ToolResult {
+ tool_use_id: String,
+ content: String,
+ is_error: bool,
+ },
+ /// Status update from server
+ Status(String),
+ /// Stream complete
+ Done { session_id: String },
+ /// Error from server
+ Error(String),
+}
+
+fn create_chat_stream(
+ hub_address: String,
+ token: String,
+ session_id: Option<String>,
+ messages: Vec<serde_json::Value>,
+ settings: &atuin_client::settings::Settings,
+) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>> {
+ let send_cwd = settings.ai.send_cwd;
+
+ Box::pin(async_stream::stream! {
+ ensure_crypto_provider();
+ let endpoint = match hub_url(&hub_address, "/api/cli/chat") {
+ Ok(url) => url,
+ Err(e) => {
+ yield Err(e);
+ return;
+ }
+ };
- let client = reqwest::Client::new();
- let response = client
- .post(endpoint)
- .bearer_auth(token)
- .json(&request)
- .send()
- .await
- .context("failed to call Atuin Hub generate endpoint")?;
+ // Build request body
+ let mut request_body = serde_json::json!({
+ "messages": messages,
+ "context": {
+ "os": detect_os(),
+ "shell": detect_shell(),
+ "pwd": if send_cwd { std::env::current_dir()
+ .ok()
+ .map(|path| path.to_string_lossy().into_owned()) } else { None },
+ }
+ });
+
+ // Include session_id only if present (not on first request)
+ if let Some(ref sid) = session_id {
+ request_body["session_id"] = serde_json::json!(sid);
+ }
- if response.status().is_success() {
- let generated = response
- .json::<GenerateResponse>()
+
+ let client = reqwest::Client::new();
+ let response = match client
+ .post(endpoint.clone())
+ .header("Accept", "text/event-stream")
+ .bearer_auth(&token)
+ .json(&request_body)
+ .send()
.await
- .context("failed to decode generate response")?;
+ {
+ Ok(resp) => resp,
+ Err(e) => {
+ yield Err(eyre::eyre!("Failed to send SSE request: {}", e));
+ return;
+ }
+ };
- if generated.command.trim().is_empty() {
- bail!("Hub returned an empty command. Please try again with a more specific request.");
+ let status = response.status();
+ if status == reqwest::StatusCode::UNAUTHORIZED {
+ // Clear saved session on auth error
+ let _ = atuin_client::hub::delete_session().await;
+ yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again."));
+ return;
+ }
+ if !status.is_success() {
+ let body = response.text().await.unwrap_or_default();
+ yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body));
+ return;
}
- return Ok(generated);
- }
+ let byte_stream = response.bytes_stream();
+ let mut stream = byte_stream.eventsource();
- if response.status() == reqwest::StatusCode::UNAUTHORIZED {
- atuin_client::hub::delete_session().await?;
- bail!("Hub session expired. Re-run to authenticate again.");
- }
+ while let Some(event) = stream.next().await {
+ match event {
+ Ok(sse_event) => {
+ let event_type = sse_event.event.as_str();
+ let data = sse_event.data.clone();
+
+ tracing::debug!(event_type = %event_type, data = %data, "SSE event received");
- let status = response.status();
- let body = response.text().await.unwrap_or_default();
- bail!("Hub request failed ({status}): {body}");
+ match event_type {
+ "text" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
+ && let Some(content) = json.get("content").and_then(|v| v.as_str())
+ {
+ yield Ok(ChatStreamEvent::TextChunk(content.to_string()));
+ }
+ }
+ "tool_call" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
+ let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let input = json.get("input").cloned().unwrap_or(serde_json::json!({}));
+ yield Ok(ChatStreamEvent::ToolCall { id, name, input });
+ }
+ }
+ "tool_result" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
+ let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
+ let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false);
+ yield Ok(ChatStreamEvent::ToolResult { tool_use_id, content, is_error });
+ }
+ }
+ "status" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data)
+ && let Some(state) = json.get("state").and_then(|v| v.as_str())
+ {
+ yield Ok(ChatStreamEvent::Status(state.to_string()));
+ }
+ }
+ "done" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
+ let session_id = json.get("session_id")
+ .and_then(|v| v.as_str())
+ .unwrap_or("")
+ .to_string();
+ yield Ok(ChatStreamEvent::Done { session_id });
+ } else {
+ yield Ok(ChatStreamEvent::Done { session_id: String::new() });
+ }
+ break;
+ }
+ "error" => {
+ if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) {
+ let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string();
+ yield Ok(ChatStreamEvent::Error(message));
+ } else {
+ yield Ok(ChatStreamEvent::Error(data));
+ }
+ break;
+ }
+ _ => {
+ // Unknown event type, ignore
+ }
+ }
+ }
+ Err(e) => {
+ yield Err(eyre::eyre!("SSE error: {}", e));
+ break;
+ }
+ }
+ }
+ })
}
fn hub_url(base: &str, path: &str) -> Result<Url> {
@@ -162,35 +281,11 @@ fn detect_os() -> String {
match std::env::consts::OS {
"macos" => "macos".to_string(),
"linux" => "linux".to_string(),
+ "windows" => "windows".to_string(),
_ => "linux".to_string(),
}
}
-fn detect_shell() -> String {
- if let Ok(shell) = std::env::var("ATUIN_SHELL")
- && !shell.trim().is_empty()
- {
- return shell;
- }
-
- let shell = std::env::var("SHELL")
- .ok()
- .and_then(|value| {
- std::path::Path::new(&value)
- .file_name()
- .map(std::ffi::OsStr::to_string_lossy)
- .map(std::borrow::Cow::into_owned)
- })
- .filter(|value| !value.trim().is_empty());
-
- match shell.as_deref() {
- Some("zsh") => "zsh".to_string(),
- Some("fish") => "fish".to_string(),
- Some("bash") => "bash".to_string(),
- _ => "bash".to_string(),
- }
-}
-
#[derive(Clone, Copy)]
enum Action {
Execute,
@@ -198,105 +293,306 @@ enum Action {
Cancel,
}
+/// Serialize AppState to JSON for debug logging
+fn state_to_json(state: &crate::tui::AppState) -> serde_json::Value {
+ let events: Vec<serde_json::Value> = state.events.iter().map(|e| e.to_json()).collect();
+
+ let mode = match state.mode {
+ AppMode::Input => "Input",
+ AppMode::Generating => "Generating",
+ AppMode::Streaming => "Streaming",
+ AppMode::Review => "Review",
+ AppMode::Error => "Error",
+ };
+
+ // Get input and cursor from textarea
+ let input = state.input();
+ let cursor = state.textarea.cursor();
+
+ let mut json = serde_json::json!({
+ "events": events,
+ "mode": mode,
+ "input": input,
+ "cursor_row": cursor.0,
+ "cursor_col": cursor.1,
+ "spinner_frame": state.spinner_frame,
+ "confirmation_pending": state.confirmation_pending,
+ });
+
+ // Add streaming fields if in streaming mode
+ if !state.streaming_text.is_empty() {
+ json["streaming_text"] = serde_json::json!(state.streaming_text);
+ }
+ if let Some(ref status) = state.streaming_status {
+ json["streaming_status"] = serde_json::json!(status.display_text());
+ }
+ if let Some(ref err) = state.error {
+ json["error"] = serde_json::json!(err);
+ }
+
+ json
+}
+
+/// Debug logger that writes state changes to a file
+struct DebugStateLogger {
+ file: std::fs::File,
+ entry_count: usize,
+ width: u16,
+}
+
+impl DebugStateLogger {
+ fn new(path: &str) -> Result<Self> {
+ let file = std::fs::File::create(path)
+ .with_context(|| format!("Failed to create debug state file: {}", path))?;
+ // Get terminal width, default to 80
+ let (width, _) = crossterm::terminal::size().unwrap_or((80, 24));
+ Ok(Self {
+ file,
+ entry_count: 0,
+ width,
+ })
+ }
+
+ fn log(&mut self, label: &str, state: &crate::tui::AppState) {
+ use crate::tui::calculate_needed_height;
+
+ self.entry_count += 1;
+ let timestamp_ms = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .map(|d| d.as_millis())
+ .unwrap_or(0);
+
+ // Calculate the actual content height needed for this state
+ let content_height = calculate_needed_height(state);
+
+ let mut state_json = state_to_json(state);
+ // Add dimensions for accurate replay
+ state_json["width"] = serde_json::json!(self.width);
+ state_json["height"] = serde_json::json!(content_height);
+
+ let entry = serde_json::json!({
+ "entry": self.entry_count,
+ "label": label,
+ "timestamp_ms": timestamp_ms,
+ "state": state_json,
+ });
+
+ // Write as JSONL (one JSON object per line)
+ if let Err(e) = writeln!(self.file, "{}", entry) {
+ tracing::warn!("Failed to write debug state: {}", e);
+ }
+ let _ = self.file.flush();
+ }
+}
+
async fn run_inline_tui(
endpoint: String,
token: String,
initial_prompt: Option<String>,
+ keep_output: bool,
+ debug_state_file: Option<String>,
) -> Result<(Action, String)> {
- let mut ui = InlineUi::new()?;
- let mut prompt = initial_prompt.unwrap_or_default();
- let mut spinner_idx = 0usize;
+ // Initialize terminal guard and app state
+ let mut guard = TerminalGuard::new(keep_output)?;
+ let mut app = App::new();
+ if let Some(prompt) = initial_prompt {
+ // Set initial text in textarea
+ let mut textarea = tui_textarea::TextArea::from(prompt.lines());
+ // Disable underline on cursor line
+ textarea.set_cursor_line_style(ratatui::style::Style::default());
+ // Enable word wrapping
+ textarea.set_wrap_mode(tui_textarea::WrapMode::Word);
+ // Move cursor to end
+ textarea.move_cursor(tui_textarea::CursorMove::End);
+ app.state.textarea = textarea;
+ }
- loop {
- ui.render_prompt(&prompt)?;
- if !event::poll(Duration::from_millis(250)).context("failed to poll for input")? {
- continue;
- }
+ // Initialize debug state logger if requested
+ let mut debug_logger = debug_state_file
+ .map(|path| DebugStateLogger::new(&path))
+ .transpose()?;
- let ev = event::read().context("failed to read terminal event")?;
- let Event::Key(key) = ev else {
- continue;
+ // Helper macro to log state changes
+ macro_rules! log_state {
+ ($label:expr) => {
+ if let Some(ref mut logger) = debug_logger {
+ logger.log($label, &app.state);
+ }
};
+ }
- match key.code {
- KeyCode::Esc => return Ok((Action::Cancel, String::new())),
- KeyCode::Backspace => {
- prompt.pop();
- }
- KeyCode::Enter => {
- let query = prompt.trim().to_string();
- if query.is_empty() {
- return Ok((Action::Cancel, String::new()));
- }
+ // Log initial state
+ log_state!("init");
- let response = loop {
- let endpoint_clone = endpoint.clone();
- let token_clone = token.clone();
- let query_clone = query.clone();
- let task = tokio::spawn(async move {
- generate_command(&endpoint_clone, &token_clone, &query_clone).await
- });
+ // Load theme
+ let settings = atuin_client::settings::Settings::new()?;
+ let mut theme_manager = ThemeManager::new(None, None);
+ let theme = theme_manager.load_theme(&settings.theme.name, None);
- let generated = loop {
- if task.is_finished() {
- break task.await.context("generate task join failed")?;
- }
+ // Initialize event loop
+ let mut event_loop = EventLoop::new();
- ui.render_generating(&prompt, spinner_idx)?;
- spinner_idx = (spinner_idx + 1) % SPINNER_FRAMES.len();
+ // Track chat stream
+ let mut chat_stream: Option<
+ std::pin::Pin<Box<dyn futures::Stream<Item = Result<ChatStreamEvent>> + Send>>,
+ > = None;
- if event::poll(Duration::from_millis(100))
- .context("failed to poll while generating")?
- {
- let ev = event::read().context("failed reading generate event")?;
- if let Event::Key(key) = ev
- && key.code == KeyCode::Esc
- {
- task.abort();
- return Ok((Action::Cancel, String::new()));
- }
- }
- };
+ loop {
+ // Ensure viewport is large enough for current content (capped at terminal height)
+ let needed_height = calculate_needed_height(&app.state);
+ let actual_height = guard.ensure_height(needed_height)?;
+
+ // Render current state
+ let anchor_col = guard.anchor_col();
+ let ctx = RenderContext {
+ theme,
+ anchor_col,
+ textarea: Some(&app.state.textarea),
+ max_height: actual_height,
+ };
+ // Handle draw errors gracefully - cursor position reads can fail during resize
+ if let Err(e) = guard.terminal().draw(|frame| {
+ render(frame, &app.state, &ctx);
+ }) {
+ let err_msg = e.to_string();
+ if err_msg.contains("cursor position") {
+ // Cursor position read failed (common during terminal resize)
+ // Skip this frame and continue - next frame will likely succeed
+ tracing::debug!(
+ "Skipping frame due to cursor position read error: {}",
+ err_msg
+ );
+ continue;
+ }
+ return Err(e.into());
+ }
+
+ // Get next event
+ let event = event_loop.run().await?;
+
+ // Handle event based on app mode
+ match event {
+ AppEvent::Key(key) => {
+ app.handle_key(key);
+ log_state!("key");
+ }
+ AppEvent::Tick => {
+ app.state.tick();
- match generated {
- Ok(value) => break value,
- Err(err) => {
- ui.render_error(&prompt, &err.to_string())?;
- if !wait_for_retry_or_cancel()? {
- return Ok((Action::Cancel, String::new()));
+ // Poll chat stream if active - keep polling until done regardless of mode
+ // (mode may change to Review before we receive the done event with session_id)
+ if let Some(stream) = &mut chat_stream {
+ let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
+ match stream.as_mut().poll_next(&mut cx) {
+ std::task::Poll::Ready(Some(Ok(event))) => match event {
+ ChatStreamEvent::TextChunk(text) => {
+ tracing::debug!(text = %text, "Processing TextChunk");
+ app.state.append_streaming_text(&text);
+ log_state!("text_chunk");
}
+ ChatStreamEvent::ToolCall { id, name, input } => {
+ tracing::debug!(id = %id, name = %name, "Processing ToolCall");
+ app.state.add_tool_call(id, name, input);
+ log_state!("tool_call");
+ }
+ ChatStreamEvent::ToolResult {
+ tool_use_id,
+ content,
+ is_error,
+ } => {
+ tracing::debug!(tool_use_id = %tool_use_id, "Processing ToolResult");
+ app.state.add_tool_result(tool_use_id, content, is_error);
+ log_state!("tool_result");
+ }
+ ChatStreamEvent::Status(status) => {
+ tracing::debug!(status = %status, "Processing Status");
+ app.state.update_streaming_status(&status);
+ log_state!("status");
+ }
+ ChatStreamEvent::Done { session_id } => {
+ tracing::debug!(session_id = %session_id, "Processing Done");
+ chat_stream = None;
+ if !session_id.is_empty() {
+ app.state.store_session_id(session_id);
+ }
+ app.state.finalize_streaming();
+ log_state!("done");
+ }
+ ChatStreamEvent::Error(msg) => {
+ tracing::debug!(error = %msg, "Processing Error");
+ chat_stream = None;
+ app.state.streaming_error(msg);
+ log_state!("error");
+ }
+ },
+ std::task::Poll::Ready(Some(Err(e))) => {
+ chat_stream = None;
+ app.state.streaming_error(e.to_string());
+ log_state!("stream_error");
+ }
+ std::task::Poll::Ready(None) => {
+ chat_stream = None;
+ app.state.finalize_streaming();
+ log_state!("stream_end");
}
+ std::task::Poll::Pending => {}
}
- };
+ }
+ }
+ _ => {}
+ }
- loop {
- ui.render_review(&prompt, &response)?;
- if !event::poll(Duration::from_millis(250))
- .context("failed to poll in review")?
- {
- continue;
- }
+ // Handle user cancellation (Esc during streaming) - drop the stream
+ if app.state.was_interrupted && chat_stream.is_some() {
+ tracing::debug!("User cancelled streaming, dropping chat stream");
+ chat_stream = None;
+ app.state.was_interrupted = false; // Reset the flag
+ }
- let ev = event::read().context("failed to read review event")?;
- let Event::Key(key) = ev else {
- continue;
- };
+ // Check exit condition
+ if app.state.should_exit {
+ break;
+ }
- match key.code {
- KeyCode::Enter => return Ok((Action::Execute, response.command)),
- KeyCode::Tab => return Ok((Action::Insert, response.command)),
- KeyCode::Esc => return Ok((Action::Cancel, String::new())),
- KeyCode::Char('e') => break,
- _ => {}
- }
+ // Handle generation trigger - unified path for all turns
+ if app.state.mode == AppMode::Generating && chat_stream.is_none() {
+ // Get the last user message from events
+ let last_user_content = app.state.events.iter().rev().find_map(|e| {
+ if let ConversationEvent::UserMessage { content } = e {
+ Some(content.clone())
+ } else {
+ None
}
+ });
+
+ if last_user_content.is_some() {
+ // Build messages in Claude API format
+ let messages = app.state.events_to_messages();
+
+ // Transition to streaming mode
+ app.state.start_streaming();
+ log_state!("start_streaming");
+
+ // Start the chat stream
+ chat_stream = Some(create_chat_stream(
+ endpoint.clone(),
+ token.clone(),
+ app.state.session_id.clone(),
+ messages,
+ &settings,
+ ));
}
- KeyCode::Char(c) => {
- prompt.push(c);
- }
- _ => {}
}
}
+
+ // Map exit action to return value
+ let result = match app.state.exit_action {
+ Some(ExitAction::Execute(cmd)) => (Action::Execute, cmd),
+ Some(ExitAction::Insert(cmd)) => (Action::Insert, cmd),
+ _ => (Action::Cancel, String::new()),
+ };
+
+ Ok(result)
}
struct RawModeGuard;
@@ -330,279 +626,3 @@ fn wait_for_login_confirmation() -> Result<bool> {
}
}
}
-
-fn wait_for_retry_or_cancel() -> Result<bool> {
- loop {
- let ev = event::read().context("failed to read retry/cancel key")?;
- if let Event::Key(key) = ev {
- match key.code {
- KeyCode::Enter | KeyCode::Char('r') => return Ok(true),
- KeyCode::Esc => return Ok(false),
- _ => {}
- }
- }
- }
-}
-
-const SPINNER_FRAMES: [&str; 4] = ["/", "-", "\\", "|"];
-
-struct InlineUi {
- terminal: Terminal<CrosstermBackend<std::io::Stdout>>,
- anchor_col: u16,
-}
-
-impl InlineUi {
- fn new() -> Result<Self> {
- let anchor_col = cursor::position().map(|(x, _)| x).unwrap_or(0);
- enable_raw_mode().context("failed to enable raw mode for inline UI")?;
- let backend = CrosstermBackend::new(std::io::stdout());
- let terminal = Terminal::with_options(
- backend,
- TerminalOptions {
- viewport: Viewport::Inline(16),
- },
- )
- .context("failed to initialize inline UI")?;
- Ok(Self {
- terminal,
- anchor_col,
- })
- }
-
- fn render_prompt(&mut self, prompt: &str) -> Result<()> {
- self.render(Screen::Prompt {
- prompt,
- footer: "[Enter]: Accept [Esc]: Cancel",
- })
- }
-
- fn render_generating(&mut self, prompt: &str, spinner_idx: usize) -> Result<()> {
- self.render(Screen::Generating {
- prompt,
- footer: "[Esc]: Cancel",
- spinner_idx,
- })
- }
-
- fn render_review(&mut self, prompt: &str, response: &GenerateResponse) -> Result<()> {
- self.render(Screen::Review {
- prompt,
- response,
- footer: "[Enter]: Run [Tab]: Insert [e]: Edit [Esc]: Cancel",
- })
- }
-
- fn render_error(&mut self, prompt: &str, err: &str) -> Result<()> {
- self.render(Screen::Error {
- prompt,
- err,
- footer: "[Enter]/[r]: Retry [Esc]: Cancel",
- })
- }
-
- fn render(&mut self, screen: Screen<'_>) -> Result<()> {
- self.terminal
- .draw(|f| draw_screen(f, screen, self.anchor_col))
- .context("failed rendering inline UI")?;
- Ok(())
- }
-}
-
-impl Drop for InlineUi {
- fn drop(&mut self) {
- let _ = self.terminal.clear();
- let _ = disable_raw_mode();
- }
-}
-
-enum Screen<'a> {
- Prompt {
- prompt: &'a str,
- footer: &'a str,
- },
- Generating {
- prompt: &'a str,
- footer: &'a str,
- spinner_idx: usize,
- },
- Review {
- prompt: &'a str,
- response: &'a GenerateResponse,
- footer: &'a str,
- },
- Error {
- prompt: &'a str,
- err: &'a str,
- footer: &'a str,
- },
-}
-
-fn draw_screen(frame: &mut Frame, screen: Screen<'_>, anchor_col: u16) {
- let area = frame.area();
- let desired_width = 64u16.min(area.width.saturating_sub(2)).max(32);
- let content_width = usize::from(desired_width.saturating_sub(2)).max(1);
- let (content_preview, _, _) = build_screen_content(&screen, content_width);
- let desired_height = (wrapped_line_count(&content_preview, content_width) as u16)
- .saturating_add(2)
- .min(area.height.max(1))
- .max(3);
-
- let max_x = area.x + area.width.saturating_sub(desired_width);
- let preferred_x = area.x + anchor_col.saturating_sub(2);
- let card = Rect {
- x: preferred_x.min(max_x),
- y: area.y,
- width: desired_width,
- height: desired_height,
- };
-
- let footer = match &screen {
- Screen::Prompt { footer, .. }
- | Screen::Generating { footer, .. }
- | Screen::Review { footer, .. }
- | Screen::Error { footer, .. } => *footer,
- };
-
- let block = Block::default()
- .borders(Borders::ALL)
- .title("Describe the command you'd like to generate:")
- .title_bottom(Line::from(footer).alignment(Alignment::Right));
-
- let content_area = block.inner(card);
- frame.render_widget(block, card);
-
- let (content, show_cursor, cursor_prompt) =
- build_screen_content(&screen, usize::from(content_area.width).max(1));
-
- let paragraph = Paragraph::new(content).wrap(Wrap { trim: false });
- frame.render_widget(paragraph, content_area);
-
- if show_cursor {
- let width = usize::from(content_area.width).max(1);
- let (cursor_row, cursor_col) =
- prompt_cursor_position(cursor_prompt.as_deref().unwrap_or_default(), width);
- let cursor_x = content_area.x.saturating_add(cursor_col);
- let cursor_y = content_area.y.saturating_add(cursor_row);
- frame.set_cursor_position((cursor_x, cursor_y));
- }
-}
-
-fn format_prompt(prompt: &str) -> String {
- if prompt.is_empty() {
- return "> ".to_string();
- }
- format!("> {prompt}")
-}
-
-fn wrapped_line_count(text: &str, width: usize) -> usize {
- if width == 0 {
- return 1;
- }
-
- text.split('\n')
- .map(|line| {
- let len = line.chars().count();
- len.max(1).div_ceil(width)
- })
- .sum::<usize>()
- .max(1)
-}
-
-fn build_screen_content(
- screen: &Screen<'_>,
- content_width: usize,
-) -> (String, bool, Option<String>) {
- match screen {
- Screen::Prompt { prompt, .. } => {
- let formatted = format_prompt(prompt);
- (formatted, true, Some((*prompt).to_string()))
- }
- Screen::Generating {
- prompt,
- spinner_idx,
- ..
- } => (
- format!(
- "{}\n\n{} Generating...",
- format_prompt(prompt),
- SPINNER_FRAMES[*spinner_idx]
- ),
- false,
- None,
- ),
- Screen::Review {
- prompt, response, ..
- } => {
- let separator = "─".repeat(content_width.max(1));
- let mut text = format!(
- "{}\n\n{}\n\n$ {}\n",
- format_prompt(prompt),
- separator,
- response.command
- );
- if let Some(explanation) = &response.explanation {
- text.push('\n');
- text.push_str(explanation);
- }
- (text, false, None)
- }
- Screen::Error { prompt, err, .. } => (
- format!("{}\n\nRequest failed:\n{}", format_prompt(prompt), err),
- false,
- None,
- ),
- }
-}
-
-fn prompt_cursor_position(prompt: &str, width: usize) -> (u16, u16) {
- if width == 0 {
- return (0, 0);
- }
-
- // The visible prompt line is always `> {prompt}`.
- // We mimic word-wrapping so cursor tracking matches visual layout.
- let mut row = 0usize;
- let mut col = 2usize; // "> "
-
- let mut saw_any_word = false;
- for word in prompt.split_whitespace() {
- let word_len = word.chars().count();
- if !saw_any_word {
- saw_any_word = true;
- if col + word_len <= width {
- col += word_len;
- } else if word_len >= width {
- let used = width.saturating_sub(col);
- let remaining = word_len.saturating_sub(used);
- row += 1 + (remaining / width);
- col = remaining % width;
- } else {
- row += 1;
- col = word_len;
- }
- continue;
- }
-
- if col + 1 + word_len <= width {
- col += 1 + word_len;
- } else if word_len >= width {
- row += 1 + (word_len / width);
- col = word_len % width;
- } else {
- row += 1;
- col = word_len;
- }
- }
-
- // Keep trailing spaces user typed.
- let trailing_spaces = prompt.chars().rev().take_while(|c| *c == ' ').count();
- for _ in 0..trailing_spaces {
- if col >= width {
- row += 1;
- col = 0;
- }
- col += 1;
- }
-
- (row as u16, col as u16)
-}