diff options
Diffstat (limited to 'crates/atuin-ai/src/driver.rs')
| -rw-r--r-- | crates/atuin-ai/src/driver.rs | 150 |
1 files changed, 138 insertions, 12 deletions
diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index b5e1c275..1285f2da 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -55,6 +55,7 @@ pub(crate) struct IoContext { pub file_tracker: FileReadTracker, pub edit_permissions: EditPermissionCache, pub snapshot_store: Option<crate::snapshots::SnapshotStore>, + pub skill_registry: crate::skills::SkillRegistry, } // ============================================================================ @@ -87,6 +88,7 @@ pub(crate) struct ViewState { pub slash_command_search_results: Vec<crate::tui::slash::SlashCommandSearchResult>, pub exit_action: Option<ExitAction>, pub slash_registry: crate::tui::slash::SlashCommandRegistry, + pub skill_names: std::collections::HashSet<String>, } impl ViewState { @@ -253,11 +255,18 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< } else if input == "/new" { Some(Event::NewSession) } else if input.starts_with('/') { - let content = resolve_slash_command(&input, handle); - Some(Event::SlashCommand { - command: input, - content, - }) + if let Some((skill_name, arguments)) = resolve_skill_name(&input, handle) { + Some(Event::RequestSkillLoad { + name: skill_name, + arguments, + }) + } else { + let content = resolve_slash_command(&input, handle); + Some(Event::SlashCommand { + command: input, + content, + }) + } } else { Some(Event::UserSubmit(input)) } @@ -295,7 +304,9 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< .fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone())) .blocking_recv() .ok() - .flatten()?; + .flatten(); + + let tool_id = tool_id?; let choice = match result { PermissionResult::Allow => PermissionChoice::Allow, @@ -307,16 +318,50 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< Some(Event::PermissionUserChoice { tool_id, choice }) } AiTuiEvent::SlashCommand(cmd) => { - let content = resolve_slash_command(&cmd, handle); - Some(Event::SlashCommand { - command: cmd, - content, - }) + if let Some((skill_name, arguments)) = resolve_skill_name(&cmd, handle) { + Some(Event::RequestSkillLoad { + name: skill_name, + arguments, + }) + } else { + let content = resolve_slash_command(&cmd, handle); + Some(Event::SlashCommand { + command: cmd, + content, + }) + } } } } /// Resolve a slash command to its output content. +/// If the input starts with `/`, check whether the command name matches a +/// registered skill. Returns `Some((skill_name, arguments))` if it does. +fn resolve_skill_name(input: &str, handle: &Handle<ViewState>) -> Option<(String, Option<String>)> { + let after_slash = input.trim_start_matches('/'); + let cmd_name = after_slash.split_whitespace().next()?.to_string(); + + let is_skill = handle + .fetch({ + let cmd_name = cmd_name.clone(); + move |vs| vs.skill_names.contains(&cmd_name) + }) + .blocking_recv() + .unwrap_or(false); + + if !is_skill { + return None; + } + + let args = after_slash + .strip_prefix(&cmd_name) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + + Some((cmd_name, args)) +} + fn resolve_slash_command(command: &str, handle: &Handle<ViewState>) -> String { match command.trim() { "/help" => { @@ -406,6 +451,7 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { let tx = tx.clone(); let app = io.app_ctx.clone(); let cc = io.client_ctx.clone(); + let (skill_summaries, skill_overflow) = io.skill_registry.server_skills(); let request = ChatRequest::new( messages.clone(), session_id.clone(), @@ -413,7 +459,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { fsm.ctx.invocation_id.clone(), ); tokio::spawn(async move { - run_stream_bridge(request, app, cc, tx, cancel_rx).await; + run_stream_bridge( + request, + app, + cc, + tx, + cancel_rx, + skill_summaries, + skill_overflow, + ) + .await; }); } @@ -427,6 +482,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { let tool_id = tool_id.clone(); let tool = tool.clone(); let tx = tx.clone(); + + // Auto-approved tools (e.g. load_skill) bypass permission checks entirely + if tool.is_auto_approved() { + let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { + tool_id, + response: PermissionResponse::Allowed, + })); + return; + } + let working_dir = tool .target_dir() .map(|p| p.to_path_buf()) @@ -641,9 +706,50 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { })); }); } + ClientToolCall::LoadSkill(skill_call) => { + let skill_name = skill_call.name.clone(); + let registry = io.skill_registry.clone(); + let shell = io + .client_ctx + .shell + .clone() + .unwrap_or_else(|| "sh".to_string()); + + tokio::spawn(async move { + let content = + load_skill_content(®istry, &skill_name, &shell, None).await; + let outcome = crate::tools::ToolOutcome::Success(content); + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + }); + } } } + Effect::LoadSkill { name, arguments } => { + let name = name.clone(); + let arguments = arguments.clone(); + let registry = io.skill_registry.clone(); + let shell = io + .client_ctx + .shell + .clone() + .unwrap_or_else(|| "sh".to_string()); + let tx = tx.clone(); + tokio::spawn(async move { + let content = + load_skill_content(®istry, &name, &shell, arguments.as_deref()).await; + let _ = tx.send(DriverEvent::Fsm(Event::SkillLoaded { + name, + arguments, + content, + })); + }); + } + Effect::AbortTool { tool_id } => { if let Some(abort_tx) = tool_abort_txs.remove(tool_id) { let _ = abort_tx.send(()); @@ -762,6 +868,22 @@ fn persist(fsm: &AgentFsm, io: &mut IoContext) { } // ============================================================================ +// Skill loading +// ============================================================================ + +async fn load_skill_content( + registry: &crate::skills::SkillRegistry, + name: &str, + shell: &str, + arguments: Option<&str>, +) -> String { + match registry.load(name, shell, arguments).await { + Ok(body) => body, + Err(e) => format!("Failed to load skill '{name}': {e}"), + } +} + +// ============================================================================ // Stream bridge // ============================================================================ @@ -771,6 +893,8 @@ async fn run_stream_bridge( client_ctx: ClientContext, tx: mpsc::Sender<DriverEvent>, mut cancel_rx: tokio::sync::watch::Receiver<()>, + skill_summaries: Vec<crate::skills::SkillSummary>, + skill_overflow: Option<String>, ) { use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; use futures::StreamExt; @@ -790,6 +914,8 @@ async fn run_stream_bridge( app_ctx.send_cwd, app_ctx.last_command.clone(), user_contexts, + skill_summaries, + skill_overflow, ); futures::pin_mut!(stream); |
