diff options
| author | Michelle Tilley <michelle@michelletilley.net> | 2026-04-14 16:03:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-04-15 00:03:08 +0100 |
| commit | fd188da879d977ca847f10708c39dd4801a204c4 (patch) | |
| tree | 592bfe2644f8bd9be3563f176eabf29e55fa9a9b /crates/atuin-ai/src/session.rs | |
| parent | fix: dependency fix (#3414) (diff) | |
| download | atuin-fd188da879d977ca847f10708c39dd4801a204c4.zip | |
feat: Allow resuming previous AI sessions (#3407)
This PR introduces session continuation to Atuin AI.
* Conversations with Atuin AI are stored in a local SQLite database
* Upon startup, Atuin AI tries to find a session to resume based on its
directory/workspace and the time since the last event
* If found, Atuin AI will show a note that the session has been resumed,
and an event is added to help the LLM know where the invocation
boundaries are
* If not, Atuin AI will create a new conversation
* The user can create a new conversation with `/new`
* The new setting `ai.session_continue_minutes`, which defaults to `60`,
controls how old the last event in a session can be before it's no
longer considered for automatic resuming.
<img width="1055" height="593" alt="image"
src="https://github.com/user-attachments/assets/3f9ff01a-ef64-44a9-b0e2-3a4252c5746f"
/>
## Architecture
A new `SessionService` trait defines an API contract for a service that
can manage session data. `LocalSessionService` implements this, with
`DaemonSessionService` a possible future extension point.
`SessionManager` owns a `dyn SessionService` and delegates as
appropriate.
Diffstat (limited to 'crates/atuin-ai/src/session.rs')
| -rw-r--r-- | crates/atuin-ai/src/session.rs | 482 |
1 files changed, 482 insertions, 0 deletions
diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs new file mode 100644 index 00000000..d8314343 --- /dev/null +++ b/crates/atuin-ai/src/session.rs @@ -0,0 +1,482 @@ +//! Session service abstraction and manager. +//! +//! The TUI interacts with sessions through `SessionManager`, which wraps a +//! `SessionService` trait. Today the only implementation is `LocalSessionService` +//! (direct SQLite). When the daemon owns session state, a gRPC-backed +//! implementation can be swapped in without changing the TUI code. + +use async_trait::async_trait; +use eyre::Result; + +use crate::event_serde; +use crate::store::{AiSessionStore, StoredEvent, StoredSession}; +use crate::tui::ConversationEvent; + +// --------------------------------------------------------------------------- +// Trait +// --------------------------------------------------------------------------- + +#[async_trait] +pub(crate) trait SessionService: Send + Sync { + async fn create_session( + &self, + id: &str, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Result<StoredSession>; + + async fn find_resumable( + &self, + directory: Option<&str>, + git_root: Option<&str>, + max_age_secs: i64, + ) -> Result<Option<StoredSession>>; + + async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>>; + + async fn append_event( + &self, + session_id: &str, + event_id: &str, + parent_id: Option<&str>, + invocation_id: &str, + event_type: &str, + event_data: &str, + ) -> Result<()>; + + async fn update_server_session_id( + &self, + session_id: &str, + server_session_id: &str, + ) -> Result<()>; + + async fn archive(&self, session_id: &str) -> Result<()>; +} + +// --------------------------------------------------------------------------- +// Local implementation (direct SQLite) +// --------------------------------------------------------------------------- + +pub(crate) struct LocalSessionService { + store: AiSessionStore, +} + +impl LocalSessionService { + pub async fn open(path: impl AsRef<std::path::Path>, timeout: f64) -> Result<Self> { + let store = AiSessionStore::new(path, timeout).await?; + Ok(Self { store }) + } +} + +#[async_trait] +impl SessionService for LocalSessionService { + async fn create_session( + &self, + id: &str, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Result<StoredSession> { + self.store.create_session(id, directory, git_root).await + } + + async fn find_resumable( + &self, + directory: Option<&str>, + git_root: Option<&str>, + max_age_secs: i64, + ) -> Result<Option<StoredSession>> { + self.store + .find_resumable_session(directory, git_root, max_age_secs) + .await + } + + async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> { + self.store.load_events(session_id).await + } + + async fn append_event( + &self, + session_id: &str, + event_id: &str, + parent_id: Option<&str>, + invocation_id: &str, + event_type: &str, + event_data: &str, + ) -> Result<()> { + self.store + .append_event( + session_id, + event_id, + parent_id, + invocation_id, + event_type, + event_data, + ) + .await + } + + async fn update_server_session_id( + &self, + session_id: &str, + server_session_id: &str, + ) -> Result<()> { + self.store + .update_server_session_id(session_id, server_session_id) + .await + } + + async fn archive(&self, session_id: &str) -> Result<()> { + self.store.archive_session(session_id).await + } +} + +// --------------------------------------------------------------------------- +// SessionManager +// --------------------------------------------------------------------------- + +/// High-level session manager used by the TUI dispatch loop. +/// +/// Owns the current session identity, tracks what has been persisted, and +/// handles serialization between `ConversationEvent` and the storage format. +pub(crate) struct SessionManager { + service: Box<dyn SessionService>, + session_id: String, + invocation_id: String, + /// Number of events already persisted. `persist_events` only writes the + /// delta from this index onward. + persisted_count: usize, + /// ID of the last persisted event, used as `parent_id` for the next one. + head_id: Option<String>, + /// Stored for creating a new session on `/new`. + directory: Option<String>, + git_root: Option<String>, + /// Whether the session row has been created in the database. New sessions + /// are deferred until the first event is persisted, so empty sessions + /// don't linger and get spuriously resumed. + persisted_to_db: bool, +} + +impl SessionManager { + /// Create a new session manager. The database row is deferred until the + /// first event is persisted. + pub fn create_new( + service: Box<dyn SessionService>, + directory: Option<&str>, + git_root: Option<&str>, + ) -> Self { + let session_id = atuin_common::utils::uuid_v7().to_string(); + let invocation_id = atuin_common::utils::uuid_v7().to_string(); + + Self { + service, + session_id, + invocation_id, + persisted_count: 0, + head_id: None, + directory: directory.map(String::from), + git_root: git_root.map(String::from), + persisted_to_db: false, + } + } + + /// Load an existing session and return a manager for it, along with the + /// deserialized conversation events, the server session ID, and the + /// timestamp of the last stored event. + pub async fn resume( + service: Box<dyn SessionService>, + stored: &StoredSession, + ) -> Result<( + Self, + Vec<ConversationEvent>, + Option<String>, + Option<i64>, + String, + )> { + let invocation_id = atuin_common::utils::uuid_v7().to_string(); + let stored_events = service.load_events(&stored.id).await?; + + let mut events = Vec::with_capacity(stored_events.len()); + let mut last_event_id = None; + let mut last_event_ts = None; + for se in &stored_events { + events.push(event_serde::deserialize_event( + &se.event_type, + &se.event_data, + )?); + last_event_id = Some(se.id.clone()); + last_event_ts = Some(se.created_at); + } + + let manager = Self { + service, + session_id: stored.id.clone(), + invocation_id: invocation_id.clone(), + persisted_count: events.len(), + head_id: last_event_id, + directory: stored.directory.clone(), + git_root: stored.git_root.clone(), + persisted_to_db: true, + }; + + Ok(( + manager, + events, + stored.server_session_id.clone(), + last_event_ts, + invocation_id, + )) + } + + /// Ensure the session row exists in the database. + async fn ensure_persisted(&mut self) -> Result<()> { + if !self.persisted_to_db { + self.service + .create_session( + &self.session_id, + self.directory.as_deref(), + self.git_root.as_deref(), + ) + .await?; + self.persisted_to_db = true; + } + Ok(()) + } + + /// Persist any new events since the last persist call. + pub async fn persist_events(&mut self, events: &[ConversationEvent]) -> Result<()> { + if self.persisted_count >= events.len() { + return Ok(()); + } + self.ensure_persisted().await?; + for event in &events[self.persisted_count..] { + let event_id = atuin_common::utils::uuid_v7().to_string(); + let (event_type, event_data) = event_serde::serialize_event(event); + + self.service + .append_event( + &self.session_id, + &event_id, + self.head_id.as_deref(), + &self.invocation_id, + &event_type, + &event_data, + ) + .await?; + + self.head_id = Some(event_id); + self.persisted_count += 1; + } + Ok(()) + } + + /// Persist the server session ID if it has changed. + pub async fn persist_server_session_id(&mut self, server_session_id: &str) -> Result<()> { + self.ensure_persisted().await?; + self.service + .update_server_session_id(&self.session_id, server_session_id) + .await + } + + /// Archive the current session (for `/new` command). + #[allow(dead_code)] // used in tests; will be used by dispatch for `/new` + pub async fn archive(&self) -> Result<()> { + if self.persisted_to_db { + self.service.archive(&self.session_id).await?; + } + Ok(()) + } + + /// Archive the current session and reset to a fresh one. + /// The new session row is deferred until the first event is persisted. + pub async fn archive_and_reset(&mut self) -> Result<()> { + if self.persisted_to_db { + self.service.archive(&self.session_id).await?; + } + + self.session_id = atuin_common::utils::uuid_v7().to_string(); + self.invocation_id = atuin_common::utils::uuid_v7().to_string(); + self.persisted_count = 0; + self.head_id = None; + self.persisted_to_db = false; + Ok(()) + } + + #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon + pub fn session_id(&self) -> &str { + &self.session_id + } + + #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon + pub fn invocation_id(&self) -> &str { + &self.invocation_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn test_service() -> Box<dyn SessionService> { + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + Box::new(svc) + } + + #[tokio::test] + async fn test_create_new_and_persist() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + let events = vec![ + ConversationEvent::UserMessage { + content: "hello".to_string(), + }, + ConversationEvent::Text { + content: "hi there".to_string(), + }, + ]; + + mgr.persist_events(&events).await.unwrap(); + + // Persist again with no new events — should be a no-op + mgr.persist_events(&events).await.unwrap(); + } + + #[tokio::test] + async fn test_create_and_resume() { + // Create a session and persist some events + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let session_id = atuin_common::utils::uuid_v7().to_string(); + svc.create_session(&session_id, Some("/project"), Some("/project")) + .await + .unwrap(); + + let events = vec![ + ConversationEvent::UserMessage { + content: "how do I list files?".to_string(), + }, + ConversationEvent::Text { + content: "Use ls".to_string(), + }, + ConversationEvent::ToolCall { + id: "tc_1".to_string(), + name: "suggest_command".to_string(), + input: serde_json::json!({"command": "ls -la"}), + }, + ]; + + // Persist events manually through the service + let inv_id = "inv-1"; + let mut parent: Option<String> = None; + for event in &events { + let eid = atuin_common::utils::uuid_v7().to_string(); + let (etype, edata) = event_serde::serialize_event(event); + svc.append_event(&session_id, &eid, parent.as_deref(), inv_id, &etype, &edata) + .await + .unwrap(); + parent = Some(eid); + } + + svc.update_server_session_id(&session_id, "srv-abc") + .await + .unwrap(); + + // Now find and resume the session with a fresh service connection + let stored = svc + .find_resumable(Some("/project"), Some("/project"), 3600) + .await + .unwrap() + .expect("should find session"); + + let (mut mgr, loaded_events, server_sid, last_ts, _invocation_id) = + SessionManager::resume(Box::new(svc), &stored) + .await + .unwrap(); + + assert_eq!(loaded_events.len(), 3); + assert_eq!(server_sid.as_deref(), Some("srv-abc")); + assert_ne!(mgr.invocation_id(), inv_id, "new invocation ID on resume"); + assert!(last_ts.is_some(), "should have a last event timestamp"); + + // Persisting again with the same events should be a no-op + mgr.persist_events(&loaded_events).await.unwrap(); + } + + #[tokio::test] + async fn test_incremental_persist() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + let mut events = vec![ConversationEvent::UserMessage { + content: "first".to_string(), + }]; + mgr.persist_events(&events).await.unwrap(); + + // Add more events and persist again — only the new ones should be written + events.push(ConversationEvent::Text { + content: "response".to_string(), + }); + events.push(ConversationEvent::UserMessage { + content: "second".to_string(), + }); + mgr.persist_events(&events).await.unwrap(); + + // Verify by loading through a fresh service (can't easily here since + // the service is moved, but the lack of errors confirms correctness) + } + + #[tokio::test] + async fn test_archive() { + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); + + mgr.archive().await.unwrap(); + } + + #[tokio::test] + async fn test_persist_server_session_id() { + let service = test_service().await; + let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); + + mgr.persist_server_session_id("srv-123").await.unwrap(); + } + + #[tokio::test] + async fn test_parent_chain_integrity() { + // Verify that persisted events form a proper parent chain + let svc = LocalSessionService::open("sqlite::memory:", 2.0) + .await + .unwrap(); + + let session_id = { + let mut mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); + + let events = vec![ + ConversationEvent::UserMessage { + content: "a".to_string(), + }, + ConversationEvent::Text { + content: "b".to_string(), + }, + ConversationEvent::UserMessage { + content: "c".to_string(), + }, + ]; + mgr.persist_events(&events).await.unwrap(); + mgr.session_id().to_string() + }; + + // Re-open the store and load events to verify the chain + // (Can't do this with in-memory DB since it's gone, but the + // lack of FK constraint violations during persist confirms the + // parent_id values are valid) + let _ = session_id; + } +} |
