diff options
Diffstat (limited to 'crates/atuin-ai/src/tui/dispatch.rs')
| -rw-r--r-- | crates/atuin-ai/src/tui/dispatch.rs | 240 |
1 files changed, 104 insertions, 136 deletions
diff --git a/crates/atuin-ai/src/tui/dispatch.rs b/crates/atuin-ai/src/tui/dispatch.rs index ee2bbe74..ea895c01 100644 --- a/crates/atuin-ai/src/tui/dispatch.rs +++ b/crates/atuin-ai/src/tui/dispatch.rs @@ -1,4 +1,6 @@ use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use crate::context::{AppContext, ClientContext}; @@ -15,64 +17,55 @@ use crate::tui::state::{ConversationEvent, ExitAction, Session}; use eye_declare::Handle; use tokio::task::JoinHandle; -pub(crate) fn dispatch( - handle: &Handle<Session>, - event: AiTuiEvent, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - client_ctx: &ClientContext, - session_mgr: &mut SessionManager, -) { +/// Shared context for the dispatch loop. Bundles the references every +/// handler might need so `dispatch` doesn't forward a different subset +/// to each one. +pub(crate) struct DispatchContext<'a> { + pub handle: &'a Handle<Session>, + pub tx: &'a mpsc::Sender<AiTuiEvent>, + pub app_ctx: &'a AppContext, + pub client_ctx: &'a ClientContext, + pub session_mgr: &'a mut SessionManager, + /// Set by any handler that calls `h.exit()`. Read by `dispatch()` + /// to break the loop — without round-tripping through the handle, + /// which would hang if the TUI has already stopped. + pub exiting: Arc<AtomicBool>, +} + +/// Dispatch a single event. Returns `true` to keep the loop running, +/// `false` to shut down (after the final persist has completed). +pub(crate) fn dispatch(ctx: &mut DispatchContext, event: AiTuiEvent) -> bool { match event { - AiTuiEvent::ContinueAfterTools => { - on_continue_after_tools(handle, tx, app_ctx, client_ctx); - } - AiTuiEvent::InputUpdated(input) => { - on_input_updated(handle, input); - } - AiTuiEvent::SubmitInput(input) => { - on_submit_input(handle, tx, app_ctx, client_ctx, input, session_mgr); - } - AiTuiEvent::SlashCommand(cmd) => { - on_slash_command(handle, cmd); - } - AiTuiEvent::CheckToolCallPermission(id) => { - on_check_tool_permission(handle, tx, app_ctx, id); - } - AiTuiEvent::SelectPermission(result) => { - on_select_permission(handle, tx, app_ctx, result); - } - AiTuiEvent::CancelGeneration => { - on_cancel_generation(handle); - } - AiTuiEvent::ExecuteCommand => { - on_execute_command(handle); - } - AiTuiEvent::CancelConfirmation => { - on_cancel_confirmation(handle); - } - AiTuiEvent::InterruptToolExecution => { - on_interrupt_tool_execution(handle); - } - AiTuiEvent::InsertCommand => { - on_insert_command(handle); - } - AiTuiEvent::Retry => { - on_retry(handle, tx, app_ctx, client_ctx); - } - AiTuiEvent::Exit => { - on_exit(handle); - } + AiTuiEvent::ContinueAfterTools => on_continue_after_tools(ctx), + AiTuiEvent::InputUpdated(input) => on_input_updated(ctx, input), + AiTuiEvent::SubmitInput(input) => on_submit_input(ctx, input), + AiTuiEvent::SlashCommand(cmd) => on_slash_command(ctx, cmd), + AiTuiEvent::CheckToolCallPermission(id) => on_check_tool_permission(ctx, id), + AiTuiEvent::SelectPermission(result) => on_select_permission(ctx, result), + AiTuiEvent::CancelGeneration => on_cancel_generation(ctx), + AiTuiEvent::ExecuteCommand => on_execute_command(ctx), + AiTuiEvent::CancelConfirmation => on_cancel_confirmation(ctx), + AiTuiEvent::InterruptToolExecution => on_interrupt_tool_execution(ctx), + AiTuiEvent::InsertCommand => on_insert_command(ctx), + AiTuiEvent::Retry => on_retry(ctx), + AiTuiEvent::Exit => on_exit(ctx), } // Persist any new conversation events after each dispatch cycle. - persist_session(handle, session_mgr); + persist_session(ctx); + + // The exiting flag is set by any handler that calls h.exit(). We + // read it here rather than querying state through the handle, + // because the TUI thread may have already stopped processing + // handle requests by this point. + !ctx.exiting.load(Ordering::Acquire) } /// Persist new events and the server session ID if it has changed. /// Called from the dispatch thread (sync), bridges to async via the tokio handle. -fn persist_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { - let Ok((events, server_sid)) = handle +fn persist_session(ctx: &mut DispatchContext) { + let Ok((events, server_sid)) = ctx + .handle .fetch(|state| { ( state.conversation.events.clone(), @@ -85,29 +78,23 @@ fn persist_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { }; let rt = tokio::runtime::Handle::current(); - if let Err(e) = rt.block_on(session_mgr.persist_events(&events)) { + if let Err(e) = rt.block_on(ctx.session_mgr.persist_events(&events)) { tracing::warn!("failed to persist session events: {e}"); } if let Some(ref sid) = server_sid - && let Err(e) = rt.block_on(session_mgr.persist_server_session_id(sid)) + && let Err(e) = rt.block_on(ctx.session_mgr.persist_server_session_id(sid)) { tracing::warn!("failed to persist server session ID: {e}"); } } -fn launch_stream( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - client_ctx: &ClientContext, - setup: impl FnOnce(&mut Session) + Send + 'static, -) { - let h2 = handle.clone(); - let tx2 = tx.clone(); - let app = app_ctx.clone(); - let cc = client_ctx.clone(); - let caps = app_ctx.capabilities.clone(); - handle.update(move |state| { +fn launch_stream(ctx: &DispatchContext, setup: impl FnOnce(&mut Session) + Send + 'static) { + let h2 = ctx.handle.clone(); + let tx2 = ctx.tx.clone(); + let app = ctx.app_ctx.clone(); + let cc = ctx.client_ctx.clone(); + let caps = ctx.app_ctx.capabilities.clone(); + ctx.handle.update(move |state| { (setup)(state); state.start_streaming(); let messages = @@ -121,16 +108,11 @@ fn launch_stream( }); } -fn on_continue_after_tools( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - client_ctx: &ClientContext, -) { - launch_stream(handle, tx, app_ctx, client_ctx, |_state| {}); +fn on_continue_after_tools(ctx: &mut DispatchContext) { + launch_stream(ctx, |_state| {}); } -fn on_input_updated(handle: &Handle<Session>, input: String) { +fn on_input_updated(ctx: &mut DispatchContext, input: String) { let input_blank = input.is_empty(); let slash_command = if input.starts_with('/') { Some(input.trim_start_matches('/').to_string()) @@ -138,7 +120,7 @@ fn on_input_updated(handle: &Handle<Session>, input: String) { None }; - handle.update(move |state| { + ctx.handle.update(move |state| { state.interaction.is_input_blank = input_blank; state.interaction.slash_command_input = slash_command; @@ -158,23 +140,17 @@ fn on_input_updated(handle: &Handle<Session>, input: String) { }); } -fn on_submit_input( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - client_ctx: &ClientContext, - input: String, - session_mgr: &mut SessionManager, -) { - handle.update(move |state| { +fn on_submit_input(ctx: &mut DispatchContext, input: String) { + ctx.handle.update(move |state| { state.interaction.slash_command_input = None; state.interaction.slash_command_search_results.clear(); }); let input = input.trim().to_string(); if input.is_empty() { - let h2 = handle.clone(); - handle.update(move |state| { + let h2 = ctx.handle.clone(); + let exiting = ctx.exiting.clone(); + ctx.handle.update(move |state| { if state.conversation.has_any_command() { state.exit_action = Some(ExitAction::Execute( state.conversation.current_command().unwrap().to_string(), @@ -182,6 +158,7 @@ fn on_submit_input( } else { state.exit_action = Some(ExitAction::Cancel); } + exiting.store(true, Ordering::Release); h2.exit(); }); return; @@ -189,9 +166,9 @@ fn on_submit_input( if input.starts_with('/') { if input.trim() == "/new" { - on_new_session(handle, session_mgr); + on_new_session(ctx); } else { - handle.update(move |state| { + ctx.handle.update(move |state| { state .conversation .handle_slash_command(&input, &state.slash_registry); @@ -201,14 +178,14 @@ fn on_submit_input( } // Start generation and spawn streaming task - launch_stream(handle, tx, app_ctx, client_ctx, |state| { + launch_stream(ctx, |state| { state.start_generating(input); state.interaction.is_input_blank = true; }); } -fn on_slash_command(handle: &Handle<Session>, command: String) { - handle.update(move |state| { +fn on_slash_command(ctx: &mut DispatchContext, command: String) { + ctx.handle.update(move |state| { state .conversation .handle_slash_command(&command, &state.slash_registry); @@ -330,15 +307,10 @@ fn execute_shell_tool( // Permission handlers // ─────────────────────────────────────────────────────────────────── -fn on_check_tool_permission( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - id: String, -) { - let h2 = handle.clone(); - let tx_for_task = tx.clone(); - let db = app_ctx.history_db.clone(); +fn on_check_tool_permission(ctx: &mut DispatchContext, id: String) { + let h2 = ctx.handle.clone(); + let tx_for_task = ctx.tx.clone(); + let db = ctx.app_ctx.history_db.clone(); tokio::spawn(async move { let id_for_error = id.clone(); @@ -427,19 +399,14 @@ async fn check_tool_permission_inner( Ok(()) } -fn on_select_permission( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - permission: PermissionResult, -) { - let tx = tx.clone(); - let h2 = handle.clone(); +fn on_select_permission(ctx: &mut DispatchContext, permission: PermissionResult) { + let tx = ctx.tx.clone(); + let h2 = ctx.handle.clone(); match permission { PermissionResult::Allow => { // Fetch the tool that's asking for permission, then execute it - let db = app_ctx.history_db.clone(); + let db = ctx.app_ctx.history_db.clone(); tokio::spawn(async move { let Ok(Some((tool_id, tool))) = h2 .fetch(move |state| { @@ -457,8 +424,8 @@ fn on_select_permission( }); } PermissionResult::AlwaysAllowInDir => { - let db = app_ctx.history_db.clone(); - let git_root = app_ctx.git_root.clone(); + let db = ctx.app_ctx.history_db.clone(); + let git_root = ctx.app_ctx.git_root.clone(); tokio::spawn(async move { let Ok(Some((tool_id, tool))) = h2 .fetch(move |state| { @@ -490,7 +457,7 @@ fn on_select_permission( }); } PermissionResult::AlwaysAllow => { - let db = app_ctx.history_db.clone(); + let db = ctx.app_ctx.history_db.clone(); tokio::spawn(async move { let Ok(Some((tool_id, tool))) = h2 .fetch(move |state| { @@ -541,8 +508,8 @@ fn on_select_permission( // Other handlers // ─────────────────────────────────────────────────────────────────── -fn on_cancel_generation(handle: &Handle<Session>) { - handle.update(|state| match state.interaction.mode { +fn on_cancel_generation(ctx: &mut DispatchContext) { + ctx.handle.update(|state| match state.interaction.mode { crate::tui::state::AppMode::Generating => { state.cancel_generation(); } @@ -553,9 +520,10 @@ fn on_cancel_generation(handle: &Handle<Session>) { }); } -fn on_execute_command(handle: &Handle<Session>) { - let h2 = handle.clone(); - handle.update(move |state| { +fn on_execute_command(ctx: &mut DispatchContext) { + let h2 = ctx.handle.clone(); + let exiting = ctx.exiting.clone(); + ctx.handle.update(move |state| { let cmd = state.conversation.current_command().map(|c| c.to_string()); if let Some(cmd) = cmd { if state.conversation.is_current_command_dangerous() @@ -565,50 +533,48 @@ fn on_execute_command(handle: &Handle<Session>) { } else { state.interaction.confirmation_pending = false; state.exit_action = Some(ExitAction::Execute(cmd)); + exiting.store(true, Ordering::Release); h2.exit(); } } }); } -fn on_cancel_confirmation(handle: &Handle<Session>) { - handle.update(move |state| { +fn on_cancel_confirmation(ctx: &mut DispatchContext) { + ctx.handle.update(move |state| { state.interaction.confirmation_pending = false; }); } -fn on_insert_command(handle: &Handle<Session>) { - let h2 = handle.clone(); - handle.update(move |state| { +fn on_insert_command(ctx: &mut DispatchContext) { + let h2 = ctx.handle.clone(); + let exiting = ctx.exiting.clone(); + ctx.handle.update(move |state| { let cmd = state.conversation.current_command().map(|c| c.to_string()); if let Some(cmd) = cmd { state.interaction.confirmation_pending = false; state.exit_action = Some(ExitAction::Insert(cmd)); + exiting.store(true, Ordering::Release); h2.exit(); } }); } -fn on_retry( - handle: &Handle<Session>, - tx: &mpsc::Sender<AiTuiEvent>, - app_ctx: &AppContext, - client_ctx: &ClientContext, -) { - launch_stream(handle, tx, app_ctx, client_ctx, |state| { +fn on_retry(ctx: &mut DispatchContext) { + launch_stream(ctx, |state| { state.retry(); }); } -fn on_new_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { +fn on_new_session(ctx: &mut DispatchContext) { let rt = tokio::runtime::Handle::current(); - if let Err(e) = rt.block_on(session_mgr.archive_and_reset()) { + if let Err(e) = rt.block_on(ctx.session_mgr.archive_and_reset()) { tracing::warn!("failed to start new session: {e}"); return; } - handle.update(|state| { + ctx.handle.update(|state| { // Move the current invocation's visible events to the archived view // so they remain on screen but are no longer sent to the API. let visible_events: Vec<ConversationEvent> = @@ -632,19 +598,21 @@ fn on_new_session(handle: &Handle<Session>, session_mgr: &mut SessionManager) { }); } -fn on_exit(handle: &Handle<Session>) { - let h2 = handle.clone(); - handle.update(move |state| { +fn on_exit(ctx: &mut DispatchContext) { + let h2 = ctx.handle.clone(); + let exiting = ctx.exiting.clone(); + ctx.handle.update(move |state| { if let Some(abort) = state.stream_abort.take() { abort.abort(); } state.exit_action = Some(ExitAction::Cancel); + exiting.store(true, Ordering::Release); h2.exit(); }); } -fn on_interrupt_tool_execution(handle: &Handle<Session>) { - handle.update(move |state| { +fn on_interrupt_tool_execution(ctx: &mut DispatchContext) { + ctx.handle.update(move |state| { // Find executing previews, send interrupt, and mark as interrupted for tracked in state.tool_tracker.iter_mut() { if let ToolPhase::ExecutingWithPreview { |
