diff options
Diffstat (limited to 'crates/atuin-ai/src/tui/state.rs')
| -rw-r--r-- | crates/atuin-ai/src/tui/state.rs | 530 |
1 files changed, 530 insertions, 0 deletions
diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs new file mode 100644 index 00000000..ba9c8ac6 --- /dev/null +++ b/crates/atuin-ai/src/tui/state.rs @@ -0,0 +1,530 @@ +//! Domain state types for the TUI application +//! +//! This module contains the core state types that represent the application's +//! domain model. Conversation events match the API protocol format. + +use std::time::Instant; +use tui_textarea::TextArea; + +use super::spinner::{ACTIVE_SPINNER, active_tick_interval}; + +/// Streaming status indicators from server +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StreamingStatus { + Processing, + Searching, + Thinking, + WaitingForTools, +} + +impl StreamingStatus { + pub fn from_status_str(s: &str) -> Self { + match s { + "processing" => Self::Processing, + "searching" => Self::Searching, + "waiting_for_tools" => Self::WaitingForTools, + _ => Self::Thinking, // Default to thinking for "thinking" and unknown + } + } + + pub fn display_text(&self) -> &'static str { + match self { + Self::Processing => "Processing...", + Self::Searching => "Searching...", + Self::Thinking => "Thinking...", + Self::WaitingForTools => "Waiting for tools...", + } + } +} + +/// Conversation event types matching the API protocol +#[derive(Debug, Clone)] +pub enum ConversationEvent { + /// User message (what the user typed) + UserMessage { content: String }, + /// Text content from assistant (streamed or complete) + Text { content: String }, + /// Tool call from assistant + ToolCall { + id: String, + name: String, + input: serde_json::Value, + }, + /// Tool result (usually from server-side execution) + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +impl ConversationEvent { + /// Convert to JSON for API calls + pub fn to_json(&self) -> serde_json::Value { + match self { + ConversationEvent::UserMessage { content } => serde_json::json!({ + "type": "user_message", + "content": content + }), + ConversationEvent::Text { content } => serde_json::json!({ + "type": "text", + "content": content + }), + ConversationEvent::ToolCall { id, name, input } => serde_json::json!({ + "type": "tool_call", + "id": id, + "name": name, + "input": input + }), + ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + } => serde_json::json!({ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + "is_error": is_error + }), + } + } + + /// Extract command from a suggest_command tool call + pub fn as_command(&self) -> Option<&str> { + if let ConversationEvent::ToolCall { name, input, .. } = self + && name == "suggest_command" + { + // command can be null for pure conversational turns + return input.get("command").and_then(|v| v.as_str()); + } + None + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AppMode { + /// User is typing input + Input, + /// Waiting for generation (showing spinner) + Generating, + /// Streaming SSE response + Streaming, + /// Reviewing generated command + Review, + /// Error state, can retry + Error, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExitAction { + /// Run the command + Execute(String), + /// Insert command without running + Insert(String), + /// User canceled + Cancel, +} + +/// Application state - the domain model +/// +/// Conversation is stored as a sequence of events matching the API protocol. +/// The view model is derived from this state via `Blocks::from_state()`. +pub struct AppState { + /// Current application mode + pub mode: AppMode, + /// Conversation events (source of truth, matches API protocol) + pub events: Vec<ConversationEvent>, + /// Text being streamed (accumulated, flushed to Text event on completion) + pub streaming_text: String, + /// Active text input (uses tui-textarea for proper cursor handling) + pub textarea: TextArea<'static>, + /// Current error message (renders at end of blocks) + pub error: Option<String>, + /// Whether app should exit + pub should_exit: bool, + /// Exit action (set when exiting) + pub exit_action: Option<ExitAction>, + /// Session ID from server (store after first response, send on subsequent) + pub session_id: Option<String>, + /// Current streaming status (for spinner text) + pub streaming_status: Option<StreamingStatus>, + /// Whether current turn was interrupted by user + pub was_interrupted: bool, + /// Spinner animation state + pub spinner_frame: usize, + /// When spinner frame last advanced (for timing control) + pub last_spinner_tick: Instant, + /// When streaming started (for spinner delay) + pub streaming_started: Option<Instant>, + /// True when user has pressed Enter once on a dangerous command + pub confirmation_pending: bool, +} + +/// Create a TextArea with our preferred configuration +fn create_textarea() -> TextArea<'static> { + let mut textarea = TextArea::default(); + // Disable underline on cursor line - it's distracting + textarea.set_cursor_line_style(ratatui::style::Style::default()); + // Enable word wrapping + textarea.set_wrap_mode(tui_textarea::WrapMode::Word); + textarea +} + +impl AppState { + pub fn new() -> Self { + Self { + mode: AppMode::Input, + events: Vec::new(), + streaming_text: String::new(), + textarea: create_textarea(), + error: None, + should_exit: false, + exit_action: None, + session_id: None, + streaming_status: None, + was_interrupted: false, + spinner_frame: 0, + last_spinner_tick: Instant::now(), + streaming_started: None, + confirmation_pending: false, + } + } + + /// Get the current input text + pub fn input(&self) -> String { + self.textarea.lines().join("\n") + } + + /// Check if input is empty + pub fn input_is_empty(&self) -> bool { + self.textarea.is_empty() + } + + /// Clear the input + pub fn clear_input(&mut self) { + self.textarea = create_textarea(); + } + + /// Convert conversation events to Claude API message format + /// Groups consecutive tool calls, handles role alternation + pub fn events_to_messages(&self) -> Vec<serde_json::Value> { + let mut messages = Vec::new(); + let mut i = 0; + let events = &self.events; + + while i < events.len() { + match &events[i] { + ConversationEvent::UserMessage { content } => { + messages.push(serde_json::json!({ + "role": "user", + "content": content + })); + i += 1; + } + ConversationEvent::Text { content } => { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } + ConversationEvent::ToolCall { .. } => { + // Group consecutive tool calls into single assistant message + let mut tool_uses = Vec::new(); + while i < events.len() { + if let ConversationEvent::ToolCall { id, name, input } = &events[i] { + tool_uses.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } else { + break; + } + } + messages.push(serde_json::json!({ + "role": "assistant", + "content": tool_uses + })); + } + ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + } => { + messages.push(serde_json::json!({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + "is_error": is_error + }] + })); + i += 1; + } + } + } + + messages + } + + // ===== Generation lifecycle methods ===== + + /// Start generating from current input + pub fn start_generating(&mut self) { + // Add user message event + self.events.push(ConversationEvent::UserMessage { + content: self.input(), + }); + + // Clear input, switch mode + self.clear_input(); + self.mode = AppMode::Generating; + } + + /// Generation complete with command (legacy method, kept for compatibility) + pub fn generation_complete( + &mut self, + command: String, + explanation: Option<String>, + dangerous: bool, + warnings: Vec<String>, + ) { + // Add explanation as text event if present + if let Some(ref exp) = explanation { + self.events.push(ConversationEvent::Text { + content: exp.clone(), + }); + } + + // Add tool_call event for suggest_command + let tool_id = format!("gen_{}", uuid::Uuid::new_v4().simple()); + let mut tool_input = serde_json::json!({ + "command": command, + "conversation_only": false, + "confidence": "high" + }); + if let Some(ref exp) = explanation { + tool_input["message"] = serde_json::json!(exp); + } + if dangerous { + tool_input["danger"] = serde_json::json!("high"); + } + if !warnings.is_empty() { + tool_input["warning"] = serde_json::json!(warnings.join("; ")); + } + + self.events.push(ConversationEvent::ToolCall { + id: tool_id, + name: "suggest_command".to_string(), + input: tool_input, + }); + + self.mode = AppMode::Review; + } + + /// Generation error occurred + pub fn generation_error(&mut self, error: String) { + self.error = Some(error); + self.mode = AppMode::Error; + } + + /// Cancel during generation + pub fn cancel_generation(&mut self) { + // Remove the last user message since generation was cancelled + if let Some(ConversationEvent::UserMessage { .. }) = self.events.last() { + self.events.pop(); + } + self.mode = AppMode::Input; + self.clear_input(); + } + + // ===== Streaming lifecycle methods ===== + + /// Start streaming response + pub fn start_streaming(&mut self) { + self.streaming_text.clear(); + self.streaming_status = None; + self.was_interrupted = false; + self.streaming_started = Some(Instant::now()); + self.mode = AppMode::Streaming; + } + + /// Store session ID from server response + pub fn store_session_id(&mut self, session_id: String) { + self.session_id = Some(session_id); + } + + /// Update streaming status from SSE event + pub fn update_streaming_status(&mut self, status: &str) { + self.streaming_status = Some(StreamingStatus::from_status_str(status)); + } + + /// Cancel streaming with context preservation + pub fn cancel_streaming(&mut self) { + // Mark as interrupted + self.was_interrupted = true; + + // Flush partial text with interruption marker if any + // Trim leading whitespace since LLM responses often start with \n\n + let content = std::mem::take(&mut self.streaming_text); + let trimmed = content.trim_start(); + if !trimmed.is_empty() { + let interrupted_text = format!("{trimmed}\n\n[User cancelled this generation]"); + self.events.push(ConversationEvent::Text { + content: interrupted_text, + }); + } + + // Clear status and return to input + self.streaming_status = None; + self.confirmation_pending = false; + self.mode = AppMode::Input; + } + + /// Append text chunk during streaming + /// Trims leading whitespace from the first chunk(s) since LLM responses often start with \n\n + pub fn append_streaming_text(&mut self, chunk: &str) { + if self.streaming_text.is_empty() { + // First chunk(s): trim leading whitespace + let trimmed = chunk.trim_start(); + if !trimmed.is_empty() { + self.streaming_text.push_str(trimmed); + } + } else { + // Subsequent chunks: append as-is + self.streaming_text.push_str(chunk); + } + } + + /// Add a tool call event during streaming + /// Flushes any pending streaming text first to maintain correct event order + /// For suggest_command, also transitions to Review mode since that ends the LLM turn + pub fn add_tool_call(&mut self, id: String, name: String, input: serde_json::Value) { + // Flush streaming text before adding tool call to maintain correct order + let content = std::mem::take(&mut self.streaming_text); + let trimmed = content.trim_start(); + if !trimmed.is_empty() { + self.events.push(ConversationEvent::Text { + content: trimmed.to_string(), + }); + } + + // suggest_command marks the end of the LLM turn - transition to Review + let is_suggest_command = name == "suggest_command"; + + self.events + .push(ConversationEvent::ToolCall { id, name, input }); + + if is_suggest_command { + self.streaming_status = None; + self.streaming_started = None; + self.mode = AppMode::Review; + } + } + + /// Add a tool result event during streaming + pub fn add_tool_result(&mut self, tool_use_id: String, content: String, is_error: bool) { + self.events.push(ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + }); + } + + /// Finalize streaming - flush accumulated text to event + pub fn finalize_streaming(&mut self) { + // Flush streaming text to a Text event if non-empty + // Trim leading whitespace since LLM responses often start with \n\n + let content = std::mem::take(&mut self.streaming_text); + let trimmed = content.trim_start(); + if !trimmed.is_empty() { + self.events.push(ConversationEvent::Text { + content: trimmed.to_string(), + }); + } + self.streaming_status = None; + self.streaming_started = None; + self.mode = AppMode::Review; + } + + /// Streaming error + pub fn streaming_error(&mut self, error: String) { + // Discard any partial streaming text + self.streaming_text.clear(); + self.streaming_started = None; + self.error = Some(error); + self.mode = AppMode::Error; + } + + // ===== Edit mode and exit methods ===== + + /// Start edit mode for refinement + pub fn start_edit_mode(&mut self) { + self.confirmation_pending = false; + self.clear_input(); + self.mode = AppMode::Input; + } + + /// Exit with action + pub fn exit(&mut self, action: ExitAction) { + self.exit_action = Some(action); + self.should_exit = true; + } + + /// Retry after error + pub fn retry(&mut self) { + self.error = None; + self.mode = AppMode::Generating; + } + + // ===== Utility methods ===== + + /// Advance spinner frame if enough time has passed + /// Called on every event loop tick (50ms), but only advances spinner + /// when the active spinner's interval has elapsed + pub fn tick(&mut self) { + let interval = active_tick_interval(); + if self.last_spinner_tick.elapsed() >= interval { + self.spinner_frame = (self.spinner_frame + 1) % ACTIVE_SPINNER.frame_count(); + self.last_spinner_tick = Instant::now(); + } + } + + /// Get the most recent command from events + pub fn current_command(&self) -> Option<&str> { + self.events.iter().rev().find_map(|e| e.as_command()) + } + + /// Check if the most recent command suggestion is marked dangerous + /// Checks the `danger` field for "high", "medium", or "med" values + pub fn is_current_command_dangerous(&self) -> bool { + self.events + .iter() + .rev() + .find_map(|e| { + if let ConversationEvent::ToolCall { name, input, .. } = e + && name == "suggest_command" + { + let danger_level = input + .get("danger") + .and_then(|v| v.as_str()) + .unwrap_or("low"); + return Some( + danger_level == "high" || danger_level == "medium" || danger_level == "med", + ); + } + None + }) + .unwrap_or(false) + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} |
