diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-10 13:24:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-10 20:24:57 +0000 |
| commit | 09279a428659cf41824737d3e0c97bcc19a8885a (patch) | |
| tree | 64731502c065df2483e8dd680d46c5559f3094f2 /crates/atuin-ai/src/tui/state.rs | |
| parent | feat: add strip_trailing_whitespace, on by default (#3390) (diff) | |
| download | atuin-09279a428659cf41824737d3e0c97bcc19a8885a.zip | |
feat: Client-tool execution + permission system (#3370)
Adds client-side tool execution to Atuin AI, starting with
`atuin_history`. The server can request tool calls, which are executed
locally with a permission system, and results are sent back to continue
the conversation.
Diffstat (limited to 'crates/atuin-ai/src/tui/state.rs')
| -rw-r--r-- | crates/atuin-ai/src/tui/state.rs | 600 |
1 files changed, 321 insertions, 279 deletions
diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 4c5c2a1e..69b35909 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -5,9 +5,11 @@ use tokio::task::AbortHandle; +use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker}; + /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] -pub enum StreamingStatus { +pub(crate) enum StreamingStatus { Processing, Searching, Thinking, @@ -15,7 +17,7 @@ pub enum StreamingStatus { } impl StreamingStatus { - pub fn from_status_str(s: &str) -> Self { + pub(crate) fn from_status_str(s: &str) -> Self { match s { "processing" => Self::Processing, "searching" => Self::Searching, @@ -23,20 +25,11 @@ impl StreamingStatus { _ => Self::Thinking, } } - - pub fn display_text(&self) -> &'static str { - match self { - Self::Processing => "Processing...", - Self::Searching => "Searching...", - Self::Thinking => "Thinking...", - Self::WaitingForTools => "Waiting for tools...", - } - } } /// Conversation event types matching the API protocol #[derive(Debug, Clone)] -pub enum ConversationEvent { +pub(crate) enum ConversationEvent { /// User message (what the user typed) UserMessage { content: String }, /// Text content from assistant (streamed or complete) @@ -62,48 +55,8 @@ pub enum ConversationEvent { } impl ConversationEvent { - /// Convert to JSON for API calls - pub fn to_json(&self) -> serde_json::Value { - match self { - ConversationEvent::UserMessage { content } => serde_json::json!({ - "type": "user_message", - "content": content - }), - ConversationEvent::Text { content } => serde_json::json!({ - "type": "text", - "content": content - }), - ConversationEvent::ToolCall { id, name, input } => serde_json::json!({ - "type": "tool_call", - "id": id, - "name": name, - "input": input - }), - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - } => serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error - }), - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => serde_json::json!({ - "type": "out_of_band_output", - "name": name, - "command": command, - "content": content - }), - } - } - /// Extract command from a suggest_command tool call - pub fn as_command(&self) -> Option<&str> { + pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self && name == "suggest_command" { @@ -113,8 +66,9 @@ impl ConversationEvent { } } +/// Application mode for key handling and footer text. #[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum AppMode { +pub(crate) enum AppMode { /// User is typing input Input, /// Waiting for generation (showing spinner) @@ -126,7 +80,7 @@ pub enum AppMode { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExitAction { +pub(crate) enum ExitAction { /// Run the command Execute(String), /// Insert command without running @@ -135,47 +89,20 @@ pub enum ExitAction { Cancel, } -/// Application state — the domain model -/// -/// Conversation is stored as a sequence of events matching the API protocol. -/// The view function derives the UI from this state. +/// Owned event log and session ID #[derive(Debug)] -pub struct AppState { - /// Current application mode - pub mode: AppMode, +pub(crate) struct Conversation { /// Conversation events (source of truth, matches API protocol) pub events: Vec<ConversationEvent>, - /// Current error message - pub error: Option<String>, - /// Exit action (set when exiting) - pub exit_action: Option<ExitAction>, /// Session ID from server pub session_id: Option<String>, - /// Current streaming status - pub streaming_status: Option<StreamingStatus>, - /// Whether the input is blank - pub is_input_blank: bool, - /// Whether current turn was interrupted by user - pub was_interrupted: bool, - /// True when user has pressed Enter once on a dangerous command - pub confirmation_pending: bool, - /// Abort handle for the active streaming task, if any - pub stream_abort: Option<AbortHandle>, } -impl AppState { +impl Conversation { pub fn new() -> Self { Self { - mode: AppMode::Input, events: Vec::new(), - error: None, - exit_action: None, session_id: None, - streaming_status: None, - is_input_blank: false, - was_interrupted: false, - confirmation_pending: false, - stream_abort: None, } } @@ -195,16 +122,57 @@ impl AppState { i += 1; } ConversationEvent::Text { content } => { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; + // Check if the next event(s) are ToolCalls — if so, combine + // into a single assistant message with mixed content blocks. + let next_is_tool_call = events + .get(i + 1) + .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); + + if next_is_tool_call { + let mut content_blocks = Vec::new(); + + if !content.is_empty() { + content_blocks.push(serde_json::json!({ + "type": "text", + "text": content + })); + } + + while let Some(ConversationEvent::ToolCall { + id, name, input, .. + }) = events.get(i + 1) + { + content_blocks.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } + + messages.push(serde_json::json!({ + "role": "assistant", + "content": content_blocks + })); + i += 1; + } else { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } } ConversationEvent::ToolCall { .. } => { + // ToolCalls without preceding Text (shouldn't normally happen, + // but handle defensively) let mut tool_uses = Vec::new(); while i < events.len() { - if let ConversationEvent::ToolCall { id, name, input } = &events[i] { + if let ConversationEvent::ToolCall { + id, name, input, .. + } = &events[i] + { tool_uses.push(serde_json::json!({ "type": "tool_use", "id": id, @@ -247,53 +215,42 @@ impl AppState { messages } - // ===== Generation lifecycle methods ===== - - /// Start generating from submitted input - pub fn start_generating(&mut self, input: String) { - self.events - .push(ConversationEvent::UserMessage { content: input }); - self.mode = AppMode::Generating; - } - - /// Generation error occurred - pub fn generation_error(&mut self, error: String) { - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Cancel during generation - pub fn cancel_generation(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() { - self.events.pop(); - } - self.mode = AppMode::Input; - } - - // ===== Streaming lifecycle methods ===== - - /// Start streaming response. - /// Pushes an empty Text event that will be mutated in-place as chunks arrive. - pub fn start_streaming(&mut self) { - self.events.push(ConversationEvent::Text { - content: String::new(), - }); - self.streaming_status = None; - self.was_interrupted = false; - self.mode = AppMode::Streaming; + /// Get the most recent command from events + pub fn current_command(&self) -> Option<&str> { + self.events.iter().rev().find_map(|e| e.as_command()) } - /// Store session ID from server response - pub fn store_session_id(&mut self, session_id: String) { - self.session_id = Some(session_id); + /// Check if any turn in the conversation has a command + pub fn has_any_command(&self) -> bool { + self.events.iter().any(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e { + name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + } else { + false + } + }) } - /// Update streaming status from SSE event - pub fn update_streaming_status(&mut self, status: &str) { - self.streaming_status = Some(StreamingStatus::from_status_str(status)); + /// Check if the most recent command is marked dangerous + pub fn is_current_command_dangerous(&self) -> bool { + self.events + .iter() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger_level = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + return Some( + danger_level == "high" || danger_level == "medium" || danger_level == "med", + ); + } + None + }) + .unwrap_or(false) } /// Get a mutable reference to the last Text event's content (the streaming buffer). @@ -307,28 +264,15 @@ impl AppState { }) } - /// Cancel streaming with context preservation - pub fn cancel_streaming(&mut self) { - if let Some(abort) = self.stream_abort.take() { - abort.abort(); - } - self.was_interrupted = true; - - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - if trimmed.is_empty() { - // Remove the empty text event - *content = String::new(); + /// Remove trailing empty Text events from the events list + fn remove_empty_trailing_text(&mut self) { + while let Some(ConversationEvent::Text { content }) = self.events.last() { + if content.is_empty() { + self.events.pop(); } else { - *content = format!("{trimmed}\n\n[User cancelled this generation]"); + break; } } - // Remove trailing empty Text events - self.remove_empty_trailing_text(); - - self.streaming_status = None; - self.confirmation_pending = false; - self.mode = AppMode::Input; } /// Append text chunk during streaming (mutates the last Text event in-place) @@ -354,26 +298,6 @@ impl AppState { } } - /// Add a tool call event during streaming. - /// The current streaming text is already in events, so we just push the tool call. - pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { - // Trim the streaming text event - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - - let is_suggest_command = name == "suggest_command"; - self.events - .push(ConversationEvent::ToolCall { id, name, input }); - - if is_suggest_command { - self.streaming_status = None; - self.mode = AppMode::Input; - } - } - /// Add a tool result event during streaming pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { self.events.push(ConversationEvent::ToolResult { @@ -383,47 +307,9 @@ impl AppState { }); } - /// Finalize streaming — trim the accumulated text and change mode - pub fn finalize_streaming(&mut self) { - if let Some(content) = self.streaming_content_mut() { - let trimmed = content.trim_start().to_string(); - *content = trimmed; - } - self.remove_empty_trailing_text(); - self.streaming_status = None; - self.mode = AppMode::Input; - } - - /// Streaming error — remove the partial text event - pub fn streaming_error(&mut self, error: String) { - self.remove_empty_trailing_text(); - self.error = Some(error); - self.mode = AppMode::Error; - } - - /// Remove trailing empty Text events from the events list - fn remove_empty_trailing_text(&mut self) { - while let Some(ConversationEvent::Text { content }) = self.events.last() { - if content.is_empty() { - self.events.pop(); - } else { - break; - } - } - } - - // ===== Edit mode and exit methods ===== - - /// Start edit mode for refinement - pub fn start_edit_mode(&mut self) { - self.confirmation_pending = false; - self.mode = AppMode::Input; - } - - /// Retry after error - pub fn retry(&mut self) { - self.error = None; - self.mode = AppMode::Generating; + /// Store session ID from server response + pub fn store_session_id(&mut self, session_id: String) { + self.session_id = Some(session_id); } /// Handle a slash command @@ -445,85 +331,247 @@ impl AppState { }), } } +} - // ===== Query methods ===== +/// Ephemeral UI/presentation state +#[derive(Debug)] +pub(crate) struct Interaction { + /// Current application mode + pub mode: AppMode, + /// Whether the input is blank + pub is_input_blank: bool, + /// True when user has pressed Enter once on a dangerous command + pub confirmation_pending: bool, + /// Current streaming status + pub streaming_status: Option<StreamingStatus>, + /// Whether current turn was interrupted by user + pub was_interrupted: bool, + /// Current error message + pub error: Option<String>, +} - /// Get the most recent command from events - pub fn current_command(&self) -> Option<&str> { - self.events.iter().rev().find_map(|e| e.as_command()) +impl Interaction { + pub fn new() -> Self { + Self { + mode: AppMode::Input, + is_input_blank: false, + confirmation_pending: false, + streaming_status: None, + was_interrupted: false, + error: None, + } } +} - /// Check if the most recent command is marked dangerous - pub fn is_current_command_dangerous(&self) -> bool { - self.events - .iter() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - return Some( - danger_level == "high" || danger_level == "medium" || danger_level == "med", - ); - } - None - }) - .unwrap_or(false) +/// Top-level session state +/// +/// Decomposed into `Conversation` (event log + session ID) and +/// `Interaction` (ephemeral UI state). Session methods that cross +/// both sub-structs live here. +#[derive(Debug)] +pub(crate) struct Session { + pub conversation: Conversation, + pub interaction: Interaction, + /// Tracks all tool calls through their full lifecycle. + pub tool_tracker: ToolTracker, + /// Whether the session is running inside a git project (for permission UI labels). + pub in_git_project: bool, + /// Exit action (set when exiting) + pub exit_action: Option<ExitAction>, + /// Abort handle for the active streaming task, if any + pub stream_abort: Option<AbortHandle>, +} + +impl Session { + pub fn new(in_git_project: bool) -> Self { + Self { + conversation: Conversation::new(), + interaction: Interaction::new(), + tool_tracker: ToolTracker::new(), + in_git_project, + exit_action: None, + stream_abort: None, + } } - /// Count non-suggest_command tool calls since the last user message - pub fn tool_count_since_last_user(&self) -> usize { - let last_user_idx = self + // ===== Generation lifecycle methods ===== + + /// Start generating from submitted input + pub fn start_generating(&mut self, input: String) { + self.conversation .events - .iter() - .rposition(|e| matches!(e, ConversationEvent::UserMessage { .. })) - .unwrap_or(0); + .push(ConversationEvent::UserMessage { content: input }); + self.interaction.mode = AppMode::Generating; + } - let mut completed = 0; - let mut in_flight = false; + /// Generation error occurred + #[expect(dead_code)] + pub fn generation_error(&mut self, error: String) { + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } - for event in &self.events[last_user_idx..] { - match event { - ConversationEvent::ToolCall { name, .. } if name != "suggest_command" => { - if in_flight { - completed += 1; - } - in_flight = true; - } - ConversationEvent::ToolResult { .. } => { - if in_flight { - completed += 1; - in_flight = false; - } - } - _ => {} - } + /// Cancel during generation + pub fn cancel_generation(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + if let Some(ConversationEvent::UserMessage { .. }) = self.conversation.events.last() { + self.conversation.events.pop(); } + self.interaction.mode = AppMode::Input; + } - completed + // ===== Streaming lifecycle methods ===== + + /// Start streaming response. + /// Pushes an empty Text event that will be mutated in-place as chunks arrive. + pub fn start_streaming(&mut self) { + self.conversation.events.push(ConversationEvent::Text { + content: String::new(), + }); + self.interaction.streaming_status = None; + self.interaction.was_interrupted = false; + self.interaction.mode = AppMode::Streaming; } - /// Check if any turn in the conversation has a command - pub fn has_any_command(&self) -> bool { - self.events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() + /// Update streaming status from SSE event + pub fn update_streaming_status(&mut self, status: &str) { + self.interaction.streaming_status = Some(StreamingStatus::from_status_str(status)); + } + + /// Cancel streaming with context preservation + pub fn cancel_streaming(&mut self) { + if let Some(abort) = self.stream_abort.take() { + abort.abort(); + } + self.interaction.was_interrupted = true; + + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + if trimmed.is_empty() { + // Remove the empty text event + *content = String::new(); } else { - false + *content = format!("{trimmed}\n\n[User cancelled this generation]"); } - }) + } + // Remove trailing empty Text events + self.conversation.remove_empty_trailing_text(); + + self.interaction.streaming_status = None; + self.interaction.confirmation_pending = false; + self.interaction.mode = AppMode::Input; + } + + /// Add a tool call event during streaming. + /// The current streaming text is already in events, so we just push the tool call. + pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { + // Trim the streaming text event + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + + let is_suggest_command = name == "suggest_command"; + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + if is_suggest_command { + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + } + + /// Finalize streaming — trim the accumulated text and change mode + pub fn finalize_streaming(&mut self) { + if let Some(content) = self.conversation.streaming_content_mut() { + let trimmed = content.trim_start().to_string(); + *content = trimmed; + } + self.conversation.remove_empty_trailing_text(); + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Streaming error — remove the partial text event + pub fn streaming_error(&mut self, error: String) { + self.conversation.remove_empty_trailing_text(); + self.interaction.error = Some(error); + self.interaction.mode = AppMode::Error; + } + + pub(crate) fn handle_client_tool_call( + &mut self, + id: String, + tool: ClientToolCall, + input: serde_json::Value, + ) { + let desc = tool.descriptor(); + let name = desc.canonical_names[0].to_string(); + + self.tool_tracker.insert(id.clone(), tool); + + // Add the ToolCall event to the conversation immediately so it appears + // in the view. Preview data is sourced from tool_tracker. + self.conversation + .events + .push(ConversationEvent::ToolCall { id, name, input }); + + // Client tool calls can only happen at the last part of a turn + self.interaction.streaming_status = None; + self.interaction.mode = AppMode::Input; + } + + /// Retry after error + pub fn retry(&mut self) { + self.interaction.error = None; + self.interaction.mode = AppMode::Generating; + } + + // ===== Tool lifecycle methods ===== + + /// Finish a tool call: transition tracker to Completed, push ToolResult to conversation. + /// + /// For shell commands, captures the final preview from the ExecutingWithPreview phase + /// and patches exit_code/interrupted from the authoritative ToolOutcome. + pub fn finish_tool_call(&mut self, tool_id: &str, outcome: ToolOutcome) { + let mut preview = self.tool_tracker.get(tool_id).and_then(|t| t.preview()); + + // Patch preview with authoritative outcome data (handles race where + // final VT100 update hasn't been applied yet). + if let Some(ref mut p) = preview + && let ToolOutcome::Structured { + exit_code, + interrupted, + .. + } = &outcome + { + p.interrupted = *interrupted; + if p.exit_code.is_none() { + p.exit_code = *exit_code; + } + } + + // Transition tracker entry to Completed + if let Some(tracked) = self.tool_tracker.get_mut(tool_id) { + tracked.complete(preview); + } + + let content = outcome.format_for_llm(); + let is_error = outcome.is_error(); + self.conversation + .add_tool_result(tool_id.to_string(), content, is_error); } /// Get the footer text for current mode pub fn footer_text(&self) -> &'static str { - match self.mode { + match self.interaction.mode { AppMode::Input => { - if self.has_any_command() && self.is_input_blank { - if self.confirmation_pending { + if self.conversation.has_any_command() && self.interaction.is_input_blank { + if self.interaction.confirmation_pending { "[Enter] Confirm dangerous command [Esc] Cancel" } else { "[Enter] Execute suggested command [Tab] Insert Command" @@ -542,9 +590,3 @@ impl AppState { self.exit_action.is_some() } } - -impl Default for AppState { - fn default() -> Self { - Self::new() - } -} |
