diff options
Diffstat (limited to 'crates/atuin-ai/src/permissions')
| -rw-r--r-- | crates/atuin-ai/src/permissions/check.rs | 74 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/file.rs | 26 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/mod.rs | 7 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/resolver.rs | 31 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/rule.rs | 106 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/shell.rs | 1297 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/walker.rs | 121 | ||||
| -rw-r--r-- | crates/atuin-ai/src/permissions/writer.rs | 198 |
8 files changed, 1860 insertions, 0 deletions
diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs new file mode 100644 index 00000000..6b908b93 --- /dev/null +++ b/crates/atuin-ai/src/permissions/check.rs @@ -0,0 +1,74 @@ +use eyre::Result; + +use crate::{permissions::file::RuleFile, tools::PermissableToolCall}; + +pub(crate) struct PermissionRequest<'t> { + call: &'t (dyn PermissableToolCall + Send + Sync), +} + +impl<'t> PermissionRequest<'t> { + pub fn new(call: &'t (dyn PermissableToolCall + Send + Sync)) -> Self { + Self { call } + } +} + +pub(crate) enum PermissionResponse { + Allowed, + Denied, + Ask, +} + +pub(crate) struct PermissionChecker { + files: Vec<RuleFile>, +} + +impl PermissionChecker { + pub fn new(files: Vec<RuleFile>) -> Self { + Self { files } + } + + pub async fn check<'t>( + &self, + request: &'t PermissionRequest<'t>, + ) -> Result<PermissionResponse> { + // Files are in order from deepest to shallowest, so we can stop at the first match. + // Within a file, the priority is ask -> deny -> allow + // The first rule type that matches is the one that applies, even if a later rule would contradict it. + for file in &self.files { + for rule in &file.content.permissions.ask { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ASK' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Ask); + } + } + + for rule in &file.content.permissions.deny { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'DENY' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Denied); + } + } + + for rule in &file.content.permissions.allow { + if request.call.matches_rule(rule) { + tracing::debug!( + "Permission 'ALLOW' by rule: {} in file: {}", + rule, + file.path.display() + ); + return Ok(PermissionResponse::Allowed); + } + } + } + + Ok(PermissionResponse::Ask) + } +} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs new file mode 100644 index 00000000..c973f55b --- /dev/null +++ b/crates/atuin-ai/src/permissions/file.rs @@ -0,0 +1,26 @@ +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; + +use crate::permissions::rule::Rule; + +#[derive(Debug, Clone)] +pub(crate) struct RuleFile { + pub path: PathBuf, + pub content: RuleFileContent, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFileContent { + pub permissions: RuleFilePermissions, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub(crate) struct RuleFilePermissions { + #[serde(default)] + pub allow: Vec<Rule>, + #[serde(default)] + pub deny: Vec<Rule>, + #[serde(default)] + pub ask: Vec<Rule>, +} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs new file mode 100644 index 00000000..fce64a51 --- /dev/null +++ b/crates/atuin-ai/src/permissions/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod check; +pub(crate) mod file; +pub(crate) mod resolver; +pub(crate) mod rule; +pub(crate) mod shell; +pub(crate) mod walker; +pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs new file mode 100644 index 00000000..dc4f83bf --- /dev/null +++ b/crates/atuin-ai/src/permissions/resolver.rs @@ -0,0 +1,31 @@ +use std::path::PathBuf; + +use eyre::Result; + +use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; +use crate::permissions::walker::PermissionWalker; +use crate::permissions::writer; +use crate::tools::ClientToolCall; + +/// Resolves permissions for client tool calls by walking the filesystem to find permission files, +pub(crate) struct PermissionResolver { + checker: PermissionChecker, +} + +impl PermissionResolver { + /// Create a new resolver that walks from `working_dir` to root for project + /// permissions, and also checks the global permissions file. + pub async fn new(working_dir: PathBuf) -> Result<Self> { + let global_file = writer::global_permissions_path(); + let mut walker = PermissionWalker::new(working_dir, Some(global_file)); + walker.walk().await?; + let checker = PermissionChecker::new(walker.rules().to_owned()); + Ok(Self { checker }) + } + + /// Check whether `tool` is allowed, denied, or needs user confirmation. + pub async fn check(&self, tool: &ClientToolCall) -> Result<PermissionResponse> { + let request = PermissionRequest::new(tool); + self.checker.check(&request).await + } +} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs new file mode 100644 index 00000000..8fa3fa4a --- /dev/null +++ b/crates/atuin-ai/src/permissions/rule.rs @@ -0,0 +1,106 @@ +use std::sync::OnceLock; + +use regex::Regex; +use serde::{Deserialize, Serialize}; + +static RULE_RE: OnceLock<Regex> = OnceLock::new(); + +#[derive(Debug, thiserror::Error)] +pub(crate) enum RuleError { + #[error("invalid rule format: {0}")] + InvalidRule(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Rule { + pub tool: String, + pub scope: Option<String>, +} + +impl std::fmt::Display for Rule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.scope.as_ref() { + Some(scope) => write!(f, "{}({})", self.tool, scope), + None => write!(f, "{}", self.tool), + } + } +} + +impl Serialize for Rule { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for Rule { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::try_from(s.as_str()).map_err(serde::de::Error::custom) + } +} +impl TryFrom<&str> for Rule { + type Error = RuleError; + + fn try_from(value: &str) -> Result<Self, Self::Error> { + let value = value.trim(); + let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); + let caps = re + .captures(value) + .ok_or(RuleError::InvalidRule(value.to_string()))?; + let tool = caps.get(1).unwrap().as_str().to_string(); + let scope = caps.get(2).map(|m| m.as_str().to_string()); + Ok(Rule { tool, scope }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_try_from() { + assert_eq!( + Rule::try_from("Read").unwrap(), + Rule { + tool: "Read".to_string(), + scope: None + } + ); + assert_eq!( + Rule::try_from("Read(*)").unwrap(), + Rule { + tool: "Read".to_string(), + scope: Some("*".to_string()) + } + ); + assert_eq!( + Rule::try_from("Write(*.md)").unwrap(), + Rule { + tool: "Write".to_string(), + scope: Some("*.md".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(git commit *)").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("git commit *".to_string()) + } + ); + assert_eq!( + Rule::try_from("Shell(echo ())").unwrap(), + Rule { + tool: "Shell".to_string(), + scope: Some("echo ()".to_string()) + } + ); + assert!(Rule::try_from("Shell(git commit *").is_err()); + assert!(Rule::try_from("Shell(git commit *)!").is_err()); + } +} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs new file mode 100644 index 00000000..7a2eee2e --- /dev/null +++ b/crates/atuin-ai/src/permissions/shell.rs @@ -0,0 +1,1297 @@ +/// Extracted command info from a shell command string. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ShellCommand { + /// The command name (first word), e.g. "git" + pub name: String, + /// The full invocation including arguments, e.g. "git commit -m msg" + pub full: String, +} + +/// A parsed shell command with all subcommands extracted. +#[derive(Debug)] +pub(crate) struct ParsedShellCommand { + pub subcommands: Vec<ShellCommand>, +} + +/// Supported shell families for parsing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ShellKind { + /// POSIX sh, bash, zsh — all share similar syntax + Posix, + /// fish shell + Fish, + /// nushell or unknown — fallback to word-level extraction + Other, +} + +impl ShellKind { + pub(crate) fn from_shell_name(name: &str) -> Self { + match name { + "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, + "fish" => Self::Fish, + _ => Self::Other, + } + } +} + +/// Parse a shell command string and extract all subcommands. +pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { + #[cfg(feature = "tree-sitter")] + match shell { + ShellKind::Posix => ts::parse_posix(code), + ShellKind::Fish => ts::parse_fish(code), + ShellKind::Other => parse_fallback(code), + } + + #[cfg(not(feature = "tree-sitter"))] + { + let _ = shell; + parse_fallback(code) + } +} + +// ──────────────────────────────────────────────────────────────── +// Tree-sitter parsers (POSIX + Fish) +// Disabled on platforms where tree-sitter doesn't cross-compile +// (e.g. Windows); falls back to word-level extraction. +// ──────────────────────────────────────────────────────────────── + +#[cfg(feature = "tree-sitter")] +mod ts { + use super::{ParsedShellCommand, ShellCommand, parse_fallback}; + use tree_sitter_lib::{Parser, Tree}; + + fn bash_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_bash::LANGUAGE.into()) + .expect("failed to set bash language"); + parser + } + + pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { + let mut parser = bash_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_bash_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + /// Leaf node kinds that never contain nested commands. + const BASH_LEAVES: &[&str] = &[ + "command_name", + "word", + "number", + "simple_expansion", + "expansion", + "arithmetic_expansion", + "ansi_c_string", + "special_variable_name", + "variable_name", + "file_descriptor", + "heredoc_body", + "heredoc_start", + "regex", + "heredoc_redirect", + ]; + + fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_bash_node(tree.root_node(), source, commands); + } + + fn walk_bash_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_bash_command(node, source) { + commands.push(cmd); + } + // Descend into all non-leaf children to find nested commands + // (e.g. command_substitution inside a string inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if !BASH_LEAVES.contains(&child.kind()) { + walk_bash_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_bash_node(child, source, commands); + } + } + } + } + + /// Extract the full command string and name from a bash `command` node. + fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // A `command` node has children like: + // variable_assignment* command_name argument* redirect* + // We want the command_name and all arguments (skipping assignments and redirects). + let mut name = None; + let mut name_start = None; + let mut arg_end = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" => { + name = child.utf8_text(source).ok().map(|s| s.to_string()); + name_start = Some(child.start_byte()); + } + "word" + | "string" + | "raw_string" + | "concatenation" + | "number" + | "simple_expansion" + | "expansion" + | "arithmetic_expansion" + | "ansi_c_string" + | "process_substitution" => { + arg_end = Some(child.end_byte()); + } + _ => {} + } + } + + let name = name?; + let full = if let (Some(start), Some(end)) = (name_start, arg_end) { + std::str::from_utf8(&source[start..end]).ok()?.to_string() + } else { + name.clone() + }; + + Some(ShellCommand { name, full }) + } + + // ──────────────────────────────────────────────────────────────── + // Fish parser + // ──────────────────────────────────────────────────────────────── + + fn fish_parser() -> Parser { + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_fish::language()) + .expect("failed to set fish language"); + parser + } + + pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { + let mut parser = fish_parser(); + let Some(tree) = parser.parse(code, None) else { + return parse_fallback(code); + }; + + let mut commands = Vec::new(); + walk_fish_tree(&tree, code.as_bytes(), &mut commands); + ParsedShellCommand { + subcommands: commands, + } + } + + const FISH_COMPOUND: &[&str] = &[ + "conditional_execution", + "pipe", + "job", + "command_substitution", + "block", + "for_statement", + "while_statement", + "if_statement", + "switch_statement", + "function_definition", + "begin_statement", + "redirected_statement", + ]; + + fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { + walk_fish_node(tree.root_node(), source, commands); + } + + fn walk_fish_node( + node: tree_sitter_lib::Node, + source: &[u8], + commands: &mut Vec<ShellCommand>, + ) { + match node.kind() { + "command" => { + if let Some(cmd) = extract_fish_command(node, source) { + commands.push(cmd); + } + // Still descend into compound children (e.g. command_substitution inside a command) + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if FISH_COMPOUND.contains(&child.kind()) { + walk_fish_node(child, source, commands); + } + } + } + // Other nodes: descend into all children + _ => { + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + walk_fish_node(child, source, commands); + } + } + } + } + + fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { + // In fish, a `command` node has: + // name (command_name or word) followed by arguments (word, string, etc.) + let mut name = None; + + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + match child.kind() { + "command_name" | "word" => { + let text = child.utf8_text(source).ok()?.to_string(); + if name.is_none() { + name = Some(text); + } + } + "string" + | "concatenation" + | "command_substitution" + | "escape_sequence" + | "double_quote_string" + | "single_quote_string" => {} + _ => {} + } + } + + let name = name?; + // Get the full text of the command node + let full = node.utf8_text(source).ok()?.trim().to_string(); + + Some(ShellCommand { name, full }) + } +} // mod ts + +// ──────────────────────────────────────────────────────────────── +// Fallback (word-level extraction for nushell / unknown shells) +// ──────────────────────────────────────────────────────────────── + +fn parse_fallback(code: &str) -> ParsedShellCommand { + // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. + // This is intentionally simple — for unknown shells we can't do better. + let mut commands = Vec::new(); + let mut segment = String::new(); + let mut chars = code.chars().peekable(); + + while let Some(c) = chars.next() { + match c { + ';' => { + push_segment(&mut segment, &mut commands); + } + '|' => { + if chars.peek() == Some(&'|') { + chars.next(); + } + push_segment(&mut segment, &mut commands); + } + '&' if chars.peek() == Some(&'&') => { + chars.next(); + push_segment(&mut segment, &mut commands); + } + _ => segment.push(c), + } + } + push_segment(&mut segment, &mut commands); + + ParsedShellCommand { + subcommands: commands, + } +} + +fn push_segment(segment: &mut String, commands: &mut Vec<ShellCommand>) { + let trimmed = segment.trim(); + if !trimmed.is_empty() + && let Some(name) = trimmed.split_whitespace().next() + { + commands.push(ShellCommand { + name: name.to_string(), + full: trimmed.to_string(), + }); + } + segment.clear(); +} + +// ──────────────────────────────────────────────────────────────── +// Scope matching +// ──────────────────────────────────────────────────────────────── + +/// Check if any of the extracted subcommands match the given scope pattern. +/// +/// Matching semantics depend on where the `*` wildcard appears: +/// - `*` alone — matches everything +/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` +/// - `git commit *` — matches `git commit -m "msg"` (word boundary) +/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) +/// - `rm` (no wildcard) — matches exactly `rm` +/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) +pub(crate) fn any_subcommand_matches(subcommands: &[ShellCommand], scope: &str) -> bool { + let scope = scope.trim(); + + if scope == "*" { + return true; + } + + if let Some(prefix) = scope.strip_suffix(" *") { + // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` + return subcommands.iter().any(|cmd| { + if prefix.is_empty() { + return true; + } + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); + cmd_words.len() >= prefix_words.len() + && cmd_words[..prefix_words.len()] == prefix_words[..] + }); + } + + if let Some(prefix) = scope.strip_suffix('*') { + // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. + return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); + } + + if scope.contains('*') { + // Middle wildcard: `git * amend` — each `*` matches zero or more words + return subcommands + .iter() + .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); + } + + // No wildcard: word-boundary prefix match + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + subcommands.iter().any(|cmd| { + let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); + cmd_words.len() >= scope_words.len() && cmd_words[..scope_words.len()] == scope_words[..] + }) +} + +/// Match a scope pattern containing `*` wildcards against a sequence of words. +/// Each `*` matches zero or more words. Consecutive `*` collapse into one. +fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { + let parts: Vec<&str> = scope.split('*').collect(); + if parts.len() == 1 { + // No wildcard (shouldn't reach here, but handle it) + let scope_words: Vec<&str> = scope.split_whitespace().collect(); + return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; + } + + // Each segment between * is a sequence of literal words that must appear in order. + // Walk through `words` consuming segments left to right. + let mut word_idx = 0; + + for (i, part) in parts.iter().enumerate() { + let segment_words: Vec<&str> = part.split_whitespace().collect(); + if segment_words.is_empty() { + continue; + } + + // Find the segment words starting from word_idx + if i == 0 { + // First segment must match at the start + if words.len() < segment_words.len() + || words[..segment_words.len()] != segment_words[..] + { + return false; + } + word_idx = segment_words.len(); + } else if i == parts.len() - 1 { + // Last segment must match at the end + if words.len() - word_idx < segment_words.len() { + return false; + } + let start = words.len() - segment_words.len(); + return words[start..] == segment_words[..]; + } else { + // Middle segment: find it anywhere after word_idx + let found = find_subslice(&words[word_idx..], &segment_words); + match found { + Some(idx) => word_idx += idx + segment_words.len(), + None => return false, + } + } + } + + true +} + +/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. +fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option<usize> { + if needle.is_empty() { + return Some(0); + } + if haystack.len() < needle.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) +} + +// ──────────────────────────────────────────────────────────────── +// Tests +// ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + fn names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.full.as_str()).collect() + } + + #[test] + fn simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[test] + fn pipeline() { + let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); + } + + #[test] + fn command_chaining() { + let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git", "git"]); + assert_eq!( + fulls(&result.subcommands), + vec!["git add .", "git commit -m 'hi'"] + ); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn command_substitution() { + let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn backtick_substitution() { + let result = parse_shell_command("echo `date`", ShellKind::Posix); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn subshell() { + let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); + } + + #[test] + fn semicolon_separated() { + let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn for_loop() { + let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); + assert!(names(&result.subcommands).contains(&"cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn if_statement() { + let result = parse_shell_command( + "if [ -f foo ]; then cat foo; else echo nope; fi", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn scope_matching_wildcard() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "*")); + } + + #[test] + fn scope_matching_prefix() { + let commands = vec![ + ShellCommand { + name: "git".into(), + full: "git commit -m msg".into(), + }, + ShellCommand { + name: "npm".into(), + full: "npm test".into(), + }, + ]; + assert!(any_subcommand_matches(&commands, "git commit *")); + assert!(any_subcommand_matches(&commands, "git commit")); + assert!(!any_subcommand_matches(&commands, "git push *")); + assert!(!any_subcommand_matches(&commands, "git push")); + assert!(any_subcommand_matches(&commands, "npm *")); + } + + #[test] + fn scope_word_boundary_vs_glob() { + let commands = vec![ + ShellCommand { + name: "ls".into(), + full: "ls -a".into(), + }, + ShellCommand { + name: "lsof".into(), + full: "lsof -i :3000".into(), + }, + ]; + // `ls *` — word boundary: matches `ls -a` but not `lsof` + assert!(any_subcommand_matches(&commands, "ls *")); + assert!(!any_subcommand_matches(&commands, "cat *")); + assert!(any_subcommand_matches(&commands, "lsof *")); + + // `ls*` — glob/prefix: matches both `ls -a` and `lsof` + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn scope_exact_match() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + assert!(any_subcommand_matches(&commands, "ls")); + assert!(!any_subcommand_matches(&commands, "cat")); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn nested_substitution() { + let result = parse_shell_command( + "echo \"Result: $(git log --oneline | head -1)\"", + ShellKind::Posix, + ); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + assert!(n.contains(&"head"), "should contain head: {n:?}"); + } + + #[test] + fn fallback_splits_correctly() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + let n = names(&result.subcommands); + assert!(n.contains(&"ls"), "should contain ls: {n:?}"); + assert!(n.contains(&"cat"), "should contain cat: {n:?}"); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + } + + #[test] + fn fish_simple_command() { + let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); + assert_eq!(names(&result.subcommands), vec!["ls"]); + } + + #[test] + fn fish_conditional() { + let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"git"), "should contain git: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn fish_command_substitution() { + let result = parse_shell_command("echo (date)", ShellKind::Fish); + let n = names(&result.subcommands); + assert!(n.contains(&"echo"), "should contain echo: {n:?}"); + assert!(n.contains(&"date"), "should contain date: {n:?}"); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_excluded() { + let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["ls"]); + assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); + } + + #[cfg(feature = "tree-sitter")] + #[test] + fn variable_assignment_multiple() { + let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); + assert_eq!(names(&result.subcommands), vec!["git"]); + assert_eq!(fulls(&result.subcommands), vec!["git status"]); + } + + #[test] + fn fallback_double_ampersand_and_pipe_pipe() { + let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); + assert_eq!( + fulls(&result.subcommands), + vec!["ls", "cat foo", "echo fail"] + ); + } + + #[test] + fn fallback_pipe_without_double() { + let result = parse_shell_command("ls | grep foo", ShellKind::Other); + assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); + assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); + } + + #[test] + fn scope_middle_wildcard() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit -m amend".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * amend")); + assert!(any_subcommand_matches(&commands, "git commit * amend")); + assert!(!any_subcommand_matches(&commands, "git push * amend")); + } + + #[test] + fn scope_middle_wildcard_zero_words() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + // `*` matches zero words, so `git * commit` should match `git commit` + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn scope_leading_wildcard() { + let commands = vec![ShellCommand { + name: "docker".into(), + full: "docker run --rm alpine".into(), + }]; + assert!(any_subcommand_matches(&commands, "* alpine")); + assert!(!any_subcommand_matches(&commands, "* ubuntu")); + } + + #[test] + fn scope_multiple_wildcards() { + let commands = vec![ShellCommand { + name: "git".into(), + full: "git rebase -i HEAD~5".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * -i * HEAD~5")); + assert!(!any_subcommand_matches(&commands, "git * -i * HEAD~10")); + } +} + +#[cfg(all(test, feature = "tree-sitter"))] +mod adversarial { + use super::*; + + fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { + cmds.iter().map(|c| c.name.as_str()).collect() + } + + /// Helper: assert that parsing POSIX extracts all expected command names + fn assert_posix(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Posix); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + fn assert_fish(code: &str, expected: &[&str]) { + let result = parse_shell_command(code, ShellKind::Fish); + let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); + got.sort(); + let mut want: Vec<&str> = expected.to_vec(); + want.sort(); + assert_eq!( + got, want, + "Fish parse of {:?}:\n got: {:?}\n want: {:?}", + code, got, want + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 1: Basic compounds + // ──────────────────────────────────────────────────────────── + + #[test] + fn a01_triple_chain() { + assert_posix("a && b && c", &["a", "b", "c"]); + } + + #[test] + fn a02_or_chain() { + assert_posix("a || b || c", &["a", "b", "c"]); + } + + #[test] + fn a03_mixed_chain() { + assert_posix("a && b || c && d", &["a", "b", "c", "d"]); + } + + #[test] + fn a04_long_pipeline() { + assert_posix( + "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", + &["cat", "grep", "awk", "sort", "uniq"], + ); + } + + #[test] + fn a05_semicolons() { + assert_posix("a; b; c; d", &["a", "b", "c", "d"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 2: Nested substitution + // ──────────────────────────────────────────────────────────── + + #[test] + fn a06_nested_dollar() { + assert_posix( + "echo $(basename $(dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn a07_deeply_nested() { + // 4 nested echos, all should be extracted + assert_posix( + "echo $(echo $(echo $(echo deep)))", + &["echo", "echo", "echo", "echo"], + ); + } + + #[test] + fn a08_backtick_in_echo() { + assert_posix("echo `hostname`", &["echo", "hostname"]); + } + + #[test] + fn a09_mixed_substitutions() { + assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 3: Subshells and grouping + // ──────────────────────────────────────────────────────────── + + #[test] + fn a10_subshell_chain() { + assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); + } + + #[test] + fn a11_nested_subshells() { + assert_posix("( (inner_cmd) )", &["inner_cmd"]); + } + + #[test] + fn a12_brace_group() { + assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 4: Variable assignments + // ──────────────────────────────────────────────────────────── + + #[test] + fn a13_single_var_assignment() { + let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["ls"]); + assert_eq!(result.subcommands[0].full, "ls"); + } + + #[test] + fn a14_multiple_var_assignments() { + let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); + assert_eq!(cmd_names(&result.subcommands), &["git"]); + assert_eq!(result.subcommands[0].full, "git status"); + } + + #[test] + fn a15_var_assignment_no_command() { + // Variable assignment only — no command to extract + assert_posix("FOO=bar", &[]); + } + + #[test] + fn a16_var_assignment_in_pipeline() { + assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 5: Control flow + // ──────────────────────────────────────────────────────────── + + #[test] + fn a17_if_then_else() { + assert_posix( + "if [ -f foo ]; then cat foo; else echo missing; fi", + &["cat", "echo"], + ); + } + + #[test] + fn a18_elif_chain() { + // Two cat commands (then + elif branch), one echo (else branch). + // [ is part of the test_condition, not extracted as a command. + assert_posix( + "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", + &["cat", "cat", "echo"], + ); + } + + #[test] + fn a19_for_loop() { + assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); + } + + #[test] + fn a20_while_loop() { + // read in the condition is a real command + assert_posix( + "while read line; do echo \"$line\"; done < input.txt", + &["echo", "read"], + ); + } + + #[test] + fn f07_if_statement() { + // test in if-condition is a real command + assert_fish( + "if test -f foo; cat foo; else; echo missing; end", + &["cat", "echo", "test"], + ); + } + + #[test] + fn f09_while_loop() { + // `true` in the condition is a real command + assert_fish( + "while true; echo tick; sleep 1; end", + &["echo", "sleep", "true"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 6: Redirections + // ──────────────────────────────────────────────────────────── + + #[test] + fn a23_redirect_out() { + assert_posix("ls > output.txt", &["ls"]); + } + + #[test] + fn a24_redirect_append() { + assert_posix("ls >> output.txt 2>&1", &["ls"]); + } + + #[test] + fn a25_here_string() { + assert_posix("grep foo <<< \"hello world\"", &["grep"]); + } + + #[test] + fn a26_redirect_in_pipeline() { + assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); + } + + #[test] + fn a27_process_substitution() { + assert_posix( + "diff <(sort a.txt) <(sort b.txt)", + &["diff", "sort", "sort"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 7: Function definitions + // ──────────────────────────────────────────────────────────── + + #[test] + fn a28_function_def() { + assert_posix("foo() { echo hello; }", &["echo"]); + } + + #[test] + fn a29_function_with_subshell() { + assert_posix( + "build() { cargo build && cargo test; }", + &["cargo", "cargo"], + ); + } + + // ──────────────────────────────────────────────────────────── + // Level 8: Edge cases — empties, weird quoting + // ──────────────────────────────────────────────────────────── + + #[test] + fn a30_empty_string() { + let result = parse_shell_command("", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a31_whitespace_only() { + let result = parse_shell_command(" \t \n ", ShellKind::Posix); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn a32_single_command_no_args() { + assert_posix("ls", &["ls"]); + } + + #[test] + fn a33_command_with_single_quotes() { + assert_posix("echo 'hello world'", &["echo"]); + } + + #[test] + fn a34_command_with_double_quotes() { + assert_posix("echo \"hello world\"", &["echo"]); + } + + #[test] + fn a35_escaped_spaces() { + // ls\ -la is a single word in bash, not "ls" with flag "-la" + assert_posix("ls\\ -la", &["ls\\ -la"]); + } + + #[test] + fn a36_command_with_dollar_var() { + assert_posix("echo $HOME/.bashrc", &["echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 9: Background jobs and coproc + // ──────────────────────────────────────────────────────────── + + #[test] + fn a37_background_job() { + assert_posix("sleep 10 &", &["sleep"]); + } + + #[test] + fn a38_background_chain() { + assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 10: Real-world complex commands + // ──────────────────────────────────────────────────────────── + + #[test] + fn a39_docker_build_and_run() { + assert_posix( + "docker build -t app . && docker run --rm app npm test", + &["docker", "docker"], + ); + } + + #[test] + fn a40_git_rebase_interactive() { + assert_posix( + "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", + &["git"], + ); + } + + #[test] + fn a41_find_with_exec() { + // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. + // This is a known limitation: args to -exec/-execdir are opaque to the parser. + assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); + } + + #[test] + fn a42_curl_pipe_sh() { + assert_posix( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn a43_xargs() { + assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); + } + + #[test] + fn a44_npm_script_chain() { + assert_posix( + "npm run build && npm run test && npm run lint", + &["npm", "npm", "npm"], + ); + } + + #[test] + fn a45_make_with_redirect() { + assert_posix( + "make -j$(nproc) 2>&1 | tee build.log", + &["make", "nproc", "tee"], + ); + } + + #[test] + fn a46_sudo_chain() { + assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); + } + + #[test] + fn a47_here_doc_with_subcommand() { + assert_posix("cat <<EOF\nhello $(whoami)\nEOF", &["cat", "whoami"]); + } + + #[test] + fn a48_eval_with_command() { + assert_posix("eval \"echo hello\"", &["eval"]); + } + + #[test] + fn a49_exec_replace() { + assert_posix("exec ls", &["exec"]); + } + + #[test] + fn a50_source_script() { + assert_posix("source ~/.bashrc", &["source"]); + } + + // ──────────────────────────────────────────────────────────── + // Level 11: Fish-specific tests + // ──────────────────────────────────────────────────────────── + + #[test] + fn f01_simple() { + assert_fish("ls -la /tmp", &["ls"]); + } + + #[test] + fn f02_pipe() { + assert_fish("cat foo | grep bar | sort", &["cat", "grep", "sort"]); + } + + #[test] + fn f03_and() { + assert_fish("git add .; and git commit -m hi", &["git", "git"]); + } + + #[test] + fn f04_or() { + assert_fish("test -f foo; or echo missing", &["test", "echo"]); + } + + #[test] + fn f04_not() { + // fish parses `not test -f foo` — `not` is a modifier, `test` is the command + assert_fish("not test -f foo", &["test"]); + } + + #[test] + fn f05_command_substitution() { + assert_fish("echo (date)", &["echo", "date"]); + } + + #[test] + fn f06_nested_substitution() { + assert_fish( + "echo (basename (dirname /foo/bar))", + &["echo", "basename", "dirname"], + ); + } + + #[test] + fn f06_begin_end() { + assert_fish("begin; ls; echo done; end", &["ls", "echo"]); + } + + #[test] + fn f10_switch() { + // Two echo commands, one per case branch + assert_fish( + "switch $x; case foo; echo foo; case bar; echo bar; end", + &["echo", "echo"], + ); + } + + #[test] + fn f08_for_loop() { + assert_fish("for f in *.txt; cat $f; end", &["cat"]); + } + + #[test] + fn a21_case_statement() { + // Two echo branches + assert_posix( + "case $x in foo) echo foo;; bar) echo bar;; esac", + &["echo", "echo"], + ); + } + + #[test] + fn f11_function_def() { + assert_fish("function greet; echo hello $argv; end", &["echo"]); + } + + #[test] + fn f12_redirect() { + assert_fish("ls > output.txt", &["ls"]); + } + + #[test] + fn f13_redirect_append() { + assert_fish("ls >> output.txt", &["ls"]); + } + + #[test] + fn f14_here_string() { + assert_fish("grep foo <<< \"hello\"", &["grep"]); + } + + #[test] + fn f15_curl_pipe() { + assert_fish( + "curl -sSL https://example.com/install.sh | bash", + &["curl", "bash"], + ); + } + + #[test] + fn f16_double_ampersand() { + assert_fish("git add . && git commit -m hi", &["git", "git"]); + } + + #[test] + fn f17_double_pipe() { + assert_fish("test -f foo || echo missing", &["test", "echo"]); + } + + #[test] + fn f18_empty() { + let result = parse_shell_command("", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + #[test] + fn f19_whitespace() { + let result = parse_shell_command(" ", ShellKind::Fish); + assert!(result.subcommands.is_empty()); + } + + // ──────────────────────────────────────────────────────────── + // Level 12: Scope matching adversarial + // ──────────────────────────────────────────────────────────── + + #[test] + fn s01_empty_scope() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // Empty scope matches everything (nothing to constrain) + assert!(any_subcommand_matches(&commands, "")); + } + + #[test] + fn s03_only_wildcard_space_star() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // " *" with empty prefix = match anything + assert!(any_subcommand_matches(&commands, " *")); + } + + #[test] + fn s04_glob_matches_empty() { + let commands = vec![ShellCommand { + name: "ls".into(), + full: "ls".into(), + }]; + // `ls*` matches `ls` (prefix match with nothing after) + assert!(any_subcommand_matches(&commands, "ls*")); + } + + #[test] + fn s05_middle_wildcard_empty_match() { + // `git * commit` matches `git commit` (* = zero words) + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git * commit")); + } + + #[test] + fn s06_consecutive_wildcards() { + // `git ** commit` should behave like `git * commit` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit".into(), + }]; + assert!(any_subcommand_matches(&commands, "git ** commit")); + } + + #[test] + fn s07_case_sensitivity() { + let commands = vec![ShellCommand { + name: "LS".into(), + full: "LS -la".into(), + }]; + assert!(!any_subcommand_matches(&commands, "ls")); + assert!(any_subcommand_matches(&commands, "LS")); + } + + #[test] + fn s08_multi_word_exact_no_subcommand() { + // `git commit` should not match `git commit-amend` + let commands = vec![ShellCommand { + name: "git".into(), + full: "git commit-amend".into(), + }]; + assert!(!any_subcommand_matches(&commands, "git commit")); + } +} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs new file mode 100644 index 00000000..3bda01c3 --- /dev/null +++ b/crates/atuin-ai/src/permissions/walker.rs @@ -0,0 +1,121 @@ +use std::path::{Path, PathBuf}; + +use eyre::Result; +use tokio::task::JoinSet; + +use crate::permissions::file::{RuleFile, RuleFileContent}; + +#[derive(Debug)] +struct FoundRuleFile { + depth: usize, + file: RuleFile, +} + +pub(crate) struct PermissionWalker { + start: PathBuf, + /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). + global_permissions_file: Option<PathBuf>, + rules: Vec<RuleFile>, +} + +impl PermissionWalker { + pub fn new(start: PathBuf, global_permissions_file: Option<PathBuf>) -> Self { + Self { + start, + global_permissions_file, + rules: Vec::new(), + } + } + + pub fn rules(&self) -> &[RuleFile] { + &self.rules + } + + /// Walks the filesystem starting from the start path and collecting permission files along the way. + /// Walks to the root, then checks the global permissions file, if any. + pub async fn walk(&mut self) -> Result<()> { + let dirs_to_check: Vec<PathBuf> = self.start.ancestors().map(PathBuf::from).collect(); + let dir_count = dirs_to_check.len(); + + let mut set: JoinSet<Result<Option<FoundRuleFile>>> = JoinSet::new(); + + for (index, path) in dirs_to_check.into_iter().enumerate() { + set.spawn(async move { + match check_dir_for_permissions(&path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth: index, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) + if let Some(global_path) = self.global_permissions_file.clone() { + let depth = dir_count; // sorts after all directory-walk entries + set.spawn(async move { + match load_permissions_file(&global_path).await { + Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { + depth, + file: rule_file, + })), + Ok(None) => Ok(None), + Err(e) => Err(e), + } + }); + } + + let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); + let mut found = Vec::with_capacity(capacity); + while let Some(result) = set.join_next().await { + let result = result?; // JoinErrors result in failure to walk the filesystem + + match result { + Ok(Some(FoundRuleFile { depth, file })) => { + found.push((depth, file)); + } + Ok(None) => { + continue; + } + Err(e) => { + tracing::error!( + "Error while walking filesystem for permissions check; skipping: {}", + e + ); + continue; + } + } + } + // join_next() returns in order of completion, not order of spawn + found.sort_by_key(|(depth, _)| *depth); + self.rules = found.into_iter().map(|(_, file)| file).collect(); + + Ok(()) + } +} + +/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. +async fn check_dir_for_permissions(path: &Path) -> Result<Option<RuleFile>> { + let file_path = path.join(".atuin").join("permissions.ai.toml"); + load_permissions_file(&file_path).await +} + +/// Load a permissions file from an exact path. Returns None if the file doesn't exist. +async fn load_permissions_file(file_path: &Path) -> Result<Option<RuleFile>> { + if !tokio::fs::try_exists(file_path).await? { + return Ok(None); + } + + let raw = tokio::fs::read_to_string(file_path).await?; + let content: RuleFileContent = toml::from_str(&raw)?; + + // Use the file's parent as the rule file path (for logging/debugging) + let path = file_path + .parent() + .map(Path::to_path_buf) + .unwrap_or_else(|| file_path.to_path_buf()); + + Ok(Some(RuleFile { path, content })) +} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs new file mode 100644 index 00000000..b2bd9482 --- /dev/null +++ b/crates/atuin-ai/src/permissions/writer.rs @@ -0,0 +1,198 @@ +use std::path::Path; + +use eyre::Result; + +use crate::permissions::rule::Rule; + +/// Whether a rule should be added to the allow or deny list. +#[allow(dead_code)] +pub(crate) enum RuleDisposition { + Allow, + Deny, +} + +/// Write a permission rule to a `permissions.ai.toml` file. +/// +/// If the file doesn't exist it is created (along with parent directories). +/// If it does exist, `toml_edit` is used to append the rule while preserving +/// existing formatting and comments. +/// +/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the +/// current UI this is fine — the Select widget serializes permission decisions — +/// but callers should not invoke this concurrently for the same file. +pub(crate) async fn write_rule( + file_path: &Path, + rule: &Rule, + disposition: RuleDisposition, +) -> Result<()> { + let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { + tokio::fs::read_to_string(file_path).await? + } else { + String::new() + }; + + let mut doc: toml_edit::DocumentMut = content.parse()?; + + // Ensure [permissions] table exists + if !doc.contains_key("permissions") { + doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); + } + + let key = match disposition { + RuleDisposition::Allow => "allow", + RuleDisposition::Deny => "deny", + }; + + // Use as_table_like_mut so both standard and inline tables work. + let permissions = doc["permissions"] + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; + + // Get or create the array + if !permissions.contains_key(key) { + permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); + } + + let array = permissions + .get_mut(key) + .and_then(|item| item.as_value_mut()) + .and_then(|v| v.as_array_mut()) + .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; + + // Don't add duplicates + let rule_str = rule.to_string(); + let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); + if !already_present { + array.push(rule_str); + } + + // Write back, creating parent directories as needed + if let Some(parent) = file_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(file_path, doc.to_string()).await?; + + Ok(()) +} + +/// Build the path to the project-level permissions file. +/// `project_root` is typically a git root or the current working directory. +pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { + project_root.join(".atuin").join("permissions.ai.toml") +} + +/// Build the path to the global permissions file (sibling of atuin config). +pub(crate) fn global_permissions_path() -> std::path::PathBuf { + atuin_common::utils::config_dir().join("permissions.ai.toml") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn creates_new_file_with_allow_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("[permissions]")); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn appends_to_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"# My permissions +[permissions] +allow = ["Read"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Comment preserved + assert!(content.contains("# My permissions")); + // Both rules present + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn does_not_duplicate_existing_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let existing = r#"[permissions] +allow = ["AtuinHistory"] +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + // Should appear exactly once + assert_eq!(content.matches("AtuinHistory").count(), 1); + } + + #[tokio::test] + async fn handles_inline_table_permissions() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + // Inline table style — as_table_mut() would return None for this + let existing = r#"permissions = { allow = ["Read"] } +"#; + tokio::fs::write(&file, existing).await.unwrap(); + + let rule = Rule { + tool: "AtuinHistory".to_string(), + scope: None, + }; + write_rule(&file, &rule, RuleDisposition::Allow) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains(r#""Read""#)); + assert!(content.contains(r#""AtuinHistory""#)); + } + + #[tokio::test] + async fn writes_deny_rule() { + let dir = tempfile::tempdir().unwrap(); + let file = dir.path().join("permissions.ai.toml"); + let rule = Rule { + tool: "Shell".to_string(), + scope: None, + }; + + write_rule(&file, &rule, RuleDisposition::Deny) + .await + .unwrap(); + + let content = tokio::fs::read_to_string(&file).await.unwrap(); + assert!(content.contains("deny")); + assert!(content.contains(r#""Shell""#)); + } +} |
