diff options
Diffstat (limited to 'crates/atuin-ai/src')
| -rw-r--r-- | crates/atuin-ai/src/commands/inline.rs | 30 | ||||
| -rw-r--r-- | crates/atuin-ai/src/context.rs | 1 | ||||
| -rw-r--r-- | crates/atuin-ai/src/driver.rs | 150 | ||||
| -rw-r--r-- | crates/atuin-ai/src/event_serde.rs | 21 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/effects.rs | 5 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/events.rs | 17 | ||||
| -rw-r--r-- | crates/atuin-ai/src/fsm/mod.rs | 50 | ||||
| -rw-r--r-- | crates/atuin-ai/src/lib.rs | 1 | ||||
| -rw-r--r-- | crates/atuin-ai/src/skills/frontmatter.rs | 233 | ||||
| -rw-r--r-- | crates/atuin-ai/src/skills/mod.rs | 468 | ||||
| -rw-r--r-- | crates/atuin-ai/src/skills/walker.rs | 178 | ||||
| -rw-r--r-- | crates/atuin-ai/src/stream.rs | 15 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/descriptor.rs | 10 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tools/mod.rs | 54 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/state.rs | 24 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/mod.rs | 30 | ||||
| -rw-r--r-- | crates/atuin-ai/src/tui/view/turn.rs | 14 | ||||
| -rw-r--r-- | crates/atuin-ai/src/user_context/mod.rs | 2 |
18 files changed, 1273 insertions, 30 deletions
diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs index adedc542..70f26c65 100644 --- a/crates/atuin-ai/src/commands/inline.rs +++ b/crates/atuin-ai/src/commands/inline.rs @@ -247,8 +247,15 @@ async fn run_inline_tui( let in_git_project = ctx.git_root.is_some(); + // ─── Discover skills ─────────────────────────────────────── + let project_root = ctx + .git_root + .clone() + .or_else(|| std::env::current_dir().ok()); + let skill_registry = crate::skills::SkillRegistry::discover(project_root.as_deref()).await; + // ─── Build initial ViewState from FSM ─────────────────────── - let initial_view = build_view_state(&fsm, in_git_project); + let initial_view = build_view_state(&fsm, in_git_project, &skill_registry); // ─── Build IoContext ──────────────────────────────────────── let io = IoContext { @@ -258,6 +265,7 @@ async fn run_inline_tui( file_tracker, edit_permissions, snapshot_store, + skill_registry, }; // ─── Channel + Application ────────────────────────────────── @@ -324,8 +332,23 @@ impl DriverEventSender { /// Build a ViewState snapshot from FSM state. Used for the initial view /// and by the driver for ongoing sync. -fn build_view_state(fsm: &AgentFsm, in_git_project: bool) -> ViewState { +fn build_view_state( + fsm: &AgentFsm, + in_git_project: bool, + skill_registry: &crate::skills::SkillRegistry, +) -> ViewState { let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); + + let mut slash_registry = crate::tui::slash::SlashCommandRegistry::default(); + let mut skill_names = std::collections::HashSet::new(); + for skill in skill_registry.all() { + slash_registry.register(crate::tui::slash::SlashCommand::new( + &skill.name, + &skill.description, + )); + skill_names.insert(skill.name.clone()); + } + ViewState { agent_state: fsm.state.clone(), visible_events: fsm.ctx.events[safe_start..].to_vec(), @@ -341,7 +364,8 @@ fn build_view_state(fsm: &AgentFsm, in_git_project: bool) -> ViewState { slash_command_input: None, slash_command_search_results: Vec::new(), exit_action: None, - slash_registry: Default::default(), + slash_registry, + skill_names, } } diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs index 93fcf9b9..625de0c6 100644 --- a/crates/atuin-ai/src/context.rs +++ b/crates/atuin-ai/src/context.rs @@ -33,6 +33,7 @@ impl AppContext { if self.capabilities.enable_command_execution.unwrap_or(true) { caps.push("client_v1_execute_shell_command".to_string()); } + caps.push("client_v1_load_skill".to_string()); if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { caps.extend( extra diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs index b5e1c275..1285f2da 100644 --- a/crates/atuin-ai/src/driver.rs +++ b/crates/atuin-ai/src/driver.rs @@ -55,6 +55,7 @@ pub(crate) struct IoContext { pub file_tracker: FileReadTracker, pub edit_permissions: EditPermissionCache, pub snapshot_store: Option<crate::snapshots::SnapshotStore>, + pub skill_registry: crate::skills::SkillRegistry, } // ============================================================================ @@ -87,6 +88,7 @@ pub(crate) struct ViewState { pub slash_command_search_results: Vec<crate::tui::slash::SlashCommandSearchResult>, pub exit_action: Option<ExitAction>, pub slash_registry: crate::tui::slash::SlashCommandRegistry, + pub skill_names: std::collections::HashSet<String>, } impl ViewState { @@ -253,11 +255,18 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< } else if input == "/new" { Some(Event::NewSession) } else if input.starts_with('/') { - let content = resolve_slash_command(&input, handle); - Some(Event::SlashCommand { - command: input, - content, - }) + if let Some((skill_name, arguments)) = resolve_skill_name(&input, handle) { + Some(Event::RequestSkillLoad { + name: skill_name, + arguments, + }) + } else { + let content = resolve_slash_command(&input, handle); + Some(Event::SlashCommand { + command: input, + content, + }) + } } else { Some(Event::UserSubmit(input)) } @@ -295,7 +304,9 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< .fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone())) .blocking_recv() .ok() - .flatten()?; + .flatten(); + + let tool_id = tool_id?; let choice = match result { PermissionResult::Allow => PermissionChoice::Allow, @@ -307,16 +318,50 @@ fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option< Some(Event::PermissionUserChoice { tool_id, choice }) } AiTuiEvent::SlashCommand(cmd) => { - let content = resolve_slash_command(&cmd, handle); - Some(Event::SlashCommand { - command: cmd, - content, - }) + if let Some((skill_name, arguments)) = resolve_skill_name(&cmd, handle) { + Some(Event::RequestSkillLoad { + name: skill_name, + arguments, + }) + } else { + let content = resolve_slash_command(&cmd, handle); + Some(Event::SlashCommand { + command: cmd, + content, + }) + } } } } /// Resolve a slash command to its output content. +/// If the input starts with `/`, check whether the command name matches a +/// registered skill. Returns `Some((skill_name, arguments))` if it does. +fn resolve_skill_name(input: &str, handle: &Handle<ViewState>) -> Option<(String, Option<String>)> { + let after_slash = input.trim_start_matches('/'); + let cmd_name = after_slash.split_whitespace().next()?.to_string(); + + let is_skill = handle + .fetch({ + let cmd_name = cmd_name.clone(); + move |vs| vs.skill_names.contains(&cmd_name) + }) + .blocking_recv() + .unwrap_or(false); + + if !is_skill { + return None; + } + + let args = after_slash + .strip_prefix(&cmd_name) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()); + + Some((cmd_name, args)) +} + fn resolve_slash_command(command: &str, handle: &Handle<ViewState>) -> String { match command.trim() { "/help" => { @@ -406,6 +451,7 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { let tx = tx.clone(); let app = io.app_ctx.clone(); let cc = io.client_ctx.clone(); + let (skill_summaries, skill_overflow) = io.skill_registry.server_skills(); let request = ChatRequest::new( messages.clone(), session_id.clone(), @@ -413,7 +459,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { fsm.ctx.invocation_id.clone(), ); tokio::spawn(async move { - run_stream_bridge(request, app, cc, tx, cancel_rx).await; + run_stream_bridge( + request, + app, + cc, + tx, + cancel_rx, + skill_summaries, + skill_overflow, + ) + .await; }); } @@ -427,6 +482,16 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { let tool_id = tool_id.clone(); let tool = tool.clone(); let tx = tx.clone(); + + // Auto-approved tools (e.g. load_skill) bypass permission checks entirely + if tool.is_auto_approved() { + let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { + tool_id, + response: PermissionResponse::Allowed, + })); + return; + } + let working_dir = tool .target_dir() .map(|p| p.to_path_buf()) @@ -641,9 +706,50 @@ fn execute_effect(effect: &Effect, ctx: DriverContext) { })); }); } + ClientToolCall::LoadSkill(skill_call) => { + let skill_name = skill_call.name.clone(); + let registry = io.skill_registry.clone(); + let shell = io + .client_ctx + .shell + .clone() + .unwrap_or_else(|| "sh".to_string()); + + tokio::spawn(async move { + let content = + load_skill_content(®istry, &skill_name, &shell, None).await; + let outcome = crate::tools::ToolOutcome::Success(content); + let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { + tool_id, + outcome, + preview: None, + })); + }); + } } } + Effect::LoadSkill { name, arguments } => { + let name = name.clone(); + let arguments = arguments.clone(); + let registry = io.skill_registry.clone(); + let shell = io + .client_ctx + .shell + .clone() + .unwrap_or_else(|| "sh".to_string()); + let tx = tx.clone(); + tokio::spawn(async move { + let content = + load_skill_content(®istry, &name, &shell, arguments.as_deref()).await; + let _ = tx.send(DriverEvent::Fsm(Event::SkillLoaded { + name, + arguments, + content, + })); + }); + } + Effect::AbortTool { tool_id } => { if let Some(abort_tx) = tool_abort_txs.remove(tool_id) { let _ = abort_tx.send(()); @@ -762,6 +868,22 @@ fn persist(fsm: &AgentFsm, io: &mut IoContext) { } // ============================================================================ +// Skill loading +// ============================================================================ + +async fn load_skill_content( + registry: &crate::skills::SkillRegistry, + name: &str, + shell: &str, + arguments: Option<&str>, +) -> String { + match registry.load(name, shell, arguments).await { + Ok(body) => body, + Err(e) => format!("Failed to load skill '{name}': {e}"), + } +} + +// ============================================================================ // Stream bridge // ============================================================================ @@ -771,6 +893,8 @@ async fn run_stream_bridge( client_ctx: ClientContext, tx: mpsc::Sender<DriverEvent>, mut cancel_rx: tokio::sync::watch::Receiver<()>, + skill_summaries: Vec<crate::skills::SkillSummary>, + skill_overflow: Option<String>, ) { use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; use futures::StreamExt; @@ -790,6 +914,8 @@ async fn run_stream_bridge( app_ctx.send_cwd, app_ctx.last_command.clone(), user_contexts, + skill_summaries, + skill_overflow, ); futures::pin_mut!(stream); diff --git a/crates/atuin-ai/src/event_serde.rs b/crates/atuin-ai/src/event_serde.rs index 546d6e5b..e3f9d6f7 100644 --- a/crates/atuin-ai/src/event_serde.rs +++ b/crates/atuin-ai/src/event_serde.rs @@ -64,6 +64,19 @@ pub(crate) fn serialize_event(event: &ConversationEvent) -> (String, String) { "system_context".to_string(), serde_json::json!({ "content": content }).to_string(), ), + ConversationEvent::SkillInvocation { + name, + arguments, + content, + } => ( + "skill_invocation".to_string(), + serde_json::json!({ + "name": name, + "arguments": arguments, + "content": content, + }) + .to_string(), + ), } } @@ -112,6 +125,14 @@ pub(crate) fn deserialize_event(event_type: &str, event_data: &str) -> Result<Co "system_context" => Ok(ConversationEvent::SystemContext { content: json_string(&data, "content")?, }), + "skill_invocation" => Ok(ConversationEvent::SkillInvocation { + name: json_string(&data, "name")?, + arguments: data + .get("arguments") + .and_then(|v| if v.is_null() { None } else { v.as_str() }) + .map(String::from), + content: json_string(&data, "content")?, + }), other => Err(eyre!("unknown event type: {other}")), } } diff --git a/crates/atuin-ai/src/fsm/effects.rs b/crates/atuin-ai/src/fsm/effects.rs index 306f1401..adc9628e 100644 --- a/crates/atuin-ai/src/fsm/effects.rs +++ b/crates/atuin-ai/src/fsm/effects.rs @@ -45,6 +45,11 @@ pub(crate) enum Effect { }, /// Kill a running tool (send interrupt to shell command). AbortTool { tool_id: String }, + /// Load a skill's content asynchronously (read + interpolate). + LoadSkill { + name: String, + arguments: Option<String>, + }, // ─── Persistence ──────────────────────────────────────────── /// Persist current conversation state to disk. diff --git a/crates/atuin-ai/src/fsm/events.rs b/crates/atuin-ai/src/fsm/events.rs index 6fecda08..e591db41 100644 --- a/crates/atuin-ai/src/fsm/events.rs +++ b/crates/atuin-ai/src/fsm/events.rs @@ -92,6 +92,23 @@ pub(crate) enum Event { /// The driver resolves known commands (like /help) and passes the /// rendered content; the FSM just pushes an OOB event. SlashCommand { command: String, content: String }, + + // ─── Skills ──────────────────────────────────────────────── + /// User invoked a skill via /skill-name. FSM emits a LoadSkill + /// effect; the driver loads the content asynchronously and sends + /// SkillLoaded when ready. + RequestSkillLoad { + name: String, + arguments: Option<String>, + }, + /// A skill's content has been loaded and interpolated. + /// Pushes skill content as OOB context and starts a turn so the + /// LLM sees the skill and acts on it. + SkillLoaded { + name: String, + arguments: Option<String>, + content: String, + }, } /// Result of the permission resolver check. diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs index 25de41f3..3d72a3ae 100644 --- a/crates/atuin-ai/src/fsm/mod.rs +++ b/crates/atuin-ai/src/fsm/mod.rs @@ -309,6 +309,33 @@ impl AgentFsm { vec![] } + ( + AgentState::Idle { .. }, + Event::SkillLoaded { + name, + arguments, + content, + }, + ) => { + self.ctx.events.push(ConversationEvent::SkillInvocation { + name, + arguments, + content, + }); + self.ctx.current_response.clear(); + self.ctx.current_turn_tool_ids.clear(); + + let messages = self.build_messages(); + let session_id = self.ctx.session_id.clone(); + self.state = AgentState::Turn { + stream: StreamPhase::Connecting, + }; + vec![Effect::StartStream { + messages, + session_id, + }] + } + // ================================================================ // Turn — stream lifecycle // ================================================================ @@ -584,6 +611,29 @@ impl AgentFsm { vec![] } + // RequestSkillLoad during non-idle: still emit the effect + (_, Event::RequestSkillLoad { name, arguments }) => { + vec![Effect::LoadSkill { name, arguments }] + } + + // SkillLoaded during non-idle: queue so it's visible + // in context for the next turn. + ( + _, + Event::SkillLoaded { + name, + arguments, + content, + }, + ) => { + self.ctx.events.push(ConversationEvent::SkillInvocation { + name, + arguments, + content, + }); + vec![] + } + _ => vec![], } } diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs index 289f6ea2..b3587739 100644 --- a/crates/atuin-ai/src/lib.rs +++ b/crates/atuin-ai/src/lib.rs @@ -9,6 +9,7 @@ pub(crate) mod file_tracker; pub(crate) mod fsm; pub(crate) mod permissions; pub(crate) mod session; +pub(crate) mod skills; pub(crate) mod snapshots; pub(crate) mod store; pub(crate) mod stream; diff --git a/crates/atuin-ai/src/skills/frontmatter.rs b/crates/atuin-ai/src/skills/frontmatter.rs new file mode 100644 index 00000000..759dffcc --- /dev/null +++ b/crates/atuin-ai/src/skills/frontmatter.rs @@ -0,0 +1,233 @@ +//! YAML frontmatter parsing for `SKILL.md` files. +//! +//! Extracts the YAML block between `---` delimiters and parses it with +//! `yaml-rust2`. Returns the parsed fields and the byte offset where the +//! body begins (after the closing `---`). + +use yaml_rust2::YamlLoader; + +/// Parsed frontmatter fields from a `SKILL.md` file. +#[derive(Debug, Default)] +pub(crate) struct Frontmatter { + pub name: Option<String>, + pub description: Option<String>, + pub disable_model_invocation: bool, +} + +/// Result of splitting a skill file into frontmatter + body. +#[derive(Debug)] +pub(crate) struct ParsedSkillFile { + pub frontmatter: Frontmatter, + /// Everything after the closing `---` delimiter. + pub body: String, +} + +/// Parse a `SKILL.md` file's content into frontmatter and body. +/// +/// If no frontmatter delimiters are found, all content is treated as body +/// with default frontmatter. +pub(crate) fn parse(content: &str) -> ParsedSkillFile { + let Some((yaml_str, body)) = split_frontmatter(content) else { + return ParsedSkillFile { + frontmatter: Frontmatter::default(), + body: content.to_string(), + }; + }; + + let frontmatter = match YamlLoader::load_from_str(yaml_str) { + Ok(docs) if !docs.is_empty() => extract_fields(&docs[0]), + Ok(_) => Frontmatter::default(), + Err(e) => { + tracing::warn!("Failed to parse skill frontmatter: {e}"); + Frontmatter::default() + } + }; + + ParsedSkillFile { frontmatter, body } +} + +/// Split content on `---` delimiters. Returns `(yaml_str, body)` or `None` +/// if frontmatter is not present. +fn split_frontmatter(content: &str) -> Option<(&str, String)> { + let trimmed = content.trim_start(); + + // Must start with `---` + if !trimmed.starts_with("---") { + return None; + } + + // Find the end of the opening delimiter line + let after_open = trimmed.get(3..)?.trim_start_matches(|c: char| c != '\n'); + let after_open = after_open.strip_prefix('\n').unwrap_or(after_open); + + // Find the closing `---` + let close_pos = after_open + .lines() + .enumerate() + .find(|(_, line)| line.trim() == "---") + .map(|(i, _)| { + after_open + .lines() + .take(i) + .map(|l| l.len() + 1) // +1 for newline + .sum::<usize>() + })?; + + let yaml_str = &after_open[..close_pos]; + let rest = &after_open[close_pos..]; + // Skip the closing `---` line + let body = rest + .strip_prefix("---") + .unwrap_or(rest) + .trim_start_matches(|c: char| c != '\n'); + let body = body.strip_prefix('\n').unwrap_or(body); + + Some((yaml_str, body.to_string())) +} + +fn extract_fields(doc: &yaml_rust2::Yaml) -> Frontmatter { + use yaml_rust2::Yaml; + + let name = match &doc["name"] { + Yaml::String(s) => Some(s.clone()), + _ => None, + }; + + let description = match &doc["description"] { + Yaml::String(s) => Some(s.trim().to_string()), + _ => None, + }; + + let disable_model_invocation = match &doc["disable-model-invocation"] { + Yaml::Boolean(b) => *b, + Yaml::String(s) => matches!(s.as_str(), "true" | "yes" | "1"), + _ => false, + }; + + Frontmatter { + name, + description, + disable_model_invocation, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic_frontmatter() { + let content = "\ +--- +name: my-skill +description: A test skill +disable-model-invocation: true +--- + +Body content here. +"; + let parsed = parse(content); + assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); + assert_eq!( + parsed.frontmatter.description.as_deref(), + Some("A test skill") + ); + assert!(parsed.frontmatter.disable_model_invocation); + assert_eq!(parsed.body.trim(), "Body content here."); + } + + #[test] + fn multiline_folded_description() { + let content = "\ +--- +name: release +description: > + Orchestrate a multi-step release — version bumping, changelog + generation, PR creation, tagging, and publishing. +disable-model-invocation: true +--- + +# Release steps +"; + let parsed = parse(content); + assert_eq!(parsed.frontmatter.name.as_deref(), Some("release")); + let desc = parsed.frontmatter.description.unwrap(); + assert!(desc.contains("Orchestrate a multi-step release")); + assert!(desc.contains("publishing")); + assert!(parsed.frontmatter.disable_model_invocation); + assert!(parsed.body.contains("# Release steps")); + } + + #[test] + fn no_frontmatter() { + let content = "Just a body with no frontmatter."; + let parsed = parse(content); + assert!(parsed.frontmatter.name.is_none()); + assert!(parsed.frontmatter.description.is_none()); + assert!(!parsed.frontmatter.disable_model_invocation); + assert_eq!(parsed.body, content); + } + + #[test] + fn empty_frontmatter() { + let content = "\ +--- +--- + +Body after empty frontmatter. +"; + let parsed = parse(content); + assert!(parsed.frontmatter.name.is_none()); + assert!(parsed.frontmatter.description.is_none()); + assert_eq!(parsed.body.trim(), "Body after empty frontmatter."); + } + + #[test] + fn missing_fields_use_defaults() { + let content = "\ +--- +name: partial +--- + +Some body. +"; + let parsed = parse(content); + assert_eq!(parsed.frontmatter.name.as_deref(), Some("partial")); + assert!(parsed.frontmatter.description.is_none()); + assert!(!parsed.frontmatter.disable_model_invocation); + } + + #[test] + fn unknown_fields_ignored() { + let content = "\ +--- +name: my-skill +future-field: some value +another: 42 +--- + +Body. +"; + let parsed = parse(content); + assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); + } + + #[test] + fn body_with_triple_dashes() { + let content = "\ +--- +name: test +--- + +Some body. + +--- + +More body after a horizontal rule. +"; + let parsed = parse(content); + assert_eq!(parsed.frontmatter.name.as_deref(), Some("test")); + assert!(parsed.body.contains("Some body.")); + assert!(parsed.body.contains("More body after a horizontal rule.")); + } +} diff --git a/crates/atuin-ai/src/skills/mod.rs b/crates/atuin-ai/src/skills/mod.rs new file mode 100644 index 00000000..36b3a2ae --- /dev/null +++ b/crates/atuin-ai/src/skills/mod.rs @@ -0,0 +1,468 @@ +//! AI skill discovery, metadata, and lazy loading. +//! +//! Skills are markdown files (`SKILL.md`) with YAML frontmatter that define +//! reusable instructions for the LLM. Only skill metadata (name + description) +//! is sent to the server; full content is loaded on demand via `load_skill`. + +mod frontmatter; +pub(crate) mod walker; + +use std::path::Path; + +use eyre::{Result, eyre}; + +use crate::user_context::interpolate; + +/// Per-skill description truncation limit (before budget calculation). +const MAX_DESCRIPTION_LEN: usize = 1024; + +/// Default total character budget for skill descriptions sent to the server. +const DEFAULT_DESCRIPTION_BUDGET: usize = 9992; + +/// JSON overhead per skill entry: `{"name":"","description":""},` ≈ 30 chars. +const PER_ENTRY_OVERHEAD: usize = 30; + +/// Metadata for a discovered skill. Produced at discovery time from +/// frontmatter only — the body is not read until `load()`. +#[derive(Debug, Clone)] +pub(crate) struct SkillDescriptor { + pub name: String, + pub description: String, + pub source_path: std::path::PathBuf, + pub disable_model_invocation: bool, +} + +/// A name + description pair ready to serialize into the request payload. +#[derive(Debug, Clone, serde::Serialize)] +pub(crate) struct SkillSummary { + pub name: String, + pub description: String, +} + +/// Holds discovered skills and provides lookup, budget packing, and loading. +#[derive(Debug, Clone)] +pub(crate) struct SkillRegistry { + skills: Vec<SkillDescriptor>, +} + +impl SkillRegistry { + /// Discover skills from project and global directories. + pub async fn discover(project_root: Option<&Path>) -> Self { + let global_dir = walker::global_skills_dir(); + let project_dir = project_root.map(walker::project_skills_dir); + + Self::discover_from_dirs(project_dir.as_deref(), &global_dir).await + } + + /// Discover skills from explicit directory paths. Useful for testing. + pub async fn discover_from_dirs( + project_skills_dir: Option<&Path>, + global_skills_dir: &Path, + ) -> Self { + let raw_files = walker::discover(project_skills_dir, global_skills_dir).await; + + let mut skills = Vec::new(); + let mut seen_names = std::collections::HashSet::new(); + + for raw in raw_files { + let parsed = frontmatter::parse(&raw.content); + let fm = parsed.frontmatter; + + let name = fm.name.unwrap_or_else(|| sanitize_name(&raw.dir_name)); + + // Deduplicate: first seen wins (project before global) + if !seen_names.insert(name.clone()) { + continue; + } + + let description = fm + .description + .or_else(|| first_paragraph(&parsed.body)) + .unwrap_or_default(); + + skills.push(SkillDescriptor { + name, + description, + source_path: raw.path, + disable_model_invocation: fm.disable_model_invocation, + }); + } + + Self { skills } + } + + /// Create an empty registry. + #[cfg(test)] + pub fn empty() -> Self { + Self { skills: Vec::new() } + } + + /// Look up a skill by name. + pub fn get(&self, name: &str) -> Option<&SkillDescriptor> { + self.skills.iter().find(|s| s.name == name) + } + + /// All discovered skills. + pub fn all(&self) -> &[SkillDescriptor] { + &self.skills + } + + /// Whether any non-disabled skills exist (determines capability advertisement). + #[cfg(test)] + pub fn has_server_visible_skills(&self) -> bool { + self.skills.iter().any(|s| !s.disable_model_invocation) + } + + /// Pack skill descriptions into the server payload under a character budget. + /// + /// Returns the summaries that fit plus an optional overflow message. + pub fn server_skills(&self) -> (Vec<SkillSummary>, Option<String>) { + self.server_skills_with_budget(DEFAULT_DESCRIPTION_BUDGET) + } + + pub fn server_skills_with_budget(&self, budget: usize) -> (Vec<SkillSummary>, Option<String>) { + let eligible: Vec<&SkillDescriptor> = self + .skills + .iter() + .filter(|s| !s.disable_model_invocation) + .collect(); + + let mut summaries = Vec::new(); + let mut used = 0; + let mut overflow_names = Vec::new(); + + for skill in &eligible { + let truncated_desc = truncate_description(&skill.description, MAX_DESCRIPTION_LEN); + let entry_size = skill.name.len() + truncated_desc.len() + PER_ENTRY_OVERHEAD; + + if used + entry_size > budget && !summaries.is_empty() { + overflow_names.push(skill.name.as_str()); + continue; + } + + used += entry_size; + summaries.push(SkillSummary { + name: skill.name.clone(), + description: truncated_desc, + }); + } + + let overflow = if overflow_names.is_empty() { + None + } else { + Some(format!( + "{} additional skill(s) not listed due to size limits: {}", + overflow_names.len(), + overflow_names.join(", ") + )) + }; + + (summaries, overflow) + } + + /// Load a skill's full body content, with argument substitution and + /// `!`` interpolation applied. + /// + /// `$ARGUMENTS` in the body is replaced with the provided arguments before + /// shell interpolation runs. If `$ARGUMENTS` does not appear in the body + /// and arguments were provided, they are appended as `ARGUMENTS: <value>`. + pub async fn load(&self, name: &str, shell: &str, arguments: Option<&str>) -> Result<String> { + let skill = self + .get(name) + .ok_or_else(|| eyre!("Unknown skill: {name}"))?; + + let content = tokio::fs::read_to_string(&skill.source_path).await?; + let parsed = frontmatter::parse(&content); + let body = parsed.body; + + if body.trim().is_empty() { + return Ok(format!("(Skill '{name}' has no body content)")); + } + + let body = substitute_arguments(&body, arguments); + + Ok(interpolate::interpolate(&body, shell).await) + } +} + +/// Replace `$ARGUMENTS` placeholders in skill body text. +/// +/// If `$ARGUMENTS` appears in the body, all occurrences are replaced with the +/// argument string (or empty string if none). If `$ARGUMENTS` does not appear +/// and arguments were provided, they are appended on a new line. +fn substitute_arguments(body: &str, arguments: Option<&str>) -> String { + let args = arguments.unwrap_or(""); + + if body.contains("$ARGUMENTS") { + return body.replace("$ARGUMENTS", args); + } + + if !args.is_empty() { + return format!("{body}\n\nARGUMENTS: {args}"); + } + + body.to_string() +} + +/// Sanitize a directory name into a valid skill name. +fn sanitize_name(name: &str) -> String { + name.chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' { + c + } else { + '-' + } + }) + .collect::<String>() + .to_lowercase() +} + +/// Extract the first non-empty paragraph from markdown body text. +fn first_paragraph(body: &str) -> Option<String> { + let trimmed = body.trim(); + if trimmed.is_empty() { + return None; + } + + let para: String = trimmed + .lines() + .take_while(|line| !line.trim().is_empty()) + .collect::<Vec<_>>() + .join(" "); + + let para = para.trim().to_string(); + if para.is_empty() { None } else { Some(para) } +} + +/// Truncate a description to `max_len` characters, adding ellipsis if cut. +fn truncate_description(desc: &str, max_len: usize) -> String { + if desc.len() <= max_len { + return desc.to_string(); + } + let mut end = max_len.saturating_sub(3); + // Avoid splitting a multi-byte char + while !desc.is_char_boundary(end) && end > 0 { + end -= 1; + } + format!("{}...", &desc[..end]) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sanitize_name_basic() { + assert_eq!(sanitize_name("My Skill"), "my-skill"); + assert_eq!(sanitize_name("deploy_prod"), "deploy-prod"); + assert_eq!(sanitize_name("code-review"), "code-review"); + } + + #[test] + fn first_paragraph_extraction() { + assert_eq!( + first_paragraph("Hello world\nSecond line\n\nNew paragraph"), + Some("Hello world Second line".to_string()) + ); + assert_eq!(first_paragraph(""), None); + assert_eq!(first_paragraph("\n\n"), None); + assert_eq!( + first_paragraph("Single line"), + Some("Single line".to_string()) + ); + } + + #[test] + fn truncate_description_short() { + assert_eq!(truncate_description("short", 100), "short"); + } + + #[test] + fn substitute_arguments_replaces_placeholder() { + let body = "Deploy $ARGUMENTS to production."; + assert_eq!( + substitute_arguments(body, Some("patch")), + "Deploy patch to production." + ); + } + + #[test] + fn substitute_arguments_multiple_occurrences() { + let body = "Run $ARGUMENTS then verify $ARGUMENTS worked."; + assert_eq!( + substitute_arguments(body, Some("migrate")), + "Run migrate then verify migrate worked." + ); + } + + #[test] + fn substitute_arguments_appends_when_no_placeholder() { + let body = "Do the thing."; + let result = substitute_arguments(body, Some("extra context")); + assert!(result.starts_with("Do the thing.")); + assert!(result.contains("ARGUMENTS: extra context")); + } + + #[test] + fn substitute_arguments_no_args_no_placeholder() { + let body = "Just a body."; + assert_eq!(substitute_arguments(body, None), "Just a body."); + } + + #[test] + fn substitute_arguments_no_args_clears_placeholder() { + let body = "Deploy $ARGUMENTS to production."; + assert_eq!(substitute_arguments(body, None), "Deploy to production."); + } + + #[test] + fn truncate_description_long() { + let long = "a".repeat(600); + let result = truncate_description(&long, 512); + assert!(result.len() <= 512); + assert!(result.ends_with("...")); + } + + #[test] + fn budget_packing() { + let registry = SkillRegistry { + skills: vec![ + SkillDescriptor { + name: "a".to_string(), + description: "Short desc".to_string(), + source_path: "a/SKILL.md".into(), + disable_model_invocation: false, + }, + SkillDescriptor { + name: "b".to_string(), + description: "Another desc".to_string(), + source_path: "b/SKILL.md".into(), + disable_model_invocation: false, + }, + ], + }; + + let (summaries, overflow) = registry.server_skills_with_budget(4096); + assert_eq!(summaries.len(), 2); + assert!(overflow.is_none()); + } + + #[test] + fn budget_overflow() { + let registry = SkillRegistry { + skills: vec![ + SkillDescriptor { + name: "first".to_string(), + description: "x".repeat(200), + source_path: "a/SKILL.md".into(), + disable_model_invocation: false, + }, + SkillDescriptor { + name: "second".to_string(), + description: "y".repeat(200), + source_path: "b/SKILL.md".into(), + disable_model_invocation: false, + }, + ], + }; + + // Budget only fits one + let (summaries, overflow) = registry.server_skills_with_budget(260); + assert_eq!(summaries.len(), 1); + assert_eq!(summaries[0].name, "first"); + let overflow = overflow.unwrap(); + assert!(overflow.contains("second")); + assert!(overflow.contains("1 additional")); + } + + #[test] + fn disabled_skills_excluded_from_server() { + let registry = SkillRegistry { + skills: vec![ + SkillDescriptor { + name: "visible".to_string(), + description: "I show up".to_string(), + source_path: "a/SKILL.md".into(), + disable_model_invocation: false, + }, + SkillDescriptor { + name: "hidden".to_string(), + description: "I don't".to_string(), + source_path: "b/SKILL.md".into(), + disable_model_invocation: true, + }, + ], + }; + + let (summaries, _) = registry.server_skills(); + assert_eq!(summaries.len(), 1); + assert_eq!(summaries[0].name, "visible"); + + // But all() includes both + assert_eq!(registry.all().len(), 2); + } + + #[test] + fn has_server_visible_skills() { + let empty = SkillRegistry::empty(); + assert!(!empty.has_server_visible_skills()); + + let all_disabled = SkillRegistry { + skills: vec![SkillDescriptor { + name: "hidden".to_string(), + description: String::new(), + source_path: "a/SKILL.md".into(), + disable_model_invocation: true, + }], + }; + assert!(!all_disabled.has_server_visible_skills()); + + let some_visible = SkillRegistry { + skills: vec![SkillDescriptor { + name: "visible".to_string(), + description: String::new(), + source_path: "a/SKILL.md".into(), + disable_model_invocation: false, + }], + }; + assert!(some_visible.has_server_visible_skills()); + } + + #[tokio::test] + async fn end_to_end_discover() { + let dir = tempfile::tempdir().unwrap(); + let skills_dir = dir.path().join("skills"); + + // Create a skill with frontmatter + let skill_dir = skills_dir.join("my-skill"); + std::fs::create_dir_all(&skill_dir).unwrap(); + std::fs::write( + skill_dir.join("SKILL.md"), + "---\nname: my-skill\ndescription: A test skill\n---\n\nBody here.\n", + ) + .unwrap(); + + // Create a skill with multiline description + let skill_dir2 = skills_dir.join("release"); + std::fs::create_dir_all(&skill_dir2).unwrap(); + std::fs::write( + skill_dir2.join("SKILL.md"), + "---\nname: release\ndescription: >\n Multi-line\n description here.\n---\n\nRelease steps.\n", + ) + .unwrap(); + + let registry = SkillRegistry::discover_from_dirs( + Some(&skills_dir), + &std::path::PathBuf::from("/nonexistent"), + ) + .await; + assert_eq!(registry.all().len(), 2); + + let my_skill = registry.get("my-skill").unwrap(); + assert_eq!(my_skill.description, "A test skill"); + + let release = registry.get("release").unwrap(); + assert!(release.description.contains("Multi-line")); + } +} diff --git a/crates/atuin-ai/src/skills/walker.rs b/crates/atuin-ai/src/skills/walker.rs new file mode 100644 index 00000000..b93845f9 --- /dev/null +++ b/crates/atuin-ai/src/skills/walker.rs @@ -0,0 +1,178 @@ +//! Filesystem discovery for `SKILL.md` files. +//! +//! Recursively scans `.atuin/skills/` directories at the project and global +//! levels. Supports nested directories for organization (e.g. +//! `.atuin/skills/ops/deploy/SKILL.md`). + +use std::path::{Path, PathBuf}; + +const SKILL_FILENAME: &str = "SKILL.md"; + +/// A skill file found on disk, before body interpolation. +#[derive(Debug)] +pub(crate) struct RawSkillFile { + /// Full path to the SKILL.md file. + pub path: PathBuf, + /// The parent directory name, used as fallback skill name. + pub dir_name: String, + /// Whether this is a project-level skill (vs global). + #[allow(dead_code)] + pub is_project: bool, + /// Raw file content. + pub content: String, +} + +/// Discover all `SKILL.md` files across project and global skill directories. +/// +/// Project skills come first in the returned list (higher priority for +/// deduplication). +pub(crate) async fn discover( + project_skills_dir: Option<&Path>, + global_skills_dir: &Path, +) -> Vec<RawSkillFile> { + let mut files = Vec::new(); + + // Project skills first (higher priority) + if let Some(dir) = project_skills_dir.filter(|d| d.is_dir()) { + scan_dir(dir, true, &mut files).await; + } + + // Global skills second + if global_skills_dir.is_dir() { + scan_dir(global_skills_dir, false, &mut files).await; + } + + files +} + +/// The default global skills directory (`~/.config/atuin/skills/`). +pub(crate) fn global_skills_dir() -> PathBuf { + atuin_common::utils::config_dir().join("skills") +} + +/// Given a project working directory, return the project skills directory. +pub(crate) fn project_skills_dir(project_root: &Path) -> PathBuf { + project_root.join(".atuin").join("skills") +} + +/// Recursively scan a directory for `SKILL.md` files. +async fn scan_dir(dir: &Path, is_project: bool, out: &mut Vec<RawSkillFile>) { + let mut entries = match tokio::fs::read_dir(dir).await { + Ok(entries) => entries, + Err(e) => { + tracing::debug!("Could not read skills directory {}: {e}", dir.display()); + return; + } + }; + + let mut subdirs = Vec::new(); + + while let Ok(Some(entry)) = entries.next_entry().await { + let path = entry.path(); + + if path.is_dir() { + // Check for SKILL.md directly in this directory + let skill_path = path.join(SKILL_FILENAME); + if skill_path.is_file() { + let dir_name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + + match tokio::fs::read_to_string(&skill_path).await { + Ok(content) => { + out.push(RawSkillFile { + path: skill_path, + dir_name, + is_project, + content, + }); + } + Err(e) => { + tracing::warn!("Failed to read skill file {}: {e}", skill_path.display()); + } + } + } + + // Collect subdirectories for recursive scanning + subdirs.push(path); + } + } + + for subdir in subdirs { + Box::pin(scan_dir(&subdir, is_project, out)).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn setup_skill(dir: &Path, rel_path: &str, content: &str) { + let skill_dir = dir.join(rel_path); + std::fs::create_dir_all(&skill_dir).unwrap(); + std::fs::write(skill_dir.join(SKILL_FILENAME), content).unwrap(); + } + + #[tokio::test] + async fn discovers_project_skills() { + let dir = tempfile::tempdir().unwrap(); + let skills_dir = dir.path().join("skills"); + setup_skill(&skills_dir, "deploy", "---\nname: deploy\n---\nDeploy."); + + let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; + assert_eq!(files.len(), 1); + assert_eq!(files[0].dir_name, "deploy"); + assert!(files[0].is_project); + } + + #[tokio::test] + async fn discovers_global_skills() { + let dir = tempfile::tempdir().unwrap(); + let skills_dir = dir.path().join("skills"); + setup_skill(&skills_dir, "review", "---\nname: review\n---\nReview."); + + let files = discover(None, &skills_dir).await; + assert_eq!(files.len(), 1); + assert_eq!(files[0].dir_name, "review"); + assert!(!files[0].is_project); + } + + #[tokio::test] + async fn discovers_nested_skills() { + let dir = tempfile::tempdir().unwrap(); + let skills_dir = dir.path().join("skills"); + setup_skill(&skills_dir, "ops/deploy", "---\nname: deploy\n---\n"); + setup_skill(&skills_dir, "ops/rollback", "---\nname: rollback\n---\n"); + + let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; + assert_eq!(files.len(), 2); + } + + #[tokio::test] + async fn project_comes_before_global() { + let project = tempfile::tempdir().unwrap(); + let global = tempfile::tempdir().unwrap(); + let project_skills = project.path().join("skills"); + let global_skills = global.path().join("skills"); + + setup_skill(&project_skills, "a-skill", "project"); + setup_skill(&global_skills, "b-skill", "global"); + + let files = discover(Some(&project_skills), &global_skills).await; + assert_eq!(files.len(), 2); + assert!(files[0].is_project); + assert!(!files[1].is_project); + } + + #[tokio::test] + async fn missing_directories_handled() { + let files = discover( + Some(Path::new("/does/not/exist")), + Path::new("/also/missing"), + ) + .await; + assert!(files.is_empty()); + } +} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs index e7155a08..d921b39c 100644 --- a/crates/atuin-ai/src/stream.rs +++ b/crates/atuin-ai/src/stream.rs @@ -63,7 +63,10 @@ impl ChatRequest { capabilities: &AiCapabilities, invocation_id: String, ) -> Self { - let mut caps = vec!["client_invocations".to_string()]; + let mut caps = vec![ + "client_invocations".to_string(), + "client_v1_load_skill".to_string(), + ]; if capabilities.enable_history_search.unwrap_or(true) { caps.push("client_v1_atuin_history".to_string()); } @@ -93,6 +96,7 @@ impl ChatRequest { } } +#[allow(clippy::too_many_arguments)] pub(crate) fn create_chat_stream( hub_address: String, token: String, @@ -101,6 +105,8 @@ pub(crate) fn create_chat_stream( send_cwd: bool, last_command: Option<String>, user_contexts: Vec<crate::user_context::UserContext>, + skill_summaries: Vec<crate::skills::SkillSummary>, + skill_overflow: Option<String>, ) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { Box::pin(async_stream::stream! { ensure_crypto_provider(); @@ -124,6 +130,13 @@ pub(crate) fn create_chat_stream( config["user_contexts"] = serde_json::json!(user_contexts); } + if !skill_summaries.is_empty() { + config["skills"] = serde_json::json!(skill_summaries); + if let Some(ref overflow) = skill_overflow { + config["skills_overflow"] = serde_json::json!(overflow); + } + } + let mut request_body = serde_json::json!({ "messages": request.messages, "context": context, diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs index 6ccb595f..06858bf8 100644 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ b/crates/atuin-ai/src/tools/descriptor.rs @@ -67,6 +67,15 @@ pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { is_client: true, }; +pub(crate) const LOAD_SKILL: &ToolDescriptor = &ToolDescriptor { + canonical_names: &["load_skill"], + capability: Some("client_v1_load_skill"), + display_verb: "load skill", + progressive_verb: "Loading skill...", + past_verb: "Loaded skill", + is_client: true, +}; + // ── Server-side tool descriptors ── // These appear in tool summaries but aren't client-side tools. @@ -95,6 +104,7 @@ const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ WRITE, SHELL, ATUIN_HISTORY, + LOAD_SKILL, SERVER_SEARCH, SERVER_SCRAPE, ]; diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs index e66d64b8..fdda10a4 100644 --- a/crates/atuin-ai/src/tools/mod.rs +++ b/crates/atuin-ai/src/tools/mod.rs @@ -158,6 +158,7 @@ pub(crate) enum ClientToolCall { Write(WriteToolCall), Shell(ShellToolCall), AtuinHistory(AtuinHistoryToolCall), + LoadSkill(LoadSkillToolCall), } impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { @@ -172,6 +173,9 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { "atuin_history" => Ok(ClientToolCall::AtuinHistory( AtuinHistoryToolCall::try_from(input)?, )), + "load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from( + input, + )?)), _ => Err(eyre::eyre!("Unknown tool call: {name}")), } } @@ -185,6 +189,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => descriptor::WRITE, ClientToolCall::Shell(_) => descriptor::SHELL, ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, + ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL, } } @@ -200,6 +205,7 @@ impl ClientToolCall { ClientToolCall::Write(_) => "Write", ClientToolCall::Shell(_) => "Shell", ClientToolCall::AtuinHistory(_) => "AtuinHistory", + ClientToolCall::LoadSkill(_) => "LoadSkill", } } @@ -210,7 +216,9 @@ impl ClientToolCall { ClientToolCall::Read(tool) => Some(tool.resolved_path()), ClientToolCall::Edit(tool) => Some(tool.resolved_path()), ClientToolCall::Write(tool) => Some(tool.resolved_path()), - _ => None, + ClientToolCall::Shell(_) + | ClientToolCall::AtuinHistory(_) + | ClientToolCall::LoadSkill(_) => None, } } @@ -221,6 +229,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.matches_rule(rule), ClientToolCall::Shell(tool) => tool.matches_rule(rule), ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), + ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule), } } @@ -231,6 +240,7 @@ impl ClientToolCall { ClientToolCall::Write(tool) => tool.target_dir(), ClientToolCall::Shell(tool) => tool.target_dir(), ClientToolCall::AtuinHistory(tool) => tool.target_dir(), + ClientToolCall::LoadSkill(tool) => tool.target_dir(), } } @@ -239,6 +249,10 @@ impl ClientToolCall { match self { ClientToolCall::Read(tool) => tool.execute(), ClientToolCall::AtuinHistory(tool) => tool.execute(db).await, + // LoadSkill is handled separately by the driver (needs registry access) + ClientToolCall::LoadSkill(_) => { + ToolOutcome::Error("LoadSkill must be executed via the driver".to_string()) + } _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()), } } @@ -271,6 +285,7 @@ impl PermissableToolCall for ClientToolCall { fn all_covered_by(&self, rules: &[Rule]) -> bool { match self { ClientToolCall::Shell(tool) => tool.all_covered_by(rules), + // LoadSkill is always auto-approved, but support rules for completeness _ => rules.iter().any(|r| self.matches_rule(r)), } } @@ -280,6 +295,13 @@ impl PermissableToolCall for ClientToolCall { } } +/// Returns true if this tool call should bypass the permission system entirely. +impl ClientToolCall { + pub(crate) fn is_auto_approved(&self) -> bool { + matches!(self, ClientToolCall::LoadSkill(_)) + } +} + /// Expand shell constructs (`~`, `$HOME`, etc.) in a path string. /// /// Tool call paths arrive as raw strings from the API without shell @@ -1197,6 +1219,36 @@ impl AtuinHistoryToolCall { } } +#[derive(Debug, Clone)] +pub(crate) struct LoadSkillToolCall { + pub name: String, +} + +impl TryFrom<&serde_json::Value> for LoadSkillToolCall { + type Error = eyre::Error; + + fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { + let name = value + .get("name") + .and_then(|v| v.as_str()) + .ok_or(eyre::eyre!("Missing skill name"))?; + + Ok(LoadSkillToolCall { + name: name.to_string(), + }) + } +} + +impl PermissableToolCall for LoadSkillToolCall { + fn target_dir(&self) -> Option<&Path> { + None + } + + fn matches_rule(&self, rule: &Rule) -> bool { + rule.tool == "LoadSkill" + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs index e008bd3c..71da6ff5 100644 --- a/crates/atuin-ai/src/tui/state.rs +++ b/crates/atuin-ai/src/tui/state.rs @@ -37,6 +37,14 @@ pub(crate) enum ConversationEvent { /// Context injected for the LLM that is not rendered in the TUI. /// Converted to a user message in the API protocol. SystemContext { content: String }, + /// A skill was loaded and its content injected into the conversation. + /// Serialized as a full user message for the API but rendered compactly + /// in the TUI (just the `/name args` invocation line). + SkillInvocation { + name: String, + arguments: Option<String>, + content: String, + }, } impl ConversationEvent { @@ -49,6 +57,7 @@ impl ConversationEvent { ConversationEvent::ToolResult { .. } => true, ConversationEvent::OutOfBandOutput { .. } => false, ConversationEvent::SystemContext { .. } => false, + ConversationEvent::SkillInvocation { .. } => true, } } @@ -206,6 +215,21 @@ pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec<serde_json })); i += 1; } + ConversationEvent::SkillInvocation { + name, + arguments, + content, + } => { + let header = match arguments { + Some(args) => format!("[Loaded skill: {name}]\n[Arguments: {args}]"), + None => format!("[Loaded skill: {name}]"), + }; + messages.push(serde_json::json!({ + "role": "user", + "content": format!("{header}\n\n{content}") + })); + i += 1; + } } } diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs index 2061ec38..96ad5d85 100644 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ b/crates/atuin-ai/src/tui/view/mod.rs @@ -73,7 +73,7 @@ pub(crate) fn ai_view(state: &ViewState) -> Elements { user_turn_view(events, index == 0) } turn::UiTurn::Agent { events } => { - agent_turn_view(events, busy && index == last_index) + agent_turn_view(events, busy && index == last_index, state.tools.awaiting_permission().is_some()) } turn::UiTurn::OutOfBand { events } => { out_of_band_turn_view(events) @@ -85,7 +85,7 @@ pub(crate) fn ai_view(state: &ViewState) -> Elements { let needs_pending_banner = busy && !matches!(turns.last(), Some(turn::UiTurn::Agent { .. })); if needs_pending_banner { let empty: &[turn::UiEvent] = &[]; - agent_turn_view(empty, true) + agent_turn_view(empty, true, false) } else { element! {} } @@ -170,6 +170,7 @@ fn tool_call_view(tool_call: &crate::fsm::tools::TrackedTool, in_git_project: bo ClientToolCall::Write(tool) => tool.path.display().to_string(), ClientToolCall::Shell(tool) => tool.command.clone(), ClientToolCall::AtuinHistory(tool) => tool.query.clone(), + ClientToolCall::LoadSkill(tool) => format!("skill: {}", tool.name), }; let select_options = permission_options_for_tool(&tool_call.tool, in_git_project); @@ -273,21 +274,16 @@ fn user_turn_view(events: &[turn::UiEvent], first_turn: bool) -> Elements { } } -fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { +fn agent_turn_view(events: &[turn::UiEvent], busy: bool, showing_ui: bool) -> Elements { let label_style = Style::default() .fg(Color::Yellow) .add_modifier(Modifier::BOLD); element! { View { - Spinner( - label: " Atuin AI ", - label_style: label_style.reversed(), - done_label_style: label_style.reversed(), - hide_checkmark: true, - label_first: true, - done: !busy, - ) + Text { + Span(text: " Atuin AI ", style: label_style.reversed()) + } #(for (i, event) in events.iter().enumerate() { #(if i > 0 { Text { Span(text: "") } @@ -325,7 +321,8 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { tool_status_view(&details.name, &details.status) }, turn::ToolRenderData::FileRead { .. } - | turn::ToolRenderData::HistorySearch { .. } => { + | turn::ToolRenderData::HistorySearch { .. } + | turn::ToolRenderData::SkillLoad { .. } => { element!{} }, }) @@ -350,6 +347,15 @@ fn agent_turn_view(events: &[turn::UiEvent], busy: bool) -> Elements { _ => element!{} }) }) + + #(if busy && !showing_ui { + View(key: "agent-working-spinner", padding_left: Cells::from(2), padding_top: Cells::from(1)) { + Spinner( + label: "", + spinner_style: Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD), + ) + } + }) } } } diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs index 98ae5eff..9f4460eb 100644 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ b/crates/atuin-ai/src/tui/view/turn.rs @@ -151,6 +151,8 @@ pub(crate) enum ToolRenderData { query: String, filter_modes: Vec<HistorySearchFilterMode>, }, + /// Skill loading — read-only, auto-approved. + SkillLoad { _name: String }, /// Server-side tool — no client rendering data available. Remote, } @@ -257,6 +259,15 @@ impl<'a> TurnBuilder<'a> { ConversationEvent::SystemContext { .. } => { // Not rendered in the TUI — only sent to the API } + ConversationEvent::SkillInvocation { + name, arguments, .. + } => { + let display = match arguments { + Some(args) => format!("/{name} {args}"), + None => format!("/{name}"), + }; + self.add_user_message(&display); + } } } @@ -459,6 +470,9 @@ impl<'a> TurnBuilder<'a> { query: history.query.clone(), filter_modes: history.filter_modes.clone(), }, + ClientToolCall::LoadSkill(skill) => ToolRenderData::SkillLoad { + _name: skill.name.clone(), + }, } } else { // Not in tracker → server-side tool diff --git a/crates/atuin-ai/src/user_context/mod.rs b/crates/atuin-ai/src/user_context/mod.rs index 295efdec..fdeb890b 100644 --- a/crates/atuin-ai/src/user_context/mod.rs +++ b/crates/atuin-ai/src/user_context/mod.rs @@ -5,7 +5,7 @@ //! by walking the filesystem, commands are executed, and the interpolated //! content is sent to the server as `config.user_contexts`. -mod interpolate; +pub(crate) mod interpolate; mod walker; use std::path::Path; |
