diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-14 16:03:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-15 00:03:08 +0100 |
| commit | fd188da879d977ca847f10708c39dd4801a204c4 (patch) | |
| tree | 592bfe2644f8bd9be3563f176eabf29e55fa9a9b /crates | |
| parent | fix: dependency fix (#3414) (diff) | |
| download | atuin-fd188da879d977ca847f10708c39dd4801a204c4.zip | |
feat: Allow resuming previous AI sessions (#3407)
This PR introduces session continuation to Atuin AI.
* Conversations with Atuin AI are stored in a local SQLite database
* Upon startup, Atuin AI tries to find a session to resume based on its
directory/workspace and the time since the last event
* If found, Atuin AI will show a note that the session has been resumed,
and an event is added to help the LLM know where the invocation
boundaries are
* If not, Atuin AI will create a new conversation
* The user can create a new conversation with `/new`
* The new setting `ai.session_continue_minutes`, which defaults to `60`,
controls how old the last event in a session can be before it's no
longer considered for automatic resuming.
<img width="1055" height="593" alt="image"
src="https://github.com/user-attachments/assets/3f9ff01a-ef64-44a9-b0e2-3a4252c5746f"
/>
## Architecture
A new `SessionService` trait defines an API contract for a service that
can manage session data. `LocalSessionService` implements this, with
`DaemonSessionService` a possible future extension point.
`SessionManager` owns a `dyn SessionService` and delegates as
appropriate.
Diffstat (limited to 'crates')
| -rw-r--r-- | crates/atuin-ai/Cargo.toml | 7 | ||||
| -rw-r--r-- | crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql | 32 | ||||
| -rw-r--r-- | crates/atuin-ai/src/commands/inline.rs | 74 | ||||
| -rw-r--r-- | crates/atuin-ai/src/context_window.rs | 578 | ||||
| -rw-r--r-- | crates/atuin-ai/src/event_serde.rs | 376 | ||||
| -rw-r--r-- | crates/atuin-ai/src/lib.rs | 4 | ||||
| -rw-r--r-- | crates/atuin-ai/src/session.rs | 482 | ||||
| -rw-r--r-- | crates/atuin-ai/src/store.rs | 522 | ||||
| -rw-r--r-- | crates/atuin-ai/src/stream.rs | 10 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/components/input_box.rs | 16 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/components/mod.rs | 1 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/components/session_continue.rs | 49 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/content/help.md | 3 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/dispatch.rs | 117 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/mod.rs | 3 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/slash.rs | 79 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/state.rs | 337 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/mod.rs | 37 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/turn.rs | 3 | ||||
| -rw-r--r-- | crates/atuin-client/src/settings.rs | 9 |
20 files changed, 2585 insertions, 154 deletions
diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml index c5f66695..3bdd45d2 100644 --- a/crates/atuin-ai/Cargo.toml +++ b/crates/atuin-ai/Cargo.toml @@ -17,6 +17,7 @@ default = [] tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] [dependencies] +async-trait = { workspace = true } atuin-client = { workspace = true } atuin-common = { workspace = true } tokio = { workspace = true } @@ -36,6 +37,7 @@ serde = { workspace = true } serde_json = { workspace = true } crossterm = { workspace = true, features = ["use-dev-tty", "event-stream"] } ratatui = { workspace = true } +fs-err = { workspace = true } futures = "0.3" eventsource-stream = "0.2" pulldown-cmark = "0.13.0" @@ -43,7 +45,7 @@ async-stream = "0.3" uuid = { workspace = true } tui-textarea-2 = "0.10.2" unicode-width = "0.2" -eye_declare = "0.4" +eye_declare = "0.4.2" ratatui-core = "0.1" ratatui-widgets = "0.3" thiserror = { workspace = true } @@ -55,8 +57,11 @@ toml_edit = { workspace = true } tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true } tree-sitter-bash = { version = "0.25.1", optional = true } tree-sitter-fish = { version = "3.6.0", optional = true } +sqlx = { workspace = true, features = ["sqlite"] } typed-builder = { workspace = true } vt100 = { workspace = true } +chrono = "0.4" +chrono-humanize = "0.2" [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql b/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql new file mode 100644 index 00000000..906a5726 --- /dev/null +++ b/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql @@ -0,0 +1,32 @@ +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + head_id TEXT, + server_session_id TEXT, + directory TEXT, + git_root TEXT, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + archived_at INTEGER +); + +CREATE INDEX idx_sessions_directory ON sessions(directory); +CREATE INDEX idx_sessions_git_root ON sessions(git_root); +CREATE INDEX idx_sessions_updated_at ON sessions(updated_at); +CREATE INDEX idx_sessions_created_at ON sessions(created_at); + +CREATE TABLE IF NOT EXISTS session_events ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + parent_id TEXT, + invocation_id TEXT NOT NULL, + event_type TEXT NOT NULL, + event_data TEXT NOT NULL, + created_at INTEGER NOT NULL, + + FOREIGN KEY (session_id) REFERENCES sessions(id) +); + +CREATE INDEX idx_session_events_session_id ON session_events(session_id); +CREATE INDEX idx_session_events_parent_id ON session_events(parent_id); +CREATE INDEX idx_session_events_invocation_id ON session_events(invocation_id); +CREATE INDEX idx_session_events_created_at ON session_events(created_at); diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index b37bb72f..2e6beca2 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; +use crate::session::{LocalSessionService, SessionManager, SessionService}; use crate::tui::dispatch; use crate::tui::events::AiTuiEvent; use crate::tui::state::{ExitAction, Session}; @@ -83,7 +84,7 @@ pub(crate) async fn run( capabilities: settings.ai.capabilities.clone(), }; - let action = run_inline_tui(ctx, initial_command).await?; + let action = run_inline_tui(ctx, initial_command, settings).await?; emit_shell_result(action, output_for_hook); Ok(()) @@ -147,12 +148,74 @@ async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Resu // ─────────────────────────────────────────────────────────────────── -async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Result<Action> { +async fn run_inline_tui( + ctx: AppContext, + initial_prompt: Option<String>, + settings: &atuin_client::settings::Settings, +) -> Result<Action> { let client_ctx = ClientContext::detect(); - let (tx, rx) = mpsc::channel::<AiTuiEvent>(); + // Open the session service and check for a resumable session + let service = LocalSessionService::open(&settings.ai.db_path, settings.local_timeout) + .await + .context("failed to open AI session database")?; + + let cwd = std::env::current_dir() + .ok() + .map(|p| p.to_string_lossy().into_owned()); + let git_root_str = ctx + .git_root + .as_ref() + .map(|p| p.to_string_lossy().into_owned()); + + let session_window_mins = settings.ai.session_continue_minutes.max(0); // treat negative values as 0 to avoid confusion + let max_age_secs: i64 = session_window_mins * 60; + + let resumable = service + .find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs) + .await?; - let initial_state = Session::new(ctx.git_root.is_some()); + let (session_mgr, initial_state) = if let Some(stored) = resumable { + debug!(session_id = %stored.id, "resuming AI session"); + let (mgr, events, server_sid, last_event_ts, invocation_id) = + SessionManager::resume(Box::new(service), &stored).await?; + + // Only treat this as a meaningful resume if there are API-visible events + // (not just OutOfBandOutput or SystemContext). + let has_api_content = events.iter().any(|e| e.is_api_content()); + + if has_api_content { + let mut session = Session::new(ctx.git_root.is_some(), Some(invocation_id)); + session.conversation.events = events; + session.conversation.session_id = server_sid; + // Inject an invocation boundary so the LLM knows prior messages + // are from an earlier interaction. + session.conversation.events.push( + crate::tui::state::ConversationEvent::SystemContext { + content: "[Note: The user has started a new invocation of Atuin AI. Prior messages from this session are from an earlier invocation.]".to_string(), + }, + ); + session.view_start_index = session.conversation.events.len(); + session.is_resumed = true; + session.last_event_time = + last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); + (mgr, session) + } else { + // No meaningful content — treat as a fresh session + debug!("resumable session has no API-visible content, starting fresh"); + ( + mgr, + Session::new(ctx.git_root.is_some(), Some(invocation_id)), + ) + } + } else { + debug!("creating new AI session"); + let mgr = + SessionManager::create_new(Box::new(service), cwd.as_deref(), git_root_str.as_deref()); + (mgr, Session::new(ctx.git_root.is_some(), None)) + }; + + let (tx, rx) = mpsc::channel::<AiTuiEvent>(); println!(); @@ -177,8 +240,9 @@ async fn run_inline_tui(ctx: AppContext, initial_prompt: Option<String>) -> Resu tokio::task::spawn_blocking(move || { let tx = tx.clone(); let client_ctx = client_ctx; + let mut session_mgr = session_mgr; while let Ok(event) = rx.recv() { - dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx); + dispatch::dispatch(&h, event, &tx, &ctx, &client_ctx, &mut session_mgr); } }); diff --git a/crates/atuin-ai/src/context_window.rs b/crates/atuin-ai/src/context_window.rs new file mode 100644 index 00000000..dcef05aa --- /dev/null +++ b/crates/atuin-ai/src/context_window.rs @@ -0,0 +1,578 @@ +//! Context window management for API requests. +//! +//! Full conversation events are always persisted to disk. This module handles +//! truncation at send time so the API payload stays within a character budget. +//! +//! Strategy: **frozen prefix + live tail**. The first N turns form a stable +//! prefix that stays identical across requests (maximizing prompt cache hits). +//! The most recent turns form the live tail. When the total exceeds the budget, +//! turns between prefix and tail are dropped with a truncation marker. The +//! prefix never shifts, avoiding cache invalidation. + +use std::ops::Range; + +use crate::tui::{ConversationEvent, events_to_messages}; + +/// Default character budget for the context window. +/// Roughly ~50K tokens at ~4 chars/token — generous enough that truncation +/// only kicks in for genuinely long sessions. +const DEFAULT_BUDGET_CHARS: usize = 200_000; + +/// Number of initial turns to freeze as the stable prefix. +const FROZEN_PREFIX_TURNS: usize = 1; + +/// Builds API messages from conversation events while respecting a character +/// budget using frozen prefix + live tail truncation. +pub(crate) struct ContextWindowBuilder { + budget: usize, +} + +impl ContextWindowBuilder { + pub fn new(budget: usize) -> Self { + Self { budget } + } + + pub fn with_default_budget() -> Self { + Self::new(DEFAULT_BUDGET_CHARS) + } + + /// Build API messages from conversation events, applying the context + /// window budget. Returns the messages to send in the API request. + pub fn build(&self, events: &[ConversationEvent]) -> Vec<serde_json::Value> { + if events.is_empty() { + return Vec::new(); + } + + let turns = group_into_turns(events); + + // Convert each turn's events to API messages independently. + // This is safe because the combining logic (Text + ToolCall merging) + // only operates within a single assistant response, which never + // spans turn boundaries. + let turn_messages: Vec<Vec<serde_json::Value>> = turns + .iter() + .map(|range| events_to_messages(&events[range.clone()])) + .collect(); + + let turn_chars: Vec<usize> = turn_messages.iter().map(|m| estimate_chars(m)).collect(); + let total_chars: usize = turn_chars.iter().sum(); + + if total_chars <= self.budget { + return turn_messages.into_iter().flatten().collect(); + } + + // --- Over budget: apply frozen prefix + live tail --- + + let prefix_count = FROZEN_PREFIX_TURNS.min(turns.len()); + let prefix_chars: usize = turn_chars[..prefix_count].iter().sum(); + + let marker = truncation_marker(); + let marker_chars = estimate_chars(std::slice::from_ref(&marker)); + + let mut remaining = self.budget.saturating_sub(prefix_chars + marker_chars); + + // Work backwards from the end, accumulating tail turns that fit. + let mut tail_start = turns.len(); + for i in (prefix_count..turns.len()).rev() { + if turn_chars[i] <= remaining { + remaining -= turn_chars[i]; + tail_start = i; + } else { + break; + } + } + + // Always include at least the most recent turn, even if it alone + // exceeds the budget — sending something is better than nothing. + if tail_start >= turns.len() && turns.len() > prefix_count { + tail_start = turns.len() - 1; + } + + let mut result = Vec::new(); + + // Frozen prefix + for msgs in &turn_messages[..prefix_count] { + result.extend(msgs.iter().cloned()); + } + + // Truncation marker (only if turns were actually dropped) + if tail_start > prefix_count { + result.push(marker); + } + + // Live tail + for msgs in &turn_messages[tail_start..] { + result.extend(msgs.iter().cloned()); + } + + result + } +} + +/// Marker message inserted where turns were dropped. Uses user role since +/// the preceding prefix typically ends with an assistant message. +fn truncation_marker() -> serde_json::Value { + serde_json::json!({ + "role": "user", + "content": "[Earlier conversation context was omitted to fit within the context window. The conversation continues below.]" + }) +} + +/// Group conversation events into turns. A new turn starts at each +/// `UserMessage` or `SystemContext` event. Everything between boundaries +/// belongs to the preceding turn (assistant text, tool calls, tool results, +/// out-of-band output). +fn group_into_turns(events: &[ConversationEvent]) -> Vec<Range<usize>> { + let mut turns = Vec::new(); + let mut start = 0; + + for (i, event) in events.iter().enumerate() { + if i > start + && matches!( + event, + ConversationEvent::UserMessage { .. } | ConversationEvent::SystemContext { .. } + ) + { + turns.push(start..i); + start = i; + } + } + + if start < events.len() { + turns.push(start..events.len()); + } + + turns +} + +/// Rough character-count estimate for a set of messages. Uses the JSON +/// serialization length as a proxy — not exact tokens, but proportional +/// and cheap to compute. +fn estimate_chars(messages: &[serde_json::Value]) -> usize { + messages.iter().map(|m| m.to_string().len()).sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn user(content: &str) -> ConversationEvent { + ConversationEvent::UserMessage { + content: content.to_string(), + } + } + + fn text(content: &str) -> ConversationEvent { + ConversationEvent::Text { + content: content.to_string(), + } + } + + fn tool_call(id: &str, name: &str) -> ConversationEvent { + ConversationEvent::ToolCall { + id: id.to_string(), + name: name.to_string(), + input: serde_json::json!({"command": "ls"}), + } + } + + fn tool_result(tool_use_id: &str, content: &str) -> ConversationEvent { + ConversationEvent::ToolResult { + tool_use_id: tool_use_id.to_string(), + content: content.to_string(), + is_error: false, + remote: false, + content_length: None, + } + } + + fn system_context(content: &str) -> ConversationEvent { + ConversationEvent::SystemContext { + content: content.to_string(), + } + } + + fn oob(content: &str) -> ConversationEvent { + ConversationEvent::OutOfBandOutput { + name: "test".to_string(), + command: None, + content: content.to_string(), + } + } + + // --- group_into_turns --- + + #[test] + fn empty_events_produce_no_turns() { + assert!(group_into_turns(&[]).is_empty()); + } + + #[test] + fn single_user_message_is_one_turn() { + let events = vec![user("hello")]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..1]); + } + + #[test] + fn user_assistant_is_one_turn() { + let events = vec![user("hello"), text("hi there")]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..2]); + } + + #[test] + fn two_turns_split_at_user_message() { + let events = vec![ + user("first"), + text("response 1"), + user("second"), + text("response 2"), + ]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..2, 2..4]); + } + + #[test] + fn tool_calls_and_results_stay_in_same_turn() { + let events = vec![ + user("list files"), + text("Let me check"), + tool_call("tc1", "suggest_command"), + tool_result("tc1", "file1\nfile2"), + text("Here are your files"), + ]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..5]); + } + + #[test] + fn system_context_starts_new_turn() { + let events = vec![ + user("hello"), + text("hi"), + system_context("invocation boundary"), + user("next question"), + text("answer"), + ]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..2, 2..3, 3..5]); + } + + #[test] + fn oob_events_stay_in_current_turn() { + let events = vec![user("hello"), oob("some output"), text("response")]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..3]); + } + + #[test] + fn leading_text_without_user_message() { + // Edge case: events start with assistant text (shouldn't happen + // normally but handle gracefully) + let events = vec![text("orphaned"), user("hello"), text("hi")]; + let turns = group_into_turns(&events); + assert_eq!(turns, vec![0..1, 1..3]); + } + + // --- ContextWindowBuilder --- + + #[test] + fn empty_events_produce_empty_messages() { + let builder = ContextWindowBuilder::with_default_budget(); + assert!(builder.build(&[]).is_empty()); + } + + #[test] + fn under_budget_returns_all_messages() { + let events = vec![user("hello"), text("hi"), user("how are you"), text("good")]; + let builder = ContextWindowBuilder::with_default_budget(); + let messages = builder.build(&events); + + // Should produce 4 messages (2 user + 2 assistant) + assert_eq!(messages.len(), 4); + assert_eq!(messages[0]["role"], "user"); + assert_eq!(messages[0]["content"], "hello"); + assert_eq!(messages[1]["role"], "assistant"); + assert_eq!(messages[1]["content"], "hi"); + assert_eq!(messages[2]["role"], "user"); + assert_eq!(messages[2]["content"], "how are you"); + assert_eq!(messages[3]["role"], "assistant"); + assert_eq!(messages[3]["content"], "good"); + } + + #[test] + fn over_budget_truncates_middle_turns() { + // Create events where each turn has known content. Use a tiny + // budget so truncation is triggered with just a few turns. + let events = vec![ + user("turn-1-user"), + text("turn-1-assistant"), + user("turn-2-user"), + text("turn-2-assistant"), + user("turn-3-user"), + text("turn-3-assistant"), + user("turn-4-user"), + text("turn-4-assistant-final"), + ]; + + // Calculate sizes to set budget that keeps turn 1 (prefix) + turn 4 (tail) + // but drops turns 2 and 3. + let all_messages = events_to_messages(&events); + let total_chars: usize = all_messages.iter().map(|m| m.to_string().len()).sum(); + + // Set budget to roughly half — enough for prefix + last turn + marker + let turn1_msgs = events_to_messages(&events[0..2]); + let turn4_msgs = events_to_messages(&events[6..8]); + let marker_chars = estimate_chars(std::slice::from_ref(&truncation_marker())); + let needed = estimate_chars(&turn1_msgs) + estimate_chars(&turn4_msgs) + marker_chars; + + // Budget allows prefix + marker + last turn but not the middle turns + assert!( + needed < total_chars, + "test setup: needed ({needed}) should be less than total ({total_chars})" + ); + let builder = ContextWindowBuilder::new(needed + 10); // small margin + + let messages = builder.build(&events); + + // Should have: turn 1 (2 msgs) + marker (1 msg) + turn 4 (2 msgs) = 5 + assert_eq!(messages.len(), 5, "expected prefix + marker + tail"); + assert_eq!(messages[0]["content"], "turn-1-user"); + assert_eq!(messages[1]["content"], "turn-1-assistant"); + assert!( + messages[2]["content"].as_str().unwrap().contains("omitted"), + "middle message should be truncation marker" + ); + assert_eq!(messages[3]["content"], "turn-4-user"); + assert_eq!(messages[4]["content"], "turn-4-assistant-final"); + } + + #[test] + fn very_tight_budget_keeps_prefix_and_last_turn() { + let events = vec![ + user("first"), + text("response-1"), + user("second"), + text("response-2"), + user("third"), + text("response-3"), + ]; + + // Budget of 1 — forces the "always include last turn" fallback + let builder = ContextWindowBuilder::new(1); + let messages = builder.build(&events); + + // Should have prefix (turn 1) + marker + last turn (turn 3) + assert!( + messages.len() >= 3, + "should have at least prefix + marker + tail" + ); + + // First message should be from turn 1 + assert_eq!(messages[0]["content"], "first"); + + // Last messages should be from the final turn + let last = messages.last().unwrap(); + assert_eq!(last["content"], "response-3"); + } + + #[test] + fn single_turn_always_returned() { + let events = vec![user("hello"), text("hi there")]; + + // Even with a tiny budget, the single turn must be returned + let builder = ContextWindowBuilder::new(1); + let messages = builder.build(&events); + assert_eq!(messages.len(), 2); + } + + #[test] + fn tool_calls_preserved_through_truncation() { + let events = vec![ + // Turn 1: simple exchange + user("turn 1"), + text("response 1"), + // Turn 2: with tool calls (will be dropped) + user("turn 2"), + text("checking"), + tool_call("tc1", "suggest_command"), + tool_result("tc1", "output"), + text("done"), + // Turn 3: final turn (kept in tail) + user("turn 3"), + text("final response"), + ]; + + // Budget that fits turn 1 + turn 3 + marker but not turn 2 + let turn1 = events_to_messages(&events[0..2]); + let turn3 = events_to_messages(&events[7..9]); + let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); + let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; + + let builder = ContextWindowBuilder::new(budget); + let messages = builder.build(&events); + + // Verify turn 2 (the tool call turn) was dropped + let has_tool_use = messages.iter().any(|m| { + m["content"] + .as_array() + .is_some_and(|arr| arr.iter().any(|b| b["type"] == "tool_use")) + }); + assert!(!has_tool_use, "tool call turn should have been truncated"); + + // Verify first and last turns present + assert_eq!(messages[0]["content"], "turn 1"); + assert_eq!(messages.last().unwrap()["content"], "final response"); + } + + #[test] + fn tail_accumulates_multiple_turns_when_budget_allows() { + // Use long content so turn sizes dwarf the truncation marker. + let padding = "x".repeat(500); + let events = vec![ + user(&format!("turn-1-user-{padding}")), + text(&format!("turn-1-response-{padding}")), + user(&format!("turn-2-user-{padding}")), + text(&format!("turn-2-response-{padding}")), + user(&format!("turn-3-user-{padding}")), + text(&format!("turn-3-response-{padding}")), + user(&format!("turn-4-user-{padding}")), + text(&format!("turn-4-response-{padding}")), + ]; + + // Budget that fits everything except turn 2 + let all = events_to_messages(&events); + let total = estimate_chars(&all); + let turn2 = events_to_messages(&events[2..4]); + let turn2_chars = estimate_chars(&turn2); + + let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); + let budget = total - turn2_chars + marker_cost + 5; + assert!( + budget < total, + "budget must be less than total for truncation to trigger" + ); + + let builder = ContextWindowBuilder::new(budget); + let messages = builder.build(&events); + + // Should have: prefix (t1: 2 msgs) + marker (1 msg) + t3 (2 msgs) + t4 (2 msgs) = 7 + // (turn 2 dropped) + assert_eq!(messages.len(), 7); + assert!( + messages[0]["content"] + .as_str() + .unwrap() + .starts_with("turn-1-user-") + ); + assert!( + messages[1]["content"] + .as_str() + .unwrap() + .starts_with("turn-1-response-") + ); + assert!(messages[2]["content"].as_str().unwrap().contains("omitted")); + assert!( + messages[3]["content"] + .as_str() + .unwrap() + .starts_with("turn-3-user-") + ); + assert!( + messages[4]["content"] + .as_str() + .unwrap() + .starts_with("turn-3-response-") + ); + assert!( + messages[5]["content"] + .as_str() + .unwrap() + .starts_with("turn-4-user-") + ); + assert!( + messages[6]["content"] + .as_str() + .unwrap() + .starts_with("turn-4-response-") + ); + } + + #[test] + fn no_marker_when_no_turns_dropped() { + // Two turns, both fit in budget + let events = vec![user("a"), text("b"), user("c"), text("d")]; + + let builder = ContextWindowBuilder::with_default_budget(); + let messages = builder.build(&events); + + // No truncation marker + assert_eq!(messages.len(), 4); + assert!( + !messages + .iter() + .any(|m| m["content"].as_str().is_some_and(|s| s.contains("omitted"))) + ); + } + + #[test] + fn tool_use_and_tool_result_never_split() { + // Invariant: a tool_use and its matching tool_result must always + // end up in the same turn, so truncation can't orphan one from + // the other. This test verifies that ToolResult does NOT start + // a new turn boundary. + let padding = "x".repeat(500); + let events = vec![ + // Turn 1 (prefix) + user(&format!("turn-1-{padding}")), + text(&format!("resp-1-{padding}")), + // Turn 2: contains a tool_use → tool_result pair (will be dropped) + user(&format!("turn-2-{padding}")), + text("checking"), + tool_call("tc1", "suggest_command"), + tool_result("tc1", &format!("output-{padding}")), + text(&format!("done-{padding}")), + // Turn 3 (tail) + user(&format!("turn-3-{padding}")), + text(&format!("resp-3-{padding}")), + ]; + + // Budget that fits turn 1 + turn 3 + marker, but not turn 2 + let turn1 = events_to_messages(&events[0..2]); + let turn3 = events_to_messages(&events[7..9]); + let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); + let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; + + let builder = ContextWindowBuilder::new(budget); + let messages = builder.build(&events); + + // Verify: every tool_use has a matching tool_result, and vice versa + let tool_use_ids: Vec<&str> = messages + .iter() + .filter_map(|m| m["content"].as_array()) + .flatten() + .filter(|b| b["type"] == "tool_use") + .filter_map(|b| b["id"].as_str()) + .collect(); + + let tool_result_ids: Vec<&str> = messages + .iter() + .filter_map(|m| m["content"].as_array()) + .flatten() + .filter(|b| b["type"] == "tool_result") + .filter_map(|b| b["tool_use_id"].as_str()) + .collect(); + + assert_eq!( + tool_use_ids, tool_result_ids, + "every tool_use must have a matching tool_result (and vice versa)" + ); + + // Turn 2 was dropped entirely, so no tool IDs should be present + assert!( + !tool_use_ids.contains(&"tc1"), + "dropped turn's tool_use should not appear" + ); + } +} diff --git a/crates/atuin-ai/src/event_serde.rs b/crates/atuin-ai/src/event_serde.rs new file mode 100644 index 00000000..546d6e5b --- /dev/null +++ b/crates/atuin-ai/src/event_serde.rs @@ -0,0 +1,376 @@ +//! Manual serialization for ConversationEvent to/from storage format. +//! +//! The storage format is decoupled from the Rust enum so the two can evolve +//! independently. Each event is stored as an `(event_type, event_data)` pair +//! where `event_data` is a JSON string. + +use eyre::{Result, eyre}; +use serde_json::Value; + +use crate::tui::ConversationEvent; + +/// Serialize a ConversationEvent into an (event_type, event_data_json) pair +/// suitable for database storage. +pub(crate) fn serialize_event(event: &ConversationEvent) -> (String, String) { + match event { + ConversationEvent::UserMessage { content } => ( + "user_message".to_string(), + serde_json::json!({ "content": content }).to_string(), + ), + ConversationEvent::Text { content } => ( + "text".to_string(), + serde_json::json!({ "content": content }).to_string(), + ), + ConversationEvent::ToolCall { id, name, input } => ( + "tool_call".to_string(), + serde_json::json!({ + "id": id, + "name": name, + "input": input, + }) + .to_string(), + ), + ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + } => ( + "tool_result".to_string(), + serde_json::json!({ + "tool_use_id": tool_use_id, + "content": content, + "is_error": is_error, + "remote": remote, + "content_length": content_length, + }) + .to_string(), + ), + ConversationEvent::OutOfBandOutput { + name, + command, + content, + } => ( + "out_of_band_output".to_string(), + serde_json::json!({ + "name": name, + "command": command, + "content": content, + }) + .to_string(), + ), + ConversationEvent::SystemContext { content } => ( + "system_context".to_string(), + serde_json::json!({ "content": content }).to_string(), + ), + } +} + +/// Deserialize an (event_type, event_data_json) pair from storage back into a +/// ConversationEvent. +pub(crate) fn deserialize_event(event_type: &str, event_data: &str) -> Result<ConversationEvent> { + let data: Value = serde_json::from_str(event_data) + .map_err(|e| eyre!("failed to parse event_data JSON: {e}"))?; + + match event_type { + "user_message" => Ok(ConversationEvent::UserMessage { + content: json_string(&data, "content")?, + }), + "text" => Ok(ConversationEvent::Text { + content: json_string(&data, "content")?, + }), + "tool_call" => Ok(ConversationEvent::ToolCall { + id: json_string(&data, "id")?, + name: json_string(&data, "name")?, + input: data + .get("input") + .cloned() + .ok_or_else(|| eyre!("tool_call missing 'input' field"))?, + }), + "tool_result" => Ok(ConversationEvent::ToolResult { + tool_use_id: json_string(&data, "tool_use_id")?, + content: json_string(&data, "content")?, + is_error: data + .get("is_error") + .and_then(Value::as_bool) + .ok_or_else(|| eyre!("tool_result missing 'is_error' field"))?, + remote: data.get("remote").and_then(Value::as_bool).unwrap_or(false), + content_length: data + .get("content_length") + .and_then(Value::as_u64) + .map(|v| v as usize), + }), + "out_of_band_output" => Ok(ConversationEvent::OutOfBandOutput { + name: json_string(&data, "name")?, + command: data + .get("command") + .and_then(|v| if v.is_null() { None } else { v.as_str() }) + .map(String::from), + content: json_string(&data, "content")?, + }), + "system_context" => Ok(ConversationEvent::SystemContext { + content: json_string(&data, "content")?, + }), + other => Err(eyre!("unknown event type: {other}")), + } +} + +fn json_string(data: &Value, field: &str) -> Result<String> { + data.get(field) + .and_then(Value::as_str) + .map(String::from) + .ok_or_else(|| eyre!("missing or non-string field '{field}'")) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn round_trip(event: &ConversationEvent) -> ConversationEvent { + let (event_type, event_data) = serialize_event(event); + deserialize_event(&event_type, &event_data).unwrap() + } + + #[test] + fn test_user_message() { + let event = ConversationEvent::UserMessage { + content: "hello world".to_string(), + }; + let result = round_trip(&event); + assert!( + matches!(result, ConversationEvent::UserMessage { content } if content == "hello world") + ); + } + + #[test] + fn test_text() { + let event = ConversationEvent::Text { + content: "response text".to_string(), + }; + let result = round_trip(&event); + assert!( + matches!(result, ConversationEvent::Text { content } if content == "response text") + ); + } + + #[test] + fn test_tool_call() { + let input = serde_json::json!({"command": "ls -la", "danger": "low"}); + let event = ConversationEvent::ToolCall { + id: "tc_123".to_string(), + name: "suggest_command".to_string(), + input: input.clone(), + }; + let result = round_trip(&event); + match result { + ConversationEvent::ToolCall { + id, + name, + input: result_input, + } => { + assert_eq!(id, "tc_123"); + assert_eq!(name, "suggest_command"); + assert_eq!(result_input, input); + } + _ => panic!("expected ToolCall"), + } + } + + #[test] + fn test_tool_result() { + let event = ConversationEvent::ToolResult { + tool_use_id: "tc_123".to_string(), + content: "file contents here".to_string(), + is_error: false, + remote: false, + content_length: None, + }; + let result = round_trip(&event); + match result { + ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + } => { + assert_eq!(tool_use_id, "tc_123"); + assert_eq!(content, "file contents here"); + assert!(!is_error); + assert!(!remote); + assert!(content_length.is_none()); + } + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_tool_result_error() { + let event = ConversationEvent::ToolResult { + tool_use_id: "tc_456".to_string(), + content: "permission denied".to_string(), + is_error: true, + remote: false, + content_length: None, + }; + let result = round_trip(&event); + match result { + ConversationEvent::ToolResult { is_error, .. } => assert!(is_error), + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_tool_result_remote() { + let event = ConversationEvent::ToolResult { + tool_use_id: "tc_789".to_string(), + content: "ref:abc123".to_string(), + is_error: false, + remote: true, + content_length: Some(4096), + }; + let result = round_trip(&event); + match result { + ConversationEvent::ToolResult { + remote, + content_length, + .. + } => { + assert!(remote); + assert_eq!(content_length, Some(4096)); + } + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_tool_result_backwards_compat() { + // Old stored data without remote/content_length fields should deserialize + // with defaults (remote=false, content_length=None) + let event = deserialize_event( + "tool_result", + r#"{"tool_use_id":"tc_old","content":"old result","is_error":false}"#, + ) + .unwrap(); + match event { + ConversationEvent::ToolResult { + remote, + content_length, + .. + } => { + assert!(!remote); + assert!(content_length.is_none()); + } + _ => panic!("expected ToolResult"), + } + } + + #[test] + fn test_out_of_band_with_command() { + let event = ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: Some("/help".to_string()), + content: "help text".to_string(), + }; + let result = round_trip(&event); + match result { + ConversationEvent::OutOfBandOutput { + name, + command, + content, + } => { + assert_eq!(name, "System"); + assert_eq!(command.as_deref(), Some("/help")); + assert_eq!(content, "help text"); + } + _ => panic!("expected OutOfBandOutput"), + } + } + + #[test] + fn test_out_of_band_without_command() { + let event = ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: None, + content: "some output".to_string(), + }; + let result = round_trip(&event); + match result { + ConversationEvent::OutOfBandOutput { command, .. } => { + assert!(command.is_none()); + } + _ => panic!("expected OutOfBandOutput"), + } + } + + #[test] + fn test_unknown_event_type() { + let result = deserialize_event("banana", "{}"); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("unknown event type") + ); + } + + #[test] + fn test_invalid_json() { + let result = deserialize_event("text", "not json"); + assert!(result.is_err()); + } + + #[test] + fn test_missing_field() { + let result = deserialize_event("text", "{}"); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("content")); + } + + #[test] + fn test_text_with_special_characters() { + let event = ConversationEvent::Text { + content: "line1\nline2\ttab \"quotes\" \\backslash 🎉".to_string(), + }; + let result = round_trip(&event); + assert!( + matches!(result, ConversationEvent::Text { content } if content == "line1\nline2\ttab \"quotes\" \\backslash 🎉") + ); + } + + #[test] + fn test_tool_call_with_nested_input() { + let input = serde_json::json!({ + "command": "echo 'hello'", + "nested": { "a": [1, 2, 3], "b": null } + }); + let event = ConversationEvent::ToolCall { + id: "tc_1".to_string(), + name: "execute_shell_command".to_string(), + input: input.clone(), + }; + let result = round_trip(&event); + match result { + ConversationEvent::ToolCall { + input: result_input, + .. + } => { + assert_eq!(result_input, input); + } + _ => panic!("expected ToolCall"), + } + } + + #[test] + fn test_system_context() { + let event = ConversationEvent::SystemContext { + content: "[system: new invocation started]".to_string(), + }; + let result = round_trip(&event); + assert!( + matches!(result, ConversationEvent::SystemContext { content } if content == "[system: new invocation started]") + ); + } +} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 6f431179..febb488e 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -1,6 +1,10 @@ pub mod commands; pub(crate) mod context; +pub(crate) mod context_window; +pub(crate) mod event_serde; pub(crate) mod permissions; +pub(crate) mod session; +pub(crate) mod store; pub(crate) mod stream; pub(crate) mod tools; pub(crate) mod tui; diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs new file mode 100644 index 00000000..d8314343 --- /dev/null +++ b/crates/atuin-ai/src/session.rs @@ -0,0 +1,482 @@ +//! Session service abstraction and manager. +//! +//! The TUI interacts with sessions through `SessionManager`, which wraps a +//! `SessionService` trait. Today the only implementation is `LocalSessionService` +//! (direct SQLite). When the daemon owns session state, a gRPC-backed +//! implementation can be swapped in without changing the TUI code. + +use async_trait::async_trait; +use eyre::Result; + +use crate::event_serde; +use crate::store::{AiSessionStore, StoredEvent, StoredSession}; +use crate::tui::ConversationEvent; + +// --------------------------------------------------------------------------- +// Trait +// --------------------------------------------------------------------------- + +#[async_trait] +pub(crate) trait SessionService: Send + Sync { + async fn create_session( + &self, + id: &str, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Result<StoredSession>; + + async fn find_resumable( + &self, + directory: Option<&str>, + git_root: Option<&str>, + max_age_secs: i64, + ) -> Result<Option<StoredSession>>; + + async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>>; + + async fn append_event( + &self, + session_id: &str, + event_id: &str, + parent_id: Option<&str>, + invocation_id: &str, + event_type: &str, + event_data: &str, + ) -> Result<()>; + + async fn update_server_session_id( + &self, + session_id: &str, + server_session_id: &str, + ) -> Result<()>; + + async fn archive(&self, session_id: &str) -> Result<()>; +} + +// --------------------------------------------------------------------------- +// Local implementation (direct SQLite) +// --------------------------------------------------------------------------- + +pub(crate) struct LocalSessionService { + store: AiSessionStore, +} + +impl LocalSessionService { + pub async fn open(path: impl AsRef<std::path::Path>, timeout: f64) -> Result<Self> { + let store = AiSessionStore::new(path, timeout).await?; + Ok(Self { store }) + } +} + +#[async_trait] +impl SessionService for LocalSessionService { + async fn create_session( + &self, + id: &str, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Result<StoredSession> { + self.store.create_session(id, directory, git_root).await + } + + async fn find_resumable( + &self, + directory: Option<&str>, + git_root: Option<&str>, + max_age_secs: i64, + ) -> Result<Option<StoredSession>> { + self.store + .find_resumable_session(directory, git_root, max_age_secs) + .await + } + + async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> { + self.store.load_events(session_id).await + } + + async fn append_event( + &self, + session_id: &str, + event_id: &str, + parent_id: Option<&str>, + invocation_id: &str, + event_type: &str, + event_data: &str, + ) -> Result<()> { + self.store + .append_event( + session_id, + event_id, + parent_id, + invocation_id, + event_type, + event_data, + ) + .await + } + + async fn update_server_session_id( + &self, + session_id: &str, + server_session_id: &str, + ) -> Result<()> { + self.store + .update_server_session_id(session_id, server_session_id) + .await + } + + async fn archive(&self, session_id: &str) -> Result<()> { + self.store.archive_session(session_id).await + } +} + +// --------------------------------------------------------------------------- +// SessionManager +// --------------------------------------------------------------------------- + +/// High-level session manager used by the TUI dispatch loop. +/// +/// Owns the current session identity, tracks what has been persisted, and +/// handles serialization between `ConversationEvent` and the storage format. +pub(crate) struct SessionManager { + service: Box<dyn SessionService>, + session_id: String, + invocation_id: String, + /// Number of events already persisted. `persist_events` only writes the + /// delta from this index onward. + persisted_count: usize, + /// ID of the last persisted event, used as `parent_id` for the next one. + head_id: Option<String>, + /// Stored for creating a new session on `/new`. + directory: Option<String>, + git_root: Option<String>, + /// Whether the session row has been created in the database. New sessions + /// are deferred until the first event is persisted, so empty sessions + /// don't linger and get spuriously resumed. + persisted_to_db: bool, +} + +impl SessionManager { + /// Create a new session manager. The database row is deferred until the + /// first event is persisted. + pub fn create_new( + service: Box<dyn SessionService>, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Self { + let session_id = atuin_common::utils::uuid_v7().to_string(); + let invocation_id = atuin_common::utils::uuid_v7().to_string(); + + Self { + service, + session_id, + invocation_id, + persisted_count: 0, + head_id: None, + directory: directory.map(String::from), + git_root: git_root.map(String::from), + persisted_to_db: false, + } + } + + /// Load an existing session and return a manager for it, along with the + /// deserialized conversation events, the server session ID, and the + /// timestamp of the last stored event. + pub async fn resume( + service: Box<dyn SessionService>, + stored: &StoredSession, + ) -> Result<( + Self, + Vec<ConversationEvent>, + Option<String>, + Option<i64>, + String, + )> { + let invocation_id = atuin_common::utils::uuid_v7().to_string(); + let stored_events = service.load_events(&stored.id).await?; + + let mut events = Vec::with_capacity(stored_events.len()); + let mut last_event_id = None; + let mut last_event_ts = None; + for se in &stored_events { + events.push(event_serde::deserialize_event( + &se.event_type, + &se.event_data, + )?); + last_event_id = Some(se.id.clone()); + last_event_ts = Some(se.created_at); + } + + let manager = Self { + service, + session_id: stored.id.clone(), + invocation_id: invocation_id.clone(), + persisted_count: events.len(), + head_id: last_event_id, + directory: stored.directory.clone(), + git_root: stored.git_root.clone(), + persisted_to_db: true, + }; + + Ok(( + manager, + events, + stored.server_session_id.clone(), + last_event_ts, + invocation_id, + )) + } + + /// Ensure the session row exists in the database. + async fn ensure_persisted(&mut self) -> Result<()> { + if !self.persisted_to_db { + self.service + .create_session( + &self.session_id, + self.directory.as_deref(), + self.git_root.as_deref(), + ) + .await?; + self.persisted_to_db = true; + } + Ok(()) + } + + /// Persist any new events since the last persist call. + pub async fn persist_events(&mut self, events: &[ConversationEvent]) -> Result<()> { + if self.persisted_count >= events.len() { + return Ok(()); + } + self.ensure_persisted().await?; + for event in &events[self.persisted_count..] { + let event_id = atuin_common::utils::uuid_v7().to_string(); + let (event_type, event_data) = event_serde::serialize_event(event); + + self.service + .append_event( + &self.session_id, + &event_id, + self.head_id.as_deref(), + &self.invocation_id, + &event_type, + &event_data, + ) + .await?; + + self.head_id = Some(event_id); + self.persisted_count += 1; + } + Ok(()) + } + + /// Persist the server session ID if it has changed. + pub async fn persist_server_session_id(&mut self, server_session_id: &str) -> Result<()> { + self.ensure_persisted().await?; + self.service + .update_server_session_id(&self.session_id, server_session_id) + .await + } + + /// Archive the current session (for `/new` command). + #[allow(dead_code)] // used in tests; will be used by dispatch for `/new` + pub async fn archive(&self) -> Result<()> { + if self.persisted_to_db { + self.service.archive(&self.session_id).await?; + } + Ok(()) + } + + /// Archive the current session and reset to a fresh one. + /// The new session row is deferred until the first event is persisted. + pub async fn archive_and_reset(&mut self) -> Result<()> { + if self.persisted_to_db { + self.service.archive(&self.session_id).await?; + } + + self.session_id = atuin_common::utils::uuid_v7().to_string(); + self.invocation_id = atuin_common::utils::uuid_v7().to_string(); + self.persisted_count = 0; + self.head_id = None; + self.persisted_to_db = false; + Ok(()) + } + + #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon + pub fn session_id(&self) -> &str { + &self.session_id + } + + #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon + pub fn invocation_id(&self) -> &str { + &self.invocation_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn test_service() -> Box<dyn SessionService> { + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + Box::new(svc) + } + + #[tokio::test] + async fn test_create_new_and_persist() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + let events = vec![ + ConversationEvent::UserMessage { + content: "hello".to_string(), + }, + ConversationEvent::Text { + content: "hi there".to_string(), + }, + ]; + + mgr.persist_events(&events).await.unwrap(); + + // Persist again with no new events — should be a no-op + mgr.persist_events(&events).await.unwrap(); + } + + #[tokio::test] + async fn test_create_and_resume() { + // Create a session and persist some events + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let session_id = atuin_common::utils::uuid_v7().to_string(); + svc.create_session(&session_id, Some("/project"), Some("/project")) + .await + .unwrap(); + + let events = vec![ + ConversationEvent::UserMessage { + content: "how do I list files?".to_string(), + }, + ConversationEvent::Text { + content: "Use ls".to_string(), + }, + ConversationEvent::ToolCall { + id: "tc_1".to_string(), + name: "suggest_command".to_string(), + input: serde_json::json!({"command": "ls -la"}), + }, + ]; + + // Persist events manually through the service + let inv_id = "inv-1"; + let mut parent: Option<String> = None; + for event in &events { + let eid = atuin_common::utils::uuid_v7().to_string(); + let (etype, edata) = event_serde::serialize_event(event); + svc.append_event(&session_id, &eid, parent.as_deref(), inv_id, &etype, &edata) + .await + .unwrap(); + parent = Some(eid); + } + + svc.update_server_session_id(&session_id, "srv-abc") + .await + .unwrap(); + + // Now find and resume the session with a fresh service connection + let stored = svc + .find_resumable(Some("/project"), Some("/project"), 3600) + .await + .unwrap() + .expect("should find session"); + + let (mut mgr, loaded_events, server_sid, last_ts, _invocation_id) = + SessionManager::resume(Box::new(svc), &stored) + .await + .unwrap(); + + assert_eq!(loaded_events.len(), 3); + assert_eq!(server_sid.as_deref(), Some("srv-abc")); + assert_ne!(mgr.invocation_id(), inv_id, "new invocation ID on resume"); + assert!(last_ts.is_some(), "should have a last event timestamp"); + + // Persisting again with the same events should be a no-op + mgr.persist_events(&loaded_events).await.unwrap(); + } + + #[tokio::test] + async fn test_incremental_persist() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + let mut events = vec![ConversationEvent::UserMessage { + content: "first".to_string(), + }]; + mgr.persist_events(&events).await.unwrap(); + + // Add more events and persist again — only the new ones should be written + events.push(ConversationEvent::Text { + content: "response".to_string(), + }); + events.push(ConversationEvent::UserMessage { + content: "second".to_string(), + }); + mgr.persist_events(&events).await.unwrap(); + + // Verify by loading through a fresh service (can't easily here since + // the service is moved, but the lack of errors confirms correctness) + } + + #[tokio::test] + async fn test_archive() { + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); + + mgr.archive().await.unwrap(); + } + + #[tokio::test] + async fn test_persist_server_session_id() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + mgr.persist_server_session_id("srv-123").await.unwrap(); + } + + #[tokio::test] + async fn test_parent_chain_integrity() { + // Verify that persisted events form a proper parent chain + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let session_id = { + let mut mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); + + let events = vec![ + ConversationEvent::UserMessage { + content: "a".to_string(), + }, + ConversationEvent::Text { + content: "b".to_string(), + }, + ConversationEvent::UserMessage { + content: "c".to_string(), + }, + ]; + mgr.persist_events(&events).await.unwrap(); + mgr.session_id().to_string() + }; + + // Re-open the store and load events to verify the chain + // (Can't do this with in-memory DB since it's gone, but the + // lack of FK constraint violations during persist confirms the + // parent_id values are valid) + let _ = session_id; + } +} diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs new file mode 100644 index 00000000..2a75d8f4 --- /dev/null +++ b/crates/atuin-ai/src/store.rs @@ -0,0 +1,522 @@ +use std::path::Path; +use std::str::FromStr; +use std::time::Duration; + +use eyre::{Result, eyre}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; +use time::OffsetDateTime; + +// Database row mappings — all columns are kept even if not yet read in +// non-test code, since they're part of the schema and used in tests. +#[derive(Debug)] +#[allow(dead_code)] +pub(crate) struct StoredSession { + pub id: String, + pub head_id: Option<String>, + pub server_session_id: Option<String>, + pub directory: Option<String>, + pub git_root: Option<String>, + pub created_at: i64, + pub updated_at: i64, + pub archived_at: Option<i64>, +} + +#[derive(Debug)] +#[allow(dead_code)] +pub(crate) struct StoredEvent { + pub id: String, + pub session_id: String, + pub parent_id: Option<String>, + pub invocation_id: String, + pub event_type: String, + pub event_data: String, + pub created_at: i64, +} + +/// Row type returned by session queries (avoids clippy::type_complexity). +type SessionRow = ( + String, + Option<String>, + Option<String>, + Option<String>, + Option<String>, + i64, + i64, + Option<i64>, +); + +/// Row type returned by event queries. +type EventRow = (String, String, Option<String>, String, String, String, i64); + +pub(crate) struct AiSessionStore { + pool: SqlitePool, +} + +impl AiSessionStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + let path_str = path + .as_os_str() + .to_str() + .ok_or_else(|| eyre!("AI session database path is not valid UTF-8: {path:?}"))?; + + let is_memory = path_str.contains(":memory:"); + + if !is_memory + && !path.exists() + && let Some(dir) = path.parent() + { + fs_err::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path_str)? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + sqlx::migrate!("./migrations").run(&pool).await?; + + #[cfg(unix)] + if !is_memory { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + } + + Ok(Self { pool }) + } + + pub async fn create_session( + &self, + id: &str, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Result<StoredSession> { + let now = OffsetDateTime::now_utc().unix_timestamp(); + + sqlx::query( + "INSERT INTO sessions (id, directory, git_root, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?4)", + ) + .bind(id) + .bind(directory) + .bind(git_root) + .bind(now) + .execute(&self.pool) + .await?; + + Ok(StoredSession { + id: id.to_string(), + head_id: None, + server_session_id: None, + directory: directory.map(String::from), + git_root: git_root.map(String::from), + created_at: now, + updated_at: now, + archived_at: None, + }) + } + + #[allow(dead_code)] // used in tests; will be used by daemon service + pub async fn get_session(&self, id: &str) -> Result<Option<StoredSession>> { + let row: Option<SessionRow> = sqlx::query_as( + "SELECT id, head_id, server_session_id, directory, git_root, + created_at, updated_at, archived_at + FROM sessions WHERE id = ?1", + ) + .bind(id) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map( + |( + id, + head_id, + server_session_id, + directory, + git_root, + created_at, + updated_at, + archived_at, + )| { + StoredSession { + id, + head_id, + server_session_id, + directory, + git_root, + created_at, + updated_at, + archived_at, + } + }, + )) + } + + /// Find the most recent non-archived session matching the given directory or git + /// root, updated within `max_age_secs` seconds. + pub async fn find_resumable_session( + &self, + directory: Option<&str>, + git_root: Option<&str>, + max_age_secs: i64, + ) -> Result<Option<StoredSession>> { + let cutoff = OffsetDateTime::now_utc().unix_timestamp() - max_age_secs; + + let row: Option<SessionRow> = sqlx::query_as( + "SELECT id, head_id, server_session_id, directory, git_root, + created_at, updated_at, archived_at + FROM sessions + WHERE archived_at IS NULL + AND updated_at > ?1 + AND (directory = ?2 OR (git_root IS NOT NULL AND git_root = ?3)) + ORDER BY updated_at DESC + LIMIT 1", + ) + .bind(cutoff) + .bind(directory) + .bind(git_root) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map( + |( + id, + head_id, + server_session_id, + directory, + git_root, + created_at, + updated_at, + archived_at, + )| { + StoredSession { + id, + head_id, + server_session_id, + directory, + git_root, + created_at, + updated_at, + archived_at, + } + }, + )) + } + + /// Append a single event and update the session's `head_id` and `updated_at`. + pub async fn append_event( + &self, + session_id: &str, + event_id: &str, + parent_id: Option<&str>, + invocation_id: &str, + event_type: &str, + event_data: &str, + ) -> Result<()> { + let now = OffsetDateTime::now_utc().unix_timestamp(); + + let mut tx = self.pool.begin().await?; + + sqlx::query( + "INSERT INTO session_events (id, session_id, parent_id, invocation_id, event_type, event_data, created_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + ) + .bind(event_id) + .bind(session_id) + .bind(parent_id) + .bind(invocation_id) + .bind(event_type) + .bind(event_data) + .bind(now) + .execute(&mut *tx) + .await?; + + sqlx::query("UPDATE sessions SET head_id = ?1, updated_at = ?2 WHERE id = ?3") + .bind(event_id) + .bind(now) + .bind(session_id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + Ok(()) + } + + /// Load all events for a session, ordered chronologically. + pub async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> { + let rows: Vec<EventRow> = sqlx::query_as( + "SELECT id, session_id, parent_id, invocation_id, event_type, event_data, created_at + FROM session_events + WHERE session_id = ?1 + ORDER BY created_at ASC, rowid ASC", + ) + .bind(session_id) + .fetch_all(&self.pool) + .await?; + + Ok(rows + .into_iter() + .map( + |(id, session_id, parent_id, invocation_id, event_type, event_data, created_at)| { + StoredEvent { + id, + session_id, + parent_id, + invocation_id, + event_type, + event_data, + created_at, + } + }, + ) + .collect()) + } + + pub async fn update_server_session_id( + &self, + session_id: &str, + server_session_id: &str, + ) -> Result<()> { + sqlx::query("UPDATE sessions SET server_session_id = ?1 WHERE id = ?2") + .bind(server_session_id) + .bind(session_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub async fn archive_session(&self, session_id: &str) -> Result<()> { + let now = OffsetDateTime::now_utc().unix_timestamp(); + sqlx::query("UPDATE sessions SET archived_at = ?1 WHERE id = ?2") + .bind(now) + .bind(session_id) + .execute(&self.pool) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn new_test_store() -> AiSessionStore { + AiSessionStore::new("sqlite::memory:", 2.0).await.unwrap() + } + + #[tokio::test] + async fn test_create_and_get_session() { + let store = new_test_store().await; + + let session = store + .create_session("s1", Some("/home/user/project"), Some("/home/user/project")) + .await + .unwrap(); + assert_eq!(session.id, "s1"); + assert!(session.head_id.is_none()); + assert!(session.archived_at.is_none()); + + let loaded = store.get_session("s1").await.unwrap().unwrap(); + assert_eq!(loaded.id, "s1"); + assert_eq!(loaded.directory.as_deref(), Some("/home/user/project")); + } + + #[tokio::test] + async fn test_get_nonexistent_session() { + let store = new_test_store().await; + assert!(store.get_session("nope").await.unwrap().is_none()); + } + + #[tokio::test] + async fn test_append_and_load_events() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + + store + .append_event( + "s1", + "e1", + None, + "inv1", + "user_message", + r#"{"content":"hello"}"#, + ) + .await + .unwrap(); + store + .append_event( + "s1", + "e2", + Some("e1"), + "inv1", + "text", + r#"{"content":"hi there"}"#, + ) + .await + .unwrap(); + + let events = store.load_events("s1").await.unwrap(); + assert_eq!(events.len(), 2); + assert_eq!(events[0].id, "e1"); + assert!(events[0].parent_id.is_none()); + assert_eq!(events[0].invocation_id, "inv1"); + assert_eq!(events[1].id, "e2"); + assert_eq!(events[1].parent_id.as_deref(), Some("e1")); + + let session = store.get_session("s1").await.unwrap().unwrap(); + assert_eq!(session.head_id.as_deref(), Some("e2")); + } + + #[tokio::test] + async fn test_find_resumable_session() { + let store = new_test_store().await; + store + .create_session("s1", Some("/home/user/project"), None) + .await + .unwrap(); + + let found = store + .find_resumable_session(Some("/home/user/project"), None, 3600) + .await + .unwrap(); + assert!(found.is_some()); + assert_eq!(found.unwrap().id, "s1"); + } + + #[tokio::test] + async fn test_find_resumable_by_git_root() { + let store = new_test_store().await; + store + .create_session( + "s1", + Some("/home/user/project/sub"), + Some("/home/user/project"), + ) + .await + .unwrap(); + + let found = store + .find_resumable_session(Some("/different/dir"), Some("/home/user/project"), 3600) + .await + .unwrap(); + assert!(found.is_some()); + assert_eq!(found.unwrap().id, "s1"); + } + + #[tokio::test] + async fn test_find_resumable_skips_archived() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + store.archive_session("s1").await.unwrap(); + + let found = store + .find_resumable_session(Some("/tmp"), None, 3600) + .await + .unwrap(); + assert!(found.is_none()); + } + + #[tokio::test] + async fn test_find_resumable_no_match_different_dir() { + let store = new_test_store().await; + store + .create_session("s1", Some("/home/user/project"), None) + .await + .unwrap(); + + let found = store + .find_resumable_session(Some("/other/dir"), None, 3600) + .await + .unwrap(); + assert!(found.is_none()); + } + + #[tokio::test] + async fn test_archive_session() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + + store.archive_session("s1").await.unwrap(); + + let session = store.get_session("s1").await.unwrap().unwrap(); + assert!(session.archived_at.is_some()); + } + + #[tokio::test] + async fn test_update_server_session_id() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + + store + .update_server_session_id("s1", "server-abc") + .await + .unwrap(); + + let session = store.get_session("s1").await.unwrap().unwrap(); + assert_eq!(session.server_session_id.as_deref(), Some("server-abc")); + } + + #[tokio::test] + async fn test_find_resumable_does_not_mutate() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + + let before = store.get_session("s1").await.unwrap().unwrap(); + store + .find_resumable_session(Some("/tmp"), None, 3600) + .await + .unwrap() + .unwrap(); + let after = store.get_session("s1").await.unwrap().unwrap(); + + assert_eq!(before.updated_at, after.updated_at); + } + + #[tokio::test] + async fn test_events_ordered_chronologically() { + let store = new_test_store().await; + store + .create_session("s1", Some("/tmp"), None) + .await + .unwrap(); + + store + .append_event("s1", "e1", None, "inv1", "user_message", "{}") + .await + .unwrap(); + store + .append_event("s1", "e2", Some("e1"), "inv1", "text", "{}") + .await + .unwrap(); + store + .append_event("s1", "e3", Some("e2"), "inv2", "user_message", "{}") + .await + .unwrap(); + + let events = store.load_events("s1").await.unwrap(); + assert_eq!(events.len(), 3); + assert!(events[0].created_at <= events[1].created_at); + assert!(events[1].created_at <= events[2].created_at); + assert_eq!(events[2].invocation_id, "inv2"); + } +} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index 9c21fc05..f4f4d704 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -12,6 +12,7 @@ use eye_declare::Handle; use eyre::{Context, Result}; use futures::StreamExt; use reqwest::Url; +use reqwest::header::USER_AGENT; use crate::{ context::{AppContext, ClientContext}, @@ -19,6 +20,8 @@ use crate::{ tui::{Session, events::AiTuiEvent}, }; +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); + /// Frames that alter the stream lifecycle — terminal or state-changing. #[derive(Debug, Clone)] pub(crate) enum StreamControl { @@ -57,6 +60,7 @@ pub(crate) struct ChatRequest { pub messages: Vec<serde_json::Value>, pub session_id: Option<String>, pub capabilities: Vec<String>, + pub invocation_id: String, } impl ChatRequest { @@ -64,8 +68,9 @@ impl ChatRequest { messages: Vec<serde_json::Value>, session_id: Option<String>, capabilities: &AiCapabilities, + invocation_id: String, ) -> Self { - let mut caps = vec![]; + let mut caps = vec!["client_invocations".to_string()]; if capabilities.enable_history_search.unwrap_or(true) { caps.push("client_v1_atuin_history".to_string()); } @@ -82,6 +87,7 @@ impl ChatRequest { messages, session_id, capabilities: caps, + invocation_id, } } } @@ -112,6 +118,7 @@ fn create_chat_stream( "messages": request.messages, "context": context, "capabilities": request.capabilities, + "invocation_id": request.invocation_id }); if let Some(ref sid) = request.session_id { @@ -123,6 +130,7 @@ fn create_chat_stream( let response = match client .post(endpoint.clone()) .header("Accept", "text/event-stream") + .header(USER_AGENT, APP_USER_AGENT) .bearer_auth(&token) .json(&request_body) .send() diff --git a/crates/atuin-ai/src/tui/components/input_box.rs b/crates/atuin-ai/src/tui/components/input_box.rs index f5e0fe2b..6e041418 100644 --- a/crates/atuin-ai/src/tui/components/input_box.rs +++ b/crates/atuin-ai/src/tui/components/input_box.rs @@ -19,7 +19,7 @@ use ratatui_core::{ }; use tui_textarea::TextArea; -use crate::tui::events::AiTuiEvent; +use crate::tui::{events::AiTuiEvent, slash::SlashCommandSearchResult}; /// A bordered text input box backed by tui-textarea. /// @@ -35,6 +35,8 @@ pub(crate) struct InputBox { pub footer: String, /// Whether the input is currently active (shows cursor, accepts input) pub active: bool, + /// If the user has typed a slash command, this holds the best match for it. + pub slash_suggestion: Option<SlashCommandSearchResult>, } pub(crate) struct InputBoxState { @@ -129,6 +131,18 @@ fn input_box( textarea.insert_newline(); return EventResult::Consumed; } + crossterm::event::KeyCode::Tab if props.slash_suggestion.is_some() => { + // If there's a slash command suggestion, Tab accepts it. + if let Some(suggestion) = &props.slash_suggestion { + textarea.clear(); + textarea.insert_str(format!("/{}", suggestion.command.name)); + // Manually trigger an input update event so the slash suggestion box can update immediately + if let Some(ref tx) = state.tx { + let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n"))); + } + return EventResult::Consumed; + } + } crossterm::event::KeyCode::Enter => { if key.modifiers.contains(KeyModifiers::SHIFT) { textarea.insert_newline(); diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs index 3458327d..9959dbad 100644 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ b/crates/atuin-ai/src/tui/components/mod.rs @@ -2,3 +2,4 @@ pub(crate) mod atuin_ai; pub(crate) mod input_box; pub(crate) mod markdown; pub(crate) mod select; +pub(crate) mod session_continue; diff --git a/crates/atuin-ai/src/tui/components/session_continue.rs b/crates/atuin-ai/src/tui/components/session_continue.rs new file mode 100644 index 00000000..bfbfb191 --- /dev/null +++ b/crates/atuin-ai/src/tui/components/session_continue.rs @@ -0,0 +1,49 @@ +use chrono_humanize::HumanTime; +use eye_declare::{Elements, Hooks, Span, Text, component, element, props}; +use ratatui::style::{Color, Modifier, Style}; + +#[props] +pub(crate) struct SessionContinue { + pub continued_at: Option<chrono::DateTime<chrono::Utc>>, +} + +#[derive(Default)] +pub(crate) struct SessionContinueState { + /// Frozen on mount so the label doesn't change on every render. + label: Option<String>, +} + +#[component(props = SessionContinue, state = SessionContinueState)] +fn session_continue( + _props: &SessionContinue, + state: &SessionContinueState, + hooks: &mut Hooks<SessionContinue, SessionContinueState>, +) -> Elements { + hooks.use_mount(|props, state| { + state.label = Some(match props.continued_at { + Some(t) => { + let human = HumanTime::from(t - chrono::Utc::now()); + format!( + " Continuing previous session (last active {human}) - type /new to start a new session" + ) + } + None => { + " Continuing previous session - type /new to start a new session".to_string() + } + }); + }); + + let resume_label = state + .label + .as_deref() + .unwrap_or(" Continuing previous session - type /new to start a new session"); + + element! { + Text { + Span( + text: resume_label, + style: Style::default().fg(Color::DarkGray).add_modifier(Modifier::ITALIC), + ) + } + } +} diff --git a/crates/atuin-ai/src/tui/content/help.md b/crates/atuin-ai/src/tui/content/help.md index 654aea40..d6623ac9 100644 --- a/crates/atuin-ai/src/tui/content/help.md +++ b/crates/atuin-ai/src/tui/content/help.md @@ -1,3 +1,6 @@ Welcome to Atuin AI, an AI assistant in your terminal. You can ask it to generate a shell command for you, or ask general terminal or software questions. +Commands: +{commands} + For more information, see [https://docs.atuin.sh/cli/ai/introduction/](https://docs.atuin.sh/cli/ai/introduction/) diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index b3e84757..ee2bbe74 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -2,14 +2,16 @@ use std::path::PathBuf; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; +use crate::context_window::ContextWindowBuilder; use crate::permissions::check::PermissionResponse; use crate::permissions::resolver::PermissionResolver; use crate::permissions::rule::Rule; use crate::permissions::writer::{self, RuleDisposition}; +use crate::session::SessionManager; use crate::stream::{ChatRequest, run_chat_stream}; use crate::tools::{ClientToolCall, ToolPhase}; use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::{ExitAction, Session}; +use crate::tui::state::{ConversationEvent, ExitAction, Session}; use eye_declare::Handle; use tokio::task::JoinHandle; @@ -19,6 +21,7 @@ pub(crate) fn dispatch( tx: &mpsc::Sender<AiTuiEvent>, app_ctx: &AppContext, client_ctx: &ClientContext, + session_mgr: &mut SessionManager, ) { match event { AiTuiEvent::ContinueAfterTools => { @@ -28,7 +31,7 @@ pub(crate) fn dispatch( on_input_updated(handle, input); } AiTuiEvent::SubmitInput(input) => { - on_submit_input(handle, tx, app_ctx, client_ctx, input); + on_submit_input(handle, tx, app_ctx, client_ctx, input, session_mgr); } AiTuiEvent::SlashCommand(cmd) => { on_slash_command(handle, cmd); @@ -61,6 +64,35 @@ pub(crate) fn dispatch( on_exit(handle); } } + + // Persist any new conversation events after each dispatch cycle. + persist_session(handle, session_mgr); +} + +/// Persist new events and the server session ID if it has changed. +/// Called from the dispatch thread (sync), bridges to async via the tokio handle. +fn persist_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { + let Ok((events, server_sid)) = handle + .fetch(|state| { + ( + state.conversation.events.clone(), + state.conversation.session_id.clone(), + ) + }) + .blocking_recv() + else { + return; + }; + + let rt = tokio::runtime::Handle::current(); + if let Err(e) = rt.block_on(session_mgr.persist_events(&events)) { + tracing::warn!("failed to persist session events: {e}"); + } + if let Some(ref sid) = server_sid + && let Err(e) = rt.block_on(session_mgr.persist_server_session_id(sid)) + { + tracing::warn!("failed to persist server session ID: {e}"); + } } fn launch_stream( @@ -78,9 +110,10 @@ fn launch_stream( handle.update(move |state| { (setup)(state); state.start_streaming(); - let messages = state.conversation.events_to_messages(); + let messages = + ContextWindowBuilder::with_default_budget().build(&state.conversation.events); let sid = state.conversation.session_id.clone(); - let request = ChatRequest::new(messages, sid, &caps); + let request = ChatRequest::new(messages, sid, &caps, state.invocation_id.clone()); let task: JoinHandle<()> = tokio::spawn(async move { run_chat_stream(h2, tx2, app, cc, request).await; }); @@ -98,10 +131,30 @@ fn on_continue_after_tools( } fn on_input_updated(handle: &Handle<Session>, input: String) { - let input_blank = input.trim().is_empty(); + let input_blank = input.is_empty(); + let slash_command = if input.starts_with('/') { + Some(input.trim_start_matches('/').to_string()) + } else { + None + }; handle.update(move |state| { state.interaction.is_input_blank = input_blank; + state.interaction.slash_command_input = slash_command; + + if let Some(query) = state.interaction.slash_command_input.as_ref() { + let mut results = state.slash_registry.search_fuzzy(query); + + results.sort_by(|a, b| { + b.relevance + .partial_cmp(&a.relevance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + state.interaction.slash_command_search_results = results; + } else { + state.interaction.slash_command_search_results.clear(); + } }); } @@ -111,7 +164,13 @@ fn on_submit_input( app_ctx: &AppContext, client_ctx: &ClientContext, input: String, + session_mgr: &mut SessionManager, ) { + handle.update(move |state| { + state.interaction.slash_command_input = None; + state.interaction.slash_command_search_results.clear(); + }); + let input = input.trim().to_string(); if input.is_empty() { let h2 = handle.clone(); @@ -129,9 +188,15 @@ fn on_submit_input( } if input.starts_with('/') { - handle.update(move |state| { - state.conversation.handle_slash_command(&input); - }); + if input.trim() == "/new" { + on_new_session(handle, session_mgr); + } else { + handle.update(move |state| { + state + .conversation + .handle_slash_command(&input, &state.slash_registry); + }); + } return; } @@ -144,7 +209,9 @@ fn on_submit_input( fn on_slash_command(handle: &Handle<Session>, command: String) { handle.update(move |state| { - state.conversation.handle_slash_command(&command); + state + .conversation + .handle_slash_command(&command, &state.slash_registry); }); } @@ -533,6 +600,38 @@ fn on_retry( }); } +fn on_new_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { + let rt = tokio::runtime::Handle::current(); + + if let Err(e) = rt.block_on(session_mgr.archive_and_reset()) { + tracing::warn!("failed to start new session: {e}"); + return; + } + + handle.update(|state| { + // Move the current invocation's visible events to the archived view + // so they remain on screen but are no longer sent to the API. + let visible_events: Vec<ConversationEvent> = + state.conversation.events[state.view_start_index..].to_vec(); + state.archived_view_events.extend(visible_events); + + state.conversation.events.clear(); + state.conversation.session_id = None; + state.tool_tracker = crate::tools::ToolTracker::new(); + state.view_start_index = 0; + state.is_resumed = false; + state.last_event_time = None; + state + .conversation + .events + .push(ConversationEvent::OutOfBandOutput { + name: "System".to_string(), + command: Some("/new".to_string()), + content: "Started a new session.".to_string(), + }); + }); +} + fn on_exit(handle: &Handle<Session>) { let h2 = handle.clone(); handle.update(move |state| { diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs index afd63312..05a040a1 100644 --- a/crates/atuin-ai/src/tui/mod.rs +++ b/crates/atuin-ai/src/tui/mod.rs @@ -1,7 +1,8 @@ pub(crate) mod components; pub(crate) mod dispatch; pub(crate) mod events; +pub(crate) mod slash; pub(crate) mod state; pub(crate) mod view; -pub(crate) use state::{ConversationEvent, Session}; +pub(crate) use state::{ConversationEvent, Session, events_to_messages}; diff --git a/crates/atuin-ai/src/tui/slash.rs b/crates/atuin-ai/src/tui/slash.rs new file mode 100644 index 00000000..7d5e6fa8 --- /dev/null +++ b/crates/atuin-ai/src/tui/slash.rs @@ -0,0 +1,79 @@ +#[derive(Debug, Clone)] +pub(crate) struct SlashCommand { + pub name: String, + pub description: String, +} + +impl SlashCommand { + pub fn new(name: &str, description: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + } + } +} + +#[derive(Debug)] +pub(crate) struct SlashCommandRegistry { + commands: Vec<SlashCommand>, +} + +#[derive(Debug, Clone)] +pub(crate) struct SlashCommandSearchResult { + pub command: SlashCommand, + pub relevance: f32, + pub span: (usize, usize), +} + +impl SlashCommandRegistry { + pub fn new() -> Self { + Self { + commands: Vec::new(), + } + } + + pub fn register(&mut self, command: SlashCommand) { + self.commands.push(command); + } + + pub fn get_commands(&self) -> &[SlashCommand] { + &self.commands + } + + pub fn search_fuzzy(&self, query: &str) -> Vec<SlashCommandSearchResult> { + let query_lower = query.to_lowercase(); + + self.commands + .iter() + .filter_map(|command| { + let name_lower = command.name.to_lowercase(); + if let Some(start) = name_lower.find(&query_lower as &str) { + let end = start + query_lower.len(); + Some((command, start, end)) + } else { + None + } + }) + .map(|(command, start, end)| { + SlashCommandSearchResult { + command: command.clone(), + relevance: 1.0, // Simple relevance score for now + span: (start, end), + } + }) + .collect() + } +} + +impl Default for SlashCommandRegistry { + fn default() -> Self { + let mut registry = Self::new(); + registry.register(SlashCommand::new("help", "Show help information")); + registry.register(SlashCommand::new( + "new", + "Start a new conversation, archiving the current one", + )); + + registry + } +} diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index 37200025..a012386a 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -5,7 +5,10 @@ use tokio::task::AbortHandle; -use crate::tools::{ClientToolCall, ToolOutcome, ToolTracker}; +use crate::{ + tools::{ClientToolCall, ToolOutcome, ToolTracker}, + tui::slash::{SlashCommandRegistry, SlashCommandSearchResult}, +}; /// Streaming status indicators from server #[derive(Debug, Clone, PartialEq, Eq)] @@ -57,9 +60,25 @@ pub(crate) enum ConversationEvent { command: Option<String>, content: String, }, + /// Context injected for the LLM that is not rendered in the TUI. + /// Converted to a user message in the API protocol. + SystemContext { content: String }, } impl ConversationEvent { + /// Whether this event represents actual conversation content sent to the API. + /// Used to determine if a resumed session has meaningful context. + pub(crate) fn is_api_content(&self) -> bool { + match self { + ConversationEvent::UserMessage { .. } => true, + ConversationEvent::Text { .. } => true, + ConversationEvent::ToolCall { .. } => true, + ConversationEvent::ToolResult { .. } => true, + ConversationEvent::OutOfBandOutput { .. } => false, + ConversationEvent::SystemContext { .. } => false, + } + } + /// Extract command from a suggest_command tool call pub(crate) fn as_command(&self) -> Option<&str> { if let ConversationEvent::ToolCall { name, input, .. } = self @@ -111,131 +130,6 @@ impl Conversation { } } - /// Convert conversation events to Claude API message format - pub fn events_to_messages(&self) -> Vec<serde_json::Value> { - let mut messages = Vec::new(); - let mut i = 0; - let events = &self.events; - - while i < events.len() { - match &events[i] { - ConversationEvent::UserMessage { content } => { - messages.push(serde_json::json!({ - "role": "user", - "content": content - })); - i += 1; - } - ConversationEvent::Text { content } => { - // Check if the next event(s) are ToolCalls — if so, combine - // into a single assistant message with mixed content blocks. - let next_is_tool_call = events - .get(i + 1) - .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); - - if next_is_tool_call { - let mut content_blocks = Vec::new(); - - if !content.is_empty() { - content_blocks.push(serde_json::json!({ - "type": "text", - "text": content - })); - } - - while let Some(ConversationEvent::ToolCall { - id, name, input, .. - }) = events.get(i + 1) - { - content_blocks.push(serde_json::json!({ - "type": "tool_use", - "id": id, - "name": name, - "input": input - })); - i += 1; - } - - messages.push(serde_json::json!({ - "role": "assistant", - "content": content_blocks - })); - i += 1; - } else { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; - } - } - ConversationEvent::ToolCall { .. } => { - // ToolCalls without preceding Text (shouldn't normally happen, - // but handle defensively) - let mut tool_uses = Vec::new(); - while i < events.len() { - if let ConversationEvent::ToolCall { - id, name, input, .. - } = &events[i] - { - tool_uses.push(serde_json::json!({ - "type": "tool_use", - "id": id, - "name": name, - "input": input - })); - i += 1; - } else { - break; - } - } - messages.push(serde_json::json!({ - "role": "assistant", - "content": tool_uses - })); - } - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => { - let tool_result = if *remote { - let mut obj = serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "remote": true, - "is_error": is_error - }); - if let Some(len) = content_length { - obj["content_length"] = serde_json::json!(len); - } - obj - } else { - serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error - }) - }; - messages.push(serde_json::json!({ - "role": "user", - "content": [tool_result] - })); - i += 1; - } - ConversationEvent::OutOfBandOutput { .. } => { - // Out-of-band output is not sent to the server, so we don't need to add it to the messages - i += 1; - } - } - } - - messages - } - /// Get the most recent command from events pub fn current_command(&self) -> Option<&str> { self.events.iter().rev().find_map(|e| e.as_command()) @@ -343,15 +237,22 @@ impl Conversation { } /// Handle a slash command - pub fn handle_slash_command(&mut self, command: &str) { + pub fn handle_slash_command(&mut self, command: &str, registry: &SlashCommandRegistry) { match command.trim() { "/help" => { - let content = include_str!("./content/help.md"); + let commands = registry + .get_commands() + .iter() + .map(|cmd| format!("- `/{}` - {}", cmd.name, cmd.description)) + .collect::<Vec<_>>() + .join("\n"); + + let content = include_str!("./content/help.md").replace("{commands}", &commands); self.events.push(ConversationEvent::OutOfBandOutput { name: "System".to_string(), command: Some("/help".to_string()), - content: content.to_string(), + content, }); } _ => self.events.push(ConversationEvent::OutOfBandOutput { @@ -363,6 +264,147 @@ impl Conversation { } } +/// Convert a slice of conversation events to Claude API message format. +/// +/// This is the canonical event-to-message conversion, used by the context window +/// builder to convert turn slices independently. The logic handles combining +/// adjacent Text + ToolCall events into single assistant messages with mixed +/// content blocks. +pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec<serde_json::Value> { + let mut messages = Vec::new(); + let mut i = 0; + + while i < events.len() { + match &events[i] { + ConversationEvent::UserMessage { content } => { + messages.push(serde_json::json!({ + "role": "user", + "content": content + })); + i += 1; + } + ConversationEvent::Text { content } if content.is_empty() => { + // Skip empty text events (e.g. streaming buffer before + // any data arrived). + i += 1; + } + ConversationEvent::Text { content } => { + // Check if the next event(s) are ToolCalls — if so, combine + // into a single assistant message with mixed content blocks. + let next_is_tool_call = events + .get(i + 1) + .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); + + if next_is_tool_call { + let mut content_blocks = Vec::new(); + + if !content.is_empty() { + content_blocks.push(serde_json::json!({ + "type": "text", + "text": content + })); + } + + while let Some(ConversationEvent::ToolCall { + id, name, input, .. + }) = events.get(i + 1) + { + content_blocks.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } + + messages.push(serde_json::json!({ + "role": "assistant", + "content": content_blocks + })); + i += 1; + } else { + messages.push(serde_json::json!({ + "role": "assistant", + "content": content + })); + i += 1; + } + } + ConversationEvent::ToolCall { .. } => { + // ToolCalls without preceding Text (shouldn't normally happen, + // but handle defensively) + let mut tool_uses = Vec::new(); + while i < events.len() { + if let ConversationEvent::ToolCall { + id, name, input, .. + } = &events[i] + { + tool_uses.push(serde_json::json!({ + "type": "tool_use", + "id": id, + "name": name, + "input": input + })); + i += 1; + } else { + break; + } + } + messages.push(serde_json::json!({ + "role": "assistant", + "content": tool_uses + })); + } + ConversationEvent::ToolResult { + tool_use_id, + content, + is_error, + remote, + content_length, + } => { + let tool_result = if *remote { + let mut obj = serde_json::json!({ + "type": "tool_result", + "tool_use_id": tool_use_id, + "remote": true, + "is_error": is_error + }); + if let Some(len) = content_length { + obj["content_length"] = serde_json::json!(len); + } + obj + } else { + serde_json::json!({ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + "is_error": is_error + }) + }; + messages.push(serde_json::json!({ + "role": "user", + "content": [tool_result] + })); + i += 1; + } + ConversationEvent::OutOfBandOutput { .. } => { + // Out-of-band output is not sent to the server + i += 1; + } + ConversationEvent::SystemContext { content } => { + messages.push(serde_json::json!({ + "role": "user", + "content": content + })); + i += 1; + } + } + } + + messages +} + /// Ephemeral UI/presentation state #[derive(Debug)] pub(crate) struct Interaction { @@ -370,6 +412,10 @@ pub(crate) struct Interaction { pub mode: AppMode, /// Whether the input is blank pub is_input_blank: bool, + /// The currently in-progress slash command (if any) + pub slash_command_input: Option<String>, + /// Search results for the current slash command input + pub slash_command_search_results: Vec<SlashCommandSearchResult>, /// True when user has pressed Enter once on a dangerous command pub confirmation_pending: bool, /// Current streaming status @@ -385,6 +431,8 @@ impl Interaction { Self { mode: AppMode::Input, is_input_blank: false, + slash_command_input: None, + slash_command_search_results: Vec::new(), confirmation_pending: false, streaming_status: None, was_interrupted: false, @@ -410,10 +458,26 @@ pub(crate) struct Session { pub exit_action: Option<ExitAction>, /// Abort handle for the active streaming task, if any pub stream_abort: Option<AbortHandle>, + /// Index into `conversation.events` where the current TUI invocation starts. + /// Events before this index are historical context sent to the API but not + /// rendered in the TUI. + pub view_start_index: usize, + /// Whether this session was resumed from a prior invocation. + pub is_resumed: bool, + /// Time of the last event from a previous invocation when resuming a session + pub last_event_time: Option<chrono::DateTime<chrono::Utc>>, + /// Events from archived sessions that are still rendered on screen but no + /// longer sent to the API. Accumulated by `/new` commands within a single + /// TUI lifetime. + pub archived_view_events: Vec<ConversationEvent>, + /// A registry of available slash commands + pub slash_registry: SlashCommandRegistry, + /// The unique ID for this invocation + pub invocation_id: String, } impl Session { - pub fn new(in_git_project: bool) -> Self { + pub fn new(in_git_project: bool, invocation_id: Option<String>) -> Self { Self { conversation: Conversation::new(), interaction: Interaction::new(), @@ -421,6 +485,12 @@ impl Session { in_git_project, exit_action: None, stream_abort: None, + view_start_index: 0, + is_resumed: false, + last_event_time: None, + archived_view_events: Vec::new(), + slash_registry: Default::default(), + invocation_id: invocation_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()), } } @@ -455,11 +525,10 @@ impl Session { // ===== Streaming lifecycle methods ===== /// Start streaming response. - /// Pushes an empty Text event that will be mutated in-place as chunks arrive. + /// The Text event for streamed content is created lazily by + /// `append_streaming_text` when the first chunk arrives, so we + /// don't leave an empty assistant turn in the conversation. pub fn start_streaming(&mut self) { - self.conversation.events.push(ConversationEvent::Text { - content: String::new(), - }); self.interaction.streaming_status = None; self.interaction.was_interrupted = false; self.interaction.mode = AppMode::Streaming; diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index ee5483d8..565a0597 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -8,6 +8,7 @@ use ratatui_core::style::{Color, Modifier, Style}; use crate::tools::{ClientToolCall, TrackedTool}; use crate::tui::components::select::SelectOption; +use crate::tui::components::session_continue::SessionContinue; use crate::tui::events::{AiTuiEvent, PermissionResult}; use super::components::atuin_ai::AtuinAi; @@ -29,7 +30,10 @@ mod turn; pub(crate) fn ai_view(state: &Session) -> Elements { let mut turn_builder = turn::TurnBuilder::new(&state.tool_tracker); - for event in &state.conversation.events { + for event in &state.archived_view_events { + turn_builder.add_event(event); + } + for event in &state.conversation.events[state.view_start_index..] { turn_builder.add_event(event); } let turns = turn_builder.build(); @@ -46,6 +50,10 @@ pub(crate) fn ai_view(state: &Session) -> Elements { pending_confirmation: state.interaction.confirmation_pending, has_executing_preview: state.tool_tracker.has_executing_preview(), ) { + #(if state.is_resumed && (!state.is_exiting() || !turns.is_empty()) { + SessionContinue(key: "continuation-notice", continued_at: state.last_event_time) + }) + #(for (index, turn) in turns.iter().enumerate() { #(match turn { turn::UiTurn::User { events } => { @@ -70,6 +78,13 @@ pub(crate) fn ai_view(state: &Session) -> Elements { fn input_view(state: &Session) -> Elements { let asking_tool = state.tool_tracker.asking_for_permission(); let in_git_project = state.in_git_project; + let slash_results = state + .interaction + .slash_command_search_results + .iter() + .take(4) + .collect::<Vec<_>>(); + let first_slash_result = slash_results.first().cloned(); element! { #(if let Some(tc) = asking_tool { @@ -84,6 +99,7 @@ fn input_view(state: &Session) -> Elements { title_right: "Atuin AI", footer: state.footer_text(), active: state.interaction.mode == AppMode::Input && !state.interaction.confirmation_pending, + slash_suggestion: first_slash_result.cloned() ) #(if state.interaction.is_input_blank && state.conversation.has_any_command() && state.interaction.mode == AppMode::Input { @@ -93,6 +109,23 @@ fn input_view(state: &Session) -> Elements { Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } }) }) + + #(if !slash_results.is_empty() { + #(for (i, result) in slash_results.iter().enumerate() { + Text { + Span(text: format!("/{}", &result.command.name[..result.span.0]), style: Style::default().fg(Color::Blue)) + Span(text: &result.command.name[result.span.0..result.span.1], style: Style::default().fg(Color::Blue).add_modifier(Modifier::UNDERLINED)) + Span(text: format!("{}", &result.command.name[result.span.1..]), style: Style::default().fg(Color::Blue)) + Span(text: " - ") + Span(text: &result.command.description) + + #(if i == 0 { + Span(text: " [Tab] Insert", style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC).dim()) + }) + } + + }) + }) } }) } @@ -270,7 +303,7 @@ fn out_of_band_turn_view(events: &[turn::UiEvent]) -> Elements { element! { View { Text { - Span(text: "System", style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD)) + Span(text: " System ", style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD).add_modifier(Modifier::REVERSED)) } #(for event in events { #(match event { diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 7369f151..a2555dc6 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -170,6 +170,9 @@ impl<'a> TurnBuilder<'a> { } => { self.add_out_of_band_output(name, command.as_deref(), content); } + ConversationEvent::SystemContext { .. } => { + // Not rendered in the TUI — only sent to the API + } } } diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs index 25c3bd65..9a2b84f5 100644 --- a/crates/atuin-client/src/settings.rs +++ b/crates/atuin-client/src/settings.rs @@ -664,6 +664,12 @@ pub struct Ai { /// Only necessary for custom AI endpoints. pub api_token: Option<String>, + /// Path to the AI sessions database. + pub db_path: String, + + /// The maximum time in minutes that an AI session can be automatically resumed. + pub session_continue_minutes: i64, + /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. #[serde(default)] pub send_cwd: Option<bool>, @@ -1467,6 +1473,7 @@ impl Settings { let record_store_path = data_dir.join("records.db"); let kv_path = data_dir.join("kv.db"); let scripts_path = data_dir.join("scripts.db"); + let ai_sessions_path = data_dir.join("ai_sessions.db"); let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); let pidfile_path = data_dir.join("atuin-daemon.pid"); let logs_dir = atuin_common::utils::logs_dir(); @@ -1550,6 +1557,8 @@ impl Settings { .set_default("search.frequency_score_multiplier", 1.0)? .set_default("search.frecency_score_multiplier", 1.0)? .set_default("meta.db_path", meta_path.to_str())? + .set_default("ai.db_path", ai_sessions_path.to_str())? + .set_default("ai.session_continue_minutes", 60)? .set_default("ai.send_cwd", false)? .set_default("ai.opening.send_cwd", false)? .set_default("ai.opening.send_last_command", false)? |
