aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/driver.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/driver.rs')
-rw-r--r--crates/atuin-ai/src/driver.rs150
1 files changed, 138 insertions, 12 deletions
diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs
index b5e1c275..1285f2da 100644
--- a/crates/atuin-ai/src/driver.rs
+++ b/crates/atuin-ai/src/driver.rs
@@ -55,6 +55,7 @@ pub(crate) struct IoContext {
pub file_tracker: FileReadTracker,
pub edit_permissions: EditPermissionCache,
pub snapshot_store: Option<crate::snapshots::SnapshotStore>,
+ pub skill_registry: crate::skills::SkillRegistry,
}
// ============================================================================
@@ -87,6 +88,7 @@ pub(crate) struct ViewState {
pub slash_command_search_results: Vec<crate::tui::slash::SlashCommandSearchResult>,
pub exit_action: Option<ExitAction>,
pub slash_registry: crate::tui::slash::SlashCommandRegistry,
+ pub skill_names: std::collections::HashSet<String>,
}
impl ViewState {
@@ -253,11 +255,18 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option<
} else if input == "/new" {
Some(Event::NewSession)
} else if input.starts_with('/') {
- let content = resolve_slash_command(&input, handle);
- Some(Event::SlashCommand {
- command: input,
- content,
- })
+ if let Some((skill_name, arguments)) = resolve_skill_name(&input, handle) {
+ Some(Event::RequestSkillLoad {
+ name: skill_name,
+ arguments,
+ })
+ } else {
+ let content = resolve_slash_command(&input, handle);
+ Some(Event::SlashCommand {
+ command: input,
+ content,
+ })
+ }
} else {
Some(Event::UserSubmit(input))
}
@@ -295,7 +304,9 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option<
.fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone()))
.blocking_recv()
.ok()
- .flatten()?;
+ .flatten();
+
+ let tool_id = tool_id?;
let choice = match result {
PermissionResult::Allow => PermissionChoice::Allow,
@@ -307,16 +318,50 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option<
Some(Event::PermissionUserChoice { tool_id, choice })
}
AiTuiEvent::SlashCommand(cmd) => {
- let content = resolve_slash_command(&cmd, handle);
- Some(Event::SlashCommand {
- command: cmd,
- content,
- })
+ if let Some((skill_name, arguments)) = resolve_skill_name(&cmd, handle) {
+ Some(Event::RequestSkillLoad {
+ name: skill_name,
+ arguments,
+ })
+ } else {
+ let content = resolve_slash_command(&cmd, handle);
+ Some(Event::SlashCommand {
+ command: cmd,
+ content,
+ })
+ }
}
}
}
/// Resolve a slash command to its output content.
+/// If the input starts with `/`, check whether the command name matches a
+/// registered skill. Returns `Some((skill_name, arguments))` if it does.
+fn resolve_skill_name(input: &str, handle: &Handle<ViewState>) -> Option<(String, Option<String>)> {
+ let after_slash = input.trim_start_matches('/');
+ let cmd_name = after_slash.split_whitespace().next()?.to_string();
+
+ let is_skill = handle
+ .fetch({
+ let cmd_name = cmd_name.clone();
+ move |vs| vs.skill_names.contains(&cmd_name)
+ })
+ .blocking_recv()
+ .unwrap_or(false);
+
+ if !is_skill {
+ return None;
+ }
+
+ let args = after_slash
+ .strip_prefix(&cmd_name)
+ .map(|s| s.trim())
+ .filter(|s| !s.is_empty())
+ .map(|s| s.to_string());
+
+ Some((cmd_name, args))
+}
+
fn resolve_slash_command(command: &str, handle: &Handle<ViewState>) -> String {
match command.trim() {
"/help" => {
@@ -406,6 +451,7 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) {
let tx = tx.clone();
let app = io.app_ctx.clone();
let cc = io.client_ctx.clone();
+ let (skill_summaries, skill_overflow) = io.skill_registry.server_skills();
let request = ChatRequest::new(
messages.clone(),
session_id.clone(),
@@ -413,7 +459,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) {
fsm.ctx.invocation_id.clone(),
);
tokio::spawn(async move {
- run_stream_bridge(request, app, cc, tx, cancel_rx).await;
+ run_stream_bridge(
+ request,
+ app,
+ cc,
+ tx,
+ cancel_rx,
+ skill_summaries,
+ skill_overflow,
+ )
+ .await;
});
}
@@ -427,6 +482,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) {
let tool_id = tool_id.clone();
let tool = tool.clone();
let tx = tx.clone();
+
+ // Auto-approved tools (e.g. load_skill) bypass permission checks entirely
+ if tool.is_auto_approved() {
+ let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved {
+ tool_id,
+ response: PermissionResponse::Allowed,
+ }));
+ return;
+ }
+
let working_dir = tool
.target_dir()
.map(|p| p.to_path_buf())
@@ -641,9 +706,50 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) {
}));
});
}
+ ClientToolCall::LoadSkill(skill_call) => {
+ let skill_name = skill_call.name.clone();
+ let registry = io.skill_registry.clone();
+ let shell = io
+ .client_ctx
+ .shell
+ .clone()
+ .unwrap_or_else(|| "sh".to_string());
+
+ tokio::spawn(async move {
+ let content =
+ load_skill_content(&registry, &skill_name, &shell, None).await;
+ let outcome = crate::tools::ToolOutcome::Success(content);
+ let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone {
+ tool_id,
+ outcome,
+ preview: None,
+ }));
+ });
+ }
}
}
+ Effect::LoadSkill { name, arguments } => {
+ let name = name.clone();
+ let arguments = arguments.clone();
+ let registry = io.skill_registry.clone();
+ let shell = io
+ .client_ctx
+ .shell
+ .clone()
+ .unwrap_or_else(|| "sh".to_string());
+ let tx = tx.clone();
+ tokio::spawn(async move {
+ let content =
+ load_skill_content(&registry, &name, &shell, arguments.as_deref()).await;
+ let _ = tx.send(DriverEvent::Fsm(Event::SkillLoaded {
+ name,
+ arguments,
+ content,
+ }));
+ });
+ }
+
Effect::AbortTool { tool_id } => {
if let Some(abort_tx) = tool_abort_txs.remove(tool_id) {
let _ = abort_tx.send(());
@@ -762,6 +868,22 @@ fn persist(fsm: &AgentFsm, io: &mut IoContext) {
}
// ============================================================================
+// Skill loading
+// ============================================================================
+
+async fn load_skill_content(
+ registry: &crate::skills::SkillRegistry,
+ name: &str,
+ shell: &str,
+ arguments: Option<&str>,
+) -> String {
+ match registry.load(name, shell, arguments).await {
+ Ok(body) => body,
+ Err(e) => format!("Failed to load skill '{name}': {e}"),
+ }
+}
+
+// ============================================================================
// Stream bridge
// ============================================================================
@@ -771,6 +893,8 @@ async fn run_stream_bridge(
client_ctx: ClientContext,
tx: mpsc::Sender<DriverEvent>,
mut cancel_rx: tokio::sync::watch::Receiver<()>,
+ skill_summaries: Vec<crate::skills::SkillSummary>,
+ skill_overflow: Option<String>,
) {
use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream};
use futures::StreamExt;
@@ -790,6 +914,8 @@ async fn run_stream_bridge(
app_ctx.send_cwd,
app_ctx.last_command.clone(),
user_contexts,
+ skill_summaries,
+ skill_overflow,
);
futures::pin_mut!(stream);