From 5e31a81cd2207f053b8cd8ad84ebe2a2f691b29d Mon Sep 17 00:00:00 2001 From: Benedikt Peetz Date: Wed, 10 Jun 2026 22:01:45 +0200 Subject: chore: Remove some unused rust code --- .gitignore | 2 + Cargo.lock | 349 +--- Cargo.toml | 18 +- crates/atuin-ai/Cargo.toml | 74 - .../20260413000000_create_ai_sessions.sql | 32 - .../20260417000000_add_session_metadata.sql | 9 - crates/atuin-ai/render-tests.sh | 34 - crates/atuin-ai/replay-states.sh | 101 - crates/atuin-ai/src/commands.rs | 158 -- crates/atuin-ai/src/commands/init.rs | 233 --- crates/atuin-ai/src/commands/inline.rs | 587 ------ crates/atuin-ai/src/context.rs | 121 -- crates/atuin-ai/src/context_window.rs | 578 ------ crates/atuin-ai/src/diff.rs | 328 --- crates/atuin-ai/src/driver.rs | 1030 ---------- crates/atuin-ai/src/edit_permissions.rs | 108 - crates/atuin-ai/src/event_serde.rs | 397 ---- crates/atuin-ai/src/file_tracker.rs | 234 --- crates/atuin-ai/src/fsm/effects.rs | 99 - crates/atuin-ai/src/fsm/events.rs | 140 -- crates/atuin-ai/src/fsm/mod.rs | 1103 ---------- crates/atuin-ai/src/fsm/tests.rs | 890 -------- crates/atuin-ai/src/fsm/tools.rs | 178 -- crates/atuin-ai/src/history_format.rs | 120 -- crates/atuin-ai/src/lib.rs | 19 - crates/atuin-ai/src/permissions/check.rs | 71 - crates/atuin-ai/src/permissions/file.rs | 26 - crates/atuin-ai/src/permissions/mod.rs | 7 - crates/atuin-ai/src/permissions/resolver.rs | 31 - crates/atuin-ai/src/permissions/rule.rs | 106 - crates/atuin-ai/src/permissions/shell.rs | 1335 ------------ crates/atuin-ai/src/permissions/walker.rs | 121 -- crates/atuin-ai/src/permissions/writer.rs | 199 -- crates/atuin-ai/src/session.rs | 509 ----- crates/atuin-ai/src/skills/frontmatter.rs | 233 --- crates/atuin-ai/src/skills/mod.rs | 468 ----- crates/atuin-ai/src/skills/walker.rs | 178 -- crates/atuin-ai/src/snapshots.rs | 414 ---- crates/atuin-ai/src/store.rs | 554 ----- crates/atuin-ai/src/stream.rs | 288 --- crates/atuin-ai/src/tools/descriptor.rs | 129 -- crates/atuin-ai/src/tools/mod.rs | 2159 -------------------- crates/atuin-ai/src/tui/components/atuin_ai.rs | 143 -- crates/atuin-ai/src/tui/components/input_box.rs | 220 -- crates/atuin-ai/src/tui/components/markdown.rs | 210 -- crates/atuin-ai/src/tui/components/mod.rs | 5 - crates/atuin-ai/src/tui/components/select.rs | 95 - .../src/tui/components/session_continue.rs | 49 - crates/atuin-ai/src/tui/content/help.md | 6 - crates/atuin-ai/src/tui/events.rs | 67 - crates/atuin-ai/src/tui/mod.rs | 7 - crates/atuin-ai/src/tui/slash.rs | 79 - crates/atuin-ai/src/tui/state.rs | 237 --- crates/atuin-ai/src/tui/view/mod.rs | 978 --------- crates/atuin-ai/src/tui/view/turn.rs | 606 ------ crates/atuin-ai/src/user_context/interpolate.rs | 279 --- crates/atuin-ai/src/user_context/mod.rs | 68 - crates/atuin-ai/src/user_context/walker.rs | 90 - crates/atuin-ai/test-renders.json | 295 --- crates/atuin-daemon/Cargo.toml | 1 - crates/atuin-daemon/src/components/sync.rs | 17 - crates/atuin-dotfiles/Cargo.toml | 25 - crates/atuin-dotfiles/src/lib.rs | 2 - crates/atuin-dotfiles/src/shell.rs | 241 --- crates/atuin-dotfiles/src/shell/bash.rs | 68 - crates/atuin-dotfiles/src/shell/fish.rs | 69 - crates/atuin-dotfiles/src/shell/powershell.rs | 169 -- crates/atuin-dotfiles/src/shell/xonsh.rs | 68 - crates/atuin-dotfiles/src/shell/zsh.rs | 68 - crates/atuin-dotfiles/src/store.rs | 421 ---- crates/atuin-dotfiles/src/store/alias.rs | 1 - crates/atuin-dotfiles/src/store/var.rs | 542 ----- crates/atuin-scripts/Cargo.toml | 34 - .../20250326160051_create_scripts.down.sql | 2 - .../20250326160051_create_scripts.up.sql | 17 - .../20250402170430_unique_names.down.sql | 2 - .../migrations/20250402170430_unique_names.up.sql | 2 - crates/atuin-scripts/src/database.rs | 371 ---- crates/atuin-scripts/src/execution.rs | 286 --- crates/atuin-scripts/src/lib.rs | 4 - crates/atuin-scripts/src/settings.rs | 1 - crates/atuin-scripts/src/store.rs | 114 -- crates/atuin-scripts/src/store/record.rs | 215 -- crates/atuin-scripts/src/store/script.rs | 151 -- crates/atuin/Cargo.toml | 44 +- crates/atuin/contrib/pi/atuin.ts | 87 - crates/atuin/src/command/client.rs | 39 +- crates/atuin/src/command/client/dotfiles.rs | 28 - crates/atuin/src/command/client/dotfiles/alias.rs | 187 -- crates/atuin/src/command/client/dotfiles/var.rs | 197 -- crates/atuin/src/command/client/hook.rs | 479 ----- crates/atuin/src/command/client/init.rs | 81 +- crates/atuin/src/command/client/init/bash.rs | 26 +- crates/atuin/src/command/client/init/fish.rs | 20 +- crates/atuin/src/command/client/init/powershell.rs | 15 +- crates/atuin/src/command/client/init/xonsh.rs | 15 +- crates/atuin/src/command/client/init/zsh.rs | 21 +- crates/atuin/src/command/client/scripts.rs | 590 ------ crates/atuin/src/command/client/store/rebuild.rs | 36 - crates/atuin/src/command/client/wrapped.rs | 54 +- crates/atuin/src/sync.rs | 10 - flake.lock | 98 +- flake.nix | 109 +- 103 files changed, 88 insertions(+), 22143 deletions(-) delete mode 100644 crates/atuin-ai/Cargo.toml delete mode 100644 crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql delete mode 100644 crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql delete mode 100755 crates/atuin-ai/render-tests.sh delete mode 100755 crates/atuin-ai/replay-states.sh delete mode 100644 crates/atuin-ai/src/commands.rs delete mode 100644 crates/atuin-ai/src/commands/init.rs delete mode 100644 crates/atuin-ai/src/commands/inline.rs delete mode 100644 crates/atuin-ai/src/context.rs delete mode 100644 crates/atuin-ai/src/context_window.rs delete mode 100644 crates/atuin-ai/src/diff.rs delete mode 100644 crates/atuin-ai/src/driver.rs delete mode 100644 crates/atuin-ai/src/edit_permissions.rs delete mode 100644 crates/atuin-ai/src/event_serde.rs delete mode 100644 crates/atuin-ai/src/file_tracker.rs delete mode 100644 crates/atuin-ai/src/fsm/effects.rs delete mode 100644 crates/atuin-ai/src/fsm/events.rs delete mode 100644 crates/atuin-ai/src/fsm/mod.rs delete mode 100644 crates/atuin-ai/src/fsm/tests.rs delete mode 100644 crates/atuin-ai/src/fsm/tools.rs delete mode 100644 crates/atuin-ai/src/history_format.rs delete mode 100644 crates/atuin-ai/src/lib.rs delete mode 100644 crates/atuin-ai/src/permissions/check.rs delete mode 100644 crates/atuin-ai/src/permissions/file.rs delete mode 100644 crates/atuin-ai/src/permissions/mod.rs delete mode 100644 crates/atuin-ai/src/permissions/resolver.rs delete mode 100644 crates/atuin-ai/src/permissions/rule.rs delete mode 100644 crates/atuin-ai/src/permissions/shell.rs delete mode 100644 crates/atuin-ai/src/permissions/walker.rs delete mode 100644 crates/atuin-ai/src/permissions/writer.rs delete mode 100644 crates/atuin-ai/src/session.rs delete mode 100644 crates/atuin-ai/src/skills/frontmatter.rs delete mode 100644 crates/atuin-ai/src/skills/mod.rs delete mode 100644 crates/atuin-ai/src/skills/walker.rs delete mode 100644 crates/atuin-ai/src/snapshots.rs delete mode 100644 crates/atuin-ai/src/store.rs delete mode 100644 crates/atuin-ai/src/stream.rs delete mode 100644 crates/atuin-ai/src/tools/descriptor.rs delete mode 100644 crates/atuin-ai/src/tools/mod.rs delete mode 100644 crates/atuin-ai/src/tui/components/atuin_ai.rs delete mode 100644 crates/atuin-ai/src/tui/components/input_box.rs delete mode 100644 crates/atuin-ai/src/tui/components/markdown.rs delete mode 100644 crates/atuin-ai/src/tui/components/mod.rs delete mode 100644 crates/atuin-ai/src/tui/components/select.rs delete mode 100644 crates/atuin-ai/src/tui/components/session_continue.rs delete mode 100644 crates/atuin-ai/src/tui/content/help.md delete mode 100644 crates/atuin-ai/src/tui/events.rs delete mode 100644 crates/atuin-ai/src/tui/mod.rs delete mode 100644 crates/atuin-ai/src/tui/slash.rs delete mode 100644 crates/atuin-ai/src/tui/state.rs delete mode 100644 crates/atuin-ai/src/tui/view/mod.rs delete mode 100644 crates/atuin-ai/src/tui/view/turn.rs delete mode 100644 crates/atuin-ai/src/user_context/interpolate.rs delete mode 100644 crates/atuin-ai/src/user_context/mod.rs delete mode 100644 crates/atuin-ai/src/user_context/walker.rs delete mode 100644 crates/atuin-ai/test-renders.json delete mode 100644 crates/atuin-dotfiles/Cargo.toml delete mode 100644 crates/atuin-dotfiles/src/lib.rs delete mode 100644 crates/atuin-dotfiles/src/shell.rs delete mode 100644 crates/atuin-dotfiles/src/shell/bash.rs delete mode 100644 crates/atuin-dotfiles/src/shell/fish.rs delete mode 100644 crates/atuin-dotfiles/src/shell/powershell.rs delete mode 100644 crates/atuin-dotfiles/src/shell/xonsh.rs delete mode 100644 crates/atuin-dotfiles/src/shell/zsh.rs delete mode 100644 crates/atuin-dotfiles/src/store.rs delete mode 100644 crates/atuin-dotfiles/src/store/alias.rs delete mode 100644 crates/atuin-dotfiles/src/store/var.rs delete mode 100644 crates/atuin-scripts/Cargo.toml delete mode 100644 crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql delete mode 100644 crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql delete mode 100644 crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql delete mode 100644 crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql delete mode 100644 crates/atuin-scripts/src/database.rs delete mode 100644 crates/atuin-scripts/src/execution.rs delete mode 100644 crates/atuin-scripts/src/lib.rs delete mode 100644 crates/atuin-scripts/src/settings.rs delete mode 100644 crates/atuin-scripts/src/store.rs delete mode 100644 crates/atuin-scripts/src/store/record.rs delete mode 100644 crates/atuin-scripts/src/store/script.rs delete mode 100644 crates/atuin/contrib/pi/atuin.ts delete mode 100644 crates/atuin/src/command/client/dotfiles.rs delete mode 100644 crates/atuin/src/command/client/dotfiles/alias.rs delete mode 100644 crates/atuin/src/command/client/dotfiles/var.rs delete mode 100644 crates/atuin/src/command/client/hook.rs delete mode 100644 crates/atuin/src/command/client/scripts.rs diff --git a/.gitignore b/.gitignore index c4ccffb0..985e2a88 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ publish.sh .envrc .planning/ +.direnv + ui/backend/target ui/backend/gen diff --git a/Cargo.lock b/Cargo.lock index 73c3ed63..3c9d56ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,40 +143,12 @@ dependencies = [ "password-hash", ] -[[package]] -name = "arraydeque" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" - [[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "async-trait" version = "0.1.89" @@ -218,16 +190,13 @@ version = "18.16.1" dependencies = [ "arboard", "async-trait", - "atuin-ai", "atuin-client", "atuin-common", "atuin-daemon", - "atuin-dotfiles", "atuin-history", "atuin-kv", "atuin-nucleo-matcher", "atuin-pty-proxy", - "atuin-scripts", "atuin-server", "atuin-server-database", "atuin-server-postgres", @@ -269,60 +238,6 @@ dependencies = [ "tracing-tree", "unicode-width 0.2.2", "uuid", - "windows-sys 0.61.2", -] - -[[package]] -name = "atuin-ai" -version = "18.16.1" -dependencies = [ - "async-stream", - "async-trait", - "atuin-client", - "atuin-common", - "atuin-daemon", - "chrono", - "chrono-humanize", - "clap", - "crossterm", - "directories", - "eventsource-stream", - "eye_declare", - "eyre", - "fs-err", - "futures", - "glob-match", - "imara-diff", - "pretty_assertions", - "pulldown-cmark", - "ratatui", - "ratatui-core", - "ratatui-widgets", - "regex", - "reqwest", - "serde", - "serde_json", - "shellexpand", - "sqlx", - "tempfile", - "thiserror 2.0.18", - "time", - "tokio", - "toml", - "toml_edit", - "tracing", - "tracing-appender", - "tracing-subscriber", - "tree-sitter", - "tree-sitter-bash", - "tree-sitter-fish", - "tui-textarea-2", - "typed-builder 0.18.2", - "unicode-width 0.2.2", - "uuid", - "vt100", - "xxhash-rust", - "yaml-rust2", ] [[package]] @@ -374,7 +289,7 @@ dependencies = [ "time", "tiny-bip39", "tokio", - "typed-builder 0.18.2", + "typed-builder", "urlencoding", "uuid", "whoami 2.1.1", @@ -396,7 +311,7 @@ dependencies = [ "sysinfo", "thiserror 2.0.18", "time", - "typed-builder 0.18.2", + "typed-builder", "uuid", ] @@ -406,7 +321,6 @@ version = "18.16.1" dependencies = [ "atuin-client", "atuin-common", - "atuin-dotfiles", "atuin-history", "atuin-nucleo", "dashmap", @@ -433,20 +347,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "atuin-dotfiles" -version = "18.16.1" -dependencies = [ - "atuin-client", - "atuin-common", - "crypto_secretbox", - "eyre", - "rand 0.8.5", - "rmp", - "serde", - "tokio", -] - [[package]] name = "atuin-history" version = "18.16.1" @@ -473,7 +373,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", - "typed-builder 0.18.2", + "typed-builder", ] [[package]] @@ -515,28 +415,6 @@ dependencies = [ "vt100", ] -[[package]] -name = "atuin-scripts" -version = "18.16.1" -dependencies = [ - "atuin-client", - "atuin-common", - "eyre", - "minijinja", - "pretty_assertions", - "rmp", - "serde", - "serde_json", - "sql-builder", - "sqlx", - "tempfile", - "tokio", - "tracing", - "tracing-subscriber", - "typed-builder 0.18.2", - "uuid", -] - [[package]] name = "atuin-server" version = "18.16.1" @@ -840,22 +718,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", - "js-sys", "num-traits", "serde", - "wasm-bindgen", "windows-link", ] -[[package]] -name = "chrono-humanize" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799627e6b4d27827a814e837b9d8a504832086081806d45b1afa34dc982b023b" -dependencies = [ - "chrono", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1124,7 +991,6 @@ dependencies = [ "derive_more", "document-features", "filedescriptor", - "futures-core", "mio", "parking_lot", "rustix", @@ -1478,15 +1344,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" -[[package]] -name = "encoding_rs" -version = "0.8.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" -dependencies = [ - "cfg-if", -] - [[package]] name = "equivalent" version = "1.0.2" @@ -1540,44 +1397,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "eventsource-stream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" -dependencies = [ - "futures-core", - "nom 7.1.3", - "pin-project-lite", -] - -[[package]] -name = "eye_declare" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e4caff9f5574315258489ecb2d27aad2cc057521f4b0bc6819a1cd82a648f40" -dependencies = [ - "crossterm", - "eye_declare_macros", - "futures", - "ratatui-core", - "ratatui-widgets", - "tokio", - "typed-builder 0.23.2", - "unicode-width 0.2.2", -] - -[[package]] -name = "eye_declare_macros" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb8ac27b19f79b61a8afc0a55a3148eb2dafc2aaec66c1103432e3e76e7c25af" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "eyre" version = "0.6.12" @@ -1858,15 +1677,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "getopts" -version = "0.2.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" -dependencies = [ - "unicode-width 0.2.2", -] - [[package]] name = "getrandom" version = "0.2.17" @@ -1903,12 +1713,6 @@ dependencies = [ "wasip3", ] -[[package]] -name = "glob-match" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d" - [[package]] name = "h2" version = "0.4.13" @@ -1975,15 +1779,6 @@ dependencies = [ "hashbrown 0.15.5", ] -[[package]] -name = "hashlink" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" -dependencies = [ - "hashbrown 0.16.1", -] - [[package]] name = "heck" version = "0.5.0" @@ -2287,16 +2082,6 @@ dependencies = [ "icu_properties", ] -[[package]] -name = "imara-diff" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f01d462f766df78ab820dd06f5eb700233c51f0f4c2e846520eaf4ba6aa5c5c" -dependencies = [ - "hashbrown 0.15.5", - "memchr", -] - [[package]] name = "indenter" version = "0.3.4" @@ -2746,12 +2531,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" -[[package]] -name = "memo-map" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" - [[package]] name = "memoffset" version = "0.9.1" @@ -2829,16 +2608,6 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" -[[package]] -name = "minijinja" -version = "2.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "328251e58ad8e415be6198888fc207502727dc77945806421ab34f35bf012e7d" -dependencies = [ - "memo-map", - "serde", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -3639,18 +3408,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" dependencies = [ "bitflags 2.11.0", - "getopts", "memchr", - "pulldown-cmark-escape", "unicase", ] -[[package]] -name = "pulldown-cmark-escape" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae" - [[package]] name = "pulldown-cmark-to-cmark" version = "22.0.0" @@ -4362,7 +4123,6 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ - "indexmap 2.13.0", "itoa", "memchr", "serde", @@ -4653,7 +4413,7 @@ dependencies = [ "futures-io", "futures-util", "hashbrown 0.15.5", - "hashlink 0.10.0", + "hashlink", "indexmap 2.13.0", "log", "memchr", @@ -4834,12 +4594,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" - [[package]] name = "stringprep" version = "0.1.5" @@ -5228,12 +4982,10 @@ version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "994b95d9e7bae62b34bab0e2a4510b801fa466066a6a8b2b57361fa1eba068ee" dependencies = [ - "indexmap 2.13.0", "serde_core", "serde_spanned", "toml_datetime", "toml_parser", - "toml_writer", "winnow 1.0.1", ] @@ -5502,46 +5254,6 @@ dependencies = [ "tracing-subscriber", ] -[[package]] -name = "tree-sitter" -version = "0.26.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "887bd495d0582c5e3e0d8ece2233666169fa56a9644d172fc22ad179ab2d0538" -dependencies = [ - "cc", - "regex", - "regex-syntax", - "serde_json", - "streaming-iterator", - "tree-sitter-language", -] - -[[package]] -name = "tree-sitter-bash" -version = "0.25.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5ec769279cc91b561d3df0d8a5deb26b0ad40d183127f409494d6d8fc53062" -dependencies = [ - "cc", - "tree-sitter-language", -] - -[[package]] -name = "tree-sitter-fish" -version = "3.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "014e3b299f251e9c2e372e3b5e1b0323ef21196e9aa2e90a5bc1f6130cbe8b18" -dependencies = [ - "cc", - "tree-sitter", -] - -[[package]] -name = "tree-sitter-language" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" - [[package]] name = "tree_magic_mini" version = "3.2.2" @@ -5559,36 +5271,13 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tui-textarea-2" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74a31ca0965e3ff6a7ac5ecb02b20a88b4f68ebf138d8ae438e8510b27a1f00f" -dependencies = [ - "crossterm", - "portable-atomic", - "ratatui-core", - "ratatui-widgets", - "unicode-segmentation", - "unicode-width 0.2.2", -] - [[package]] name = "typed-builder" version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77739c880e00693faef3d65ea3aad725f196da38b22fdc7ea6ded6e1ce4d3add" dependencies = [ - "typed-builder-macro 0.18.2", -] - -[[package]] -name = "typed-builder" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31aa81521b70f94402501d848ccc0ecaa8f93c8eb6999eb9747e72287757ffda" -dependencies = [ - "typed-builder-macro 0.23.2", + "typed-builder-macro", ] [[package]] @@ -5602,17 +5291,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "typed-builder-macro" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076a02dc54dd46795c2e9c8282ed40bcfb1e22747e955de9389a1de28190fb26" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "typenum" version = "1.19.0" @@ -6750,23 +6428,6 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd" -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - -[[package]] -name = "yaml-rust2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631a50d867fafb7093e709d75aaee9e0e0d5deb934021fcea25ac2fe09edc51e" -dependencies = [ - "arraydeque", - "encoding_rs", - "hashlink 0.11.0", -] - [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index 805e1b89..bb02487c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"] [workspace.package] version = "18.16.1" authors = ["Ellie Huxtable "] -rust-version = "1.96.0" +rust-version = "1.95.0" license = "MIT" homepage = "https://atuin.sh" repository = "https://github.com/atuinsh/atuin" @@ -76,19 +76,9 @@ xxhash-rust = { version = "0.8", features = ["xxh3"] } vt100 = "0.16" regex = "1.10.5" toml_edit = "0.25.4" - -[workspace.dependencies.tracing-subscriber] -version = "0.3" -features = ["ansi", "fmt", "registry", "env-filter", "json"] - -[workspace.dependencies.reqwest] -version = "0.13" -features = ["json", "rustls-no-provider", "stream"] -default-features = false - -[workspace.dependencies.sqlx] -version = "0.8" -features = ["runtime-tokio-rustls", "time", "postgres", "uuid"] +tracing-subscriber = { version = "0.3", features = ["ansi", "fmt", "registry", "env-filter", "json"] } +reqwest = { version = "0.13", features = ["json", "rustls-no-provider", "stream"], default-features = false } +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "time", "postgres", "uuid"] } # The profile that 'cargo dist' will build with [profile.dist] diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml deleted file mode 100644 index 027bd490..00000000 --- a/crates/atuin-ai/Cargo.toml +++ /dev/null @@ -1,74 +0,0 @@ -[package] -name = "atuin-ai" -edition = "2024" -description = "AI integration for Atuin CLI" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[features] -default = [] -daemon = [] -tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] - -[dependencies] -async-trait = { workspace = true } -atuin-client = { workspace = true } -atuin-common = { workspace = true } -atuin-daemon = { workspace = true } -tokio = { workspace = true } -eyre = { workspace = true } -clap = { workspace = true, features = ["derive", "env"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = [ - "ansi", - "fmt", - "registry", - "env-filter", -] } -directories = { workspace = true } -tracing-appender = "0.2.4" -reqwest = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -crossterm = { workspace = true, features = ["use-dev-tty", "event-stream"] } -ratatui = { workspace = true } -fs-err = { workspace = true } -futures = "0.3" -eventsource-stream = "0.2" -pulldown-cmark = "0.13.0" -async-stream = "0.3" -uuid = { workspace = true } -tui-textarea-2 = "0.10.2" -unicode-width = "0.2" -eye_declare = "0.5.1" -ratatui-core = "0.1" -ratatui-widgets = "0.3" -thiserror = { workspace = true } -glob-match = { workspace = true } -regex = { workspace = true } -time = { workspace = true } -toml = "1.1" -toml_edit = { workspace = true } -tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true } -tree-sitter-bash = { version = "0.25.1", optional = true } -tree-sitter-fish = { version = "3.6.0", optional = true } -sqlx = { workspace = true, features = ["sqlite"] } -typed-builder = { workspace = true } -shellexpand = { workspace = true } -imara-diff = { workspace = true } -xxhash-rust = { workspace = true } -vt100 = { workspace = true } -yaml-rust2 = "0.11" -tempfile = { workspace = true } -chrono = "0.4" -chrono-humanize = "0.2" - -[dev-dependencies] -pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql b/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql deleted file mode 100644 index 906a5726..00000000 --- a/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql +++ /dev/null @@ -1,32 +0,0 @@ -CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - head_id TEXT, - server_session_id TEXT, - directory TEXT, - git_root TEXT, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL, - archived_at INTEGER -); - -CREATE INDEX idx_sessions_directory ON sessions(directory); -CREATE INDEX idx_sessions_git_root ON sessions(git_root); -CREATE INDEX idx_sessions_updated_at ON sessions(updated_at); -CREATE INDEX idx_sessions_created_at ON sessions(created_at); - -CREATE TABLE IF NOT EXISTS session_events ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - parent_id TEXT, - invocation_id TEXT NOT NULL, - event_type TEXT NOT NULL, - event_data TEXT NOT NULL, - created_at INTEGER NOT NULL, - - FOREIGN KEY (session_id) REFERENCES sessions(id) -); - -CREATE INDEX idx_session_events_session_id ON session_events(session_id); -CREATE INDEX idx_session_events_parent_id ON session_events(parent_id); -CREATE INDEX idx_session_events_invocation_id ON session_events(invocation_id); -CREATE INDEX idx_session_events_created_at ON session_events(created_at); diff --git a/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql b/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql deleted file mode 100644 index f97dfd1b..00000000 --- a/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql +++ /dev/null @@ -1,9 +0,0 @@ -CREATE TABLE IF NOT EXISTS session_metadata ( - session_id TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - updated_at INTEGER NOT NULL, - - PRIMARY KEY (session_id, key), - FOREIGN KEY (session_id) REFERENCES sessions(id) -); diff --git a/crates/atuin-ai/render-tests.sh b/crates/atuin-ai/render-tests.sh deleted file mode 100755 index 8dedc76e..00000000 --- a/crates/atuin-ai/render-tests.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# Render all test cases from test-renders.json -# Usage: ./render-tests.sh [test_name] -# With no args: renders all tests -# With arg: renders only matching test (e.g., ./render-tests.sh 05) - -set -e -cd "$(dirname "$0")" - -JSON_FILE="test-renders.json" -FILTER="${1:-}" - -# Build once -cargo build -p atuin-ai --quiet - -# Count tests -TOTAL=$(jq length "$JSON_FILE") - -for i in $(seq 0 $((TOTAL - 1))); do - NAME=$(jq -r ".[$i].name" "$JSON_FILE") - DESC=$(jq -r ".[$i].description" "$JSON_FILE") - STATE=$(jq -c ".[$i].state" "$JSON_FILE") - - # Skip if filter provided and doesn't match - if [[ -n "$FILTER" && ! "$NAME" =~ $FILTER ]]; then - continue - fi - - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$NAME] $DESC" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin-ai --quiet -- debug-render -f plain - echo "" -done diff --git a/crates/atuin-ai/replay-states.sh b/crates/atuin-ai/replay-states.sh deleted file mode 100755 index 791ad47e..00000000 --- a/crates/atuin-ai/replay-states.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/bin/bash -# Replay state snapshots from a debug state JSONL file -# Usage: ./replay-states.sh [entry-number] -# With no entry: renders all frames in sequence (press Enter to advance) -# With entry number: renders just that frame - -set -e -# cd "$(dirname "$0")" - -STATE_FILE="${1:-}" -ENTRY_FILTER="${2:-}" - -if [[ -z "$STATE_FILE" ]]; then - echo "Usage: $0 [entry-number]" - echo "" - echo "Examples:" - echo " $0 /tmp/state.jsonl # Interactive replay of all frames" - echo " $0 /tmp/state.jsonl 15 # Show just entry 15" - exit 1 -fi - -if [[ ! -f "$STATE_FILE" ]]; then - echo "Error: File not found: $STATE_FILE" - exit 1 -fi - -# Build once -cargo build -p atuin --quiet - -# Count entries -TOTAL=$(wc -l < "$STATE_FILE" | tr -d ' ') - -if [[ -n "$ENTRY_FILTER" ]]; then - # Show single entry - LINE=$(sed -n "${ENTRY_FILTER}p" "$STATE_FILE") - if [[ -z "$LINE" ]]; then - echo "Error: Entry $ENTRY_FILTER not found (file has $TOTAL entries)" - exit 1 - fi - - ENTRY=$(echo "$LINE" | jq -r '.entry') - LABEL=$(echo "$LINE" | jq -r '.label') - STATE=$(echo "$LINE" | jq -c '.state') - - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$ENTRY/$TOTAL] $LABEL" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin --quiet -- ai debug-render -f ansi -else - # Interactive replay - echo "Replaying $TOTAL frames from $STATE_FILE" - echo "Press Enter to advance, 'q' to quit, or number+Enter to jump" - echo "" - - CURRENT=1 - while [[ $CURRENT -le $TOTAL ]]; do - LINE=$(sed -n "${CURRENT}p" "$STATE_FILE") - ENTRY=$(echo "$LINE" | jq -r '.entry') - LABEL=$(echo "$LINE" | jq -r '.label') - STATE=$(echo "$LINE" | jq -c '.state') - - clear - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$CURRENT/$TOTAL] $LABEL" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin --quiet -- ai debug-render -f ansi - echo "" - echo "[Enter: next] [p: prev] [number: jump] [s: show state JSON] [q: quit]" - - read -r INPUT - case "$INPUT" in - q|Q) - break - ;; - p|P) - if [[ $CURRENT -gt 1 ]]; then - CURRENT=$((CURRENT - 1)) - fi - ;; - s|S) - echo "" - echo "State JSON:" - echo "$STATE" | jq . - echo "" - echo "Press Enter to continue..." - read -r - ;; - ''|' ') - CURRENT=$((CURRENT + 1)) - ;; - *[0-9]*) - if [[ "$INPUT" =~ ^[0-9]+$ ]] && [[ "$INPUT" -ge 1 ]] && [[ "$INPUT" -le $TOTAL ]]; then - CURRENT=$INPUT - else - echo "Invalid entry number (1-$TOTAL)" - sleep 1 - fi - ;; - esac - done -fi diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs deleted file mode 100644 index cdbc8f2d..00000000 --- a/crates/atuin-ai/src/commands.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::{ - fs, - path::{Path, PathBuf}, -}; - -use atuin_common::shell::Shell; -use clap::{Args, Subcommand}; -use eyre::Result; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; -pub mod init; -pub(crate) mod inline; - -#[derive(Args, Debug)] -pub struct AiArgs { - /// Enable verbose logging - #[arg(short, long, global = true)] - verbose: bool, - - /// Custom API endpoint; defaults to reading from the `ai.endpoint` setting. - #[arg(long, global = true)] - api_endpoint: Option, - - /// Custom API token; defaults to reading from the `ai.api_token` setting. - #[arg(long, global = true)] - api_token: Option, -} - -#[derive(Subcommand, Debug)] -pub enum Commands { - /// Initialize shell integration - Init { - /// Shell to generate integration for; defaults to "auto" - #[arg(value_name = "SHELL", default_value = "auto")] - shell: String, - }, - - /// Inline completion mode with small TUI overlay - Inline { - #[command(flatten)] - args: AiArgs, - - /// Current command line to complete - #[arg(value_name = "COMMAND")] - command: Option, - - /// Use the hook mode - #[arg(long, hide = true)] - hook: bool, - }, -} - -pub async fn run( - command: Commands, - settings: &atuin_client::settings::Settings, -) -> eyre::Result<()> { - match command { - Commands::Init { shell } => init::run(shell).await, - Commands::Inline { - command, - hook, - args, - .. - } => { - if settings.logs.ai_enabled() { - init_logging(settings, args.verbose)?; - } - - inline::run(command, args.api_endpoint, args.api_token, settings, hook).await - } - } -} - -pub(crate) fn detect_shell() -> Option { - Some(Shell::current().to_string()) -} - -/// Initializes logging for the AI commands. -fn init_logging(settings: &atuin_client::settings::Settings, verbose: bool) -> Result<()> { - // ATUIN_LOG env var overrides config file level settings - let env_log_set = std::env::var("ATUIN_LOG").is_ok(); - - // Base filter from env var (or empty if not set) - let base_filter = - EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); - - // Use config level unless ATUIN_LOG is set - let filter = if env_log_set { - base_filter - } else { - EnvFilter::default() - .add_directive(settings.logs.ai_level().as_directive().parse()?) - .add_directive("sqlx_sqlite::regexp=off".parse()?) - }; - - let log_dir = PathBuf::from(&settings.logs.dir); - let ai_log_filename = settings.logs.ai.file.clone(); - - // Clean up old log files - cleanup_old_logs(&log_dir, &ai_log_filename, settings.logs.ai_retention()); - - let console_layer = if verbose { - Some( - fmt::layer() - .with_writer(std::io::stderr) - .with_ansi(true) - .with_target(false) - .with_filter(filter.clone()), - ) - } else { - None - }; - - let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, &ai_log_filename); - - let base = tracing_subscriber::registry().with( - fmt::layer() - .with_writer(file_appender) - .with_ansi(false) - .with_filter(filter), - ); - - if let Some(console_layer) = console_layer { - base.with(console_layer).init(); - } else { - base.init(); - }; - - Ok(()) -} - -fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { - let cutoff = std::time::SystemTime::now() - - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); - - let Ok(entries) = fs::read_dir(log_dir) else { - return; - }; - - for entry in entries.flatten() { - let path = entry.path(); - let Some(name) = path.file_name().and_then(|n| n.to_str()) else { - continue; - }; - - // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" - if !name.starts_with(prefix) || name == prefix { - continue; - } - - if let Ok(metadata) = entry.metadata() - && let Ok(modified) = metadata.modified() - && modified < cutoff - { - let _ = fs::remove_file(&path); - } - } -} diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs deleted file mode 100644 index 1f03f5b1..00000000 --- a/crates/atuin-ai/src/commands/init.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::commands::detect_shell; - -pub(crate) async fn run(shell: String) -> eyre::Result<()> { - let integration = match shell.as_str() { - "zsh" => generate_zsh_integration(), - "bash" => generate_bash_integration(), - "fish" => generate_fish_integration(), - "auto" => generate_auto_integration()?, - _ => eyre::bail!("Unsupported shell: {}", shell), - }; - - println!("{}", integration); - Ok(()) -} - -fn generate_auto_integration() -> eyre::Result<&'static str> { - let shell = detect_shell(); - match shell.as_deref() { - Some("zsh") => Ok(generate_zsh_integration()), - Some("bash") => Ok(generate_bash_integration()), - Some("fish") => Ok(generate_fish_integration()), - Some(s) => eyre::bail!("Unsupported shell: {}", s), - None => eyre::bail!("Could not detect shell"), - } -} - -/// Generate the zsh integration function - pure function for easy testing -pub fn generate_zsh_integration() -> &'static str { - r#" -# TUI uses an alternate screen, so no explicit cleanup is needed. -_atuin_ai_cleanup() { - true -} - -# Question mark at start of line - natural language mode. -# Named with 'self-' prefix so bracketed-paste-magic activates it during -# paste, allowing url-quote-magic to escape ? in pasted URLs via self-insert. -self-atuin-ai-question-mark() { - # If buffer is empty or just contains '?', trigger natural language mode - if [[ -z "$BUFFER" || "$BUFFER" == "?" ]]; then - BUFFER="" - local output - output=$(atuin ai inline --hook 3>&1 1>&2 2>&3) - - # Clean up the inline viewport - _atuin_ai_cleanup - - if [[ $output == __atuin_ai_print__:* ]]; then - zle -I - echo "${output#__atuin_ai_print__:}" - elif [[ $output == __atuin_ai_cancel__ ]]; then - zle reset-prompt - elif [[ $output == __atuin_ai_execute__:* ]]; then - RBUFFER="" - LBUFFER=${output#__atuin_ai_execute__:} - zle reset-prompt - zle accept-line - elif [[ $output == __atuin_ai_insert__:* ]]; then - RBUFFER="" - LBUFFER=${output#__atuin_ai_insert__:} - zle reset-prompt - elif [[ -n $output ]]; then - RBUFFER="" - LBUFFER=$output - zle reset-prompt - else - zle reset-prompt - fi - else - zle self-insert - fi -} - -# Set up keybindings -zle -N self-atuin-ai-question-mark -bindkey '?' self-atuin-ai-question-mark # Question mark -"# - .trim() -} - -/// Generate the bash integration function - pure function for easy testing -pub fn generate_bash_integration() -> &'static str { - r#" -# Question mark at start of line - natural language mode -_atuin_ai_question_mark() { - # If buffer is empty or just contains '?', trigger natural language mode - if [[ -z "$READLINE_LINE" || "$READLINE_LINE" == "?" ]]; then - READLINE_LINE="" - READLINE_POINT=0 - - local output - output=$(atuin ai inline --hook 3>&1 1>&2 2>&3) - - if [[ $output == __atuin_ai_print__:* ]]; then - echo "${output#__atuin_ai_print__:}" - READLINE_LINE="" - READLINE_POINT=0 - elif [[ $output == __atuin_ai_cancel__ ]]; then - READLINE_LINE="" - READLINE_POINT=0 - elif [[ $output == __atuin_ai_execute__:* ]]; then - # Execute the command immediately - READLINE_LINE=${output#__atuin_ai_execute__:} - READLINE_POINT=${#READLINE_LINE} - # Note: We can't directly execute in bash bind -x, but we can - # use a workaround by binding to a macro that accepts the line - bind '"\C-x\C-a": accept-line' - bind -x '"\C-x\C-e": _atuin_ai_question_mark' - elif [[ $output == __atuin_ai_insert__:* ]]; then - # Insert the command for editing - READLINE_LINE=${output#__atuin_ai_insert__:} - READLINE_POINT=${#READLINE_LINE} - elif [[ -n $output ]]; then - # Default: insert for editing - READLINE_LINE=$output - READLINE_POINT=${#READLINE_LINE} - fi - else - # Not at empty prompt, just insert the question mark - READLINE_LINE="${READLINE_LINE:0:READLINE_POINT}?${READLINE_LINE:READLINE_POINT}" - ((READLINE_POINT++)) - fi -} - -# Set up keybindings -# Bash requires special handling: we use bind -x for the function, -# but need a two-step approach for execute mode -__atuin_ai_accept_line="" - -_atuin_ai_question_mark_wrapper() { - _atuin_ai_question_mark - if [[ -n "$__atuin_ai_accept_line" ]]; then - __atuin_ai_accept_line="" - fi -} - -bind -x '"?": _atuin_ai_question_mark' -"# - .trim() -} - -/// Generate the fish integration function - pure function for easy testing -pub fn generate_fish_integration() -> &'static str { - r#" -# Question mark at start of line - natural language mode -function _atuin_ai_question_mark - set -l buf (commandline -b) - - # If buffer is empty or just contains '?', trigger natural language mode - if test -z "$buf" -o "$buf" = "?" - commandline -r "" - - # Run atuin ai inline, swapping stdout and stderr - set -l output (atuin ai inline --hook 3>&1 1>&2 2>&3 | string collect) - - if string match --quiet '__atuin_ai_print__:*' "$output" - echo (string replace "__atuin_ai_print__:" "" -- "$output" | string collect) - commandline -f repaint - else if test "$output" = "__atuin_ai_cancel__" - commandline -f repaint - else if string match --quiet '__atuin_ai_execute__:*' "$output" - # Execute the command immediately - set -l cmd (string replace "__atuin_ai_execute__:" "" -- "$output" | string collect) - commandline -r "$cmd" - commandline -f repaint - commandline -f execute - else if string match --quiet '__atuin_ai_insert__:*' "$output" - # Insert the command for editing - set -l cmd (string replace "__atuin_ai_insert__:" "" -- "$output" | string collect) - commandline -r "$cmd" - commandline -f repaint - else if test -n "$output" - # Default: insert for editing - commandline -r "$output" - commandline -f repaint - else - commandline -f repaint - end - else if not contains -- "$fish_key_bindings" fish_vi_key_bindings fish_hybrid_key_bindings - # Not at empty prompt, just insert the question mark - commandline -i "?" - end -end - -# Set up keybindings -bind "?" _atuin_ai_question_mark -"# - .trim() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_zsh_integration() { - let result = generate_zsh_integration(); - assert!(result.contains("self-atuin-ai-question-mark")); - assert!(result.contains("bindkey")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - assert!(result.contains("zle self-insert")); - } - - #[test] - fn test_generate_bash_integration() { - let result = generate_bash_integration(); - assert!(result.contains("_atuin_ai_question_mark")); - assert!(result.contains("bind")); - assert!(result.contains("READLINE_LINE")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - } - - #[test] - fn test_generate_fish_integration() { - let result = generate_fish_integration(); - assert!(result.contains("_atuin_ai_question_mark")); - assert!(result.contains("bind")); - assert!(result.contains("commandline")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - } -} diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs deleted file mode 100644 index 6d1f9c51..00000000 --- a/crates/atuin-ai/src/commands/inline.rs +++ /dev/null @@ -1,587 +0,0 @@ -use std::path::PathBuf; -use std::sync::mpsc; - -use crate::context::{AppContext, ClientContext}; -use crate::driver::{DriverEvent, IoContext, ViewState, run_driver}; -use crate::fsm::AgentFsm; -use crate::fsm::effects::ExitAction; -use crate::session::{LocalSessionService, SessionManager, SessionService}; -use crate::tui::events::AiTuiEvent; -use crate::tui::state::ConversationEvent; -use crate::tui::view::ai_view; -use atuin_client::database::{Database, Sqlite}; -use eye_declare::{Application, CtrlCBehavior}; -use eyre::{Context as _, Result, bail}; -use tracing::{debug, info}; - -pub(crate) async fn run( - initial_command: Option, - api_endpoint: Option, - api_token: Option, - settings: &atuin_client::settings::Settings, - output_for_hook: bool, -) -> Result<()> { - if settings.ai.enabled == Some(false) { - return Ok(()); - } - - if settings.ai.enabled.is_none() { - match prompt_ai_setup()? { - SetupChoice::EnableAi => { - set_ai_enabled(true).await?; - } - SetupChoice::DisableKeybind => { - set_ai_enabled(false).await?; - emit_shell_result(Action::Cancel, output_for_hook); - return Ok(()); - } - SetupChoice::Cancel => { - emit_shell_result(Action::Cancel, output_for_hook); - return Ok(()); - } - } - } - - let endpoint = api_endpoint.as_deref().unwrap_or( - settings - .ai - .endpoint - .as_deref() - .unwrap_or("https://hub.atuin.sh"), - ); - let api_token = api_token.as_deref().or(settings.ai.api_token.as_deref()); - - let token = if let Some(token) = &api_token { - token.to_string() - } else { - ensure_hub_session(settings).await? - }; - - let history_db_path = PathBuf::from(settings.db_path.as_str()); - let history_db = Sqlite::new(history_db_path, settings.local_timeout) - .await - .context("failed to open history database for AI")?; - - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - history_db.last().await.ok().flatten() - } else { - None - }; - - let git_root = std::env::current_dir() - .ok() - .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); - - let ctx = AppContext { - endpoint: endpoint.to_string(), - token, - send_cwd, - last_command, - history_db: std::sync::Arc::new(history_db), - git_root, - capabilities: settings.ai.capabilities.clone(), - daemon_enabled: settings.daemon.enabled, - }; - - let action = run_inline_tui(ctx, initial_command, settings).await?; - emit_shell_result(action, output_for_hook); - - Ok(()) -} - -async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Result { - if let Some(token) = atuin_client::hub::get_session_token().await? { - debug!("Found Hub session, using existing token"); - return Ok(token); - } - - let hub_address = settings.active_hub_endpoint().unwrap_or_default(); - let will_sync = settings.is_hub_sync(); - - info!("No Hub session found, prompting for authentication"); - - println!("Atuin AI requires authenticating with Atuin Hub."); - if will_sync { - println!( - "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ); - } - println!( - "If you have an existing Atuin sync account, you can log in with your existing credentials." - ); - println!("Press enter to begin (or esc to cancel)."); - if !wait_for_login_confirmation()? { - bail!("authentication canceled"); - } - - debug!("Starting Atuin Hub authentication..."); - println!("Authenticating with Atuin Hub..."); - - let session = atuin_client::hub::HubAuthSession::start(hub_address.as_ref()).await?; - println!("Open this URL to continue:"); - println!("{}", session.auth_url); - - let token = session - .wait_for_completion( - atuin_client::hub::DEFAULT_AUTH_TIMEOUT, - atuin_client::hub::DEFAULT_POLL_INTERVAL, - ) - .await?; - - info!("Authentication complete, saving session token"); - atuin_client::hub::save_session(&token).await?; - - if let Ok(meta) = atuin_client::settings::Settings::meta_store().await - && let Ok(Some(cli_token)) = meta.session_token().await - { - debug!("CLI session found, attempting to link accounts"); - if let Err(e) = atuin_client::hub::link_account(hub_address.as_ref(), &cli_token).await { - debug!("Could not link CLI account to Hub: {}", e); - } else { - info!("Successfully linked CLI account to Hub"); - } - } - - Ok(token) -} - -// ─────────────────────────────────────────────────────────────────── - -async fn run_inline_tui( - ctx: AppContext, - initial_prompt: Option, - settings: &atuin_client::settings::Settings, -) -> Result { - let client_ctx = ClientContext::detect(); - - // Open the session service and check for a resumable session - let service = LocalSessionService::open(&settings.ai.db_path, settings.local_timeout) - .await - .context("failed to open AI session database")?; - - let cwd = std::env::current_dir() - .ok() - .map(|p| p.to_string_lossy().into_owned()); - let git_root_str = ctx - .git_root - .as_ref() - .map(|p| p.to_string_lossy().into_owned()); - - let session_window_mins = settings.ai.session_continue_minutes.max(0); // treat negative values as 0 to avoid confusion - let max_age_secs: i64 = session_window_mins * 60; - - let resumable = service - .find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs) - .await?; - - // ─── Build FSM ─────────────────────────────────────────────── - let (session_mgr, fsm, file_tracker, edit_permissions) = if let Some(stored) = resumable { - debug!(session_id = %stored.id, "resuming AI session"); - let (mgr, mut events, server_sid, last_event_ts, invocation_id) = - SessionManager::resume(Box::new(service), &stored).await?; - - let has_api_content = events.iter().any(|e| e.is_api_content()); - - if has_api_content { - events.push(ConversationEvent::SystemContext { - content: "[Note: The user has started a new invocation of Atuin AI. Prior messages from this session are from an earlier invocation.]".to_string(), - }); - let view_start = events.len(); - let last_time = last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); - - let ft = if let Ok(Some(json)) = - mgr.get_metadata(crate::file_tracker::METADATA_KEY).await - && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json) - { - tracker - } else { - Default::default() - }; - - let ep = if let Ok(Some(json)) = mgr - .get_metadata(crate::edit_permissions::METADATA_KEY) - .await - && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json) - { - cache - } else { - Default::default() - }; - - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::from_session( - events, - server_sid, - caps, - invocation_id, - view_start, - true, - last_time, - ); - (mgr, fsm, ft, ep) - } else { - debug!("resumable session has no API-visible content, starting fresh"); - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::new(caps, invocation_id); - (mgr, fsm, Default::default(), Default::default()) - } - } else { - debug!("creating new AI session"); - let mgr = - SessionManager::create_new(Box::new(service), cwd.as_deref(), git_root_str.as_deref()); - let invocation_id = uuid::Uuid::now_v7().to_string(); - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::new(caps, invocation_id); - (mgr, fsm, Default::default(), Default::default()) - }; - - // ─── Snapshot store ───────────────────────────────────────── - let snapshot_dir = atuin_common::utils::data_dir() - .join("ai") - .join("snapshots") - .join(session_mgr.session_id()); - let snapshot_store = crate::snapshots::SnapshotStore::open(snapshot_dir).ok(); - - let in_git_project = ctx.git_root.is_some(); - - // ─── Discover skills ─────────────────────────────────────── - let project_root = ctx - .git_root - .clone() - .or_else(|| std::env::current_dir().ok()); - let skill_registry = crate::skills::SkillRegistry::discover(project_root.as_deref()).await; - - // ─── Build initial ViewState from FSM ─────────────────────── - let initial_view = build_view_state(&fsm, in_git_project, &skill_registry); - - // ─── Build IoContext ──────────────────────────────────────── - let io = IoContext { - app_ctx: ctx.clone(), - client_ctx: client_ctx.clone(), - session_mgr, - file_tracker, - edit_permissions, - snapshot_store, - skill_registry, - }; - - // ─── Channel + Application ────────────────────────────────── - // Components emit DriverEvent::Tui(AiTuiEvent) via a wrapping sender. - // Spawned tasks emit DriverEvent::Fsm(Event) directly. - let (tx, rx) = mpsc::channel::(); - - // Wrap sender for components: they send AiTuiEvent, we wrap it - let tui_tx = DriverEventSender(tx.clone()); - - println!(); - - if let Some(prompt) = initial_prompt { - let _ = tui_tx - .0 - .send(DriverEvent::Tui(AiTuiEvent::SubmitInput(prompt))); - } - - let (mut app, handle) = Application::builder() - .state(initial_view) - .view(ai_view) - .ctrl_c(CtrlCBehavior::Deliver) - .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) - .bracketed_paste(true) - .with_context(tui_tx) - .extra_newlines_at_exit(1) - .on_commit(|committed, state| { - if let Some(key) = &committed.key - && let Some(id_str) = key.strip_prefix("turn-") - && let Ok(id) = id_str.parse::() - { - let new_count = id + 1; - if new_count > state.committed_turn_count { - state.committed_turn_count = new_count; - } - } - }) - .build()?; - - // ─── Driver loop ──────────────────────────────────────────── - let h = handle.clone(); - let exiting = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); - let exiting_clone = exiting.clone(); - let dispatch_handle = tokio::task::spawn_blocking(move || { - run_driver(fsm, io, h, rx, tx, exiting_clone, in_git_project); - }); - - let run_result = app.run_loop().await; - let _ = dispatch_handle.await; - run_result?; - - let result = match app.state().exit_action { - Some(ExitAction::Execute(ref cmd)) => Action::Execute(cmd.clone()), - Some(ExitAction::Insert(ref cmd)) => Action::Insert(cmd.clone()), - _ => Action::Cancel, - }; - - Ok(result) -} - -/// Wrapper around `mpsc::Sender` that components use as context. -/// -/// Components call `tx.send(AiTuiEvent::...)` via eye-declare's context system. -/// This wrapper implements the same interface but wraps events in `DriverEvent::Tui`. -#[derive(Debug, Clone)] -pub(crate) struct DriverEventSender(pub mpsc::Sender); - -impl DriverEventSender { - pub fn send(&self, event: AiTuiEvent) -> Result<(), mpsc::SendError> { - self.0 - .send(DriverEvent::Tui(event)) - .map_err(|_| mpsc::SendError(AiTuiEvent::Exit)) - } -} - -/// Build a ViewState snapshot from FSM state. Used for the initial view -/// and by the driver for ongoing sync. -fn build_view_state( - fsm: &AgentFsm, - in_git_project: bool, - skill_registry: &crate::skills::SkillRegistry, -) -> ViewState { - let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); - - let mut slash_registry = crate::tui::slash::SlashCommandRegistry::default(); - let mut skill_names = std::collections::HashSet::new(); - for skill in skill_registry.all() { - slash_registry.register(crate::tui::slash::SlashCommand::new( - &skill.name, - &skill.description, - )); - skill_names.insert(skill.name.clone()); - } - - let tools = fsm.ctx.tools.clone(); - let visible_events = fsm.ctx.events[safe_start..].to_vec(); - let archived_events = fsm.ctx.archived_events.clone(); - - let mut archived_builder = crate::tui::view::turn::TurnBuilder::new(&tools); - for event in &archived_events { - archived_builder.add_event(event); - } - let archived_turns = archived_builder.build(); - let archived_turn_count = archived_turns.len(); - - let mut visible_builder = - crate::tui::view::turn::TurnBuilder::new_starting_at(&tools, archived_turn_count); - for event in &visible_events { - visible_builder.add_event(event); - } - let visible_turns = visible_builder.build(); - - let mut turns = archived_turns; - turns.extend(visible_turns); - - let has_command = visible_events.iter().any(|e| { - matches!(e, ConversationEvent::ToolCall { name, input, .. } - if name == "suggest_command" - && input.get("command").and_then(|v| v.as_str()).is_some()) - }); - - ViewState { - agent_state: fsm.state.clone(), - visible_events, - all_events: fsm.ctx.events.clone(), - session_id: fsm.ctx.session_id.clone(), - tools, - current_response: fsm.ctx.current_response.clone(), - is_resumed: fsm.ctx.is_resumed, - last_event_time: fsm.ctx.last_event_time, - in_git_project, - archived_events, - turns, - has_command, - committed_turn_count: 0, - archived_turn_count, - is_input_blank: true, - slash_command_input: None, - slash_command_search_results: Vec::new(), - exit_action: None, - slash_registry, - skill_names, - } -} - -// ─────────────────────────────────────────────────────────────────── -// Helpers -// ─────────────────────────────────────────────────────────────────── - -enum SetupChoice { - EnableAi, - DisableKeybind, - Cancel, -} - -fn prompt_ai_setup() -> Result { - use crossterm::{ - cursor, - event::{self, Event, KeyCode}, - terminal, - }; - - let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; - let mut selected: usize = 0; - let mut stdout = std::io::stdout(); - - // Print header before raw mode so newlines render correctly. - // Use stdout because the shell hook swaps stdout/stderr — stdout goes - // to the terminal in both hook and non-hook modes. - println!(); - println!(" Atuin AI is not yet configured."); - println!(); - - terminal::enable_raw_mode().context("failed to enable raw mode")?; - struct Guard; - impl Drop for Guard { - fn drop(&mut self) { - let _ = terminal::disable_raw_mode(); - } - } - let _guard = Guard; - - crossterm::execute!(stdout, cursor::Hide)?; - - loop { - render_setup_options(&mut stdout, &options, selected)?; - - let ev = event::read().context("failed to read key event")?; - - crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; - - if let Event::Key(key) = ev { - match key.code { - KeyCode::Up | KeyCode::Char('k') => { - selected = selected.saturating_sub(1); - } - KeyCode::Down | KeyCode::Char('j') if selected < options.len() - 1 => { - selected += 1; - } - KeyCode::Enter => break, - KeyCode::Esc => { - selected = 2; - break; - } - _ => {} - } - } - } - - // Final render with selection visible - render_setup_options(&mut stdout, &options, selected)?; - crossterm::execute!(stdout, cursor::Show)?; - - Ok(match selected { - 0 => SetupChoice::EnableAi, - 1 => SetupChoice::DisableKeybind, - _ => SetupChoice::Cancel, - }) -} - -fn render_setup_options( - w: &mut impl std::io::Write, - options: &[&str], - selected: usize, -) -> Result<()> { - use crossterm::{ - style::Stylize, - terminal::{Clear, ClearType}, - }; - - for (i, option) in options.iter().enumerate() { - if i == selected { - write!(w, "\r {}", format!("> {option}").bold().cyan())?; - } else { - write!(w, "\r {option}")?; - } - crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; - write!(w, "\r\n")?; - } - w.flush()?; - Ok(()) -} - -async fn set_ai_enabled(enabled: bool) -> Result<()> { - let config_file = atuin_client::settings::Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let mut doc = config_str.parse::()?; - - if !doc.contains_key("ai") { - doc["ai"] = toml_edit::table(); - } - doc["ai"]["enabled"] = toml_edit::value(enabled); - - tokio::fs::write(&config_file, doc.to_string()).await?; - - if !enabled { - println!( - "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", - ); - println!("Restart your shell for changes to take effect."); - // Two printlns to ensure the message is visible above the shell prompt after program ends. - println!(); - println!(); - } - - Ok(()) -} - -fn wait_for_login_confirmation() -> Result { - use crossterm::{ - event::{self, Event, KeyCode}, - terminal::{disable_raw_mode, enable_raw_mode}, - }; - - enable_raw_mode().context("failed enabling raw mode for login prompt")?; - struct Guard; - impl Drop for Guard { - fn drop(&mut self) { - let _ = disable_raw_mode(); - } - } - let _guard = Guard; - - loop { - let ev = event::read().context("failed to read login confirmation key")?; - if let Event::Key(key) = ev { - match key.code { - KeyCode::Enter => return Ok(true), - KeyCode::Esc => return Ok(false), - _ => {} - } - } - } -} - -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Cancel, -} - -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), - } - } else { - match action { - Action::Execute(output) | Action::Insert(output) => { - println!("{output}"); - } - Action::Cancel => {} - } - } -} diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs deleted file mode 100644 index f891a9fc..00000000 --- a/crates/atuin-ai/src/context.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use atuin_client::distro::detect_linux_distribution; -use atuin_client::history::History; -use atuin_client::settings::AiCapabilities; - -/// Session-scoped context for the AI chat session. -/// Holds the API configuration and client settings needed by the event loop and stream task. -#[derive(Clone, Debug)] -pub(crate) struct AppContext { - pub endpoint: String, - pub token: String, - pub send_cwd: bool, - pub last_command: Option, - pub history_db: Arc, - /// Git root of the current working directory, if inside a git repo. - /// Resolves through worktrees to the main repo root. - pub git_root: Option, - pub capabilities: AiCapabilities, - pub daemon_enabled: bool, -} - -pub(crate) fn history_output_capability_available(daemon_enabled: bool) -> bool { - cfg!(feature = "daemon") && daemon_enabled -} - -impl AppContext { - pub(crate) fn capabilities_as_strings(&self) -> Vec { - let mut caps = vec!["client_invocations".to_string()]; - if self.capabilities.enable_history_search.unwrap_or(true) { - caps.push("client_v1_atuin_history".to_string()); - } - if self.capabilities.enable_file_tools.unwrap_or(true) { - caps.push("client_v1_read_file".to_string()); - caps.push("client_v1_edit_file".to_string()); - caps.push("client_v1_write_file".to_string()); - } - if self.capabilities.enable_command_execution.unwrap_or(true) { - caps.push("client_v1_execute_shell_command".to_string()); - } - if history_output_capability_available(self.daemon_enabled) - && self.capabilities.enable_history_output.unwrap_or(true) - { - caps.push("client_v1_atuin_output".to_string()); - } - caps.push("client_v1_load_skill".to_string()); - if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { - caps.extend( - extra - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()), - ); - } - caps - } -} - -/// Machine identity — computed once per session. -#[derive(Clone, Debug)] -pub(crate) struct ClientContext { - pub os: String, - pub shell: Option, - pub distro: Option, -} - -impl ClientContext { - pub(crate) fn detect() -> Self { - let os = detect_os(); - let shell = crate::commands::detect_shell(); - let distro = if os == "linux" { - Some(detect_linux_distribution()) - } else { - None - }; - Self { os, shell, distro } - } - - /// Serialize to the JSON format the API expects for the "context" field. - /// The `pwd` field is always dynamic (current working directory), so it's - /// computed fresh on each call if `send_cwd` is true. - pub(crate) fn to_json( - &self, - send_cwd: bool, - last_command: Option<&History>, - ) -> serde_json::Value { - let mut ctx = serde_json::json!({ - "os": self.os, - "shell": self.shell, - "pwd": if send_cwd { - std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) - } else { - None - }, - }); - - if let Some(history) = last_command { - ctx["last_command"] = serde_json::json!(crate::history_format::format_last_command( - history, - crate::history_format::current_local_offset(), - )); - } - - if let Some(ref distro) = self.distro { - ctx["distro"] = serde_json::json!(distro); - } - - ctx - } -} - -/// Move the `detect_os` function here since it's about client identity. -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), - } -} diff --git a/crates/atuin-ai/src/context_window.rs b/crates/atuin-ai/src/context_window.rs deleted file mode 100644 index dcef05aa..00000000 --- a/crates/atuin-ai/src/context_window.rs +++ /dev/null @@ -1,578 +0,0 @@ -//! Context window management for API requests. -//! -//! Full conversation events are always persisted to disk. This module handles -//! truncation at send time so the API payload stays within a character budget. -//! -//! Strategy: **frozen prefix + live tail**. The first N turns form a stable -//! prefix that stays identical across requests (maximizing prompt cache hits). -//! The most recent turns form the live tail. When the total exceeds the budget, -//! turns between prefix and tail are dropped with a truncation marker. The -//! prefix never shifts, avoiding cache invalidation. - -use std::ops::Range; - -use crate::tui::{ConversationEvent, events_to_messages}; - -/// Default character budget for the context window. -/// Roughly ~50K tokens at ~4 chars/token — generous enough that truncation -/// only kicks in for genuinely long sessions. -const DEFAULT_BUDGET_CHARS: usize = 200_000; - -/// Number of initial turns to freeze as the stable prefix. -const FROZEN_PREFIX_TURNS: usize = 1; - -/// Builds API messages from conversation events while respecting a character -/// budget using frozen prefix + live tail truncation. -pub(crate) struct ContextWindowBuilder { - budget: usize, -} - -impl ContextWindowBuilder { - pub fn new(budget: usize) -> Self { - Self { budget } - } - - pub fn with_default_budget() -> Self { - Self::new(DEFAULT_BUDGET_CHARS) - } - - /// Build API messages from conversation events, applying the context - /// window budget. Returns the messages to send in the API request. - pub fn build(&self, events: &[ConversationEvent]) -> Vec { - if events.is_empty() { - return Vec::new(); - } - - let turns = group_into_turns(events); - - // Convert each turn's events to API messages independently. - // This is safe because the combining logic (Text + ToolCall merging) - // only operates within a single assistant response, which never - // spans turn boundaries. - let turn_messages: Vec> = turns - .iter() - .map(|range| events_to_messages(&events[range.clone()])) - .collect(); - - let turn_chars: Vec = turn_messages.iter().map(|m| estimate_chars(m)).collect(); - let total_chars: usize = turn_chars.iter().sum(); - - if total_chars <= self.budget { - return turn_messages.into_iter().flatten().collect(); - } - - // --- Over budget: apply frozen prefix + live tail --- - - let prefix_count = FROZEN_PREFIX_TURNS.min(turns.len()); - let prefix_chars: usize = turn_chars[..prefix_count].iter().sum(); - - let marker = truncation_marker(); - let marker_chars = estimate_chars(std::slice::from_ref(&marker)); - - let mut remaining = self.budget.saturating_sub(prefix_chars + marker_chars); - - // Work backwards from the end, accumulating tail turns that fit. - let mut tail_start = turns.len(); - for i in (prefix_count..turns.len()).rev() { - if turn_chars[i] <= remaining { - remaining -= turn_chars[i]; - tail_start = i; - } else { - break; - } - } - - // Always include at least the most recent turn, even if it alone - // exceeds the budget — sending something is better than nothing. - if tail_start >= turns.len() && turns.len() > prefix_count { - tail_start = turns.len() - 1; - } - - let mut result = Vec::new(); - - // Frozen prefix - for msgs in &turn_messages[..prefix_count] { - result.extend(msgs.iter().cloned()); - } - - // Truncation marker (only if turns were actually dropped) - if tail_start > prefix_count { - result.push(marker); - } - - // Live tail - for msgs in &turn_messages[tail_start..] { - result.extend(msgs.iter().cloned()); - } - - result - } -} - -/// Marker message inserted where turns were dropped. Uses user role since -/// the preceding prefix typically ends with an assistant message. -fn truncation_marker() -> serde_json::Value { - serde_json::json!({ - "role": "user", - "content": "[Earlier conversation context was omitted to fit within the context window. The conversation continues below.]" - }) -} - -/// Group conversation events into turns. A new turn starts at each -/// `UserMessage` or `SystemContext` event. Everything between boundaries -/// belongs to the preceding turn (assistant text, tool calls, tool results, -/// out-of-band output). -fn group_into_turns(events: &[ConversationEvent]) -> Vec> { - let mut turns = Vec::new(); - let mut start = 0; - - for (i, event) in events.iter().enumerate() { - if i > start - && matches!( - event, - ConversationEvent::UserMessage { .. } | ConversationEvent::SystemContext { .. } - ) - { - turns.push(start..i); - start = i; - } - } - - if start < events.len() { - turns.push(start..events.len()); - } - - turns -} - -/// Rough character-count estimate for a set of messages. Uses the JSON -/// serialization length as a proxy — not exact tokens, but proportional -/// and cheap to compute. -fn estimate_chars(messages: &[serde_json::Value]) -> usize { - messages.iter().map(|m| m.to_string().len()).sum() -} - -#[cfg(test)] -mod tests { - use super::*; - - fn user(content: &str) -> ConversationEvent { - ConversationEvent::UserMessage { - content: content.to_string(), - } - } - - fn text(content: &str) -> ConversationEvent { - ConversationEvent::Text { - content: content.to_string(), - } - } - - fn tool_call(id: &str, name: &str) -> ConversationEvent { - ConversationEvent::ToolCall { - id: id.to_string(), - name: name.to_string(), - input: serde_json::json!({"command": "ls"}), - } - } - - fn tool_result(tool_use_id: &str, content: &str) -> ConversationEvent { - ConversationEvent::ToolResult { - tool_use_id: tool_use_id.to_string(), - content: content.to_string(), - is_error: false, - remote: false, - content_length: None, - } - } - - fn system_context(content: &str) -> ConversationEvent { - ConversationEvent::SystemContext { - content: content.to_string(), - } - } - - fn oob(content: &str) -> ConversationEvent { - ConversationEvent::OutOfBandOutput { - name: "test".to_string(), - command: None, - content: content.to_string(), - } - } - - // --- group_into_turns --- - - #[test] - fn empty_events_produce_no_turns() { - assert!(group_into_turns(&[]).is_empty()); - } - - #[test] - fn single_user_message_is_one_turn() { - let events = vec![user("hello")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..1]); - } - - #[test] - fn user_assistant_is_one_turn() { - let events = vec![user("hello"), text("hi there")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2]); - } - - #[test] - fn two_turns_split_at_user_message() { - let events = vec![ - user("first"), - text("response 1"), - user("second"), - text("response 2"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2, 2..4]); - } - - #[test] - fn tool_calls_and_results_stay_in_same_turn() { - let events = vec![ - user("list files"), - text("Let me check"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", "file1\nfile2"), - text("Here are your files"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..5]); - } - - #[test] - fn system_context_starts_new_turn() { - let events = vec![ - user("hello"), - text("hi"), - system_context("invocation boundary"), - user("next question"), - text("answer"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2, 2..3, 3..5]); - } - - #[test] - fn oob_events_stay_in_current_turn() { - let events = vec![user("hello"), oob("some output"), text("response")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..3]); - } - - #[test] - fn leading_text_without_user_message() { - // Edge case: events start with assistant text (shouldn't happen - // normally but handle gracefully) - let events = vec![text("orphaned"), user("hello"), text("hi")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..1, 1..3]); - } - - // --- ContextWindowBuilder --- - - #[test] - fn empty_events_produce_empty_messages() { - let builder = ContextWindowBuilder::with_default_budget(); - assert!(builder.build(&[]).is_empty()); - } - - #[test] - fn under_budget_returns_all_messages() { - let events = vec![user("hello"), text("hi"), user("how are you"), text("good")]; - let builder = ContextWindowBuilder::with_default_budget(); - let messages = builder.build(&events); - - // Should produce 4 messages (2 user + 2 assistant) - assert_eq!(messages.len(), 4); - assert_eq!(messages[0]["role"], "user"); - assert_eq!(messages[0]["content"], "hello"); - assert_eq!(messages[1]["role"], "assistant"); - assert_eq!(messages[1]["content"], "hi"); - assert_eq!(messages[2]["role"], "user"); - assert_eq!(messages[2]["content"], "how are you"); - assert_eq!(messages[3]["role"], "assistant"); - assert_eq!(messages[3]["content"], "good"); - } - - #[test] - fn over_budget_truncates_middle_turns() { - // Create events where each turn has known content. Use a tiny - // budget so truncation is triggered with just a few turns. - let events = vec![ - user("turn-1-user"), - text("turn-1-assistant"), - user("turn-2-user"), - text("turn-2-assistant"), - user("turn-3-user"), - text("turn-3-assistant"), - user("turn-4-user"), - text("turn-4-assistant-final"), - ]; - - // Calculate sizes to set budget that keeps turn 1 (prefix) + turn 4 (tail) - // but drops turns 2 and 3. - let all_messages = events_to_messages(&events); - let total_chars: usize = all_messages.iter().map(|m| m.to_string().len()).sum(); - - // Set budget to roughly half — enough for prefix + last turn + marker - let turn1_msgs = events_to_messages(&events[0..2]); - let turn4_msgs = events_to_messages(&events[6..8]); - let marker_chars = estimate_chars(std::slice::from_ref(&truncation_marker())); - let needed = estimate_chars(&turn1_msgs) + estimate_chars(&turn4_msgs) + marker_chars; - - // Budget allows prefix + marker + last turn but not the middle turns - assert!( - needed < total_chars, - "test setup: needed ({needed}) should be less than total ({total_chars})" - ); - let builder = ContextWindowBuilder::new(needed + 10); // small margin - - let messages = builder.build(&events); - - // Should have: turn 1 (2 msgs) + marker (1 msg) + turn 4 (2 msgs) = 5 - assert_eq!(messages.len(), 5, "expected prefix + marker + tail"); - assert_eq!(messages[0]["content"], "turn-1-user"); - assert_eq!(messages[1]["content"], "turn-1-assistant"); - assert!( - messages[2]["content"].as_str().unwrap().contains("omitted"), - "middle message should be truncation marker" - ); - assert_eq!(messages[3]["content"], "turn-4-user"); - assert_eq!(messages[4]["content"], "turn-4-assistant-final"); - } - - #[test] - fn very_tight_budget_keeps_prefix_and_last_turn() { - let events = vec![ - user("first"), - text("response-1"), - user("second"), - text("response-2"), - user("third"), - text("response-3"), - ]; - - // Budget of 1 — forces the "always include last turn" fallback - let builder = ContextWindowBuilder::new(1); - let messages = builder.build(&events); - - // Should have prefix (turn 1) + marker + last turn (turn 3) - assert!( - messages.len() >= 3, - "should have at least prefix + marker + tail" - ); - - // First message should be from turn 1 - assert_eq!(messages[0]["content"], "first"); - - // Last messages should be from the final turn - let last = messages.last().unwrap(); - assert_eq!(last["content"], "response-3"); - } - - #[test] - fn single_turn_always_returned() { - let events = vec![user("hello"), text("hi there")]; - - // Even with a tiny budget, the single turn must be returned - let builder = ContextWindowBuilder::new(1); - let messages = builder.build(&events); - assert_eq!(messages.len(), 2); - } - - #[test] - fn tool_calls_preserved_through_truncation() { - let events = vec![ - // Turn 1: simple exchange - user("turn 1"), - text("response 1"), - // Turn 2: with tool calls (will be dropped) - user("turn 2"), - text("checking"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", "output"), - text("done"), - // Turn 3: final turn (kept in tail) - user("turn 3"), - text("final response"), - ]; - - // Budget that fits turn 1 + turn 3 + marker but not turn 2 - let turn1 = events_to_messages(&events[0..2]); - let turn3 = events_to_messages(&events[7..9]); - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Verify turn 2 (the tool call turn) was dropped - let has_tool_use = messages.iter().any(|m| { - m["content"] - .as_array() - .is_some_and(|arr| arr.iter().any(|b| b["type"] == "tool_use")) - }); - assert!(!has_tool_use, "tool call turn should have been truncated"); - - // Verify first and last turns present - assert_eq!(messages[0]["content"], "turn 1"); - assert_eq!(messages.last().unwrap()["content"], "final response"); - } - - #[test] - fn tail_accumulates_multiple_turns_when_budget_allows() { - // Use long content so turn sizes dwarf the truncation marker. - let padding = "x".repeat(500); - let events = vec![ - user(&format!("turn-1-user-{padding}")), - text(&format!("turn-1-response-{padding}")), - user(&format!("turn-2-user-{padding}")), - text(&format!("turn-2-response-{padding}")), - user(&format!("turn-3-user-{padding}")), - text(&format!("turn-3-response-{padding}")), - user(&format!("turn-4-user-{padding}")), - text(&format!("turn-4-response-{padding}")), - ]; - - // Budget that fits everything except turn 2 - let all = events_to_messages(&events); - let total = estimate_chars(&all); - let turn2 = events_to_messages(&events[2..4]); - let turn2_chars = estimate_chars(&turn2); - - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = total - turn2_chars + marker_cost + 5; - assert!( - budget < total, - "budget must be less than total for truncation to trigger" - ); - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Should have: prefix (t1: 2 msgs) + marker (1 msg) + t3 (2 msgs) + t4 (2 msgs) = 7 - // (turn 2 dropped) - assert_eq!(messages.len(), 7); - assert!( - messages[0]["content"] - .as_str() - .unwrap() - .starts_with("turn-1-user-") - ); - assert!( - messages[1]["content"] - .as_str() - .unwrap() - .starts_with("turn-1-response-") - ); - assert!(messages[2]["content"].as_str().unwrap().contains("omitted")); - assert!( - messages[3]["content"] - .as_str() - .unwrap() - .starts_with("turn-3-user-") - ); - assert!( - messages[4]["content"] - .as_str() - .unwrap() - .starts_with("turn-3-response-") - ); - assert!( - messages[5]["content"] - .as_str() - .unwrap() - .starts_with("turn-4-user-") - ); - assert!( - messages[6]["content"] - .as_str() - .unwrap() - .starts_with("turn-4-response-") - ); - } - - #[test] - fn no_marker_when_no_turns_dropped() { - // Two turns, both fit in budget - let events = vec![user("a"), text("b"), user("c"), text("d")]; - - let builder = ContextWindowBuilder::with_default_budget(); - let messages = builder.build(&events); - - // No truncation marker - assert_eq!(messages.len(), 4); - assert!( - !messages - .iter() - .any(|m| m["content"].as_str().is_some_and(|s| s.contains("omitted"))) - ); - } - - #[test] - fn tool_use_and_tool_result_never_split() { - // Invariant: a tool_use and its matching tool_result must always - // end up in the same turn, so truncation can't orphan one from - // the other. This test verifies that ToolResult does NOT start - // a new turn boundary. - let padding = "x".repeat(500); - let events = vec![ - // Turn 1 (prefix) - user(&format!("turn-1-{padding}")), - text(&format!("resp-1-{padding}")), - // Turn 2: contains a tool_use → tool_result pair (will be dropped) - user(&format!("turn-2-{padding}")), - text("checking"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", &format!("output-{padding}")), - text(&format!("done-{padding}")), - // Turn 3 (tail) - user(&format!("turn-3-{padding}")), - text(&format!("resp-3-{padding}")), - ]; - - // Budget that fits turn 1 + turn 3 + marker, but not turn 2 - let turn1 = events_to_messages(&events[0..2]); - let turn3 = events_to_messages(&events[7..9]); - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Verify: every tool_use has a matching tool_result, and vice versa - let tool_use_ids: Vec<&str> = messages - .iter() - .filter_map(|m| m["content"].as_array()) - .flatten() - .filter(|b| b["type"] == "tool_use") - .filter_map(|b| b["id"].as_str()) - .collect(); - - let tool_result_ids: Vec<&str> = messages - .iter() - .filter_map(|m| m["content"].as_array()) - .flatten() - .filter(|b| b["type"] == "tool_result") - .filter_map(|b| b["tool_use_id"].as_str()) - .collect(); - - assert_eq!( - tool_use_ids, tool_result_ids, - "every tool_use must have a matching tool_result (and vice versa)" - ); - - // Turn 2 was dropped entirely, so no tool IDs should be present - assert!( - !tool_use_ids.contains(&"tc1"), - "dropped turn's tool_use should not appear" - ); - } -} diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs deleted file mode 100644 index e704175c..00000000 --- a/crates/atuin-ai/src/diff.rs +++ /dev/null @@ -1,328 +0,0 @@ -//! Structured diff computation for edit previews. -//! -//! Computes a line-level diff between old and new file content using -//! imara-diff's Histogram algorithm, producing structured hunks with -//! typed lines (Context, Added, Removed) suitable for TUI rendering. - -use imara_diff::{Algorithm, Diff, InternedInput}; - -/// Number of context lines to show around each change. -const CONTEXT_LINES: u32 = 3; - -/// A structured diff preview for a file edit, ready for rendering. -#[derive(Debug, Clone)] -pub(crate) struct EditPreview { - pub hunks: Vec, -} - -/// A contiguous group of diff lines (context + changes). -#[derive(Debug, Clone)] -pub(crate) struct DiffHunk { - /// 1-indexed line number of the first line in this hunk (in the original file). - pub before_start: u32, - /// 1-indexed line number of the first line in this hunk (in the new file). - pub after_start: u32, - pub lines: Vec, -} - -/// A single line in a diff hunk. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum DiffLine { - /// Unchanged line (shown for context). - Context(String), - /// Line added in the new version. - Added(String), - /// Line removed from the old version. - Removed(String), -} - -impl EditPreview { - /// Compute a structured diff between old and new file content. - /// - /// Uses the Histogram algorithm with line-level granularity and - /// indentation-aware postprocessing for readable output. - pub fn compute(old: &str, new: &str) -> Self { - let input = InternedInput::new(old, new); - let mut diff = Diff::compute(Algorithm::Histogram, &input); - diff.postprocess_lines(&input); - - let raw_hunks: Vec<_> = diff.hunks().collect(); - if raw_hunks.is_empty() { - return EditPreview { hunks: Vec::new() }; - } - - // Merge hunks that are within 2*CONTEXT_LINES of each other - // (same logic as unified diff format). - let mut merged_groups: Vec> = Vec::new(); - let mut current_group: Vec<&imara_diff::Hunk> = vec![&raw_hunks[0]]; - - for hunk in &raw_hunks[1..] { - let prev = current_group.last().unwrap(); - if hunk.before.start.saturating_sub(prev.before.end) <= 2 * CONTEXT_LINES { - current_group.push(hunk); - } else { - merged_groups.push(current_group); - current_group = vec![hunk]; - } - } - merged_groups.push(current_group); - - // Build structured hunks from merged groups - let hunks = merged_groups - .into_iter() - .map(|group| build_hunk(&group, &input)) - .collect(); - - EditPreview { hunks } - } - - /// The highest line number (from either file) that will be displayed. - /// Used to calculate gutter width. - pub fn max_line_number(&self) -> u32 { - self.hunks - .iter() - .map(|h| { - let mut before_pos = h.before_start; - let mut after_pos = h.after_start; - for line in &h.lines { - match line { - DiffLine::Context(_) => { - before_pos += 1; - after_pos += 1; - } - DiffLine::Removed(_) => before_pos += 1, - DiffLine::Added(_) => after_pos += 1, - } - } - before_pos.max(after_pos).saturating_sub(1) - }) - .max() - .unwrap_or(0) - } -} - -/// Maximum lines to show in a write preview. -const WRITE_PREVIEW_LINES: usize = 10; - -/// A content preview for a write_file operation. -/// -/// Shows the first N lines of the written content plus a count of -/// remaining lines if truncated. -#[derive(Debug, Clone)] -pub(crate) struct WritePreview { - /// First lines of content (up to WRITE_PREVIEW_LINES). - pub lines: Vec, - /// Total number of lines in the written file. - pub total_lines: usize, -} - -impl WritePreview { - /// Create a preview from file content. - pub fn from_content(content: &str) -> Self { - let all_lines: Vec<&str> = content.lines().collect(); - let total_lines = all_lines.len(); - let lines = all_lines - .into_iter() - .take(WRITE_PREVIEW_LINES) - .map(String::from) - .collect(); - WritePreview { lines, total_lines } - } - - /// Number of lines not shown in the preview. - pub fn remaining_lines(&self) -> usize { - self.total_lines.saturating_sub(self.lines.len()) - } -} - -/// Build a single DiffHunk from a group of adjacent raw hunks. -fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk { - let first = group.first().unwrap(); - let last = group.last().unwrap(); - - let context_start = first.before.start.saturating_sub(CONTEXT_LINES); - let context_end = (last.before.end + CONTEXT_LINES).min(input.before.len() as u32); - - // The after-file position of context_start: same offset as before since - // context before the first change is identical in both files. - let after_context_start = first.after.start - (first.before.start - context_start); - - let mut lines = Vec::new(); - let mut pos = context_start; - - for hunk in group { - // Context lines before this hunk - for i in pos..hunk.before.start { - lines.push(DiffLine::Context(token_text(input, true, i))); - } - - // Removed lines - for i in hunk.before.start..hunk.before.end { - lines.push(DiffLine::Removed(token_text(input, true, i))); - } - - // Added lines - for i in hunk.after.start..hunk.after.end { - lines.push(DiffLine::Added(token_text(input, false, i))); - } - - pos = hunk.before.end; - } - - // Trailing context - for i in pos..context_end { - lines.push(DiffLine::Context(token_text(input, true, i))); - } - - DiffHunk { - before_start: context_start + 1, // 1-indexed - after_start: after_context_start + 1, // 1-indexed - lines, - } -} - -/// Extract the text content of a token, trimming the trailing newline -/// that imara-diff includes in line-based tokenization. -fn token_text(input: &InternedInput<&str>, is_before: bool, idx: u32) -> String { - let tokens = if is_before { - &input.before - } else { - &input.after - }; - let text = input.interner[tokens[idx as usize]]; - text.strip_suffix('\n') - .unwrap_or(text) - .strip_suffix('\r') - .unwrap_or(text.strip_suffix('\n').unwrap_or(text)) - .to_string() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn no_changes_produces_empty_preview() { - let preview = EditPreview::compute("hello\nworld\n", "hello\nworld\n"); - assert!(preview.hunks.is_empty()); - } - - #[test] - fn single_line_replacement() { - let old = "line1\nline2\nline3\n"; - let new = "line1\nchanged\nline3\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - - // Should have: context(line1), removed(line2), added(changed), context(line3) - assert!(hunk.lines.contains(&DiffLine::Context("line1".into()))); - assert!(hunk.lines.contains(&DiffLine::Removed("line2".into()))); - assert!(hunk.lines.contains(&DiffLine::Added("changed".into()))); - assert!(hunk.lines.contains(&DiffLine::Context("line3".into()))); - } - - #[test] - fn addition_only() { - let old = "aaa\nbbb\n"; - let new = "aaa\nnew_line\nbbb\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!(hunk.lines.contains(&DiffLine::Added("new_line".into()))); - // Original lines are context - assert!(hunk.lines.contains(&DiffLine::Context("aaa".into()))); - assert!(hunk.lines.contains(&DiffLine::Context("bbb".into()))); - } - - #[test] - fn removal_only() { - let old = "aaa\nremove_me\nbbb\n"; - let new = "aaa\nbbb\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!(hunk.lines.contains(&DiffLine::Removed("remove_me".into()))); - } - - #[test] - fn distant_changes_produce_separate_hunks() { - // Two changes separated by more than 2*CONTEXT_LINES (6) lines - let old = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n"; - let new = "1\nX\n3\n4\n5\n6\n7\n8\n9\n10\n11\nY\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 2); - } - - #[test] - fn close_changes_merge_into_one_hunk() { - // Two changes separated by fewer than 2*CONTEXT_LINES lines - let old = "1\n2\n3\n4\n5\n"; - let new = "X\n2\n3\n4\nY\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - } - - #[test] - fn context_is_limited() { - // With CONTEXT_LINES=3, a change at line 10 shouldn't include line 1 - let mut old_lines: Vec<&str> = (1..=20).map(|_| "unchanged").collect(); - old_lines[9] = "target"; - let old = old_lines.join("\n") + "\n"; - let new = old.replace("target", "replaced"); - - let preview = EditPreview::compute(&old, &new); - assert_eq!(preview.hunks.len(), 1); - - // Should have at most 3 context lines before + 3 after + 1 removed + 1 added = 8 lines - assert!(preview.hunks[0].lines.len() <= 8); - } - - #[test] - fn max_line_number_reflects_file_position() { - let old = "a\nb\nc\n"; - let new = "a\nX\nc\n"; - let preview = EditPreview::compute(old, new); - // 3-line file, context + removed lines span positions 1-3 - assert_eq!(preview.max_line_number(), 3); - } - - #[test] - fn start_line_is_correct_for_later_changes() { - // Change at line 10 with 3 context lines → before_start = 7 - let mut lines: Vec = (1..=15).map(|i| format!("line{i}")).collect(); - let old = lines.join("\n") + "\n"; - lines[9] = "CHANGED".to_string(); - let new = lines.join("\n") + "\n"; - - let preview = EditPreview::compute(&old, &new); - assert_eq!(preview.hunks.len(), 1); - assert_eq!(preview.hunks[0].before_start, 7); // line 10 - 3 context = line 7 - assert_eq!(preview.hunks[0].after_start, 7); // same for a simple replacement - } - - #[test] - fn multiline_replacement() { - let old = "[section]\nkey1 = old1\nkey2 = old2\n[other]\n"; - let new = "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!( - hunk.lines - .contains(&DiffLine::Removed("key1 = old1".into())) - ); - assert!( - hunk.lines - .contains(&DiffLine::Removed("key2 = old2".into())) - ); - assert!(hunk.lines.contains(&DiffLine::Added("key1 = new1".into()))); - assert!(hunk.lines.contains(&DiffLine::Added("key2 = new2".into()))); - } -} diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs deleted file mode 100644 index 82d666ef..00000000 --- a/crates/atuin-ai/src/driver.rs +++ /dev/null @@ -1,1030 +0,0 @@ -//! Driver loop for the agent FSM. -//! -//! Receives events from the channel, calls `fsm.handle()`, syncs ViewState -//! to the Handle, and executes effects (spawning async tasks for IO). -//! -//! The driver runs on a blocking thread (`spawn_blocking`) so it can call -//! `blocking_recv()` on the Handle and `block_on()` for async persistence. - -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc; - -use eye_declare::Handle; - -use crate::context::{AppContext, ClientContext}; -use crate::edit_permissions::EditPermissionCache; -use crate::file_tracker::FileReadTracker; -use crate::fsm::effects::{Effect, ExitAction, PermissionTarget}; -use crate::fsm::events::{Event, PermissionChoice, PermissionResponse}; -use crate::fsm::tools::ToolPreviewData; -use crate::fsm::{AgentFsm, AgentState}; -use crate::permissions::resolver::PermissionResolver; -use crate::permissions::writer; -use crate::session::SessionManager; -use crate::stream::ChatRequest; -use crate::tools::ClientToolCall; -use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::ConversationEvent; -use crate::tui::view::turn; - -// ============================================================================ -// Driver event — the unified channel type -// ============================================================================ - -/// Events processed by the driver loop. -/// -/// Components emit `Tui` variants via the channel. Spawned async tasks -/// (stream, tool execution) emit `Fsm` variants directly. -#[derive(Debug)] -pub(crate) enum DriverEvent { - /// Event from a TUI component (key press, input change, etc.) - Tui(AiTuiEvent), - /// Internal FSM event (from spawned stream/tool tasks) - Fsm(Event), -} - -// ============================================================================ -// IO context (driver-owned, not visible to FSM) -// ============================================================================ - -pub(crate) struct IoContext { - pub app_ctx: AppContext, - pub client_ctx: ClientContext, - pub session_mgr: SessionManager, - pub file_tracker: FileReadTracker, - pub edit_permissions: EditPermissionCache, - pub snapshot_store: Option, - pub skill_registry: crate::skills::SkillRegistry, -} - -// ============================================================================ -// ViewState (Handle payload for the render thread) -// ============================================================================ - -/// State pushed to the Handle for the view/render thread. -/// Synced from the FSM after each transition. -#[derive(Debug)] -pub(crate) struct ViewState { - // ─── From FSM ─────────────────────────────────────────────── - pub agent_state: AgentState, - pub visible_events: Vec, - pub all_events: Vec, - pub session_id: Option, - pub tools: crate::fsm::tools::ToolManager, - pub current_response: String, - - // ─── Session metadata (set once) ──────────────────────────── - pub is_resumed: bool, - pub last_event_time: Option>, - pub in_git_project: bool, - - // ─── View-only ────────────────────────────────────────────── - pub archived_events: Vec, - - // ─── Pre-computed for rendering ──────────────────────────── - pub turns: Vec, - pub has_command: bool, - pub committed_turn_count: usize, - pub archived_turn_count: usize, - - // ─── Ephemeral interaction state ──────────────────────────── - pub is_input_blank: bool, - pub slash_command_input: Option, - pub slash_command_search_results: Vec, - pub exit_action: Option, - pub slash_registry: crate::tui::slash::SlashCommandRegistry, - pub skill_names: std::collections::HashSet, -} - -impl ViewState { - pub fn is_exiting(&self) -> bool { - self.exit_action.is_some() - } - - pub fn is_busy(&self) -> bool { - matches!(self.agent_state, AgentState::Turn { .. }) - } - - pub fn has_confirmation(&self) -> bool { - matches!( - self.agent_state, - AgentState::Idle { - confirmation: Some(_) - } - ) - } - - pub fn is_input_active(&self) -> bool { - matches!(self.agent_state, AgentState::Idle { .. }) && !self.has_confirmation() - } - - pub fn footer_text(&self) -> &'static str { - match &self.agent_state { - AgentState::Idle { confirmation: None } => { - if self.has_command && self.is_input_blank { - "[Enter] Execute suggested command [Tab] Insert Command" - } else { - "[Enter] Send [Shift+Enter] New line [Esc] Exit" - } - } - AgentState::Idle { - confirmation: Some(_), - } => "[Enter] Confirm dangerous command [Esc] Cancel", - AgentState::Turn { .. } => "[Esc] Cancel", - AgentState::Error(_) => "[Enter]/[r] Retry [Esc] Exit", - } - } -} - -// ============================================================================ -// Main driver loop -// ============================================================================ - -struct DriverContext<'a> { - fsm: &'a mut AgentFsm, - io: &'a mut IoContext, - handle: &'a Handle, - tx: &'a mpsc::Sender, - exiting: &'a Arc, - stream_cancel_tx: &'a mut Option>, - tool_abort_txs: &'a mut std::collections::HashMap>, -} - -/// Main driver loop. Processes events, transitions FSM, syncs view, executes effects. -/// -/// Runs on a blocking thread. Returns when the event channel closes or exit is requested. -/// The Handle already contains the initial ViewState (set by Application::builder). -pub(crate) fn run_driver( - mut fsm: AgentFsm, - mut io: IoContext, - handle: Handle, - rx: mpsc::Receiver, - tx: mpsc::Sender, - exiting: Arc, - in_git_project: bool, -) { - // Dropping the sender cancels the stream (receiver sees Err on changed()). - let mut stream_cancel_tx: Option> = None; - // Per-tool interrupt senders for shell commands. - let mut tool_abort_txs: std::collections::HashMap> = - std::collections::HashMap::new(); - - while let Ok(driver_event) = rx.recv() { - // Log and translate DriverEvent to FSM Event (or handle directly) - let fsm_event = match driver_event { - DriverEvent::Fsm(event) => { - tracing::trace!(?event, state = ?fsm.state, "FSM event"); - Some(event) - } - DriverEvent::Tui(tui_event) => { - tracing::trace!(?tui_event, state = ?fsm.state, "TUI event"); - translate_tui_event(tui_event, &handle) - } - }; - - if let Some(event) = fsm_event { - // Feed event to FSM - let effects = fsm.handle(event); - tracing::trace!(?effects, state = ?fsm.state, "FSM transition"); - - // Sync ViewState to Handle (FSM owns all state now) - sync_view_state(&handle, &fsm, in_git_project); - - // Execute effects (only persist when FSM says to) - for effect in &effects { - if matches!(effect, Effect::Persist) { - persist(&fsm, &mut io); - } - - let ctx = DriverContext { - fsm: &mut fsm, - io: &mut io, - handle: &handle, - tx: &tx, - exiting: &exiting, - stream_cancel_tx: &mut stream_cancel_tx, - tool_abort_txs: &mut tool_abort_txs, - }; - - execute_effect(effect, ctx); - } - - // Final sync after effects — ensures the render thread sees - // the absolute final state even if effects modified anything. - if !effects.is_empty() { - sync_view_state(&handle, &fsm, in_git_project); - } - } - // InputUpdated (the only event that returns None) already pushed - // its view-only changes via handle.update() — no FSM state changed, - // so skip the expensive sync_view_state that clones all events. - - if exiting.load(Ordering::Acquire) { - break; - } - tracing::trace!(state = ?fsm.state, "driver loop iteration complete, waiting for next event"); - } -} - -// ============================================================================ -// TUI event translation -// ============================================================================ - -/// Translate a TUI event into an FSM event. -/// Returns None for events handled directly (e.g. InputUpdated). -fn translate_tui_event(event: AiTuiEvent, handle: &Handle) -> Option { - match event { - AiTuiEvent::SubmitInput(input) => { - // Clear slash state and reset is_input_blank (the InputBox clears - // its text on submit but doesn't fire InputUpdated for the clear). - handle.update(|vs| { - vs.slash_command_input = None; - vs.slash_command_search_results.clear(); - vs.is_input_blank = true; - }); - - let input = input.trim().to_string(); - if input.is_empty() { - Some(Event::ExecuteCommand) - } else if input == "/new" { - Some(Event::NewSession) - } else if input.starts_with('/') { - if let Some((skill_name, arguments)) = resolve_skill_name(&input, handle) { - Some(Event::RequestSkillLoad { - name: skill_name, - arguments, - }) - } else { - let content = resolve_slash_command(&input, handle); - Some(Event::SlashCommand { - command: input, - content, - }) - } - } else { - Some(Event::UserSubmit(input)) - } - } - AiTuiEvent::InputUpdated(text) => { - let is_blank = text.is_empty(); - - // Hot path (every keystroke); uses handle.update_tracked - // to allow read()ing the state without marking it dirty. - handle.update_tracked(move |vs| { - if vs.read().is_input_blank != is_blank { - vs.is_input_blank = is_blank; - } - - if text.starts_with('/') { - let query = text.trim_start_matches('/').to_string(); - let mut results = vs.slash_registry.search_fuzzy(&query); - results.sort_by(|a, b| { - b.relevance - .partial_cmp(&a.relevance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - vs.slash_command_input = Some(query); - vs.slash_command_search_results = results; - } else { - if vs.read().slash_command_input.is_some() { - vs.slash_command_input = None; - } - - if !vs.read().slash_command_search_results.is_empty() { - vs.slash_command_search_results.clear(); - } - } - }); - None - } - AiTuiEvent::CancelGeneration => Some(Event::Cancel), - AiTuiEvent::ExecuteCommand => Some(Event::ExecuteCommand), - AiTuiEvent::InsertCommand => Some(Event::InsertCommand), - AiTuiEvent::CancelConfirmation => Some(Event::Cancel), - AiTuiEvent::InterruptToolExecution => Some(Event::InterruptTools), - AiTuiEvent::Retry => Some(Event::Retry), - AiTuiEvent::Exit => Some(Event::Cancel), - AiTuiEvent::SelectPermission(result) => { - let tool_id = handle - .fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone())) - .blocking_recv() - .ok() - .flatten(); - - let tool_id = tool_id?; - - let choice = match result { - PermissionResult::Allow => PermissionChoice::Allow, - PermissionResult::AllowFileForSession => PermissionChoice::AllowForSession, - PermissionResult::AlwaysAllowInDir => PermissionChoice::AlwaysAllowInProject, - PermissionResult::AlwaysAllow => PermissionChoice::AlwaysAllow, - PermissionResult::Deny => PermissionChoice::Deny, - }; - Some(Event::PermissionUserChoice { tool_id, choice }) - } - AiTuiEvent::SlashCommand(cmd) => { - if let Some((skill_name, arguments)) = resolve_skill_name(&cmd, handle) { - Some(Event::RequestSkillLoad { - name: skill_name, - arguments, - }) - } else { - let content = resolve_slash_command(&cmd, handle); - Some(Event::SlashCommand { - command: cmd, - content, - }) - } - } - } -} - -/// Resolve a slash command to its output content. -/// If the input starts with `/`, check whether the command name matches a -/// registered skill. Returns `Some((skill_name, arguments))` if it does. -fn resolve_skill_name(input: &str, handle: &Handle) -> Option<(String, Option)> { - let after_slash = input.trim_start_matches('/'); - let cmd_name = after_slash.split_whitespace().next()?.to_string(); - - let is_skill = handle - .fetch({ - let cmd_name = cmd_name.clone(); - move |vs| vs.skill_names.contains(&cmd_name) - }) - .blocking_recv() - .unwrap_or(false); - - if !is_skill { - return None; - } - - let args = after_slash - .strip_prefix(&cmd_name) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()); - - Some((cmd_name, args)) -} - -fn resolve_slash_command(command: &str, handle: &Handle) -> String { - match command.trim() { - "/help" => { - let commands = handle - .fetch(|vs| { - vs.slash_registry - .get_commands() - .iter() - .map(|cmd| format!("- `/{}` — {}", cmd.name, cmd.description)) - .collect::>() - .join("\n") - }) - .blocking_recv() - .unwrap_or_default(); - include_str!("tui/content/help.md").replace("{commands}", &commands) - } - _ => format!("Unknown command: {command}"), - } -} - -// ============================================================================ -// ViewState sync -// ============================================================================ - -fn sync_view_state(handle: &Handle, fsm: &AgentFsm, in_git_project: bool) { - let state = fsm.state.clone(); - let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); - let mut visible_events = fsm.ctx.events[safe_start..].to_vec(); - let all_events = fsm.ctx.events.clone(); - let tools = fsm.ctx.tools.clone(); - let current_response = fsm.ctx.current_response.clone(); - let session_id = fsm.ctx.session_id.clone(); - let is_resumed = fsm.ctx.is_resumed; - let last_event_time = fsm.ctx.last_event_time; - let archived_events = fsm.ctx.archived_events.clone(); - - // Inject streaming text as a synthetic event for live rendering. - // The FSM commits text to events on stream end; this makes it visible during streaming. - let trimmed = current_response.trim_start(); - if !trimmed.is_empty() { - visible_events.push(ConversationEvent::Text { - content: trimmed.to_string(), - }); - } - - // Pre-compute turns and has_command on the driver thread so the - // render-thread view function doesn't redo O(n) work every frame. - let mut archived_builder = turn::TurnBuilder::new(&tools); - for event in &archived_events { - archived_builder.add_event(event); - } - let archived_turns = archived_builder.build(); - let archived_turn_count = archived_turns.len(); - - let mut visible_builder = turn::TurnBuilder::new_starting_at(&tools, archived_turn_count); - for event in &visible_events { - visible_builder.add_event(event); - } - let visible_turns = visible_builder.build(); - - let mut turns = archived_turns; - turns.extend(visible_turns); - - let has_command = visible_events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() - } else { - false - } - }); - - tracing::trace!(?state, "sync_view_state pushing to handle"); - handle.update(move |vs| { - vs.agent_state = state; - vs.visible_events = visible_events; - vs.all_events = all_events; - vs.tools = tools; - vs.current_response = current_response; - vs.session_id = session_id; - vs.is_resumed = is_resumed; - vs.last_event_time = last_event_time; - vs.in_git_project = in_git_project; - vs.archived_events = archived_events; - vs.turns = turns; - vs.has_command = has_command; - vs.archived_turn_count = archived_turn_count; - }); -} - -// ============================================================================ -// Effect execution -// ============================================================================ - -fn execute_effect(effect: &Effect, ctx: DriverContext) { - let DriverContext { - fsm, - io, - handle, - tx, - exiting, - stream_cancel_tx, - tool_abort_txs, - } = ctx; - - match effect { - Effect::StartStream { - messages, - session_id, - } => { - // Cancel any existing stream before starting a new one - stream_cancel_tx.take(); - - let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(()); - *stream_cancel_tx = Some(cancel_tx); - - let tx = tx.clone(); - let app = io.app_ctx.clone(); - let cc = io.client_ctx.clone(); - let (skill_summaries, skill_overflow) = io.skill_registry.server_skills(); - let request = ChatRequest::new( - messages.clone(), - session_id.clone(), - &app.capabilities, - app.daemon_enabled, - fsm.ctx.invocation_id.clone(), - ); - tokio::spawn(async move { - run_stream_bridge( - request, - app, - cc, - tx, - cancel_rx, - skill_summaries, - skill_overflow, - ) - .await; - }); - } - - Effect::AbortStream => { - // Drop the sender — the bridge's cancel_rx.changed() will error, - // breaking the stream loop and dropping the HTTP connection. - stream_cancel_tx.take(); - } - - Effect::CheckPermission { tool_id, tool } => { - let tool_id = tool_id.clone(); - let tool = tool.clone(); - let tx = tx.clone(); - - // Auto-approved tools (e.g. load_skill) bypass permission checks entirely - if tool.is_auto_approved() { - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response: PermissionResponse::Allowed, - })); - return; - } - - let working_dir = tool - .target_dir() - .map(|p| p.to_path_buf()) - .or_else(|| std::env::current_dir().ok()) - .unwrap_or_else(|| PathBuf::from(".")); - - // Check session grants first (synchronous) - if let Some(resolved) = tool.resolved_file_path() - && io.edit_permissions.has_valid_grant(&resolved) - { - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response: PermissionResponse::SessionGranted, - })); - return; - } - - tokio::spawn(async move { - let response = match PermissionResolver::new(working_dir).await { - Ok(resolver) => match resolver.check(&tool).await { - Ok(crate::permissions::check::PermissionResponse::Allowed) => { - PermissionResponse::Allowed - } - Ok(crate::permissions::check::PermissionResponse::Denied) => { - PermissionResponse::Denied - } - Ok(crate::permissions::check::PermissionResponse::Ask) => { - PermissionResponse::Ask - } - Err(_) => PermissionResponse::Ask, - }, - Err(_) => PermissionResponse::Ask, - }; - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response, - })); - }); - } - - Effect::ExecuteTool { tool_id, tool } => { - let tool_id = tool_id.clone(); - let tx = tx.clone(); - let db = io.app_ctx.history_db.clone(); - - match &tool { - ClientToolCall::Shell(shell_call) => { - let shell_call = shell_call.clone(); - let tx_preview = tx.clone(); - let tool_id_for_preview = tool_id.clone(); - - // Create interrupt channel and store the sender for AbortTool - let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel(); - tool_abort_txs.insert(tool_id.clone(), interrupt_tx); - - tokio::spawn(async move { - let (output_tx, mut output_rx) = - tokio::sync::mpsc::channel::>(16); - - let preview_id = tool_id_for_preview; - let tx_fwd = tx_preview; - tokio::spawn(async move { - while let Some(lines) = output_rx.recv().await { - let _ = tx_fwd.send(DriverEvent::Fsm(Event::ToolPreviewUpdate { - tool_id: preview_id.clone(), - lines, - exit_code: None, - })); - } - }); - - let outcome = crate::tools::execute_shell_command_streaming( - &shell_call, - output_tx, - interrupt_rx, - ) - .await; - - let preview = - if let crate::tools::ToolOutcome::Structured { exit_code, .. } = - &outcome - { - Some(ToolPreviewData::Shell { - lines: vec![], - exit_code: *exit_code, - // Reason is set by the FSM in handle_tool_done - // based on whether it was a user interrupt or timeout. - interrupted: None, - }) - } else { - None - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - }); - } - ClientToolCall::Edit(edit_call) => { - let resolved = edit_call.resolved_path(); - - // Capture old content for snapshot + diff preview - let old_content = std::fs::read(&resolved).ok(); - if let Some(ref content) = old_content - && let Some(ref mut store) = io.snapshot_store - && let Err(e) = store.ensure_snapshot(&resolved, content) - { - tracing::warn!("Failed to snapshot before edit: {e}"); - } - - // Edit is fast (file read + string replace + write) — run inline - let (outcome, new_content) = edit_call.execute(&resolved, &io.file_tracker); - - // Update file tracker with new content - if let Some(new_bytes) = &new_content - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - - // Compute diff preview - let preview = match (&old_content, &new_content) { - (Some(old_bytes), Some(new_bytes)) => { - let old_str = String::from_utf8_lossy(old_bytes); - let new_str = String::from_utf8_lossy(new_bytes); - let diff = crate::diff::EditPreview::compute(&old_str, &new_str); - if diff.hunks.is_empty() { - None - } else { - Some(ToolPreviewData::Edit(diff)) - } - } - _ => None, - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - } - ClientToolCall::Write(write_call) => { - let resolved = write_call.resolved_path(); - - // Snapshot existing file before overwriting - if let Ok(content) = std::fs::read(&resolved) - && let Some(ref mut store) = io.snapshot_store - && let Err(e) = store.ensure_snapshot(&resolved, &content) - { - tracing::warn!("Failed to snapshot before write: {e}"); - } - - // Write is fast (atomic file write) — run inline - let (outcome, written_bytes) = write_call.execute(&resolved); - - // Update file tracker with new content - if let Some(new_bytes) = &written_bytes - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - - let preview = if !outcome.is_error() { - Some(ToolPreviewData::Write( - crate::diff::WritePreview::from_content(&write_call.content), - )) - } else { - None - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - } - ClientToolCall::Read(read_call) => { - // Read is fast (file read) — run inline so we can update file_tracker - let outcome = read_call.execute(); - - // Track the read for freshness checking on subsequent edits - if !outcome.is_error() { - let resolved = read_call.resolved_path(); - if resolved.is_file() - && let Ok(content) = std::fs::read(&resolved) - && let Ok(mtime) = - std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker.record_read(resolved, &content, mtime); - } - } - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - } - ClientToolCall::AtuinHistory(tool) => { - // History search needs async DB access - let tool = tool.clone(); - tokio::spawn(async move { - let outcome = tool.execute(&db).await; - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - ClientToolCall::AtuinOutput(tool) => { - let tool = tool.clone(); - tokio::spawn(async move { - let outcome = tool.execute().await; - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - ClientToolCall::LoadSkill(skill_call) => { - let skill_name = skill_call.name.clone(); - let registry = io.skill_registry.clone(); - let shell = io - .client_ctx - .shell - .clone() - .unwrap_or_else(|| "sh".to_string()); - - tokio::spawn(async move { - let content = - load_skill_content(®istry, &skill_name, &shell, None).await; - let outcome = crate::tools::ToolOutcome::Success(content); - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - } - } - - Effect::LoadSkill { name, arguments } => { - let name = name.clone(); - let arguments = arguments.clone(); - let registry = io.skill_registry.clone(); - let shell = io - .client_ctx - .shell - .clone() - .unwrap_or_else(|| "sh".to_string()); - let tx = tx.clone(); - tokio::spawn(async move { - let content = - load_skill_content(®istry, &name, &shell, arguments.as_deref()).await; - let _ = tx.send(DriverEvent::Fsm(Event::SkillLoaded { - name, - arguments, - content, - })); - }); - } - - Effect::AbortTool { tool_id } => { - if let Some(abort_tx) = tool_abort_txs.remove(tool_id) { - let _ = abort_tx.send(()); - } - } - - Effect::Persist => { - // Handled inline in the driver loop (before this function is called). - } - - Effect::WritePermissionRule { - target, - rule, - disposition, - } => { - let file_path = match target { - PermissionTarget::Project => { - let project_root = io - .app_ctx - .git_root - .clone() - .or_else(|| std::env::current_dir().ok()) - .unwrap_or_else(|| PathBuf::from(".")); - writer::project_permissions_path(&project_root) - } - PermissionTarget::Global => writer::global_permissions_path(), - }; - let rule = rule.clone(); - let disposition = disposition.clone(); - tokio::spawn(async move { - if let Err(e) = writer::write_rule(&file_path, &rule, disposition).await { - tracing::error!("Failed to write permission rule: {e}"); - } - }); - } - - Effect::CacheSessionGrant { path } => { - io.edit_permissions.grant(path.clone()); - } - - Effect::ArchiveSession => { - let rt = tokio::runtime::Handle::current(); - if let Err(e) = rt.block_on(io.session_mgr.archive_and_reset()) { - tracing::warn!("Failed to archive session: {e}"); - } - } - - Effect::ScheduleTimeout { - timeout_id, - duration, - kind, - } => { - let timeout_id = *timeout_id; - let duration = *duration; - let kind = kind.clone(); - let tx = tx.clone(); - tokio::spawn(async move { - tokio::time::sleep(duration).await; - use crate::fsm::effects::TimeoutKind; - let event = match kind { - TimeoutKind::Confirmation => Event::ConfirmationTimeout { timeout_id }, - TimeoutKind::ToolExecution { tool_id } => Event::ToolExecutionTimeout { - timeout_id, - tool_id, - }, - }; - let _ = tx.send(DriverEvent::Fsm(event)); - }); - } - - Effect::ExitApp(action) => { - let action = action.clone(); - handle.update(move |vs| { - vs.exit_action = Some(action); - }); - exiting.store(true, Ordering::Release); - let h2 = handle.clone(); - h2.exit(); - } - } -} - -// ============================================================================ -// Persistence -// ============================================================================ - -fn persist(fsm: &AgentFsm, io: &mut IoContext) { - let start = std::time::Instant::now(); - let rt = tokio::runtime::Handle::current(); - - if let Err(e) = rt.block_on(io.session_mgr.persist_events(&fsm.ctx.events)) { - tracing::warn!("Failed to persist session events: {e}"); - } - if let Some(ref sid) = fsm.ctx.session_id - && let Err(e) = rt.block_on(io.session_mgr.persist_server_session_id(sid)) - { - tracing::warn!("Failed to persist server session ID: {e}"); - } - if let Ok(json) = io.file_tracker.to_json() - && let Err(e) = rt.block_on( - io.session_mgr - .set_metadata(crate::file_tracker::METADATA_KEY, &json), - ) - { - tracing::warn!("Failed to persist file tracker: {e}"); - } - if let Ok(json) = io.edit_permissions.to_json() - && let Err(e) = rt.block_on( - io.session_mgr - .set_metadata(crate::edit_permissions::METADATA_KEY, &json), - ) - { - tracing::warn!("Failed to persist edit permissions: {e}"); - } - tracing::trace!(elapsed_ms = start.elapsed().as_millis(), "persist complete"); -} - -// ============================================================================ -// Skill loading -// ============================================================================ - -async fn load_skill_content( - registry: &crate::skills::SkillRegistry, - name: &str, - shell: &str, - arguments: Option<&str>, -) -> String { - match registry.load(name, shell, arguments).await { - Ok(body) => body, - Err(e) => format!("Failed to load skill '{name}': {e}"), - } -} - -// ============================================================================ -// Stream bridge -// ============================================================================ - -async fn run_stream_bridge( - request: ChatRequest, - app_ctx: AppContext, - client_ctx: ClientContext, - tx: mpsc::Sender, - mut cancel_rx: tokio::sync::watch::Receiver<()>, - skill_summaries: Vec, - skill_overflow: Option, -) { - use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; - use futures::StreamExt; - - // Gather user context files (TERMINAL.md) and interpolate commands. - let shell = client_ctx.shell.as_deref().unwrap_or("sh"); - let start_dir = std::env::current_dir().unwrap_or_default(); - let global_ctx_path = crate::user_context::global_context_path(); - let user_contexts = - crate::user_context::gather(&start_dir, Some(&global_ctx_path), shell).await; - - let stream = create_chat_stream( - app_ctx.endpoint.clone(), - app_ctx.token.clone(), - request, - client_ctx, - app_ctx.send_cwd, - app_ctx.last_command.clone(), - user_contexts, - skill_summaries, - skill_overflow, - ); - futures::pin_mut!(stream); - - let _ = tx.send(DriverEvent::Fsm(Event::StreamStarted)); - - loop { - // Select between the next stream frame and cancellation. - // When the driver drops the cancel sender, changed() returns Err - // and we break — dropping the HTTP stream and cancelling the request. - let frame = tokio::select! { - biased; - _ = cancel_rx.changed() => break, - frame = stream.next() => match frame { - Some(frame) => frame, - None => break, - }, - }; - - let event = match frame { - Ok(StreamFrame::Content(content)) => match content { - StreamContent::TextChunk(text) => Some(Event::StreamChunk(text)), - StreamContent::ToolCall { id, name, input } => { - if name == "suggest_command" { - Some(Event::SuggestCommand { id, input }) - } else { - Some(Event::StreamToolCall { id, name, input }) - } - } - StreamContent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => Some(Event::StreamServerToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }), - }, - Ok(StreamFrame::Control(control)) => match control { - StreamControl::StatusChanged(status) => Some(Event::StreamStatusChanged(status)), - StreamControl::Done { session_id } => Some(Event::StreamDone { session_id }), - StreamControl::Error(msg) => Some(Event::StreamError(msg)), - }, - Err(e) => Some(Event::StreamError(e.to_string())), - }; - - if let Some(event) = event { - // StreamDone and StreamError are terminal — the server won't send more. - // SuggestCommand is NOT terminal: the server sends StreamDone after it - // with the session_id we need to capture. - let is_terminal = matches!(event, Event::StreamDone { .. } | Event::StreamError(_)); - if tx.send(DriverEvent::Fsm(event)).is_err() { - break; - } - if is_terminal { - break; - } - } - } -} diff --git a/crates/atuin-ai/src/edit_permissions.rs b/crates/atuin-ai/src/edit_permissions.rs deleted file mode 100644 index 5015a007..00000000 --- a/crates/atuin-ai/src/edit_permissions.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Session-scoped permission cache for file edits. -//! -//! When the user selects "Allow this file for this session", the grant is -//! recorded here with a timestamp. Subsequent edits to the same file skip -//! the permission prompt as long as the grant hasn't expired. -//! -//! Grants are time-limited (1 hour TTL) so they don't outlive the user's -//! attention in long-running sessions. Persisted as JSON in session -//! metadata so they survive across CLI invocations. - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::time::SystemTime; - -use eyre::Result; -use serde::{Deserialize, Serialize}; - -/// Session metadata key for persistence. -pub(crate) const METADATA_KEY: &str = "edit_permissions"; - -/// How long a session-scoped edit permission remains valid. -const TTL_MS: i64 = 60 * 60 * 1000; // 1 hour - -/// Cache of per-file edit permission grants within a session. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub(crate) struct EditPermissionCache { - /// Maps canonical file paths to the grant timestamp (unix millis). - grants: HashMap, -} - -impl EditPermissionCache { - /// Record a permission grant for a file. - pub fn grant(&mut self, path: PathBuf) { - self.grants.insert(path, now_ms()); - } - - /// Check whether there's a valid (non-expired) grant for a file. - pub fn has_valid_grant(&self, path: &Path) -> bool { - if let Some(&granted_at) = self.grants.get(path) { - (now_ms() - granted_at) < TTL_MS - } else { - false - } - } - - pub fn to_json(&self) -> Result { - Ok(serde_json::to_string(self)?) - } - - pub fn from_json(json: &str) -> Result { - Ok(serde_json::from_str(json)?) - } -} - -fn now_ms() -> i64 { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_millis() as i64) - .unwrap_or(0) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn grant_and_check() { - let mut cache = EditPermissionCache::default(); - let path = PathBuf::from("/Users/me/.config/foo.toml"); - - assert!(!cache.has_valid_grant(&path)); - cache.grant(path.clone()); - assert!(cache.has_valid_grant(&path)); - } - - #[test] - fn different_paths_are_independent() { - let mut cache = EditPermissionCache::default(); - let a = PathBuf::from("/etc/hosts"); - let b = PathBuf::from("/etc/resolv.conf"); - - cache.grant(a.clone()); - assert!(cache.has_valid_grant(&a)); - assert!(!cache.has_valid_grant(&b)); - } - - #[test] - fn roundtrip_json() { - let mut cache = EditPermissionCache::default(); - cache.grant(PathBuf::from("/some/file.toml")); - - let json = cache.to_json().unwrap(); - let restored = EditPermissionCache::from_json(&json).unwrap(); - assert!(restored.has_valid_grant(Path::new("/some/file.toml"))); - } - - #[test] - fn expired_grant_is_invalid() { - let mut cache = EditPermissionCache::default(); - let path = PathBuf::from("/expired/file.toml"); - - // Insert a grant from 2 hours ago - let two_hours_ago = now_ms() - (2 * 60 * 60 * 1000); - cache.grants.insert(path.clone(), two_hours_ago); - - assert!(!cache.has_valid_grant(&path)); - } -} diff --git a/crates/atuin-ai/src/event_serde.rs b/crates/atuin-ai/src/event_serde.rs deleted file mode 100644 index e3f9d6f7..00000000 --- a/crates/atuin-ai/src/event_serde.rs +++ /dev/null @@ -1,397 +0,0 @@ -//! Manual serialization for ConversationEvent to/from storage format. -//! -//! The storage format is decoupled from the Rust enum so the two can evolve -//! independently. Each event is stored as an `(event_type, event_data)` pair -//! where `event_data` is a JSON string. - -use eyre::{Result, eyre}; -use serde_json::Value; - -use crate::tui::ConversationEvent; - -/// Serialize a ConversationEvent into an (event_type, event_data_json) pair -/// suitable for database storage. -pub(crate) fn serialize_event(event: &ConversationEvent) -> (String, String) { - match event { - ConversationEvent::UserMessage { content } => ( - "user_message".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::Text { content } => ( - "text".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::ToolCall { id, name, input } => ( - "tool_call".to_string(), - serde_json::json!({ - "id": id, - "name": name, - "input": input, - }) - .to_string(), - ), - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => ( - "tool_result".to_string(), - serde_json::json!({ - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error, - "remote": remote, - "content_length": content_length, - }) - .to_string(), - ), - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => ( - "out_of_band_output".to_string(), - serde_json::json!({ - "name": name, - "command": command, - "content": content, - }) - .to_string(), - ), - ConversationEvent::SystemContext { content } => ( - "system_context".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::SkillInvocation { - name, - arguments, - content, - } => ( - "skill_invocation".to_string(), - serde_json::json!({ - "name": name, - "arguments": arguments, - "content": content, - }) - .to_string(), - ), - } -} - -/// Deserialize an (event_type, event_data_json) pair from storage back into a -/// ConversationEvent. -pub(crate) fn deserialize_event(event_type: &str, event_data: &str) -> Result { - let data: Value = serde_json::from_str(event_data) - .map_err(|e| eyre!("failed to parse event_data JSON: {e}"))?; - - match event_type { - "user_message" => Ok(ConversationEvent::UserMessage { - content: json_string(&data, "content")?, - }), - "text" => Ok(ConversationEvent::Text { - content: json_string(&data, "content")?, - }), - "tool_call" => Ok(ConversationEvent::ToolCall { - id: json_string(&data, "id")?, - name: json_string(&data, "name")?, - input: data - .get("input") - .cloned() - .ok_or_else(|| eyre!("tool_call missing 'input' field"))?, - }), - "tool_result" => Ok(ConversationEvent::ToolResult { - tool_use_id: json_string(&data, "tool_use_id")?, - content: json_string(&data, "content")?, - is_error: data - .get("is_error") - .and_then(Value::as_bool) - .ok_or_else(|| eyre!("tool_result missing 'is_error' field"))?, - remote: data.get("remote").and_then(Value::as_bool).unwrap_or(false), - content_length: data - .get("content_length") - .and_then(Value::as_u64) - .map(|v| v as usize), - }), - "out_of_band_output" => Ok(ConversationEvent::OutOfBandOutput { - name: json_string(&data, "name")?, - command: data - .get("command") - .and_then(|v| if v.is_null() { None } else { v.as_str() }) - .map(String::from), - content: json_string(&data, "content")?, - }), - "system_context" => Ok(ConversationEvent::SystemContext { - content: json_string(&data, "content")?, - }), - "skill_invocation" => Ok(ConversationEvent::SkillInvocation { - name: json_string(&data, "name")?, - arguments: data - .get("arguments") - .and_then(|v| if v.is_null() { None } else { v.as_str() }) - .map(String::from), - content: json_string(&data, "content")?, - }), - other => Err(eyre!("unknown event type: {other}")), - } -} - -fn json_string(data: &Value, field: &str) -> Result { - data.get(field) - .and_then(Value::as_str) - .map(String::from) - .ok_or_else(|| eyre!("missing or non-string field '{field}'")) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn round_trip(event: &ConversationEvent) -> ConversationEvent { - let (event_type, event_data) = serialize_event(event); - deserialize_event(&event_type, &event_data).unwrap() - } - - #[test] - fn test_user_message() { - let event = ConversationEvent::UserMessage { - content: "hello world".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::UserMessage { content } if content == "hello world") - ); - } - - #[test] - fn test_text() { - let event = ConversationEvent::Text { - content: "response text".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::Text { content } if content == "response text") - ); - } - - #[test] - fn test_tool_call() { - let input = serde_json::json!({"command": "ls -la", "danger": "low"}); - let event = ConversationEvent::ToolCall { - id: "tc_123".to_string(), - name: "suggest_command".to_string(), - input: input.clone(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolCall { - id, - name, - input: result_input, - } => { - assert_eq!(id, "tc_123"); - assert_eq!(name, "suggest_command"); - assert_eq!(result_input, input); - } - _ => panic!("expected ToolCall"), - } - } - - #[test] - fn test_tool_result() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_123".to_string(), - content: "file contents here".to_string(), - is_error: false, - remote: false, - content_length: None, - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => { - assert_eq!(tool_use_id, "tc_123"); - assert_eq!(content, "file contents here"); - assert!(!is_error); - assert!(!remote); - assert!(content_length.is_none()); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_error() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_456".to_string(), - content: "permission denied".to_string(), - is_error: true, - remote: false, - content_length: None, - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { is_error, .. } => assert!(is_error), - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_remote() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_789".to_string(), - content: "ref:abc123".to_string(), - is_error: false, - remote: true, - content_length: Some(4096), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { - remote, - content_length, - .. - } => { - assert!(remote); - assert_eq!(content_length, Some(4096)); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_backwards_compat() { - // Old stored data without remote/content_length fields should deserialize - // with defaults (remote=false, content_length=None) - let event = deserialize_event( - "tool_result", - r#"{"tool_use_id":"tc_old","content":"old result","is_error":false}"#, - ) - .unwrap(); - match event { - ConversationEvent::ToolResult { - remote, - content_length, - .. - } => { - assert!(!remote); - assert!(content_length.is_none()); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_out_of_band_with_command() { - let event = ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/help".to_string()), - content: "help text".to_string(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => { - assert_eq!(name, "System"); - assert_eq!(command.as_deref(), Some("/help")); - assert_eq!(content, "help text"); - } - _ => panic!("expected OutOfBandOutput"), - } - } - - #[test] - fn test_out_of_band_without_command() { - let event = ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: None, - content: "some output".to_string(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::OutOfBandOutput { command, .. } => { - assert!(command.is_none()); - } - _ => panic!("expected OutOfBandOutput"), - } - } - - #[test] - fn test_unknown_event_type() { - let result = deserialize_event("banana", "{}"); - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("unknown event type") - ); - } - - #[test] - fn test_invalid_json() { - let result = deserialize_event("text", "not json"); - assert!(result.is_err()); - } - - #[test] - fn test_missing_field() { - let result = deserialize_event("text", "{}"); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("content")); - } - - #[test] - fn test_text_with_special_characters() { - let event = ConversationEvent::Text { - content: "line1\nline2\ttab \"quotes\" \\backslash 🎉".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::Text { content } if content == "line1\nline2\ttab \"quotes\" \\backslash 🎉") - ); - } - - #[test] - fn test_tool_call_with_nested_input() { - let input = serde_json::json!({ - "command": "echo 'hello'", - "nested": { "a": [1, 2, 3], "b": null } - }); - let event = ConversationEvent::ToolCall { - id: "tc_1".to_string(), - name: "execute_shell_command".to_string(), - input: input.clone(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolCall { - input: result_input, - .. - } => { - assert_eq!(result_input, input); - } - _ => panic!("expected ToolCall"), - } - } - - #[test] - fn test_system_context() { - let event = ConversationEvent::SystemContext { - content: "[system: new invocation started]".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::SystemContext { content } if content == "[system: new invocation started]") - ); - } -} diff --git a/crates/atuin-ai/src/file_tracker.rs b/crates/atuin-ai/src/file_tracker.rs deleted file mode 100644 index feee1ee8..00000000 --- a/crates/atuin-ai/src/file_tracker.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! Tracks which files have been read in the current session, for freshness -//! checking before edits. -//! -//! The tracker records the content hash and mtime of each file at the time -//! it was last read. Before an edit, the tracker verifies the file hasn't -//! changed since the last read — catching both external modifications and -//! concurrent tool calls. -//! -//! Persisted as JSON in session metadata so it survives across CLI -//! invocations within the same logical session. - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::time::SystemTime; - -use eyre::Result; -use serde::{Deserialize, Serialize}; - -/// Metadata key used for session_metadata persistence. -pub(crate) const METADATA_KEY: &str = "file_read_tracker"; - -/// State recorded for a single file read. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct FileReadState { - /// Hash of the file contents at the time of the last read. - pub content_hash: u64, - /// File mtime (as milliseconds since epoch) at the time of the last read. - /// Millisecond precision ensures sub-second modifications are detected. - pub mtime_ms: i64, -} - -/// Tracks file read state for freshness checking. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub(crate) struct FileReadTracker { - reads: HashMap, -} - -/// Result of a freshness check. -pub(crate) enum FreshnessCheck { - /// File is fresh — the content hasn't changed since the last read. - Fresh, - /// File has never been read in this session. - NotRead, - /// File has been modified since the last read. - Stale, -} - -impl FileReadTracker { - /// Record that a file was read. Call this after a successful `read_file` - /// execution. The `path` should be canonical (absolute, tilde-expanded). - pub fn record_read(&mut self, path: PathBuf, content: &[u8], mtime: SystemTime) { - let content_hash = hash_content(content); - let mtime_ms = system_time_to_ms(mtime); - - self.reads.insert( - path, - FileReadState { - content_hash, - mtime_ms, - }, - ); - } - - /// Check whether a file is fresh (unchanged since last read). - /// - /// Uses mtime as a fast path — only re-hashes if mtime differs. - pub fn check_freshness(&self, path: &Path) -> Result { - let state = match self.reads.get(path) { - Some(s) => s, - None => return Ok(FreshnessCheck::NotRead), - }; - - // Stat the file - let metadata = match std::fs::metadata(path) { - Ok(m) => m, - Err(_) => return Ok(FreshnessCheck::Stale), // file deleted or inaccessible - }; - - let current_mtime_ms = - system_time_to_ms(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH)); - - // Fast path: mtime unchanged → fresh - if current_mtime_ms == state.mtime_ms { - return Ok(FreshnessCheck::Fresh); - } - - // Mtime changed — re-hash to confirm - let content = std::fs::read(path)?; - let current_hash = hash_content(&content); - - if current_hash == state.content_hash { - Ok(FreshnessCheck::Fresh) - } else { - Ok(FreshnessCheck::Stale) - } - } - - /// Update the tracker entry after a successful edit (new content written). - pub fn update_after_edit(&mut self, path: &Path, new_content: &[u8], new_mtime: SystemTime) { - let content_hash = hash_content(new_content); - let mtime_ms = system_time_to_ms(new_mtime); - - self.reads.insert( - path.to_path_buf(), - FileReadState { - content_hash, - mtime_ms, - }, - ); - } - - /// Serialize to JSON for session metadata persistence. - pub fn to_json(&self) -> Result { - Ok(serde_json::to_string(self)?) - } - - /// Deserialize from JSON session metadata. - pub fn from_json(json: &str) -> Result { - Ok(serde_json::from_str(json)?) - } -} - -fn system_time_to_ms(t: SystemTime) -> i64 { - t.duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_millis() as i64) - .unwrap_or(0) -} - -fn hash_content(content: &[u8]) -> u64 { - xxhash_rust::xxh3::xxh3_64(content) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn record_and_check_fresh() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "hello world").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Fresh - )); - } - - #[test] - fn check_not_read() { - let tracker = FileReadTracker::default(); - let path = PathBuf::from("/nonexistent/file.txt"); - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::NotRead - )); - } - - #[test] - fn check_stale_after_modification() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "original").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - // Small delay to ensure the filesystem mtime advances - std::thread::sleep(std::time::Duration::from_millis(10)); - - // Modify the file - std::fs::write(&path, "modified").unwrap(); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Stale - )); - } - - #[test] - fn update_after_edit_makes_fresh() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "original").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - // Simulate an edit - let new_content = b"edited content"; - std::fs::write(&path, new_content).unwrap(); - let new_mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - tracker.update_after_edit(&path, new_content, new_mtime); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Fresh - )); - } - - #[test] - fn roundtrip_json() { - let mut tracker = FileReadTracker::default(); - tracker.reads.insert( - PathBuf::from("/some/file.toml"), - FileReadState { - content_hash: 12345, - mtime_ms: 1700000000000, - }, - ); - - let json = tracker.to_json().unwrap(); - let restored = FileReadTracker::from_json(&json).unwrap(); - assert_eq!(restored.reads.len(), 1); - assert_eq!( - restored.reads[&PathBuf::from("/some/file.toml")].content_hash, - 12345 - ); - } -} diff --git a/crates/atuin-ai/src/fsm/effects.rs b/crates/atuin-ai/src/fsm/effects.rs deleted file mode 100644 index adc9628e..00000000 --- a/crates/atuin-ai/src/fsm/effects.rs +++ /dev/null @@ -1,99 +0,0 @@ -//! Effects (outputs) from the agent FSM. -//! -//! The FSM returns these as data; the driver is responsible for executing them. - -use std::path::PathBuf; -use std::time::Duration; - -use serde_json::Value; - -use crate::permissions::rule::Rule; -use crate::permissions::writer::RuleDisposition; -use crate::tools::ClientToolCall; - -/// Where to write a permission rule. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionTarget { - /// Project-level: `/.atuin/permissions.ai.toml` - Project, - /// Global: `~/.config/atuin/permissions.ai.toml` - Global, -} - -/// Side effects the driver should execute after a state transition. -#[derive(Debug, Clone)] -pub(crate) enum Effect { - // ─── Network ──────────────────────────────────────────────── - /// Start a new streaming request to the server. - StartStream { - messages: Vec, - session_id: Option, - }, - /// Abort the active stream connection. - AbortStream, - - // ─── Tool orchestration ───────────────────────────────────── - /// Run the permission resolver for a tool call. - CheckPermission { - tool_id: String, - tool: ClientToolCall, - }, - /// Execute a tool (file read, edit, write, shell, history search). - ExecuteTool { - tool_id: String, - tool: ClientToolCall, - }, - /// Kill a running tool (send interrupt to shell command). - AbortTool { tool_id: String }, - /// Load a skill's content asynchronously (read + interpolate). - LoadSkill { - name: String, - arguments: Option, - }, - - // ─── Persistence ──────────────────────────────────────────── - /// Persist current conversation state to disk. - Persist, - /// Write a permanent permission rule to disk. - WritePermissionRule { - target: PermissionTarget, - rule: Rule, - disposition: RuleDisposition, - }, - /// Cache a session-scoped file permission grant. - CacheSessionGrant { path: PathBuf }, - /// Archive current session and start fresh (IO only — state already updated by FSM). - ArchiveSession, - - // ─── Timers ───────────────────────────────────────────────── - /// Schedule a timer that fires an event after the given delay. - ScheduleTimeout { - timeout_id: u64, - duration: Duration, - kind: TimeoutKind, - }, - - // ─── Exit ─────────────────────────────────────────────────── - /// Exit the application with the given action. - ExitApp(ExitAction), -} - -/// What kind of timeout was scheduled. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum TimeoutKind { - /// Dangerous command confirmation dialog auto-dismiss. - Confirmation, - /// Shell tool execution timeout — abort the tool if it's still running. - ToolExecution { tool_id: String }, -} - -/// What to do when exiting the TUI. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ExitAction { - /// Run the suggested command. - Execute(String), - /// Insert the command into the shell without running. - Insert(String), - /// Exit without action. - Cancel, -} diff --git a/crates/atuin-ai/src/fsm/events.rs b/crates/atuin-ai/src/fsm/events.rs deleted file mode 100644 index e591db41..00000000 --- a/crates/atuin-ai/src/fsm/events.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! Events (inputs) to the agent FSM. - -use serde_json::Value; - -use crate::tools::ToolOutcome; - -/// Events that drive state transitions in the agent FSM. -#[derive(Debug, Clone)] -pub(crate) enum Event { - // ─── User actions ─────────────────────────────────────────── - /// User submitted a message from the input box. - UserSubmit(String), - /// User pressed Esc or equivalent cancel action. - Cancel, - /// User pressed Enter to execute the suggested command. - ExecuteCommand, - /// User pressed Tab to insert the suggested command. - InsertCommand, - /// User chose to retry after an error. - Retry, - /// User interrupted executing tools (Ctrl+C / Esc during shell execution). - InterruptTools, - - // ─── Stream lifecycle ─────────────────────────────────────── - /// Stream connection established, first frame received. - StreamStarted, - /// Received a chunk of streamed text content. - StreamChunk(String), - /// Stream delivered a client-side tool call. - StreamToolCall { - id: String, - name: String, - input: Value, - }, - /// Stream delivered a server-side tool result (executed remotely). - StreamServerToolResult { - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option, - }, - /// Stream status changed (e.g. "thinking", "searching"). - StreamStatusChanged(String), - /// Stream ended normally. - StreamDone { session_id: String }, - /// Stream encountered an error. - StreamError(String), - - // ─── Suggest command (terminal tool call) ─────────────────── - /// The suggest_command tool call acts as a stream terminal event. - /// This is the server signaling "turn complete, here's the command." - SuggestCommand { id: String, input: Value }, - - // ─── Tool lifecycle ───────────────────────────────────────── - /// Permission resolver completed for a tool. - PermissionResolved { - tool_id: String, - response: PermissionResponse, - }, - /// User made a permission choice via the dialog. - PermissionUserChoice { - tool_id: String, - choice: PermissionChoice, - }, - /// Tool execution completed. - ToolExecutionDone { - tool_id: String, - outcome: ToolOutcome, - /// Preview data computed by the driver (diff, content preview, final shell state). - preview: Option, - }, - /// Live preview update for an executing shell command. - ToolPreviewUpdate { - tool_id: String, - lines: Vec, - exit_code: Option, - }, - - // ─── Timers ───────────────────────────────────────────────── - /// Confirmation timeout expired. - ConfirmationTimeout { timeout_id: u64 }, - /// Shell tool execution timeout expired. - ToolExecutionTimeout { timeout_id: u64, tool_id: String }, - - // ─── Session management ───────────────────────────────────── - /// User ran /new to start a fresh session. - NewSession, - - // ─── Slash commands ───────────────────────────────────────── - /// User submitted a slash command (other than /new). - /// The driver resolves known commands (like /help) and passes the - /// rendered content; the FSM just pushes an OOB event. - SlashCommand { command: String, content: String }, - - // ─── Skills ──────────────────────────────────────────────── - /// User invoked a skill via /skill-name. FSM emits a LoadSkill - /// effect; the driver loads the content asynchronously and sends - /// SkillLoaded when ready. - RequestSkillLoad { - name: String, - arguments: Option, - }, - /// A skill's content has been loaded and interpolated. - /// Pushes skill content as OOB context and starts a turn so the - /// LLM sees the skill and acts on it. - SkillLoaded { - name: String, - arguments: Option, - content: String, - }, -} - -/// Result of the permission resolver check. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionResponse { - /// Rule allows this tool call — execute immediately. - Allowed, - /// Rule denies this tool call — reject with error. - Denied, - /// No matching rule — ask the user. - Ask, - /// Session-scoped grant exists — execute immediately (bypass resolver). - SessionGranted, -} - -/// User's choice from the permission dialog. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionChoice { - /// Allow this one time. - Allow, - /// Allow this file for the remainder of the session. - AllowForSession, - /// Always allow in this project (writes to project permissions file). - AlwaysAllowInProject, - /// Always allow globally (writes to global permissions file, scoped to file). - AlwaysAllow, - /// Deny this tool call. - Deny, -} diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs deleted file mode 100644 index 3d72a3ae..00000000 --- a/crates/atuin-ai/src/fsm/mod.rs +++ /dev/null @@ -1,1103 +0,0 @@ -//! Agent conversation FSM. -//! -//! Pure state machine that returns effects as data. -//! The driver is responsible for executing effects and feeding events back. -//! -//! The FSM owns the conversation event log and tool lifecycle state. -//! It never performs IO directly. - -pub(crate) mod effects; -pub(crate) mod events; -pub(crate) mod tools; - -#[cfg(test)] -mod tests; - -use std::collections::HashMap; - -use serde_json::Value; - -use crate::context_window::ContextWindowBuilder; -use crate::tui::state::ConversationEvent; - -use effects::{Effect, ExitAction, PermissionTarget, TimeoutKind}; -use events::{Event, PermissionChoice, PermissionResponse}; -use tools::{ToolManager, ToolState}; - -// ============================================================================ -// State -// ============================================================================ - -/// The discrete states of the agent FSM. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum AgentState { - /// Waiting for user input. - Idle { - confirmation: Option, - }, - - /// A conversation turn is in progress. - Turn { stream: StreamPhase }, - - /// Unrecoverable error. User can retry or exit. - Error(String), -} - -/// Stream connection lifecycle within a Turn. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum StreamPhase { - /// Request sent, awaiting first stream frame. - Connecting, - /// Actively receiving streamed response. - Streaming { status: Option }, - /// Stream connection has ended (Done received). - Done, -} - -/// Streaming status indicators from server. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum StreamingStatus { - Processing, - Searching, - Thinking, - WaitingForTools, -} - -impl StreamingStatus { - pub(crate) fn from_str(s: &str) -> Self { - match s { - "processing" => Self::Processing, - "searching" => Self::Searching, - "waiting_for_tools" => Self::WaitingForTools, - _ => Self::Thinking, - } - } -} - -/// Pending dangerous command confirmation state. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct PendingConfirmation { - pub command: String, - pub timeout_id: u64, -} - -// ============================================================================ -// Context -// ============================================================================ - -/// Shared context owned by the FSM. -#[derive(Debug, Clone)] -pub(crate) struct AgentContext { - /// The full conversation event log (source of truth for API + persistence). - pub events: Vec, - /// Server-assigned session ID. - pub session_id: Option, - /// Accumulated text from current stream (committed to events on tool call or stream end). - pub current_response: String, - /// Per-tool lifecycle state and cached render data. - /// Tools persist across turns for rendering history. - pub tools: ToolManager, - /// Tool IDs that belong to the current turn. Cleared on continuation start. - /// Used to determine whether a turn needs continuation (has unprocessed results). - current_turn_tool_ids: Vec, - /// Maps timeout_id → tool_id for active tool execution timeouts. - /// Cleaned up when a tool completes naturally, so stale timeouts are ignored. - tool_timeout_ids: HashMap, - /// Counter for generating unique timeout IDs. - next_timeout_id: u64, - /// Capabilities advertised to the server. - pub capabilities: Vec, - /// Unique invocation ID for this CLI invocation. - pub invocation_id: String, - - // ─── View state (owned by FSM for atomic transitions) ─────── - /// Index into events where the current TUI invocation starts. - /// Events before this are context for the API but not rendered. - pub view_start_index: usize, - /// Whether this session was resumed from a prior invocation. - pub is_resumed: bool, - /// Time of the last event from a previous invocation. - pub last_event_time: Option>, - /// Events from archived sessions (/new) still rendered on screen. - pub archived_events: Vec, -} - -impl AgentContext { - fn next_timeout_id(&mut self) -> u64 { - let id = self.next_timeout_id; - self.next_timeout_id += 1; - id - } -} - -// ============================================================================ -// The Agent FSM -// ============================================================================ - -/// The agent finite state machine. -/// -/// Pure state machine — `handle()` takes an event, mutates internal state, -/// and returns effects as data for the driver to execute. -#[derive(Debug, Clone)] -pub(crate) struct AgentFsm { - pub state: AgentState, - pub ctx: AgentContext, -} - -impl AgentFsm { - /// Create a new FSM in Idle state. - pub fn new(capabilities: Vec, invocation_id: String) -> Self { - Self { - state: AgentState::Idle { confirmation: None }, - ctx: AgentContext { - events: Vec::new(), - session_id: None, - current_response: String::new(), - tools: ToolManager::new(), - current_turn_tool_ids: Vec::new(), - tool_timeout_ids: HashMap::new(), - next_timeout_id: 0, - capabilities, - invocation_id, - view_start_index: 0, - is_resumed: false, - last_event_time: None, - archived_events: Vec::new(), - }, - } - } - - /// Create an FSM from saved session state (for resume). - pub fn from_session( - events: Vec, - session_id: Option, - capabilities: Vec, - invocation_id: String, - view_start_index: usize, - is_resumed: bool, - last_event_time: Option>, - ) -> Self { - Self { - state: AgentState::Idle { confirmation: None }, - ctx: AgentContext { - events, - session_id, - current_response: String::new(), - tools: ToolManager::new(), - current_turn_tool_ids: Vec::new(), - tool_timeout_ids: HashMap::new(), - next_timeout_id: 0, - capabilities, - invocation_id, - view_start_index, - is_resumed, - last_event_time, - archived_events: Vec::new(), - }, - } - } - - /// Handle an event, returning effects to execute. - pub fn handle(&mut self, event: Event) -> Vec { - match (&self.state, event) { - // ================================================================ - // Idle state - // ================================================================ - (AgentState::Idle { confirmation: None }, Event::UserSubmit(msg)) => { - self.start_turn(msg) - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::UserSubmit(msg), - ) => self.start_turn(msg), - - (AgentState::Idle { confirmation: None }, Event::ExecuteCommand) => { - let cmd = self.current_command(); - let Some(cmd) = cmd else { - // No command suggested — exit - return vec![Effect::ExitApp(ExitAction::Cancel)]; - }; - if self.is_current_command_dangerous() { - let timeout_id = self.ctx.next_timeout_id(); - self.state = AgentState::Idle { - confirmation: Some(PendingConfirmation { - command: cmd, - timeout_id, - }), - }; - vec![Effect::ScheduleTimeout { - timeout_id, - duration: std::time::Duration::from_secs(5), - kind: TimeoutKind::Confirmation, - }] - } else { - vec![Effect::ExitApp(ExitAction::Execute(cmd))] - } - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::ExecuteCommand, - ) => { - let confirm = self.state_confirmation().unwrap().clone(); - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::ExitApp(ExitAction::Execute(confirm.command))] - } - - (AgentState::Idle { .. }, Event::InsertCommand) => { - let cmd = self.current_command(); - match cmd { - Some(cmd) => vec![Effect::ExitApp(ExitAction::Insert(cmd))], - None => vec![], - } - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::Cancel, - ) => { - self.state = AgentState::Idle { confirmation: None }; - vec![] - } - - (AgentState::Idle { confirmation: None }, Event::Cancel) => { - vec![Effect::ExitApp(ExitAction::Cancel)] - } - - (AgentState::Idle { .. }, Event::ConfirmationTimeout { timeout_id }) => { - if self - .state_confirmation() - .is_some_and(|c| c.timeout_id == timeout_id) - { - self.state = AgentState::Idle { confirmation: None }; - } - vec![] - } - - (AgentState::Idle { .. }, Event::NewSession) => { - // Archive visible events so they remain on screen but aren't - // sent to the API. Tools persist for rendering. - let visible = self.ctx.events[self.ctx.view_start_index..].to_vec(); - self.ctx.archived_events.extend(visible); - - self.ctx.events.clear(); - self.ctx.session_id = None; - self.ctx.current_turn_tool_ids.clear(); - self.ctx.view_start_index = 0; - self.ctx.is_resumed = false; - - // Add OOB indicator for the new session - self.ctx.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/new".to_string()), - content: "Started a new session.".to_string(), - }); - - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::ArchiveSession, Effect::Persist] - } - - (AgentState::Idle { .. }, Event::SlashCommand { command, content }) => { - self.handle_slash_command(&command, &content); - vec![] - } - - ( - AgentState::Idle { .. }, - Event::SkillLoaded { - name, - arguments, - content, - }, - ) => { - self.ctx.events.push(ConversationEvent::SkillInvocation { - name, - arguments, - content, - }); - self.ctx.current_response.clear(); - self.ctx.current_turn_tool_ids.clear(); - - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - // ================================================================ - // Turn — stream lifecycle - // ================================================================ - ( - AgentState::Turn { - stream: StreamPhase::Connecting, - }, - Event::StreamStarted, - ) => { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { status: None }, - }; - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Connecting, - }, - Event::StreamError(e), - ) => { - self.state = AgentState::Error(e); - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamChunk(text), - ) => { - self.ctx.current_response.push_str(&text); - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamStatusChanged(status), - ) => { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { - status: Some(StreamingStatus::from_str(&status)), - }, - }; - vec![] - } - - (AgentState::Turn { .. }, Event::StreamToolCall { id, name, input }) => { - self.commit_streaming_text(); - self.handle_stream_tool_call(id, name, input) - } - - (AgentState::Turn { .. }, Event::SuggestCommand { id, input }) => { - self.commit_streaming_text(); - // Push the suggest_command as a ToolCall event (protocol requirement) - self.ctx.events.push(ConversationEvent::ToolCall { - id, - name: "suggest_command".to_string(), - input, - }); - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::Persist] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamServerToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }, - ) => { - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }); - vec![] - } - - (AgentState::Turn { .. }, Event::StreamDone { session_id }) => { - self.commit_streaming_text(); - if !session_id.is_empty() { - self.ctx.session_id = Some(session_id); - } - self.state = AgentState::Turn { - stream: StreamPhase::Done, - }; - self.check_turn_completion() - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamError(e), - ) => { - // Abort any executing tools on stream error - let abort_effects: Vec<_> = self - .ctx - .tools - .executing_ids() - .into_iter() - .map(|tool_id| Effect::AbortTool { tool_id }) - .collect(); - self.ctx.tool_timeout_ids.clear(); - self.state = AgentState::Error(e); - abort_effects - } - - // ================================================================ - // Turn — tool lifecycle (any stream phase) - // ================================================================ - (AgentState::Turn { .. }, Event::PermissionResolved { tool_id, response }) => { - self.handle_permission_resolved(tool_id, response) - } - - (AgentState::Turn { .. }, Event::PermissionUserChoice { tool_id, choice }) => { - self.handle_permission_choice(tool_id, choice) - } - - ( - AgentState::Turn { .. }, - Event::ToolExecutionDone { - tool_id, - outcome, - preview, - }, - ) => self.handle_tool_done(tool_id, outcome, preview), - - ( - AgentState::Turn { .. }, - Event::ToolPreviewUpdate { - tool_id, - lines, - exit_code, - }, - ) => { - if let Some(tracked) = self.ctx.tools.get_mut(&tool_id) { - if tracked.is_resolved() { - // Tool already completed — a late preview update raced with - // ToolExecutionDone. Update lines (they may carry the final - // screen) but preserve the finalized exit_code/interrupted. - if let Some(tools::ToolPreviewData::Shell { - lines: existing_lines, - .. - }) = &mut tracked.preview - { - *existing_lines = lines; - } - } else { - tracked.preview = Some(tools::ToolPreviewData::Shell { - lines, - exit_code, - interrupted: None, - }); - } - } - vec![] - } - - (AgentState::Turn { .. }, Event::InterruptTools) => { - let ids = self.ctx.tools.executing_ids(); - for id in &ids { - if let Some(tracked) = self.ctx.tools.get_mut(id) { - tracked.interrupt_reason = Some(tools::InterruptReason::User); - } - // Clear any pending execution timeout for this tool - self.ctx.tool_timeout_ids.retain(|_, tid| tid != id); - } - ids.into_iter() - .map(|tool_id| Effect::AbortTool { tool_id }) - .collect() - } - - ( - AgentState::Turn { .. }, - Event::ToolExecutionTimeout { - timeout_id, - tool_id, - }, - ) => self.handle_tool_execution_timeout(timeout_id, tool_id), - - // ─── Cancel during Turn ───────────────────────────────────── - (AgentState::Turn { stream }, Event::Cancel) => { - let mut effects = Vec::new(); - - // Abort stream if still active - if !matches!(stream, StreamPhase::Done) { - effects.push(Effect::AbortStream); - } - - // Cancel all pending tools - let pending = self.ctx.tools.pending_ids(); - for id in &pending { - if let Some(tracked) = self.ctx.tools.get_mut(id) { - if tracked.state == ToolState::Executing { - effects.push(Effect::AbortTool { - tool_id: id.clone(), - }); - } - tracked.state = ToolState::Completed; - } - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: id.clone(), - content: "Error: user cancelled this operation".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - } - - // Commit any partial streaming text - self.commit_streaming_text_as_cancelled(); - - // Add context so the LLM knows what happened - if !pending.is_empty() { - self.ctx.events.push(ConversationEvent::SystemContext { - content: "The user cancelled the previous generation. Tool calls that were in progress have been aborted.".to_string(), - }); - } - - // Clear timeout mappings — stale timeouts will be ignored by the guard - self.ctx.tool_timeout_ids.clear(); - - self.state = AgentState::Idle { confirmation: None }; - effects.push(Effect::Persist); - effects - } - - // ================================================================ - // Error state - // ================================================================ - (AgentState::Error(_), Event::Retry) => { - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - (AgentState::Error(_), Event::Cancel) => { - vec![Effect::ExitApp(ExitAction::Cancel)] - } - - // ================================================================ - // Fallthrough — ignore events with no valid transition - // ================================================================ - - // StreamDone can arrive after SuggestCommand (which already moved to Idle). - // We still need to capture the session_id from it. - (_, Event::StreamDone { session_id }) => { - if !session_id.is_empty() { - self.ctx.session_id = Some(session_id); - } - vec![Effect::Persist] - } - - (_, Event::SlashCommand { command, content }) => { - self.handle_slash_command(&command, &content); - vec![] - } - - // RequestSkillLoad during non-idle: still emit the effect - (_, Event::RequestSkillLoad { name, arguments }) => { - vec![Effect::LoadSkill { name, arguments }] - } - - // SkillLoaded during non-idle: queue so it's visible - // in context for the next turn. - ( - _, - Event::SkillLoaded { - name, - arguments, - content, - }, - ) => { - self.ctx.events.push(ConversationEvent::SkillInvocation { - name, - arguments, - content, - }); - vec![] - } - - _ => vec![], - } - } - - // ──────────────────────────────────────────────────────────────────── - // Private helpers - // ──────────────────────────────────────────────────────────────────── - - /// Start a new turn: push user message, build messages, emit StartStream. - fn start_turn(&mut self, msg: String) -> Vec { - self.ctx - .events - .push(ConversationEvent::UserMessage { content: msg }); - // Don't clear tools — completed tools persist for rendering history. - // Tools are only cleared on /new (session reset). - self.ctx.current_response.clear(); - self.ctx.current_turn_tool_ids.clear(); - - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - /// Build API messages from the conversation event log. - fn build_messages(&self) -> Vec { - ContextWindowBuilder::with_default_budget().build(&self.ctx.events) - } - - /// Commit accumulated streaming text to the event log. - fn commit_streaming_text(&mut self) { - let text = std::mem::take(&mut self.ctx.current_response); - let trimmed = text.trim_start().to_string(); - if !trimmed.is_empty() { - self.ctx - .events - .push(ConversationEvent::Text { content: trimmed }); - } - } - - /// Commit streaming text with a cancellation suffix. - fn commit_streaming_text_as_cancelled(&mut self) { - let text = std::mem::take(&mut self.ctx.current_response); - let trimmed = text.trim_start().to_string(); - if !trimmed.is_empty() { - self.ctx.events.push(ConversationEvent::Text { - content: format!("{trimmed}\n\n[User cancelled this generation]"), - }); - } - } - - /// Handle a client-side tool call from the stream. - fn handle_stream_tool_call(&mut self, id: String, name: String, input: Value) -> Vec { - // Parse the tool call - let tool = match crate::tools::ClientToolCall::try_from((name.as_str(), &input)) { - Ok(tool) => tool, - Err(_) => { - // Unknown tool — push as event but don't track - self.ctx - .events - .push(ConversationEvent::ToolCall { id, name, input }); - return vec![]; - } - }; - - // Capability gating - if let Some(required_cap) = tool.descriptor().capability - && !self.ctx.capabilities.iter().any(|c| c == required_cap) - { - self.ctx.events.push(ConversationEvent::ToolCall { - id: id.clone(), - name, - input, - }); - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: id, - content: format!( - "Tool not enabled: capability '{required_cap}' was not advertised by this client" - ), - is_error: true, - remote: false, - content_length: None, - }); - return vec![]; - } - - // Track the tool and push ToolCall event - let tool_for_effect = tool.clone(); - self.ctx.tools.insert(id.clone(), tool); - self.ctx.current_turn_tool_ids.push(id.clone()); - self.ctx.events.push(ConversationEvent::ToolCall { - id: id.clone(), - name, - input, - }); - - // Transition to Turn if we were Streaming - if let AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - } = &self.state - { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { status: None }, - }; - } - - vec![Effect::CheckPermission { - tool_id: id, - tool: tool_for_effect, - }] - } - - /// Handle permission resolver result. - fn handle_permission_resolved( - &mut self, - tool_id: String, - response: PermissionResponse, - ) -> Vec { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - // If already resolved (e.g. cancelled while permission check was in flight), - // ignore the stale result to avoid re-executing a cancelled tool. - if tracked.is_resolved() { - return vec![]; - } - - match response { - PermissionResponse::Allowed | PermissionResponse::SessionGranted => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - self.emit_execute_tool(tool_id, tool) - } - PermissionResponse::Ask => { - tracked.state = ToolState::AwaitingPermission; - vec![] - } - PermissionResponse::Denied => { - tracked.state = ToolState::Denied; - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content: "Permission denied on the user's system".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - self.check_turn_completion() - } - } - } - - /// Handle user's permission choice from the dialog. - fn handle_permission_choice( - &mut self, - tool_id: String, - choice: PermissionChoice, - ) -> Vec { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - if tracked.is_resolved() { - return vec![]; - } - - match choice { - PermissionChoice::Allow => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - self.emit_execute_tool(tool_id, tool) - } - PermissionChoice::AllowForSession => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let mut effects = self.emit_execute_tool(tool_id, tool.clone()); - if let Some(path) = tool.resolved_file_path() { - effects.push(Effect::CacheSessionGrant { path }); - } - effects - } - PermissionChoice::AlwaysAllowInProject => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let rule = crate::permissions::rule::Rule { - tool: tool.rule_name().to_string(), - scope: None, // project file provides the scoping - }; - let mut effects = self.emit_execute_tool(tool_id, tool); - effects.push(Effect::WritePermissionRule { - target: PermissionTarget::Project, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }); - effects - } - PermissionChoice::AlwaysAllow => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let scope = tool - .resolved_file_path() - .map(|p| p.to_string_lossy().to_string()); - let rule = crate::permissions::rule::Rule { - tool: tool.rule_name().to_string(), - scope, - }; - let mut effects = self.emit_execute_tool(tool_id, tool); - effects.push(Effect::WritePermissionRule { - target: PermissionTarget::Global, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }); - effects - } - PermissionChoice::Deny => { - tracked.state = ToolState::Denied; - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content: "Permission denied by the user".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - self.check_turn_completion() - } - } - } - - /// Handle tool execution completion. - fn handle_tool_done( - &mut self, - tool_id: String, - outcome: crate::tools::ToolOutcome, - preview: Option, - ) -> Vec { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - // If already completed (e.g. cancelled), ignore stale result - if tracked.is_resolved() { - return vec![]; - } - - tracked.state = ToolState::Completed; - - // If the FSM tagged this tool with an interrupt reason (user or timeout), - // use it; otherwise derive from the outcome's interrupted flag. - let reason = tracked.interrupt_reason.take().or({ - if let crate::tools::ToolOutcome::Structured { - interrupted: true, .. - } = &outcome - { - Some(tools::InterruptReason::User) - } else { - None - } - }); - - // Merge shell preview: the final ToolExecutionDone carries exit_code/interrupted - // but has empty lines (the live lines were accumulated via ToolPreviewUpdate). - // Preserve the accumulated lines and fold in the terminal metadata. - match (&mut tracked.preview, preview) { - ( - Some(tools::ToolPreviewData::Shell { - exit_code, - interrupted, - .. - }), - Some(tools::ToolPreviewData::Shell { - exit_code: final_exit, - .. - }), - ) => { - *exit_code = final_exit; - *interrupted = reason.clone(); - } - (_, Some(mut p)) => { - if let tools::ToolPreviewData::Shell { - ref mut interrupted, - .. - } = p - { - *interrupted = reason.clone(); - } - tracked.preview = Some(p); - } - _ => {} - } - - // Clean up any pending execution timeout for this tool - self.ctx.tool_timeout_ids.retain(|_, tid| tid != &tool_id); - - let content = outcome.format_for_llm(reason.as_ref()); - let is_error = outcome.is_error(); - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content, - is_error, - remote: false, - content_length: None, - }); - - self.check_turn_completion() - } - - /// Handle a tool execution timeout. Aborts the tool if it's still running. - fn handle_tool_execution_timeout(&mut self, timeout_id: u64, tool_id: String) -> Vec { - // Guard: only act if this timeout is still registered (not cleaned up by natural completion) - if self.ctx.tool_timeout_ids.remove(&timeout_id).is_none() { - return vec![]; - } - - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - if tracked.is_resolved() { - return vec![]; - } - - // Tag the tool so handle_tool_done can distinguish timeout from user interrupt. - // Only shell tools have entries in tool_timeout_ids, so this is always Shell. - let timeout_secs = match &tracked.tool { - crate::tools::ClientToolCall::Shell(s) => s.timeout_secs, - _ => unreachable!("only shell tools have execution timeouts"), - }; - tracked.interrupt_reason = Some(tools::InterruptReason::Timeout(timeout_secs)); - - // Abort the tool — the driver sends the interrupt signal via oneshot, - // and execute_shell_command_streaming returns a Structured outcome with - // interrupted: true and partial stdout/stderr. This flows through the - // normal ToolExecutionDone path. - vec![Effect::AbortTool { tool_id }] - } - - /// Emit effects to begin executing a tool. For shell commands, also schedules - /// an execution timeout based on the LLM-specified timeout_secs. - fn emit_execute_tool( - &mut self, - tool_id: String, - tool: crate::tools::ClientToolCall, - ) -> Vec { - let mut effects = vec![Effect::ExecuteTool { - tool_id: tool_id.clone(), - tool: tool.clone(), - }]; - - if let crate::tools::ClientToolCall::Shell(ref shell) = tool { - let timeout_id = self.ctx.next_timeout_id(); - self.ctx - .tool_timeout_ids - .insert(timeout_id, tool_id.clone()); - effects.push(Effect::ScheduleTimeout { - timeout_id, - duration: std::time::Duration::from_secs(shell.timeout_secs), - kind: TimeoutKind::ToolExecution { tool_id }, - }); - } - - effects - } - - /// Check if the turn is complete (stream done + all tools resolved). - /// If so, either continue the conversation or go Idle. - fn check_turn_completion(&mut self) -> Vec { - // Stream must be done - if !matches!( - self.state, - AgentState::Turn { - stream: StreamPhase::Done - } - ) { - return vec![]; - } - - // All current-turn tools must be resolved before the turn can complete - if !self.ctx.tools.all_resolved(&self.ctx.current_turn_tool_ids) { - return vec![]; - } - - // Turn is complete. Check if we need to continue (tool results to send back). - // We continue if this turn had any client tool calls (the LLM needs to see - // the results and respond). - if !self.ctx.current_turn_tool_ids.is_empty() { - // Continue conversation with tool results. - // Don't clear tools — they persist for rendering history. - // Clear turn IDs so the continuation turn doesn't loop. - self.ctx.current_turn_tool_ids.clear(); - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.ctx.current_response.clear(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } else { - // No tools — turn is done, go idle - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::Persist] - } - } - - /// Extract the current confirmation state (if any). - fn state_confirmation(&self) -> Option<&PendingConfirmation> { - if let AgentState::Idle { - confirmation: Some(ref c), - } = self.state - { - Some(c) - } else { - None - } - } - - /// Get the most recent suggested command from the conversation. - /// Get the most recent command from the current invocation only. - fn current_command(&self) -> Option { - self.current_invocation_events() - .rev() - .find_map(|e| e.as_command()) - .map(|s| s.to_string()) - } - - /// Check if the most recent command is dangerous. - fn is_current_command_dangerous(&self) -> bool { - self.current_invocation_events() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - Some(danger == "high" || danger == "medium" || danger == "med") - } else { - None - } - }) - .unwrap_or(false) - } - - /// Events from the current invocation only (from view_start_index onward). - fn current_invocation_events(&self) -> impl DoubleEndedIterator { - let start = self.ctx.view_start_index.min(self.ctx.events.len()); - self.ctx.events[start..].iter() - } - - /// Handle a slash command by pushing an OOB event. - fn handle_slash_command(&mut self, command: &str, content: &str) { - self.ctx.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some(command.to_string()), - content: content.to_string(), - }); - } -} diff --git a/crates/atuin-ai/src/fsm/tests.rs b/crates/atuin-ai/src/fsm/tests.rs deleted file mode 100644 index 51c23915..00000000 --- a/crates/atuin-ai/src/fsm/tests.rs +++ /dev/null @@ -1,890 +0,0 @@ -//! Pure FSM transition tests. No IO, no async. - -use serde_json::json; - -use super::*; -use effects::{Effect, ExitAction}; -use events::{Event, PermissionChoice, PermissionResponse}; - -fn new_fsm() -> AgentFsm { - AgentFsm::new( - vec!["client_v1_read_file".to_string()], - "test-inv".to_string(), - ) -} - -// ============================================================================ -// Idle → Turn -// ============================================================================ - -#[test] -fn user_submit_starts_turn() { - let mut fsm = new_fsm(); - - let effects = fsm.handle(Event::UserSubmit("hello".into())); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert_eq!(effects.len(), 1); - assert!(matches!(effects[0], Effect::StartStream { .. })); - // User message was pushed to events - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::UserMessage { content } if content == "hello" - ))); -} - -#[test] -fn stream_started_transitions_to_streaming() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - - let effects = fsm.handle(Event::StreamStarted); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Streaming { status: None } - } - )); - assert!(effects.is_empty()); -} - -#[test] -fn stream_chunk_accumulates_text() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - - fsm.handle(Event::StreamChunk("Hello ".into())); - fsm.handle(Event::StreamChunk("world!".into())); - - assert_eq!(fsm.ctx.current_response, "Hello world!"); -} - -#[test] -fn stream_done_without_tools_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("Hi there!".into())); - - let effects = fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert_eq!(fsm.ctx.session_id, Some("s1".to_string())); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - // Text was committed to events - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::Text { content } if content == "Hi there!" - ))); -} - -// ============================================================================ -// Tool lifecycle -// ============================================================================ - -#[test] -fn stream_tool_call_tracks_tool_and_emits_check_permission() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read a file".into())); - fsm.handle(Event::StreamStarted); - - let effects = fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - assert!(fsm.ctx.tools.get("t1").is_some()); - assert_eq!(effects.len(), 1); - assert!(matches!(effects[0], Effect::CheckPermission { .. })); -} - -#[test] -fn permission_allowed_transitions_to_executing() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert_eq!(fsm.ctx.tools.get("t1").unwrap().state, ToolState::Executing); - assert!(matches!(effects[0], Effect::ExecuteTool { .. })); -} - -#[test] -fn permission_ask_transitions_to_awaiting() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Ask, - }); - - assert_eq!( - fsm.ctx.tools.get("t1").unwrap().state, - ToolState::AwaitingPermission - ); - assert!(effects.is_empty()); -} - -#[test] -fn tool_done_after_stream_done_continues_conversation() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Now in Turn { Done } with one tool Executing - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("file contents".into()), - preview: None, - }); - - // Turn complete → continuation - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -#[test] -fn continuation_turn_without_new_tools_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - // Tool completes → continuation starts - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - - // Continuation stream: text only, no new tools - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("Here's the file.".into())); - let effects = fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - - // Should go Idle, NOT start another continuation - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - assert!( - !effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -#[test] -fn tool_done_before_stream_done_stays_in_turn() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Tool completes but stream hasn't sent Done yet - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - - // Still in Turn — stream phase is Streaming, not Done - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Streaming { .. } - } - )); - assert!(effects.is_empty()); -} - -// ============================================================================ -// Cancel -// ============================================================================ - -#[test] -fn cancel_during_streaming_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("partial text".into())); - - let effects = fsm.handle(Event::Cancel); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.iter().any(|e| matches!(e, Effect::AbortStream))); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - // Partial text committed with cancel suffix - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::Text { content } if content.contains("[User cancelled") - ))); -} - -#[test] -fn stale_permission_resolved_after_cancel_is_ignored() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - // Tool is in CheckingPermission, cancel happens before permission resolves - fsm.handle(Event::Cancel); - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - - // Stale permission result arrives — tool is already Completed (cancelled) - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Should NOT emit ExecuteTool — the tool was cancelled - assert!(effects.is_empty()); -} - -#[test] -fn cancel_during_turn_with_pending_tools() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - // Tool is Executing, stream is Done - - let effects = fsm.handle(Event::Cancel); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::AbortTool { .. })) - ); - // Error ToolResult injected - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" - ))); - // SystemContext about cancellation - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::SystemContext { content } if content.contains("cancelled") - ))); -} - -#[test] -fn stale_tool_result_after_cancel_is_ignored() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::Cancel); - - // Stale event arrives - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.is_empty()); -} - -// ============================================================================ -// Confirmation -// ============================================================================ - -#[test] -fn dangerous_command_enters_confirmation() { - let mut fsm = new_fsm(); - // Simulate a dangerous command in history - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - - let effects = fsm.handle(Event::ExecuteCommand); - - assert!(matches!( - fsm.state, - AgentState::Idle { - confirmation: Some(_) - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) - ); -} - -#[test] -fn second_execute_confirms_and_exits() { - let mut fsm = new_fsm(); - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - fsm.handle(Event::ExecuteCommand); - - let effects = fsm.handle(Event::ExecuteCommand); - - assert!(effects.iter().any(|e| matches!( - e, - Effect::ExitApp(ExitAction::Execute(cmd)) if cmd == "rm -rf /" - ))); -} - -#[test] -fn confirmation_timeout_clears_confirmation() { - let mut fsm = new_fsm(); - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - fsm.handle(Event::ExecuteCommand); - let timeout_id = match &fsm.state { - AgentState::Idle { - confirmation: Some(c), - } => c.timeout_id, - _ => panic!("expected confirmation"), - }; - - fsm.handle(Event::ConfirmationTimeout { timeout_id }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); -} - -// ============================================================================ -// Error / Retry -// ============================================================================ - -#[test] -fn stream_error_goes_to_error_state() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - - fsm.handle(Event::StreamError("network error".into())); - - assert_eq!(fsm.state, AgentState::Error("network error".to_string())); -} - -#[test] -fn retry_from_error_starts_new_stream() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamError("fail".into())); - - let effects = fsm.handle(Event::Retry); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -// ============================================================================ -// Permission choices -// ============================================================================ - -#[test] -fn permission_deny_completes_turn_and_continues() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Ask, - }); - - let effects = fsm.handle(Event::PermissionUserChoice { - tool_id: "t1".into(), - choice: PermissionChoice::Deny, - }); - - // Turn should complete since all tools resolved and stream is done - // → continuation needed (there was a tool result to send back) - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); - // Error result was injected - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" - ))); -} - -// ============================================================================ -// Shell execution timeouts -// ============================================================================ - -fn fsm_with_shell() -> AgentFsm { - AgentFsm::new( - vec![ - "client_v1_read_file".to_string(), - "client_v1_execute_shell_command".to_string(), - ], - "test-inv".to_string(), - ) -} - -fn shell_tool_call_event(id: &str) -> Event { - Event::StreamToolCall { - id: id.into(), - name: "execute_shell_command".into(), - input: json!({ - "command": "sleep 999", - "shell": "bash", - "timeout": 60, - "description": "test" - }), - } -} - -#[test] -fn shell_tool_schedules_execution_timeout() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run something".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Should have ExecuteTool + ScheduleTimeout - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ExecuteTool { .. })) - ); - assert!(effects.iter().any(|e| matches!( - e, - Effect::ScheduleTimeout { kind: effects::TimeoutKind::ToolExecution { tool_id }, .. } - if tool_id == "t1" - ))); - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn read_tool_does_not_schedule_timeout() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ExecuteTool { .. })) - ); - assert!( - !effects - .iter() - .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) - ); - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn tool_completion_clears_timeout_mapping() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - // Tool completes naturally - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("done".into()), - preview: None, - }); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn stale_timeout_after_natural_completion_is_ignored() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Tool completes naturally - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("done".into()), - preview: None, - }); - - // Stale timeout fires — should be no-op - let effects = fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - assert!(effects.is_empty()); -} - -#[test] -fn timeout_fires_before_completion_emits_abort() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Timeout fires while tool is still executing - let effects = fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - assert_eq!(effects.len(), 1); - assert!(matches!( - effects[0], - Effect::AbortTool { ref tool_id } if tool_id == "t1" - )); - // Timeout mapping cleaned up - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn timeout_respects_llm_specified_duration() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - - // Tool call with timeout: 120 - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "execute_shell_command".into(), - input: json!({ - "command": "cargo build", - "shell": "bash", - "timeout": 120, - "description": "build" - }), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - let timeout_effect = effects - .iter() - .find(|e| matches!(e, Effect::ScheduleTimeout { .. })); - assert!(matches!( - timeout_effect, - Some(Effect::ScheduleTimeout { duration, .. }) if *duration == std::time::Duration::from_secs(120) - )); -} - -#[test] -fn cancel_clears_timeout_mappings() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - fsm.handle(Event::Cancel); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn timeout_abort_propagates_timeout_reason_to_preview_and_llm() { - use super::tools::InterruptReason; - - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Timeout fires - fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - // Tool completes after abort (interrupted: true from execute_shell_command_streaming) - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Structured { - stdout: "partial output".into(), - stderr: String::new(), - exit_code: None, - duration_ms: 60000, - interrupted: true, - }, - preview: Some(super::tools::ToolPreviewData::Shell { - lines: vec!["partial output".into()], - exit_code: None, - interrupted: None, // FSM overrides this with the reason - }), - }); - - // Preview should carry Timeout reason - let tracked = fsm.ctx.tools.get("t1").unwrap(); - let preview = tracked.shell_preview().unwrap(); - assert_eq!(preview.interrupted, Some(InterruptReason::Timeout(60))); - - // LLM content should say "Timed out" not "Interrupted by user" - let tool_result = fsm.ctx.events.iter().find( - |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), - ); - if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { - assert!( - content.contains("[Timed out after 60s]"), - "Expected timeout message, got: {content}" - ); - assert!(!content.contains("[Interrupted by user]")); - } else { - panic!("No ToolResult found for t1"); - } -} - -#[test] -fn user_interrupt_propagates_user_reason_to_preview_and_llm() { - use super::tools::InterruptReason; - - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // User interrupts - fsm.handle(Event::InterruptTools); - - // Tool completes after abort - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Structured { - stdout: "partial".into(), - stderr: String::new(), - exit_code: None, - duration_ms: 5000, - interrupted: true, - }, - preview: Some(super::tools::ToolPreviewData::Shell { - lines: vec!["partial".into()], - exit_code: None, - interrupted: None, // FSM overrides this with the reason - }), - }); - - // Preview should carry User reason - let tracked = fsm.ctx.tools.get("t1").unwrap(); - let preview = tracked.shell_preview().unwrap(); - assert_eq!(preview.interrupted, Some(InterruptReason::User)); - - // LLM content should say "Interrupted by user" - let tool_result = fsm.ctx.events.iter().find( - |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), - ); - if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { - assert!( - content.contains("[Interrupted by user]"), - "Expected user interrupt message, got: {content}" - ); - } else { - panic!("No ToolResult found for t1"); - } -} - -#[test] -fn user_interrupt_clears_timeout_mappings_for_aborted_tools() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - fsm.handle(Event::InterruptTools); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} diff --git a/crates/atuin-ai/src/fsm/tools.rs b/crates/atuin-ai/src/fsm/tools.rs deleted file mode 100644 index 96348672..00000000 --- a/crates/atuin-ai/src/fsm/tools.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Tool lifecycle management within the FSM. -//! -//! Each tool call goes through an independent lifecycle. The ToolManager -//! tracks all tools in the current turn and provides the "all resolved" -//! check that gates turn completion. - -use crate::diff::{EditPreview, WritePreview}; -use crate::tools::ClientToolCall; - -/// Why a tool execution was interrupted. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum InterruptReason { - /// User pressed Ctrl+C or Esc during execution. - User, - /// The LLM-specified execution timeout expired. - Timeout(u64), -} - -/// Per-tool lifecycle state. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ToolState { - /// Permission resolver is running asynchronously. - CheckingPermission, - /// Waiting for user to grant/deny via the permission dialog. - AwaitingPermission, - /// Actively executing. - Executing, - /// Execution completed (result injected into conversation). - Completed, - /// User denied permission (error result injected into conversation). - Denied, -} - -/// Cached preview data for rendering tool output. -#[derive(Debug, Clone)] -pub(crate) enum ToolPreviewData { - /// Shell command VT100 output lines. - Shell { - lines: Vec, - exit_code: Option, - interrupted: Option, - }, - /// File edit diff preview. - Edit(EditPreview), - /// File write content preview. - Write(WritePreview), -} - -/// A tracked tool call with its current lifecycle state. -#[derive(Debug, Clone)] -pub(crate) struct TrackedTool { - pub id: String, - pub tool: ClientToolCall, - pub state: ToolState, - /// Cached preview data for rendering (populated during/after execution). - pub preview: Option, - /// Set by the FSM when it emits AbortTool, so that ToolExecutionDone - /// can distinguish user interrupts from timeouts. - pub interrupt_reason: Option, -} - -impl TrackedTool { - /// Whether this tool has reached a terminal state. - pub fn is_resolved(&self) -> bool { - matches!(self.state, ToolState::Completed | ToolState::Denied) - } - - /// Extract shell preview data (for TurnBuilder compatibility). - pub fn shell_preview(&self) -> Option { - match &self.preview { - Some(ToolPreviewData::Shell { - lines, - exit_code, - interrupted, - }) => Some(crate::tools::ToolPreview { - lines: lines.clone(), - exit_code: *exit_code, - interrupted: interrupted.clone(), - }), - _ => None, - } - } - - /// Extract edit diff preview (for TurnBuilder compatibility). - pub fn edit_preview(&self) -> Option<&EditPreview> { - match &self.preview { - Some(ToolPreviewData::Edit(p)) => Some(p), - _ => None, - } - } - - /// Extract write content preview (for TurnBuilder compatibility). - pub fn write_preview(&self) -> Option<&WritePreview> { - match &self.preview { - Some(ToolPreviewData::Write(p)) => Some(p), - _ => None, - } - } -} - -/// Manages tool call lifecycles for a single turn. -/// -/// Tools are inserted when received from the stream and progress through -/// their lifecycle independently. The manager provides aggregate queries -/// (all resolved, any awaiting permission, etc.) that the FSM uses for -/// state transitions. -#[derive(Debug, Clone, Default)] -pub(crate) struct ToolManager { - tools: Vec, -} - -impl ToolManager { - pub fn new() -> Self { - Self { tools: Vec::new() } - } - - /// Insert a new tool in CheckingPermission state. - pub fn insert(&mut self, id: String, tool: ClientToolCall) { - self.tools.push(TrackedTool { - id, - tool, - state: ToolState::CheckingPermission, - preview: None, - interrupt_reason: None, - }); - } - - /// Look up a tool by ID. - pub fn get(&self, id: &str) -> Option<&TrackedTool> { - self.tools.iter().find(|t| t.id == id) - } - - /// Look up a tool mutably by ID. - pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { - self.tools.iter_mut().find(|t| t.id == id) - } - - /// True if all tools from the given set of IDs have reached a terminal state. - /// Returns true for an empty set (vacuously — no tools to wait for). - pub fn all_resolved(&self, tool_ids: &[String]) -> bool { - tool_ids - .iter() - .all(|id| self.get(id).is_some_and(|t| t.is_resolved())) - } - - /// Find the first tool awaiting user permission. - pub fn awaiting_permission(&self) -> Option<&TrackedTool> { - self.tools - .iter() - .find(|t| t.state == ToolState::AwaitingPermission) - } - - /// Get IDs of all non-resolved tools (for cancel). - pub fn pending_ids(&self) -> Vec { - self.tools - .iter() - .filter(|t| !t.is_resolved()) - .map(|t| t.id.clone()) - .collect() - } - - /// Get IDs of all currently executing tools (for interrupt/abort). - pub fn executing_ids(&self) -> Vec { - self.tools - .iter() - .filter(|t| t.state == ToolState::Executing) - .map(|t| t.id.clone()) - .collect() - } - - /// True if any tool has a shell preview with live output. - pub fn has_executing_preview(&self) -> bool { - self.tools.iter().any(|t| { - t.state == ToolState::Executing - && matches!(t.preview, Some(ToolPreviewData::Shell { .. })) - }) - } -} diff --git a/crates/atuin-ai/src/history_format.rs b/crates/atuin-ai/src/history_format.rs deleted file mode 100644 index 24aa963e..00000000 --- a/crates/atuin-ai/src/history_format.rs +++ /dev/null @@ -1,120 +0,0 @@ -use atuin_client::history::History; -use time::UtcOffset; - -pub(crate) fn current_local_offset() -> UtcOffset { - UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC) -} - -pub(crate) fn format_last_command(history: &History, local_offset: UtcOffset) -> String { - format!( - "History ID: {} - `{}`\n{}", - history.id, - history.command, - format_history_metadata(history, local_offset) - ) -} - -pub(crate) fn format_history_search_result( - ordinal: usize, - history: &History, - local_offset: UtcOffset, -) -> String { - format!( - "## #{}. (History ID: {}):\n`{}`\n{}\n", - ordinal, - history.id, - history.command, - format_history_metadata(history, local_offset) - ) -} - -fn format_history_metadata(history: &History, local_offset: UtcOffset) -> String { - format!( - "[{}] (in `{}`, exit {}){}", - format_timestamp(history, local_offset), - history.cwd, - history.exit, - format_duration(history.duration) - ) -} - -fn format_timestamp(history: &History, local_offset: UtcOffset) -> String { - let ts = history.timestamp.to_offset(local_offset); - format!( - "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", - ts.year(), - ts.month() as u8, - ts.day(), - ts.hour(), - ts.minute(), - ts.second(), - ) -} - -fn format_duration(nanos: i64) -> String { - if nanos <= 0 { - return String::new(); - } - - let total_secs = nanos / 1_000_000_000; - let millis = (nanos % 1_000_000_000) / 1_000_000; - - if total_secs >= 3600 { - let hours = total_secs / 3600; - let mins = (total_secs % 3600) / 60; - let secs = total_secs % 60; - format!(", {hours}h{mins}m{secs}s") - } else if total_secs >= 60 { - let mins = total_secs / 60; - let secs = total_secs % 60; - format!(", {mins}m{secs}s") - } else if total_secs > 0 { - if millis > 0 { - format!(", {total_secs}.{millis:03}s") - } else { - format!(", {total_secs}s") - } - } else { - format!(", {millis}ms") - } -} - -#[cfg(test)] -mod tests { - use atuin_client::history::{History, HistoryId}; - use time::{OffsetDateTime, UtcOffset}; - - use super::*; - - fn history(duration: i64) -> History { - History { - id: HistoryId("018f011c-9a0a-7000-8000-000000000001".to_string()), - timestamp: OffsetDateTime::UNIX_EPOCH, - duration, - exit: 2, - command: "cargo test".to_string(), - cwd: "/repo".to_string(), - session: String::new(), - hostname: String::new(), - author: String::new(), - intent: None, - deleted_at: None, - } - } - - #[test] - fn formats_last_command() { - assert_eq!( - format_last_command(&history(1_234_000_000), UtcOffset::UTC), - "History ID: 018f011c-9a0a-7000-8000-000000000001 - `cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2), 1.234s" - ); - } - - #[test] - fn formats_history_search_result() { - assert_eq!( - format_history_search_result(3, &history(0), UtcOffset::UTC), - "## #3. (History ID: 018f011c-9a0a-7000-8000-000000000001):\n`cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2)\n" - ); - } -} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs deleted file mode 100644 index f972d4ff..00000000 --- a/crates/atuin-ai/src/lib.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub mod commands; -pub(crate) mod context; -pub(crate) mod context_window; -pub(crate) mod diff; -pub(crate) mod driver; -pub(crate) mod edit_permissions; -pub(crate) mod event_serde; -pub(crate) mod file_tracker; -pub(crate) mod fsm; -pub(crate) mod history_format; -pub(crate) mod permissions; -pub(crate) mod session; -pub(crate) mod skills; -pub(crate) mod snapshots; -pub(crate) mod store; -pub(crate) mod stream; -pub(crate) mod tools; -pub(crate) mod tui; -pub(crate) mod user_context; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs deleted file mode 100644 index bb1eae0c..00000000 --- a/crates/atuin-ai/src/permissions/check.rs +++ /dev/null @@ -1,71 +0,0 @@ -use eyre::Result; - -use crate::{permissions::file::RuleFile, tools::PermissibleToolCall}; - -pub(crate) struct PermissionRequest<'t> { - call: &'t (dyn PermissibleToolCall + Send + Sync), -} - -impl<'t> PermissionRequest<'t> { - pub fn new(call: &'t (dyn PermissibleToolCall + Send + Sync)) -> Self { - Self { call } - } -} - -pub(crate) enum PermissionResponse { - Allowed, - Denied, - Ask, -} - -pub(crate) struct PermissionChecker { - files: Vec, -} - -impl PermissionChecker { - pub fn new(files: Vec) -> Self { - Self { files } - } - - pub async fn check<'t>( - &self, - request: &'t PermissionRequest<'t>, - ) -> Result { - // Files are in order from deepest to shallowest, so we can stop at the first match. - // Within a file, the priority is ask -> deny -> allow - // The first rule type that matches is the one that applies, even if a later rule would contradict it. - for file in &self.files { - for rule in &file.content.permissions.ask { - if request.call.matches_rule(rule) { - tracing::debug!( - "Permission 'ASK' by rule: {} in file: {}", - rule, - file.path.display() - ); - return Ok(PermissionResponse::Ask); - } - } - - for rule in &file.content.permissions.deny { - if request.call.matches_rule(rule) { - tracing::debug!( - "Permission 'DENY' by rule: {} in file: {}", - rule, - file.path.display() - ); - return Ok(PermissionResponse::Denied); - } - } - - if request.call.all_covered_by(&file.content.permissions.allow) { - tracing::debug!( - "Permission 'ALLOW' by rules in file: {}", - file.path.display() - ); - return Ok(PermissionResponse::Allowed); - } - } - - Ok(PermissionResponse::Ask) - } -} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs deleted file mode 100644 index c973f55b..00000000 --- a/crates/atuin-ai/src/permissions/file.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::path::PathBuf; - -use serde::{Deserialize, Serialize}; - -use crate::permissions::rule::Rule; - -#[derive(Debug, Clone)] -pub(crate) struct RuleFile { - pub path: PathBuf, - pub content: RuleFileContent, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub(crate) struct RuleFileContent { - pub permissions: RuleFilePermissions, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub(crate) struct RuleFilePermissions { - #[serde(default)] - pub allow: Vec, - #[serde(default)] - pub deny: Vec, - #[serde(default)] - pub ask: Vec, -} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs deleted file mode 100644 index fce64a51..00000000 --- a/crates/atuin-ai/src/permissions/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod check; -pub(crate) mod file; -pub(crate) mod resolver; -pub(crate) mod rule; -pub(crate) mod shell; -pub(crate) mod walker; -pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs deleted file mode 100644 index dc4f83bf..00000000 --- a/crates/atuin-ai/src/permissions/resolver.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::path::PathBuf; - -use eyre::Result; - -use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; -use crate::permissions::walker::PermissionWalker; -use crate::permissions::writer; -use crate::tools::ClientToolCall; - -/// Resolves permissions for client tool calls by walking the filesystem to find permission files, -pub(crate) struct PermissionResolver { - checker: PermissionChecker, -} - -impl PermissionResolver { - /// Create a new resolver that walks from `working_dir` to root for project - /// permissions, and also checks the global permissions file. - pub async fn new(working_dir: PathBuf) -> Result { - let global_file = writer::global_permissions_path(); - let mut walker = PermissionWalker::new(working_dir, Some(global_file)); - walker.walk().await?; - let checker = PermissionChecker::new(walker.rules().to_owned()); - Ok(Self { checker }) - } - - /// Check whether `tool` is allowed, denied, or needs user confirmation. - pub async fn check(&self, tool: &ClientToolCall) -> Result { - let request = PermissionRequest::new(tool); - self.checker.check(&request).await - } -} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs deleted file mode 100644 index 8fa3fa4a..00000000 --- a/crates/atuin-ai/src/permissions/rule.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::sync::OnceLock; - -use regex::Regex; -use serde::{Deserialize, Serialize}; - -static RULE_RE: OnceLock = OnceLock::new(); - -#[derive(Debug, thiserror::Error)] -pub(crate) enum RuleError { - #[error("invalid rule format: {0}")] - InvalidRule(String), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct Rule { - pub tool: String, - pub scope: Option, -} - -impl std::fmt::Display for Rule { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.scope.as_ref() { - Some(scope) => write!(f, "{}({})", self.tool, scope), - None => write!(f, "{}", self.tool), - } - } -} - -impl Serialize for Rule { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for Rule { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - Self::try_from(s.as_str()).map_err(serde::de::Error::custom) - } -} -impl TryFrom<&str> for Rule { - type Error = RuleError; - - fn try_from(value: &str) -> Result { - let value = value.trim(); - let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); - let caps = re - .captures(value) - .ok_or(RuleError::InvalidRule(value.to_string()))?; - let tool = caps.get(1).unwrap().as_str().to_string(); - let scope = caps.get(2).map(|m| m.as_str().to_string()); - Ok(Rule { tool, scope }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rule_try_from() { - assert_eq!( - Rule::try_from("Read").unwrap(), - Rule { - tool: "Read".to_string(), - scope: None - } - ); - assert_eq!( - Rule::try_from("Read(*)").unwrap(), - Rule { - tool: "Read".to_string(), - scope: Some("*".to_string()) - } - ); - assert_eq!( - Rule::try_from("Write(*.md)").unwrap(), - Rule { - tool: "Write".to_string(), - scope: Some("*.md".to_string()) - } - ); - assert_eq!( - Rule::try_from("Shell(git commit *)").unwrap(), - Rule { - tool: "Shell".to_string(), - scope: Some("git commit *".to_string()) - } - ); - assert_eq!( - Rule::try_from("Shell(echo ())").unwrap(), - Rule { - tool: "Shell".to_string(), - scope: Some("echo ()".to_string()) - } - ); - assert!(Rule::try_from("Shell(git commit *").is_err()); - assert!(Rule::try_from("Shell(git commit *)!").is_err()); - } -} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs deleted file mode 100644 index 29b9f5d8..00000000 --- a/crates/atuin-ai/src/permissions/shell.rs +++ /dev/null @@ -1,1335 +0,0 @@ -/// Extracted command info from a shell command string. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ShellCommand { - /// The command name (first word), e.g. "git" - pub name: String, - /// The full invocation including arguments, e.g. "git commit -m msg" - pub full: String, -} - -/// A parsed shell command with all subcommands extracted. -#[derive(Debug)] -pub(crate) struct ParsedShellCommand { - pub subcommands: Vec, -} - -/// Supported shell families for parsing. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) enum ShellKind { - /// POSIX sh, bash, zsh — all share similar syntax - Posix, - /// fish shell - Fish, - /// nushell or unknown — fallback to word-level extraction - Other, -} - -impl ShellKind { - pub(crate) fn from_shell_name(name: &str) -> Self { - match name { - "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, - "fish" => Self::Fish, - _ => Self::Other, - } - } -} - -/// Parse a shell command string and extract all subcommands. -pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { - #[cfg(feature = "tree-sitter")] - match shell { - ShellKind::Posix => ts::parse_posix(code), - ShellKind::Fish => ts::parse_fish(code), - ShellKind::Other => parse_fallback(code), - } - - #[cfg(not(feature = "tree-sitter"))] - { - let _ = shell; - parse_fallback(code) - } -} - -// ──────────────────────────────────────────────────────────────── -// Tree-sitter parsers (POSIX + Fish) -// Disabled on platforms where tree-sitter doesn't cross-compile -// (e.g. Windows); falls back to word-level extraction. -// ──────────────────────────────────────────────────────────────── - -#[cfg(feature = "tree-sitter")] -mod ts { - use super::{ParsedShellCommand, ShellCommand, parse_fallback}; - use tree_sitter_lib::{Parser, Tree}; - - fn bash_parser() -> Parser { - let mut parser = Parser::new(); - parser - .set_language(&tree_sitter_bash::LANGUAGE.into()) - .expect("failed to set bash language"); - parser - } - - pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { - let mut parser = bash_parser(); - let Some(tree) = parser.parse(code, None) else { - return parse_fallback(code); - }; - - let mut commands = Vec::new(); - walk_bash_tree(&tree, code.as_bytes(), &mut commands); - ParsedShellCommand { - subcommands: commands, - } - } - - /// Leaf node kinds that never contain nested commands. - const BASH_LEAVES: &[&str] = &[ - "command_name", - "word", - "number", - "simple_expansion", - "expansion", - "arithmetic_expansion", - "ansi_c_string", - "special_variable_name", - "variable_name", - "file_descriptor", - "heredoc_body", - "heredoc_start", - "regex", - "heredoc_redirect", - ]; - - fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { - walk_bash_node(tree.root_node(), source, commands); - } - - fn walk_bash_node( - node: tree_sitter_lib::Node, - source: &[u8], - commands: &mut Vec, - ) { - match node.kind() { - "command" => { - if let Some(cmd) = extract_bash_command(node, source) { - commands.push(cmd); - } - // Descend into all non-leaf children to find nested commands - // (e.g. command_substitution inside a string inside a command) - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - if !BASH_LEAVES.contains(&child.kind()) { - walk_bash_node(child, source, commands); - } - } - } - // Other nodes: descend into all children - _ => { - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - walk_bash_node(child, source, commands); - } - } - } - } - - /// Extract the full command string and name from a bash `command` node. - fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option { - // A `command` node has children like: - // variable_assignment* command_name argument* redirect* - // We want the command_name and all arguments (skipping assignments and redirects). - let mut name = None; - let mut name_start = None; - let mut arg_end = None; - - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - match child.kind() { - "command_name" => { - name = child.utf8_text(source).ok().map(|s| s.to_string()); - name_start = Some(child.start_byte()); - } - "word" - | "string" - | "raw_string" - | "concatenation" - | "number" - | "simple_expansion" - | "expansion" - | "arithmetic_expansion" - | "ansi_c_string" - | "process_substitution" => { - arg_end = Some(child.end_byte()); - } - _ => {} - } - } - - let name = name?; - let full = if let (Some(start), Some(end)) = (name_start, arg_end) { - std::str::from_utf8(&source[start..end]).ok()?.to_string() - } else { - name.clone() - }; - - Some(ShellCommand { name, full }) - } - - // ──────────────────────────────────────────────────────────────── - // Fish parser - // ──────────────────────────────────────────────────────────────── - - fn fish_parser() -> Parser { - let mut parser = Parser::new(); - parser - .set_language(&tree_sitter_fish::language()) - .expect("failed to set fish language"); - parser - } - - pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { - let mut parser = fish_parser(); - let Some(tree) = parser.parse(code, None) else { - return parse_fallback(code); - }; - - let mut commands = Vec::new(); - walk_fish_tree(&tree, code.as_bytes(), &mut commands); - ParsedShellCommand { - subcommands: commands, - } - } - - const FISH_COMPOUND: &[&str] = &[ - "conditional_execution", - "pipe", - "job", - "command_substitution", - "block", - "for_statement", - "while_statement", - "if_statement", - "switch_statement", - "function_definition", - "begin_statement", - "redirected_statement", - ]; - - fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec) { - walk_fish_node(tree.root_node(), source, commands); - } - - fn walk_fish_node( - node: tree_sitter_lib::Node, - source: &[u8], - commands: &mut Vec, - ) { - match node.kind() { - "command" => { - if let Some(cmd) = extract_fish_command(node, source) { - commands.push(cmd); - } - // Still descend into compound children (e.g. command_substitution inside a command) - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - if FISH_COMPOUND.contains(&child.kind()) { - walk_fish_node(child, source, commands); - } - } - } - // Other nodes: descend into all children - _ => { - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - walk_fish_node(child, source, commands); - } - } - } - } - - fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option { - // In fish, a `command` node has: - // name (command_name or word) followed by arguments (word, string, etc.) - let mut name = None; - - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - match child.kind() { - "command_name" | "word" => { - let text = child.utf8_text(source).ok()?.to_string(); - if name.is_none() { - name = Some(text); - } - } - "string" - | "concatenation" - | "command_substitution" - | "escape_sequence" - | "double_quote_string" - | "single_quote_string" => {} - _ => {} - } - } - - let name = name?; - // Get the full text of the command node - let full = node.utf8_text(source).ok()?.trim().to_string(); - - Some(ShellCommand { name, full }) - } -} // mod ts - -// ──────────────────────────────────────────────────────────────── -// Fallback (word-level extraction for nushell / unknown shells) -// ──────────────────────────────────────────────────────────────── - -fn parse_fallback(code: &str) -> ParsedShellCommand { - // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. - // This is intentionally simple — for unknown shells we can't do better. - let mut commands = Vec::new(); - let mut segment = String::new(); - let mut chars = code.chars().peekable(); - - while let Some(c) = chars.next() { - match c { - ';' => { - push_segment(&mut segment, &mut commands); - } - '|' => { - if chars.peek() == Some(&'|') { - chars.next(); - } - push_segment(&mut segment, &mut commands); - } - '&' if chars.peek() == Some(&'&') => { - chars.next(); - push_segment(&mut segment, &mut commands); - } - _ => segment.push(c), - } - } - push_segment(&mut segment, &mut commands); - - ParsedShellCommand { - subcommands: commands, - } -} - -fn push_segment(segment: &mut String, commands: &mut Vec) { - let trimmed = segment.trim(); - if !trimmed.is_empty() - && let Some(name) = trimmed.split_whitespace().next() - { - commands.push(ShellCommand { - name: name.to_string(), - full: trimmed.to_string(), - }); - } - segment.clear(); -} - -// ──────────────────────────────────────────────────────────────── -// Scope matching -// ──────────────────────────────────────────────────────────────── - -/// Check if any of the extracted subcommands match the given scope pattern. -/// -/// Matching semantics depend on where the `*` wildcard appears: -/// - `*` alone — matches everything -/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` -/// - `git commit *` — matches `git commit -m "msg"` (word boundary) -/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) -/// - `rm` (no wildcard) — matches exactly `rm` -/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) -/// -/// When `prefix_bare` is true, a bare pattern without wildcards (e.g. `rm`) -/// uses word-boundary prefix matching — `rm` matches `rm -rf /`. When false, -/// bare patterns require an exact match — `rm` only matches `rm`. -/// -/// Allow rules should pass `prefix_bare: false` (strict), while deny/ask rules -/// should pass `prefix_bare: true` (broad) so that denying `rm` also blocks -/// `rm -rf /`. -pub(crate) fn any_subcommand_matches( - subcommands: &[ShellCommand], - prefix_bare: bool, - scope: &str, -) -> bool { - let scope = scope.trim(); - - if scope.is_empty() || scope == "*" { - return true; - } - - if let Some(prefix) = scope.strip_suffix(" *") { - // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` - return subcommands.iter().any(|cmd| { - if prefix.is_empty() { - return true; - } - let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); - let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); - cmd_words.len() >= prefix_words.len() - && cmd_words[..prefix_words.len()] == prefix_words[..] - }); - } - - if let Some(prefix) = scope.strip_suffix('*') { - // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. - return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); - } - - if scope.contains('*') { - // Middle wildcard: `git * amend` — each `*` matches zero or more words - return subcommands - .iter() - .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); - } - - // No wildcard: exact or prefix depending on context - let scope_words: Vec<&str> = scope.split_whitespace().collect(); - subcommands.iter().any(|cmd| { - let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); - if prefix_bare { - cmd_words.len() >= scope_words.len() - && cmd_words[..scope_words.len()] == scope_words[..] - } else { - cmd_words == scope_words - } - }) -} - -/// Match a scope pattern containing `*` wildcards against a sequence of words. -/// Each `*` matches zero or more words. Consecutive `*` collapse into one. -fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { - let parts: Vec<&str> = scope.split('*').collect(); - if parts.len() == 1 { - // No wildcard (shouldn't reach here, but handle it) - let scope_words: Vec<&str> = scope.split_whitespace().collect(); - return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; - } - - // Each segment between * is a sequence of literal words that must appear in order. - // Walk through `words` consuming segments left to right. - let mut word_idx = 0; - - for (i, part) in parts.iter().enumerate() { - let segment_words: Vec<&str> = part.split_whitespace().collect(); - if segment_words.is_empty() { - continue; - } - - // Find the segment words starting from word_idx - if i == 0 { - // First segment must match at the start - if words.len() < segment_words.len() - || words[..segment_words.len()] != segment_words[..] - { - return false; - } - word_idx = segment_words.len(); - } else if i == parts.len() - 1 { - // Last segment must match at the end - if words.len() - word_idx < segment_words.len() { - return false; - } - let start = words.len() - segment_words.len(); - return words[start..] == segment_words[..]; - } else { - // Middle segment: find it anywhere after word_idx - let found = find_subslice(&words[word_idx..], &segment_words); - match found { - Some(idx) => word_idx += idx + segment_words.len(), - None => return false, - } - } - } - - true -} - -/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. -fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option { - if needle.is_empty() { - return Some(0); - } - if haystack.len() < needle.len() { - return None; - } - (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) -} - -// ──────────────────────────────────────────────────────────────── -// Tests -// ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - fn names(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.name.as_str()).collect() - } - - fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.full.as_str()).collect() - } - - #[test] - fn simple_command() { - let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["ls"]); - assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); - } - - #[test] - fn pipeline() { - let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); - } - - #[test] - fn command_chaining() { - let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["git", "git"]); - assert_eq!( - fulls(&result.subcommands), - vec!["git add .", "git commit -m 'hi'"] - ); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn command_substitution() { - let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn backtick_substitution() { - let result = parse_shell_command("echo `date`", ShellKind::Posix); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"date"), "should contain date: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn subshell() { - let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); - } - - #[test] - fn semicolon_separated() { - let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn for_loop() { - let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); - assert!(names(&result.subcommands).contains(&"cat")); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn if_statement() { - let result = parse_shell_command( - "if [ -f foo ]; then cat foo; else echo nope; fi", - ShellKind::Posix, - ); - let n = names(&result.subcommands); - assert!(n.contains(&"cat"), "should contain cat: {n:?}"); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - } - - #[test] - fn scope_matching_wildcard() { - let commands = vec![ - ShellCommand { - name: "git".into(), - full: "git commit -m msg".into(), - }, - ShellCommand { - name: "npm".into(), - full: "npm test".into(), - }, - ]; - assert!(any_subcommand_matches(&commands, true, "*")); - } - - #[test] - fn scope_matching_prefix() { - let commands = vec![ - ShellCommand { - name: "git".into(), - full: "git commit -m msg".into(), - }, - ShellCommand { - name: "npm".into(), - full: "npm test".into(), - }, - ]; - assert!(any_subcommand_matches(&commands, true, "git commit *")); - assert!(!any_subcommand_matches(&commands, true, "git push *")); - assert!(!any_subcommand_matches(&commands, true, "git push")); - assert!(any_subcommand_matches(&commands, true, "npm *")); - assert!(any_subcommand_matches(&commands, true, "npm test")); - - // prefix_bare=true: bare "git commit" prefix-matches "git commit -m msg" (deny/ask) - assert!(any_subcommand_matches(&commands, true, "git commit")); - // prefix_bare=false: bare "git commit" does NOT match "git commit -m msg" (allow) - assert!(!any_subcommand_matches(&commands, false, "git commit")); - // Exact match works in both modes when command has no extra args - assert!(any_subcommand_matches(&commands, false, "npm test")); - } - - #[test] - fn scope_word_boundary_vs_glob() { - let commands = vec![ - ShellCommand { - name: "ls".into(), - full: "ls -a".into(), - }, - ShellCommand { - name: "lsof".into(), - full: "lsof -i :3000".into(), - }, - ]; - // `ls *` — word boundary: matches `ls -a` but not `lsof` - assert!(any_subcommand_matches(&commands, true, "ls *")); - assert!(!any_subcommand_matches(&commands, true, "cat *")); - assert!(any_subcommand_matches(&commands, true, "lsof *")); - - // `ls*` — glob/prefix: matches both `ls -a` and `lsof` - assert!(any_subcommand_matches(&commands, true, "ls*")); - } - - #[test] - fn scope_exact_match() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "ls")); - assert!(!any_subcommand_matches(&commands, true, "cat")); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn nested_substitution() { - let result = parse_shell_command( - "echo \"Result: $(git log --oneline | head -1)\"", - ShellKind::Posix, - ); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - assert!(n.contains(&"head"), "should contain head: {n:?}"); - } - - #[test] - fn fallback_splits_correctly() { - let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); - let n = names(&result.subcommands); - assert!(n.contains(&"ls"), "should contain ls: {n:?}"); - assert!(n.contains(&"cat"), "should contain cat: {n:?}"); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - } - - #[test] - fn fish_simple_command() { - let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); - assert_eq!(names(&result.subcommands), vec!["ls"]); - } - - #[test] - fn fish_conditional() { - let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); - let n = names(&result.subcommands); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn fish_command_substitution() { - let result = parse_shell_command("echo (date)", ShellKind::Fish); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"date"), "should contain date: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn variable_assignment_excluded() { - let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["ls"]); - assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn variable_assignment_multiple() { - let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["git"]); - assert_eq!(fulls(&result.subcommands), vec!["git status"]); - } - - #[test] - fn fallback_double_ampersand_and_pipe_pipe() { - let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); - assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); - assert_eq!( - fulls(&result.subcommands), - vec!["ls", "cat foo", "echo fail"] - ); - } - - #[test] - fn fallback_pipe_without_double() { - let result = parse_shell_command("ls | grep foo", ShellKind::Other); - assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); - assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); - } - - #[test] - fn scope_middle_wildcard() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit -m amend".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * amend")); - assert!(any_subcommand_matches( - &commands, - true, - "git commit * amend" - )); - assert!(!any_subcommand_matches(&commands, true, "git push * amend")); - } - - #[test] - fn scope_middle_wildcard_zero_words() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - // `*` matches zero words, so `git * commit` should match `git commit` - assert!(any_subcommand_matches(&commands, true, "git * commit")); - } - - #[test] - fn scope_leading_wildcard() { - let commands = vec![ShellCommand { - name: "docker".into(), - full: "docker run --rm alpine".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "* alpine")); - assert!(!any_subcommand_matches(&commands, true, "* ubuntu")); - } - - #[test] - fn scope_multiple_wildcards() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git rebase -i HEAD~5".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * -i * HEAD~5")); - assert!(!any_subcommand_matches( - &commands, - true, - "git * -i * HEAD~10" - )); - } -} - -#[cfg(all(test, feature = "tree-sitter"))] -mod adversarial { - use super::*; - - fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.name.as_str()).collect() - } - - /// Helper: assert that parsing POSIX extracts all expected command names - fn assert_posix(code: &str, expected: &[&str]) { - let result = parse_shell_command(code, ShellKind::Posix); - let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); - got.sort(); - let mut want: Vec<&str> = expected.to_vec(); - want.sort(); - assert_eq!( - got, want, - "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", - code, got, want - ); - } - - fn assert_fish(code: &str, expected: &[&str]) { - let result = parse_shell_command(code, ShellKind::Fish); - let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); - got.sort(); - let mut want: Vec<&str> = expected.to_vec(); - want.sort(); - assert_eq!( - got, want, - "Fish parse of {:?}:\n got: {:?}\n want: {:?}", - code, got, want - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 1: Basic compounds - // ──────────────────────────────────────────────────────────── - - #[test] - fn a01_triple_chain() { - assert_posix("a && b && c", &["a", "b", "c"]); - } - - #[test] - fn a02_or_chain() { - assert_posix("a || b || c", &["a", "b", "c"]); - } - - #[test] - fn a03_mixed_chain() { - assert_posix("a && b || c && d", &["a", "b", "c", "d"]); - } - - #[test] - fn a04_long_pipeline() { - assert_posix( - "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", - &["cat", "grep", "awk", "sort", "uniq"], - ); - } - - #[test] - fn a05_semicolons() { - assert_posix("a; b; c; d", &["a", "b", "c", "d"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 2: Nested substitution - // ──────────────────────────────────────────────────────────── - - #[test] - fn a06_nested_dollar() { - assert_posix( - "echo $(basename $(dirname /foo/bar))", - &["echo", "basename", "dirname"], - ); - } - - #[test] - fn a07_deeply_nested() { - // 4 nested echos, all should be extracted - assert_posix( - "echo $(echo $(echo $(echo deep)))", - &["echo", "echo", "echo", "echo"], - ); - } - - #[test] - fn a08_backtick_in_echo() { - assert_posix("echo `hostname`", &["echo", "hostname"]); - } - - #[test] - fn a09_mixed_substitutions() { - assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 3: Subshells and grouping - // ──────────────────────────────────────────────────────────── - - #[test] - fn a10_subshell_chain() { - assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); - } - - #[test] - fn a11_nested_subshells() { - assert_posix("( (inner_cmd) )", &["inner_cmd"]); - } - - #[test] - fn a12_brace_group() { - assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 4: Variable assignments - // ──────────────────────────────────────────────────────────── - - #[test] - fn a13_single_var_assignment() { - let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); - assert_eq!(cmd_names(&result.subcommands), &["ls"]); - assert_eq!(result.subcommands[0].full, "ls"); - } - - #[test] - fn a14_multiple_var_assignments() { - let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); - assert_eq!(cmd_names(&result.subcommands), &["git"]); - assert_eq!(result.subcommands[0].full, "git status"); - } - - #[test] - fn a15_var_assignment_no_command() { - // Variable assignment only — no command to extract - assert_posix("FOO=bar", &[]); - } - - #[test] - fn a16_var_assignment_in_pipeline() { - assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 5: Control flow - // ──────────────────────────────────────────────────────────── - - #[test] - fn a17_if_then_else() { - assert_posix( - "if [ -f foo ]; then cat foo; else echo missing; fi", - &["cat", "echo"], - ); - } - - #[test] - fn a18_elif_chain() { - // Two cat commands (then + elif branch), one echo (else branch). - // [ is part of the test_condition, not extracted as a command. - assert_posix( - "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", - &["cat", "cat", "echo"], - ); - } - - #[test] - fn a19_for_loop() { - assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); - } - - #[test] - fn a20_while_loop() { - // read in the condition is a real command - assert_posix( - "while read line; do echo \"$line\"; done < input.txt", - &["echo", "read"], - ); - } - - #[test] - fn f07_if_statement() { - // test in if-condition is a real command - assert_fish( - "if test -f foo; cat foo; else; echo missing; end", - &["cat", "echo", "test"], - ); - } - - #[test] - fn f09_while_loop() { - // `true` in the condition is a real command - assert_fish( - "while true; echo tick; sleep 1; end", - &["echo", "sleep", "true"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 6: Redirections - // ──────────────────────────────────────────────────────────── - - #[test] - fn a23_redirect_out() { - assert_posix("ls > output.txt", &["ls"]); - } - - #[test] - fn a24_redirect_append() { - assert_posix("ls >> output.txt 2>&1", &["ls"]); - } - - #[test] - fn a25_here_string() { - assert_posix("grep foo <<< \"hello world\"", &["grep"]); - } - - #[test] - fn a26_redirect_in_pipeline() { - assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); - } - - #[test] - fn a27_process_substitution() { - assert_posix( - "diff <(sort a.txt) <(sort b.txt)", - &["diff", "sort", "sort"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 7: Function definitions - // ──────────────────────────────────────────────────────────── - - #[test] - fn a28_function_def() { - assert_posix("foo() { echo hello; }", &["echo"]); - } - - #[test] - fn a29_function_with_subshell() { - assert_posix( - "build() { cargo build && cargo test; }", - &["cargo", "cargo"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 8: Edge cases — empties, weird quoting - // ──────────────────────────────────────────────────────────── - - #[test] - fn a30_empty_string() { - let result = parse_shell_command("", ShellKind::Posix); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn a31_whitespace_only() { - let result = parse_shell_command(" \t \n ", ShellKind::Posix); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn a32_single_command_no_args() { - assert_posix("ls", &["ls"]); - } - - #[test] - fn a33_command_with_single_quotes() { - assert_posix("echo 'hello world'", &["echo"]); - } - - #[test] - fn a34_command_with_double_quotes() { - assert_posix("echo \"hello world\"", &["echo"]); - } - - #[test] - fn a35_escaped_spaces() { - // ls\ -la is a single word in bash, not "ls" with flag "-la" - assert_posix("ls\\ -la", &["ls\\ -la"]); - } - - #[test] - fn a36_command_with_dollar_var() { - assert_posix("echo $HOME/.bashrc", &["echo"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 9: Background jobs and coproc - // ──────────────────────────────────────────────────────────── - - #[test] - fn a37_background_job() { - assert_posix("sleep 10 &", &["sleep"]); - } - - #[test] - fn a38_background_chain() { - assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 10: Real-world complex commands - // ──────────────────────────────────────────────────────────── - - #[test] - fn a39_docker_build_and_run() { - assert_posix( - "docker build -t app . && docker run --rm app npm test", - &["docker", "docker"], - ); - } - - #[test] - fn a40_git_rebase_interactive() { - assert_posix( - "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", - &["git"], - ); - } - - #[test] - fn a41_find_with_exec() { - // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. - // This is a known limitation: args to -exec/-execdir are opaque to the parser. - assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); - } - - #[test] - fn a42_curl_pipe_sh() { - assert_posix( - "curl -sSL https://example.com/install.sh | bash", - &["curl", "bash"], - ); - } - - #[test] - fn a43_xargs() { - assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); - } - - #[test] - fn a44_npm_script_chain() { - assert_posix( - "npm run build && npm run test && npm run lint", - &["npm", "npm", "npm"], - ); - } - - #[test] - fn a45_make_with_redirect() { - assert_posix( - "make -j$(nproc) 2>&1 | tee build.log", - &["make", "nproc", "tee"], - ); - } - - #[test] - fn a46_sudo_chain() { - assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); - } - - #[test] - fn a47_here_doc_with_subcommand() { - assert_posix("cat < output.txt", &["ls"]); - } - - #[test] - fn f13_redirect_append() { - assert_fish("ls >> output.txt", &["ls"]); - } - - #[test] - fn f14_here_string() { - assert_fish("grep foo <<< \"hello\"", &["grep"]); - } - - #[test] - fn f15_curl_pipe() { - assert_fish( - "curl -sSL https://example.com/install.sh | bash", - &["curl", "bash"], - ); - } - - #[test] - fn f16_double_ampersand() { - assert_fish("git add . && git commit -m hi", &["git", "git"]); - } - - #[test] - fn f17_double_pipe() { - assert_fish("test -f foo || echo missing", &["test", "echo"]); - } - - #[test] - fn f18_empty() { - let result = parse_shell_command("", ShellKind::Fish); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn f19_whitespace() { - let result = parse_shell_command(" ", ShellKind::Fish); - assert!(result.subcommands.is_empty()); - } - - // ──────────────────────────────────────────────────────────── - // Level 12: Scope matching adversarial - // ──────────────────────────────────────────────────────────── - - #[test] - fn s01_empty_scope() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // Empty scope matches everything (nothing to constrain) - assert!(any_subcommand_matches(&commands, true, "")); - } - - #[test] - fn s03_only_wildcard_space_star() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // " *" with empty prefix = match anything - assert!(any_subcommand_matches(&commands, true, " *")); - } - - #[test] - fn s04_glob_matches_empty() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // `ls*` matches `ls` (prefix match with nothing after) - assert!(any_subcommand_matches(&commands, true, "ls*")); - } - - #[test] - fn s05_middle_wildcard_empty_match() { - // `git * commit` matches `git commit` (* = zero words) - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * commit")); - } - - #[test] - fn s06_consecutive_wildcards() { - // `git ** commit` should behave like `git * commit` - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git ** commit")); - } - - #[test] - fn s07_case_sensitivity() { - let commands = vec![ShellCommand { - name: "LS".into(), - full: "LS -la".into(), - }]; - // Wildcard: case matters - assert!(!any_subcommand_matches(&commands, true, "ls *")); - assert!(any_subcommand_matches(&commands, true, "LS *")); - // prefix_bare=true: bare "LS" prefix-matches "LS -la" - assert!(!any_subcommand_matches(&commands, true, "ls")); - assert!(any_subcommand_matches(&commands, true, "LS")); - // prefix_bare=false: bare "LS" does NOT match "LS -la" - assert!(!any_subcommand_matches(&commands, false, "LS")); - } - - #[test] - fn s08_multi_word_exact_no_subcommand() { - // `git commit` should not match `git commit-amend` - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit-amend".into(), - }]; - assert!(!any_subcommand_matches(&commands, true, "git commit")); - } -} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs deleted file mode 100644 index 3bda01c3..00000000 --- a/crates/atuin-ai/src/permissions/walker.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::path::{Path, PathBuf}; - -use eyre::Result; -use tokio::task::JoinSet; - -use crate::permissions::file::{RuleFile, RuleFileContent}; - -#[derive(Debug)] -struct FoundRuleFile { - depth: usize, - file: RuleFile, -} - -pub(crate) struct PermissionWalker { - start: PathBuf, - /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). - global_permissions_file: Option, - rules: Vec, -} - -impl PermissionWalker { - pub fn new(start: PathBuf, global_permissions_file: Option) -> Self { - Self { - start, - global_permissions_file, - rules: Vec::new(), - } - } - - pub fn rules(&self) -> &[RuleFile] { - &self.rules - } - - /// Walks the filesystem starting from the start path and collecting permission files along the way. - /// Walks to the root, then checks the global permissions file, if any. - pub async fn walk(&mut self) -> Result<()> { - let dirs_to_check: Vec = self.start.ancestors().map(PathBuf::from).collect(); - let dir_count = dirs_to_check.len(); - - let mut set: JoinSet>> = JoinSet::new(); - - for (index, path) in dirs_to_check.into_iter().enumerate() { - set.spawn(async move { - match check_dir_for_permissions(&path).await { - Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { - depth: index, - file: rule_file, - })), - Ok(None) => Ok(None), - Err(e) => Err(e), - } - }); - } - - // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) - if let Some(global_path) = self.global_permissions_file.clone() { - let depth = dir_count; // sorts after all directory-walk entries - set.spawn(async move { - match load_permissions_file(&global_path).await { - Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { - depth, - file: rule_file, - })), - Ok(None) => Ok(None), - Err(e) => Err(e), - } - }); - } - - let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); - let mut found = Vec::with_capacity(capacity); - while let Some(result) = set.join_next().await { - let result = result?; // JoinErrors result in failure to walk the filesystem - - match result { - Ok(Some(FoundRuleFile { depth, file })) => { - found.push((depth, file)); - } - Ok(None) => { - continue; - } - Err(e) => { - tracing::error!( - "Error while walking filesystem for permissions check; skipping: {}", - e - ); - continue; - } - } - } - // join_next() returns in order of completion, not order of spawn - found.sort_by_key(|(depth, _)| *depth); - self.rules = found.into_iter().map(|(_, file)| file).collect(); - - Ok(()) - } -} - -/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. -async fn check_dir_for_permissions(path: &Path) -> Result> { - let file_path = path.join(".atuin").join("permissions.ai.toml"); - load_permissions_file(&file_path).await -} - -/// Load a permissions file from an exact path. Returns None if the file doesn't exist. -async fn load_permissions_file(file_path: &Path) -> Result> { - if !tokio::fs::try_exists(file_path).await? { - return Ok(None); - } - - let raw = tokio::fs::read_to_string(file_path).await?; - let content: RuleFileContent = toml::from_str(&raw)?; - - // Use the file's parent as the rule file path (for logging/debugging) - let path = file_path - .parent() - .map(Path::to_path_buf) - .unwrap_or_else(|| file_path.to_path_buf()); - - Ok(Some(RuleFile { path, content })) -} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs deleted file mode 100644 index ffef404e..00000000 --- a/crates/atuin-ai/src/permissions/writer.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::path::Path; - -use eyre::Result; - -use crate::permissions::rule::Rule; - -/// Whether a rule should be added to the allow or deny list. -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub(crate) enum RuleDisposition { - Allow, - Deny, -} - -/// Write a permission rule to a `permissions.ai.toml` file. -/// -/// If the file doesn't exist it is created (along with parent directories). -/// If it does exist, `toml_edit` is used to append the rule while preserving -/// existing formatting and comments. -/// -/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the -/// current UI this is fine — the Select widget serializes permission decisions — -/// but callers should not invoke this concurrently for the same file. -pub(crate) async fn write_rule( - file_path: &Path, - rule: &Rule, - disposition: RuleDisposition, -) -> Result<()> { - let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { - tokio::fs::read_to_string(file_path).await? - } else { - String::new() - }; - - let mut doc: toml_edit::DocumentMut = content.parse()?; - - // Ensure [permissions] table exists - if !doc.contains_key("permissions") { - doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); - } - - let key = match disposition { - RuleDisposition::Allow => "allow", - RuleDisposition::Deny => "deny", - }; - - // Use as_table_like_mut so both standard and inline tables work. - let permissions = doc["permissions"] - .as_table_like_mut() - .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; - - // Get or create the array - if !permissions.contains_key(key) { - permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); - } - - let array = permissions - .get_mut(key) - .and_then(|item| item.as_value_mut()) - .and_then(|v| v.as_array_mut()) - .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; - - // Don't add duplicates - let rule_str = rule.to_string(); - let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); - if !already_present { - array.push(rule_str); - } - - // Write back, creating parent directories as needed - if let Some(parent) = file_path.parent() { - tokio::fs::create_dir_all(parent).await?; - } - tokio::fs::write(file_path, doc.to_string()).await?; - - Ok(()) -} - -/// Build the path to the project-level permissions file. -/// `project_root` is typically a git root or the current working directory. -pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { - project_root.join(".atuin").join("permissions.ai.toml") -} - -/// Build the path to the global permissions file (sibling of atuin config). -pub(crate) fn global_permissions_path() -> std::path::PathBuf { - atuin_common::utils::config_dir().join("permissions.ai.toml") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn creates_new_file_with_allow_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains("[permissions]")); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn appends_to_existing_file() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let existing = r#"# My permissions -[permissions] -allow = ["Read"] -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - // Comment preserved - assert!(content.contains("# My permissions")); - // Both rules present - assert!(content.contains(r#""Read""#)); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn does_not_duplicate_existing_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let existing = r#"[permissions] -allow = ["AtuinHistory"] -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - // Should appear exactly once - assert_eq!(content.matches("AtuinHistory").count(), 1); - } - - #[tokio::test] - async fn handles_inline_table_permissions() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - // Inline table style — as_table_mut() would return None for this - let existing = r#"permissions = { allow = ["Read"] } -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains(r#""Read""#)); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn writes_deny_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let rule = Rule { - tool: "Shell".to_string(), - scope: None, - }; - - write_rule(&file, &rule, RuleDisposition::Deny) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains("deny")); - assert!(content.contains(r#""Shell""#)); - } -} diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs deleted file mode 100644 index 848330fc..00000000 --- a/crates/atuin-ai/src/session.rs +++ /dev/null @@ -1,509 +0,0 @@ -//! Session service abstraction and manager. -//! -//! The TUI interacts with sessions through `SessionManager`, which wraps a -//! `SessionService` trait. Today the only implementation is `LocalSessionService` -//! (direct SQLite). When the daemon owns session state, a gRPC-backed -//! implementation can be swapped in without changing the TUI code. - -use async_trait::async_trait; -use eyre::Result; - -use crate::event_serde; -use crate::store::{AiSessionStore, StoredEvent, StoredSession}; -use crate::tui::ConversationEvent; - -// --------------------------------------------------------------------------- -// Trait -// --------------------------------------------------------------------------- - -#[async_trait] -pub(crate) trait SessionService: Send + Sync { - async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result; - - async fn find_resumable( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result>; - - async fn load_events(&self, session_id: &str) -> Result>; - - async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()>; - - async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()>; - - async fn archive(&self, session_id: &str) -> Result<()>; - - async fn get_metadata(&self, session_id: &str, key: &str) -> Result>; - async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>; -} - -// --------------------------------------------------------------------------- -// Local implementation (direct SQLite) -// --------------------------------------------------------------------------- - -pub(crate) struct LocalSessionService { - store: AiSessionStore, -} - -impl LocalSessionService { - pub async fn open(path: impl AsRef, timeout: f64) -> Result { - let store = AiSessionStore::new(path, timeout).await?; - Ok(Self { store }) - } -} - -#[async_trait] -impl SessionService for LocalSessionService { - async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result { - self.store.create_session(id, directory, git_root).await - } - - async fn find_resumable( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result> { - self.store - .find_resumable_session(directory, git_root, max_age_secs) - .await - } - - async fn load_events(&self, session_id: &str) -> Result> { - self.store.load_events(session_id).await - } - - async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()> { - self.store - .append_event( - session_id, - event_id, - parent_id, - invocation_id, - event_type, - event_data, - ) - .await - } - - async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()> { - self.store - .update_server_session_id(session_id, server_session_id) - .await - } - - async fn archive(&self, session_id: &str) -> Result<()> { - self.store.archive_session(session_id).await - } - - async fn get_metadata(&self, session_id: &str, key: &str) -> Result> { - self.store.get_metadata(session_id, key).await - } - - async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { - self.store.set_metadata(session_id, key, value).await - } -} - -// --------------------------------------------------------------------------- -// SessionManager -// --------------------------------------------------------------------------- - -/// High-level session manager used by the TUI dispatch loop. -/// -/// Owns the current session identity, tracks what has been persisted, and -/// handles serialization between `ConversationEvent` and the storage format. -pub(crate) struct SessionManager { - service: Box, - session_id: String, - invocation_id: String, - /// Number of events already persisted. `persist_events` only writes the - /// delta from this index onward. - persisted_count: usize, - /// ID of the last persisted event, used as `parent_id` for the next one. - head_id: Option, - /// Stored for creating a new session on `/new`. - directory: Option, - git_root: Option, - /// Whether the session row has been created in the database. New sessions - /// are deferred until the first event is persisted, so empty sessions - /// don't linger and get spuriously resumed. - persisted_to_db: bool, -} - -impl SessionManager { - /// Create a new session manager. The database row is deferred until the - /// first event is persisted. - pub fn create_new( - service: Box, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Self { - let session_id = atuin_common::utils::uuid_v7().to_string(); - let invocation_id = atuin_common::utils::uuid_v7().to_string(); - - Self { - service, - session_id, - invocation_id, - persisted_count: 0, - head_id: None, - directory: directory.map(String::from), - git_root: git_root.map(String::from), - persisted_to_db: false, - } - } - - /// Load an existing session and return a manager for it, along with the - /// deserialized conversation events, the server session ID, and the - /// timestamp of the last stored event. - pub async fn resume( - service: Box, - stored: &StoredSession, - ) -> Result<( - Self, - Vec, - Option, - Option, - String, - )> { - let invocation_id = atuin_common::utils::uuid_v7().to_string(); - let stored_events = service.load_events(&stored.id).await?; - - let mut events = Vec::with_capacity(stored_events.len()); - let mut last_event_id = None; - let mut last_event_ts = None; - for se in &stored_events { - events.push(event_serde::deserialize_event( - &se.event_type, - &se.event_data, - )?); - last_event_id = Some(se.id.clone()); - last_event_ts = Some(se.created_at); - } - - let manager = Self { - service, - session_id: stored.id.clone(), - invocation_id: invocation_id.clone(), - persisted_count: events.len(), - head_id: last_event_id, - directory: stored.directory.clone(), - git_root: stored.git_root.clone(), - persisted_to_db: true, - }; - - Ok(( - manager, - events, - stored.server_session_id.clone(), - last_event_ts, - invocation_id, - )) - } - - /// Ensure the session row exists in the database. - async fn ensure_persisted(&mut self) -> Result<()> { - if !self.persisted_to_db { - self.service - .create_session( - &self.session_id, - self.directory.as_deref(), - self.git_root.as_deref(), - ) - .await?; - self.persisted_to_db = true; - } - Ok(()) - } - - /// Persist any new events since the last persist call. - pub async fn persist_events(&mut self, events: &[ConversationEvent]) -> Result<()> { - if self.persisted_count >= events.len() { - return Ok(()); - } - self.ensure_persisted().await?; - for event in &events[self.persisted_count..] { - let event_id = atuin_common::utils::uuid_v7().to_string(); - let (event_type, event_data) = event_serde::serialize_event(event); - - self.service - .append_event( - &self.session_id, - &event_id, - self.head_id.as_deref(), - &self.invocation_id, - &event_type, - &event_data, - ) - .await?; - - self.head_id = Some(event_id); - self.persisted_count += 1; - } - Ok(()) - } - - /// Persist the server session ID if it has changed. - pub async fn persist_server_session_id(&mut self, server_session_id: &str) -> Result<()> { - self.ensure_persisted().await?; - self.service - .update_server_session_id(&self.session_id, server_session_id) - .await - } - - /// Archive the current session (for `/new` command). - #[allow(dead_code)] // used in tests; will be used by dispatch for `/new` - pub async fn archive(&self) -> Result<()> { - if self.persisted_to_db { - self.service.archive(&self.session_id).await?; - } - Ok(()) - } - - /// Archive the current session and reset to a fresh one. - /// The new session row is deferred until the first event is persisted. - pub async fn archive_and_reset(&mut self) -> Result<()> { - if self.persisted_to_db { - self.service.archive(&self.session_id).await?; - } - - self.session_id = atuin_common::utils::uuid_v7().to_string(); - self.invocation_id = atuin_common::utils::uuid_v7().to_string(); - self.persisted_count = 0; - self.head_id = None; - self.persisted_to_db = false; - Ok(()) - } - - #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon - pub fn session_id(&self) -> &str { - &self.session_id - } - - #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon - pub fn invocation_id(&self) -> &str { - &self.invocation_id - } - - /// Read a metadata value for the current session. - pub async fn get_metadata(&self, key: &str) -> Result> { - if !self.persisted_to_db { - return Ok(None); - } - self.service.get_metadata(&self.session_id, key).await - } - - /// Write a metadata value for the current session. - pub async fn set_metadata(&mut self, key: &str, value: &str) -> Result<()> { - self.ensure_persisted().await?; - self.service - .set_metadata(&self.session_id, key, value) - .await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn test_service() -> Box { - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - Box::new(svc) - } - - #[tokio::test] - async fn test_create_new_and_persist() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - let events = vec![ - ConversationEvent::UserMessage { - content: "hello".to_string(), - }, - ConversationEvent::Text { - content: "hi there".to_string(), - }, - ]; - - mgr.persist_events(&events).await.unwrap(); - - // Persist again with no new events — should be a no-op - mgr.persist_events(&events).await.unwrap(); - } - - #[tokio::test] - async fn test_create_and_resume() { - // Create a session and persist some events - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let session_id = atuin_common::utils::uuid_v7().to_string(); - svc.create_session(&session_id, Some("/project"), Some("/project")) - .await - .unwrap(); - - let events = vec![ - ConversationEvent::UserMessage { - content: "how do I list files?".to_string(), - }, - ConversationEvent::Text { - content: "Use ls".to_string(), - }, - ConversationEvent::ToolCall { - id: "tc_1".to_string(), - name: "suggest_command".to_string(), - input: serde_json::json!({"command": "ls -la"}), - }, - ]; - - // Persist events manually through the service - let inv_id = "inv-1"; - let mut parent: Option = None; - for event in &events { - let eid = atuin_common::utils::uuid_v7().to_string(); - let (etype, edata) = event_serde::serialize_event(event); - svc.append_event(&session_id, &eid, parent.as_deref(), inv_id, &etype, &edata) - .await - .unwrap(); - parent = Some(eid); - } - - svc.update_server_session_id(&session_id, "srv-abc") - .await - .unwrap(); - - // Now find and resume the session with a fresh service connection - let stored = svc - .find_resumable(Some("/project"), Some("/project"), 3600) - .await - .unwrap() - .expect("should find session"); - - let (mut mgr, loaded_events, server_sid, last_ts, _invocation_id) = - SessionManager::resume(Box::new(svc), &stored) - .await - .unwrap(); - - assert_eq!(loaded_events.len(), 3); - assert_eq!(server_sid.as_deref(), Some("srv-abc")); - assert_ne!(mgr.invocation_id(), inv_id, "new invocation ID on resume"); - assert!(last_ts.is_some(), "should have a last event timestamp"); - - // Persisting again with the same events should be a no-op - mgr.persist_events(&loaded_events).await.unwrap(); - } - - #[tokio::test] - async fn test_incremental_persist() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - let mut events = vec![ConversationEvent::UserMessage { - content: "first".to_string(), - }]; - mgr.persist_events(&events).await.unwrap(); - - // Add more events and persist again — only the new ones should be written - events.push(ConversationEvent::Text { - content: "response".to_string(), - }); - events.push(ConversationEvent::UserMessage { - content: "second".to_string(), - }); - mgr.persist_events(&events).await.unwrap(); - - // Verify by loading through a fresh service (can't easily here since - // the service is moved, but the lack of errors confirms correctness) - } - - #[tokio::test] - async fn test_archive() { - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); - - mgr.archive().await.unwrap(); - } - - #[tokio::test] - async fn test_persist_server_session_id() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - mgr.persist_server_session_id("srv-123").await.unwrap(); - } - - #[tokio::test] - async fn test_parent_chain_integrity() { - // Verify that persisted events form a proper parent chain - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let session_id = { - let mut mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); - - let events = vec![ - ConversationEvent::UserMessage { - content: "a".to_string(), - }, - ConversationEvent::Text { - content: "b".to_string(), - }, - ConversationEvent::UserMessage { - content: "c".to_string(), - }, - ]; - mgr.persist_events(&events).await.unwrap(); - mgr.session_id().to_string() - }; - - // Re-open the store and load events to verify the chain - // (Can't do this with in-memory DB since it's gone, but the - // lack of FK constraint violations during persist confirms the - // parent_id values are valid) - let _ = session_id; - } -} diff --git a/crates/atuin-ai/src/skills/frontmatter.rs b/crates/atuin-ai/src/skills/frontmatter.rs deleted file mode 100644 index 759dffcc..00000000 --- a/crates/atuin-ai/src/skills/frontmatter.rs +++ /dev/null @@ -1,233 +0,0 @@ -//! YAML frontmatter parsing for `SKILL.md` files. -//! -//! Extracts the YAML block between `---` delimiters and parses it with -//! `yaml-rust2`. Returns the parsed fields and the byte offset where the -//! body begins (after the closing `---`). - -use yaml_rust2::YamlLoader; - -/// Parsed frontmatter fields from a `SKILL.md` file. -#[derive(Debug, Default)] -pub(crate) struct Frontmatter { - pub name: Option, - pub description: Option, - pub disable_model_invocation: bool, -} - -/// Result of splitting a skill file into frontmatter + body. -#[derive(Debug)] -pub(crate) struct ParsedSkillFile { - pub frontmatter: Frontmatter, - /// Everything after the closing `---` delimiter. - pub body: String, -} - -/// Parse a `SKILL.md` file's content into frontmatter and body. -/// -/// If no frontmatter delimiters are found, all content is treated as body -/// with default frontmatter. -pub(crate) fn parse(content: &str) -> ParsedSkillFile { - let Some((yaml_str, body)) = split_frontmatter(content) else { - return ParsedSkillFile { - frontmatter: Frontmatter::default(), - body: content.to_string(), - }; - }; - - let frontmatter = match YamlLoader::load_from_str(yaml_str) { - Ok(docs) if !docs.is_empty() => extract_fields(&docs[0]), - Ok(_) => Frontmatter::default(), - Err(e) => { - tracing::warn!("Failed to parse skill frontmatter: {e}"); - Frontmatter::default() - } - }; - - ParsedSkillFile { frontmatter, body } -} - -/// Split content on `---` delimiters. Returns `(yaml_str, body)` or `None` -/// if frontmatter is not present. -fn split_frontmatter(content: &str) -> Option<(&str, String)> { - let trimmed = content.trim_start(); - - // Must start with `---` - if !trimmed.starts_with("---") { - return None; - } - - // Find the end of the opening delimiter line - let after_open = trimmed.get(3..)?.trim_start_matches(|c: char| c != '\n'); - let after_open = after_open.strip_prefix('\n').unwrap_or(after_open); - - // Find the closing `---` - let close_pos = after_open - .lines() - .enumerate() - .find(|(_, line)| line.trim() == "---") - .map(|(i, _)| { - after_open - .lines() - .take(i) - .map(|l| l.len() + 1) // +1 for newline - .sum::() - })?; - - let yaml_str = &after_open[..close_pos]; - let rest = &after_open[close_pos..]; - // Skip the closing `---` line - let body = rest - .strip_prefix("---") - .unwrap_or(rest) - .trim_start_matches(|c: char| c != '\n'); - let body = body.strip_prefix('\n').unwrap_or(body); - - Some((yaml_str, body.to_string())) -} - -fn extract_fields(doc: &yaml_rust2::Yaml) -> Frontmatter { - use yaml_rust2::Yaml; - - let name = match &doc["name"] { - Yaml::String(s) => Some(s.clone()), - _ => None, - }; - - let description = match &doc["description"] { - Yaml::String(s) => Some(s.trim().to_string()), - _ => None, - }; - - let disable_model_invocation = match &doc["disable-model-invocation"] { - Yaml::Boolean(b) => *b, - Yaml::String(s) => matches!(s.as_str(), "true" | "yes" | "1"), - _ => false, - }; - - Frontmatter { - name, - description, - disable_model_invocation, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn basic_frontmatter() { - let content = "\ ---- -name: my-skill -description: A test skill -disable-model-invocation: true ---- - -Body content here. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); - assert_eq!( - parsed.frontmatter.description.as_deref(), - Some("A test skill") - ); - assert!(parsed.frontmatter.disable_model_invocation); - assert_eq!(parsed.body.trim(), "Body content here."); - } - - #[test] - fn multiline_folded_description() { - let content = "\ ---- -name: release -description: > - Orchestrate a multi-step release — version bumping, changelog - generation, PR creation, tagging, and publishing. -disable-model-invocation: true ---- - -# Release steps -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("release")); - let desc = parsed.frontmatter.description.unwrap(); - assert!(desc.contains("Orchestrate a multi-step release")); - assert!(desc.contains("publishing")); - assert!(parsed.frontmatter.disable_model_invocation); - assert!(parsed.body.contains("# Release steps")); - } - - #[test] - fn no_frontmatter() { - let content = "Just a body with no frontmatter."; - let parsed = parse(content); - assert!(parsed.frontmatter.name.is_none()); - assert!(parsed.frontmatter.description.is_none()); - assert!(!parsed.frontmatter.disable_model_invocation); - assert_eq!(parsed.body, content); - } - - #[test] - fn empty_frontmatter() { - let content = "\ ---- ---- - -Body after empty frontmatter. -"; - let parsed = parse(content); - assert!(parsed.frontmatter.name.is_none()); - assert!(parsed.frontmatter.description.is_none()); - assert_eq!(parsed.body.trim(), "Body after empty frontmatter."); - } - - #[test] - fn missing_fields_use_defaults() { - let content = "\ ---- -name: partial ---- - -Some body. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("partial")); - assert!(parsed.frontmatter.description.is_none()); - assert!(!parsed.frontmatter.disable_model_invocation); - } - - #[test] - fn unknown_fields_ignored() { - let content = "\ ---- -name: my-skill -future-field: some value -another: 42 ---- - -Body. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); - } - - #[test] - fn body_with_triple_dashes() { - let content = "\ ---- -name: test ---- - -Some body. - ---- - -More body after a horizontal rule. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("test")); - assert!(parsed.body.contains("Some body.")); - assert!(parsed.body.contains("More body after a horizontal rule.")); - } -} diff --git a/crates/atuin-ai/src/skills/mod.rs b/crates/atuin-ai/src/skills/mod.rs deleted file mode 100644 index 36b3a2ae..00000000 --- a/crates/atuin-ai/src/skills/mod.rs +++ /dev/null @@ -1,468 +0,0 @@ -//! AI skill discovery, metadata, and lazy loading. -//! -//! Skills are markdown files (`SKILL.md`) with YAML frontmatter that define -//! reusable instructions for the LLM. Only skill metadata (name + description) -//! is sent to the server; full content is loaded on demand via `load_skill`. - -mod frontmatter; -pub(crate) mod walker; - -use std::path::Path; - -use eyre::{Result, eyre}; - -use crate::user_context::interpolate; - -/// Per-skill description truncation limit (before budget calculation). -const MAX_DESCRIPTION_LEN: usize = 1024; - -/// Default total character budget for skill descriptions sent to the server. -const DEFAULT_DESCRIPTION_BUDGET: usize = 9992; - -/// JSON overhead per skill entry: `{"name":"","description":""},` ≈ 30 chars. -const PER_ENTRY_OVERHEAD: usize = 30; - -/// Metadata for a discovered skill. Produced at discovery time from -/// frontmatter only — the body is not read until `load()`. -#[derive(Debug, Clone)] -pub(crate) struct SkillDescriptor { - pub name: String, - pub description: String, - pub source_path: std::path::PathBuf, - pub disable_model_invocation: bool, -} - -/// A name + description pair ready to serialize into the request payload. -#[derive(Debug, Clone, serde::Serialize)] -pub(crate) struct SkillSummary { - pub name: String, - pub description: String, -} - -/// Holds discovered skills and provides lookup, budget packing, and loading. -#[derive(Debug, Clone)] -pub(crate) struct SkillRegistry { - skills: Vec, -} - -impl SkillRegistry { - /// Discover skills from project and global directories. - pub async fn discover(project_root: Option<&Path>) -> Self { - let global_dir = walker::global_skills_dir(); - let project_dir = project_root.map(walker::project_skills_dir); - - Self::discover_from_dirs(project_dir.as_deref(), &global_dir).await - } - - /// Discover skills from explicit directory paths. Useful for testing. - pub async fn discover_from_dirs( - project_skills_dir: Option<&Path>, - global_skills_dir: &Path, - ) -> Self { - let raw_files = walker::discover(project_skills_dir, global_skills_dir).await; - - let mut skills = Vec::new(); - let mut seen_names = std::collections::HashSet::new(); - - for raw in raw_files { - let parsed = frontmatter::parse(&raw.content); - let fm = parsed.frontmatter; - - let name = fm.name.unwrap_or_else(|| sanitize_name(&raw.dir_name)); - - // Deduplicate: first seen wins (project before global) - if !seen_names.insert(name.clone()) { - continue; - } - - let description = fm - .description - .or_else(|| first_paragraph(&parsed.body)) - .unwrap_or_default(); - - skills.push(SkillDescriptor { - name, - description, - source_path: raw.path, - disable_model_invocation: fm.disable_model_invocation, - }); - } - - Self { skills } - } - - /// Create an empty registry. - #[cfg(test)] - pub fn empty() -> Self { - Self { skills: Vec::new() } - } - - /// Look up a skill by name. - pub fn get(&self, name: &str) -> Option<&SkillDescriptor> { - self.skills.iter().find(|s| s.name == name) - } - - /// All discovered skills. - pub fn all(&self) -> &[SkillDescriptor] { - &self.skills - } - - /// Whether any non-disabled skills exist (determines capability advertisement). - #[cfg(test)] - pub fn has_server_visible_skills(&self) -> bool { - self.skills.iter().any(|s| !s.disable_model_invocation) - } - - /// Pack skill descriptions into the server payload under a character budget. - /// - /// Returns the summaries that fit plus an optional overflow message. - pub fn server_skills(&self) -> (Vec, Option) { - self.server_skills_with_budget(DEFAULT_DESCRIPTION_BUDGET) - } - - pub fn server_skills_with_budget(&self, budget: usize) -> (Vec, Option) { - let eligible: Vec<&SkillDescriptor> = self - .skills - .iter() - .filter(|s| !s.disable_model_invocation) - .collect(); - - let mut summaries = Vec::new(); - let mut used = 0; - let mut overflow_names = Vec::new(); - - for skill in &eligible { - let truncated_desc = truncate_description(&skill.description, MAX_DESCRIPTION_LEN); - let entry_size = skill.name.len() + truncated_desc.len() + PER_ENTRY_OVERHEAD; - - if used + entry_size > budget && !summaries.is_empty() { - overflow_names.push(skill.name.as_str()); - continue; - } - - used += entry_size; - summaries.push(SkillSummary { - name: skill.name.clone(), - description: truncated_desc, - }); - } - - let overflow = if overflow_names.is_empty() { - None - } else { - Some(format!( - "{} additional skill(s) not listed due to size limits: {}", - overflow_names.len(), - overflow_names.join(", ") - )) - }; - - (summaries, overflow) - } - - /// Load a skill's full body content, with argument substitution and - /// `!`` interpolation applied. - /// - /// `$ARGUMENTS` in the body is replaced with the provided arguments before - /// shell interpolation runs. If `$ARGUMENTS` does not appear in the body - /// and arguments were provided, they are appended as `ARGUMENTS: `. - pub async fn load(&self, name: &str, shell: &str, arguments: Option<&str>) -> Result { - let skill = self - .get(name) - .ok_or_else(|| eyre!("Unknown skill: {name}"))?; - - let content = tokio::fs::read_to_string(&skill.source_path).await?; - let parsed = frontmatter::parse(&content); - let body = parsed.body; - - if body.trim().is_empty() { - return Ok(format!("(Skill '{name}' has no body content)")); - } - - let body = substitute_arguments(&body, arguments); - - Ok(interpolate::interpolate(&body, shell).await) - } -} - -/// Replace `$ARGUMENTS` placeholders in skill body text. -/// -/// If `$ARGUMENTS` appears in the body, all occurrences are replaced with the -/// argument string (or empty string if none). If `$ARGUMENTS` does not appear -/// and arguments were provided, they are appended on a new line. -fn substitute_arguments(body: &str, arguments: Option<&str>) -> String { - let args = arguments.unwrap_or(""); - - if body.contains("$ARGUMENTS") { - return body.replace("$ARGUMENTS", args); - } - - if !args.is_empty() { - return format!("{body}\n\nARGUMENTS: {args}"); - } - - body.to_string() -} - -/// Sanitize a directory name into a valid skill name. -fn sanitize_name(name: &str) -> String { - name.chars() - .map(|c| { - if c.is_ascii_alphanumeric() || c == '-' { - c - } else { - '-' - } - }) - .collect::() - .to_lowercase() -} - -/// Extract the first non-empty paragraph from markdown body text. -fn first_paragraph(body: &str) -> Option { - let trimmed = body.trim(); - if trimmed.is_empty() { - return None; - } - - let para: String = trimmed - .lines() - .take_while(|line| !line.trim().is_empty()) - .collect::>() - .join(" "); - - let para = para.trim().to_string(); - if para.is_empty() { None } else { Some(para) } -} - -/// Truncate a description to `max_len` characters, adding ellipsis if cut. -fn truncate_description(desc: &str, max_len: usize) -> String { - if desc.len() <= max_len { - return desc.to_string(); - } - let mut end = max_len.saturating_sub(3); - // Avoid splitting a multi-byte char - while !desc.is_char_boundary(end) && end > 0 { - end -= 1; - } - format!("{}...", &desc[..end]) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanitize_name_basic() { - assert_eq!(sanitize_name("My Skill"), "my-skill"); - assert_eq!(sanitize_name("deploy_prod"), "deploy-prod"); - assert_eq!(sanitize_name("code-review"), "code-review"); - } - - #[test] - fn first_paragraph_extraction() { - assert_eq!( - first_paragraph("Hello world\nSecond line\n\nNew paragraph"), - Some("Hello world Second line".to_string()) - ); - assert_eq!(first_paragraph(""), None); - assert_eq!(first_paragraph("\n\n"), None); - assert_eq!( - first_paragraph("Single line"), - Some("Single line".to_string()) - ); - } - - #[test] - fn truncate_description_short() { - assert_eq!(truncate_description("short", 100), "short"); - } - - #[test] - fn substitute_arguments_replaces_placeholder() { - let body = "Deploy $ARGUMENTS to production."; - assert_eq!( - substitute_arguments(body, Some("patch")), - "Deploy patch to production." - ); - } - - #[test] - fn substitute_arguments_multiple_occurrences() { - let body = "Run $ARGUMENTS then verify $ARGUMENTS worked."; - assert_eq!( - substitute_arguments(body, Some("migrate")), - "Run migrate then verify migrate worked." - ); - } - - #[test] - fn substitute_arguments_appends_when_no_placeholder() { - let body = "Do the thing."; - let result = substitute_arguments(body, Some("extra context")); - assert!(result.starts_with("Do the thing.")); - assert!(result.contains("ARGUMENTS: extra context")); - } - - #[test] - fn substitute_arguments_no_args_no_placeholder() { - let body = "Just a body."; - assert_eq!(substitute_arguments(body, None), "Just a body."); - } - - #[test] - fn substitute_arguments_no_args_clears_placeholder() { - let body = "Deploy $ARGUMENTS to production."; - assert_eq!(substitute_arguments(body, None), "Deploy to production."); - } - - #[test] - fn truncate_description_long() { - let long = "a".repeat(600); - let result = truncate_description(&long, 512); - assert!(result.len() <= 512); - assert!(result.ends_with("...")); - } - - #[test] - fn budget_packing() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "a".to_string(), - description: "Short desc".to_string(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "b".to_string(), - description: "Another desc".to_string(), - source_path: "b/SKILL.md".into(), - disable_model_invocation: false, - }, - ], - }; - - let (summaries, overflow) = registry.server_skills_with_budget(4096); - assert_eq!(summaries.len(), 2); - assert!(overflow.is_none()); - } - - #[test] - fn budget_overflow() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "first".to_string(), - description: "x".repeat(200), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "second".to_string(), - description: "y".repeat(200), - source_path: "b/SKILL.md".into(), - disable_model_invocation: false, - }, - ], - }; - - // Budget only fits one - let (summaries, overflow) = registry.server_skills_with_budget(260); - assert_eq!(summaries.len(), 1); - assert_eq!(summaries[0].name, "first"); - let overflow = overflow.unwrap(); - assert!(overflow.contains("second")); - assert!(overflow.contains("1 additional")); - } - - #[test] - fn disabled_skills_excluded_from_server() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "visible".to_string(), - description: "I show up".to_string(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "hidden".to_string(), - description: "I don't".to_string(), - source_path: "b/SKILL.md".into(), - disable_model_invocation: true, - }, - ], - }; - - let (summaries, _) = registry.server_skills(); - assert_eq!(summaries.len(), 1); - assert_eq!(summaries[0].name, "visible"); - - // But all() includes both - assert_eq!(registry.all().len(), 2); - } - - #[test] - fn has_server_visible_skills() { - let empty = SkillRegistry::empty(); - assert!(!empty.has_server_visible_skills()); - - let all_disabled = SkillRegistry { - skills: vec![SkillDescriptor { - name: "hidden".to_string(), - description: String::new(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: true, - }], - }; - assert!(!all_disabled.has_server_visible_skills()); - - let some_visible = SkillRegistry { - skills: vec![SkillDescriptor { - name: "visible".to_string(), - description: String::new(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }], - }; - assert!(some_visible.has_server_visible_skills()); - } - - #[tokio::test] - async fn end_to_end_discover() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - - // Create a skill with frontmatter - let skill_dir = skills_dir.join("my-skill"); - std::fs::create_dir_all(&skill_dir).unwrap(); - std::fs::write( - skill_dir.join("SKILL.md"), - "---\nname: my-skill\ndescription: A test skill\n---\n\nBody here.\n", - ) - .unwrap(); - - // Create a skill with multiline description - let skill_dir2 = skills_dir.join("release"); - std::fs::create_dir_all(&skill_dir2).unwrap(); - std::fs::write( - skill_dir2.join("SKILL.md"), - "---\nname: release\ndescription: >\n Multi-line\n description here.\n---\n\nRelease steps.\n", - ) - .unwrap(); - - let registry = SkillRegistry::discover_from_dirs( - Some(&skills_dir), - &std::path::PathBuf::from("/nonexistent"), - ) - .await; - assert_eq!(registry.all().len(), 2); - - let my_skill = registry.get("my-skill").unwrap(); - assert_eq!(my_skill.description, "A test skill"); - - let release = registry.get("release").unwrap(); - assert!(release.description.contains("Multi-line")); - } -} diff --git a/crates/atuin-ai/src/skills/walker.rs b/crates/atuin-ai/src/skills/walker.rs deleted file mode 100644 index b93845f9..00000000 --- a/crates/atuin-ai/src/skills/walker.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Filesystem discovery for `SKILL.md` files. -//! -//! Recursively scans `.atuin/skills/` directories at the project and global -//! levels. Supports nested directories for organization (e.g. -//! `.atuin/skills/ops/deploy/SKILL.md`). - -use std::path::{Path, PathBuf}; - -const SKILL_FILENAME: &str = "SKILL.md"; - -/// A skill file found on disk, before body interpolation. -#[derive(Debug)] -pub(crate) struct RawSkillFile { - /// Full path to the SKILL.md file. - pub path: PathBuf, - /// The parent directory name, used as fallback skill name. - pub dir_name: String, - /// Whether this is a project-level skill (vs global). - #[allow(dead_code)] - pub is_project: bool, - /// Raw file content. - pub content: String, -} - -/// Discover all `SKILL.md` files across project and global skill directories. -/// -/// Project skills come first in the returned list (higher priority for -/// deduplication). -pub(crate) async fn discover( - project_skills_dir: Option<&Path>, - global_skills_dir: &Path, -) -> Vec { - let mut files = Vec::new(); - - // Project skills first (higher priority) - if let Some(dir) = project_skills_dir.filter(|d| d.is_dir()) { - scan_dir(dir, true, &mut files).await; - } - - // Global skills second - if global_skills_dir.is_dir() { - scan_dir(global_skills_dir, false, &mut files).await; - } - - files -} - -/// The default global skills directory (`~/.config/atuin/skills/`). -pub(crate) fn global_skills_dir() -> PathBuf { - atuin_common::utils::config_dir().join("skills") -} - -/// Given a project working directory, return the project skills directory. -pub(crate) fn project_skills_dir(project_root: &Path) -> PathBuf { - project_root.join(".atuin").join("skills") -} - -/// Recursively scan a directory for `SKILL.md` files. -async fn scan_dir(dir: &Path, is_project: bool, out: &mut Vec) { - let mut entries = match tokio::fs::read_dir(dir).await { - Ok(entries) => entries, - Err(e) => { - tracing::debug!("Could not read skills directory {}: {e}", dir.display()); - return; - } - }; - - let mut subdirs = Vec::new(); - - while let Ok(Some(entry)) = entries.next_entry().await { - let path = entry.path(); - - if path.is_dir() { - // Check for SKILL.md directly in this directory - let skill_path = path.join(SKILL_FILENAME); - if skill_path.is_file() { - let dir_name = path - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string(); - - match tokio::fs::read_to_string(&skill_path).await { - Ok(content) => { - out.push(RawSkillFile { - path: skill_path, - dir_name, - is_project, - content, - }); - } - Err(e) => { - tracing::warn!("Failed to read skill file {}: {e}", skill_path.display()); - } - } - } - - // Collect subdirectories for recursive scanning - subdirs.push(path); - } - } - - for subdir in subdirs { - Box::pin(scan_dir(&subdir, is_project, out)).await; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn setup_skill(dir: &Path, rel_path: &str, content: &str) { - let skill_dir = dir.join(rel_path); - std::fs::create_dir_all(&skill_dir).unwrap(); - std::fs::write(skill_dir.join(SKILL_FILENAME), content).unwrap(); - } - - #[tokio::test] - async fn discovers_project_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "deploy", "---\nname: deploy\n---\nDeploy."); - - let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; - assert_eq!(files.len(), 1); - assert_eq!(files[0].dir_name, "deploy"); - assert!(files[0].is_project); - } - - #[tokio::test] - async fn discovers_global_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "review", "---\nname: review\n---\nReview."); - - let files = discover(None, &skills_dir).await; - assert_eq!(files.len(), 1); - assert_eq!(files[0].dir_name, "review"); - assert!(!files[0].is_project); - } - - #[tokio::test] - async fn discovers_nested_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "ops/deploy", "---\nname: deploy\n---\n"); - setup_skill(&skills_dir, "ops/rollback", "---\nname: rollback\n---\n"); - - let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; - assert_eq!(files.len(), 2); - } - - #[tokio::test] - async fn project_comes_before_global() { - let project = tempfile::tempdir().unwrap(); - let global = tempfile::tempdir().unwrap(); - let project_skills = project.path().join("skills"); - let global_skills = global.path().join("skills"); - - setup_skill(&project_skills, "a-skill", "project"); - setup_skill(&global_skills, "b-skill", "global"); - - let files = discover(Some(&project_skills), &global_skills).await; - assert_eq!(files.len(), 2); - assert!(files[0].is_project); - assert!(!files[1].is_project); - } - - #[tokio::test] - async fn missing_directories_handled() { - let files = discover( - Some(Path::new("/does/not/exist")), - Path::new("/also/missing"), - ) - .await; - assert!(files.is_empty()); - } -} diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs deleted file mode 100644 index d46223a8..00000000 --- a/crates/atuin-ai/src/snapshots.rs +++ /dev/null @@ -1,414 +0,0 @@ -//! Backup snapshots for files before AI edits. -//! -//! Before the first edit to a file within a session, a snapshot of the -//! original content is saved so the user can recover if needed. Snapshots -//! are stored alongside a manifest that maps sanitized filenames back to -//! their original paths. -//! -//! Filenames use percent-encoding (`/` → `%2F`) so the snapshot directory -//! is human-readable via `ls`. - -use std::collections::HashMap; -use std::io::Write; -use std::path::{Path, PathBuf}; - -use eyre::{Result, eyre}; -use serde::{Deserialize, Serialize}; -use time::OffsetDateTime; - -/// Snapshot store for a single session. -/// -/// Each session gets its own directory under the snapshots root: -/// `/ai/snapshots//` -/// -/// Files are stored with percent-encoded filenames derived from their -/// canonical paths, alongside a `manifest.json` that maps filenames -/// back to original paths with timestamps. -#[derive(Debug)] -pub(crate) struct SnapshotStore { - session_dir: PathBuf, - manifest: SnapshotManifest, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -struct SnapshotManifest { - files: HashMap, -} - -#[derive(Debug, Serialize, Deserialize)] -struct SnapshotEntry { - original_path: String, - snapshot_at: String, - size_bytes: u64, -} - -impl SnapshotStore { - /// Open or create a snapshot store for the given session directory. - /// - /// If a manifest already exists (from a prior CLI invocation in the same - /// session), it's loaded so we don't re-snapshot files that were already - /// backed up. - pub fn open(session_dir: PathBuf) -> Result { - let manifest_path = session_dir.join("manifest.json"); - let manifest = if manifest_path.exists() { - let data = fs_err::read_to_string(&manifest_path)?; - serde_json::from_str(&data)? - } else { - SnapshotManifest::default() - }; - - Ok(Self { - session_dir, - manifest, - }) - } - - /// Snapshot a file's contents if it hasn't been snapshotted yet this session. - /// - /// Returns `true` if a new snapshot was created, `false` if one already - /// existed. The `canonical_path` should be absolute (already tilde-expanded - /// and resolved). - pub fn ensure_snapshot(&mut self, canonical_path: &Path, content: &[u8]) -> Result { - let filename = sanitize_path(canonical_path); - - if self.manifest.files.contains_key(&filename) { - return Ok(false); - } - - fs_err::create_dir_all(&self.session_dir)?; - - let snapshot_path = self.session_dir.join(&filename); - atomic_write_file(&snapshot_path, content)?; - - let now = OffsetDateTime::now_utc(); - let entry = SnapshotEntry { - original_path: canonical_path.to_string_lossy().into_owned(), - snapshot_at: format_iso8601(now), - size_bytes: content.len() as u64, - }; - - self.manifest.files.insert(filename, entry); - self.save_manifest()?; - - Ok(true) - } - - /// Whether a file has already been snapshotted in this session. - #[cfg(test)] - pub fn has_snapshot(&self, canonical_path: &Path) -> bool { - let filename = sanitize_path(canonical_path); - self.manifest.files.contains_key(&filename) - } - - fn save_manifest(&self) -> Result<()> { - let json = serde_json::to_string_pretty(&self.manifest)?; - atomic_write_file(&self.session_dir.join("manifest.json"), json.as_bytes()) - } -} - -/// Percent-encode a path for use as a filename. -/// -/// Encodes `%` as `%25`, `/` as `%2F`, and `\` as `%5C`, then strips -/// leading separators and drive prefixes (e.g. `C:\`). The result is -/// always a flat filename safe for use with `Path::join` on any platform. -/// -/// Example (Unix): `/Users/me/.config/foo.toml` → `Users%2Fme%2F.config%2Ffoo.toml` -/// Example (Windows): `C:\Users\me\config.toml` → `Users%5Cme%5Cconfig.toml` -pub(crate) fn sanitize_path(path: &Path) -> String { - let s = path.to_string_lossy(); - // Strip drive letter prefix on Windows (e.g. "C:\") - let s = s.strip_prefix('/').unwrap_or_else(|| { - // Handle Windows drive prefix like "C:\" or "C:/" - if s.len() >= 3 - && s.as_bytes()[0].is_ascii_alphabetic() - && s.as_bytes()[1] == b':' - && (s.as_bytes()[2] == b'\\' || s.as_bytes()[2] == b'/') - { - &s[3..] - } else { - &s - } - }); - s.replace('%', "%25") - .replace('/', "%2F") - .replace('\\', "%5C") -} - -/// Write a file atomically using temp-file-then-rename. -/// -/// Creates a temporary file in the same directory as `target`, writes -/// content, fsyncs, then renames into place. Preserves permissions from -/// the original file if it exists. -pub(crate) fn atomic_write_file(target: &Path, content: &[u8]) -> Result<()> { - let dir = target - .parent() - .ok_or_else(|| eyre!("target path has no parent directory"))?; - fs_err::create_dir_all(dir)?; - - let mut tmp = tempfile::NamedTempFile::new_in(dir)?; - tmp.write_all(content)?; - tmp.as_file().sync_all()?; - - // Preserve permissions from original if it exists - if let Ok(meta) = std::fs::metadata(target) { - std::fs::set_permissions(tmp.path(), meta.permissions())?; - } - - tmp.persist(target).map_err(|e| { - eyre!( - "failed to persist atomic write to {}: {}", - target.display(), - e - ) - })?; - Ok(()) -} - -fn format_iso8601(dt: OffsetDateTime) -> String { - format!( - "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", - dt.year(), - dt.month() as u8, - dt.day(), - dt.hour(), - dt.minute(), - dt.second(), - ) -} - -#[cfg(test)] -mod tests { - use super::*; - - // ── sanitize_path ────────────────────────────────────────── - - #[test] - fn sanitize_absolute_path() { - let path = Path::new("/Users/me/.config/atuin/config.toml"); - assert_eq!( - sanitize_path(path), - "Users%2Fme%2F.config%2Fatuin%2Fconfig.toml" - ); - } - - #[test] - fn sanitize_preserves_existing_percent() { - let path = Path::new("/data/100%done/file.txt"); - assert_eq!(sanitize_path(path), "data%2F100%25done%2Ffile.txt"); - } - - #[test] - fn sanitize_relative_path() { - let path = Path::new("relative/path.txt"); - assert_eq!(sanitize_path(path), "relative%2Fpath.txt"); - } - - #[test] - fn sanitize_no_collision_between_similar_paths() { - let a = sanitize_path(Path::new("/foo/bar-baz")); - let b = sanitize_path(Path::new("/foo/bar/baz")); - assert_ne!(a, b); - } - - #[test] - fn sanitize_backslash_encoded() { - // Windows-style path: backslashes become %5C, drive prefix stripped - let s = sanitize_path(Path::new("C:\\Users\\me\\config.toml")); - assert!(!s.contains('\\'), "backslashes must be encoded: {s}"); - assert!(!s.starts_with("C:"), "drive prefix must be stripped: {s}"); - assert!(s.contains("Users")); - assert!(s.contains("config.toml")); - } - - #[test] - fn sanitize_result_is_flat_filename() { - // The result must not be interpreted as a path with separators - // when passed to Path::join — no raw / or \ allowed. - let unix = sanitize_path(Path::new("/home/user/file.txt")); - assert!(!unix.contains('/')); - // Construct as if on Windows - let win = "C:\\Users\\me\\file.txt"; - let encoded = win - .strip_prefix("C:\\") - .unwrap() - .replace('%', "%25") - .replace('/', "%2F") - .replace('\\', "%5C"); - assert!(!encoded.contains('\\')); - assert!(!encoded.contains('/')); - } - - // ── atomic_write_file ────────────────────────────────────── - - #[test] - fn atomic_write_creates_file() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - atomic_write_file(&target, b"hello world").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "hello world"); - } - - #[test] - fn atomic_write_overwrites_existing() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - std::fs::write(&target, "old content").unwrap(); - atomic_write_file(&target, b"new content").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "new content"); - } - - #[test] - fn atomic_write_creates_parent_dirs() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("sub").join("dir").join("test.txt"); - - atomic_write_file(&target, b"nested").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "nested"); - } - - #[cfg(unix)] - #[test] - fn atomic_write_preserves_permissions() { - use std::os::unix::fs::PermissionsExt; - - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - std::fs::write(&target, "original").unwrap(); - std::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o600)).unwrap(); - - atomic_write_file(&target, b"updated").unwrap(); - - let mode = std::fs::metadata(&target).unwrap().permissions().mode() & 0o777; - assert_eq!(mode, 0o600); - } - - // ── SnapshotStore ────────────────────────────────────────── - - #[test] - fn snapshot_creates_file_and_manifest() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - let file_path = Path::new("/Users/me/.config/foo.toml"); - let created = store - .ensure_snapshot(file_path, b"[key]\nval = 1\n") - .unwrap(); - - assert!(created); - assert!(store.has_snapshot(file_path)); - - // Snapshot file on disk - let expected_file = session_dir.join("Users%2Fme%2F.config%2Ffoo.toml"); - assert!(expected_file.exists()); - assert_eq!( - std::fs::read_to_string(&expected_file).unwrap(), - "[key]\nval = 1\n" - ); - - // Manifest on disk - let manifest_path = session_dir.join("manifest.json"); - assert!(manifest_path.exists()); - let manifest: serde_json::Value = - serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); - let files = manifest["files"].as_object().unwrap(); - assert_eq!(files.len(), 1); - let entry = &files["Users%2Fme%2F.config%2Ffoo.toml"]; - assert_eq!( - entry["original_path"].as_str().unwrap(), - "/Users/me/.config/foo.toml" - ); - assert_eq!(entry["size_bytes"].as_u64().unwrap(), 14); - } - - #[test] - fn snapshot_is_idempotent() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - let path = Path::new("/etc/hosts"); - let first = store.ensure_snapshot(path, b"first content").unwrap(); - let second = store.ensure_snapshot(path, b"different content").unwrap(); - - assert!(first); - assert!(!second); - - // Original content preserved, not overwritten - let snapshot_file = session_dir.join("etc%2Fhosts"); - assert_eq!( - std::fs::read_to_string(snapshot_file).unwrap(), - "first content" - ); - } - - #[test] - fn snapshot_store_loads_existing_manifest() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - - // First store: create a snapshot - { - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - store - .ensure_snapshot(Path::new("/etc/hosts"), b"127.0.0.1") - .unwrap(); - } - - // Second store (simulates new CLI invocation): should see existing snapshot - { - let mut store = SnapshotStore::open(session_dir).unwrap(); - assert!(store.has_snapshot(Path::new("/etc/hosts"))); - - let created = store - .ensure_snapshot(Path::new("/etc/hosts"), b"new content") - .unwrap(); - assert!(!created); - } - } - - #[test] - fn snapshot_multiple_files() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - store - .ensure_snapshot(Path::new("/etc/hosts"), b"hosts content") - .unwrap(); - store - .ensure_snapshot(Path::new("/Users/me/.bashrc"), b"bashrc content") - .unwrap(); - - assert!(store.has_snapshot(Path::new("/etc/hosts"))); - assert!(store.has_snapshot(Path::new("/Users/me/.bashrc"))); - assert!(!store.has_snapshot(Path::new("/nonexistent"))); - - // Both snapshot files exist - assert!(session_dir.join("etc%2Fhosts").exists()); - assert!(session_dir.join("Users%2Fme%2F.bashrc").exists()); - - // Manifest has both entries - let manifest: serde_json::Value = serde_json::from_str( - &std::fs::read_to_string(session_dir.join("manifest.json")).unwrap(), - ) - .unwrap(); - assert_eq!(manifest["files"].as_object().unwrap().len(), 2); - } - - #[test] - fn format_iso8601_produces_valid_format() { - let dt = OffsetDateTime::from_unix_timestamp(1700000000).unwrap(); - let formatted = format_iso8601(dt); - assert_eq!(formatted.len(), 20); - assert!(formatted.starts_with("2023-")); - assert!(formatted.contains('T')); - assert!(formatted.ends_with('Z')); - } -} diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs deleted file mode 100644 index 20b9e881..00000000 --- a/crates/atuin-ai/src/store.rs +++ /dev/null @@ -1,554 +0,0 @@ -use std::path::Path; -use std::str::FromStr; -use std::time::Duration; - -use eyre::{Result, eyre}; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; -use time::OffsetDateTime; - -// Database row mappings — all columns are kept even if not yet read in -// non-test code, since they're part of the schema and used in tests. -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) struct StoredSession { - pub id: String, - pub head_id: Option, - pub server_session_id: Option, - pub directory: Option, - pub git_root: Option, - pub created_at: i64, - pub updated_at: i64, - pub archived_at: Option, -} - -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) struct StoredEvent { - pub id: String, - pub session_id: String, - pub parent_id: Option, - pub invocation_id: String, - pub event_type: String, - pub event_data: String, - pub created_at: i64, -} - -/// Row type returned by session queries (avoids clippy::type_complexity). -type SessionRow = ( - String, - Option, - Option, - Option, - Option, - i64, - i64, - Option, -); - -/// Row type returned by event queries. -type EventRow = (String, String, Option, String, String, String, i64); - -pub(crate) struct AiSessionStore { - pool: SqlitePool, -} - -impl AiSessionStore { - pub async fn new(path: impl AsRef, timeout: f64) -> Result { - let path = path.as_ref(); - let path_str = path - .as_os_str() - .to_str() - .ok_or_else(|| eyre!("AI session database path is not valid UTF-8: {path:?}"))?; - - let is_memory = path_str.contains(":memory:"); - - if !is_memory - && !path.exists() - && let Some(dir) = path.parent() - { - fs_err::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path_str)? - .journal_mode(SqliteJournalMode::Wal) - .optimize_on_close(true, None) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - #[cfg(unix)] - if !is_memory { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; - } - - Ok(Self { pool }) - } - - pub async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result { - let now = OffsetDateTime::now_utc().unix_timestamp(); - - sqlx::query( - "INSERT INTO sessions (id, directory, git_root, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?4)", - ) - .bind(id) - .bind(directory) - .bind(git_root) - .bind(now) - .execute(&self.pool) - .await?; - - Ok(StoredSession { - id: id.to_string(), - head_id: None, - server_session_id: None, - directory: directory.map(String::from), - git_root: git_root.map(String::from), - created_at: now, - updated_at: now, - archived_at: None, - }) - } - - #[allow(dead_code)] // used in tests; will be used by daemon service - pub async fn get_session(&self, id: &str) -> Result> { - let row: Option = sqlx::query_as( - "SELECT id, head_id, server_session_id, directory, git_root, - created_at, updated_at, archived_at - FROM sessions WHERE id = ?1", - ) - .bind(id) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map( - |( - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - )| { - StoredSession { - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - } - }, - )) - } - - /// Find the most recent non-archived session matching the given directory or git - /// root, updated within `max_age_secs` seconds. - pub async fn find_resumable_session( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result> { - let cutoff = OffsetDateTime::now_utc().unix_timestamp() - max_age_secs; - - let row: Option = sqlx::query_as( - "SELECT id, head_id, server_session_id, directory, git_root, - created_at, updated_at, archived_at - FROM sessions - WHERE archived_at IS NULL - AND updated_at > ?1 - AND (directory = ?2 OR (git_root IS NOT NULL AND git_root = ?3)) - ORDER BY updated_at DESC - LIMIT 1", - ) - .bind(cutoff) - .bind(directory) - .bind(git_root) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map( - |( - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - )| { - StoredSession { - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - } - }, - )) - } - - /// Append a single event and update the session's `head_id` and `updated_at`. - pub async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - - let mut tx = self.pool.begin().await?; - - sqlx::query( - "INSERT INTO session_events (id, session_id, parent_id, invocation_id, event_type, event_data, created_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", - ) - .bind(event_id) - .bind(session_id) - .bind(parent_id) - .bind(invocation_id) - .bind(event_type) - .bind(event_data) - .bind(now) - .execute(&mut *tx) - .await?; - - sqlx::query("UPDATE sessions SET head_id = ?1, updated_at = ?2 WHERE id = ?3") - .bind(event_id) - .bind(now) - .bind(session_id) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - Ok(()) - } - - /// Load all events for a session, ordered chronologically. - pub async fn load_events(&self, session_id: &str) -> Result> { - let rows: Vec = sqlx::query_as( - "SELECT id, session_id, parent_id, invocation_id, event_type, event_data, created_at - FROM session_events - WHERE session_id = ?1 - ORDER BY created_at ASC, rowid ASC", - ) - .bind(session_id) - .fetch_all(&self.pool) - .await?; - - Ok(rows - .into_iter() - .map( - |(id, session_id, parent_id, invocation_id, event_type, event_data, created_at)| { - StoredEvent { - id, - session_id, - parent_id, - invocation_id, - event_type, - event_data, - created_at, - } - }, - ) - .collect()) - } - - pub async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()> { - sqlx::query("UPDATE sessions SET server_session_id = ?1 WHERE id = ?2") - .bind(server_session_id) - .bind(session_id) - .execute(&self.pool) - .await?; - Ok(()) - } - - pub async fn archive_session(&self, session_id: &str) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - sqlx::query("UPDATE sessions SET archived_at = ?1 WHERE id = ?2") - .bind(now) - .bind(session_id) - .execute(&self.pool) - .await?; - Ok(()) - } - - // ── Session metadata (key-value per session) ── - - /// Read a metadata value for a session. Returns `None` if the key doesn't - /// exist or the session hasn't been persisted yet. - pub async fn get_metadata(&self, session_id: &str, key: &str) -> Result> { - let row: Option<(String,)> = - sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ?1 AND key = ?2") - .bind(session_id) - .bind(key) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(|(v,)| v)) - } - - /// Write a metadata value for a session (upsert). - pub async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - sqlx::query( - "INSERT INTO session_metadata (session_id, key, value, updated_at) - VALUES (?1, ?2, ?3, ?4) - ON CONFLICT (session_id, key) DO UPDATE SET value = ?3, updated_at = ?4", - ) - .bind(session_id) - .bind(key) - .bind(value) - .bind(now) - .execute(&self.pool) - .await?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn new_test_store() -> AiSessionStore { - AiSessionStore::new("sqlite::memory:", 2.0).await.unwrap() - } - - #[tokio::test] - async fn test_create_and_get_session() { - let store = new_test_store().await; - - let session = store - .create_session("s1", Some("/home/user/project"), Some("/home/user/project")) - .await - .unwrap(); - assert_eq!(session.id, "s1"); - assert!(session.head_id.is_none()); - assert!(session.archived_at.is_none()); - - let loaded = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(loaded.id, "s1"); - assert_eq!(loaded.directory.as_deref(), Some("/home/user/project")); - } - - #[tokio::test] - async fn test_get_nonexistent_session() { - let store = new_test_store().await; - assert!(store.get_session("nope").await.unwrap().is_none()); - } - - #[tokio::test] - async fn test_append_and_load_events() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .append_event( - "s1", - "e1", - None, - "inv1", - "user_message", - r#"{"content":"hello"}"#, - ) - .await - .unwrap(); - store - .append_event( - "s1", - "e2", - Some("e1"), - "inv1", - "text", - r#"{"content":"hi there"}"#, - ) - .await - .unwrap(); - - let events = store.load_events("s1").await.unwrap(); - assert_eq!(events.len(), 2); - assert_eq!(events[0].id, "e1"); - assert!(events[0].parent_id.is_none()); - assert_eq!(events[0].invocation_id, "inv1"); - assert_eq!(events[1].id, "e2"); - assert_eq!(events[1].parent_id.as_deref(), Some("e1")); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(session.head_id.as_deref(), Some("e2")); - } - - #[tokio::test] - async fn test_find_resumable_session() { - let store = new_test_store().await; - store - .create_session("s1", Some("/home/user/project"), None) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/home/user/project"), None, 3600) - .await - .unwrap(); - assert!(found.is_some()); - assert_eq!(found.unwrap().id, "s1"); - } - - #[tokio::test] - async fn test_find_resumable_by_git_root() { - let store = new_test_store().await; - store - .create_session( - "s1", - Some("/home/user/project/sub"), - Some("/home/user/project"), - ) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/different/dir"), Some("/home/user/project"), 3600) - .await - .unwrap(); - assert!(found.is_some()); - assert_eq!(found.unwrap().id, "s1"); - } - - #[tokio::test] - async fn test_find_resumable_skips_archived() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - store.archive_session("s1").await.unwrap(); - - let found = store - .find_resumable_session(Some("/tmp"), None, 3600) - .await - .unwrap(); - assert!(found.is_none()); - } - - #[tokio::test] - async fn test_find_resumable_no_match_different_dir() { - let store = new_test_store().await; - store - .create_session("s1", Some("/home/user/project"), None) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/other/dir"), None, 3600) - .await - .unwrap(); - assert!(found.is_none()); - } - - #[tokio::test] - async fn test_archive_session() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store.archive_session("s1").await.unwrap(); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert!(session.archived_at.is_some()); - } - - #[tokio::test] - async fn test_update_server_session_id() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .update_server_session_id("s1", "server-abc") - .await - .unwrap(); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(session.server_session_id.as_deref(), Some("server-abc")); - } - - #[tokio::test] - async fn test_find_resumable_does_not_mutate() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - let before = store.get_session("s1").await.unwrap().unwrap(); - store - .find_resumable_session(Some("/tmp"), None, 3600) - .await - .unwrap() - .unwrap(); - let after = store.get_session("s1").await.unwrap().unwrap(); - - assert_eq!(before.updated_at, after.updated_at); - } - - #[tokio::test] - async fn test_events_ordered_chronologically() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .append_event("s1", "e1", None, "inv1", "user_message", "{}") - .await - .unwrap(); - store - .append_event("s1", "e2", Some("e1"), "inv1", "text", "{}") - .await - .unwrap(); - store - .append_event("s1", "e3", Some("e2"), "inv2", "user_message", "{}") - .await - .unwrap(); - - let events = store.load_events("s1").await.unwrap(); - assert_eq!(events.len(), 3); - assert!(events[0].created_at <= events[1].created_at); - assert!(events[1].created_at <= events[2].created_at); - assert_eq!(events[2].invocation_id, "inv2"); - } -} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs deleted file mode 100644 index e78dc2e1..00000000 --- a/crates/atuin-ai/src/stream.rs +++ /dev/null @@ -1,288 +0,0 @@ -// ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── - -use atuin_client::history::History; -use atuin_client::settings::AiCapabilities; - -use crate::context::history_output_capability_available; -use atuin_common::tls::ensure_crypto_provider; - -use eventsource_stream::Eventsource; -use eyre::{Context, Result}; -use futures::StreamExt; -use reqwest::Url; -use reqwest::header::USER_AGENT; - -use crate::context::ClientContext; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); - -/// Frames that alter the stream lifecycle — terminal or state-changing. -#[derive(Debug, Clone)] -pub(crate) enum StreamControl { - Done { session_id: String }, - Error(String), - StatusChanged(String), -} - -/// Frames that carry conversation content — they mutate the event log. -#[derive(Debug, Clone)] -pub(crate) enum StreamContent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option, - }, -} - -/// A frame from the SSE stream, classified as control or content. -#[derive(Debug, Clone)] -pub(crate) enum StreamFrame { - Content(StreamContent), - Control(StreamControl), -} - -/// Per-turn request payload for the chat API. -pub(crate) struct ChatRequest { - pub messages: Vec, - pub session_id: Option, - pub capabilities: Vec, - pub invocation_id: String, -} - -impl ChatRequest { - pub(crate) fn new( - messages: Vec, - session_id: Option, - capabilities: &AiCapabilities, - history_output_available: bool, - invocation_id: String, - ) -> Self { - let mut caps = vec![ - "client_invocations".to_string(), - "client_v1_load_skill".to_string(), - ]; - if capabilities.enable_history_search.unwrap_or(true) { - caps.push("client_v1_atuin_history".to_string()); - } - if capabilities.enable_file_tools.unwrap_or(true) { - caps.push("client_v1_read_file".to_string()); - caps.push("client_v1_edit_file".to_string()); - caps.push("client_v1_write_file".to_string()); - } - if capabilities.enable_command_execution.unwrap_or(true) { - caps.push("client_v1_execute_shell_command".to_string()); - } - if history_output_capability_available(history_output_available) - && capabilities.enable_history_output.unwrap_or(true) - { - caps.push("client_v1_atuin_output".to_string()); - } - if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { - caps.extend( - extra - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()), - ); - } - - Self { - messages, - session_id, - capabilities: caps, - invocation_id, - } - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) fn create_chat_stream( - hub_address: String, - token: String, - request: ChatRequest, - client_ctx: ClientContext, - send_cwd: bool, - last_command: Option, - user_contexts: Vec, - skill_summaries: Vec, - skill_overflow: Option, -) -> std::pin::Pin> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - tracing::debug!("Sending SSE request to {endpoint}"); - - let context = client_ctx.to_json(send_cwd, last_command.as_ref()); - - let mut config = serde_json::json!({ - "capabilities": request.capabilities, - }); - - if !user_contexts.is_empty() { - config["user_contexts"] = serde_json::json!(user_contexts); - } - - if !skill_summaries.is_empty() { - config["skills"] = serde_json::json!(skill_summaries); - if let Some(ref overflow) = skill_overflow { - config["skills_overflow"] = serde_json::json!(overflow); - } - } - - if let Ok(model) = std::env::var("ATUIN_AI__MODEL") - && !model.trim().is_empty() { - config["model"] = serde_json::json!(model.trim()); - - } - - - let mut request_body = serde_json::json!({ - "messages": request.messages, - "context": context, - "config": config, - "invocation_id": request.invocation_id - }); - - if let Some(ref sid) = request.session_id { - tracing::trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .header(USER_AGENT, APP_USER_AGENT) - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; - - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - tracing::error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - tracing::error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - tracing::debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - let remote = json.get("remote").and_then(|v| v.as_bool()).unwrap_or(false); - let content_length = json.get("content_length").and_then(|v| v.as_u64()).map(|v| v as usize); - yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error, remote, content_length })); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); - } else { - yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - tracing::error!("SSE error: {}", message); - yield Ok(StreamFrame::Control(StreamControl::Error(message))); - } else { - tracing::error!("SSE error: {}", data); - yield Ok(StreamFrame::Control(StreamControl::Error(data))); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -fn hub_url(base: &str, path: &str) -> Result { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") -} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs deleted file mode 100644 index 4190540c..00000000 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ /dev/null @@ -1,129 +0,0 @@ -/// Centralized metadata for a tool type. -/// -/// Covers both client-side tools (ones the CLI executes locally) and -/// server-side tools (ones the API executes remotely). This is the single -/// source of truth for display text and classification. -pub(crate) struct ToolDescriptor { - /// Canonical wire names for this tool (the names the server sends). - pub canonical_names: &'static [&'static str], - /// The capability string the client must advertise for this tool to be - /// accepted. `None` for server-side tools (always accepted). - pub capability: Option<&'static str>, - /// Imperative verb for permission prompts (e.g. "read", "run"). - pub display_verb: &'static str, - /// Present-tense progressive verb for spinners (e.g. "Reading file..."). - pub progressive_verb: &'static str, - /// Past-tense verb for summaries (e.g. "Read file"). - pub past_verb: &'static str, - /// Whether this tool is executed client-side (by the CLI). - #[expect(dead_code)] - pub is_client: bool, -} - -// ── Client-side tool descriptors ── - -pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["read_file"], - capability: Some("client_v1_read_file"), - display_verb: "read", - progressive_verb: "Reading file...", - past_verb: "Read file", - is_client: true, -}; - -pub(crate) const EDIT: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["edit_file"], - capability: Some("client_v1_edit_file"), - display_verb: "edit", - progressive_verb: "Editing file...", - past_verb: "Edited file", - is_client: true, -}; - -pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["write_file"], - capability: Some("client_v1_write_file"), - display_verb: "write to", - progressive_verb: "Writing file...", - past_verb: "Wrote file", - is_client: true, -}; - -pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["execute_shell_command"], - capability: Some("client_v1_execute_shell_command"), - display_verb: "run", - progressive_verb: "Running command...", - past_verb: "Ran command", - is_client: true, -}; - -pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["atuin_history"], - capability: Some("client_v1_atuin_history"), - display_verb: "search your Atuin history for", - progressive_verb: "Searching...", - past_verb: "Searched", - is_client: true, -}; - -pub(crate) const ATUIN_OUTPUT: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["atuin_output"], - capability: Some("client_v1_atuin_output"), - display_verb: "view the output for command", - progressive_verb: "Viewing output...", - past_verb: "Viewed output", - is_client: true, -}; - -pub(crate) const LOAD_SKILL: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["load_skill"], - capability: Some("client_v1_load_skill"), - display_verb: "load skill", - progressive_verb: "Loading skill...", - past_verb: "Loaded skill", - is_client: true, -}; - -// ── Server-side tool descriptors ── -// These appear in tool summaries but aren't client-side tools. - -pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["web_search"], - capability: None, - display_verb: "search", - progressive_verb: "Searching...", - past_verb: "Searched", - is_client: false, -}; - -pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["web_scrape"], - capability: None, - display_verb: "scrape", - progressive_verb: "Scraping...", - past_verb: "Scraped", - is_client: false, -}; - -/// All known tool descriptors, for lookup by name. -const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ - READ, - EDIT, - WRITE, - SHELL, - ATUIN_HISTORY, - ATUIN_OUTPUT, - LOAD_SKILL, - SERVER_SEARCH, - SERVER_SCRAPE, -]; - -/// Look up a tool descriptor by its canonical wire name. -/// Returns None for unknown tool names. -pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { - ALL_DESCRIPTORS - .iter() - .find(|d| d.canonical_names.contains(&name)) - .copied() -} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs deleted file mode 100644 index d1352661..00000000 --- a/crates/atuin-ai/src/tools/mod.rs +++ /dev/null @@ -1,2159 +0,0 @@ -use std::{ - io::BufRead, - path::{Path, PathBuf}, - time::Duration, -}; - -use eyre::Result; -use uuid::Uuid; - -const DEFAULT_FILE_READ_LINES: u64 = 100; -const MAX_FILE_READ_LINES: u64 = 1000; - -pub(crate) mod descriptor; - -use crate::permissions::rule::Rule; - -/// Check whether a file path matches a scope glob pattern. -/// -/// Resolves relative paths against the current directory before matching so -/// that `./foo.md` and `/cwd/foo.md` match the same glob. Supports `*`, `**`, -/// `?`, and `[...]` via `glob_match`. -fn path_matches_scope(path: &Path, scope: &str) -> bool { - let path = if path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(path)) - .unwrap_or_else(|_| path.to_path_buf()) - } else { - path.to_path_buf() - }; - // Normalize to forward slashes so globs work on Windows too. - let path_str = path.to_string_lossy().replace('\\', "/"); - - // If the scope is also relative, try matching against both the absolute - // path and just the filename/relative portion. - if !scope.starts_with('/') { - // Match against filename (e.g. "*.md" matches any .md file) - if let Some(name) = path.file_name().and_then(|n| n.to_str()) - && glob_match::glob_match(scope, name) - { - return true; - } - // Also try matching against the full absolute path in case the scope - // is a relative multi-segment pattern like "crates/**/*.rs" - if glob_match::glob_match(scope, &path_str) { - return true; - } - // And match relative to cwd (so "src/*.rs" works from project root) - if let Ok(cwd) = std::env::current_dir() - && let Ok(rel) = path.strip_prefix(&cwd) - { - let rel_str = rel.to_string_lossy().replace('\\', "/"); - return glob_match::glob_match(scope, &rel_str); - } - return false; - } - - // Absolute scope — match against absolute path - glob_match::glob_match(scope, &path_str) -} - -/// Result of executing a client-side tool. -#[derive(Debug, Clone)] -pub(crate) enum ToolOutcome { - /// Simple success with a text result (used by Read, AtuinHistory). - Success(String), - /// Error with a message. - Error(String), - /// Structured shell result with separated stdout, stderr, exit code, and duration. - Structured { - stdout: String, - stderr: String, - exit_code: Option, - duration_ms: u64, - interrupted: bool, - }, -} - -impl ToolOutcome { - /// Format this outcome as a string for the tool result sent to the LLM. - /// - /// The optional `interrupt_reason` overrides the generic interrupted message - /// with a specific one (user interrupt vs timeout). - pub fn format_for_llm( - &self, - interrupt_reason: Option<&crate::fsm::tools::InterruptReason>, - ) -> String { - match self { - ToolOutcome::Success(s) => s.clone(), - ToolOutcome::Error(e) => e.clone(), - ToolOutcome::Structured { - stdout, - stderr, - exit_code, - duration_ms, - interrupted, - } => { - let mut parts = Vec::new(); - - if let Some(code) = exit_code { - parts.push(format!("Exit code: {code}")); - } - - parts.push(format!("Duration: {duration_ms}ms")); - - if !stdout.is_empty() { - parts.push(format!("stdout:\n{stdout}")); - } else { - parts.push("stdout: (empty)".to_string()); - } - - if !stderr.is_empty() { - parts.push(format!("stderr:\n{stderr}")); - } else { - parts.push("stderr: (empty)".to_string()); - } - - if *interrupted { - use crate::fsm::tools::InterruptReason; - let msg = match interrupt_reason { - Some(InterruptReason::Timeout(secs)) => { - format!("[Timed out after {secs}s]") - } - _ => "[Interrupted by user]".to_string(), - }; - parts.push(msg); - } - - parts.join("\n\n") - } - } - } - - /// Whether this outcome represents an error. - pub fn is_error(&self) -> bool { - match self { - ToolOutcome::Error(_) => true, - ToolOutcome::Structured { - exit_code: Some(code), - .. - } if *code != 0 => true, - _ => false, - } - } -} - -/// Cached VT100 preview data for a shell tool call. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ToolPreview { - pub lines: Vec, - pub exit_code: Option, - pub interrupted: Option, -} - -/// A tool call from the server, with parsed input parameters. -#[derive(Debug, Clone)] -pub(crate) enum ClientToolCall { - Read(ReadToolCall), - Edit(EditToolCall), - Write(WriteToolCall), - Shell(ShellToolCall), - AtuinHistory(AtuinHistoryToolCall), - AtuinOutput(AtuinOutputToolCall), - LoadSkill(LoadSkillToolCall), -} - -impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { - type Error = eyre::Error; - - fn try_from((name, input): (&str, &serde_json::Value)) -> Result { - match name { - "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), - "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)), - "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), - "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), - "atuin_history" => Ok(ClientToolCall::AtuinHistory( - AtuinHistoryToolCall::try_from(input)?, - )), - "atuin_output" => Ok(ClientToolCall::AtuinOutput(AtuinOutputToolCall::try_from( - input, - )?)), - "load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from( - input, - )?)), - _ => Err(eyre::eyre!("Unknown tool call: {name}")), - } - } -} - -impl ClientToolCall { - pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { - match self { - ClientToolCall::Read(_) => descriptor::READ, - ClientToolCall::Edit(_) => descriptor::EDIT, - ClientToolCall::Write(_) => descriptor::WRITE, - ClientToolCall::Shell(_) => descriptor::SHELL, - ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, - ClientToolCall::AtuinOutput(_) => descriptor::ATUIN_OUTPUT, - ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL, - } - } - - /// The permission rule name for this tool category. - /// - /// Edit and Write share the `"Write"` rule name — a Write permission - /// covers both str_replace edits and full file creates. Write also - /// implies Read (checked in `ReadToolCall::matches_rule`). - pub(crate) fn rule_name(&self) -> &'static str { - match self { - ClientToolCall::Read(_) => "Read", - ClientToolCall::Edit(_) => "Write", - ClientToolCall::Write(_) => "Write", - ClientToolCall::Shell(_) => "Shell", - ClientToolCall::AtuinHistory(_) => "AtuinHistory", - ClientToolCall::AtuinOutput(_) => "AtuinOutput", - ClientToolCall::LoadSkill(_) => "LoadSkill", - } - } - - /// The resolved file path for this tool call, if it's a file-based tool. - /// Used to build scoped permission rules like `Write(/abs/path/to/file)`. - pub(crate) fn resolved_file_path(&self) -> Option { - match self { - ClientToolCall::Read(tool) => Some(tool.resolved_path()), - ClientToolCall::Edit(tool) => Some(tool.resolved_path()), - ClientToolCall::Write(tool) => Some(tool.resolved_path()), - ClientToolCall::Shell(_) - | ClientToolCall::AtuinHistory(_) - | ClientToolCall::AtuinOutput(_) - | ClientToolCall::LoadSkill(_) => None, - } - } - - pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { - match self { - ClientToolCall::Read(tool) => tool.matches_rule(rule), - ClientToolCall::Edit(tool) => tool.matches_rule(rule), - ClientToolCall::Write(tool) => tool.matches_rule(rule), - ClientToolCall::Shell(tool) => tool.matches_rule(rule), - ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), - ClientToolCall::AtuinOutput(tool) => tool.matches_rule(rule), - ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule), - } - } - - pub(crate) fn target_dir(&self) -> Option<&Path> { - match self { - ClientToolCall::Read(tool) => tool.target_dir(), - ClientToolCall::Edit(tool) => tool.target_dir(), - ClientToolCall::Write(tool) => tool.target_dir(), - ClientToolCall::Shell(tool) => tool.target_dir(), - ClientToolCall::AtuinHistory(tool) => tool.target_dir(), - ClientToolCall::AtuinOutput(tool) => tool.target_dir(), - ClientToolCall::LoadSkill(tool) => tool.target_dir(), - } - } -} - -/// A trait for tool calls that can be checked against permission rules. -pub(crate) trait PermissibleToolCall { - /// Checks if this tool call matches the given permission rule. - fn matches_rule(&self, rule: &Rule) -> bool; - - /// Check if every part of this tool call is covered by at least one rule in - /// the set. For compound operations (e.g. shell pipelines), each sub-part - /// must be individually covered. The default treats the call as atomic — - /// any single matching rule is sufficient. - fn all_covered_by(&self, rules: &[Rule]) -> bool { - rules.iter().any(|r| self.matches_rule(r)) - } - - /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. - fn target_dir(&self) -> Option<&Path> { - None - } -} - -impl PermissibleToolCall for ClientToolCall { - fn matches_rule(&self, rule: &Rule) -> bool { - self.matches_rule(rule) - } - - fn all_covered_by(&self, rules: &[Rule]) -> bool { - match self { - ClientToolCall::Shell(tool) => tool.all_covered_by(rules), - // LoadSkill is always auto-approved, but support rules for completeness - _ => rules.iter().any(|r| self.matches_rule(r)), - } - } - - fn target_dir(&self) -> Option<&Path> { - self.target_dir() - } -} - -/// Returns true if this tool call should bypass the permission system entirely. -impl ClientToolCall { - pub(crate) fn is_auto_approved(&self) -> bool { - matches!(self, ClientToolCall::LoadSkill(_)) - } -} - -/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string. -/// -/// Tool call paths arrive as raw strings from the API without shell -/// expansion. Uses `shellexpand` (same as `atuin-client`). -fn expand_path(path: &str) -> PathBuf { - PathBuf::from(shellexpand::tilde(path).into_owned()) -} - -#[derive(Debug, Clone)] -pub(crate) struct ReadToolCall { - pub path: PathBuf, - pub offset: u64, - pub limit: u64, -} - -impl TryFrom<&serde_json::Value> for ReadToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing path"))?; - - let offset = value.get("offset").and_then(|v| v.as_u64()).unwrap_or(0); - let limit = value - .get("limit") - .and_then(|v| v.as_u64()) - .unwrap_or(DEFAULT_FILE_READ_LINES) - .min(MAX_FILE_READ_LINES); - - Ok(ReadToolCall { - path: expand_path(path), - offset, - limit, - }) - } -} - -impl ReadToolCall { - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - pub fn execute(&self) -> ToolOutcome { - let path = self.resolved_path(); - - if !path.exists() { - return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); - } - - if path.is_dir() { - let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { - entries - .filter_map(|entry| entry.ok()) - .map(|entry| entry.file_name().to_string_lossy().to_string()) - .collect::>() - .into() - }) else { - return ToolOutcome::Error(format!( - "Error: could not read directory: {}", - path.display() - )); - }; - - return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); - } - - let file = match std::fs::File::open(&path) { - Ok(file) => file, - Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), - }; - let reader = std::io::BufReader::new(file); - - let raw_lines = reader - .lines() - .skip(self.offset as usize) - .take(self.limit as usize) - .collect::, _>>(); - - match raw_lines { - Ok(lines) => { - let first_line_no = self.offset as usize + 1; - let last_line_no = first_line_no + lines.len().saturating_sub(1); - let width = last_line_no.max(1).ilog10() as usize + 1; - - let numbered: String = lines - .iter() - .enumerate() - .map(|(i, line)| format!("{:>width$}\t{line}", first_line_no + i)) - .collect::>() - .join("\n"); - - if numbered.len() > 100_000 { - ToolOutcome::Error(format!( - "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", - numbered.len(), - lines.len() - )) - } else { - ToolOutcome::Success(numbered) - } - } - Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), - } - } -} - -impl PermissibleToolCall for ReadToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - // Write implies Read — a Write permission on a path also permits reading it. - if rule.tool != "Read" && rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct EditToolCall { - pub path: PathBuf, - pub old_string: String, - pub new_string: String, - pub replace_all: bool, -} - -impl TryFrom<&serde_json::Value> for EditToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing file_path"))?; - - let old_string = value - .get("old_string") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing old_string"))?; - - let new_string = value - .get("new_string") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing new_string"))?; - - let replace_all = value - .get("replace_all") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - Ok(EditToolCall { - path: expand_path(path), - old_string: old_string.to_string(), - new_string: new_string.to_string(), - replace_all, - }) - } -} - -impl EditToolCall { - /// Resolve the edit path to an absolute path. - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - /// Execute the edit against the filesystem. - /// - /// Checks freshness via the provided tracker, validates matches, applies - /// the replacement, and writes atomically. Returns the outcome and (on - /// success) the new file content bytes for tracker updates. - /// - /// Callers should snapshot the file before calling this method and - /// update the file tracker after a successful return. - pub fn execute( - &self, - resolved_path: &Path, - file_tracker: &crate::file_tracker::FileReadTracker, - ) -> (ToolOutcome, Option>) { - use crate::file_tracker::FreshnessCheck; - - // 1. Basic validation - if !resolved_path.exists() { - return ( - ToolOutcome::Error(format!( - "Error: file does not exist: {}", - resolved_path.display() - )), - None, - ); - } - if resolved_path.is_dir() { - return ( - ToolOutcome::Error(format!( - "Error: path is a directory, not a file: {}", - resolved_path.display() - )), - None, - ); - } - if self.old_string.is_empty() { - return ( - ToolOutcome::Error( - "old_string must not be empty. To create a new file, use write_file instead." - .to_string(), - ), - None, - ); - } - - // 2. Freshness check - match file_tracker.check_freshness(resolved_path) { - Ok(FreshnessCheck::NotRead) => { - return ( - ToolOutcome::Error( - "File has not been read yet. Read it first before editing.".to_string(), - ), - None, - ); - } - Ok(FreshnessCheck::Stale) => { - return ( - ToolOutcome::Error( - "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(), - ), - None, - ); - } - Err(e) => { - return ( - ToolOutcome::Error(format!("Error checking file state: {e}")), - None, - ); - } - Ok(FreshnessCheck::Fresh) => {} - } - - // 3. Read current contents - let content = match std::fs::read_to_string(resolved_path) { - Ok(c) => c, - Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None), - }; - - // 4. Find and validate matches - let match_count = content.matches(&self.old_string).count(); - - if match_count == 0 { - return ( - ToolOutcome::Error(format!( - "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.", - resolved_path.display() - )), - None, - ); - } - - if match_count > 1 && !self.replace_all { - return ( - ToolOutcome::Error(format!( - "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.", - resolved_path.display() - )), - None, - ); - } - - // 5. Apply replacement - let new_content = if self.replace_all { - content.replace(&self.old_string, &self.new_string) - } else { - content.replacen(&self.old_string, &self.new_string, 1) - }; - - // 6. Write atomically - let new_bytes = new_content.into_bytes(); - if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) { - return (ToolOutcome::Error(format!("Error writing file: {e}")), None); - } - - // 7. Success - let verb = if match_count == 1 { - "occurrence" - } else { - "occurrences" - }; - ( - ToolOutcome::Success(format!( - "Edited {}: replaced {match_count} {verb} of old_string with new_string.", - resolved_path.display() - )), - Some(new_bytes), - ) - } -} - -impl PermissibleToolCall for EditToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct WriteToolCall { - pub path: PathBuf, - pub content: String, - pub overwrite: bool, -} - -impl TryFrom<&serde_json::Value> for WriteToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing file_path"))?; - - let content = value - .get("content") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing content"))?; - - let overwrite = value - .get("overwrite") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - Ok(WriteToolCall { - path: expand_path(path), - content: content.to_string(), - overwrite, - }) - } -} - -impl WriteToolCall { - /// Resolve the write path to an absolute path. - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - /// Execute the write operation. - /// - /// Creates a new file or overwrites an existing one (if `overwrite` is set). - /// Returns the outcome and the written bytes (for tracker updates). - pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option>) { - if resolved_path.is_dir() { - return ( - ToolOutcome::Error(format!( - "Error: path is a directory, not a file: {}", - resolved_path.display() - )), - None, - ); - } - if resolved_path.exists() && !self.overwrite { - return ( - ToolOutcome::Error(format!( - "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.", - resolved_path.display() - )), - None, - ); - } - - // Capture before the write — after atomic_write the file always exists. - let existed = resolved_path.exists(); - - // Write atomically - let content_bytes = self.content.as_bytes().to_vec(); - if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) { - return (ToolOutcome::Error(format!("Error writing file: {e}")), None); - } - - let line_count = self.content.lines().count(); - let verb = if existed { "Overwrote" } else { "Created" }; - ( - ToolOutcome::Success(format!( - "{verb} {} ({line_count} lines).", - resolved_path.display() - )), - Some(content_bytes), - ) - } -} - -impl PermissibleToolCall for WriteToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct ShellToolCall { - pub dir: Option, - pub command: String, - pub shell: String, - /// Maximum execution time in seconds (from LLM). Clamped to 1..=600, default 30. - pub timeout_secs: u64, - // allow dead code here; this will be tied into o11y and user-facing descriptions - #[expect(dead_code)] - pub description: Option, -} - -impl TryFrom<&serde_json::Value> for ShellToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let dir = value.get("dir").and_then(|v| v.as_str()); - - let command = value - .get("command") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing command"))?; - - let shell = value - .get("shell") - .and_then(|v| v.as_str()) - .unwrap_or("bash") - .to_string(); - - let timeout_secs = value - .get("timeout") - .and_then(|v| v.as_u64()) - .filter(|&v| v > 0) - .unwrap_or(30) - .min(600); - - let description = value - .get("description") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - Ok(ShellToolCall { - dir: dir.map(expand_path), - command: command.to_string(), - shell, - timeout_secs, - description, - }) - } -} - -impl PermissibleToolCall for ShellToolCall { - fn target_dir(&self) -> Option<&Path> { - self.dir.as_deref() - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Shell" { - return false; - } - - let Some(scope) = rule.scope.as_ref() else { - // Shell without scope matches all shell commands - return true; - }; - - let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); - let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); - // Deny/ask path: prefix_bare = true so `deny = ["Shell(rm)"]` blocks `rm -rf /` - crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, true, scope) - } - - /// For compound shell commands, every subcommand must be individually - /// covered by at least one rule. This ensures that `allow = ["Shell(git *)"]` - /// does not silently permit `git add . && rm -rf /`. - fn all_covered_by(&self, rules: &[Rule]) -> bool { - use crate::permissions::shell; - - let shell_kind = shell::ShellKind::from_shell_name(&self.shell); - let parsed = shell::parse_shell_command(&self.command, shell_kind); - - // If parsing yields nothing, don't vacuously allow — fall through to ask. - !parsed.subcommands.is_empty() - && parsed.subcommands.iter().all(|subcmd| { - rules.iter().any(|rule| { - if rule.tool != "Shell" { - return false; - } - match rule.scope.as_deref() { - None | Some("*") => true, - // Allow path: prefix_bare = false so `Shell(git commit)` - // only allows exactly `git commit`, not `git commit --amend` - Some(scope) => shell::any_subcommand_matches( - std::slice::from_ref(subcmd), - false, - scope, - ), - } - }) - }) - } -} - -/// Preview viewport height for VT100 emulation. -const PREVIEW_HEIGHT: u16 = 10; - -/// Default terminal width for VT100 emulation. -const PREVIEW_WIDTH: u16 = 120; - -/// Normalize newlines for VT100 processing. -/// -/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes -/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would -/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without -/// returning to column 0. This causes lines to start at progressively higher -/// column offsets and eventually wrap, producing garbled output. -/// -/// This function inserts `\r` before any `\n` that isn't already preceded by -/// `\r`, mimicking the terminal driver's ONLCR behavior. -fn normalize_newlines_for_vt100(data: &[u8]) -> Vec { - let mut out = Vec::with_capacity(data.len() + data.len() / 8); - for (i, &b) in data.iter().enumerate() { - if b == b'\n' && (i == 0 || data[i - 1] != b'\r') { - out.push(b'\r'); - } - out.push(b); - } - out -} - -/// Extract plain text lines from a VT100 screen buffer. -/// -/// Strips trailing blank lines so the result only contains rows with actual -/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows) -/// would always return that many lines, and downstream components that use -/// tail-mode display (like the Viewport) would show the blank padding rows -/// instead of the real output. -fn vt100_screen_lines(screen: &vt100::Screen) -> Vec { - let (rows, cols) = screen.size(); - let mut lines = Vec::with_capacity(rows as usize); - for row in 0..rows { - let mut line = String::with_capacity(cols as usize); - for col in 0..cols { - if let Some(cell) = screen.cell(row, col) { - line.push_str(cell.contents()); - } - } - lines.push(line.trim_end().to_string()); - } - while lines.last().is_some_and(|l| l.is_empty()) { - lines.pop(); - } - lines -} - -/// Strip ANSI escape sequences from raw bytes using a VT100 parser. -/// -/// Uses a large virtual screen so scrollback is preserved, then extracts -/// the plain text contents. This handles all escape sequences (colors, -/// cursor movement, progress bars, etc.) not just simple SGR codes. -fn strip_ansi_via_vt100(raw: &[u8]) -> String { - if raw.is_empty() { - return String::new(); - } - // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen. - let normalized = normalize_newlines_for_vt100(raw); - // Feed bytes into a VT100 parser large enough to hold all output, then - // read back the plain text. We estimate rows from the number of newlines - // (not total byte length) because real output typically has short lines - // that would be severely under-counted by a bytes÷width estimate. - let newline_count = normalized.iter().filter(|&&b| b == b'\n').count(); - let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize; - let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16; - let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); - parser.process(&normalized); - let screen = parser.screen(); - // screen.contents() returns the full plain-text content with trailing - // whitespace trimmed per line and trailing blank lines removed. - screen.contents() -} - -/// Execute a shell command with VT100 emulation and streaming output. -/// -/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, -/// progress bars (`\r`), and cursor movement are handled correctly. Periodically -/// sends the current screen state as `Vec` through `output_tx` for the -/// live preview. -/// -/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. -/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. -pub(crate) async fn execute_shell_command_streaming( - shell_call: &ShellToolCall, - output_tx: tokio::sync::mpsc::Sender>, - mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, -) -> ToolOutcome { - use tokio::io::AsyncReadExt; - - let start = std::time::Instant::now(); - - // TODO: check if this is proper for all shells we support - let mut cmd = tokio::process::Command::new(&shell_call.shell); - cmd.arg("-c").arg(&shell_call.command); - cmd.stdout(std::process::Stdio::piped()); - cmd.stderr(std::process::Stdio::piped()); - - if let Some(ref dir) = shell_call.dir { - cmd.current_dir(dir); - } - - let mut child = match cmd.spawn() { - Ok(child) => child, - Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), - }; - - let stdout = child.stdout.take().expect("stdout was piped"); - let stderr = child.stderr.take().expect("stderr was piped"); - - // VT100 emulator for the live preview (viewport-sized) - let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); - - let mut stdout_reader = tokio::io::BufReader::new(stdout); - let mut stderr_reader = tokio::io::BufReader::new(stderr); - - let mut stdout_buf = [0u8; 4096]; - let mut stderr_buf = [0u8; 4096]; - let mut stdout_done = false; - let mut stderr_done = false; - - // Full output buffers (for the LLM, not the preview) - let mut full_stdout = Vec::::new(); - let mut full_stderr = Vec::::new(); - - let mut interval = tokio::time::interval(Duration::from_millis(50)); - - // Send initial empty screen - let initial_lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(initial_lines).await; - - let mut interrupted = false; - - loop { - tokio::select! { - biased; - - // Check for interrupt signal - _ = &mut interrupt_rx, if !interrupted => { - interrupted = true; - let _ = child.start_kill(); - } - - // Read stdout - result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { - match result { - Ok(0) => stdout_done = true, - Ok(n) => { - full_stdout.extend_from_slice(&stdout_buf[..n]); - let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]); - parser.process(&normalized); - } - Err(_) => stdout_done = true, - } - } - - // Read stderr - result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { - match result { - Ok(0) => stderr_done = true, - Ok(n) => { - full_stderr.extend_from_slice(&stderr_buf[..n]); - // Feed stderr to the preview parser too, so it shows in the VT100 screen - let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]); - parser.process(&normalized); - } - Err(_) => stderr_done = true, - } - } - - // Periodic screen snapshot for preview - _ = interval.tick() => { - let lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(lines).await; - } - } - - // Exit when both streams are done - if stdout_done && stderr_done { - break; - } - } - - // Wait for process to finish - let exit_code = match child.wait().await { - Ok(status) => status.code(), - Err(e) => { - if interrupted { - None - } else { - return ToolOutcome::Error(format!("Failed to wait for command: {e}")); - } - } - }; - - let duration = start.elapsed(); - - // Send final screen state - let final_lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(final_lines).await; - - // Strip ANSI escape sequences for clean LLM output by running - // the raw bytes through a VT100 parser and extracting plain text. - let stdout_text = strip_ansi_via_vt100(&full_stdout); - let stderr_text = strip_ansi_via_vt100(&full_stderr); - - ToolOutcome::Structured { - stdout: stdout_text, - stderr: stderr_text, - exit_code, - duration_ms: duration.as_millis() as u64, - interrupted, - } -} - -#[derive(Debug, Clone)] -pub(crate) struct AtuinHistoryToolCall { - pub filter_modes: Vec, - pub query: String, - pub limit: i64, -} - -#[derive(Debug, Clone)] -pub(crate) enum HistorySearchFilterMode { - Global, - Host, - Session, - Directory, - Workspace, -} - -impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { - fn from(mode: &HistorySearchFilterMode) -> Self { - match mode { - HistorySearchFilterMode::Global => Self::Global, - HistorySearchFilterMode::Host => Self::Host, - HistorySearchFilterMode::Session => Self::Session, - HistorySearchFilterMode::Directory => Self::Directory, - HistorySearchFilterMode::Workspace => Self::Workspace, - } - } -} - -impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let filter_modes = value - .get("filter_modes") - .and_then(|v| v.as_array()) - .ok_or(eyre::eyre!("Missing filter_modes"))?; - - let filter_modes = filter_modes - .iter() - .map(|v| { - let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; - match mode { - "global" => Ok(HistorySearchFilterMode::Global), - "host" => Ok(HistorySearchFilterMode::Host), - "session" => Ok(HistorySearchFilterMode::Session), - "directory" => Ok(HistorySearchFilterMode::Directory), - "workspace" => Ok(HistorySearchFilterMode::Workspace), - _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), - } - }) - .collect::>>()?; - - let query = value - .get("query") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing query"))?; - - let limit = value - .get("limit") - .and_then(|v| v.as_i64()) - .unwrap_or(10) - .clamp(1, 50); - - Ok(AtuinHistoryToolCall { - filter_modes, - query: query.to_string(), - limit, - }) - } -} - -impl PermissibleToolCall for AtuinHistoryToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "AtuinHistory" - } -} - -impl AtuinHistoryToolCall { - pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { - use atuin_client::database::{self, Database as _, OptFilters}; - use atuin_client::settings::SearchMode; - - let context = match database::current_context().await { - Ok(ctx) => ctx, - Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), - }; - - let filter_mode = self - .filter_modes - .first() - .map(atuin_client::settings::FilterMode::from) - .unwrap_or(atuin_client::settings::FilterMode::Global); - - let filter_options = OptFilters { - limit: Some(self.limit), - ..Default::default() - }; - - let results = match db - .search( - SearchMode::Fuzzy, - filter_mode, - &context, - &self.query, - filter_options, - ) - .await - { - Ok(results) => results, - Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), - }; - - if results.is_empty() { - return ToolOutcome::Success("No matching history entries found.".to_string()); - } - - let local_offset = crate::history_format::current_local_offset(); - - let formatted: Vec = results - .iter() - .enumerate() - .map(|(i, history)| { - crate::history_format::format_history_search_result(i + 1, history, local_offset) - }) - .collect(); - - ToolOutcome::Success(formatted.join("\n")) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct AtuinOutputToolCall { - pub history_id: Uuid, - pub ranges: Vec<(i64, i64)>, -} - -impl TryFrom<&serde_json::Value> for AtuinOutputToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let history_id = value - .get("history_id") - .and_then(|v| v.as_str()) - .and_then(|v| Uuid::parse_str(v).ok()) - .ok_or(eyre::eyre!("Missing or invalid history ID"))?; - - let ranges = value - .get("ranges") - .and_then(|v| v.as_array()) - .map(Vec::as_slice) - .unwrap_or(&[]); - - let ranges = ranges - .iter() - .map(|r| { - let range = r - .as_array() - .filter(|a| a.len() == 2) - .ok_or_else(|| eyre::eyre!("Each range must be a [start, end] array"))?; - - let start = range[0] - .as_i64() - .ok_or_else(|| eyre::eyre!("Range start must be an integer"))?; - let end = range[1] - .as_i64() - .ok_or_else(|| eyre::eyre!("Range end must be an integer"))?; - - Ok((start, end)) - }) - .collect::, eyre::Error>>()?; - - Ok(Self { history_id, ranges }) - } -} - -impl PermissibleToolCall for AtuinOutputToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "AtuinOutput" - } -} - -fn format_output_lines_for_llm(lines: &[atuin_daemon::semantic::OutputLine]) -> String { - let width = lines - .iter() - .map(|line| line.line_number) - .max() - .unwrap_or(1) - .max(1) - .ilog10() as usize - + 1; - let mut formatted = Vec::with_capacity(lines.len()); - let mut previous_line_number = None; - - for line in lines { - if let Some(previous) = previous_line_number { - let skipped = line.line_number.saturating_sub(previous + 1); - if skipped > 0 { - formatted.push(format!("[...skipped {skipped} lines...]")); - } - } - - formatted.push(format!("{:>width$}\t{}", line.line_number, line.content)); - previous_line_number = Some(line.line_number); - } - - formatted.join("\n") -} - -impl AtuinOutputToolCall { - pub(crate) async fn execute(&self) -> ToolOutcome { - let settings = match atuin_client::settings::Settings::new() { - Ok(settings) => settings, - Err(e) => return ToolOutcome::Error(format!("Failed to load Atuin settings: {e}")), - }; - - let mut client = match atuin_daemon::SemanticClient::from_settings(&settings).await { - Ok(client) => client, - Err(e) => return ToolOutcome::Error(format!("Failed to connect to Atuin daemon: {e}")), - }; - - let history_id = self.history_id.as_simple().to_string(); - let response = match client - .command_output(history_id.clone(), self.ranges.clone()) - .await - { - Ok(response) => response, - Err(e) => return ToolOutcome::Error(format!("Failed to fetch command output: {e}")), - }; - - if !response.found { - return ToolOutcome::Success(format!( - "No captured output found for history ID {history_id}." - )); - } - - if response.total_lines == 0 { - return ToolOutcome::Success(format!( - "Captured output for history ID {history_id} is empty." - )); - } - - let output = format_output_lines_for_llm(&response.lines); - if output.is_empty() { - return ToolOutcome::Success(format!( - "No lines selected from captured output for history ID {history_id}." - )); - } - - let total_output = if response.output_truncated { - format!( - "{} bytes captured, {} bytes observed before truncation, {} lines", - response.total_bytes, response.output_observed_bytes, response.total_lines - ) - } else { - format!( - "{} bytes, {} lines", - response.total_bytes, response.total_lines - ) - }; - - ToolOutcome::Success(format!( - "History ID: {history_id}\nTotal output: {total_output}\nSelected output:\n{output}" - )) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct LoadSkillToolCall { - pub name: String, -} - -impl TryFrom<&serde_json::Value> for LoadSkillToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result { - let name = value - .get("name") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing skill name"))?; - - Ok(LoadSkillToolCall { - name: name.to_string(), - }) - } -} - -impl PermissibleToolCall for LoadSkillToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "LoadSkill" - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn read_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Read".to_string(), - scope: scope.map(String::from), - } - } - - fn write_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Write".to_string(), - scope: scope.map(String::from), - } - } - - fn read_tool(path: &str) -> ReadToolCall { - ReadToolCall { - path: expand_path(path), - offset: 0, - limit: 100, - } - } - - fn write_tool(path: &str) -> WriteToolCall { - WriteToolCall { - path: expand_path(path), - content: String::new(), - overwrite: false, - } - } - - // ── Cross-platform tests ── - - #[test] - fn atuin_output_ranges_are_optional() { - let input = serde_json::json!({ - "history_id": "018f0000000070008000000000000000" - }); - - let call = AtuinOutputToolCall::try_from(&input).unwrap(); - - assert_eq!( - call.history_id.as_simple().to_string(), - "018f0000000070008000000000000000" - ); - assert!(call.ranges.is_empty()); - } - - #[test] - fn atuin_output_parses_line_ranges() { - let input = serde_json::json!({ - "history_id": "018f0000000070008000000000000000", - "ranges": [[0, 30], [-100, -1]] - }); - - let call = AtuinOutputToolCall::try_from(&input).unwrap(); - - assert_eq!(call.ranges, vec![(0, 30), (-100, -1)]); - } - - #[test] - fn atuin_output_formats_lines_like_read_file() { - let lines = vec![ - atuin_daemon::semantic::OutputLine { - line_number: 98, - content: "near end".to_string(), - }, - atuin_daemon::semantic::OutputLine { - line_number: 100, - content: "end".to_string(), - }, - ]; - - assert_eq!( - format_output_lines_for_llm(&lines), - " 98\tnear end\n[...skipped 1 lines...]\n100\tend" - ); - } - - #[test] - fn no_scope_matches_everything() { - assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); - assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); - } - - #[test] - fn wildcard_star_matches_everything() { - assert!(read_tool("foo/bar.rs").matches_rule(&read_rule(Some("*")))); - } - - #[test] - fn write_implies_read() { - // A Write rule also permits reads on the same path - assert!(read_tool("foo.txt").matches_rule(&write_rule(None))); - // But a Read rule does not permit writes - assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); - } - - #[test] - fn edit_uses_write_rule() { - let edit = EditToolCall { - path: expand_path("/home/user/config.toml"), - old_string: "x".into(), - new_string: "y".into(), - replace_all: false, - }; - assert!(edit.matches_rule(&write_rule(None))); - assert!(!edit.matches_rule(&read_rule(None))); - } - - #[test] - fn extension_glob() { - assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); - assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); - } - - #[test] - fn relative_multi_segment_glob() { - // This matches against the path relative to cwd - let cwd = std::env::current_dir().unwrap(); - let abs = cwd - .join("crates") - .join("atuin-ai") - .join("src") - .join("lib.rs"); - let tool = read_tool(abs.to_str().unwrap()); - assert!(tool.matches_rule(&read_rule(Some("crates/**/*.rs")))); - assert!(!tool.matches_rule(&read_rule(Some("crates/**/*.py")))); - } - - // ── all_covered_by tests (compound shell command semantics) ── - - fn shell_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Shell".to_string(), - scope: scope.map(String::from), - } - } - - fn shell_tool(command: &str) -> ShellToolCall { - ShellToolCall { - dir: None, - command: command.to_string(), - shell: "bash".to_string(), - timeout_secs: 30, - description: None, - } - } - - #[test] - fn all_covered_by_simple_command() { - let rules = vec![shell_rule(Some("git *"))]; - assert!(shell_tool("git add .").all_covered_by(&rules)); - assert!(!shell_tool("npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_compound_all_covered() { - let rules = vec![shell_rule(Some("git *")), shell_rule(Some("npm *"))]; - assert!(shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_compound_partially_covered() { - // Only git is allowed — npm subcommand is not covered, so the - // compound command must not be auto-allowed. - let rules = vec![shell_rule(Some("git *"))]; - assert!(!shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_unscoped_shell_rule() { - // Shell without scope covers everything - let rules = vec![shell_rule(None)]; - assert!(shell_tool("git add . && rm -rf /").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_wildcard_shell_rule() { - let rules = vec![shell_rule(Some("*"))]; - assert!(shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_non_shell_tool_unchanged() { - // Non-shell tools use the default (any single rule matches) - let rules = vec![read_rule(Some("*.md"))]; - assert!(read_tool("notes.md").all_covered_by(&rules)); - assert!(!read_tool("notes.txt").all_covered_by(&rules)); - } - - #[test] - fn matches_rule_still_uses_any_semantics() { - // matches_rule (used for deny/ask) still triggers on any subcommand - let rule = shell_rule(Some("rm *")); - assert!(shell_tool("git add . && rm -rf /").matches_rule(&rule)); - } - - #[test] - fn bare_pattern_asymmetry() { - // Deny (matches_rule, prefix_bare=true): bare "rm" blocks "rm -rf /" - let deny_rule = shell_rule(Some("rm")); - assert!(shell_tool("rm -rf /").matches_rule(&deny_rule)); - - // Allow (all_covered_by, prefix_bare=false): bare "rm" only allows exactly "rm" - let allow_rules = vec![shell_rule(Some("rm"))]; - assert!(shell_tool("rm").all_covered_by(&allow_rules)); - assert!(!shell_tool("rm -rf /").all_covered_by(&allow_rules)); - - // Bare prefix match is word-boundary, not substring — "rm" must not match "rmbackup" - assert!(!shell_tool("rmbackup").matches_rule(&deny_rule)); - assert!(!shell_tool("rmbackup /tmp").matches_rule(&deny_rule)); - } - - // ── Unix-specific tests (absolute paths with forward slashes) ── - - #[cfg(unix)] - mod unix { - use super::*; - - #[test] - fn absolute_glob() { - assert!( - read_tool("/home/user/src/main.rs") - .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) - ); - assert!( - !read_tool("/home/user/docs/readme.md") - .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) - ); - } - - #[test] - fn double_star_glob() { - assert!( - read_tool("/project/crates/foo/src/lib.rs") - .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) - ); - assert!( - !read_tool("/project/crates/foo/src/lib.py") - .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) - ); - } - } - - // ── edit_file execution tests ── - - mod edit { - use super::*; - use crate::file_tracker::FileReadTracker; - - /// Helper: create a temp file (with a closed handle), record it in a tracker. - /// Returns the TempDir (keeps the path alive) and tracker. - /// The file handle is closed so atomic_write_file can rename over it on Windows. - fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("test_file.toml"); - std::fs::write(&path, content).unwrap(); - - let file_content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - let mut tracker = FileReadTracker::default(); - tracker.record_read(path.clone(), &file_content, mtime); - - (dir, path, tracker) - } - - fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall { - EditToolCall { - path: path.to_path_buf(), - old_string: old.to_string(), - new_string: new.to_string(), - replace_all, - } - } - - #[test] - fn successful_single_replacement() { - let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n"); - - let call = edit_call(&path, "old_value", "new_value", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "[section]\nkey = new_value\n" - ); - } - - #[test] - fn successful_replace_all() { - let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa"); - - let call = edit_call(&path, "aaa", "xxx", true); - let (outcome, _) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences"))); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "xxx bbb xxx ccc xxx" - ); - } - - #[test] - fn error_file_not_read() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("unread.txt"); - std::fs::write(&path, "content").unwrap(); - let tracker = FileReadTracker::default(); // empty — never read - - let call = edit_call(&path, "x", "y", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("not been read yet"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_file_modified_since_read() { - let (_dir, path, tracker) = setup_tracked_file("original"); - - // Modify the file after the read was recorded - std::thread::sleep(std::time::Duration::from_millis(10)); - std::fs::write(&path, "modified externally").unwrap(); - - let call = edit_call(&path, "original", "replaced", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("modified since read"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_no_match() { - let (_dir, path, tracker) = setup_tracked_file("hello world"); - - let call = edit_call(&path, "nonexistent", "replacement", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("not found"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_multiple_matches_without_replace_all() { - let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo"); - - let call = edit_call(&path, "foo", "qux", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("3 matches"), "got: {msg}"); - assert!(msg.contains("replace_all"), "got: {msg}"); - } - _ => panic!("expected error"), - } - // File should be unchanged - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "foo bar foo baz foo" - ); - } - - #[test] - fn error_empty_old_string() { - let (_dir, path, tracker) = setup_tracked_file("content"); - - let call = edit_call(&path, "", "something", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - assert!(matches!(outcome, ToolOutcome::Error(_))); - } - - #[test] - fn error_file_does_not_exist() { - let tracker = FileReadTracker::default(); - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("nonexistent.txt"); - - let call = edit_call(&path, "x", "y", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("does not exist"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn preserves_file_when_no_match() { - let original = "[config]\nport = 8080\nhost = localhost\n"; - let (_dir, path, tracker) = setup_tracked_file(original); - - let call = edit_call(&path, "port = 9090", "port = 3000", false); - let (outcome, _) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Error(_))); - assert_eq!(std::fs::read_to_string(&path).unwrap(), original); - } - - #[test] - fn multiline_replacement() { - let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n"; - let (_dir, path, tracker) = setup_tracked_file(content); - - let call = edit_call( - &path, - "key1 = val1\nkey2 = val2", - "key1 = new1\nkey2 = new2", - false, - ); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "[section]\nkey1 = new1\nkey2 = new2\n[other]\n" - ); - } - } - - // ── Integration tests: full edit lifecycle ── - // - // These exercise the cross-component flow that dispatch orchestrates: - // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update - - mod edit_integration { - use super::*; - use crate::edit_permissions::EditPermissionCache; - use crate::file_tracker::FileReadTracker; - use crate::snapshots::SnapshotStore; - - /// Simulate a file read (what dispatch does after ReadToolCall.execute). - fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) { - let content = std::fs::read(path).unwrap(); - let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); - tracker.record_read(path.to_path_buf(), &content, mtime); - } - - /// Simulate a tracker update after edit (what dispatch does after execute). - fn simulate_tracker_update( - tracker: &mut FileReadTracker, - path: &std::path::Path, - new_bytes: &[u8], - ) { - let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); - tracker.update_after_edit(path, new_bytes, mtime); - } - - #[test] - fn full_read_snapshot_edit_cycle() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap(); - - let snapshot_dir = dir.path().join("snapshots").join("session-1"); - let mut tracker = FileReadTracker::default(); - let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap(); - - // 1. Simulate reading the file - simulate_read(&mut tracker, &file_path); - - // 2. Snapshot before edit - let original = std::fs::read(&file_path).unwrap(); - store.ensure_snapshot(&file_path, &original).unwrap(); - - // 3. Execute edit - let call = EditToolCall { - path: file_path.clone(), - old_string: "host = localhost".to_string(), - new_string: "host = 10.0.0.1".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - let new_bytes = new_bytes.unwrap(); - - // 4. Update tracker (simulating what dispatch does) - simulate_tracker_update(&mut tracker, &file_path, &new_bytes); - - // Verify: file was edited - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "[db]\nhost = 10.0.0.1\nport = 5432\n" - ); - - // Verify: snapshot has original content - assert!(store.has_snapshot(&file_path)); - let snapshot_name = crate::snapshots::sanitize_path(&file_path); - let snapshot_content = - std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap(); - assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n"); - } - - #[test] - fn second_edit_without_reread() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap(); - - let mut tracker = FileReadTracker::default(); - - // Read the file - simulate_read(&mut tracker, &file_path); - - // First edit - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "key1 = aaa".to_string(), - new_string: "key1 = xxx".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call1.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // Second edit — should work without re-reading because tracker was updated - let call2 = EditToolCall { - path: file_path.clone(), - old_string: "key2 = bbb".to_string(), - new_string: "key2 = yyy".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call2.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "key1 = xxx\nkey2 = yyy\n" - ); - } - - #[test] - fn external_modification_between_edits() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "value = original\n").unwrap(); - - let mut tracker = FileReadTracker::default(); - simulate_read(&mut tracker, &file_path); - - // First edit succeeds - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "value = original".to_string(), - new_string: "value = edited".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call1.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // External modification (e.g., user edits the file) - std::thread::sleep(std::time::Duration::from_millis(10)); - std::fs::write(&file_path, "value = user_changed\n").unwrap(); - - // Second edit should fail (stale) - let call2 = EditToolCall { - path: file_path.clone(), - old_string: "value = edited".to_string(), - new_string: "value = second_edit".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call2.execute(&file_path, &tracker); - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")), - _ => panic!("expected stale error"), - } - - // File should be unchanged (the user's edit preserved) - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "value = user_changed\n" - ); - } - - #[test] - fn snapshot_only_created_once_per_file() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap(); - - let snapshot_dir = dir.path().join("snapshots").join("session-1"); - let mut tracker = FileReadTracker::default(); - let mut store = SnapshotStore::open(snapshot_dir).unwrap(); - - simulate_read(&mut tracker, &file_path); - - // First edit — snapshot should be created - let original = std::fs::read(&file_path).unwrap(); - let created = store.ensure_snapshot(&file_path, &original).unwrap(); - assert!(created); - - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "a = 1".to_string(), - new_string: "a = 10".to_string(), - replace_all: false, - }; - let (_, new_bytes) = call1.execute(&file_path, &tracker); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // Second edit — snapshot should NOT be recreated - let content_before_second = std::fs::read(&file_path).unwrap(); - let created = store - .ensure_snapshot(&file_path, &content_before_second) - .unwrap(); - assert!(!created); // idempotent — already snapshotted - } - - #[test] - fn permission_cache_grant_and_check() { - let mut cache = EditPermissionCache::default(); - let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml"); - - // Initially no grant - assert!(!cache.has_valid_grant(&path)); - - // Grant permission - cache.grant(path.clone()); - assert!(cache.has_valid_grant(&path)); - - // Different file has no grant - assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml"))); - - // Roundtrip through JSON (simulates session persistence) - let json = cache.to_json().unwrap(); - let restored = EditPermissionCache::from_json(&json).unwrap(); - assert!(restored.has_valid_grant(&path)); - } - } - - // ── write_file execution tests ── - - mod write { - use super::*; - - #[test] - fn creates_new_file() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("new_file.txt"); - - let call = WriteToolCall { - path: path.clone(), - content: "hello\nworld\n".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created"))); - assert!(new_bytes.is_some()); - assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n"); - } - - #[test] - fn error_file_exists_without_overwrite() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("existing.txt"); - std::fs::write(&path, "original").unwrap(); - - let call = WriteToolCall { - path: path.clone(), - content: "new content".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("already exists"), "got: {msg}"); - assert!(msg.contains("overwrite"), "got: {msg}"); - } - _ => panic!("expected error"), - } - // Original preserved - assert_eq!(std::fs::read_to_string(&path).unwrap(), "original"); - } - - #[test] - fn overwrites_existing_file_when_flag_set() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("existing.txt"); - std::fs::write(&path, "original").unwrap(); - - let call = WriteToolCall { - path: path.clone(), - content: "replaced content\n".to_string(), - overwrite: true, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "replaced content\n" - ); - } - - #[test] - fn creates_parent_directories() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("sub").join("dir").join("file.txt"); - - let call = WriteToolCall { - path: path.clone(), - content: "nested\n".to_string(), - overwrite: false, - }; - let (outcome, _) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n"); - } - - #[test] - fn error_path_is_directory() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().to_path_buf(); - - let call = WriteToolCall { - path: path.clone(), - content: "content".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(new_bytes.is_none()); - assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory"))); - } - } - - // ── Windows-specific tests (absolute paths with drive letters) ── - - #[cfg(windows)] - mod windows { - use super::*; - - #[test] - fn absolute_glob() { - assert!( - read_tool(r"C:\Users\dev\src\main.rs") - .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) - ); - assert!( - !read_tool(r"C:\Users\dev\docs\readme.md") - .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) - ); - } - - #[test] - fn double_star_glob() { - assert!( - read_tool(r"C:\project\crates\foo\src\lib.rs") - .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) - ); - assert!( - !read_tool(r"C:\project\crates\foo\src\lib.py") - .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) - ); - } - } -} diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs deleted file mode 100644 index 31dff1c3..00000000 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Top-level AtuinAi component that translates key events into AiTuiEvents. -//! -//! Global shortcuts (Ctrl+C, Esc) are handled in the capture phase so they -//! fire regardless of which child is focused. Contextual shortcuts (Enter, -//! Tab) are handled in the bubble phase so child components like the -//! permission Select can consume them first. - -use crossterm::event::{Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; -use eye_declare::{Elements, EventResult, Hooks, component, props}; - -use crate::commands::inline::DriverEventSender; -use crate::tui::events::AiTuiEvent; -use crate::tui::state::AppMode; - -/// Top-level wrapper component for the AI TUI. -/// -/// Props carry the current mode so `handle_event` can translate keys -/// into the right `AiTuiEvent`. Children are rendered via slot children. -#[props] -pub(crate) struct AtuinAi { - pub mode: AppMode, - pub has_command: bool, - pub is_input_blank: bool, - pub pending_confirmation: bool, - pub has_executing_preview: bool, -} - -#[derive(Default)] -pub(crate) struct AtuinAiState { - tx: Option, -} - -#[component(props = AtuinAi, state = AtuinAiState, children = Elements)] -fn atuin_ai( - _props: &AtuinAi, - _state: &AtuinAiState, - hooks: &mut Hooks, - children: Elements, -) -> Elements { - hooks.use_context::(|tx, _, state| { - state.tx = tx.cloned(); - }); - - // Capture phase: global shortcuts that must fire regardless of child focus. - hooks.use_event_capture(move |event, props, state| { - let Event::Key(KeyEvent { - code, - kind: KeyEventKind::Press, - modifiers, - .. - }) = event - else { - return EventResult::Ignored; - }; - - let Some(ref tx) = state.read().tx else { - return EventResult::Ignored; - }; - - // Ctrl+C — interrupt executing command or exit - if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - if props.has_executing_preview { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - } else { - let _ = tx.send(AiTuiEvent::Exit); - } - return EventResult::Consumed; - } - - // Esc — always handled at the top level - if *code == KeyCode::Esc { - match props.mode { - AppMode::Input => { - if props.has_executing_preview { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - } else if props.pending_confirmation { - let _ = tx.send(AiTuiEvent::CancelConfirmation); - } else { - let _ = tx.send(AiTuiEvent::Exit); - } - } - AppMode::Generating | AppMode::Streaming => { - let _ = tx.send(AiTuiEvent::CancelGeneration); - } - AppMode::Error => { - let _ = tx.send(AiTuiEvent::Exit); - } - } - return EventResult::Consumed; - } - - if *code == KeyCode::Tab - && matches!(props.mode, AppMode::Input) - && modifiers.contains(KeyModifiers::NONE) - && props.has_command - && props.is_input_blank - { - let _ = tx.send(AiTuiEvent::InsertCommand); - return EventResult::Consumed; - } - - EventResult::Ignored - }); - - // Bubble phase: contextual shortcuts that children (e.g. Select) may handle first. - hooks.use_event(move |event, props, state| { - let Event::Key(KeyEvent { - code, - kind: KeyEventKind::Press, - .. - }) = event - else { - return EventResult::Ignored; - }; - - let Some(ref tx) = state.read().tx else { - return EventResult::Ignored; - }; - - match props.mode { - AppMode::Input => match code { - KeyCode::Enter => { - if props.has_command && props.is_input_blank { - let _ = tx.send(AiTuiEvent::ExecuteCommand); - return EventResult::Consumed; - } - EventResult::Ignored - } - _ => EventResult::Ignored, - }, - AppMode::Error => match code { - KeyCode::Enter | KeyCode::Char('r') => { - let _ = tx.send(AiTuiEvent::Retry); - EventResult::Consumed - } - _ => EventResult::Ignored, - }, - _ => EventResult::Ignored, - } - }); - - children -} diff --git a/crates/atuin-ai/src/tui/components/input_box.rs b/crates/atuin-ai/src/tui/components/input_box.rs deleted file mode 100644 index 6b81322c..00000000 --- a/crates/atuin-ai/src/tui/components/input_box.rs +++ /dev/null @@ -1,220 +0,0 @@ -//! Bordered input box component for the AI TUI. -//! -//! Wraps tui-textarea's TextArea, which handles rendering, wrapping, cursor -//! positioning, and height measurement natively. The component configures the -//! TextArea's block (border + titles) and forwards events to it. -//! -//! On Enter, sends `AiTuiEvent::SubmitInput` via the context-provided channel. - -use std::sync::{Arc, Mutex}; - -use crossterm::event::KeyModifiers; -use eye_declare::{Canvas, Elements, EventResult, Hooks, component, element, props}; -use ratatui::widgets::{Block, Borders, Padding}; -use ratatui_core::{ - layout::Rect, - style::{Color, Modifier, Style}, - text::Line, - widgets::Widget, -}; -use tui_textarea::TextArea; - -use crate::commands::inline::DriverEventSender; -use crate::tui::{events::AiTuiEvent, slash::SlashCommandSearchResult}; - -/// A bordered text input box backed by tui-textarea. -/// -/// Props configure the chrome (title, footer). The TextArea itself lives -/// in the component's State so it owns cursor, wrapping, and rendering. -#[props] -pub(crate) struct InputBox { - /// Title shown in top-left border - pub title: String, - /// Right-side label in top border - pub title_right: String, - /// Footer text shown in bottom border (keybinding hints) - pub footer: String, - /// Whether the input is currently active (shows cursor, accepts input) - pub active: bool, - /// If the user has typed a slash command, this holds the best match for it. - pub slash_suggestion: Option, -} - -pub(crate) struct InputBoxState { - textarea: Arc>>, - tx: Option, -} - -impl Default for InputBoxState { - fn default() -> Self { - let mut textarea = TextArea::default(); - textarea.set_cursor_line_style(ratatui::style::Style::default()); - textarea.set_wrap_mode(tui_textarea::WrapMode::Word); - textarea.set_placeholder_text("Type a message..."); - textarea.set_placeholder_style( - ratatui::style::Style::default() - .fg(ratatui::style::Color::DarkGray) - .add_modifier(ratatui::style::Modifier::ITALIC), - ); - Self { - textarea: Arc::new(Mutex::new(textarea)), - tx: None, - } - } -} - -fn make_block(props: &InputBox) -> Block<'static> { - let border_style = Style::default().fg(Color::DarkGray); - let title_style = Style::default() - .fg(Color::Gray) - .add_modifier(Modifier::BOLD); - - let mut block = Block::default() - .borders(Borders::ALL) - .border_style(border_style) - .padding(Padding::horizontal(1)); - - if !props.title.is_empty() { - block = - block.title_top(Line::styled(format!(" {} ", props.title), title_style).left_aligned()); - } - if !props.title_right.is_empty() { - block = block.title_top( - Line::styled(format!(" {} ", props.title_right), border_style).right_aligned(), - ); - } - if !props.footer.is_empty() { - block = block.title_bottom( - Line::styled(format!(" {} ", props.footer), border_style).right_aligned(), - ); - } - - block -} - -#[component(props = InputBox, state = InputBoxState)] -fn input_box( - props: &InputBox, - state: &InputBoxState, - hooks: &mut Hooks, -) -> Elements { - // Always focusable so focus isn't lost when the permission Select is - // removed from the tree. The `active` prop controls visual state and - // whether keystrokes are processed, not focusability. - hooks.use_focusable(true); - hooks.use_autofocus(); - - hooks.use_context::(|tx, _, state| { - state.tx = tx.cloned(); - }); - - hooks.use_event(move |event, props, state| { - let state = state.read(); - - if !props.active { - return EventResult::Ignored; - } - - if let crossterm::event::Event::Paste(text) = event { - let mut textarea = state.textarea.lock().unwrap(); - textarea.insert_str(text); - return EventResult::Consumed; - } - - if let crossterm::event::Event::Key(key) = event { - if key.kind != crossterm::event::KeyEventKind::Press { - return EventResult::Ignored; - } - - let mut textarea = state.textarea.lock().unwrap(); - - match key.code { - crossterm::event::KeyCode::Char('j') - if key.modifiers.contains(KeyModifiers::CONTROL) => - { - textarea.insert_newline(); - return EventResult::Consumed; - } - crossterm::event::KeyCode::Tab if props.slash_suggestion.is_some() => { - // If there's a slash command suggestion, Tab accepts it. - if let Some(suggestion) = &props.slash_suggestion { - textarea.clear(); - textarea.insert_str(format!("/{}", suggestion.command.name)); - // Manually trigger an input update event so the slash suggestion box can update immediately - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n"))); - } - return EventResult::Consumed; - } - } - crossterm::event::KeyCode::Enter => { - if key.modifiers.contains(KeyModifiers::SHIFT) { - textarea.insert_newline(); - return EventResult::Consumed; - } else { - let text = textarea.lines().join("\n"); - if text.trim().is_empty() { - return EventResult::Ignored; - } - - textarea.clear(); - - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::SubmitInput(text)); - } - return EventResult::Consumed; - } - } - _ => {} - } - - // All other keys: forward to textarea. - // tui-textarea can convert crossterm events itself. - textarea.input(*key); - - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n"))); - } - return EventResult::Consumed; - } - - EventResult::Ignored - }); - - let textarea = state.textarea.clone(); - let block = make_block(props); - let active = props.active; - element!( - Canvas(render_fn: move |area, buf| { - let mut area = area; - - if area.height < 3 || area.width < 4 { - return; - } - - let height = { - // TextArea handles scrolling internally if content overflows. - let inner = block.inner(Rect::new(0, 0, area.width, u16::MAX)); - let chrome = (u16::MAX).saturating_sub(inner.height); - let content = textarea.lock().unwrap().measure(area.width - 4); - chrome + content.preferred_rows - }; - - area.height = height.min(7); - let inner = block.clone().inner(area); - block.clone().render(area, buf); - - let mut textarea = textarea.lock().unwrap(); - if active { - textarea.set_cursor_style(Style::default().add_modifier(Modifier::REVERSED)); - textarea.set_placeholder_text("Type a message..."); - } else { - textarea.set_cursor_style(Style::default()); - textarea.set_placeholder_text(""); - } - - // Render textarea into the inner area - textarea.render(inner, buf); - }) - ) -} diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs deleted file mode 100644 index 607520b7..00000000 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Markdown rendering component using pulldown-cmark. -//! -//! More robust than eye-declare's built-in Markdown component: -//! uses a proper CommonMark parser rather than line-by-line regex. - -use eye_declare::{Component, props}; -use pulldown_cmark::{Event, Parser, Tag, TagEnd}; -use ratatui_core::{ - buffer::Buffer, - layout::Rect, - style::{Color, Modifier, Style}, - text::{Line, Span, Text}, - widgets::Widget, -}; -use ratatui_widgets::paragraph::{Paragraph, Wrap}; - -/// A markdown rendering component backed by pulldown-cmark. -#[props] -pub(crate) struct Markdown { - pub source: String, -} - -/// Style configuration for markdown rendering. -pub(crate) struct MarkdownStyles { - pub base: Style, - pub code_inline: Style, - pub code_block: Style, - pub bold: Style, - pub italic: Style, - pub heading: Style, -} - -impl MarkdownStyles { - pub fn new() -> Self { - let base = Style::default(); - Self { - base, - code_inline: Style::default().fg(Color::Yellow), - code_block: Style::default().fg(Color::Green), - bold: base.add_modifier(Modifier::BOLD), - italic: base.add_modifier(Modifier::ITALIC), - heading: Style::default() - .fg(Color::Cyan) - .add_modifier(Modifier::BOLD), - } - } -} - -impl Default for MarkdownStyles { - fn default() -> Self { - Self::new() - } -} - -impl Component for Markdown { - type State = MarkdownStyles; - - fn render(&self, area: Rect, buf: &mut Buffer, state: &Self::State) { - if self.source.is_empty() || area.width == 0 || area.height == 0 { - return; - } - let text = parse_markdown(&self.source, state); - Paragraph::new(text) - .wrap(Wrap { trim: false }) - .render(area, buf); - } - - fn desired_height(&self, width: u16, state: &Self::State) -> Option { - if self.source.is_empty() || width == 0 { - return Some(0); - } - let text = parse_markdown(&self.source, state); - Some( - Paragraph::new(text) - .wrap(Wrap { trim: false }) - .line_count(width) as u16, - ) - } - - fn initial_state(&self) -> Option { - Some(MarkdownStyles::new()) - } -} - -/// Parse markdown source into styled ratatui Text using pulldown-cmark. -fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'static> { - let parser = Parser::new(source); - let mut lines: Vec>> = vec![Vec::new()]; - let mut current_line = 0; - - let mut style_stack: Vec