aboutsummaryrefslogtreecommitdiffstats
path: root/crates/atuin-ai/src/tools/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/atuin-ai/src/tools/mod.rs')
-rw-r--r--crates/atuin-ai/src/tools/mod.rs281
1 files changed, 207 insertions, 74 deletions
diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs
index fdda10a4..d1352661 100644
--- a/crates/atuin-ai/src/tools/mod.rs
+++ b/crates/atuin-ai/src/tools/mod.rs
@@ -5,6 +5,7 @@ use std::{
};
use eyre::Result;
+use uuid::Uuid;
const DEFAULT_FILE_READ_LINES: u64 = 100;
const MAX_FILE_READ_LINES: u64 = 1000;
@@ -158,6 +159,7 @@ pub(crate) enum ClientToolCall {
Write(WriteToolCall),
Shell(ShellToolCall),
AtuinHistory(AtuinHistoryToolCall),
+ AtuinOutput(AtuinOutputToolCall),
LoadSkill(LoadSkillToolCall),
}
@@ -173,6 +175,9 @@ impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall {
"atuin_history" => Ok(ClientToolCall::AtuinHistory(
AtuinHistoryToolCall::try_from(input)?,
)),
+ "atuin_output" => Ok(ClientToolCall::AtuinOutput(AtuinOutputToolCall::try_from(
+ input,
+ )?)),
"load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from(
input,
)?)),
@@ -189,6 +194,7 @@ impl ClientToolCall {
ClientToolCall::Write(_) => descriptor::WRITE,
ClientToolCall::Shell(_) => descriptor::SHELL,
ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY,
+ ClientToolCall::AtuinOutput(_) => descriptor::ATUIN_OUTPUT,
ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL,
}
}
@@ -205,6 +211,7 @@ impl ClientToolCall {
ClientToolCall::Write(_) => "Write",
ClientToolCall::Shell(_) => "Shell",
ClientToolCall::AtuinHistory(_) => "AtuinHistory",
+ ClientToolCall::AtuinOutput(_) => "AtuinOutput",
ClientToolCall::LoadSkill(_) => "LoadSkill",
}
}
@@ -218,6 +225,7 @@ impl ClientToolCall {
ClientToolCall::Write(tool) => Some(tool.resolved_path()),
ClientToolCall::Shell(_)
| ClientToolCall::AtuinHistory(_)
+ | ClientToolCall::AtuinOutput(_)
| ClientToolCall::LoadSkill(_) => None,
}
}
@@ -229,6 +237,7 @@ impl ClientToolCall {
ClientToolCall::Write(tool) => tool.matches_rule(rule),
ClientToolCall::Shell(tool) => tool.matches_rule(rule),
ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule),
+ ClientToolCall::AtuinOutput(tool) => tool.matches_rule(rule),
ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule),
}
}
@@ -240,26 +249,14 @@ impl ClientToolCall {
ClientToolCall::Write(tool) => tool.target_dir(),
ClientToolCall::Shell(tool) => tool.target_dir(),
ClientToolCall::AtuinHistory(tool) => tool.target_dir(),
+ ClientToolCall::AtuinOutput(tool) => tool.target_dir(),
ClientToolCall::LoadSkill(tool) => tool.target_dir(),
}
}
-
- /// Execute this client-side tool and return the result.
- pub async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome {
- match self {
- ClientToolCall::Read(tool) => tool.execute(),
- ClientToolCall::AtuinHistory(tool) => tool.execute(db).await,
- // LoadSkill is handled separately by the driver (needs registry access)
- ClientToolCall::LoadSkill(_) => {
- ToolOutcome::Error("LoadSkill must be executed via the driver".to_string())
- }
- _ => ToolOutcome::Error("Client-side tool execution not yet implemented".to_string()),
- }
- }
}
/// A trait for tool calls that can be checked against permission rules.
-pub(crate) trait PermissableToolCall {
+pub(crate) trait PermissibleToolCall {
/// Checks if this tool call matches the given permission rule.
fn matches_rule(&self, rule: &Rule) -> bool;
@@ -277,7 +274,7 @@ pub(crate) trait PermissableToolCall {
}
}
-impl PermissableToolCall for ClientToolCall {
+impl PermissibleToolCall for ClientToolCall {
fn matches_rule(&self, rule: &Rule) -> bool {
self.matches_rule(rule)
}
@@ -416,7 +413,7 @@ impl ReadToolCall {
}
}
-impl PermissableToolCall for ReadToolCall {
+impl PermissibleToolCall for ReadToolCall {
fn target_dir(&self) -> Option<&Path> {
Some(&self.path)
}
@@ -616,7 +613,7 @@ impl EditToolCall {
}
}
-impl PermissableToolCall for EditToolCall {
+impl PermissibleToolCall for EditToolCall {
fn target_dir(&self) -> Option<&Path> {
Some(&self.path)
}
@@ -724,7 +721,7 @@ impl WriteToolCall {
}
}
-impl PermissableToolCall for WriteToolCall {
+impl PermissibleToolCall for WriteToolCall {
fn target_dir(&self) -> Option<&Path> {
Some(&self.path)
}
@@ -792,7 +789,7 @@ impl TryFrom<&serde_json::Value> for ShellToolCall {
}
}
-impl PermissableToolCall for ShellToolCall {
+impl PermissibleToolCall for ShellToolCall {
fn target_dir(&self) -> Option<&Path> {
self.dir.as_deref()
}
@@ -1134,7 +1131,7 @@ impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall {
}
}
-impl PermissableToolCall for AtuinHistoryToolCall {
+impl PermissibleToolCall for AtuinHistoryToolCall {
fn target_dir(&self) -> Option<&Path> {
None
}
@@ -1148,7 +1145,6 @@ impl AtuinHistoryToolCall {
pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome {
use atuin_client::database::{self, Database as _, OptFilters};
use atuin_client::settings::SearchMode;
- use time::UtcOffset;
let context = match database::current_context().await {
Ok(ctx) => ctx,
@@ -1184,34 +1180,13 @@ impl AtuinHistoryToolCall {
return ToolOutcome::Success("No matching history entries found.".to_string());
}
- let local_offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC);
+ let local_offset = crate::history_format::current_local_offset();
let formatted: Vec<String> = results
.iter()
.enumerate()
- .map(|(i, h)| {
- let ts = h.timestamp.to_offset(local_offset);
- let time_str = format!(
- "{:04}-{:02}-{:02} {:02}:{:02}:{:02}",
- ts.year(),
- ts.month() as u8,
- ts.day(),
- ts.hour(),
- ts.minute(),
- ts.second(),
- );
-
- let duration_str = format_duration(h.duration);
-
- format!(
- "{}. `{}` [{}] ({}, exit: {}){}",
- i + 1,
- h.command,
- time_str,
- h.cwd,
- h.exit,
- duration_str,
- )
+ .map(|(i, history)| {
+ crate::history_format::format_history_search_result(i + 1, history, local_offset)
})
.collect();
@@ -1220,6 +1195,146 @@ impl AtuinHistoryToolCall {
}
#[derive(Debug, Clone)]
+pub(crate) struct AtuinOutputToolCall {
+ pub history_id: Uuid,
+ pub ranges: Vec<(i64, i64)>,
+}
+
+impl TryFrom<&serde_json::Value> for AtuinOutputToolCall {
+ type Error = eyre::Error;
+
+ fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> {
+ let history_id = value
+ .get("history_id")
+ .and_then(|v| v.as_str())
+ .and_then(|v| Uuid::parse_str(v).ok())
+ .ok_or(eyre::eyre!("Missing or invalid history ID"))?;
+
+ let ranges = value
+ .get("ranges")
+ .and_then(|v| v.as_array())
+ .map(Vec::as_slice)
+ .unwrap_or(&[]);
+
+ let ranges = ranges
+ .iter()
+ .map(|r| {
+ let range = r
+ .as_array()
+ .filter(|a| a.len() == 2)
+ .ok_or_else(|| eyre::eyre!("Each range must be a [start, end] array"))?;
+
+ let start = range[0]
+ .as_i64()
+ .ok_or_else(|| eyre::eyre!("Range start must be an integer"))?;
+ let end = range[1]
+ .as_i64()
+ .ok_or_else(|| eyre::eyre!("Range end must be an integer"))?;
+
+ Ok((start, end))
+ })
+ .collect::<Result<Vec<(i64, i64)>, eyre::Error>>()?;
+
+ Ok(Self { history_id, ranges })
+ }
+}
+
+impl PermissibleToolCall for AtuinOutputToolCall {
+ fn target_dir(&self) -> Option<&Path> {
+ None
+ }
+
+ fn matches_rule(&self, rule: &Rule) -> bool {
+ rule.tool == "AtuinOutput"
+ }
+}
+
+fn format_output_lines_for_llm(lines: &[atuin_daemon::semantic::OutputLine]) -> String {
+ let width = lines
+ .iter()
+ .map(|line| line.line_number)
+ .max()
+ .unwrap_or(1)
+ .max(1)
+ .ilog10() as usize
+ + 1;
+ let mut formatted = Vec::with_capacity(lines.len());
+ let mut previous_line_number = None;
+
+ for line in lines {
+ if let Some(previous) = previous_line_number {
+ let skipped = line.line_number.saturating_sub(previous + 1);
+ if skipped > 0 {
+ formatted.push(format!("[...skipped {skipped} lines...]"));
+ }
+ }
+
+ formatted.push(format!("{:>width$}\t{}", line.line_number, line.content));
+ previous_line_number = Some(line.line_number);
+ }
+
+ formatted.join("\n")
+}
+
+impl AtuinOutputToolCall {
+ pub(crate) async fn execute(&self) -> ToolOutcome {
+ let settings = match atuin_client::settings::Settings::new() {
+ Ok(settings) => settings,
+ Err(e) => return ToolOutcome::Error(format!("Failed to load Atuin settings: {e}")),
+ };
+
+ let mut client = match atuin_daemon::SemanticClient::from_settings(&settings).await {
+ Ok(client) => client,
+ Err(e) => return ToolOutcome::Error(format!("Failed to connect to Atuin daemon: {e}")),
+ };
+
+ let history_id = self.history_id.as_simple().to_string();
+ let response = match client
+ .command_output(history_id.clone(), self.ranges.clone())
+ .await
+ {
+ Ok(response) => response,
+ Err(e) => return ToolOutcome::Error(format!("Failed to fetch command output: {e}")),
+ };
+
+ if !response.found {
+ return ToolOutcome::Success(format!(
+ "No captured output found for history ID {history_id}."
+ ));
+ }
+
+ if response.total_lines == 0 {
+ return ToolOutcome::Success(format!(
+ "Captured output for history ID {history_id} is empty."
+ ));
+ }
+
+ let output = format_output_lines_for_llm(&response.lines);
+ if output.is_empty() {
+ return ToolOutcome::Success(format!(
+ "No lines selected from captured output for history ID {history_id}."
+ ));
+ }
+
+ let total_output = if response.output_truncated {
+ format!(
+ "{} bytes captured, {} bytes observed before truncation, {} lines",
+ response.total_bytes, response.output_observed_bytes, response.total_lines
+ )
+ } else {
+ format!(
+ "{} bytes, {} lines",
+ response.total_bytes, response.total_lines
+ )
+ };
+
+ ToolOutcome::Success(format!(
+ "History ID: {history_id}\nTotal output: {total_output}\nSelected output:\n{output}"
+ ))
+ }
+}
+
+#[derive(Debug, Clone)]
pub(crate) struct LoadSkillToolCall {
pub name: String,
}
@@ -1239,7 +1354,7 @@ impl TryFrom<&serde_json::Value> for LoadSkillToolCall {
}
}
-impl PermissableToolCall for LoadSkillToolCall {
+impl PermissibleToolCall for LoadSkillToolCall {
fn target_dir(&self) -> Option<&Path> {
None
}
@@ -1286,6 +1401,52 @@ mod tests {
// ── Cross-platform tests ──
#[test]
+ fn atuin_output_ranges_are_optional() {
+ let input = serde_json::json!({
+ "history_id": "018f0000000070008000000000000000"
+ });
+
+ let call = AtuinOutputToolCall::try_from(&input).unwrap();
+
+ assert_eq!(
+ call.history_id.as_simple().to_string(),
+ "018f0000000070008000000000000000"
+ );
+ assert!(call.ranges.is_empty());
+ }
+
+ #[test]
+ fn atuin_output_parses_line_ranges() {
+ let input = serde_json::json!({
+ "history_id": "018f0000000070008000000000000000",
+ "ranges": [[0, 30], [-100, -1]]
+ });
+
+ let call = AtuinOutputToolCall::try_from(&input).unwrap();
+
+ assert_eq!(call.ranges, vec![(0, 30), (-100, -1)]);
+ }
+
+ #[test]
+ fn atuin_output_formats_lines_like_read_file() {
+ let lines = vec![
+ atuin_daemon::semantic::OutputLine {
+ line_number: 98,
+ content: "near end".to_string(),
+ },
+ atuin_daemon::semantic::OutputLine {
+ line_number: 100,
+ content: "end".to_string(),
+ },
+ ];
+
+ assert_eq!(
+ format_output_lines_for_llm(&lines),
+ " 98\tnear end\n[...skipped 1 lines...]\n100\tend"
+ );
+ }
+
+ #[test]
fn no_scope_matches_everything() {
assert!(read_tool("any/path.txt").matches_rule(&read_rule(None)));
assert!(write_tool("any/path.txt").matches_rule(&write_rule(None)));
@@ -1996,31 +2157,3 @@ mod tests {
}
}
}
-
-fn format_duration(nanos: i64) -> String {
- if nanos <= 0 {
- return String::new();
- }
-
- let total_secs = nanos / 1_000_000_000;
- let millis = (nanos % 1_000_000_000) / 1_000_000;
-
- if total_secs >= 3600 {
- let hours = total_secs / 3600;
- let mins = (total_secs % 3600) / 60;
- let secs = total_secs % 60;
- format!(", {hours}h{mins}m{secs}s")
- } else if total_secs >= 60 {
- let mins = total_secs / 60;
- let secs = total_secs % 60;
- format!(", {mins}m{secs}s")
- } else if total_secs > 0 {
- if millis > 0 {
- format!(", {total_secs}.{millis:03}s")
- } else {
- format!(", {total_secs}s")
- }
- } else {
- format!(", {millis}ms")
- }
-}