diff options
| author | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-10 22:01:45 +0200 |
|---|---|---|
| committer | Benedikt Peetz <benedikt.peetz@b-peetz.de> | 2026-06-10 22:01:45 +0200 |
| commit | 5e31a81cd2207f053b8cd8ad84ebe2a2f691b29d (patch) | |
| tree | 5d76811ab0d693c01fa472d41aa2ceaf3bd0b415 | |
| parent | chore: Remove unneeded files (diff) | |
| download | atuin-5e31a81cd2207f053b8cd8ad84ebe2a2f691b29d.zip | |
chore: Remove some unused rust code
103 files changed, 88 insertions, 22143 deletions
@@ -9,6 +9,8 @@ publish.sh .envrc .planning/ +.direnv + ui/backend/target ui/backend/gen @@ -144,40 +144,12 @@ dependencies = [ ] [[package]] -name = "arraydeque" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" - -[[package]] name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] name = "async-trait" version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -218,16 +190,13 @@ version = "18.16.1" dependencies = [ "arboard", "async-trait", - "atuin-ai", "atuin-client", "atuin-common", "atuin-daemon", - "atuin-dotfiles", "atuin-history", "atuin-kv", "atuin-nucleo-matcher", "atuin-pty-proxy", - "atuin-scripts", "atuin-server", "atuin-server-database", "atuin-server-postgres", @@ -269,60 +238,6 @@ dependencies = [ "tracing-tree", "unicode-width 0.2.2", "uuid", - "windows-sys 0.61.2", -] - -[[package]] -name = "atuin-ai" -version = "18.16.1" -dependencies = [ - "async-stream", - "async-trait", - "atuin-client", - "atuin-common", - "atuin-daemon", - "chrono", - "chrono-humanize", - "clap", - "crossterm", - "directories", - "eventsource-stream", - "eye_declare", - "eyre", - "fs-err", - "futures", - "glob-match", - "imara-diff", - "pretty_assertions", - "pulldown-cmark", - "ratatui", - "ratatui-core", - "ratatui-widgets", - "regex", - "reqwest", - "serde", - "serde_json", - "shellexpand", - "sqlx", - "tempfile", - "thiserror 2.0.18", - "time", - "tokio", - "toml", - "toml_edit", - "tracing", - "tracing-appender", - "tracing-subscriber", - "tree-sitter", - "tree-sitter-bash", - "tree-sitter-fish", - "tui-textarea-2", - "typed-builder 0.18.2", - "unicode-width 0.2.2", - "uuid", - "vt100", - "xxhash-rust", - "yaml-rust2", ] [[package]] @@ -374,7 +289,7 @@ dependencies = [ "time", "tiny-bip39", "tokio", - "typed-builder 0.18.2", + "typed-builder", "urlencoding", "uuid", "whoami 2.1.1", @@ -396,7 +311,7 @@ dependencies = [ "sysinfo", "thiserror 2.0.18", "time", - "typed-builder 0.18.2", + "typed-builder", "uuid", ] @@ -406,7 +321,6 @@ version = "18.16.1" dependencies = [ "atuin-client", "atuin-common", - "atuin-dotfiles", "atuin-history", "atuin-nucleo", "dashmap", @@ -434,20 +348,6 @@ dependencies = [ ] [[package]] -name = "atuin-dotfiles" -version = "18.16.1" -dependencies = [ - "atuin-client", - "atuin-common", - "crypto_secretbox", - "eyre", - "rand 0.8.5", - "rmp", - "serde", - "tokio", -] - -[[package]] name = "atuin-history" version = "18.16.1" dependencies = [ @@ -473,7 +373,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", - "typed-builder 0.18.2", + "typed-builder", ] [[package]] @@ -516,28 +416,6 @@ dependencies = [ ] [[package]] -name = "atuin-scripts" -version = "18.16.1" -dependencies = [ - "atuin-client", - "atuin-common", - "eyre", - "minijinja", - "pretty_assertions", - "rmp", - "serde", - "serde_json", - "sql-builder", - "sqlx", - "tempfile", - "tokio", - "tracing", - "tracing-subscriber", - "typed-builder 0.18.2", - "uuid", -] - -[[package]] name = "atuin-server" version = "18.16.1" dependencies = [ @@ -840,23 +718,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", - "js-sys", "num-traits", "serde", - "wasm-bindgen", "windows-link", ] [[package]] -name = "chrono-humanize" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "799627e6b4d27827a814e837b9d8a504832086081806d45b1afa34dc982b023b" -dependencies = [ - "chrono", -] - -[[package]] name = "cipher" version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1124,7 +991,6 @@ dependencies = [ "derive_more", "document-features", "filedescriptor", - "futures-core", "mio", "parking_lot", "rustix", @@ -1479,15 +1345,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] -name = "encoding_rs" -version = "0.8.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" -dependencies = [ - "cfg-if", -] - -[[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1541,44 +1398,6 @@ dependencies = [ ] [[package]] -name = "eventsource-stream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" -dependencies = [ - "futures-core", - "nom 7.1.3", - "pin-project-lite", -] - -[[package]] -name = "eye_declare" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e4caff9f5574315258489ecb2d27aad2cc057521f4b0bc6819a1cd82a648f40" -dependencies = [ - "crossterm", - "eye_declare_macros", - "futures", - "ratatui-core", - "ratatui-widgets", - "tokio", - "typed-builder 0.23.2", - "unicode-width 0.2.2", -] - -[[package]] -name = "eye_declare_macros" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb8ac27b19f79b61a8afc0a55a3148eb2dafc2aaec66c1103432e3e76e7c25af" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] name = "eyre" version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1859,15 +1678,6 @@ dependencies = [ ] [[package]] -name = "getopts" -version = "0.2.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" -dependencies = [ - "unicode-width 0.2.2", -] - -[[package]] name = "getrandom" version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1904,12 +1714,6 @@ dependencies = [ ] [[package]] -name = "glob-match" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d" - -[[package]] name = "h2" version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1976,15 +1780,6 @@ dependencies = [ ] [[package]] -name = "hashlink" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" -dependencies = [ - "hashbrown 0.16.1", -] - -[[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2288,16 +2083,6 @@ dependencies = [ ] [[package]] -name = "imara-diff" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f01d462f766df78ab820dd06f5eb700233c51f0f4c2e846520eaf4ba6aa5c5c" -dependencies = [ - "hashbrown 0.15.5", - "memchr", -] - -[[package]] name = "indenter" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2747,12 +2532,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" [[package]] -name = "memo-map" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" - -[[package]] name = "memoffset" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2830,16 +2609,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] -name = "minijinja" -version = "2.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "328251e58ad8e415be6198888fc207502727dc77945806421ab34f35bf012e7d" -dependencies = [ - "memo-map", - "serde", -] - -[[package]] name = "minimal-lexical" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -3639,19 +3408,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83c41efbf8f90ac44de7f3a868f0867851d261b56291732d0cbf7cceaaeb55a6" dependencies = [ "bitflags 2.11.0", - "getopts", "memchr", - "pulldown-cmark-escape", "unicase", ] [[package]] -name = "pulldown-cmark-escape" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae" - -[[package]] name = "pulldown-cmark-to-cmark" version = "22.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -4362,7 +4123,6 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ - "indexmap 2.13.0", "itoa", "memchr", "serde", @@ -4653,7 +4413,7 @@ dependencies = [ "futures-io", "futures-util", "hashbrown 0.15.5", - "hashlink 0.10.0", + "hashlink", "indexmap 2.13.0", "log", "memchr", @@ -4835,12 +4595,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] -name = "streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" - -[[package]] name = "stringprep" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5228,12 +4982,10 @@ version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "994b95d9e7bae62b34bab0e2a4510b801fa466066a6a8b2b57361fa1eba068ee" dependencies = [ - "indexmap 2.13.0", "serde_core", "serde_spanned", "toml_datetime", "toml_parser", - "toml_writer", "winnow 1.0.1", ] @@ -5503,46 +5255,6 @@ dependencies = [ ] [[package]] -name = "tree-sitter" -version = "0.26.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "887bd495d0582c5e3e0d8ece2233666169fa56a9644d172fc22ad179ab2d0538" -dependencies = [ - "cc", - "regex", - "regex-syntax", - "serde_json", - "streaming-iterator", - "tree-sitter-language", -] - -[[package]] -name = "tree-sitter-bash" -version = "0.25.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5ec769279cc91b561d3df0d8a5deb26b0ad40d183127f409494d6d8fc53062" -dependencies = [ - "cc", - "tree-sitter-language", -] - -[[package]] -name = "tree-sitter-fish" -version = "3.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "014e3b299f251e9c2e372e3b5e1b0323ef21196e9aa2e90a5bc1f6130cbe8b18" -dependencies = [ - "cc", - "tree-sitter", -] - -[[package]] -name = "tree-sitter-language" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" - -[[package]] name = "tree_magic_mini" version = "3.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5560,35 +5272,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] -name = "tui-textarea-2" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74a31ca0965e3ff6a7ac5ecb02b20a88b4f68ebf138d8ae438e8510b27a1f00f" -dependencies = [ - "crossterm", - "portable-atomic", - "ratatui-core", - "ratatui-widgets", - "unicode-segmentation", - "unicode-width 0.2.2", -] - -[[package]] name = "typed-builder" version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77739c880e00693faef3d65ea3aad725f196da38b22fdc7ea6ded6e1ce4d3add" dependencies = [ - "typed-builder-macro 0.18.2", -] - -[[package]] -name = "typed-builder" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31aa81521b70f94402501d848ccc0ecaa8f93c8eb6999eb9747e72287757ffda" -dependencies = [ - "typed-builder-macro 0.23.2", + "typed-builder-macro", ] [[package]] @@ -5603,17 +5292,6 @@ dependencies = [ ] [[package]] -name = "typed-builder-macro" -version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "076a02dc54dd46795c2e9c8282ed40bcfb1e22747e955de9389a1de28190fb26" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - -[[package]] name = "typenum" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -6751,23 +6429,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd" [[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - -[[package]] -name = "yaml-rust2" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631a50d867fafb7093e709d75aaee9e0e0d5deb934021fcea25ac2fe09edc51e" -dependencies = [ - "arraydeque", - "encoding_rs", - "hashlink 0.11.0", -] - -[[package]] name = "yansi" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -11,7 +11,7 @@ exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"] [workspace.package] version = "18.16.1" authors = ["Ellie Huxtable <ellie@atuin.sh>"] -rust-version = "1.96.0" +rust-version = "1.95.0" license = "MIT" homepage = "https://atuin.sh" repository = "https://github.com/atuinsh/atuin" @@ -76,19 +76,9 @@ xxhash-rust = { version = "0.8", features = ["xxh3"] } vt100 = "0.16" regex = "1.10.5" toml_edit = "0.25.4" - -[workspace.dependencies.tracing-subscriber] -version = "0.3" -features = ["ansi", "fmt", "registry", "env-filter", "json"] - -[workspace.dependencies.reqwest] -version = "0.13" -features = ["json", "rustls-no-provider", "stream"] -default-features = false - -[workspace.dependencies.sqlx] -version = "0.8" -features = ["runtime-tokio-rustls", "time", "postgres", "uuid"] +tracing-subscriber = { version = "0.3", features = ["ansi", "fmt", "registry", "env-filter", "json"] } +reqwest = { version = "0.13", features = ["json", "rustls-no-provider", "stream"], default-features = false } +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "time", "postgres", "uuid"] } # The profile that 'cargo dist' will build with [profile.dist] diff --git a/crates/atuin-ai/Cargo.toml b/crates/atuin-ai/Cargo.toml deleted file mode 100644 index 027bd490..00000000 --- a/crates/atuin-ai/Cargo.toml +++ /dev/null @@ -1,74 +0,0 @@ -[package] -name = "atuin-ai" -edition = "2024" -description = "AI integration for Atuin CLI" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[features] -default = [] -daemon = [] -tree-sitter = ["dep:tree-sitter-lib", "dep:tree-sitter-bash", "dep:tree-sitter-fish"] - -[dependencies] -async-trait = { workspace = true } -atuin-client = { workspace = true } -atuin-common = { workspace = true } -atuin-daemon = { workspace = true } -tokio = { workspace = true } -eyre = { workspace = true } -clap = { workspace = true, features = ["derive", "env"] } -tracing = { workspace = true } -tracing-subscriber = { workspace = true, features = [ - "ansi", - "fmt", - "registry", - "env-filter", -] } -directories = { workspace = true } -tracing-appender = "0.2.4" -reqwest = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -crossterm = { workspace = true, features = ["use-dev-tty", "event-stream"] } -ratatui = { workspace = true } -fs-err = { workspace = true } -futures = "0.3" -eventsource-stream = "0.2" -pulldown-cmark = "0.13.0" -async-stream = "0.3" -uuid = { workspace = true } -tui-textarea-2 = "0.10.2" -unicode-width = "0.2" -eye_declare = "0.5.1" -ratatui-core = "0.1" -ratatui-widgets = "0.3" -thiserror = { workspace = true } -glob-match = { workspace = true } -regex = { workspace = true } -time = { workspace = true } -toml = "1.1" -toml_edit = { workspace = true } -tree-sitter-lib = { package = "tree-sitter", version = "0.26.8", optional = true } -tree-sitter-bash = { version = "0.25.1", optional = true } -tree-sitter-fish = { version = "3.6.0", optional = true } -sqlx = { workspace = true, features = ["sqlite"] } -typed-builder = { workspace = true } -shellexpand = { workspace = true } -imara-diff = { workspace = true } -xxhash-rust = { workspace = true } -vt100 = { workspace = true } -yaml-rust2 = "0.11" -tempfile = { workspace = true } -chrono = "0.4" -chrono-humanize = "0.2" - -[dev-dependencies] -pretty_assertions = { workspace = true } diff --git a/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql b/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql deleted file mode 100644 index 906a5726..00000000 --- a/crates/atuin-ai/migrations/20260413000000_create_ai_sessions.sql +++ /dev/null @@ -1,32 +0,0 @@ -CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - head_id TEXT, - server_session_id TEXT, - directory TEXT, - git_root TEXT, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL, - archived_at INTEGER -); - -CREATE INDEX idx_sessions_directory ON sessions(directory); -CREATE INDEX idx_sessions_git_root ON sessions(git_root); -CREATE INDEX idx_sessions_updated_at ON sessions(updated_at); -CREATE INDEX idx_sessions_created_at ON sessions(created_at); - -CREATE TABLE IF NOT EXISTS session_events ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - parent_id TEXT, - invocation_id TEXT NOT NULL, - event_type TEXT NOT NULL, - event_data TEXT NOT NULL, - created_at INTEGER NOT NULL, - - FOREIGN KEY (session_id) REFERENCES sessions(id) -); - -CREATE INDEX idx_session_events_session_id ON session_events(session_id); -CREATE INDEX idx_session_events_parent_id ON session_events(parent_id); -CREATE INDEX idx_session_events_invocation_id ON session_events(invocation_id); -CREATE INDEX idx_session_events_created_at ON session_events(created_at); diff --git a/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql b/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql deleted file mode 100644 index f97dfd1b..00000000 --- a/crates/atuin-ai/migrations/20260417000000_add_session_metadata.sql +++ /dev/null @@ -1,9 +0,0 @@ -CREATE TABLE IF NOT EXISTS session_metadata ( - session_id TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - updated_at INTEGER NOT NULL, - - PRIMARY KEY (session_id, key), - FOREIGN KEY (session_id) REFERENCES sessions(id) -); diff --git a/crates/atuin-ai/render-tests.sh b/crates/atuin-ai/render-tests.sh deleted file mode 100755 index 8dedc76e..00000000 --- a/crates/atuin-ai/render-tests.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash -# Render all test cases from test-renders.json -# Usage: ./render-tests.sh [test_name] -# With no args: renders all tests -# With arg: renders only matching test (e.g., ./render-tests.sh 05) - -set -e -cd "$(dirname "$0")" - -JSON_FILE="test-renders.json" -FILTER="${1:-}" - -# Build once -cargo build -p atuin-ai --quiet - -# Count tests -TOTAL=$(jq length "$JSON_FILE") - -for i in $(seq 0 $((TOTAL - 1))); do - NAME=$(jq -r ".[$i].name" "$JSON_FILE") - DESC=$(jq -r ".[$i].description" "$JSON_FILE") - STATE=$(jq -c ".[$i].state" "$JSON_FILE") - - # Skip if filter provided and doesn't match - if [[ -n "$FILTER" && ! "$NAME" =~ $FILTER ]]; then - continue - fi - - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$NAME] $DESC" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin-ai --quiet -- debug-render -f plain - echo "" -done diff --git a/crates/atuin-ai/replay-states.sh b/crates/atuin-ai/replay-states.sh deleted file mode 100755 index 791ad47e..00000000 --- a/crates/atuin-ai/replay-states.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/bin/bash -# Replay state snapshots from a debug state JSONL file -# Usage: ./replay-states.sh <state-file.jsonl> [entry-number] -# With no entry: renders all frames in sequence (press Enter to advance) -# With entry number: renders just that frame - -set -e -# cd "$(dirname "$0")" - -STATE_FILE="${1:-}" -ENTRY_FILTER="${2:-}" - -if [[ -z "$STATE_FILE" ]]; then - echo "Usage: $0 <state-file.jsonl> [entry-number]" - echo "" - echo "Examples:" - echo " $0 /tmp/state.jsonl # Interactive replay of all frames" - echo " $0 /tmp/state.jsonl 15 # Show just entry 15" - exit 1 -fi - -if [[ ! -f "$STATE_FILE" ]]; then - echo "Error: File not found: $STATE_FILE" - exit 1 -fi - -# Build once -cargo build -p atuin --quiet - -# Count entries -TOTAL=$(wc -l < "$STATE_FILE" | tr -d ' ') - -if [[ -n "$ENTRY_FILTER" ]]; then - # Show single entry - LINE=$(sed -n "${ENTRY_FILTER}p" "$STATE_FILE") - if [[ -z "$LINE" ]]; then - echo "Error: Entry $ENTRY_FILTER not found (file has $TOTAL entries)" - exit 1 - fi - - ENTRY=$(echo "$LINE" | jq -r '.entry') - LABEL=$(echo "$LINE" | jq -r '.label') - STATE=$(echo "$LINE" | jq -c '.state') - - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$ENTRY/$TOTAL] $LABEL" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin --quiet -- ai debug-render -f ansi -else - # Interactive replay - echo "Replaying $TOTAL frames from $STATE_FILE" - echo "Press Enter to advance, 'q' to quit, or number+Enter to jump" - echo "" - - CURRENT=1 - while [[ $CURRENT -le $TOTAL ]]; do - LINE=$(sed -n "${CURRENT}p" "$STATE_FILE") - ENTRY=$(echo "$LINE" | jq -r '.entry') - LABEL=$(echo "$LINE" | jq -r '.label') - STATE=$(echo "$LINE" | jq -c '.state') - - clear - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "[$CURRENT/$TOTAL] $LABEL" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "$STATE" | cargo run -p atuin --quiet -- ai debug-render -f ansi - echo "" - echo "[Enter: next] [p: prev] [number: jump] [s: show state JSON] [q: quit]" - - read -r INPUT - case "$INPUT" in - q|Q) - break - ;; - p|P) - if [[ $CURRENT -gt 1 ]]; then - CURRENT=$((CURRENT - 1)) - fi - ;; - s|S) - echo "" - echo "State JSON:" - echo "$STATE" | jq . - echo "" - echo "Press Enter to continue..." - read -r - ;; - ''|' ') - CURRENT=$((CURRENT + 1)) - ;; - *[0-9]*) - if [[ "$INPUT" =~ ^[0-9]+$ ]] && [[ "$INPUT" -ge 1 ]] && [[ "$INPUT" -le $TOTAL ]]; then - CURRENT=$INPUT - else - echo "Invalid entry number (1-$TOTAL)" - sleep 1 - fi - ;; - esac - done -fi diff --git a/crates/atuin-ai/src/commands.rs b/crates/atuin-ai/src/commands.rs deleted file mode 100644 index cdbc8f2d..00000000 --- a/crates/atuin-ai/src/commands.rs +++ /dev/null @@ -1,158 +0,0 @@ -use std::{ - fs, - path::{Path, PathBuf}, -}; - -use atuin_common::shell::Shell; -use clap::{Args, Subcommand}; -use eyre::Result; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -use tracing_subscriber::{EnvFilter, Layer, fmt, layer::SubscriberExt, util::SubscriberInitExt}; -pub mod init; -pub(crate) mod inline; - -#[derive(Args, Debug)] -pub struct AiArgs { - /// Enable verbose logging - #[arg(short, long, global = true)] - verbose: bool, - - /// Custom API endpoint; defaults to reading from the `ai.endpoint` setting. - #[arg(long, global = true)] - api_endpoint: Option<String>, - - /// Custom API token; defaults to reading from the `ai.api_token` setting. - #[arg(long, global = true)] - api_token: Option<String>, -} - -#[derive(Subcommand, Debug)] -pub enum Commands { - /// Initialize shell integration - Init { - /// Shell to generate integration for; defaults to "auto" - #[arg(value_name = "SHELL", default_value = "auto")] - shell: String, - }, - - /// Inline completion mode with small TUI overlay - Inline { - #[command(flatten)] - args: AiArgs, - - /// Current command line to complete - #[arg(value_name = "COMMAND")] - command: Option<String>, - - /// Use the hook mode - #[arg(long, hide = true)] - hook: bool, - }, -} - -pub async fn run( - command: Commands, - settings: &atuin_client::settings::Settings, -) -> eyre::Result<()> { - match command { - Commands::Init { shell } => init::run(shell).await, - Commands::Inline { - command, - hook, - args, - .. - } => { - if settings.logs.ai_enabled() { - init_logging(settings, args.verbose)?; - } - - inline::run(command, args.api_endpoint, args.api_token, settings, hook).await - } - } -} - -pub(crate) fn detect_shell() -> Option<String> { - Some(Shell::current().to_string()) -} - -/// Initializes logging for the AI commands. -fn init_logging(settings: &atuin_client::settings::Settings, verbose: bool) -> Result<()> { - // ATUIN_LOG env var overrides config file level settings - let env_log_set = std::env::var("ATUIN_LOG").is_ok(); - - // Base filter from env var (or empty if not set) - let base_filter = - EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); - - // Use config level unless ATUIN_LOG is set - let filter = if env_log_set { - base_filter - } else { - EnvFilter::default() - .add_directive(settings.logs.ai_level().as_directive().parse()?) - .add_directive("sqlx_sqlite::regexp=off".parse()?) - }; - - let log_dir = PathBuf::from(&settings.logs.dir); - let ai_log_filename = settings.logs.ai.file.clone(); - - // Clean up old log files - cleanup_old_logs(&log_dir, &ai_log_filename, settings.logs.ai_retention()); - - let console_layer = if verbose { - Some( - fmt::layer() - .with_writer(std::io::stderr) - .with_ansi(true) - .with_target(false) - .with_filter(filter.clone()), - ) - } else { - None - }; - - let file_appender = RollingFileAppender::new(Rotation::DAILY, &log_dir, &ai_log_filename); - - let base = tracing_subscriber::registry().with( - fmt::layer() - .with_writer(file_appender) - .with_ansi(false) - .with_filter(filter), - ); - - if let Some(console_layer) = console_layer { - base.with(console_layer).init(); - } else { - base.init(); - }; - - Ok(()) -} - -fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { - let cutoff = std::time::SystemTime::now() - - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); - - let Ok(entries) = fs::read_dir(log_dir) else { - return; - }; - - for entry in entries.flatten() { - let path = entry.path(); - let Some(name) = path.file_name().and_then(|n| n.to_str()) else { - continue; - }; - - // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" - if !name.starts_with(prefix) || name == prefix { - continue; - } - - if let Ok(metadata) = entry.metadata() - && let Ok(modified) = metadata.modified() - && modified < cutoff - { - let _ = fs::remove_file(&path); - } - } -} diff --git a/crates/atuin-ai/src/commands/init.rs b/crates/atuin-ai/src/commands/init.rs deleted file mode 100644 index 1f03f5b1..00000000 --- a/crates/atuin-ai/src/commands/init.rs +++ /dev/null @@ -1,233 +0,0 @@ -use crate::commands::detect_shell; - -pub(crate) async fn run(shell: String) -> eyre::Result<()> { - let integration = match shell.as_str() { - "zsh" => generate_zsh_integration(), - "bash" => generate_bash_integration(), - "fish" => generate_fish_integration(), - "auto" => generate_auto_integration()?, - _ => eyre::bail!("Unsupported shell: {}", shell), - }; - - println!("{}", integration); - Ok(()) -} - -fn generate_auto_integration() -> eyre::Result<&'static str> { - let shell = detect_shell(); - match shell.as_deref() { - Some("zsh") => Ok(generate_zsh_integration()), - Some("bash") => Ok(generate_bash_integration()), - Some("fish") => Ok(generate_fish_integration()), - Some(s) => eyre::bail!("Unsupported shell: {}", s), - None => eyre::bail!("Could not detect shell"), - } -} - -/// Generate the zsh integration function - pure function for easy testing -pub fn generate_zsh_integration() -> &'static str { - r#" -# TUI uses an alternate screen, so no explicit cleanup is needed. -_atuin_ai_cleanup() { - true -} - -# Question mark at start of line - natural language mode. -# Named with 'self-' prefix so bracketed-paste-magic activates it during -# paste, allowing url-quote-magic to escape ? in pasted URLs via self-insert. -self-atuin-ai-question-mark() { - # If buffer is empty or just contains '?', trigger natural language mode - if [[ -z "$BUFFER" || "$BUFFER" == "?" ]]; then - BUFFER="" - local output - output=$(atuin ai inline --hook 3>&1 1>&2 2>&3) - - # Clean up the inline viewport - _atuin_ai_cleanup - - if [[ $output == __atuin_ai_print__:* ]]; then - zle -I - echo "${output#__atuin_ai_print__:}" - elif [[ $output == __atuin_ai_cancel__ ]]; then - zle reset-prompt - elif [[ $output == __atuin_ai_execute__:* ]]; then - RBUFFER="" - LBUFFER=${output#__atuin_ai_execute__:} - zle reset-prompt - zle accept-line - elif [[ $output == __atuin_ai_insert__:* ]]; then - RBUFFER="" - LBUFFER=${output#__atuin_ai_insert__:} - zle reset-prompt - elif [[ -n $output ]]; then - RBUFFER="" - LBUFFER=$output - zle reset-prompt - else - zle reset-prompt - fi - else - zle self-insert - fi -} - -# Set up keybindings -zle -N self-atuin-ai-question-mark -bindkey '?' self-atuin-ai-question-mark # Question mark -"# - .trim() -} - -/// Generate the bash integration function - pure function for easy testing -pub fn generate_bash_integration() -> &'static str { - r#" -# Question mark at start of line - natural language mode -_atuin_ai_question_mark() { - # If buffer is empty or just contains '?', trigger natural language mode - if [[ -z "$READLINE_LINE" || "$READLINE_LINE" == "?" ]]; then - READLINE_LINE="" - READLINE_POINT=0 - - local output - output=$(atuin ai inline --hook 3>&1 1>&2 2>&3) - - if [[ $output == __atuin_ai_print__:* ]]; then - echo "${output#__atuin_ai_print__:}" - READLINE_LINE="" - READLINE_POINT=0 - elif [[ $output == __atuin_ai_cancel__ ]]; then - READLINE_LINE="" - READLINE_POINT=0 - elif [[ $output == __atuin_ai_execute__:* ]]; then - # Execute the command immediately - READLINE_LINE=${output#__atuin_ai_execute__:} - READLINE_POINT=${#READLINE_LINE} - # Note: We can't directly execute in bash bind -x, but we can - # use a workaround by binding to a macro that accepts the line - bind '"\C-x\C-a": accept-line' - bind -x '"\C-x\C-e": _atuin_ai_question_mark' - elif [[ $output == __atuin_ai_insert__:* ]]; then - # Insert the command for editing - READLINE_LINE=${output#__atuin_ai_insert__:} - READLINE_POINT=${#READLINE_LINE} - elif [[ -n $output ]]; then - # Default: insert for editing - READLINE_LINE=$output - READLINE_POINT=${#READLINE_LINE} - fi - else - # Not at empty prompt, just insert the question mark - READLINE_LINE="${READLINE_LINE:0:READLINE_POINT}?${READLINE_LINE:READLINE_POINT}" - ((READLINE_POINT++)) - fi -} - -# Set up keybindings -# Bash requires special handling: we use bind -x for the function, -# but need a two-step approach for execute mode -__atuin_ai_accept_line="" - -_atuin_ai_question_mark_wrapper() { - _atuin_ai_question_mark - if [[ -n "$__atuin_ai_accept_line" ]]; then - __atuin_ai_accept_line="" - fi -} - -bind -x '"?": _atuin_ai_question_mark' -"# - .trim() -} - -/// Generate the fish integration function - pure function for easy testing -pub fn generate_fish_integration() -> &'static str { - r#" -# Question mark at start of line - natural language mode -function _atuin_ai_question_mark - set -l buf (commandline -b) - - # If buffer is empty or just contains '?', trigger natural language mode - if test -z "$buf" -o "$buf" = "?" - commandline -r "" - - # Run atuin ai inline, swapping stdout and stderr - set -l output (atuin ai inline --hook 3>&1 1>&2 2>&3 | string collect) - - if string match --quiet '__atuin_ai_print__:*' "$output" - echo (string replace "__atuin_ai_print__:" "" -- "$output" | string collect) - commandline -f repaint - else if test "$output" = "__atuin_ai_cancel__" - commandline -f repaint - else if string match --quiet '__atuin_ai_execute__:*' "$output" - # Execute the command immediately - set -l cmd (string replace "__atuin_ai_execute__:" "" -- "$output" | string collect) - commandline -r "$cmd" - commandline -f repaint - commandline -f execute - else if string match --quiet '__atuin_ai_insert__:*' "$output" - # Insert the command for editing - set -l cmd (string replace "__atuin_ai_insert__:" "" -- "$output" | string collect) - commandline -r "$cmd" - commandline -f repaint - else if test -n "$output" - # Default: insert for editing - commandline -r "$output" - commandline -f repaint - else - commandline -f repaint - end - else if not contains -- "$fish_key_bindings" fish_vi_key_bindings fish_hybrid_key_bindings - # Not at empty prompt, just insert the question mark - commandline -i "?" - end -end - -# Set up keybindings -bind "?" _atuin_ai_question_mark -"# - .trim() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_zsh_integration() { - let result = generate_zsh_integration(); - assert!(result.contains("self-atuin-ai-question-mark")); - assert!(result.contains("bindkey")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - assert!(result.contains("zle self-insert")); - } - - #[test] - fn test_generate_bash_integration() { - let result = generate_bash_integration(); - assert!(result.contains("_atuin_ai_question_mark")); - assert!(result.contains("bind")); - assert!(result.contains("READLINE_LINE")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - } - - #[test] - fn test_generate_fish_integration() { - let result = generate_fish_integration(); - assert!(result.contains("_atuin_ai_question_mark")); - assert!(result.contains("bind")); - assert!(result.contains("commandline")); - assert!(result.contains("atuin ai inline --hook")); - assert!(result.contains("__atuin_ai_print__")); - assert!(result.contains("__atuin_ai_cancel__")); - assert!(result.contains("__atuin_ai_execute__")); - assert!(result.contains("__atuin_ai_insert__")); - } -} diff --git a/crates/atuin-ai/src/commands/inline.rs b/crates/atuin-ai/src/commands/inline.rs deleted file mode 100644 index 6d1f9c51..00000000 --- a/crates/atuin-ai/src/commands/inline.rs +++ /dev/null @@ -1,587 +0,0 @@ -use std::path::PathBuf; -use std::sync::mpsc; - -use crate::context::{AppContext, ClientContext}; -use crate::driver::{DriverEvent, IoContext, ViewState, run_driver}; -use crate::fsm::AgentFsm; -use crate::fsm::effects::ExitAction; -use crate::session::{LocalSessionService, SessionManager, SessionService}; -use crate::tui::events::AiTuiEvent; -use crate::tui::state::ConversationEvent; -use crate::tui::view::ai_view; -use atuin_client::database::{Database, Sqlite}; -use eye_declare::{Application, CtrlCBehavior}; -use eyre::{Context as _, Result, bail}; -use tracing::{debug, info}; - -pub(crate) async fn run( - initial_command: Option<String>, - api_endpoint: Option<String>, - api_token: Option<String>, - settings: &atuin_client::settings::Settings, - output_for_hook: bool, -) -> Result<()> { - if settings.ai.enabled == Some(false) { - return Ok(()); - } - - if settings.ai.enabled.is_none() { - match prompt_ai_setup()? { - SetupChoice::EnableAi => { - set_ai_enabled(true).await?; - } - SetupChoice::DisableKeybind => { - set_ai_enabled(false).await?; - emit_shell_result(Action::Cancel, output_for_hook); - return Ok(()); - } - SetupChoice::Cancel => { - emit_shell_result(Action::Cancel, output_for_hook); - return Ok(()); - } - } - } - - let endpoint = api_endpoint.as_deref().unwrap_or( - settings - .ai - .endpoint - .as_deref() - .unwrap_or("https://hub.atuin.sh"), - ); - let api_token = api_token.as_deref().or(settings.ai.api_token.as_deref()); - - let token = if let Some(token) = &api_token { - token.to_string() - } else { - ensure_hub_session(settings).await? - }; - - let history_db_path = PathBuf::from(settings.db_path.as_str()); - let history_db = Sqlite::new(history_db_path, settings.local_timeout) - .await - .context("failed to open history database for AI")?; - - // Support both legacy [ai] send_cwd and new [ai.opening] send_cwd - let send_cwd = - settings.ai.opening.send_cwd.unwrap_or(false) || settings.ai.send_cwd.unwrap_or(false); - - let last_command = if settings.ai.opening.send_last_command.unwrap_or(false) { - history_db.last().await.ok().flatten() - } else { - None - }; - - let git_root = std::env::current_dir() - .ok() - .and_then(|cwd| atuin_common::utils::in_git_repo(cwd.to_str()?)); - - let ctx = AppContext { - endpoint: endpoint.to_string(), - token, - send_cwd, - last_command, - history_db: std::sync::Arc::new(history_db), - git_root, - capabilities: settings.ai.capabilities.clone(), - daemon_enabled: settings.daemon.enabled, - }; - - let action = run_inline_tui(ctx, initial_command, settings).await?; - emit_shell_result(action, output_for_hook); - - Ok(()) -} - -async fn ensure_hub_session(settings: &atuin_client::settings::Settings) -> Result<String> { - if let Some(token) = atuin_client::hub::get_session_token().await? { - debug!("Found Hub session, using existing token"); - return Ok(token); - } - - let hub_address = settings.active_hub_endpoint().unwrap_or_default(); - let will_sync = settings.is_hub_sync(); - - info!("No Hub session found, prompting for authentication"); - - println!("Atuin AI requires authenticating with Atuin Hub."); - if will_sync { - println!( - "Once logged in, your shell history will be synchronized via Atuin Hub if auto_sync is enabled or when manually syncing." - ); - } - println!( - "If you have an existing Atuin sync account, you can log in with your existing credentials." - ); - println!("Press enter to begin (or esc to cancel)."); - if !wait_for_login_confirmation()? { - bail!("authentication canceled"); - } - - debug!("Starting Atuin Hub authentication..."); - println!("Authenticating with Atuin Hub..."); - - let session = atuin_client::hub::HubAuthSession::start(hub_address.as_ref()).await?; - println!("Open this URL to continue:"); - println!("{}", session.auth_url); - - let token = session - .wait_for_completion( - atuin_client::hub::DEFAULT_AUTH_TIMEOUT, - atuin_client::hub::DEFAULT_POLL_INTERVAL, - ) - .await?; - - info!("Authentication complete, saving session token"); - atuin_client::hub::save_session(&token).await?; - - if let Ok(meta) = atuin_client::settings::Settings::meta_store().await - && let Ok(Some(cli_token)) = meta.session_token().await - { - debug!("CLI session found, attempting to link accounts"); - if let Err(e) = atuin_client::hub::link_account(hub_address.as_ref(), &cli_token).await { - debug!("Could not link CLI account to Hub: {}", e); - } else { - info!("Successfully linked CLI account to Hub"); - } - } - - Ok(token) -} - -// ─────────────────────────────────────────────────────────────────── - -async fn run_inline_tui( - ctx: AppContext, - initial_prompt: Option<String>, - settings: &atuin_client::settings::Settings, -) -> Result<Action> { - let client_ctx = ClientContext::detect(); - - // Open the session service and check for a resumable session - let service = LocalSessionService::open(&settings.ai.db_path, settings.local_timeout) - .await - .context("failed to open AI session database")?; - - let cwd = std::env::current_dir() - .ok() - .map(|p| p.to_string_lossy().into_owned()); - let git_root_str = ctx - .git_root - .as_ref() - .map(|p| p.to_string_lossy().into_owned()); - - let session_window_mins = settings.ai.session_continue_minutes.max(0); // treat negative values as 0 to avoid confusion - let max_age_secs: i64 = session_window_mins * 60; - - let resumable = service - .find_resumable(cwd.as_deref(), git_root_str.as_deref(), max_age_secs) - .await?; - - // ─── Build FSM ─────────────────────────────────────────────── - let (session_mgr, fsm, file_tracker, edit_permissions) = if let Some(stored) = resumable { - debug!(session_id = %stored.id, "resuming AI session"); - let (mgr, mut events, server_sid, last_event_ts, invocation_id) = - SessionManager::resume(Box::new(service), &stored).await?; - - let has_api_content = events.iter().any(|e| e.is_api_content()); - - if has_api_content { - events.push(ConversationEvent::SystemContext { - content: "[Note: The user has started a new invocation of Atuin AI. Prior messages from this session are from an earlier invocation.]".to_string(), - }); - let view_start = events.len(); - let last_time = last_event_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)); - - let ft = if let Ok(Some(json)) = - mgr.get_metadata(crate::file_tracker::METADATA_KEY).await - && let Ok(tracker) = crate::file_tracker::FileReadTracker::from_json(&json) - { - tracker - } else { - Default::default() - }; - - let ep = if let Ok(Some(json)) = mgr - .get_metadata(crate::edit_permissions::METADATA_KEY) - .await - && let Ok(cache) = crate::edit_permissions::EditPermissionCache::from_json(&json) - { - cache - } else { - Default::default() - }; - - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::from_session( - events, - server_sid, - caps, - invocation_id, - view_start, - true, - last_time, - ); - (mgr, fsm, ft, ep) - } else { - debug!("resumable session has no API-visible content, starting fresh"); - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::new(caps, invocation_id); - (mgr, fsm, Default::default(), Default::default()) - } - } else { - debug!("creating new AI session"); - let mgr = - SessionManager::create_new(Box::new(service), cwd.as_deref(), git_root_str.as_deref()); - let invocation_id = uuid::Uuid::now_v7().to_string(); - let caps = ctx.capabilities_as_strings(); - let fsm = AgentFsm::new(caps, invocation_id); - (mgr, fsm, Default::default(), Default::default()) - }; - - // ─── Snapshot store ───────────────────────────────────────── - let snapshot_dir = atuin_common::utils::data_dir() - .join("ai") - .join("snapshots") - .join(session_mgr.session_id()); - let snapshot_store = crate::snapshots::SnapshotStore::open(snapshot_dir).ok(); - - let in_git_project = ctx.git_root.is_some(); - - // ─── Discover skills ─────────────────────────────────────── - let project_root = ctx - .git_root - .clone() - .or_else(|| std::env::current_dir().ok()); - let skill_registry = crate::skills::SkillRegistry::discover(project_root.as_deref()).await; - - // ─── Build initial ViewState from FSM ─────────────────────── - let initial_view = build_view_state(&fsm, in_git_project, &skill_registry); - - // ─── Build IoContext ──────────────────────────────────────── - let io = IoContext { - app_ctx: ctx.clone(), - client_ctx: client_ctx.clone(), - session_mgr, - file_tracker, - edit_permissions, - snapshot_store, - skill_registry, - }; - - // ─── Channel + Application ────────────────────────────────── - // Components emit DriverEvent::Tui(AiTuiEvent) via a wrapping sender. - // Spawned tasks emit DriverEvent::Fsm(Event) directly. - let (tx, rx) = mpsc::channel::<DriverEvent>(); - - // Wrap sender for components: they send AiTuiEvent, we wrap it - let tui_tx = DriverEventSender(tx.clone()); - - println!(); - - if let Some(prompt) = initial_prompt { - let _ = tui_tx - .0 - .send(DriverEvent::Tui(AiTuiEvent::SubmitInput(prompt))); - } - - let (mut app, handle) = Application::builder() - .state(initial_view) - .view(ai_view) - .ctrl_c(CtrlCBehavior::Deliver) - .keyboard_protocol(eye_declare::KeyboardProtocol::Enhanced) - .bracketed_paste(true) - .with_context(tui_tx) - .extra_newlines_at_exit(1) - .on_commit(|committed, state| { - if let Some(key) = &committed.key - && let Some(id_str) = key.strip_prefix("turn-") - && let Ok(id) = id_str.parse::<usize>() - { - let new_count = id + 1; - if new_count > state.committed_turn_count { - state.committed_turn_count = new_count; - } - } - }) - .build()?; - - // ─── Driver loop ──────────────────────────────────────────── - let h = handle.clone(); - let exiting = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); - let exiting_clone = exiting.clone(); - let dispatch_handle = tokio::task::spawn_blocking(move || { - run_driver(fsm, io, h, rx, tx, exiting_clone, in_git_project); - }); - - let run_result = app.run_loop().await; - let _ = dispatch_handle.await; - run_result?; - - let result = match app.state().exit_action { - Some(ExitAction::Execute(ref cmd)) => Action::Execute(cmd.clone()), - Some(ExitAction::Insert(ref cmd)) => Action::Insert(cmd.clone()), - _ => Action::Cancel, - }; - - Ok(result) -} - -/// Wrapper around `mpsc::Sender<DriverEvent>` that components use as context. -/// -/// Components call `tx.send(AiTuiEvent::...)` via eye-declare's context system. -/// This wrapper implements the same interface but wraps events in `DriverEvent::Tui`. -#[derive(Debug, Clone)] -pub(crate) struct DriverEventSender(pub mpsc::Sender<DriverEvent>); - -impl DriverEventSender { - pub fn send(&self, event: AiTuiEvent) -> Result<(), mpsc::SendError<AiTuiEvent>> { - self.0 - .send(DriverEvent::Tui(event)) - .map_err(|_| mpsc::SendError(AiTuiEvent::Exit)) - } -} - -/// Build a ViewState snapshot from FSM state. Used for the initial view -/// and by the driver for ongoing sync. -fn build_view_state( - fsm: &AgentFsm, - in_git_project: bool, - skill_registry: &crate::skills::SkillRegistry, -) -> ViewState { - let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); - - let mut slash_registry = crate::tui::slash::SlashCommandRegistry::default(); - let mut skill_names = std::collections::HashSet::new(); - for skill in skill_registry.all() { - slash_registry.register(crate::tui::slash::SlashCommand::new( - &skill.name, - &skill.description, - )); - skill_names.insert(skill.name.clone()); - } - - let tools = fsm.ctx.tools.clone(); - let visible_events = fsm.ctx.events[safe_start..].to_vec(); - let archived_events = fsm.ctx.archived_events.clone(); - - let mut archived_builder = crate::tui::view::turn::TurnBuilder::new(&tools); - for event in &archived_events { - archived_builder.add_event(event); - } - let archived_turns = archived_builder.build(); - let archived_turn_count = archived_turns.len(); - - let mut visible_builder = - crate::tui::view::turn::TurnBuilder::new_starting_at(&tools, archived_turn_count); - for event in &visible_events { - visible_builder.add_event(event); - } - let visible_turns = visible_builder.build(); - - let mut turns = archived_turns; - turns.extend(visible_turns); - - let has_command = visible_events.iter().any(|e| { - matches!(e, ConversationEvent::ToolCall { name, input, .. } - if name == "suggest_command" - && input.get("command").and_then(|v| v.as_str()).is_some()) - }); - - ViewState { - agent_state: fsm.state.clone(), - visible_events, - all_events: fsm.ctx.events.clone(), - session_id: fsm.ctx.session_id.clone(), - tools, - current_response: fsm.ctx.current_response.clone(), - is_resumed: fsm.ctx.is_resumed, - last_event_time: fsm.ctx.last_event_time, - in_git_project, - archived_events, - turns, - has_command, - committed_turn_count: 0, - archived_turn_count, - is_input_blank: true, - slash_command_input: None, - slash_command_search_results: Vec::new(), - exit_action: None, - slash_registry, - skill_names, - } -} - -// ─────────────────────────────────────────────────────────────────── -// Helpers -// ─────────────────────────────────────────────────────────────────── - -enum SetupChoice { - EnableAi, - DisableKeybind, - Cancel, -} - -fn prompt_ai_setup() -> Result<SetupChoice> { - use crossterm::{ - cursor, - event::{self, Event, KeyCode}, - terminal, - }; - - let options = ["Enable Atuin AI", "Disable ? Keybind", "Cancel"]; - let mut selected: usize = 0; - let mut stdout = std::io::stdout(); - - // Print header before raw mode so newlines render correctly. - // Use stdout because the shell hook swaps stdout/stderr — stdout goes - // to the terminal in both hook and non-hook modes. - println!(); - println!(" Atuin AI is not yet configured."); - println!(); - - terminal::enable_raw_mode().context("failed to enable raw mode")?; - struct Guard; - impl Drop for Guard { - fn drop(&mut self) { - let _ = terminal::disable_raw_mode(); - } - } - let _guard = Guard; - - crossterm::execute!(stdout, cursor::Hide)?; - - loop { - render_setup_options(&mut stdout, &options, selected)?; - - let ev = event::read().context("failed to read key event")?; - - crossterm::execute!(stdout, cursor::MoveUp(options.len() as u16))?; - - if let Event::Key(key) = ev { - match key.code { - KeyCode::Up | KeyCode::Char('k') => { - selected = selected.saturating_sub(1); - } - KeyCode::Down | KeyCode::Char('j') if selected < options.len() - 1 => { - selected += 1; - } - KeyCode::Enter => break, - KeyCode::Esc => { - selected = 2; - break; - } - _ => {} - } - } - } - - // Final render with selection visible - render_setup_options(&mut stdout, &options, selected)?; - crossterm::execute!(stdout, cursor::Show)?; - - Ok(match selected { - 0 => SetupChoice::EnableAi, - 1 => SetupChoice::DisableKeybind, - _ => SetupChoice::Cancel, - }) -} - -fn render_setup_options( - w: &mut impl std::io::Write, - options: &[&str], - selected: usize, -) -> Result<()> { - use crossterm::{ - style::Stylize, - terminal::{Clear, ClearType}, - }; - - for (i, option) in options.iter().enumerate() { - if i == selected { - write!(w, "\r {}", format!("> {option}").bold().cyan())?; - } else { - write!(w, "\r {option}")?; - } - crossterm::execute!(w, Clear(ClearType::UntilNewLine))?; - write!(w, "\r\n")?; - } - w.flush()?; - Ok(()) -} - -async fn set_ai_enabled(enabled: bool) -> Result<()> { - let config_file = atuin_client::settings::Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let mut doc = config_str.parse::<toml_edit::DocumentMut>()?; - - if !doc.contains_key("ai") { - doc["ai"] = toml_edit::table(); - } - doc["ai"]["enabled"] = toml_edit::value(enabled); - - tokio::fs::write(&config_file, doc.to_string()).await?; - - if !enabled { - println!( - "Atuin AI keybind disabled. You can re-enable with `atuin config set ai.enabled true`.", - ); - println!("Restart your shell for changes to take effect."); - // Two printlns to ensure the message is visible above the shell prompt after program ends. - println!(); - println!(); - } - - Ok(()) -} - -fn wait_for_login_confirmation() -> Result<bool> { - use crossterm::{ - event::{self, Event, KeyCode}, - terminal::{disable_raw_mode, enable_raw_mode}, - }; - - enable_raw_mode().context("failed enabling raw mode for login prompt")?; - struct Guard; - impl Drop for Guard { - fn drop(&mut self) { - let _ = disable_raw_mode(); - } - } - let _guard = Guard; - - loop { - let ev = event::read().context("failed to read login confirmation key")?; - if let Event::Key(key) = ev { - match key.code { - KeyCode::Enter => return Ok(true), - KeyCode::Esc => return Ok(false), - _ => {} - } - } - } -} - -#[derive(Clone)] -enum Action { - Execute(String), - Insert(String), - Cancel, -} - -fn emit_shell_result(action: Action, output_for_hook: bool) { - if output_for_hook { - match action { - Action::Execute(output) => eprintln!("__atuin_ai_execute__:{output}"), - Action::Insert(output) => eprintln!("__atuin_ai_insert__:{output}"), - Action::Cancel => eprintln!("__atuin_ai_cancel__"), - } - } else { - match action { - Action::Execute(output) | Action::Insert(output) => { - println!("{output}"); - } - Action::Cancel => {} - } - } -} diff --git a/crates/atuin-ai/src/context.rs b/crates/atuin-ai/src/context.rs deleted file mode 100644 index f891a9fc..00000000 --- a/crates/atuin-ai/src/context.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::path::PathBuf; -use std::sync::Arc; - -use atuin_client::distro::detect_linux_distribution; -use atuin_client::history::History; -use atuin_client::settings::AiCapabilities; - -/// Session-scoped context for the AI chat session. -/// Holds the API configuration and client settings needed by the event loop and stream task. -#[derive(Clone, Debug)] -pub(crate) struct AppContext { - pub endpoint: String, - pub token: String, - pub send_cwd: bool, - pub last_command: Option<History>, - pub history_db: Arc<atuin_client::database::Sqlite>, - /// Git root of the current working directory, if inside a git repo. - /// Resolves through worktrees to the main repo root. - pub git_root: Option<PathBuf>, - pub capabilities: AiCapabilities, - pub daemon_enabled: bool, -} - -pub(crate) fn history_output_capability_available(daemon_enabled: bool) -> bool { - cfg!(feature = "daemon") && daemon_enabled -} - -impl AppContext { - pub(crate) fn capabilities_as_strings(&self) -> Vec<String> { - let mut caps = vec!["client_invocations".to_string()]; - if self.capabilities.enable_history_search.unwrap_or(true) { - caps.push("client_v1_atuin_history".to_string()); - } - if self.capabilities.enable_file_tools.unwrap_or(true) { - caps.push("client_v1_read_file".to_string()); - caps.push("client_v1_edit_file".to_string()); - caps.push("client_v1_write_file".to_string()); - } - if self.capabilities.enable_command_execution.unwrap_or(true) { - caps.push("client_v1_execute_shell_command".to_string()); - } - if history_output_capability_available(self.daemon_enabled) - && self.capabilities.enable_history_output.unwrap_or(true) - { - caps.push("client_v1_atuin_output".to_string()); - } - caps.push("client_v1_load_skill".to_string()); - if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { - caps.extend( - extra - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()), - ); - } - caps - } -} - -/// Machine identity — computed once per session. -#[derive(Clone, Debug)] -pub(crate) struct ClientContext { - pub os: String, - pub shell: Option<String>, - pub distro: Option<String>, -} - -impl ClientContext { - pub(crate) fn detect() -> Self { - let os = detect_os(); - let shell = crate::commands::detect_shell(); - let distro = if os == "linux" { - Some(detect_linux_distribution()) - } else { - None - }; - Self { os, shell, distro } - } - - /// Serialize to the JSON format the API expects for the "context" field. - /// The `pwd` field is always dynamic (current working directory), so it's - /// computed fresh on each call if `send_cwd` is true. - pub(crate) fn to_json( - &self, - send_cwd: bool, - last_command: Option<&History>, - ) -> serde_json::Value { - let mut ctx = serde_json::json!({ - "os": self.os, - "shell": self.shell, - "pwd": if send_cwd { - std::env::current_dir().ok().map(|p| p.to_string_lossy().into_owned()) - } else { - None - }, - }); - - if let Some(history) = last_command { - ctx["last_command"] = serde_json::json!(crate::history_format::format_last_command( - history, - crate::history_format::current_local_offset(), - )); - } - - if let Some(ref distro) = self.distro { - ctx["distro"] = serde_json::json!(distro); - } - - ctx - } -} - -/// Move the `detect_os` function here since it's about client identity. -fn detect_os() -> String { - match std::env::consts::OS { - "macos" => "macos".to_string(), - "linux" => "linux".to_string(), - "windows" => "windows".to_string(), - other => format!("Other: {other}"), - } -} diff --git a/crates/atuin-ai/src/context_window.rs b/crates/atuin-ai/src/context_window.rs deleted file mode 100644 index dcef05aa..00000000 --- a/crates/atuin-ai/src/context_window.rs +++ /dev/null @@ -1,578 +0,0 @@ -//! Context window management for API requests. -//! -//! Full conversation events are always persisted to disk. This module handles -//! truncation at send time so the API payload stays within a character budget. -//! -//! Strategy: **frozen prefix + live tail**. The first N turns form a stable -//! prefix that stays identical across requests (maximizing prompt cache hits). -//! The most recent turns form the live tail. When the total exceeds the budget, -//! turns between prefix and tail are dropped with a truncation marker. The -//! prefix never shifts, avoiding cache invalidation. - -use std::ops::Range; - -use crate::tui::{ConversationEvent, events_to_messages}; - -/// Default character budget for the context window. -/// Roughly ~50K tokens at ~4 chars/token — generous enough that truncation -/// only kicks in for genuinely long sessions. -const DEFAULT_BUDGET_CHARS: usize = 200_000; - -/// Number of initial turns to freeze as the stable prefix. -const FROZEN_PREFIX_TURNS: usize = 1; - -/// Builds API messages from conversation events while respecting a character -/// budget using frozen prefix + live tail truncation. -pub(crate) struct ContextWindowBuilder { - budget: usize, -} - -impl ContextWindowBuilder { - pub fn new(budget: usize) -> Self { - Self { budget } - } - - pub fn with_default_budget() -> Self { - Self::new(DEFAULT_BUDGET_CHARS) - } - - /// Build API messages from conversation events, applying the context - /// window budget. Returns the messages to send in the API request. - pub fn build(&self, events: &[ConversationEvent]) -> Vec<serde_json::Value> { - if events.is_empty() { - return Vec::new(); - } - - let turns = group_into_turns(events); - - // Convert each turn's events to API messages independently. - // This is safe because the combining logic (Text + ToolCall merging) - // only operates within a single assistant response, which never - // spans turn boundaries. - let turn_messages: Vec<Vec<serde_json::Value>> = turns - .iter() - .map(|range| events_to_messages(&events[range.clone()])) - .collect(); - - let turn_chars: Vec<usize> = turn_messages.iter().map(|m| estimate_chars(m)).collect(); - let total_chars: usize = turn_chars.iter().sum(); - - if total_chars <= self.budget { - return turn_messages.into_iter().flatten().collect(); - } - - // --- Over budget: apply frozen prefix + live tail --- - - let prefix_count = FROZEN_PREFIX_TURNS.min(turns.len()); - let prefix_chars: usize = turn_chars[..prefix_count].iter().sum(); - - let marker = truncation_marker(); - let marker_chars = estimate_chars(std::slice::from_ref(&marker)); - - let mut remaining = self.budget.saturating_sub(prefix_chars + marker_chars); - - // Work backwards from the end, accumulating tail turns that fit. - let mut tail_start = turns.len(); - for i in (prefix_count..turns.len()).rev() { - if turn_chars[i] <= remaining { - remaining -= turn_chars[i]; - tail_start = i; - } else { - break; - } - } - - // Always include at least the most recent turn, even if it alone - // exceeds the budget — sending something is better than nothing. - if tail_start >= turns.len() && turns.len() > prefix_count { - tail_start = turns.len() - 1; - } - - let mut result = Vec::new(); - - // Frozen prefix - for msgs in &turn_messages[..prefix_count] { - result.extend(msgs.iter().cloned()); - } - - // Truncation marker (only if turns were actually dropped) - if tail_start > prefix_count { - result.push(marker); - } - - // Live tail - for msgs in &turn_messages[tail_start..] { - result.extend(msgs.iter().cloned()); - } - - result - } -} - -/// Marker message inserted where turns were dropped. Uses user role since -/// the preceding prefix typically ends with an assistant message. -fn truncation_marker() -> serde_json::Value { - serde_json::json!({ - "role": "user", - "content": "[Earlier conversation context was omitted to fit within the context window. The conversation continues below.]" - }) -} - -/// Group conversation events into turns. A new turn starts at each -/// `UserMessage` or `SystemContext` event. Everything between boundaries -/// belongs to the preceding turn (assistant text, tool calls, tool results, -/// out-of-band output). -fn group_into_turns(events: &[ConversationEvent]) -> Vec<Range<usize>> { - let mut turns = Vec::new(); - let mut start = 0; - - for (i, event) in events.iter().enumerate() { - if i > start - && matches!( - event, - ConversationEvent::UserMessage { .. } | ConversationEvent::SystemContext { .. } - ) - { - turns.push(start..i); - start = i; - } - } - - if start < events.len() { - turns.push(start..events.len()); - } - - turns -} - -/// Rough character-count estimate for a set of messages. Uses the JSON -/// serialization length as a proxy — not exact tokens, but proportional -/// and cheap to compute. -fn estimate_chars(messages: &[serde_json::Value]) -> usize { - messages.iter().map(|m| m.to_string().len()).sum() -} - -#[cfg(test)] -mod tests { - use super::*; - - fn user(content: &str) -> ConversationEvent { - ConversationEvent::UserMessage { - content: content.to_string(), - } - } - - fn text(content: &str) -> ConversationEvent { - ConversationEvent::Text { - content: content.to_string(), - } - } - - fn tool_call(id: &str, name: &str) -> ConversationEvent { - ConversationEvent::ToolCall { - id: id.to_string(), - name: name.to_string(), - input: serde_json::json!({"command": "ls"}), - } - } - - fn tool_result(tool_use_id: &str, content: &str) -> ConversationEvent { - ConversationEvent::ToolResult { - tool_use_id: tool_use_id.to_string(), - content: content.to_string(), - is_error: false, - remote: false, - content_length: None, - } - } - - fn system_context(content: &str) -> ConversationEvent { - ConversationEvent::SystemContext { - content: content.to_string(), - } - } - - fn oob(content: &str) -> ConversationEvent { - ConversationEvent::OutOfBandOutput { - name: "test".to_string(), - command: None, - content: content.to_string(), - } - } - - // --- group_into_turns --- - - #[test] - fn empty_events_produce_no_turns() { - assert!(group_into_turns(&[]).is_empty()); - } - - #[test] - fn single_user_message_is_one_turn() { - let events = vec![user("hello")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..1]); - } - - #[test] - fn user_assistant_is_one_turn() { - let events = vec![user("hello"), text("hi there")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2]); - } - - #[test] - fn two_turns_split_at_user_message() { - let events = vec![ - user("first"), - text("response 1"), - user("second"), - text("response 2"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2, 2..4]); - } - - #[test] - fn tool_calls_and_results_stay_in_same_turn() { - let events = vec![ - user("list files"), - text("Let me check"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", "file1\nfile2"), - text("Here are your files"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..5]); - } - - #[test] - fn system_context_starts_new_turn() { - let events = vec![ - user("hello"), - text("hi"), - system_context("invocation boundary"), - user("next question"), - text("answer"), - ]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..2, 2..3, 3..5]); - } - - #[test] - fn oob_events_stay_in_current_turn() { - let events = vec![user("hello"), oob("some output"), text("response")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..3]); - } - - #[test] - fn leading_text_without_user_message() { - // Edge case: events start with assistant text (shouldn't happen - // normally but handle gracefully) - let events = vec![text("orphaned"), user("hello"), text("hi")]; - let turns = group_into_turns(&events); - assert_eq!(turns, vec![0..1, 1..3]); - } - - // --- ContextWindowBuilder --- - - #[test] - fn empty_events_produce_empty_messages() { - let builder = ContextWindowBuilder::with_default_budget(); - assert!(builder.build(&[]).is_empty()); - } - - #[test] - fn under_budget_returns_all_messages() { - let events = vec![user("hello"), text("hi"), user("how are you"), text("good")]; - let builder = ContextWindowBuilder::with_default_budget(); - let messages = builder.build(&events); - - // Should produce 4 messages (2 user + 2 assistant) - assert_eq!(messages.len(), 4); - assert_eq!(messages[0]["role"], "user"); - assert_eq!(messages[0]["content"], "hello"); - assert_eq!(messages[1]["role"], "assistant"); - assert_eq!(messages[1]["content"], "hi"); - assert_eq!(messages[2]["role"], "user"); - assert_eq!(messages[2]["content"], "how are you"); - assert_eq!(messages[3]["role"], "assistant"); - assert_eq!(messages[3]["content"], "good"); - } - - #[test] - fn over_budget_truncates_middle_turns() { - // Create events where each turn has known content. Use a tiny - // budget so truncation is triggered with just a few turns. - let events = vec![ - user("turn-1-user"), - text("turn-1-assistant"), - user("turn-2-user"), - text("turn-2-assistant"), - user("turn-3-user"), - text("turn-3-assistant"), - user("turn-4-user"), - text("turn-4-assistant-final"), - ]; - - // Calculate sizes to set budget that keeps turn 1 (prefix) + turn 4 (tail) - // but drops turns 2 and 3. - let all_messages = events_to_messages(&events); - let total_chars: usize = all_messages.iter().map(|m| m.to_string().len()).sum(); - - // Set budget to roughly half — enough for prefix + last turn + marker - let turn1_msgs = events_to_messages(&events[0..2]); - let turn4_msgs = events_to_messages(&events[6..8]); - let marker_chars = estimate_chars(std::slice::from_ref(&truncation_marker())); - let needed = estimate_chars(&turn1_msgs) + estimate_chars(&turn4_msgs) + marker_chars; - - // Budget allows prefix + marker + last turn but not the middle turns - assert!( - needed < total_chars, - "test setup: needed ({needed}) should be less than total ({total_chars})" - ); - let builder = ContextWindowBuilder::new(needed + 10); // small margin - - let messages = builder.build(&events); - - // Should have: turn 1 (2 msgs) + marker (1 msg) + turn 4 (2 msgs) = 5 - assert_eq!(messages.len(), 5, "expected prefix + marker + tail"); - assert_eq!(messages[0]["content"], "turn-1-user"); - assert_eq!(messages[1]["content"], "turn-1-assistant"); - assert!( - messages[2]["content"].as_str().unwrap().contains("omitted"), - "middle message should be truncation marker" - ); - assert_eq!(messages[3]["content"], "turn-4-user"); - assert_eq!(messages[4]["content"], "turn-4-assistant-final"); - } - - #[test] - fn very_tight_budget_keeps_prefix_and_last_turn() { - let events = vec![ - user("first"), - text("response-1"), - user("second"), - text("response-2"), - user("third"), - text("response-3"), - ]; - - // Budget of 1 — forces the "always include last turn" fallback - let builder = ContextWindowBuilder::new(1); - let messages = builder.build(&events); - - // Should have prefix (turn 1) + marker + last turn (turn 3) - assert!( - messages.len() >= 3, - "should have at least prefix + marker + tail" - ); - - // First message should be from turn 1 - assert_eq!(messages[0]["content"], "first"); - - // Last messages should be from the final turn - let last = messages.last().unwrap(); - assert_eq!(last["content"], "response-3"); - } - - #[test] - fn single_turn_always_returned() { - let events = vec![user("hello"), text("hi there")]; - - // Even with a tiny budget, the single turn must be returned - let builder = ContextWindowBuilder::new(1); - let messages = builder.build(&events); - assert_eq!(messages.len(), 2); - } - - #[test] - fn tool_calls_preserved_through_truncation() { - let events = vec![ - // Turn 1: simple exchange - user("turn 1"), - text("response 1"), - // Turn 2: with tool calls (will be dropped) - user("turn 2"), - text("checking"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", "output"), - text("done"), - // Turn 3: final turn (kept in tail) - user("turn 3"), - text("final response"), - ]; - - // Budget that fits turn 1 + turn 3 + marker but not turn 2 - let turn1 = events_to_messages(&events[0..2]); - let turn3 = events_to_messages(&events[7..9]); - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Verify turn 2 (the tool call turn) was dropped - let has_tool_use = messages.iter().any(|m| { - m["content"] - .as_array() - .is_some_and(|arr| arr.iter().any(|b| b["type"] == "tool_use")) - }); - assert!(!has_tool_use, "tool call turn should have been truncated"); - - // Verify first and last turns present - assert_eq!(messages[0]["content"], "turn 1"); - assert_eq!(messages.last().unwrap()["content"], "final response"); - } - - #[test] - fn tail_accumulates_multiple_turns_when_budget_allows() { - // Use long content so turn sizes dwarf the truncation marker. - let padding = "x".repeat(500); - let events = vec![ - user(&format!("turn-1-user-{padding}")), - text(&format!("turn-1-response-{padding}")), - user(&format!("turn-2-user-{padding}")), - text(&format!("turn-2-response-{padding}")), - user(&format!("turn-3-user-{padding}")), - text(&format!("turn-3-response-{padding}")), - user(&format!("turn-4-user-{padding}")), - text(&format!("turn-4-response-{padding}")), - ]; - - // Budget that fits everything except turn 2 - let all = events_to_messages(&events); - let total = estimate_chars(&all); - let turn2 = events_to_messages(&events[2..4]); - let turn2_chars = estimate_chars(&turn2); - - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = total - turn2_chars + marker_cost + 5; - assert!( - budget < total, - "budget must be less than total for truncation to trigger" - ); - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Should have: prefix (t1: 2 msgs) + marker (1 msg) + t3 (2 msgs) + t4 (2 msgs) = 7 - // (turn 2 dropped) - assert_eq!(messages.len(), 7); - assert!( - messages[0]["content"] - .as_str() - .unwrap() - .starts_with("turn-1-user-") - ); - assert!( - messages[1]["content"] - .as_str() - .unwrap() - .starts_with("turn-1-response-") - ); - assert!(messages[2]["content"].as_str().unwrap().contains("omitted")); - assert!( - messages[3]["content"] - .as_str() - .unwrap() - .starts_with("turn-3-user-") - ); - assert!( - messages[4]["content"] - .as_str() - .unwrap() - .starts_with("turn-3-response-") - ); - assert!( - messages[5]["content"] - .as_str() - .unwrap() - .starts_with("turn-4-user-") - ); - assert!( - messages[6]["content"] - .as_str() - .unwrap() - .starts_with("turn-4-response-") - ); - } - - #[test] - fn no_marker_when_no_turns_dropped() { - // Two turns, both fit in budget - let events = vec![user("a"), text("b"), user("c"), text("d")]; - - let builder = ContextWindowBuilder::with_default_budget(); - let messages = builder.build(&events); - - // No truncation marker - assert_eq!(messages.len(), 4); - assert!( - !messages - .iter() - .any(|m| m["content"].as_str().is_some_and(|s| s.contains("omitted"))) - ); - } - - #[test] - fn tool_use_and_tool_result_never_split() { - // Invariant: a tool_use and its matching tool_result must always - // end up in the same turn, so truncation can't orphan one from - // the other. This test verifies that ToolResult does NOT start - // a new turn boundary. - let padding = "x".repeat(500); - let events = vec![ - // Turn 1 (prefix) - user(&format!("turn-1-{padding}")), - text(&format!("resp-1-{padding}")), - // Turn 2: contains a tool_use → tool_result pair (will be dropped) - user(&format!("turn-2-{padding}")), - text("checking"), - tool_call("tc1", "suggest_command"), - tool_result("tc1", &format!("output-{padding}")), - text(&format!("done-{padding}")), - // Turn 3 (tail) - user(&format!("turn-3-{padding}")), - text(&format!("resp-3-{padding}")), - ]; - - // Budget that fits turn 1 + turn 3 + marker, but not turn 2 - let turn1 = events_to_messages(&events[0..2]); - let turn3 = events_to_messages(&events[7..9]); - let marker_cost = estimate_chars(std::slice::from_ref(&truncation_marker())); - let budget = estimate_chars(&turn1) + estimate_chars(&turn3) + marker_cost + 10; - - let builder = ContextWindowBuilder::new(budget); - let messages = builder.build(&events); - - // Verify: every tool_use has a matching tool_result, and vice versa - let tool_use_ids: Vec<&str> = messages - .iter() - .filter_map(|m| m["content"].as_array()) - .flatten() - .filter(|b| b["type"] == "tool_use") - .filter_map(|b| b["id"].as_str()) - .collect(); - - let tool_result_ids: Vec<&str> = messages - .iter() - .filter_map(|m| m["content"].as_array()) - .flatten() - .filter(|b| b["type"] == "tool_result") - .filter_map(|b| b["tool_use_id"].as_str()) - .collect(); - - assert_eq!( - tool_use_ids, tool_result_ids, - "every tool_use must have a matching tool_result (and vice versa)" - ); - - // Turn 2 was dropped entirely, so no tool IDs should be present - assert!( - !tool_use_ids.contains(&"tc1"), - "dropped turn's tool_use should not appear" - ); - } -} diff --git a/crates/atuin-ai/src/diff.rs b/crates/atuin-ai/src/diff.rs deleted file mode 100644 index e704175c..00000000 --- a/crates/atuin-ai/src/diff.rs +++ /dev/null @@ -1,328 +0,0 @@ -//! Structured diff computation for edit previews. -//! -//! Computes a line-level diff between old and new file content using -//! imara-diff's Histogram algorithm, producing structured hunks with -//! typed lines (Context, Added, Removed) suitable for TUI rendering. - -use imara_diff::{Algorithm, Diff, InternedInput}; - -/// Number of context lines to show around each change. -const CONTEXT_LINES: u32 = 3; - -/// A structured diff preview for a file edit, ready for rendering. -#[derive(Debug, Clone)] -pub(crate) struct EditPreview { - pub hunks: Vec<DiffHunk>, -} - -/// A contiguous group of diff lines (context + changes). -#[derive(Debug, Clone)] -pub(crate) struct DiffHunk { - /// 1-indexed line number of the first line in this hunk (in the original file). - pub before_start: u32, - /// 1-indexed line number of the first line in this hunk (in the new file). - pub after_start: u32, - pub lines: Vec<DiffLine>, -} - -/// A single line in a diff hunk. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum DiffLine { - /// Unchanged line (shown for context). - Context(String), - /// Line added in the new version. - Added(String), - /// Line removed from the old version. - Removed(String), -} - -impl EditPreview { - /// Compute a structured diff between old and new file content. - /// - /// Uses the Histogram algorithm with line-level granularity and - /// indentation-aware postprocessing for readable output. - pub fn compute(old: &str, new: &str) -> Self { - let input = InternedInput::new(old, new); - let mut diff = Diff::compute(Algorithm::Histogram, &input); - diff.postprocess_lines(&input); - - let raw_hunks: Vec<_> = diff.hunks().collect(); - if raw_hunks.is_empty() { - return EditPreview { hunks: Vec::new() }; - } - - // Merge hunks that are within 2*CONTEXT_LINES of each other - // (same logic as unified diff format). - let mut merged_groups: Vec<Vec<&imara_diff::Hunk>> = Vec::new(); - let mut current_group: Vec<&imara_diff::Hunk> = vec![&raw_hunks[0]]; - - for hunk in &raw_hunks[1..] { - let prev = current_group.last().unwrap(); - if hunk.before.start.saturating_sub(prev.before.end) <= 2 * CONTEXT_LINES { - current_group.push(hunk); - } else { - merged_groups.push(current_group); - current_group = vec![hunk]; - } - } - merged_groups.push(current_group); - - // Build structured hunks from merged groups - let hunks = merged_groups - .into_iter() - .map(|group| build_hunk(&group, &input)) - .collect(); - - EditPreview { hunks } - } - - /// The highest line number (from either file) that will be displayed. - /// Used to calculate gutter width. - pub fn max_line_number(&self) -> u32 { - self.hunks - .iter() - .map(|h| { - let mut before_pos = h.before_start; - let mut after_pos = h.after_start; - for line in &h.lines { - match line { - DiffLine::Context(_) => { - before_pos += 1; - after_pos += 1; - } - DiffLine::Removed(_) => before_pos += 1, - DiffLine::Added(_) => after_pos += 1, - } - } - before_pos.max(after_pos).saturating_sub(1) - }) - .max() - .unwrap_or(0) - } -} - -/// Maximum lines to show in a write preview. -const WRITE_PREVIEW_LINES: usize = 10; - -/// A content preview for a write_file operation. -/// -/// Shows the first N lines of the written content plus a count of -/// remaining lines if truncated. -#[derive(Debug, Clone)] -pub(crate) struct WritePreview { - /// First lines of content (up to WRITE_PREVIEW_LINES). - pub lines: Vec<String>, - /// Total number of lines in the written file. - pub total_lines: usize, -} - -impl WritePreview { - /// Create a preview from file content. - pub fn from_content(content: &str) -> Self { - let all_lines: Vec<&str> = content.lines().collect(); - let total_lines = all_lines.len(); - let lines = all_lines - .into_iter() - .take(WRITE_PREVIEW_LINES) - .map(String::from) - .collect(); - WritePreview { lines, total_lines } - } - - /// Number of lines not shown in the preview. - pub fn remaining_lines(&self) -> usize { - self.total_lines.saturating_sub(self.lines.len()) - } -} - -/// Build a single DiffHunk from a group of adjacent raw hunks. -fn build_hunk(group: &[&imara_diff::Hunk], input: &InternedInput<&str>) -> DiffHunk { - let first = group.first().unwrap(); - let last = group.last().unwrap(); - - let context_start = first.before.start.saturating_sub(CONTEXT_LINES); - let context_end = (last.before.end + CONTEXT_LINES).min(input.before.len() as u32); - - // The after-file position of context_start: same offset as before since - // context before the first change is identical in both files. - let after_context_start = first.after.start - (first.before.start - context_start); - - let mut lines = Vec::new(); - let mut pos = context_start; - - for hunk in group { - // Context lines before this hunk - for i in pos..hunk.before.start { - lines.push(DiffLine::Context(token_text(input, true, i))); - } - - // Removed lines - for i in hunk.before.start..hunk.before.end { - lines.push(DiffLine::Removed(token_text(input, true, i))); - } - - // Added lines - for i in hunk.after.start..hunk.after.end { - lines.push(DiffLine::Added(token_text(input, false, i))); - } - - pos = hunk.before.end; - } - - // Trailing context - for i in pos..context_end { - lines.push(DiffLine::Context(token_text(input, true, i))); - } - - DiffHunk { - before_start: context_start + 1, // 1-indexed - after_start: after_context_start + 1, // 1-indexed - lines, - } -} - -/// Extract the text content of a token, trimming the trailing newline -/// that imara-diff includes in line-based tokenization. -fn token_text(input: &InternedInput<&str>, is_before: bool, idx: u32) -> String { - let tokens = if is_before { - &input.before - } else { - &input.after - }; - let text = input.interner[tokens[idx as usize]]; - text.strip_suffix('\n') - .unwrap_or(text) - .strip_suffix('\r') - .unwrap_or(text.strip_suffix('\n').unwrap_or(text)) - .to_string() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn no_changes_produces_empty_preview() { - let preview = EditPreview::compute("hello\nworld\n", "hello\nworld\n"); - assert!(preview.hunks.is_empty()); - } - - #[test] - fn single_line_replacement() { - let old = "line1\nline2\nline3\n"; - let new = "line1\nchanged\nline3\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - - // Should have: context(line1), removed(line2), added(changed), context(line3) - assert!(hunk.lines.contains(&DiffLine::Context("line1".into()))); - assert!(hunk.lines.contains(&DiffLine::Removed("line2".into()))); - assert!(hunk.lines.contains(&DiffLine::Added("changed".into()))); - assert!(hunk.lines.contains(&DiffLine::Context("line3".into()))); - } - - #[test] - fn addition_only() { - let old = "aaa\nbbb\n"; - let new = "aaa\nnew_line\nbbb\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!(hunk.lines.contains(&DiffLine::Added("new_line".into()))); - // Original lines are context - assert!(hunk.lines.contains(&DiffLine::Context("aaa".into()))); - assert!(hunk.lines.contains(&DiffLine::Context("bbb".into()))); - } - - #[test] - fn removal_only() { - let old = "aaa\nremove_me\nbbb\n"; - let new = "aaa\nbbb\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!(hunk.lines.contains(&DiffLine::Removed("remove_me".into()))); - } - - #[test] - fn distant_changes_produce_separate_hunks() { - // Two changes separated by more than 2*CONTEXT_LINES (6) lines - let old = "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n"; - let new = "1\nX\n3\n4\n5\n6\n7\n8\n9\n10\n11\nY\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 2); - } - - #[test] - fn close_changes_merge_into_one_hunk() { - // Two changes separated by fewer than 2*CONTEXT_LINES lines - let old = "1\n2\n3\n4\n5\n"; - let new = "X\n2\n3\n4\nY\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - } - - #[test] - fn context_is_limited() { - // With CONTEXT_LINES=3, a change at line 10 shouldn't include line 1 - let mut old_lines: Vec<&str> = (1..=20).map(|_| "unchanged").collect(); - old_lines[9] = "target"; - let old = old_lines.join("\n") + "\n"; - let new = old.replace("target", "replaced"); - - let preview = EditPreview::compute(&old, &new); - assert_eq!(preview.hunks.len(), 1); - - // Should have at most 3 context lines before + 3 after + 1 removed + 1 added = 8 lines - assert!(preview.hunks[0].lines.len() <= 8); - } - - #[test] - fn max_line_number_reflects_file_position() { - let old = "a\nb\nc\n"; - let new = "a\nX\nc\n"; - let preview = EditPreview::compute(old, new); - // 3-line file, context + removed lines span positions 1-3 - assert_eq!(preview.max_line_number(), 3); - } - - #[test] - fn start_line_is_correct_for_later_changes() { - // Change at line 10 with 3 context lines → before_start = 7 - let mut lines: Vec<String> = (1..=15).map(|i| format!("line{i}")).collect(); - let old = lines.join("\n") + "\n"; - lines[9] = "CHANGED".to_string(); - let new = lines.join("\n") + "\n"; - - let preview = EditPreview::compute(&old, &new); - assert_eq!(preview.hunks.len(), 1); - assert_eq!(preview.hunks[0].before_start, 7); // line 10 - 3 context = line 7 - assert_eq!(preview.hunks[0].after_start, 7); // same for a simple replacement - } - - #[test] - fn multiline_replacement() { - let old = "[section]\nkey1 = old1\nkey2 = old2\n[other]\n"; - let new = "[section]\nkey1 = new1\nkey2 = new2\n[other]\n"; - let preview = EditPreview::compute(old, new); - - assert_eq!(preview.hunks.len(), 1); - let hunk = &preview.hunks[0]; - assert!( - hunk.lines - .contains(&DiffLine::Removed("key1 = old1".into())) - ); - assert!( - hunk.lines - .contains(&DiffLine::Removed("key2 = old2".into())) - ); - assert!(hunk.lines.contains(&DiffLine::Added("key1 = new1".into()))); - assert!(hunk.lines.contains(&DiffLine::Added("key2 = new2".into()))); - } -} diff --git a/crates/atuin-ai/src/driver.rs b/crates/atuin-ai/src/driver.rs deleted file mode 100644 index 82d666ef..00000000 --- a/crates/atuin-ai/src/driver.rs +++ /dev/null @@ -1,1030 +0,0 @@ -//! Driver loop for the agent FSM. -//! -//! Receives events from the channel, calls `fsm.handle()`, syncs ViewState -//! to the Handle, and executes effects (spawning async tasks for IO). -//! -//! The driver runs on a blocking thread (`spawn_blocking`) so it can call -//! `blocking_recv()` on the Handle and `block_on()` for async persistence. - -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc; - -use eye_declare::Handle; - -use crate::context::{AppContext, ClientContext}; -use crate::edit_permissions::EditPermissionCache; -use crate::file_tracker::FileReadTracker; -use crate::fsm::effects::{Effect, ExitAction, PermissionTarget}; -use crate::fsm::events::{Event, PermissionChoice, PermissionResponse}; -use crate::fsm::tools::ToolPreviewData; -use crate::fsm::{AgentFsm, AgentState}; -use crate::permissions::resolver::PermissionResolver; -use crate::permissions::writer; -use crate::session::SessionManager; -use crate::stream::ChatRequest; -use crate::tools::ClientToolCall; -use crate::tui::events::{AiTuiEvent, PermissionResult}; -use crate::tui::state::ConversationEvent; -use crate::tui::view::turn; - -// ============================================================================ -// Driver event — the unified channel type -// ============================================================================ - -/// Events processed by the driver loop. -/// -/// Components emit `Tui` variants via the channel. Spawned async tasks -/// (stream, tool execution) emit `Fsm` variants directly. -#[derive(Debug)] -pub(crate) enum DriverEvent { - /// Event from a TUI component (key press, input change, etc.) - Tui(AiTuiEvent), - /// Internal FSM event (from spawned stream/tool tasks) - Fsm(Event), -} - -// ============================================================================ -// IO context (driver-owned, not visible to FSM) -// ============================================================================ - -pub(crate) struct IoContext { - pub app_ctx: AppContext, - pub client_ctx: ClientContext, - pub session_mgr: SessionManager, - pub file_tracker: FileReadTracker, - pub edit_permissions: EditPermissionCache, - pub snapshot_store: Option<crate::snapshots::SnapshotStore>, - pub skill_registry: crate::skills::SkillRegistry, -} - -// ============================================================================ -// ViewState (Handle payload for the render thread) -// ============================================================================ - -/// State pushed to the Handle for the view/render thread. -/// Synced from the FSM after each transition. -#[derive(Debug)] -pub(crate) struct ViewState { - // ─── From FSM ─────────────────────────────────────────────── - pub agent_state: AgentState, - pub visible_events: Vec<ConversationEvent>, - pub all_events: Vec<ConversationEvent>, - pub session_id: Option<String>, - pub tools: crate::fsm::tools::ToolManager, - pub current_response: String, - - // ─── Session metadata (set once) ──────────────────────────── - pub is_resumed: bool, - pub last_event_time: Option<chrono::DateTime<chrono::Utc>>, - pub in_git_project: bool, - - // ─── View-only ────────────────────────────────────────────── - pub archived_events: Vec<ConversationEvent>, - - // ─── Pre-computed for rendering ──────────────────────────── - pub turns: Vec<turn::UiTurn>, - pub has_command: bool, - pub committed_turn_count: usize, - pub archived_turn_count: usize, - - // ─── Ephemeral interaction state ──────────────────────────── - pub is_input_blank: bool, - pub slash_command_input: Option<String>, - pub slash_command_search_results: Vec<crate::tui::slash::SlashCommandSearchResult>, - pub exit_action: Option<ExitAction>, - pub slash_registry: crate::tui::slash::SlashCommandRegistry, - pub skill_names: std::collections::HashSet<String>, -} - -impl ViewState { - pub fn is_exiting(&self) -> bool { - self.exit_action.is_some() - } - - pub fn is_busy(&self) -> bool { - matches!(self.agent_state, AgentState::Turn { .. }) - } - - pub fn has_confirmation(&self) -> bool { - matches!( - self.agent_state, - AgentState::Idle { - confirmation: Some(_) - } - ) - } - - pub fn is_input_active(&self) -> bool { - matches!(self.agent_state, AgentState::Idle { .. }) && !self.has_confirmation() - } - - pub fn footer_text(&self) -> &'static str { - match &self.agent_state { - AgentState::Idle { confirmation: None } => { - if self.has_command && self.is_input_blank { - "[Enter] Execute suggested command [Tab] Insert Command" - } else { - "[Enter] Send [Shift+Enter] New line [Esc] Exit" - } - } - AgentState::Idle { - confirmation: Some(_), - } => "[Enter] Confirm dangerous command [Esc] Cancel", - AgentState::Turn { .. } => "[Esc] Cancel", - AgentState::Error(_) => "[Enter]/[r] Retry [Esc] Exit", - } - } -} - -// ============================================================================ -// Main driver loop -// ============================================================================ - -struct DriverContext<'a> { - fsm: &'a mut AgentFsm, - io: &'a mut IoContext, - handle: &'a Handle<ViewState>, - tx: &'a mpsc::Sender<DriverEvent>, - exiting: &'a Arc<AtomicBool>, - stream_cancel_tx: &'a mut Option<tokio::sync::watch::Sender<()>>, - tool_abort_txs: &'a mut std::collections::HashMap<String, tokio::sync::oneshot::Sender<()>>, -} - -/// Main driver loop. Processes events, transitions FSM, syncs view, executes effects. -/// -/// Runs on a blocking thread. Returns when the event channel closes or exit is requested. -/// The Handle already contains the initial ViewState (set by Application::builder). -pub(crate) fn run_driver( - mut fsm: AgentFsm, - mut io: IoContext, - handle: Handle<ViewState>, - rx: mpsc::Receiver<DriverEvent>, - tx: mpsc::Sender<DriverEvent>, - exiting: Arc<AtomicBool>, - in_git_project: bool, -) { - // Dropping the sender cancels the stream (receiver sees Err on changed()). - let mut stream_cancel_tx: Option<tokio::sync::watch::Sender<()>> = None; - // Per-tool interrupt senders for shell commands. - let mut tool_abort_txs: std::collections::HashMap<String, tokio::sync::oneshot::Sender<()>> = - std::collections::HashMap::new(); - - while let Ok(driver_event) = rx.recv() { - // Log and translate DriverEvent to FSM Event (or handle directly) - let fsm_event = match driver_event { - DriverEvent::Fsm(event) => { - tracing::trace!(?event, state = ?fsm.state, "FSM event"); - Some(event) - } - DriverEvent::Tui(tui_event) => { - tracing::trace!(?tui_event, state = ?fsm.state, "TUI event"); - translate_tui_event(tui_event, &handle) - } - }; - - if let Some(event) = fsm_event { - // Feed event to FSM - let effects = fsm.handle(event); - tracing::trace!(?effects, state = ?fsm.state, "FSM transition"); - - // Sync ViewState to Handle (FSM owns all state now) - sync_view_state(&handle, &fsm, in_git_project); - - // Execute effects (only persist when FSM says to) - for effect in &effects { - if matches!(effect, Effect::Persist) { - persist(&fsm, &mut io); - } - - let ctx = DriverContext { - fsm: &mut fsm, - io: &mut io, - handle: &handle, - tx: &tx, - exiting: &exiting, - stream_cancel_tx: &mut stream_cancel_tx, - tool_abort_txs: &mut tool_abort_txs, - }; - - execute_effect(effect, ctx); - } - - // Final sync after effects — ensures the render thread sees - // the absolute final state even if effects modified anything. - if !effects.is_empty() { - sync_view_state(&handle, &fsm, in_git_project); - } - } - // InputUpdated (the only event that returns None) already pushed - // its view-only changes via handle.update() — no FSM state changed, - // so skip the expensive sync_view_state that clones all events. - - if exiting.load(Ordering::Acquire) { - break; - } - tracing::trace!(state = ?fsm.state, "driver loop iteration complete, waiting for next event"); - } -} - -// ============================================================================ -// TUI event translation -// ============================================================================ - -/// Translate a TUI event into an FSM event. -/// Returns None for events handled directly (e.g. InputUpdated). -fn translate_tui_event(event: AiTuiEvent, handle: &Handle<ViewState>) -> Option<Event> { - match event { - AiTuiEvent::SubmitInput(input) => { - // Clear slash state and reset is_input_blank (the InputBox clears - // its text on submit but doesn't fire InputUpdated for the clear). - handle.update(|vs| { - vs.slash_command_input = None; - vs.slash_command_search_results.clear(); - vs.is_input_blank = true; - }); - - let input = input.trim().to_string(); - if input.is_empty() { - Some(Event::ExecuteCommand) - } else if input == "/new" { - Some(Event::NewSession) - } else if input.starts_with('/') { - if let Some((skill_name, arguments)) = resolve_skill_name(&input, handle) { - Some(Event::RequestSkillLoad { - name: skill_name, - arguments, - }) - } else { - let content = resolve_slash_command(&input, handle); - Some(Event::SlashCommand { - command: input, - content, - }) - } - } else { - Some(Event::UserSubmit(input)) - } - } - AiTuiEvent::InputUpdated(text) => { - let is_blank = text.is_empty(); - - // Hot path (every keystroke); uses handle.update_tracked - // to allow read()ing the state without marking it dirty. - handle.update_tracked(move |vs| { - if vs.read().is_input_blank != is_blank { - vs.is_input_blank = is_blank; - } - - if text.starts_with('/') { - let query = text.trim_start_matches('/').to_string(); - let mut results = vs.slash_registry.search_fuzzy(&query); - results.sort_by(|a, b| { - b.relevance - .partial_cmp(&a.relevance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - vs.slash_command_input = Some(query); - vs.slash_command_search_results = results; - } else { - if vs.read().slash_command_input.is_some() { - vs.slash_command_input = None; - } - - if !vs.read().slash_command_search_results.is_empty() { - vs.slash_command_search_results.clear(); - } - } - }); - None - } - AiTuiEvent::CancelGeneration => Some(Event::Cancel), - AiTuiEvent::ExecuteCommand => Some(Event::ExecuteCommand), - AiTuiEvent::InsertCommand => Some(Event::InsertCommand), - AiTuiEvent::CancelConfirmation => Some(Event::Cancel), - AiTuiEvent::InterruptToolExecution => Some(Event::InterruptTools), - AiTuiEvent::Retry => Some(Event::Retry), - AiTuiEvent::Exit => Some(Event::Cancel), - AiTuiEvent::SelectPermission(result) => { - let tool_id = handle - .fetch(|vs| vs.tools.awaiting_permission().map(|t| t.id.clone())) - .blocking_recv() - .ok() - .flatten(); - - let tool_id = tool_id?; - - let choice = match result { - PermissionResult::Allow => PermissionChoice::Allow, - PermissionResult::AllowFileForSession => PermissionChoice::AllowForSession, - PermissionResult::AlwaysAllowInDir => PermissionChoice::AlwaysAllowInProject, - PermissionResult::AlwaysAllow => PermissionChoice::AlwaysAllow, - PermissionResult::Deny => PermissionChoice::Deny, - }; - Some(Event::PermissionUserChoice { tool_id, choice }) - } - AiTuiEvent::SlashCommand(cmd) => { - if let Some((skill_name, arguments)) = resolve_skill_name(&cmd, handle) { - Some(Event::RequestSkillLoad { - name: skill_name, - arguments, - }) - } else { - let content = resolve_slash_command(&cmd, handle); - Some(Event::SlashCommand { - command: cmd, - content, - }) - } - } - } -} - -/// Resolve a slash command to its output content. -/// If the input starts with `/`, check whether the command name matches a -/// registered skill. Returns `Some((skill_name, arguments))` if it does. -fn resolve_skill_name(input: &str, handle: &Handle<ViewState>) -> Option<(String, Option<String>)> { - let after_slash = input.trim_start_matches('/'); - let cmd_name = after_slash.split_whitespace().next()?.to_string(); - - let is_skill = handle - .fetch({ - let cmd_name = cmd_name.clone(); - move |vs| vs.skill_names.contains(&cmd_name) - }) - .blocking_recv() - .unwrap_or(false); - - if !is_skill { - return None; - } - - let args = after_slash - .strip_prefix(&cmd_name) - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .map(|s| s.to_string()); - - Some((cmd_name, args)) -} - -fn resolve_slash_command(command: &str, handle: &Handle<ViewState>) -> String { - match command.trim() { - "/help" => { - let commands = handle - .fetch(|vs| { - vs.slash_registry - .get_commands() - .iter() - .map(|cmd| format!("- `/{}` — {}", cmd.name, cmd.description)) - .collect::<Vec<_>>() - .join("\n") - }) - .blocking_recv() - .unwrap_or_default(); - include_str!("tui/content/help.md").replace("{commands}", &commands) - } - _ => format!("Unknown command: {command}"), - } -} - -// ============================================================================ -// ViewState sync -// ============================================================================ - -fn sync_view_state(handle: &Handle<ViewState>, fsm: &AgentFsm, in_git_project: bool) { - let state = fsm.state.clone(); - let safe_start = fsm.ctx.view_start_index.min(fsm.ctx.events.len()); - let mut visible_events = fsm.ctx.events[safe_start..].to_vec(); - let all_events = fsm.ctx.events.clone(); - let tools = fsm.ctx.tools.clone(); - let current_response = fsm.ctx.current_response.clone(); - let session_id = fsm.ctx.session_id.clone(); - let is_resumed = fsm.ctx.is_resumed; - let last_event_time = fsm.ctx.last_event_time; - let archived_events = fsm.ctx.archived_events.clone(); - - // Inject streaming text as a synthetic event for live rendering. - // The FSM commits text to events on stream end; this makes it visible during streaming. - let trimmed = current_response.trim_start(); - if !trimmed.is_empty() { - visible_events.push(ConversationEvent::Text { - content: trimmed.to_string(), - }); - } - - // Pre-compute turns and has_command on the driver thread so the - // render-thread view function doesn't redo O(n) work every frame. - let mut archived_builder = turn::TurnBuilder::new(&tools); - for event in &archived_events { - archived_builder.add_event(event); - } - let archived_turns = archived_builder.build(); - let archived_turn_count = archived_turns.len(); - - let mut visible_builder = turn::TurnBuilder::new_starting_at(&tools, archived_turn_count); - for event in &visible_events { - visible_builder.add_event(event); - } - let visible_turns = visible_builder.build(); - - let mut turns = archived_turns; - turns.extend(visible_turns); - - let has_command = visible_events.iter().any(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e { - name == "suggest_command" && input.get("command").and_then(|v| v.as_str()).is_some() - } else { - false - } - }); - - tracing::trace!(?state, "sync_view_state pushing to handle"); - handle.update(move |vs| { - vs.agent_state = state; - vs.visible_events = visible_events; - vs.all_events = all_events; - vs.tools = tools; - vs.current_response = current_response; - vs.session_id = session_id; - vs.is_resumed = is_resumed; - vs.last_event_time = last_event_time; - vs.in_git_project = in_git_project; - vs.archived_events = archived_events; - vs.turns = turns; - vs.has_command = has_command; - vs.archived_turn_count = archived_turn_count; - }); -} - -// ============================================================================ -// Effect execution -// ============================================================================ - -fn execute_effect(effect: &Effect, ctx: DriverContext) { - let DriverContext { - fsm, - io, - handle, - tx, - exiting, - stream_cancel_tx, - tool_abort_txs, - } = ctx; - - match effect { - Effect::StartStream { - messages, - session_id, - } => { - // Cancel any existing stream before starting a new one - stream_cancel_tx.take(); - - let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(()); - *stream_cancel_tx = Some(cancel_tx); - - let tx = tx.clone(); - let app = io.app_ctx.clone(); - let cc = io.client_ctx.clone(); - let (skill_summaries, skill_overflow) = io.skill_registry.server_skills(); - let request = ChatRequest::new( - messages.clone(), - session_id.clone(), - &app.capabilities, - app.daemon_enabled, - fsm.ctx.invocation_id.clone(), - ); - tokio::spawn(async move { - run_stream_bridge( - request, - app, - cc, - tx, - cancel_rx, - skill_summaries, - skill_overflow, - ) - .await; - }); - } - - Effect::AbortStream => { - // Drop the sender — the bridge's cancel_rx.changed() will error, - // breaking the stream loop and dropping the HTTP connection. - stream_cancel_tx.take(); - } - - Effect::CheckPermission { tool_id, tool } => { - let tool_id = tool_id.clone(); - let tool = tool.clone(); - let tx = tx.clone(); - - // Auto-approved tools (e.g. load_skill) bypass permission checks entirely - if tool.is_auto_approved() { - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response: PermissionResponse::Allowed, - })); - return; - } - - let working_dir = tool - .target_dir() - .map(|p| p.to_path_buf()) - .or_else(|| std::env::current_dir().ok()) - .unwrap_or_else(|| PathBuf::from(".")); - - // Check session grants first (synchronous) - if let Some(resolved) = tool.resolved_file_path() - && io.edit_permissions.has_valid_grant(&resolved) - { - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response: PermissionResponse::SessionGranted, - })); - return; - } - - tokio::spawn(async move { - let response = match PermissionResolver::new(working_dir).await { - Ok(resolver) => match resolver.check(&tool).await { - Ok(crate::permissions::check::PermissionResponse::Allowed) => { - PermissionResponse::Allowed - } - Ok(crate::permissions::check::PermissionResponse::Denied) => { - PermissionResponse::Denied - } - Ok(crate::permissions::check::PermissionResponse::Ask) => { - PermissionResponse::Ask - } - Err(_) => PermissionResponse::Ask, - }, - Err(_) => PermissionResponse::Ask, - }; - let _ = tx.send(DriverEvent::Fsm(Event::PermissionResolved { - tool_id, - response, - })); - }); - } - - Effect::ExecuteTool { tool_id, tool } => { - let tool_id = tool_id.clone(); - let tx = tx.clone(); - let db = io.app_ctx.history_db.clone(); - - match &tool { - ClientToolCall::Shell(shell_call) => { - let shell_call = shell_call.clone(); - let tx_preview = tx.clone(); - let tool_id_for_preview = tool_id.clone(); - - // Create interrupt channel and store the sender for AbortTool - let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel(); - tool_abort_txs.insert(tool_id.clone(), interrupt_tx); - - tokio::spawn(async move { - let (output_tx, mut output_rx) = - tokio::sync::mpsc::channel::<Vec<String>>(16); - - let preview_id = tool_id_for_preview; - let tx_fwd = tx_preview; - tokio::spawn(async move { - while let Some(lines) = output_rx.recv().await { - let _ = tx_fwd.send(DriverEvent::Fsm(Event::ToolPreviewUpdate { - tool_id: preview_id.clone(), - lines, - exit_code: None, - })); - } - }); - - let outcome = crate::tools::execute_shell_command_streaming( - &shell_call, - output_tx, - interrupt_rx, - ) - .await; - - let preview = - if let crate::tools::ToolOutcome::Structured { exit_code, .. } = - &outcome - { - Some(ToolPreviewData::Shell { - lines: vec![], - exit_code: *exit_code, - // Reason is set by the FSM in handle_tool_done - // based on whether it was a user interrupt or timeout. - interrupted: None, - }) - } else { - None - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - }); - } - ClientToolCall::Edit(edit_call) => { - let resolved = edit_call.resolved_path(); - - // Capture old content for snapshot + diff preview - let old_content = std::fs::read(&resolved).ok(); - if let Some(ref content) = old_content - && let Some(ref mut store) = io.snapshot_store - && let Err(e) = store.ensure_snapshot(&resolved, content) - { - tracing::warn!("Failed to snapshot before edit: {e}"); - } - - // Edit is fast (file read + string replace + write) — run inline - let (outcome, new_content) = edit_call.execute(&resolved, &io.file_tracker); - - // Update file tracker with new content - if let Some(new_bytes) = &new_content - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - - // Compute diff preview - let preview = match (&old_content, &new_content) { - (Some(old_bytes), Some(new_bytes)) => { - let old_str = String::from_utf8_lossy(old_bytes); - let new_str = String::from_utf8_lossy(new_bytes); - let diff = crate::diff::EditPreview::compute(&old_str, &new_str); - if diff.hunks.is_empty() { - None - } else { - Some(ToolPreviewData::Edit(diff)) - } - } - _ => None, - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - } - ClientToolCall::Write(write_call) => { - let resolved = write_call.resolved_path(); - - // Snapshot existing file before overwriting - if let Ok(content) = std::fs::read(&resolved) - && let Some(ref mut store) = io.snapshot_store - && let Err(e) = store.ensure_snapshot(&resolved, &content) - { - tracing::warn!("Failed to snapshot before write: {e}"); - } - - // Write is fast (atomic file write) — run inline - let (outcome, written_bytes) = write_call.execute(&resolved); - - // Update file tracker with new content - if let Some(new_bytes) = &written_bytes - && let Ok(mtime) = std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker - .update_after_edit(&resolved, new_bytes, mtime); - } - - let preview = if !outcome.is_error() { - Some(ToolPreviewData::Write( - crate::diff::WritePreview::from_content(&write_call.content), - )) - } else { - None - }; - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview, - })); - } - ClientToolCall::Read(read_call) => { - // Read is fast (file read) — run inline so we can update file_tracker - let outcome = read_call.execute(); - - // Track the read for freshness checking on subsequent edits - if !outcome.is_error() { - let resolved = read_call.resolved_path(); - if resolved.is_file() - && let Ok(content) = std::fs::read(&resolved) - && let Ok(mtime) = - std::fs::metadata(&resolved).and_then(|m| m.modified()) - { - io.file_tracker.record_read(resolved, &content, mtime); - } - } - - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - } - ClientToolCall::AtuinHistory(tool) => { - // History search needs async DB access - let tool = tool.clone(); - tokio::spawn(async move { - let outcome = tool.execute(&db).await; - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - ClientToolCall::AtuinOutput(tool) => { - let tool = tool.clone(); - tokio::spawn(async move { - let outcome = tool.execute().await; - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - ClientToolCall::LoadSkill(skill_call) => { - let skill_name = skill_call.name.clone(); - let registry = io.skill_registry.clone(); - let shell = io - .client_ctx - .shell - .clone() - .unwrap_or_else(|| "sh".to_string()); - - tokio::spawn(async move { - let content = - load_skill_content(®istry, &skill_name, &shell, None).await; - let outcome = crate::tools::ToolOutcome::Success(content); - let _ = tx.send(DriverEvent::Fsm(Event::ToolExecutionDone { - tool_id, - outcome, - preview: None, - })); - }); - } - } - } - - Effect::LoadSkill { name, arguments } => { - let name = name.clone(); - let arguments = arguments.clone(); - let registry = io.skill_registry.clone(); - let shell = io - .client_ctx - .shell - .clone() - .unwrap_or_else(|| "sh".to_string()); - let tx = tx.clone(); - tokio::spawn(async move { - let content = - load_skill_content(®istry, &name, &shell, arguments.as_deref()).await; - let _ = tx.send(DriverEvent::Fsm(Event::SkillLoaded { - name, - arguments, - content, - })); - }); - } - - Effect::AbortTool { tool_id } => { - if let Some(abort_tx) = tool_abort_txs.remove(tool_id) { - let _ = abort_tx.send(()); - } - } - - Effect::Persist => { - // Handled inline in the driver loop (before this function is called). - } - - Effect::WritePermissionRule { - target, - rule, - disposition, - } => { - let file_path = match target { - PermissionTarget::Project => { - let project_root = io - .app_ctx - .git_root - .clone() - .or_else(|| std::env::current_dir().ok()) - .unwrap_or_else(|| PathBuf::from(".")); - writer::project_permissions_path(&project_root) - } - PermissionTarget::Global => writer::global_permissions_path(), - }; - let rule = rule.clone(); - let disposition = disposition.clone(); - tokio::spawn(async move { - if let Err(e) = writer::write_rule(&file_path, &rule, disposition).await { - tracing::error!("Failed to write permission rule: {e}"); - } - }); - } - - Effect::CacheSessionGrant { path } => { - io.edit_permissions.grant(path.clone()); - } - - Effect::ArchiveSession => { - let rt = tokio::runtime::Handle::current(); - if let Err(e) = rt.block_on(io.session_mgr.archive_and_reset()) { - tracing::warn!("Failed to archive session: {e}"); - } - } - - Effect::ScheduleTimeout { - timeout_id, - duration, - kind, - } => { - let timeout_id = *timeout_id; - let duration = *duration; - let kind = kind.clone(); - let tx = tx.clone(); - tokio::spawn(async move { - tokio::time::sleep(duration).await; - use crate::fsm::effects::TimeoutKind; - let event = match kind { - TimeoutKind::Confirmation => Event::ConfirmationTimeout { timeout_id }, - TimeoutKind::ToolExecution { tool_id } => Event::ToolExecutionTimeout { - timeout_id, - tool_id, - }, - }; - let _ = tx.send(DriverEvent::Fsm(event)); - }); - } - - Effect::ExitApp(action) => { - let action = action.clone(); - handle.update(move |vs| { - vs.exit_action = Some(action); - }); - exiting.store(true, Ordering::Release); - let h2 = handle.clone(); - h2.exit(); - } - } -} - -// ============================================================================ -// Persistence -// ============================================================================ - -fn persist(fsm: &AgentFsm, io: &mut IoContext) { - let start = std::time::Instant::now(); - let rt = tokio::runtime::Handle::current(); - - if let Err(e) = rt.block_on(io.session_mgr.persist_events(&fsm.ctx.events)) { - tracing::warn!("Failed to persist session events: {e}"); - } - if let Some(ref sid) = fsm.ctx.session_id - && let Err(e) = rt.block_on(io.session_mgr.persist_server_session_id(sid)) - { - tracing::warn!("Failed to persist server session ID: {e}"); - } - if let Ok(json) = io.file_tracker.to_json() - && let Err(e) = rt.block_on( - io.session_mgr - .set_metadata(crate::file_tracker::METADATA_KEY, &json), - ) - { - tracing::warn!("Failed to persist file tracker: {e}"); - } - if let Ok(json) = io.edit_permissions.to_json() - && let Err(e) = rt.block_on( - io.session_mgr - .set_metadata(crate::edit_permissions::METADATA_KEY, &json), - ) - { - tracing::warn!("Failed to persist edit permissions: {e}"); - } - tracing::trace!(elapsed_ms = start.elapsed().as_millis(), "persist complete"); -} - -// ============================================================================ -// Skill loading -// ============================================================================ - -async fn load_skill_content( - registry: &crate::skills::SkillRegistry, - name: &str, - shell: &str, - arguments: Option<&str>, -) -> String { - match registry.load(name, shell, arguments).await { - Ok(body) => body, - Err(e) => format!("Failed to load skill '{name}': {e}"), - } -} - -// ============================================================================ -// Stream bridge -// ============================================================================ - -async fn run_stream_bridge( - request: ChatRequest, - app_ctx: AppContext, - client_ctx: ClientContext, - tx: mpsc::Sender<DriverEvent>, - mut cancel_rx: tokio::sync::watch::Receiver<()>, - skill_summaries: Vec<crate::skills::SkillSummary>, - skill_overflow: Option<String>, -) { - use crate::stream::{StreamContent, StreamControl, StreamFrame, create_chat_stream}; - use futures::StreamExt; - - // Gather user context files (TERMINAL.md) and interpolate commands. - let shell = client_ctx.shell.as_deref().unwrap_or("sh"); - let start_dir = std::env::current_dir().unwrap_or_default(); - let global_ctx_path = crate::user_context::global_context_path(); - let user_contexts = - crate::user_context::gather(&start_dir, Some(&global_ctx_path), shell).await; - - let stream = create_chat_stream( - app_ctx.endpoint.clone(), - app_ctx.token.clone(), - request, - client_ctx, - app_ctx.send_cwd, - app_ctx.last_command.clone(), - user_contexts, - skill_summaries, - skill_overflow, - ); - futures::pin_mut!(stream); - - let _ = tx.send(DriverEvent::Fsm(Event::StreamStarted)); - - loop { - // Select between the next stream frame and cancellation. - // When the driver drops the cancel sender, changed() returns Err - // and we break — dropping the HTTP stream and cancelling the request. - let frame = tokio::select! { - biased; - _ = cancel_rx.changed() => break, - frame = stream.next() => match frame { - Some(frame) => frame, - None => break, - }, - }; - - let event = match frame { - Ok(StreamFrame::Content(content)) => match content { - StreamContent::TextChunk(text) => Some(Event::StreamChunk(text)), - StreamContent::ToolCall { id, name, input } => { - if name == "suggest_command" { - Some(Event::SuggestCommand { id, input }) - } else { - Some(Event::StreamToolCall { id, name, input }) - } - } - StreamContent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => Some(Event::StreamServerToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }), - }, - Ok(StreamFrame::Control(control)) => match control { - StreamControl::StatusChanged(status) => Some(Event::StreamStatusChanged(status)), - StreamControl::Done { session_id } => Some(Event::StreamDone { session_id }), - StreamControl::Error(msg) => Some(Event::StreamError(msg)), - }, - Err(e) => Some(Event::StreamError(e.to_string())), - }; - - if let Some(event) = event { - // StreamDone and StreamError are terminal — the server won't send more. - // SuggestCommand is NOT terminal: the server sends StreamDone after it - // with the session_id we need to capture. - let is_terminal = matches!(event, Event::StreamDone { .. } | Event::StreamError(_)); - if tx.send(DriverEvent::Fsm(event)).is_err() { - break; - } - if is_terminal { - break; - } - } - } -} diff --git a/crates/atuin-ai/src/edit_permissions.rs b/crates/atuin-ai/src/edit_permissions.rs deleted file mode 100644 index 5015a007..00000000 --- a/crates/atuin-ai/src/edit_permissions.rs +++ /dev/null @@ -1,108 +0,0 @@ -//! Session-scoped permission cache for file edits. -//! -//! When the user selects "Allow this file for this session", the grant is -//! recorded here with a timestamp. Subsequent edits to the same file skip -//! the permission prompt as long as the grant hasn't expired. -//! -//! Grants are time-limited (1 hour TTL) so they don't outlive the user's -//! attention in long-running sessions. Persisted as JSON in session -//! metadata so they survive across CLI invocations. - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::time::SystemTime; - -use eyre::Result; -use serde::{Deserialize, Serialize}; - -/// Session metadata key for persistence. -pub(crate) const METADATA_KEY: &str = "edit_permissions"; - -/// How long a session-scoped edit permission remains valid. -const TTL_MS: i64 = 60 * 60 * 1000; // 1 hour - -/// Cache of per-file edit permission grants within a session. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub(crate) struct EditPermissionCache { - /// Maps canonical file paths to the grant timestamp (unix millis). - grants: HashMap<PathBuf, i64>, -} - -impl EditPermissionCache { - /// Record a permission grant for a file. - pub fn grant(&mut self, path: PathBuf) { - self.grants.insert(path, now_ms()); - } - - /// Check whether there's a valid (non-expired) grant for a file. - pub fn has_valid_grant(&self, path: &Path) -> bool { - if let Some(&granted_at) = self.grants.get(path) { - (now_ms() - granted_at) < TTL_MS - } else { - false - } - } - - pub fn to_json(&self) -> Result<String> { - Ok(serde_json::to_string(self)?) - } - - pub fn from_json(json: &str) -> Result<Self> { - Ok(serde_json::from_str(json)?) - } -} - -fn now_ms() -> i64 { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_millis() as i64) - .unwrap_or(0) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn grant_and_check() { - let mut cache = EditPermissionCache::default(); - let path = PathBuf::from("/Users/me/.config/foo.toml"); - - assert!(!cache.has_valid_grant(&path)); - cache.grant(path.clone()); - assert!(cache.has_valid_grant(&path)); - } - - #[test] - fn different_paths_are_independent() { - let mut cache = EditPermissionCache::default(); - let a = PathBuf::from("/etc/hosts"); - let b = PathBuf::from("/etc/resolv.conf"); - - cache.grant(a.clone()); - assert!(cache.has_valid_grant(&a)); - assert!(!cache.has_valid_grant(&b)); - } - - #[test] - fn roundtrip_json() { - let mut cache = EditPermissionCache::default(); - cache.grant(PathBuf::from("/some/file.toml")); - - let json = cache.to_json().unwrap(); - let restored = EditPermissionCache::from_json(&json).unwrap(); - assert!(restored.has_valid_grant(Path::new("/some/file.toml"))); - } - - #[test] - fn expired_grant_is_invalid() { - let mut cache = EditPermissionCache::default(); - let path = PathBuf::from("/expired/file.toml"); - - // Insert a grant from 2 hours ago - let two_hours_ago = now_ms() - (2 * 60 * 60 * 1000); - cache.grants.insert(path.clone(), two_hours_ago); - - assert!(!cache.has_valid_grant(&path)); - } -} diff --git a/crates/atuin-ai/src/event_serde.rs b/crates/atuin-ai/src/event_serde.rs deleted file mode 100644 index e3f9d6f7..00000000 --- a/crates/atuin-ai/src/event_serde.rs +++ /dev/null @@ -1,397 +0,0 @@ -//! Manual serialization for ConversationEvent to/from storage format. -//! -//! The storage format is decoupled from the Rust enum so the two can evolve -//! independently. Each event is stored as an `(event_type, event_data)` pair -//! where `event_data` is a JSON string. - -use eyre::{Result, eyre}; -use serde_json::Value; - -use crate::tui::ConversationEvent; - -/// Serialize a ConversationEvent into an (event_type, event_data_json) pair -/// suitable for database storage. -pub(crate) fn serialize_event(event: &ConversationEvent) -> (String, String) { - match event { - ConversationEvent::UserMessage { content } => ( - "user_message".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::Text { content } => ( - "text".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::ToolCall { id, name, input } => ( - "tool_call".to_string(), - serde_json::json!({ - "id": id, - "name": name, - "input": input, - }) - .to_string(), - ), - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => ( - "tool_result".to_string(), - serde_json::json!({ - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error, - "remote": remote, - "content_length": content_length, - }) - .to_string(), - ), - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => ( - "out_of_band_output".to_string(), - serde_json::json!({ - "name": name, - "command": command, - "content": content, - }) - .to_string(), - ), - ConversationEvent::SystemContext { content } => ( - "system_context".to_string(), - serde_json::json!({ "content": content }).to_string(), - ), - ConversationEvent::SkillInvocation { - name, - arguments, - content, - } => ( - "skill_invocation".to_string(), - serde_json::json!({ - "name": name, - "arguments": arguments, - "content": content, - }) - .to_string(), - ), - } -} - -/// Deserialize an (event_type, event_data_json) pair from storage back into a -/// ConversationEvent. -pub(crate) fn deserialize_event(event_type: &str, event_data: &str) -> Result<ConversationEvent> { - let data: Value = serde_json::from_str(event_data) - .map_err(|e| eyre!("failed to parse event_data JSON: {e}"))?; - - match event_type { - "user_message" => Ok(ConversationEvent::UserMessage { - content: json_string(&data, "content")?, - }), - "text" => Ok(ConversationEvent::Text { - content: json_string(&data, "content")?, - }), - "tool_call" => Ok(ConversationEvent::ToolCall { - id: json_string(&data, "id")?, - name: json_string(&data, "name")?, - input: data - .get("input") - .cloned() - .ok_or_else(|| eyre!("tool_call missing 'input' field"))?, - }), - "tool_result" => Ok(ConversationEvent::ToolResult { - tool_use_id: json_string(&data, "tool_use_id")?, - content: json_string(&data, "content")?, - is_error: data - .get("is_error") - .and_then(Value::as_bool) - .ok_or_else(|| eyre!("tool_result missing 'is_error' field"))?, - remote: data.get("remote").and_then(Value::as_bool).unwrap_or(false), - content_length: data - .get("content_length") - .and_then(Value::as_u64) - .map(|v| v as usize), - }), - "out_of_band_output" => Ok(ConversationEvent::OutOfBandOutput { - name: json_string(&data, "name")?, - command: data - .get("command") - .and_then(|v| if v.is_null() { None } else { v.as_str() }) - .map(String::from), - content: json_string(&data, "content")?, - }), - "system_context" => Ok(ConversationEvent::SystemContext { - content: json_string(&data, "content")?, - }), - "skill_invocation" => Ok(ConversationEvent::SkillInvocation { - name: json_string(&data, "name")?, - arguments: data - .get("arguments") - .and_then(|v| if v.is_null() { None } else { v.as_str() }) - .map(String::from), - content: json_string(&data, "content")?, - }), - other => Err(eyre!("unknown event type: {other}")), - } -} - -fn json_string(data: &Value, field: &str) -> Result<String> { - data.get(field) - .and_then(Value::as_str) - .map(String::from) - .ok_or_else(|| eyre!("missing or non-string field '{field}'")) -} - -#[cfg(test)] -mod tests { - use super::*; - - fn round_trip(event: &ConversationEvent) -> ConversationEvent { - let (event_type, event_data) = serialize_event(event); - deserialize_event(&event_type, &event_data).unwrap() - } - - #[test] - fn test_user_message() { - let event = ConversationEvent::UserMessage { - content: "hello world".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::UserMessage { content } if content == "hello world") - ); - } - - #[test] - fn test_text() { - let event = ConversationEvent::Text { - content: "response text".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::Text { content } if content == "response text") - ); - } - - #[test] - fn test_tool_call() { - let input = serde_json::json!({"command": "ls -la", "danger": "low"}); - let event = ConversationEvent::ToolCall { - id: "tc_123".to_string(), - name: "suggest_command".to_string(), - input: input.clone(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolCall { - id, - name, - input: result_input, - } => { - assert_eq!(id, "tc_123"); - assert_eq!(name, "suggest_command"); - assert_eq!(result_input, input); - } - _ => panic!("expected ToolCall"), - } - } - - #[test] - fn test_tool_result() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_123".to_string(), - content: "file contents here".to_string(), - is_error: false, - remote: false, - content_length: None, - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => { - assert_eq!(tool_use_id, "tc_123"); - assert_eq!(content, "file contents here"); - assert!(!is_error); - assert!(!remote); - assert!(content_length.is_none()); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_error() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_456".to_string(), - content: "permission denied".to_string(), - is_error: true, - remote: false, - content_length: None, - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { is_error, .. } => assert!(is_error), - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_remote() { - let event = ConversationEvent::ToolResult { - tool_use_id: "tc_789".to_string(), - content: "ref:abc123".to_string(), - is_error: false, - remote: true, - content_length: Some(4096), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolResult { - remote, - content_length, - .. - } => { - assert!(remote); - assert_eq!(content_length, Some(4096)); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_tool_result_backwards_compat() { - // Old stored data without remote/content_length fields should deserialize - // with defaults (remote=false, content_length=None) - let event = deserialize_event( - "tool_result", - r#"{"tool_use_id":"tc_old","content":"old result","is_error":false}"#, - ) - .unwrap(); - match event { - ConversationEvent::ToolResult { - remote, - content_length, - .. - } => { - assert!(!remote); - assert!(content_length.is_none()); - } - _ => panic!("expected ToolResult"), - } - } - - #[test] - fn test_out_of_band_with_command() { - let event = ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/help".to_string()), - content: "help text".to_string(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => { - assert_eq!(name, "System"); - assert_eq!(command.as_deref(), Some("/help")); - assert_eq!(content, "help text"); - } - _ => panic!("expected OutOfBandOutput"), - } - } - - #[test] - fn test_out_of_band_without_command() { - let event = ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: None, - content: "some output".to_string(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::OutOfBandOutput { command, .. } => { - assert!(command.is_none()); - } - _ => panic!("expected OutOfBandOutput"), - } - } - - #[test] - fn test_unknown_event_type() { - let result = deserialize_event("banana", "{}"); - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("unknown event type") - ); - } - - #[test] - fn test_invalid_json() { - let result = deserialize_event("text", "not json"); - assert!(result.is_err()); - } - - #[test] - fn test_missing_field() { - let result = deserialize_event("text", "{}"); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("content")); - } - - #[test] - fn test_text_with_special_characters() { - let event = ConversationEvent::Text { - content: "line1\nline2\ttab \"quotes\" \\backslash 🎉".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::Text { content } if content == "line1\nline2\ttab \"quotes\" \\backslash 🎉") - ); - } - - #[test] - fn test_tool_call_with_nested_input() { - let input = serde_json::json!({ - "command": "echo 'hello'", - "nested": { "a": [1, 2, 3], "b": null } - }); - let event = ConversationEvent::ToolCall { - id: "tc_1".to_string(), - name: "execute_shell_command".to_string(), - input: input.clone(), - }; - let result = round_trip(&event); - match result { - ConversationEvent::ToolCall { - input: result_input, - .. - } => { - assert_eq!(result_input, input); - } - _ => panic!("expected ToolCall"), - } - } - - #[test] - fn test_system_context() { - let event = ConversationEvent::SystemContext { - content: "[system: new invocation started]".to_string(), - }; - let result = round_trip(&event); - assert!( - matches!(result, ConversationEvent::SystemContext { content } if content == "[system: new invocation started]") - ); - } -} diff --git a/crates/atuin-ai/src/file_tracker.rs b/crates/atuin-ai/src/file_tracker.rs deleted file mode 100644 index feee1ee8..00000000 --- a/crates/atuin-ai/src/file_tracker.rs +++ /dev/null @@ -1,234 +0,0 @@ -//! Tracks which files have been read in the current session, for freshness -//! checking before edits. -//! -//! The tracker records the content hash and mtime of each file at the time -//! it was last read. Before an edit, the tracker verifies the file hasn't -//! changed since the last read — catching both external modifications and -//! concurrent tool calls. -//! -//! Persisted as JSON in session metadata so it survives across CLI -//! invocations within the same logical session. - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::time::SystemTime; - -use eyre::Result; -use serde::{Deserialize, Serialize}; - -/// Metadata key used for session_metadata persistence. -pub(crate) const METADATA_KEY: &str = "file_read_tracker"; - -/// State recorded for a single file read. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct FileReadState { - /// Hash of the file contents at the time of the last read. - pub content_hash: u64, - /// File mtime (as milliseconds since epoch) at the time of the last read. - /// Millisecond precision ensures sub-second modifications are detected. - pub mtime_ms: i64, -} - -/// Tracks file read state for freshness checking. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub(crate) struct FileReadTracker { - reads: HashMap<PathBuf, FileReadState>, -} - -/// Result of a freshness check. -pub(crate) enum FreshnessCheck { - /// File is fresh — the content hasn't changed since the last read. - Fresh, - /// File has never been read in this session. - NotRead, - /// File has been modified since the last read. - Stale, -} - -impl FileReadTracker { - /// Record that a file was read. Call this after a successful `read_file` - /// execution. The `path` should be canonical (absolute, tilde-expanded). - pub fn record_read(&mut self, path: PathBuf, content: &[u8], mtime: SystemTime) { - let content_hash = hash_content(content); - let mtime_ms = system_time_to_ms(mtime); - - self.reads.insert( - path, - FileReadState { - content_hash, - mtime_ms, - }, - ); - } - - /// Check whether a file is fresh (unchanged since last read). - /// - /// Uses mtime as a fast path — only re-hashes if mtime differs. - pub fn check_freshness(&self, path: &Path) -> Result<FreshnessCheck> { - let state = match self.reads.get(path) { - Some(s) => s, - None => return Ok(FreshnessCheck::NotRead), - }; - - // Stat the file - let metadata = match std::fs::metadata(path) { - Ok(m) => m, - Err(_) => return Ok(FreshnessCheck::Stale), // file deleted or inaccessible - }; - - let current_mtime_ms = - system_time_to_ms(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH)); - - // Fast path: mtime unchanged → fresh - if current_mtime_ms == state.mtime_ms { - return Ok(FreshnessCheck::Fresh); - } - - // Mtime changed — re-hash to confirm - let content = std::fs::read(path)?; - let current_hash = hash_content(&content); - - if current_hash == state.content_hash { - Ok(FreshnessCheck::Fresh) - } else { - Ok(FreshnessCheck::Stale) - } - } - - /// Update the tracker entry after a successful edit (new content written). - pub fn update_after_edit(&mut self, path: &Path, new_content: &[u8], new_mtime: SystemTime) { - let content_hash = hash_content(new_content); - let mtime_ms = system_time_to_ms(new_mtime); - - self.reads.insert( - path.to_path_buf(), - FileReadState { - content_hash, - mtime_ms, - }, - ); - } - - /// Serialize to JSON for session metadata persistence. - pub fn to_json(&self) -> Result<String> { - Ok(serde_json::to_string(self)?) - } - - /// Deserialize from JSON session metadata. - pub fn from_json(json: &str) -> Result<Self> { - Ok(serde_json::from_str(json)?) - } -} - -fn system_time_to_ms(t: SystemTime) -> i64 { - t.duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_millis() as i64) - .unwrap_or(0) -} - -fn hash_content(content: &[u8]) -> u64 { - xxhash_rust::xxh3::xxh3_64(content) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn record_and_check_fresh() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "hello world").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Fresh - )); - } - - #[test] - fn check_not_read() { - let tracker = FileReadTracker::default(); - let path = PathBuf::from("/nonexistent/file.txt"); - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::NotRead - )); - } - - #[test] - fn check_stale_after_modification() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "original").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - // Small delay to ensure the filesystem mtime advances - std::thread::sleep(std::time::Duration::from_millis(10)); - - // Modify the file - std::fs::write(&path, "modified").unwrap(); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Stale - )); - } - - #[test] - fn update_after_edit_makes_fresh() { - let mut tracker = FileReadTracker::default(); - let mut tmp = NamedTempFile::new().unwrap(); - write!(tmp, "original").unwrap(); - - let path = tmp.path().to_path_buf(); - let content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - tracker.record_read(path.clone(), &content, mtime); - - // Simulate an edit - let new_content = b"edited content"; - std::fs::write(&path, new_content).unwrap(); - let new_mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - tracker.update_after_edit(&path, new_content, new_mtime); - - assert!(matches!( - tracker.check_freshness(&path).unwrap(), - FreshnessCheck::Fresh - )); - } - - #[test] - fn roundtrip_json() { - let mut tracker = FileReadTracker::default(); - tracker.reads.insert( - PathBuf::from("/some/file.toml"), - FileReadState { - content_hash: 12345, - mtime_ms: 1700000000000, - }, - ); - - let json = tracker.to_json().unwrap(); - let restored = FileReadTracker::from_json(&json).unwrap(); - assert_eq!(restored.reads.len(), 1); - assert_eq!( - restored.reads[&PathBuf::from("/some/file.toml")].content_hash, - 12345 - ); - } -} diff --git a/crates/atuin-ai/src/fsm/effects.rs b/crates/atuin-ai/src/fsm/effects.rs deleted file mode 100644 index adc9628e..00000000 --- a/crates/atuin-ai/src/fsm/effects.rs +++ /dev/null @@ -1,99 +0,0 @@ -//! Effects (outputs) from the agent FSM. -//! -//! The FSM returns these as data; the driver is responsible for executing them. - -use std::path::PathBuf; -use std::time::Duration; - -use serde_json::Value; - -use crate::permissions::rule::Rule; -use crate::permissions::writer::RuleDisposition; -use crate::tools::ClientToolCall; - -/// Where to write a permission rule. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionTarget { - /// Project-level: `<git_root_or_cwd>/.atuin/permissions.ai.toml` - Project, - /// Global: `~/.config/atuin/permissions.ai.toml` - Global, -} - -/// Side effects the driver should execute after a state transition. -#[derive(Debug, Clone)] -pub(crate) enum Effect { - // ─── Network ──────────────────────────────────────────────── - /// Start a new streaming request to the server. - StartStream { - messages: Vec<Value>, - session_id: Option<String>, - }, - /// Abort the active stream connection. - AbortStream, - - // ─── Tool orchestration ───────────────────────────────────── - /// Run the permission resolver for a tool call. - CheckPermission { - tool_id: String, - tool: ClientToolCall, - }, - /// Execute a tool (file read, edit, write, shell, history search). - ExecuteTool { - tool_id: String, - tool: ClientToolCall, - }, - /// Kill a running tool (send interrupt to shell command). - AbortTool { tool_id: String }, - /// Load a skill's content asynchronously (read + interpolate). - LoadSkill { - name: String, - arguments: Option<String>, - }, - - // ─── Persistence ──────────────────────────────────────────── - /// Persist current conversation state to disk. - Persist, - /// Write a permanent permission rule to disk. - WritePermissionRule { - target: PermissionTarget, - rule: Rule, - disposition: RuleDisposition, - }, - /// Cache a session-scoped file permission grant. - CacheSessionGrant { path: PathBuf }, - /// Archive current session and start fresh (IO only — state already updated by FSM). - ArchiveSession, - - // ─── Timers ───────────────────────────────────────────────── - /// Schedule a timer that fires an event after the given delay. - ScheduleTimeout { - timeout_id: u64, - duration: Duration, - kind: TimeoutKind, - }, - - // ─── Exit ─────────────────────────────────────────────────── - /// Exit the application with the given action. - ExitApp(ExitAction), -} - -/// What kind of timeout was scheduled. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum TimeoutKind { - /// Dangerous command confirmation dialog auto-dismiss. - Confirmation, - /// Shell tool execution timeout — abort the tool if it's still running. - ToolExecution { tool_id: String }, -} - -/// What to do when exiting the TUI. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ExitAction { - /// Run the suggested command. - Execute(String), - /// Insert the command into the shell without running. - Insert(String), - /// Exit without action. - Cancel, -} diff --git a/crates/atuin-ai/src/fsm/events.rs b/crates/atuin-ai/src/fsm/events.rs deleted file mode 100644 index e591db41..00000000 --- a/crates/atuin-ai/src/fsm/events.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! Events (inputs) to the agent FSM. - -use serde_json::Value; - -use crate::tools::ToolOutcome; - -/// Events that drive state transitions in the agent FSM. -#[derive(Debug, Clone)] -pub(crate) enum Event { - // ─── User actions ─────────────────────────────────────────── - /// User submitted a message from the input box. - UserSubmit(String), - /// User pressed Esc or equivalent cancel action. - Cancel, - /// User pressed Enter to execute the suggested command. - ExecuteCommand, - /// User pressed Tab to insert the suggested command. - InsertCommand, - /// User chose to retry after an error. - Retry, - /// User interrupted executing tools (Ctrl+C / Esc during shell execution). - InterruptTools, - - // ─── Stream lifecycle ─────────────────────────────────────── - /// Stream connection established, first frame received. - StreamStarted, - /// Received a chunk of streamed text content. - StreamChunk(String), - /// Stream delivered a client-side tool call. - StreamToolCall { - id: String, - name: String, - input: Value, - }, - /// Stream delivered a server-side tool result (executed remotely). - StreamServerToolResult { - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option<usize>, - }, - /// Stream status changed (e.g. "thinking", "searching"). - StreamStatusChanged(String), - /// Stream ended normally. - StreamDone { session_id: String }, - /// Stream encountered an error. - StreamError(String), - - // ─── Suggest command (terminal tool call) ─────────────────── - /// The suggest_command tool call acts as a stream terminal event. - /// This is the server signaling "turn complete, here's the command." - SuggestCommand { id: String, input: Value }, - - // ─── Tool lifecycle ───────────────────────────────────────── - /// Permission resolver completed for a tool. - PermissionResolved { - tool_id: String, - response: PermissionResponse, - }, - /// User made a permission choice via the dialog. - PermissionUserChoice { - tool_id: String, - choice: PermissionChoice, - }, - /// Tool execution completed. - ToolExecutionDone { - tool_id: String, - outcome: ToolOutcome, - /// Preview data computed by the driver (diff, content preview, final shell state). - preview: Option<super::tools::ToolPreviewData>, - }, - /// Live preview update for an executing shell command. - ToolPreviewUpdate { - tool_id: String, - lines: Vec<String>, - exit_code: Option<i32>, - }, - - // ─── Timers ───────────────────────────────────────────────── - /// Confirmation timeout expired. - ConfirmationTimeout { timeout_id: u64 }, - /// Shell tool execution timeout expired. - ToolExecutionTimeout { timeout_id: u64, tool_id: String }, - - // ─── Session management ───────────────────────────────────── - /// User ran /new to start a fresh session. - NewSession, - - // ─── Slash commands ───────────────────────────────────────── - /// User submitted a slash command (other than /new). - /// The driver resolves known commands (like /help) and passes the - /// rendered content; the FSM just pushes an OOB event. - SlashCommand { command: String, content: String }, - - // ─── Skills ──────────────────────────────────────────────── - /// User invoked a skill via /skill-name. FSM emits a LoadSkill - /// effect; the driver loads the content asynchronously and sends - /// SkillLoaded when ready. - RequestSkillLoad { - name: String, - arguments: Option<String>, - }, - /// A skill's content has been loaded and interpolated. - /// Pushes skill content as OOB context and starts a turn so the - /// LLM sees the skill and acts on it. - SkillLoaded { - name: String, - arguments: Option<String>, - content: String, - }, -} - -/// Result of the permission resolver check. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionResponse { - /// Rule allows this tool call — execute immediately. - Allowed, - /// Rule denies this tool call — reject with error. - Denied, - /// No matching rule — ask the user. - Ask, - /// Session-scoped grant exists — execute immediately (bypass resolver). - SessionGranted, -} - -/// User's choice from the permission dialog. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionChoice { - /// Allow this one time. - Allow, - /// Allow this file for the remainder of the session. - AllowForSession, - /// Always allow in this project (writes to project permissions file). - AlwaysAllowInProject, - /// Always allow globally (writes to global permissions file, scoped to file). - AlwaysAllow, - /// Deny this tool call. - Deny, -} diff --git a/crates/atuin-ai/src/fsm/mod.rs b/crates/atuin-ai/src/fsm/mod.rs deleted file mode 100644 index 3d72a3ae..00000000 --- a/crates/atuin-ai/src/fsm/mod.rs +++ /dev/null @@ -1,1103 +0,0 @@ -//! Agent conversation FSM. -//! -//! Pure state machine that returns effects as data. -//! The driver is responsible for executing effects and feeding events back. -//! -//! The FSM owns the conversation event log and tool lifecycle state. -//! It never performs IO directly. - -pub(crate) mod effects; -pub(crate) mod events; -pub(crate) mod tools; - -#[cfg(test)] -mod tests; - -use std::collections::HashMap; - -use serde_json::Value; - -use crate::context_window::ContextWindowBuilder; -use crate::tui::state::ConversationEvent; - -use effects::{Effect, ExitAction, PermissionTarget, TimeoutKind}; -use events::{Event, PermissionChoice, PermissionResponse}; -use tools::{ToolManager, ToolState}; - -// ============================================================================ -// State -// ============================================================================ - -/// The discrete states of the agent FSM. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum AgentState { - /// Waiting for user input. - Idle { - confirmation: Option<PendingConfirmation>, - }, - - /// A conversation turn is in progress. - Turn { stream: StreamPhase }, - - /// Unrecoverable error. User can retry or exit. - Error(String), -} - -/// Stream connection lifecycle within a Turn. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum StreamPhase { - /// Request sent, awaiting first stream frame. - Connecting, - /// Actively receiving streamed response. - Streaming { status: Option<StreamingStatus> }, - /// Stream connection has ended (Done received). - Done, -} - -/// Streaming status indicators from server. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum StreamingStatus { - Processing, - Searching, - Thinking, - WaitingForTools, -} - -impl StreamingStatus { - pub(crate) fn from_str(s: &str) -> Self { - match s { - "processing" => Self::Processing, - "searching" => Self::Searching, - "waiting_for_tools" => Self::WaitingForTools, - _ => Self::Thinking, - } - } -} - -/// Pending dangerous command confirmation state. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct PendingConfirmation { - pub command: String, - pub timeout_id: u64, -} - -// ============================================================================ -// Context -// ============================================================================ - -/// Shared context owned by the FSM. -#[derive(Debug, Clone)] -pub(crate) struct AgentContext { - /// The full conversation event log (source of truth for API + persistence). - pub events: Vec<ConversationEvent>, - /// Server-assigned session ID. - pub session_id: Option<String>, - /// Accumulated text from current stream (committed to events on tool call or stream end). - pub current_response: String, - /// Per-tool lifecycle state and cached render data. - /// Tools persist across turns for rendering history. - pub tools: ToolManager, - /// Tool IDs that belong to the current turn. Cleared on continuation start. - /// Used to determine whether a turn needs continuation (has unprocessed results). - current_turn_tool_ids: Vec<String>, - /// Maps timeout_id → tool_id for active tool execution timeouts. - /// Cleaned up when a tool completes naturally, so stale timeouts are ignored. - tool_timeout_ids: HashMap<u64, String>, - /// Counter for generating unique timeout IDs. - next_timeout_id: u64, - /// Capabilities advertised to the server. - pub capabilities: Vec<String>, - /// Unique invocation ID for this CLI invocation. - pub invocation_id: String, - - // ─── View state (owned by FSM for atomic transitions) ─────── - /// Index into events where the current TUI invocation starts. - /// Events before this are context for the API but not rendered. - pub view_start_index: usize, - /// Whether this session was resumed from a prior invocation. - pub is_resumed: bool, - /// Time of the last event from a previous invocation. - pub last_event_time: Option<chrono::DateTime<chrono::Utc>>, - /// Events from archived sessions (/new) still rendered on screen. - pub archived_events: Vec<ConversationEvent>, -} - -impl AgentContext { - fn next_timeout_id(&mut self) -> u64 { - let id = self.next_timeout_id; - self.next_timeout_id += 1; - id - } -} - -// ============================================================================ -// The Agent FSM -// ============================================================================ - -/// The agent finite state machine. -/// -/// Pure state machine — `handle()` takes an event, mutates internal state, -/// and returns effects as data for the driver to execute. -#[derive(Debug, Clone)] -pub(crate) struct AgentFsm { - pub state: AgentState, - pub ctx: AgentContext, -} - -impl AgentFsm { - /// Create a new FSM in Idle state. - pub fn new(capabilities: Vec<String>, invocation_id: String) -> Self { - Self { - state: AgentState::Idle { confirmation: None }, - ctx: AgentContext { - events: Vec::new(), - session_id: None, - current_response: String::new(), - tools: ToolManager::new(), - current_turn_tool_ids: Vec::new(), - tool_timeout_ids: HashMap::new(), - next_timeout_id: 0, - capabilities, - invocation_id, - view_start_index: 0, - is_resumed: false, - last_event_time: None, - archived_events: Vec::new(), - }, - } - } - - /// Create an FSM from saved session state (for resume). - pub fn from_session( - events: Vec<ConversationEvent>, - session_id: Option<String>, - capabilities: Vec<String>, - invocation_id: String, - view_start_index: usize, - is_resumed: bool, - last_event_time: Option<chrono::DateTime<chrono::Utc>>, - ) -> Self { - Self { - state: AgentState::Idle { confirmation: None }, - ctx: AgentContext { - events, - session_id, - current_response: String::new(), - tools: ToolManager::new(), - current_turn_tool_ids: Vec::new(), - tool_timeout_ids: HashMap::new(), - next_timeout_id: 0, - capabilities, - invocation_id, - view_start_index, - is_resumed, - last_event_time, - archived_events: Vec::new(), - }, - } - } - - /// Handle an event, returning effects to execute. - pub fn handle(&mut self, event: Event) -> Vec<Effect> { - match (&self.state, event) { - // ================================================================ - // Idle state - // ================================================================ - (AgentState::Idle { confirmation: None }, Event::UserSubmit(msg)) => { - self.start_turn(msg) - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::UserSubmit(msg), - ) => self.start_turn(msg), - - (AgentState::Idle { confirmation: None }, Event::ExecuteCommand) => { - let cmd = self.current_command(); - let Some(cmd) = cmd else { - // No command suggested — exit - return vec![Effect::ExitApp(ExitAction::Cancel)]; - }; - if self.is_current_command_dangerous() { - let timeout_id = self.ctx.next_timeout_id(); - self.state = AgentState::Idle { - confirmation: Some(PendingConfirmation { - command: cmd, - timeout_id, - }), - }; - vec![Effect::ScheduleTimeout { - timeout_id, - duration: std::time::Duration::from_secs(5), - kind: TimeoutKind::Confirmation, - }] - } else { - vec![Effect::ExitApp(ExitAction::Execute(cmd))] - } - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::ExecuteCommand, - ) => { - let confirm = self.state_confirmation().unwrap().clone(); - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::ExitApp(ExitAction::Execute(confirm.command))] - } - - (AgentState::Idle { .. }, Event::InsertCommand) => { - let cmd = self.current_command(); - match cmd { - Some(cmd) => vec![Effect::ExitApp(ExitAction::Insert(cmd))], - None => vec![], - } - } - - ( - AgentState::Idle { - confirmation: Some(_), - }, - Event::Cancel, - ) => { - self.state = AgentState::Idle { confirmation: None }; - vec![] - } - - (AgentState::Idle { confirmation: None }, Event::Cancel) => { - vec![Effect::ExitApp(ExitAction::Cancel)] - } - - (AgentState::Idle { .. }, Event::ConfirmationTimeout { timeout_id }) => { - if self - .state_confirmation() - .is_some_and(|c| c.timeout_id == timeout_id) - { - self.state = AgentState::Idle { confirmation: None }; - } - vec![] - } - - (AgentState::Idle { .. }, Event::NewSession) => { - // Archive visible events so they remain on screen but aren't - // sent to the API. Tools persist for rendering. - let visible = self.ctx.events[self.ctx.view_start_index..].to_vec(); - self.ctx.archived_events.extend(visible); - - self.ctx.events.clear(); - self.ctx.session_id = None; - self.ctx.current_turn_tool_ids.clear(); - self.ctx.view_start_index = 0; - self.ctx.is_resumed = false; - - // Add OOB indicator for the new session - self.ctx.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some("/new".to_string()), - content: "Started a new session.".to_string(), - }); - - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::ArchiveSession, Effect::Persist] - } - - (AgentState::Idle { .. }, Event::SlashCommand { command, content }) => { - self.handle_slash_command(&command, &content); - vec![] - } - - ( - AgentState::Idle { .. }, - Event::SkillLoaded { - name, - arguments, - content, - }, - ) => { - self.ctx.events.push(ConversationEvent::SkillInvocation { - name, - arguments, - content, - }); - self.ctx.current_response.clear(); - self.ctx.current_turn_tool_ids.clear(); - - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - // ================================================================ - // Turn — stream lifecycle - // ================================================================ - ( - AgentState::Turn { - stream: StreamPhase::Connecting, - }, - Event::StreamStarted, - ) => { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { status: None }, - }; - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Connecting, - }, - Event::StreamError(e), - ) => { - self.state = AgentState::Error(e); - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamChunk(text), - ) => { - self.ctx.current_response.push_str(&text); - vec![] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamStatusChanged(status), - ) => { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { - status: Some(StreamingStatus::from_str(&status)), - }, - }; - vec![] - } - - (AgentState::Turn { .. }, Event::StreamToolCall { id, name, input }) => { - self.commit_streaming_text(); - self.handle_stream_tool_call(id, name, input) - } - - (AgentState::Turn { .. }, Event::SuggestCommand { id, input }) => { - self.commit_streaming_text(); - // Push the suggest_command as a ToolCall event (protocol requirement) - self.ctx.events.push(ConversationEvent::ToolCall { - id, - name: "suggest_command".to_string(), - input, - }); - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::Persist] - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamServerToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }, - ) => { - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - }); - vec![] - } - - (AgentState::Turn { .. }, Event::StreamDone { session_id }) => { - self.commit_streaming_text(); - if !session_id.is_empty() { - self.ctx.session_id = Some(session_id); - } - self.state = AgentState::Turn { - stream: StreamPhase::Done, - }; - self.check_turn_completion() - } - - ( - AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - }, - Event::StreamError(e), - ) => { - // Abort any executing tools on stream error - let abort_effects: Vec<_> = self - .ctx - .tools - .executing_ids() - .into_iter() - .map(|tool_id| Effect::AbortTool { tool_id }) - .collect(); - self.ctx.tool_timeout_ids.clear(); - self.state = AgentState::Error(e); - abort_effects - } - - // ================================================================ - // Turn — tool lifecycle (any stream phase) - // ================================================================ - (AgentState::Turn { .. }, Event::PermissionResolved { tool_id, response }) => { - self.handle_permission_resolved(tool_id, response) - } - - (AgentState::Turn { .. }, Event::PermissionUserChoice { tool_id, choice }) => { - self.handle_permission_choice(tool_id, choice) - } - - ( - AgentState::Turn { .. }, - Event::ToolExecutionDone { - tool_id, - outcome, - preview, - }, - ) => self.handle_tool_done(tool_id, outcome, preview), - - ( - AgentState::Turn { .. }, - Event::ToolPreviewUpdate { - tool_id, - lines, - exit_code, - }, - ) => { - if let Some(tracked) = self.ctx.tools.get_mut(&tool_id) { - if tracked.is_resolved() { - // Tool already completed — a late preview update raced with - // ToolExecutionDone. Update lines (they may carry the final - // screen) but preserve the finalized exit_code/interrupted. - if let Some(tools::ToolPreviewData::Shell { - lines: existing_lines, - .. - }) = &mut tracked.preview - { - *existing_lines = lines; - } - } else { - tracked.preview = Some(tools::ToolPreviewData::Shell { - lines, - exit_code, - interrupted: None, - }); - } - } - vec![] - } - - (AgentState::Turn { .. }, Event::InterruptTools) => { - let ids = self.ctx.tools.executing_ids(); - for id in &ids { - if let Some(tracked) = self.ctx.tools.get_mut(id) { - tracked.interrupt_reason = Some(tools::InterruptReason::User); - } - // Clear any pending execution timeout for this tool - self.ctx.tool_timeout_ids.retain(|_, tid| tid != id); - } - ids.into_iter() - .map(|tool_id| Effect::AbortTool { tool_id }) - .collect() - } - - ( - AgentState::Turn { .. }, - Event::ToolExecutionTimeout { - timeout_id, - tool_id, - }, - ) => self.handle_tool_execution_timeout(timeout_id, tool_id), - - // ─── Cancel during Turn ───────────────────────────────────── - (AgentState::Turn { stream }, Event::Cancel) => { - let mut effects = Vec::new(); - - // Abort stream if still active - if !matches!(stream, StreamPhase::Done) { - effects.push(Effect::AbortStream); - } - - // Cancel all pending tools - let pending = self.ctx.tools.pending_ids(); - for id in &pending { - if let Some(tracked) = self.ctx.tools.get_mut(id) { - if tracked.state == ToolState::Executing { - effects.push(Effect::AbortTool { - tool_id: id.clone(), - }); - } - tracked.state = ToolState::Completed; - } - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: id.clone(), - content: "Error: user cancelled this operation".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - } - - // Commit any partial streaming text - self.commit_streaming_text_as_cancelled(); - - // Add context so the LLM knows what happened - if !pending.is_empty() { - self.ctx.events.push(ConversationEvent::SystemContext { - content: "The user cancelled the previous generation. Tool calls that were in progress have been aborted.".to_string(), - }); - } - - // Clear timeout mappings — stale timeouts will be ignored by the guard - self.ctx.tool_timeout_ids.clear(); - - self.state = AgentState::Idle { confirmation: None }; - effects.push(Effect::Persist); - effects - } - - // ================================================================ - // Error state - // ================================================================ - (AgentState::Error(_), Event::Retry) => { - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - (AgentState::Error(_), Event::Cancel) => { - vec![Effect::ExitApp(ExitAction::Cancel)] - } - - // ================================================================ - // Fallthrough — ignore events with no valid transition - // ================================================================ - - // StreamDone can arrive after SuggestCommand (which already moved to Idle). - // We still need to capture the session_id from it. - (_, Event::StreamDone { session_id }) => { - if !session_id.is_empty() { - self.ctx.session_id = Some(session_id); - } - vec![Effect::Persist] - } - - (_, Event::SlashCommand { command, content }) => { - self.handle_slash_command(&command, &content); - vec![] - } - - // RequestSkillLoad during non-idle: still emit the effect - (_, Event::RequestSkillLoad { name, arguments }) => { - vec![Effect::LoadSkill { name, arguments }] - } - - // SkillLoaded during non-idle: queue so it's visible - // in context for the next turn. - ( - _, - Event::SkillLoaded { - name, - arguments, - content, - }, - ) => { - self.ctx.events.push(ConversationEvent::SkillInvocation { - name, - arguments, - content, - }); - vec![] - } - - _ => vec![], - } - } - - // ──────────────────────────────────────────────────────────────────── - // Private helpers - // ──────────────────────────────────────────────────────────────────── - - /// Start a new turn: push user message, build messages, emit StartStream. - fn start_turn(&mut self, msg: String) -> Vec<Effect> { - self.ctx - .events - .push(ConversationEvent::UserMessage { content: msg }); - // Don't clear tools — completed tools persist for rendering history. - // Tools are only cleared on /new (session reset). - self.ctx.current_response.clear(); - self.ctx.current_turn_tool_ids.clear(); - - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } - - /// Build API messages from the conversation event log. - fn build_messages(&self) -> Vec<Value> { - ContextWindowBuilder::with_default_budget().build(&self.ctx.events) - } - - /// Commit accumulated streaming text to the event log. - fn commit_streaming_text(&mut self) { - let text = std::mem::take(&mut self.ctx.current_response); - let trimmed = text.trim_start().to_string(); - if !trimmed.is_empty() { - self.ctx - .events - .push(ConversationEvent::Text { content: trimmed }); - } - } - - /// Commit streaming text with a cancellation suffix. - fn commit_streaming_text_as_cancelled(&mut self) { - let text = std::mem::take(&mut self.ctx.current_response); - let trimmed = text.trim_start().to_string(); - if !trimmed.is_empty() { - self.ctx.events.push(ConversationEvent::Text { - content: format!("{trimmed}\n\n[User cancelled this generation]"), - }); - } - } - - /// Handle a client-side tool call from the stream. - fn handle_stream_tool_call(&mut self, id: String, name: String, input: Value) -> Vec<Effect> { - // Parse the tool call - let tool = match crate::tools::ClientToolCall::try_from((name.as_str(), &input)) { - Ok(tool) => tool, - Err(_) => { - // Unknown tool — push as event but don't track - self.ctx - .events - .push(ConversationEvent::ToolCall { id, name, input }); - return vec![]; - } - }; - - // Capability gating - if let Some(required_cap) = tool.descriptor().capability - && !self.ctx.capabilities.iter().any(|c| c == required_cap) - { - self.ctx.events.push(ConversationEvent::ToolCall { - id: id.clone(), - name, - input, - }); - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: id, - content: format!( - "Tool not enabled: capability '{required_cap}' was not advertised by this client" - ), - is_error: true, - remote: false, - content_length: None, - }); - return vec![]; - } - - // Track the tool and push ToolCall event - let tool_for_effect = tool.clone(); - self.ctx.tools.insert(id.clone(), tool); - self.ctx.current_turn_tool_ids.push(id.clone()); - self.ctx.events.push(ConversationEvent::ToolCall { - id: id.clone(), - name, - input, - }); - - // Transition to Turn if we were Streaming - if let AgentState::Turn { - stream: StreamPhase::Streaming { .. }, - } = &self.state - { - self.state = AgentState::Turn { - stream: StreamPhase::Streaming { status: None }, - }; - } - - vec![Effect::CheckPermission { - tool_id: id, - tool: tool_for_effect, - }] - } - - /// Handle permission resolver result. - fn handle_permission_resolved( - &mut self, - tool_id: String, - response: PermissionResponse, - ) -> Vec<Effect> { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - // If already resolved (e.g. cancelled while permission check was in flight), - // ignore the stale result to avoid re-executing a cancelled tool. - if tracked.is_resolved() { - return vec![]; - } - - match response { - PermissionResponse::Allowed | PermissionResponse::SessionGranted => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - self.emit_execute_tool(tool_id, tool) - } - PermissionResponse::Ask => { - tracked.state = ToolState::AwaitingPermission; - vec![] - } - PermissionResponse::Denied => { - tracked.state = ToolState::Denied; - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content: "Permission denied on the user's system".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - self.check_turn_completion() - } - } - } - - /// Handle user's permission choice from the dialog. - fn handle_permission_choice( - &mut self, - tool_id: String, - choice: PermissionChoice, - ) -> Vec<Effect> { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - if tracked.is_resolved() { - return vec![]; - } - - match choice { - PermissionChoice::Allow => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - self.emit_execute_tool(tool_id, tool) - } - PermissionChoice::AllowForSession => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let mut effects = self.emit_execute_tool(tool_id, tool.clone()); - if let Some(path) = tool.resolved_file_path() { - effects.push(Effect::CacheSessionGrant { path }); - } - effects - } - PermissionChoice::AlwaysAllowInProject => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let rule = crate::permissions::rule::Rule { - tool: tool.rule_name().to_string(), - scope: None, // project file provides the scoping - }; - let mut effects = self.emit_execute_tool(tool_id, tool); - effects.push(Effect::WritePermissionRule { - target: PermissionTarget::Project, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }); - effects - } - PermissionChoice::AlwaysAllow => { - tracked.state = ToolState::Executing; - let tool = tracked.tool.clone(); - let scope = tool - .resolved_file_path() - .map(|p| p.to_string_lossy().to_string()); - let rule = crate::permissions::rule::Rule { - tool: tool.rule_name().to_string(), - scope, - }; - let mut effects = self.emit_execute_tool(tool_id, tool); - effects.push(Effect::WritePermissionRule { - target: PermissionTarget::Global, - rule, - disposition: crate::permissions::writer::RuleDisposition::Allow, - }); - effects - } - PermissionChoice::Deny => { - tracked.state = ToolState::Denied; - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content: "Permission denied by the user".to_string(), - is_error: true, - remote: false, - content_length: None, - }); - self.check_turn_completion() - } - } - } - - /// Handle tool execution completion. - fn handle_tool_done( - &mut self, - tool_id: String, - outcome: crate::tools::ToolOutcome, - preview: Option<tools::ToolPreviewData>, - ) -> Vec<Effect> { - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - // If already completed (e.g. cancelled), ignore stale result - if tracked.is_resolved() { - return vec![]; - } - - tracked.state = ToolState::Completed; - - // If the FSM tagged this tool with an interrupt reason (user or timeout), - // use it; otherwise derive from the outcome's interrupted flag. - let reason = tracked.interrupt_reason.take().or({ - if let crate::tools::ToolOutcome::Structured { - interrupted: true, .. - } = &outcome - { - Some(tools::InterruptReason::User) - } else { - None - } - }); - - // Merge shell preview: the final ToolExecutionDone carries exit_code/interrupted - // but has empty lines (the live lines were accumulated via ToolPreviewUpdate). - // Preserve the accumulated lines and fold in the terminal metadata. - match (&mut tracked.preview, preview) { - ( - Some(tools::ToolPreviewData::Shell { - exit_code, - interrupted, - .. - }), - Some(tools::ToolPreviewData::Shell { - exit_code: final_exit, - .. - }), - ) => { - *exit_code = final_exit; - *interrupted = reason.clone(); - } - (_, Some(mut p)) => { - if let tools::ToolPreviewData::Shell { - ref mut interrupted, - .. - } = p - { - *interrupted = reason.clone(); - } - tracked.preview = Some(p); - } - _ => {} - } - - // Clean up any pending execution timeout for this tool - self.ctx.tool_timeout_ids.retain(|_, tid| tid != &tool_id); - - let content = outcome.format_for_llm(reason.as_ref()); - let is_error = outcome.is_error(); - self.ctx.events.push(ConversationEvent::ToolResult { - tool_use_id: tool_id, - content, - is_error, - remote: false, - content_length: None, - }); - - self.check_turn_completion() - } - - /// Handle a tool execution timeout. Aborts the tool if it's still running. - fn handle_tool_execution_timeout(&mut self, timeout_id: u64, tool_id: String) -> Vec<Effect> { - // Guard: only act if this timeout is still registered (not cleaned up by natural completion) - if self.ctx.tool_timeout_ids.remove(&timeout_id).is_none() { - return vec![]; - } - - let Some(tracked) = self.ctx.tools.get_mut(&tool_id) else { - return vec![]; - }; - - if tracked.is_resolved() { - return vec![]; - } - - // Tag the tool so handle_tool_done can distinguish timeout from user interrupt. - // Only shell tools have entries in tool_timeout_ids, so this is always Shell. - let timeout_secs = match &tracked.tool { - crate::tools::ClientToolCall::Shell(s) => s.timeout_secs, - _ => unreachable!("only shell tools have execution timeouts"), - }; - tracked.interrupt_reason = Some(tools::InterruptReason::Timeout(timeout_secs)); - - // Abort the tool — the driver sends the interrupt signal via oneshot, - // and execute_shell_command_streaming returns a Structured outcome with - // interrupted: true and partial stdout/stderr. This flows through the - // normal ToolExecutionDone path. - vec![Effect::AbortTool { tool_id }] - } - - /// Emit effects to begin executing a tool. For shell commands, also schedules - /// an execution timeout based on the LLM-specified timeout_secs. - fn emit_execute_tool( - &mut self, - tool_id: String, - tool: crate::tools::ClientToolCall, - ) -> Vec<Effect> { - let mut effects = vec![Effect::ExecuteTool { - tool_id: tool_id.clone(), - tool: tool.clone(), - }]; - - if let crate::tools::ClientToolCall::Shell(ref shell) = tool { - let timeout_id = self.ctx.next_timeout_id(); - self.ctx - .tool_timeout_ids - .insert(timeout_id, tool_id.clone()); - effects.push(Effect::ScheduleTimeout { - timeout_id, - duration: std::time::Duration::from_secs(shell.timeout_secs), - kind: TimeoutKind::ToolExecution { tool_id }, - }); - } - - effects - } - - /// Check if the turn is complete (stream done + all tools resolved). - /// If so, either continue the conversation or go Idle. - fn check_turn_completion(&mut self) -> Vec<Effect> { - // Stream must be done - if !matches!( - self.state, - AgentState::Turn { - stream: StreamPhase::Done - } - ) { - return vec![]; - } - - // All current-turn tools must be resolved before the turn can complete - if !self.ctx.tools.all_resolved(&self.ctx.current_turn_tool_ids) { - return vec![]; - } - - // Turn is complete. Check if we need to continue (tool results to send back). - // We continue if this turn had any client tool calls (the LLM needs to see - // the results and respond). - if !self.ctx.current_turn_tool_ids.is_empty() { - // Continue conversation with tool results. - // Don't clear tools — they persist for rendering history. - // Clear turn IDs so the continuation turn doesn't loop. - self.ctx.current_turn_tool_ids.clear(); - let messages = self.build_messages(); - let session_id = self.ctx.session_id.clone(); - self.ctx.current_response.clear(); - self.state = AgentState::Turn { - stream: StreamPhase::Connecting, - }; - vec![Effect::StartStream { - messages, - session_id, - }] - } else { - // No tools — turn is done, go idle - self.state = AgentState::Idle { confirmation: None }; - vec![Effect::Persist] - } - } - - /// Extract the current confirmation state (if any). - fn state_confirmation(&self) -> Option<&PendingConfirmation> { - if let AgentState::Idle { - confirmation: Some(ref c), - } = self.state - { - Some(c) - } else { - None - } - } - - /// Get the most recent suggested command from the conversation. - /// Get the most recent command from the current invocation only. - fn current_command(&self) -> Option<String> { - self.current_invocation_events() - .rev() - .find_map(|e| e.as_command()) - .map(|s| s.to_string()) - } - - /// Check if the most recent command is dangerous. - fn is_current_command_dangerous(&self) -> bool { - self.current_invocation_events() - .rev() - .find_map(|e| { - if let ConversationEvent::ToolCall { name, input, .. } = e - && name == "suggest_command" - { - let danger = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("low"); - Some(danger == "high" || danger == "medium" || danger == "med") - } else { - None - } - }) - .unwrap_or(false) - } - - /// Events from the current invocation only (from view_start_index onward). - fn current_invocation_events(&self) -> impl DoubleEndedIterator<Item = &ConversationEvent> { - let start = self.ctx.view_start_index.min(self.ctx.events.len()); - self.ctx.events[start..].iter() - } - - /// Handle a slash command by pushing an OOB event. - fn handle_slash_command(&mut self, command: &str, content: &str) { - self.ctx.events.push(ConversationEvent::OutOfBandOutput { - name: "System".to_string(), - command: Some(command.to_string()), - content: content.to_string(), - }); - } -} diff --git a/crates/atuin-ai/src/fsm/tests.rs b/crates/atuin-ai/src/fsm/tests.rs deleted file mode 100644 index 51c23915..00000000 --- a/crates/atuin-ai/src/fsm/tests.rs +++ /dev/null @@ -1,890 +0,0 @@ -//! Pure FSM transition tests. No IO, no async. - -use serde_json::json; - -use super::*; -use effects::{Effect, ExitAction}; -use events::{Event, PermissionChoice, PermissionResponse}; - -fn new_fsm() -> AgentFsm { - AgentFsm::new( - vec!["client_v1_read_file".to_string()], - "test-inv".to_string(), - ) -} - -// ============================================================================ -// Idle → Turn -// ============================================================================ - -#[test] -fn user_submit_starts_turn() { - let mut fsm = new_fsm(); - - let effects = fsm.handle(Event::UserSubmit("hello".into())); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert_eq!(effects.len(), 1); - assert!(matches!(effects[0], Effect::StartStream { .. })); - // User message was pushed to events - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::UserMessage { content } if content == "hello" - ))); -} - -#[test] -fn stream_started_transitions_to_streaming() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - - let effects = fsm.handle(Event::StreamStarted); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Streaming { status: None } - } - )); - assert!(effects.is_empty()); -} - -#[test] -fn stream_chunk_accumulates_text() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - - fsm.handle(Event::StreamChunk("Hello ".into())); - fsm.handle(Event::StreamChunk("world!".into())); - - assert_eq!(fsm.ctx.current_response, "Hello world!"); -} - -#[test] -fn stream_done_without_tools_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("Hi there!".into())); - - let effects = fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert_eq!(fsm.ctx.session_id, Some("s1".to_string())); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - // Text was committed to events - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::Text { content } if content == "Hi there!" - ))); -} - -// ============================================================================ -// Tool lifecycle -// ============================================================================ - -#[test] -fn stream_tool_call_tracks_tool_and_emits_check_permission() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read a file".into())); - fsm.handle(Event::StreamStarted); - - let effects = fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - assert!(fsm.ctx.tools.get("t1").is_some()); - assert_eq!(effects.len(), 1); - assert!(matches!(effects[0], Effect::CheckPermission { .. })); -} - -#[test] -fn permission_allowed_transitions_to_executing() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert_eq!(fsm.ctx.tools.get("t1").unwrap().state, ToolState::Executing); - assert!(matches!(effects[0], Effect::ExecuteTool { .. })); -} - -#[test] -fn permission_ask_transitions_to_awaiting() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Ask, - }); - - assert_eq!( - fsm.ctx.tools.get("t1").unwrap().state, - ToolState::AwaitingPermission - ); - assert!(effects.is_empty()); -} - -#[test] -fn tool_done_after_stream_done_continues_conversation() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Now in Turn { Done } with one tool Executing - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("file contents".into()), - preview: None, - }); - - // Turn complete → continuation - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -#[test] -fn continuation_turn_without_new_tools_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - // Tool completes → continuation starts - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - - // Continuation stream: text only, no new tools - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("Here's the file.".into())); - let effects = fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - - // Should go Idle, NOT start another continuation - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - assert!( - !effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -#[test] -fn tool_done_before_stream_done_stays_in_turn() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Tool completes but stream hasn't sent Done yet - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - - // Still in Turn — stream phase is Streaming, not Done - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Streaming { .. } - } - )); - assert!(effects.is_empty()); -} - -// ============================================================================ -// Cancel -// ============================================================================ - -#[test] -fn cancel_during_streaming_goes_idle() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamChunk("partial text".into())); - - let effects = fsm.handle(Event::Cancel); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.iter().any(|e| matches!(e, Effect::AbortStream))); - assert!(effects.iter().any(|e| matches!(e, Effect::Persist))); - // Partial text committed with cancel suffix - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::Text { content } if content.contains("[User cancelled") - ))); -} - -#[test] -fn stale_permission_resolved_after_cancel_is_ignored() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - // Tool is in CheckingPermission, cancel happens before permission resolves - fsm.handle(Event::Cancel); - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - - // Stale permission result arrives — tool is already Completed (cancelled) - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Should NOT emit ExecuteTool — the tool was cancelled - assert!(effects.is_empty()); -} - -#[test] -fn cancel_during_turn_with_pending_tools() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - // Tool is Executing, stream is Done - - let effects = fsm.handle(Event::Cancel); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::AbortTool { .. })) - ); - // Error ToolResult injected - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" - ))); - // SystemContext about cancellation - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::SystemContext { content } if content.contains("cancelled") - ))); -} - -#[test] -fn stale_tool_result_after_cancel_is_ignored() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::Cancel); - - // Stale event arrives - let effects = fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("contents".into()), - preview: None, - }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); - assert!(effects.is_empty()); -} - -// ============================================================================ -// Confirmation -// ============================================================================ - -#[test] -fn dangerous_command_enters_confirmation() { - let mut fsm = new_fsm(); - // Simulate a dangerous command in history - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - - let effects = fsm.handle(Event::ExecuteCommand); - - assert!(matches!( - fsm.state, - AgentState::Idle { - confirmation: Some(_) - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) - ); -} - -#[test] -fn second_execute_confirms_and_exits() { - let mut fsm = new_fsm(); - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - fsm.handle(Event::ExecuteCommand); - - let effects = fsm.handle(Event::ExecuteCommand); - - assert!(effects.iter().any(|e| matches!( - e, - Effect::ExitApp(ExitAction::Execute(cmd)) if cmd == "rm -rf /" - ))); -} - -#[test] -fn confirmation_timeout_clears_confirmation() { - let mut fsm = new_fsm(); - fsm.ctx.events.push(ConversationEvent::ToolCall { - id: "sc1".into(), - name: "suggest_command".into(), - input: json!({"command": "rm -rf /", "description": "bad", "confidence": "high", "danger": "high"}), - }); - fsm.handle(Event::ExecuteCommand); - let timeout_id = match &fsm.state { - AgentState::Idle { - confirmation: Some(c), - } => c.timeout_id, - _ => panic!("expected confirmation"), - }; - - fsm.handle(Event::ConfirmationTimeout { timeout_id }); - - assert_eq!(fsm.state, AgentState::Idle { confirmation: None }); -} - -// ============================================================================ -// Error / Retry -// ============================================================================ - -#[test] -fn stream_error_goes_to_error_state() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - - fsm.handle(Event::StreamError("network error".into())); - - assert_eq!(fsm.state, AgentState::Error("network error".to_string())); -} - -#[test] -fn retry_from_error_starts_new_stream() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("hello".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamError("fail".into())); - - let effects = fsm.handle(Event::Retry); - - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); -} - -// ============================================================================ -// Permission choices -// ============================================================================ - -#[test] -fn permission_deny_completes_turn_and_continues() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - fsm.handle(Event::StreamDone { - session_id: "".into(), - }); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Ask, - }); - - let effects = fsm.handle(Event::PermissionUserChoice { - tool_id: "t1".into(), - choice: PermissionChoice::Deny, - }); - - // Turn should complete since all tools resolved and stream is done - // → continuation needed (there was a tool result to send back) - assert!(matches!( - fsm.state, - AgentState::Turn { - stream: StreamPhase::Connecting - } - )); - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::StartStream { .. })) - ); - // Error result was injected - assert!(fsm.ctx.events.iter().any(|e| matches!( - e, - ConversationEvent::ToolResult { tool_use_id, is_error: true, .. } if tool_use_id == "t1" - ))); -} - -// ============================================================================ -// Shell execution timeouts -// ============================================================================ - -fn fsm_with_shell() -> AgentFsm { - AgentFsm::new( - vec![ - "client_v1_read_file".to_string(), - "client_v1_execute_shell_command".to_string(), - ], - "test-inv".to_string(), - ) -} - -fn shell_tool_call_event(id: &str) -> Event { - Event::StreamToolCall { - id: id.into(), - name: "execute_shell_command".into(), - input: json!({ - "command": "sleep 999", - "shell": "bash", - "timeout": 60, - "description": "test" - }), - } -} - -#[test] -fn shell_tool_schedules_execution_timeout() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run something".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - // Should have ExecuteTool + ScheduleTimeout - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ExecuteTool { .. })) - ); - assert!(effects.iter().any(|e| matches!( - e, - Effect::ScheduleTimeout { kind: effects::TimeoutKind::ToolExecution { tool_id }, .. } - if tool_id == "t1" - ))); - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn read_tool_does_not_schedule_timeout() { - let mut fsm = new_fsm(); - fsm.handle(Event::UserSubmit("read".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "read_file".into(), - input: json!({"file_path": "/tmp/test.txt"}), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!( - effects - .iter() - .any(|e| matches!(e, Effect::ExecuteTool { .. })) - ); - assert!( - !effects - .iter() - .any(|e| matches!(e, Effect::ScheduleTimeout { .. })) - ); - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn tool_completion_clears_timeout_mapping() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - // Tool completes naturally - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("done".into()), - preview: None, - }); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn stale_timeout_after_natural_completion_is_ignored() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Tool completes naturally - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Success("done".into()), - preview: None, - }); - - // Stale timeout fires — should be no-op - let effects = fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - assert!(effects.is_empty()); -} - -#[test] -fn timeout_fires_before_completion_emits_abort() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Timeout fires while tool is still executing - let effects = fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - assert_eq!(effects.len(), 1); - assert!(matches!( - effects[0], - Effect::AbortTool { ref tool_id } if tool_id == "t1" - )); - // Timeout mapping cleaned up - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn timeout_respects_llm_specified_duration() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - - // Tool call with timeout: 120 - fsm.handle(Event::StreamToolCall { - id: "t1".into(), - name: "execute_shell_command".into(), - input: json!({ - "command": "cargo build", - "shell": "bash", - "timeout": 120, - "description": "build" - }), - }); - - let effects = fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - let timeout_effect = effects - .iter() - .find(|e| matches!(e, Effect::ScheduleTimeout { .. })); - assert!(matches!( - timeout_effect, - Some(Effect::ScheduleTimeout { duration, .. }) if *duration == std::time::Duration::from_secs(120) - )); -} - -#[test] -fn cancel_clears_timeout_mappings() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - fsm.handle(Event::Cancel); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} - -#[test] -fn timeout_abort_propagates_timeout_reason_to_preview_and_llm() { - use super::tools::InterruptReason; - - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // Timeout fires - fsm.handle(Event::ToolExecutionTimeout { - timeout_id: 0, - tool_id: "t1".into(), - }); - - // Tool completes after abort (interrupted: true from execute_shell_command_streaming) - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Structured { - stdout: "partial output".into(), - stderr: String::new(), - exit_code: None, - duration_ms: 60000, - interrupted: true, - }, - preview: Some(super::tools::ToolPreviewData::Shell { - lines: vec!["partial output".into()], - exit_code: None, - interrupted: None, // FSM overrides this with the reason - }), - }); - - // Preview should carry Timeout reason - let tracked = fsm.ctx.tools.get("t1").unwrap(); - let preview = tracked.shell_preview().unwrap(); - assert_eq!(preview.interrupted, Some(InterruptReason::Timeout(60))); - - // LLM content should say "Timed out" not "Interrupted by user" - let tool_result = fsm.ctx.events.iter().find( - |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), - ); - if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { - assert!( - content.contains("[Timed out after 60s]"), - "Expected timeout message, got: {content}" - ); - assert!(!content.contains("[Interrupted by user]")); - } else { - panic!("No ToolResult found for t1"); - } -} - -#[test] -fn user_interrupt_propagates_user_reason_to_preview_and_llm() { - use super::tools::InterruptReason; - - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - fsm.handle(Event::StreamDone { - session_id: "s1".into(), - }); - - // User interrupts - fsm.handle(Event::InterruptTools); - - // Tool completes after abort - fsm.handle(Event::ToolExecutionDone { - tool_id: "t1".into(), - outcome: crate::tools::ToolOutcome::Structured { - stdout: "partial".into(), - stderr: String::new(), - exit_code: None, - duration_ms: 5000, - interrupted: true, - }, - preview: Some(super::tools::ToolPreviewData::Shell { - lines: vec!["partial".into()], - exit_code: None, - interrupted: None, // FSM overrides this with the reason - }), - }); - - // Preview should carry User reason - let tracked = fsm.ctx.tools.get("t1").unwrap(); - let preview = tracked.shell_preview().unwrap(); - assert_eq!(preview.interrupted, Some(InterruptReason::User)); - - // LLM content should say "Interrupted by user" - let tool_result = fsm.ctx.events.iter().find( - |e| matches!(e, ConversationEvent::ToolResult { tool_use_id, .. } if tool_use_id == "t1"), - ); - if let Some(ConversationEvent::ToolResult { content, .. }) = tool_result { - assert!( - content.contains("[Interrupted by user]"), - "Expected user interrupt message, got: {content}" - ); - } else { - panic!("No ToolResult found for t1"); - } -} - -#[test] -fn user_interrupt_clears_timeout_mappings_for_aborted_tools() { - let mut fsm = fsm_with_shell(); - fsm.handle(Event::UserSubmit("run".into())); - fsm.handle(Event::StreamStarted); - fsm.handle(shell_tool_call_event("t1")); - fsm.handle(Event::PermissionResolved { - tool_id: "t1".into(), - response: PermissionResponse::Allowed, - }); - - assert!(!fsm.ctx.tool_timeout_ids.is_empty()); - - fsm.handle(Event::InterruptTools); - - assert!(fsm.ctx.tool_timeout_ids.is_empty()); -} diff --git a/crates/atuin-ai/src/fsm/tools.rs b/crates/atuin-ai/src/fsm/tools.rs deleted file mode 100644 index 96348672..00000000 --- a/crates/atuin-ai/src/fsm/tools.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Tool lifecycle management within the FSM. -//! -//! Each tool call goes through an independent lifecycle. The ToolManager -//! tracks all tools in the current turn and provides the "all resolved" -//! check that gates turn completion. - -use crate::diff::{EditPreview, WritePreview}; -use crate::tools::ClientToolCall; - -/// Why a tool execution was interrupted. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum InterruptReason { - /// User pressed Ctrl+C or Esc during execution. - User, - /// The LLM-specified execution timeout expired. - Timeout(u64), -} - -/// Per-tool lifecycle state. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ToolState { - /// Permission resolver is running asynchronously. - CheckingPermission, - /// Waiting for user to grant/deny via the permission dialog. - AwaitingPermission, - /// Actively executing. - Executing, - /// Execution completed (result injected into conversation). - Completed, - /// User denied permission (error result injected into conversation). - Denied, -} - -/// Cached preview data for rendering tool output. -#[derive(Debug, Clone)] -pub(crate) enum ToolPreviewData { - /// Shell command VT100 output lines. - Shell { - lines: Vec<String>, - exit_code: Option<i32>, - interrupted: Option<InterruptReason>, - }, - /// File edit diff preview. - Edit(EditPreview), - /// File write content preview. - Write(WritePreview), -} - -/// A tracked tool call with its current lifecycle state. -#[derive(Debug, Clone)] -pub(crate) struct TrackedTool { - pub id: String, - pub tool: ClientToolCall, - pub state: ToolState, - /// Cached preview data for rendering (populated during/after execution). - pub preview: Option<ToolPreviewData>, - /// Set by the FSM when it emits AbortTool, so that ToolExecutionDone - /// can distinguish user interrupts from timeouts. - pub interrupt_reason: Option<InterruptReason>, -} - -impl TrackedTool { - /// Whether this tool has reached a terminal state. - pub fn is_resolved(&self) -> bool { - matches!(self.state, ToolState::Completed | ToolState::Denied) - } - - /// Extract shell preview data (for TurnBuilder compatibility). - pub fn shell_preview(&self) -> Option<crate::tools::ToolPreview> { - match &self.preview { - Some(ToolPreviewData::Shell { - lines, - exit_code, - interrupted, - }) => Some(crate::tools::ToolPreview { - lines: lines.clone(), - exit_code: *exit_code, - interrupted: interrupted.clone(), - }), - _ => None, - } - } - - /// Extract edit diff preview (for TurnBuilder compatibility). - pub fn edit_preview(&self) -> Option<&EditPreview> { - match &self.preview { - Some(ToolPreviewData::Edit(p)) => Some(p), - _ => None, - } - } - - /// Extract write content preview (for TurnBuilder compatibility). - pub fn write_preview(&self) -> Option<&WritePreview> { - match &self.preview { - Some(ToolPreviewData::Write(p)) => Some(p), - _ => None, - } - } -} - -/// Manages tool call lifecycles for a single turn. -/// -/// Tools are inserted when received from the stream and progress through -/// their lifecycle independently. The manager provides aggregate queries -/// (all resolved, any awaiting permission, etc.) that the FSM uses for -/// state transitions. -#[derive(Debug, Clone, Default)] -pub(crate) struct ToolManager { - tools: Vec<TrackedTool>, -} - -impl ToolManager { - pub fn new() -> Self { - Self { tools: Vec::new() } - } - - /// Insert a new tool in CheckingPermission state. - pub fn insert(&mut self, id: String, tool: ClientToolCall) { - self.tools.push(TrackedTool { - id, - tool, - state: ToolState::CheckingPermission, - preview: None, - interrupt_reason: None, - }); - } - - /// Look up a tool by ID. - pub fn get(&self, id: &str) -> Option<&TrackedTool> { - self.tools.iter().find(|t| t.id == id) - } - - /// Look up a tool mutably by ID. - pub fn get_mut(&mut self, id: &str) -> Option<&mut TrackedTool> { - self.tools.iter_mut().find(|t| t.id == id) - } - - /// True if all tools from the given set of IDs have reached a terminal state. - /// Returns true for an empty set (vacuously — no tools to wait for). - pub fn all_resolved(&self, tool_ids: &[String]) -> bool { - tool_ids - .iter() - .all(|id| self.get(id).is_some_and(|t| t.is_resolved())) - } - - /// Find the first tool awaiting user permission. - pub fn awaiting_permission(&self) -> Option<&TrackedTool> { - self.tools - .iter() - .find(|t| t.state == ToolState::AwaitingPermission) - } - - /// Get IDs of all non-resolved tools (for cancel). - pub fn pending_ids(&self) -> Vec<String> { - self.tools - .iter() - .filter(|t| !t.is_resolved()) - .map(|t| t.id.clone()) - .collect() - } - - /// Get IDs of all currently executing tools (for interrupt/abort). - pub fn executing_ids(&self) -> Vec<String> { - self.tools - .iter() - .filter(|t| t.state == ToolState::Executing) - .map(|t| t.id.clone()) - .collect() - } - - /// True if any tool has a shell preview with live output. - pub fn has_executing_preview(&self) -> bool { - self.tools.iter().any(|t| { - t.state == ToolState::Executing - && matches!(t.preview, Some(ToolPreviewData::Shell { .. })) - }) - } -} diff --git a/crates/atuin-ai/src/history_format.rs b/crates/atuin-ai/src/history_format.rs deleted file mode 100644 index 24aa963e..00000000 --- a/crates/atuin-ai/src/history_format.rs +++ /dev/null @@ -1,120 +0,0 @@ -use atuin_client::history::History; -use time::UtcOffset; - -pub(crate) fn current_local_offset() -> UtcOffset { - UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC) -} - -pub(crate) fn format_last_command(history: &History, local_offset: UtcOffset) -> String { - format!( - "History ID: {} - `{}`\n{}", - history.id, - history.command, - format_history_metadata(history, local_offset) - ) -} - -pub(crate) fn format_history_search_result( - ordinal: usize, - history: &History, - local_offset: UtcOffset, -) -> String { - format!( - "## #{}. (History ID: {}):\n`{}`\n{}\n", - ordinal, - history.id, - history.command, - format_history_metadata(history, local_offset) - ) -} - -fn format_history_metadata(history: &History, local_offset: UtcOffset) -> String { - format!( - "[{}] (in `{}`, exit {}){}", - format_timestamp(history, local_offset), - history.cwd, - history.exit, - format_duration(history.duration) - ) -} - -fn format_timestamp(history: &History, local_offset: UtcOffset) -> String { - let ts = history.timestamp.to_offset(local_offset); - format!( - "{:04}-{:02}-{:02} {:02}:{:02}:{:02}", - ts.year(), - ts.month() as u8, - ts.day(), - ts.hour(), - ts.minute(), - ts.second(), - ) -} - -fn format_duration(nanos: i64) -> String { - if nanos <= 0 { - return String::new(); - } - - let total_secs = nanos / 1_000_000_000; - let millis = (nanos % 1_000_000_000) / 1_000_000; - - if total_secs >= 3600 { - let hours = total_secs / 3600; - let mins = (total_secs % 3600) / 60; - let secs = total_secs % 60; - format!(", {hours}h{mins}m{secs}s") - } else if total_secs >= 60 { - let mins = total_secs / 60; - let secs = total_secs % 60; - format!(", {mins}m{secs}s") - } else if total_secs > 0 { - if millis > 0 { - format!(", {total_secs}.{millis:03}s") - } else { - format!(", {total_secs}s") - } - } else { - format!(", {millis}ms") - } -} - -#[cfg(test)] -mod tests { - use atuin_client::history::{History, HistoryId}; - use time::{OffsetDateTime, UtcOffset}; - - use super::*; - - fn history(duration: i64) -> History { - History { - id: HistoryId("018f011c-9a0a-7000-8000-000000000001".to_string()), - timestamp: OffsetDateTime::UNIX_EPOCH, - duration, - exit: 2, - command: "cargo test".to_string(), - cwd: "/repo".to_string(), - session: String::new(), - hostname: String::new(), - author: String::new(), - intent: None, - deleted_at: None, - } - } - - #[test] - fn formats_last_command() { - assert_eq!( - format_last_command(&history(1_234_000_000), UtcOffset::UTC), - "History ID: 018f011c-9a0a-7000-8000-000000000001 - `cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2), 1.234s" - ); - } - - #[test] - fn formats_history_search_result() { - assert_eq!( - format_history_search_result(3, &history(0), UtcOffset::UTC), - "## #3. (History ID: 018f011c-9a0a-7000-8000-000000000001):\n`cargo test`\n[1970-01-01 00:00:00] (in `/repo`, exit 2)\n" - ); - } -} diff --git a/crates/atuin-ai/src/lib.rs b/crates/atuin-ai/src/lib.rs deleted file mode 100644 index f972d4ff..00000000 --- a/crates/atuin-ai/src/lib.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub mod commands; -pub(crate) mod context; -pub(crate) mod context_window; -pub(crate) mod diff; -pub(crate) mod driver; -pub(crate) mod edit_permissions; -pub(crate) mod event_serde; -pub(crate) mod file_tracker; -pub(crate) mod fsm; -pub(crate) mod history_format; -pub(crate) mod permissions; -pub(crate) mod session; -pub(crate) mod skills; -pub(crate) mod snapshots; -pub(crate) mod store; -pub(crate) mod stream; -pub(crate) mod tools; -pub(crate) mod tui; -pub(crate) mod user_context; diff --git a/crates/atuin-ai/src/permissions/check.rs b/crates/atuin-ai/src/permissions/check.rs deleted file mode 100644 index bb1eae0c..00000000 --- a/crates/atuin-ai/src/permissions/check.rs +++ /dev/null @@ -1,71 +0,0 @@ -use eyre::Result; - -use crate::{permissions::file::RuleFile, tools::PermissibleToolCall}; - -pub(crate) struct PermissionRequest<'t> { - call: &'t (dyn PermissibleToolCall + Send + Sync), -} - -impl<'t> PermissionRequest<'t> { - pub fn new(call: &'t (dyn PermissibleToolCall + Send + Sync)) -> Self { - Self { call } - } -} - -pub(crate) enum PermissionResponse { - Allowed, - Denied, - Ask, -} - -pub(crate) struct PermissionChecker { - files: Vec<RuleFile>, -} - -impl PermissionChecker { - pub fn new(files: Vec<RuleFile>) -> Self { - Self { files } - } - - pub async fn check<'t>( - &self, - request: &'t PermissionRequest<'t>, - ) -> Result<PermissionResponse> { - // Files are in order from deepest to shallowest, so we can stop at the first match. - // Within a file, the priority is ask -> deny -> allow - // The first rule type that matches is the one that applies, even if a later rule would contradict it. - for file in &self.files { - for rule in &file.content.permissions.ask { - if request.call.matches_rule(rule) { - tracing::debug!( - "Permission 'ASK' by rule: {} in file: {}", - rule, - file.path.display() - ); - return Ok(PermissionResponse::Ask); - } - } - - for rule in &file.content.permissions.deny { - if request.call.matches_rule(rule) { - tracing::debug!( - "Permission 'DENY' by rule: {} in file: {}", - rule, - file.path.display() - ); - return Ok(PermissionResponse::Denied); - } - } - - if request.call.all_covered_by(&file.content.permissions.allow) { - tracing::debug!( - "Permission 'ALLOW' by rules in file: {}", - file.path.display() - ); - return Ok(PermissionResponse::Allowed); - } - } - - Ok(PermissionResponse::Ask) - } -} diff --git a/crates/atuin-ai/src/permissions/file.rs b/crates/atuin-ai/src/permissions/file.rs deleted file mode 100644 index c973f55b..00000000 --- a/crates/atuin-ai/src/permissions/file.rs +++ /dev/null @@ -1,26 +0,0 @@ -use std::path::PathBuf; - -use serde::{Deserialize, Serialize}; - -use crate::permissions::rule::Rule; - -#[derive(Debug, Clone)] -pub(crate) struct RuleFile { - pub path: PathBuf, - pub content: RuleFileContent, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub(crate) struct RuleFileContent { - pub permissions: RuleFilePermissions, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub(crate) struct RuleFilePermissions { - #[serde(default)] - pub allow: Vec<Rule>, - #[serde(default)] - pub deny: Vec<Rule>, - #[serde(default)] - pub ask: Vec<Rule>, -} diff --git a/crates/atuin-ai/src/permissions/mod.rs b/crates/atuin-ai/src/permissions/mod.rs deleted file mode 100644 index fce64a51..00000000 --- a/crates/atuin-ai/src/permissions/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod check; -pub(crate) mod file; -pub(crate) mod resolver; -pub(crate) mod rule; -pub(crate) mod shell; -pub(crate) mod walker; -pub(crate) mod writer; diff --git a/crates/atuin-ai/src/permissions/resolver.rs b/crates/atuin-ai/src/permissions/resolver.rs deleted file mode 100644 index dc4f83bf..00000000 --- a/crates/atuin-ai/src/permissions/resolver.rs +++ /dev/null @@ -1,31 +0,0 @@ -use std::path::PathBuf; - -use eyre::Result; - -use crate::permissions::check::{PermissionChecker, PermissionRequest, PermissionResponse}; -use crate::permissions::walker::PermissionWalker; -use crate::permissions::writer; -use crate::tools::ClientToolCall; - -/// Resolves permissions for client tool calls by walking the filesystem to find permission files, -pub(crate) struct PermissionResolver { - checker: PermissionChecker, -} - -impl PermissionResolver { - /// Create a new resolver that walks from `working_dir` to root for project - /// permissions, and also checks the global permissions file. - pub async fn new(working_dir: PathBuf) -> Result<Self> { - let global_file = writer::global_permissions_path(); - let mut walker = PermissionWalker::new(working_dir, Some(global_file)); - walker.walk().await?; - let checker = PermissionChecker::new(walker.rules().to_owned()); - Ok(Self { checker }) - } - - /// Check whether `tool` is allowed, denied, or needs user confirmation. - pub async fn check(&self, tool: &ClientToolCall) -> Result<PermissionResponse> { - let request = PermissionRequest::new(tool); - self.checker.check(&request).await - } -} diff --git a/crates/atuin-ai/src/permissions/rule.rs b/crates/atuin-ai/src/permissions/rule.rs deleted file mode 100644 index 8fa3fa4a..00000000 --- a/crates/atuin-ai/src/permissions/rule.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::sync::OnceLock; - -use regex::Regex; -use serde::{Deserialize, Serialize}; - -static RULE_RE: OnceLock<Regex> = OnceLock::new(); - -#[derive(Debug, thiserror::Error)] -pub(crate) enum RuleError { - #[error("invalid rule format: {0}")] - InvalidRule(String), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct Rule { - pub tool: String, - pub scope: Option<String>, -} - -impl std::fmt::Display for Rule { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.scope.as_ref() { - Some(scope) => write!(f, "{}({})", self.tool, scope), - None => write!(f, "{}", self.tool), - } - } -} - -impl Serialize for Rule { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for Rule { - fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - Self::try_from(s.as_str()).map_err(serde::de::Error::custom) - } -} -impl TryFrom<&str> for Rule { - type Error = RuleError; - - fn try_from(value: &str) -> Result<Self, Self::Error> { - let value = value.trim(); - let re = RULE_RE.get_or_init(|| Regex::new(r"^(\w+)(?:\((.*)\))?$").unwrap()); - let caps = re - .captures(value) - .ok_or(RuleError::InvalidRule(value.to_string()))?; - let tool = caps.get(1).unwrap().as_str().to_string(); - let scope = caps.get(2).map(|m| m.as_str().to_string()); - Ok(Rule { tool, scope }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rule_try_from() { - assert_eq!( - Rule::try_from("Read").unwrap(), - Rule { - tool: "Read".to_string(), - scope: None - } - ); - assert_eq!( - Rule::try_from("Read(*)").unwrap(), - Rule { - tool: "Read".to_string(), - scope: Some("*".to_string()) - } - ); - assert_eq!( - Rule::try_from("Write(*.md)").unwrap(), - Rule { - tool: "Write".to_string(), - scope: Some("*.md".to_string()) - } - ); - assert_eq!( - Rule::try_from("Shell(git commit *)").unwrap(), - Rule { - tool: "Shell".to_string(), - scope: Some("git commit *".to_string()) - } - ); - assert_eq!( - Rule::try_from("Shell(echo ())").unwrap(), - Rule { - tool: "Shell".to_string(), - scope: Some("echo ()".to_string()) - } - ); - assert!(Rule::try_from("Shell(git commit *").is_err()); - assert!(Rule::try_from("Shell(git commit *)!").is_err()); - } -} diff --git a/crates/atuin-ai/src/permissions/shell.rs b/crates/atuin-ai/src/permissions/shell.rs deleted file mode 100644 index 29b9f5d8..00000000 --- a/crates/atuin-ai/src/permissions/shell.rs +++ /dev/null @@ -1,1335 +0,0 @@ -/// Extracted command info from a shell command string. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ShellCommand { - /// The command name (first word), e.g. "git" - pub name: String, - /// The full invocation including arguments, e.g. "git commit -m msg" - pub full: String, -} - -/// A parsed shell command with all subcommands extracted. -#[derive(Debug)] -pub(crate) struct ParsedShellCommand { - pub subcommands: Vec<ShellCommand>, -} - -/// Supported shell families for parsing. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) enum ShellKind { - /// POSIX sh, bash, zsh — all share similar syntax - Posix, - /// fish shell - Fish, - /// nushell or unknown — fallback to word-level extraction - Other, -} - -impl ShellKind { - pub(crate) fn from_shell_name(name: &str) -> Self { - match name { - "bash" | "sh" | "zsh" | "dash" | "ksh" => Self::Posix, - "fish" => Self::Fish, - _ => Self::Other, - } - } -} - -/// Parse a shell command string and extract all subcommands. -pub(crate) fn parse_shell_command(code: &str, shell: ShellKind) -> ParsedShellCommand { - #[cfg(feature = "tree-sitter")] - match shell { - ShellKind::Posix => ts::parse_posix(code), - ShellKind::Fish => ts::parse_fish(code), - ShellKind::Other => parse_fallback(code), - } - - #[cfg(not(feature = "tree-sitter"))] - { - let _ = shell; - parse_fallback(code) - } -} - -// ──────────────────────────────────────────────────────────────── -// Tree-sitter parsers (POSIX + Fish) -// Disabled on platforms where tree-sitter doesn't cross-compile -// (e.g. Windows); falls back to word-level extraction. -// ──────────────────────────────────────────────────────────────── - -#[cfg(feature = "tree-sitter")] -mod ts { - use super::{ParsedShellCommand, ShellCommand, parse_fallback}; - use tree_sitter_lib::{Parser, Tree}; - - fn bash_parser() -> Parser { - let mut parser = Parser::new(); - parser - .set_language(&tree_sitter_bash::LANGUAGE.into()) - .expect("failed to set bash language"); - parser - } - - pub(super) fn parse_posix(code: &str) -> ParsedShellCommand { - let mut parser = bash_parser(); - let Some(tree) = parser.parse(code, None) else { - return parse_fallback(code); - }; - - let mut commands = Vec::new(); - walk_bash_tree(&tree, code.as_bytes(), &mut commands); - ParsedShellCommand { - subcommands: commands, - } - } - - /// Leaf node kinds that never contain nested commands. - const BASH_LEAVES: &[&str] = &[ - "command_name", - "word", - "number", - "simple_expansion", - "expansion", - "arithmetic_expansion", - "ansi_c_string", - "special_variable_name", - "variable_name", - "file_descriptor", - "heredoc_body", - "heredoc_start", - "regex", - "heredoc_redirect", - ]; - - fn walk_bash_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { - walk_bash_node(tree.root_node(), source, commands); - } - - fn walk_bash_node( - node: tree_sitter_lib::Node, - source: &[u8], - commands: &mut Vec<ShellCommand>, - ) { - match node.kind() { - "command" => { - if let Some(cmd) = extract_bash_command(node, source) { - commands.push(cmd); - } - // Descend into all non-leaf children to find nested commands - // (e.g. command_substitution inside a string inside a command) - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - if !BASH_LEAVES.contains(&child.kind()) { - walk_bash_node(child, source, commands); - } - } - } - // Other nodes: descend into all children - _ => { - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - walk_bash_node(child, source, commands); - } - } - } - } - - /// Extract the full command string and name from a bash `command` node. - fn extract_bash_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { - // A `command` node has children like: - // variable_assignment* command_name argument* redirect* - // We want the command_name and all arguments (skipping assignments and redirects). - let mut name = None; - let mut name_start = None; - let mut arg_end = None; - - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - match child.kind() { - "command_name" => { - name = child.utf8_text(source).ok().map(|s| s.to_string()); - name_start = Some(child.start_byte()); - } - "word" - | "string" - | "raw_string" - | "concatenation" - | "number" - | "simple_expansion" - | "expansion" - | "arithmetic_expansion" - | "ansi_c_string" - | "process_substitution" => { - arg_end = Some(child.end_byte()); - } - _ => {} - } - } - - let name = name?; - let full = if let (Some(start), Some(end)) = (name_start, arg_end) { - std::str::from_utf8(&source[start..end]).ok()?.to_string() - } else { - name.clone() - }; - - Some(ShellCommand { name, full }) - } - - // ──────────────────────────────────────────────────────────────── - // Fish parser - // ──────────────────────────────────────────────────────────────── - - fn fish_parser() -> Parser { - let mut parser = Parser::new(); - parser - .set_language(&tree_sitter_fish::language()) - .expect("failed to set fish language"); - parser - } - - pub(super) fn parse_fish(code: &str) -> ParsedShellCommand { - let mut parser = fish_parser(); - let Some(tree) = parser.parse(code, None) else { - return parse_fallback(code); - }; - - let mut commands = Vec::new(); - walk_fish_tree(&tree, code.as_bytes(), &mut commands); - ParsedShellCommand { - subcommands: commands, - } - } - - const FISH_COMPOUND: &[&str] = &[ - "conditional_execution", - "pipe", - "job", - "command_substitution", - "block", - "for_statement", - "while_statement", - "if_statement", - "switch_statement", - "function_definition", - "begin_statement", - "redirected_statement", - ]; - - fn walk_fish_tree(tree: &Tree, source: &[u8], commands: &mut Vec<ShellCommand>) { - walk_fish_node(tree.root_node(), source, commands); - } - - fn walk_fish_node( - node: tree_sitter_lib::Node, - source: &[u8], - commands: &mut Vec<ShellCommand>, - ) { - match node.kind() { - "command" => { - if let Some(cmd) = extract_fish_command(node, source) { - commands.push(cmd); - } - // Still descend into compound children (e.g. command_substitution inside a command) - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - if FISH_COMPOUND.contains(&child.kind()) { - walk_fish_node(child, source, commands); - } - } - } - // Other nodes: descend into all children - _ => { - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - walk_fish_node(child, source, commands); - } - } - } - } - - fn extract_fish_command(node: tree_sitter_lib::Node, source: &[u8]) -> Option<ShellCommand> { - // In fish, a `command` node has: - // name (command_name or word) followed by arguments (word, string, etc.) - let mut name = None; - - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { - match child.kind() { - "command_name" | "word" => { - let text = child.utf8_text(source).ok()?.to_string(); - if name.is_none() { - name = Some(text); - } - } - "string" - | "concatenation" - | "command_substitution" - | "escape_sequence" - | "double_quote_string" - | "single_quote_string" => {} - _ => {} - } - } - - let name = name?; - // Get the full text of the command node - let full = node.utf8_text(source).ok()?.trim().to_string(); - - Some(ShellCommand { name, full }) - } -} // mod ts - -// ──────────────────────────────────────────────────────────────── -// Fallback (word-level extraction for nushell / unknown shells) -// ──────────────────────────────────────────────────────────────── - -fn parse_fallback(code: &str) -> ParsedShellCommand { - // Simple heuristic: split by &&, ||, ;, | and take the first word of each segment. - // This is intentionally simple — for unknown shells we can't do better. - let mut commands = Vec::new(); - let mut segment = String::new(); - let mut chars = code.chars().peekable(); - - while let Some(c) = chars.next() { - match c { - ';' => { - push_segment(&mut segment, &mut commands); - } - '|' => { - if chars.peek() == Some(&'|') { - chars.next(); - } - push_segment(&mut segment, &mut commands); - } - '&' if chars.peek() == Some(&'&') => { - chars.next(); - push_segment(&mut segment, &mut commands); - } - _ => segment.push(c), - } - } - push_segment(&mut segment, &mut commands); - - ParsedShellCommand { - subcommands: commands, - } -} - -fn push_segment(segment: &mut String, commands: &mut Vec<ShellCommand>) { - let trimmed = segment.trim(); - if !trimmed.is_empty() - && let Some(name) = trimmed.split_whitespace().next() - { - commands.push(ShellCommand { - name: name.to_string(), - full: trimmed.to_string(), - }); - } - segment.clear(); -} - -// ──────────────────────────────────────────────────────────────── -// Scope matching -// ──────────────────────────────────────────────────────────────── - -/// Check if any of the extracted subcommands match the given scope pattern. -/// -/// Matching semantics depend on where the `*` wildcard appears: -/// - `*` alone — matches everything -/// - `ls *` (space before `*`) — matches `ls` and `ls -a` but not `lsof` -/// - `git commit *` — matches `git commit -m "msg"` (word boundary) -/// - `ls*` (no space before `*`) — matches `lsof`, `ls`, `ls -a` (prefix/glob) -/// - `rm` (no wildcard) — matches exactly `rm` -/// - `git * amend` — matches `git commit amend` (middle wildcard matches zero+ words) -/// -/// When `prefix_bare` is true, a bare pattern without wildcards (e.g. `rm`) -/// uses word-boundary prefix matching — `rm` matches `rm -rf /`. When false, -/// bare patterns require an exact match — `rm` only matches `rm`. -/// -/// Allow rules should pass `prefix_bare: false` (strict), while deny/ask rules -/// should pass `prefix_bare: true` (broad) so that denying `rm` also blocks -/// `rm -rf /`. -pub(crate) fn any_subcommand_matches( - subcommands: &[ShellCommand], - prefix_bare: bool, - scope: &str, -) -> bool { - let scope = scope.trim(); - - if scope.is_empty() || scope == "*" { - return true; - } - - if let Some(prefix) = scope.strip_suffix(" *") { - // Word-boundary matching: `ls *` matches `ls` and `ls -a` but not `lsof` - return subcommands.iter().any(|cmd| { - if prefix.is_empty() { - return true; - } - let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); - let prefix_words: Vec<&str> = prefix.split_whitespace().collect(); - cmd_words.len() >= prefix_words.len() - && cmd_words[..prefix_words.len()] == prefix_words[..] - }); - } - - if let Some(prefix) = scope.strip_suffix('*') { - // Prefix/glob matching: `ls*` matches `lsof`, `ls`, etc. - return subcommands.iter().any(|cmd| cmd.full.starts_with(prefix)); - } - - if scope.contains('*') { - // Middle wildcard: `git * amend` — each `*` matches zero or more words - return subcommands - .iter() - .any(|cmd| scope_matches_words(scope, cmd.full.split_whitespace().collect())); - } - - // No wildcard: exact or prefix depending on context - let scope_words: Vec<&str> = scope.split_whitespace().collect(); - subcommands.iter().any(|cmd| { - let cmd_words: Vec<&str> = cmd.full.split_whitespace().collect(); - if prefix_bare { - cmd_words.len() >= scope_words.len() - && cmd_words[..scope_words.len()] == scope_words[..] - } else { - cmd_words == scope_words - } - }) -} - -/// Match a scope pattern containing `*` wildcards against a sequence of words. -/// Each `*` matches zero or more words. Consecutive `*` collapse into one. -fn scope_matches_words(scope: &str, words: Vec<&str>) -> bool { - let parts: Vec<&str> = scope.split('*').collect(); - if parts.len() == 1 { - // No wildcard (shouldn't reach here, but handle it) - let scope_words: Vec<&str> = scope.split_whitespace().collect(); - return words.len() >= scope_words.len() && words[..scope_words.len()] == scope_words[..]; - } - - // Each segment between * is a sequence of literal words that must appear in order. - // Walk through `words` consuming segments left to right. - let mut word_idx = 0; - - for (i, part) in parts.iter().enumerate() { - let segment_words: Vec<&str> = part.split_whitespace().collect(); - if segment_words.is_empty() { - continue; - } - - // Find the segment words starting from word_idx - if i == 0 { - // First segment must match at the start - if words.len() < segment_words.len() - || words[..segment_words.len()] != segment_words[..] - { - return false; - } - word_idx = segment_words.len(); - } else if i == parts.len() - 1 { - // Last segment must match at the end - if words.len() - word_idx < segment_words.len() { - return false; - } - let start = words.len() - segment_words.len(); - return words[start..] == segment_words[..]; - } else { - // Middle segment: find it anywhere after word_idx - let found = find_subslice(&words[word_idx..], &segment_words); - match found { - Some(idx) => word_idx += idx + segment_words.len(), - None => return false, - } - } - } - - true -} - -/// Find the first occurrence of `needle` as a contiguous subsequence in `haystack`. -fn find_subslice(haystack: &[&str], needle: &[&str]) -> Option<usize> { - if needle.is_empty() { - return Some(0); - } - if haystack.len() < needle.len() { - return None; - } - (0..=haystack.len() - needle.len()).find(|&i| haystack[i..i + needle.len()] == needle[..]) -} - -// ──────────────────────────────────────────────────────────────── -// Tests -// ──────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - fn names(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.name.as_str()).collect() - } - - fn fulls(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.full.as_str()).collect() - } - - #[test] - fn simple_command() { - let result = parse_shell_command("ls -la /tmp", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["ls"]); - assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); - } - - #[test] - fn pipeline() { - let result = parse_shell_command("cat file.txt | grep foo | wc -l", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["cat", "grep", "wc"]); - } - - #[test] - fn command_chaining() { - let result = parse_shell_command("git add . && git commit -m 'hi'", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["git", "git"]); - assert_eq!( - fulls(&result.subcommands), - vec!["git add .", "git commit -m 'hi'"] - ); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn command_substitution() { - let result = parse_shell_command("echo $(git rev-parse HEAD)", ShellKind::Posix); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn backtick_substitution() { - let result = parse_shell_command("echo `date`", ShellKind::Posix); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"date"), "should contain date: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn subshell() { - let result = parse_shell_command("(cd /tmp && ls)", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["cd", "ls"]); - } - - #[test] - fn semicolon_separated() { - let result = parse_shell_command("echo hello; echo world", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["echo", "echo"]); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn for_loop() { - let result = parse_shell_command("for f in *.txt; do cat $f; done", ShellKind::Posix); - assert!(names(&result.subcommands).contains(&"cat")); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn if_statement() { - let result = parse_shell_command( - "if [ -f foo ]; then cat foo; else echo nope; fi", - ShellKind::Posix, - ); - let n = names(&result.subcommands); - assert!(n.contains(&"cat"), "should contain cat: {n:?}"); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - } - - #[test] - fn scope_matching_wildcard() { - let commands = vec![ - ShellCommand { - name: "git".into(), - full: "git commit -m msg".into(), - }, - ShellCommand { - name: "npm".into(), - full: "npm test".into(), - }, - ]; - assert!(any_subcommand_matches(&commands, true, "*")); - } - - #[test] - fn scope_matching_prefix() { - let commands = vec![ - ShellCommand { - name: "git".into(), - full: "git commit -m msg".into(), - }, - ShellCommand { - name: "npm".into(), - full: "npm test".into(), - }, - ]; - assert!(any_subcommand_matches(&commands, true, "git commit *")); - assert!(!any_subcommand_matches(&commands, true, "git push *")); - assert!(!any_subcommand_matches(&commands, true, "git push")); - assert!(any_subcommand_matches(&commands, true, "npm *")); - assert!(any_subcommand_matches(&commands, true, "npm test")); - - // prefix_bare=true: bare "git commit" prefix-matches "git commit -m msg" (deny/ask) - assert!(any_subcommand_matches(&commands, true, "git commit")); - // prefix_bare=false: bare "git commit" does NOT match "git commit -m msg" (allow) - assert!(!any_subcommand_matches(&commands, false, "git commit")); - // Exact match works in both modes when command has no extra args - assert!(any_subcommand_matches(&commands, false, "npm test")); - } - - #[test] - fn scope_word_boundary_vs_glob() { - let commands = vec![ - ShellCommand { - name: "ls".into(), - full: "ls -a".into(), - }, - ShellCommand { - name: "lsof".into(), - full: "lsof -i :3000".into(), - }, - ]; - // `ls *` — word boundary: matches `ls -a` but not `lsof` - assert!(any_subcommand_matches(&commands, true, "ls *")); - assert!(!any_subcommand_matches(&commands, true, "cat *")); - assert!(any_subcommand_matches(&commands, true, "lsof *")); - - // `ls*` — glob/prefix: matches both `ls -a` and `lsof` - assert!(any_subcommand_matches(&commands, true, "ls*")); - } - - #[test] - fn scope_exact_match() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "ls")); - assert!(!any_subcommand_matches(&commands, true, "cat")); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn nested_substitution() { - let result = parse_shell_command( - "echo \"Result: $(git log --oneline | head -1)\"", - ShellKind::Posix, - ); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - assert!(n.contains(&"head"), "should contain head: {n:?}"); - } - - #[test] - fn fallback_splits_correctly() { - let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); - let n = names(&result.subcommands); - assert!(n.contains(&"ls"), "should contain ls: {n:?}"); - assert!(n.contains(&"cat"), "should contain cat: {n:?}"); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - } - - #[test] - fn fish_simple_command() { - let result = parse_shell_command("ls -la /tmp", ShellKind::Fish); - assert_eq!(names(&result.subcommands), vec!["ls"]); - } - - #[test] - fn fish_conditional() { - let result = parse_shell_command("git add .; and git commit -m hi", ShellKind::Fish); - let n = names(&result.subcommands); - assert!(n.contains(&"git"), "should contain git: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn fish_command_substitution() { - let result = parse_shell_command("echo (date)", ShellKind::Fish); - let n = names(&result.subcommands); - assert!(n.contains(&"echo"), "should contain echo: {n:?}"); - assert!(n.contains(&"date"), "should contain date: {n:?}"); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn variable_assignment_excluded() { - let result = parse_shell_command("FOO=bar ls -la /tmp", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["ls"]); - assert_eq!(fulls(&result.subcommands), vec!["ls -la /tmp"]); - } - - #[cfg(feature = "tree-sitter")] - #[test] - fn variable_assignment_multiple() { - let result = parse_shell_command("A=1 B=2 git status", ShellKind::Posix); - assert_eq!(names(&result.subcommands), vec!["git"]); - assert_eq!(fulls(&result.subcommands), vec!["git status"]); - } - - #[test] - fn fallback_double_ampersand_and_pipe_pipe() { - let result = parse_shell_command("ls && cat foo || echo fail", ShellKind::Other); - assert_eq!(names(&result.subcommands), vec!["ls", "cat", "echo"]); - assert_eq!( - fulls(&result.subcommands), - vec!["ls", "cat foo", "echo fail"] - ); - } - - #[test] - fn fallback_pipe_without_double() { - let result = parse_shell_command("ls | grep foo", ShellKind::Other); - assert_eq!(names(&result.subcommands), vec!["ls", "grep"]); - assert_eq!(fulls(&result.subcommands), vec!["ls", "grep foo"]); - } - - #[test] - fn scope_middle_wildcard() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit -m amend".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * amend")); - assert!(any_subcommand_matches( - &commands, - true, - "git commit * amend" - )); - assert!(!any_subcommand_matches(&commands, true, "git push * amend")); - } - - #[test] - fn scope_middle_wildcard_zero_words() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - // `*` matches zero words, so `git * commit` should match `git commit` - assert!(any_subcommand_matches(&commands, true, "git * commit")); - } - - #[test] - fn scope_leading_wildcard() { - let commands = vec![ShellCommand { - name: "docker".into(), - full: "docker run --rm alpine".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "* alpine")); - assert!(!any_subcommand_matches(&commands, true, "* ubuntu")); - } - - #[test] - fn scope_multiple_wildcards() { - let commands = vec![ShellCommand { - name: "git".into(), - full: "git rebase -i HEAD~5".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * -i * HEAD~5")); - assert!(!any_subcommand_matches( - &commands, - true, - "git * -i * HEAD~10" - )); - } -} - -#[cfg(all(test, feature = "tree-sitter"))] -mod adversarial { - use super::*; - - fn cmd_names(cmds: &[ShellCommand]) -> Vec<&str> { - cmds.iter().map(|c| c.name.as_str()).collect() - } - - /// Helper: assert that parsing POSIX extracts all expected command names - fn assert_posix(code: &str, expected: &[&str]) { - let result = parse_shell_command(code, ShellKind::Posix); - let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); - got.sort(); - let mut want: Vec<&str> = expected.to_vec(); - want.sort(); - assert_eq!( - got, want, - "POSIX parse of {:?}:\n got: {:?}\n want: {:?}", - code, got, want - ); - } - - fn assert_fish(code: &str, expected: &[&str]) { - let result = parse_shell_command(code, ShellKind::Fish); - let mut got: Vec<&str> = result.subcommands.iter().map(|c| c.name.as_str()).collect(); - got.sort(); - let mut want: Vec<&str> = expected.to_vec(); - want.sort(); - assert_eq!( - got, want, - "Fish parse of {:?}:\n got: {:?}\n want: {:?}", - code, got, want - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 1: Basic compounds - // ──────────────────────────────────────────────────────────── - - #[test] - fn a01_triple_chain() { - assert_posix("a && b && c", &["a", "b", "c"]); - } - - #[test] - fn a02_or_chain() { - assert_posix("a || b || c", &["a", "b", "c"]); - } - - #[test] - fn a03_mixed_chain() { - assert_posix("a && b || c && d", &["a", "b", "c", "d"]); - } - - #[test] - fn a04_long_pipeline() { - assert_posix( - "cat foo | grep bar | awk '{print $1}' | sort | uniq -c", - &["cat", "grep", "awk", "sort", "uniq"], - ); - } - - #[test] - fn a05_semicolons() { - assert_posix("a; b; c; d", &["a", "b", "c", "d"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 2: Nested substitution - // ──────────────────────────────────────────────────────────── - - #[test] - fn a06_nested_dollar() { - assert_posix( - "echo $(basename $(dirname /foo/bar))", - &["echo", "basename", "dirname"], - ); - } - - #[test] - fn a07_deeply_nested() { - // 4 nested echos, all should be extracted - assert_posix( - "echo $(echo $(echo $(echo deep)))", - &["echo", "echo", "echo", "echo"], - ); - } - - #[test] - fn a08_backtick_in_echo() { - assert_posix("echo `hostname`", &["echo", "hostname"]); - } - - #[test] - fn a09_mixed_substitutions() { - assert_posix("echo $(date) `uname`", &["echo", "date", "uname"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 3: Subshells and grouping - // ──────────────────────────────────────────────────────────── - - #[test] - fn a10_subshell_chain() { - assert_posix("(cd /tmp && ls -la)", &["cd", "ls"]); - } - - #[test] - fn a11_nested_subshells() { - assert_posix("( (inner_cmd) )", &["inner_cmd"]); - } - - #[test] - fn a12_brace_group() { - assert_posix("{ cd /tmp; ls; }", &["cd", "ls"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 4: Variable assignments - // ──────────────────────────────────────────────────────────── - - #[test] - fn a13_single_var_assignment() { - let result = parse_shell_command("FOO=bar ls", ShellKind::Posix); - assert_eq!(cmd_names(&result.subcommands), &["ls"]); - assert_eq!(result.subcommands[0].full, "ls"); - } - - #[test] - fn a14_multiple_var_assignments() { - let result = parse_shell_command("A=1 B=2 C=3 git status", ShellKind::Posix); - assert_eq!(cmd_names(&result.subcommands), &["git"]); - assert_eq!(result.subcommands[0].full, "git status"); - } - - #[test] - fn a15_var_assignment_no_command() { - // Variable assignment only — no command to extract - assert_posix("FOO=bar", &[]); - } - - #[test] - fn a16_var_assignment_in_pipeline() { - assert_posix("FOO=bar ls | BAZ=qux grep foo", &["ls", "grep"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 5: Control flow - // ──────────────────────────────────────────────────────────── - - #[test] - fn a17_if_then_else() { - assert_posix( - "if [ -f foo ]; then cat foo; else echo missing; fi", - &["cat", "echo"], - ); - } - - #[test] - fn a18_elif_chain() { - // Two cat commands (then + elif branch), one echo (else branch). - // [ is part of the test_condition, not extracted as a command. - assert_posix( - "if [ -f a ]; then cat a; elif [ -f b ]; then cat b; else echo none; fi", - &["cat", "cat", "echo"], - ); - } - - #[test] - fn a19_for_loop() { - assert_posix("for f in *.txt; do cat \"$f\"; done", &["cat"]); - } - - #[test] - fn a20_while_loop() { - // read in the condition is a real command - assert_posix( - "while read line; do echo \"$line\"; done < input.txt", - &["echo", "read"], - ); - } - - #[test] - fn f07_if_statement() { - // test in if-condition is a real command - assert_fish( - "if test -f foo; cat foo; else; echo missing; end", - &["cat", "echo", "test"], - ); - } - - #[test] - fn f09_while_loop() { - // `true` in the condition is a real command - assert_fish( - "while true; echo tick; sleep 1; end", - &["echo", "sleep", "true"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 6: Redirections - // ──────────────────────────────────────────────────────────── - - #[test] - fn a23_redirect_out() { - assert_posix("ls > output.txt", &["ls"]); - } - - #[test] - fn a24_redirect_append() { - assert_posix("ls >> output.txt 2>&1", &["ls"]); - } - - #[test] - fn a25_here_string() { - assert_posix("grep foo <<< \"hello world\"", &["grep"]); - } - - #[test] - fn a26_redirect_in_pipeline() { - assert_posix("cat < input.txt | sort | uniq", &["cat", "sort", "uniq"]); - } - - #[test] - fn a27_process_substitution() { - assert_posix( - "diff <(sort a.txt) <(sort b.txt)", - &["diff", "sort", "sort"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 7: Function definitions - // ──────────────────────────────────────────────────────────── - - #[test] - fn a28_function_def() { - assert_posix("foo() { echo hello; }", &["echo"]); - } - - #[test] - fn a29_function_with_subshell() { - assert_posix( - "build() { cargo build && cargo test; }", - &["cargo", "cargo"], - ); - } - - // ──────────────────────────────────────────────────────────── - // Level 8: Edge cases — empties, weird quoting - // ──────────────────────────────────────────────────────────── - - #[test] - fn a30_empty_string() { - let result = parse_shell_command("", ShellKind::Posix); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn a31_whitespace_only() { - let result = parse_shell_command(" \t \n ", ShellKind::Posix); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn a32_single_command_no_args() { - assert_posix("ls", &["ls"]); - } - - #[test] - fn a33_command_with_single_quotes() { - assert_posix("echo 'hello world'", &["echo"]); - } - - #[test] - fn a34_command_with_double_quotes() { - assert_posix("echo \"hello world\"", &["echo"]); - } - - #[test] - fn a35_escaped_spaces() { - // ls\ -la is a single word in bash, not "ls" with flag "-la" - assert_posix("ls\\ -la", &["ls\\ -la"]); - } - - #[test] - fn a36_command_with_dollar_var() { - assert_posix("echo $HOME/.bashrc", &["echo"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 9: Background jobs and coproc - // ──────────────────────────────────────────────────────────── - - #[test] - fn a37_background_job() { - assert_posix("sleep 10 &", &["sleep"]); - } - - #[test] - fn a38_background_chain() { - assert_posix("sleep 10 && echo done &", &["sleep", "echo"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 10: Real-world complex commands - // ──────────────────────────────────────────────────────────── - - #[test] - fn a39_docker_build_and_run() { - assert_posix( - "docker build -t app . && docker run --rm app npm test", - &["docker", "docker"], - ); - } - - #[test] - fn a40_git_rebase_interactive() { - assert_posix( - "GIT_SEQUENCE_EDITOR=\"sed -i 's/pick/reword/'\" git rebase -i HEAD~5", - &["git"], - ); - } - - #[test] - fn a41_find_with_exec() { - // tree-sitter-bash does not parse -exec body as commands — only `find` is extracted. - // This is a known limitation: args to -exec/-execdir are opaque to the parser. - assert_posix("find . -name '*.rs' -exec grep -l 'unsafe' {} +", &["find"]); - } - - #[test] - fn a42_curl_pipe_sh() { - assert_posix( - "curl -sSL https://example.com/install.sh | bash", - &["curl", "bash"], - ); - } - - #[test] - fn a43_xargs() { - assert_posix("find . -name '*.tmp' | xargs rm -f", &["find", "xargs"]); - } - - #[test] - fn a44_npm_script_chain() { - assert_posix( - "npm run build && npm run test && npm run lint", - &["npm", "npm", "npm"], - ); - } - - #[test] - fn a45_make_with_redirect() { - assert_posix( - "make -j$(nproc) 2>&1 | tee build.log", - &["make", "nproc", "tee"], - ); - } - - #[test] - fn a46_sudo_chain() { - assert_posix("sudo apt update && sudo apt upgrade -y", &["sudo", "sudo"]); - } - - #[test] - fn a47_here_doc_with_subcommand() { - assert_posix("cat <<EOF\nhello $(whoami)\nEOF", &["cat", "whoami"]); - } - - #[test] - fn a48_eval_with_command() { - assert_posix("eval \"echo hello\"", &["eval"]); - } - - #[test] - fn a49_exec_replace() { - assert_posix("exec ls", &["exec"]); - } - - #[test] - fn a50_source_script() { - assert_posix("source ~/.bashrc", &["source"]); - } - - // ──────────────────────────────────────────────────────────── - // Level 11: Fish-specific tests - // ──────────────────────────────────────────────────────────── - - #[test] - fn f01_simple() { - assert_fish("ls -la /tmp", &["ls"]); - } - - #[test] - fn f02_pipe() { - assert_fish("cat foo | grep bar | sort", &["cat", "grep", "sort"]); - } - - #[test] - fn f03_and() { - assert_fish("git add .; and git commit -m hi", &["git", "git"]); - } - - #[test] - fn f04_or() { - assert_fish("test -f foo; or echo missing", &["test", "echo"]); - } - - #[test] - fn f04_not() { - // fish parses `not test -f foo` — `not` is a modifier, `test` is the command - assert_fish("not test -f foo", &["test"]); - } - - #[test] - fn f05_command_substitution() { - assert_fish("echo (date)", &["echo", "date"]); - } - - #[test] - fn f06_nested_substitution() { - assert_fish( - "echo (basename (dirname /foo/bar))", - &["echo", "basename", "dirname"], - ); - } - - #[test] - fn f06_begin_end() { - assert_fish("begin; ls; echo done; end", &["ls", "echo"]); - } - - #[test] - fn f10_switch() { - // Two echo commands, one per case branch - assert_fish( - "switch $x; case foo; echo foo; case bar; echo bar; end", - &["echo", "echo"], - ); - } - - #[test] - fn f08_for_loop() { - assert_fish("for f in *.txt; cat $f; end", &["cat"]); - } - - #[test] - fn a21_case_statement() { - // Two echo branches - assert_posix( - "case $x in foo) echo foo;; bar) echo bar;; esac", - &["echo", "echo"], - ); - } - - #[test] - fn f11_function_def() { - assert_fish("function greet; echo hello $argv; end", &["echo"]); - } - - #[test] - fn f12_redirect() { - assert_fish("ls > output.txt", &["ls"]); - } - - #[test] - fn f13_redirect_append() { - assert_fish("ls >> output.txt", &["ls"]); - } - - #[test] - fn f14_here_string() { - assert_fish("grep foo <<< \"hello\"", &["grep"]); - } - - #[test] - fn f15_curl_pipe() { - assert_fish( - "curl -sSL https://example.com/install.sh | bash", - &["curl", "bash"], - ); - } - - #[test] - fn f16_double_ampersand() { - assert_fish("git add . && git commit -m hi", &["git", "git"]); - } - - #[test] - fn f17_double_pipe() { - assert_fish("test -f foo || echo missing", &["test", "echo"]); - } - - #[test] - fn f18_empty() { - let result = parse_shell_command("", ShellKind::Fish); - assert!(result.subcommands.is_empty()); - } - - #[test] - fn f19_whitespace() { - let result = parse_shell_command(" ", ShellKind::Fish); - assert!(result.subcommands.is_empty()); - } - - // ──────────────────────────────────────────────────────────── - // Level 12: Scope matching adversarial - // ──────────────────────────────────────────────────────────── - - #[test] - fn s01_empty_scope() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // Empty scope matches everything (nothing to constrain) - assert!(any_subcommand_matches(&commands, true, "")); - } - - #[test] - fn s03_only_wildcard_space_star() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // " *" with empty prefix = match anything - assert!(any_subcommand_matches(&commands, true, " *")); - } - - #[test] - fn s04_glob_matches_empty() { - let commands = vec![ShellCommand { - name: "ls".into(), - full: "ls".into(), - }]; - // `ls*` matches `ls` (prefix match with nothing after) - assert!(any_subcommand_matches(&commands, true, "ls*")); - } - - #[test] - fn s05_middle_wildcard_empty_match() { - // `git * commit` matches `git commit` (* = zero words) - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git * commit")); - } - - #[test] - fn s06_consecutive_wildcards() { - // `git ** commit` should behave like `git * commit` - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit".into(), - }]; - assert!(any_subcommand_matches(&commands, true, "git ** commit")); - } - - #[test] - fn s07_case_sensitivity() { - let commands = vec![ShellCommand { - name: "LS".into(), - full: "LS -la".into(), - }]; - // Wildcard: case matters - assert!(!any_subcommand_matches(&commands, true, "ls *")); - assert!(any_subcommand_matches(&commands, true, "LS *")); - // prefix_bare=true: bare "LS" prefix-matches "LS -la" - assert!(!any_subcommand_matches(&commands, true, "ls")); - assert!(any_subcommand_matches(&commands, true, "LS")); - // prefix_bare=false: bare "LS" does NOT match "LS -la" - assert!(!any_subcommand_matches(&commands, false, "LS")); - } - - #[test] - fn s08_multi_word_exact_no_subcommand() { - // `git commit` should not match `git commit-amend` - let commands = vec![ShellCommand { - name: "git".into(), - full: "git commit-amend".into(), - }]; - assert!(!any_subcommand_matches(&commands, true, "git commit")); - } -} diff --git a/crates/atuin-ai/src/permissions/walker.rs b/crates/atuin-ai/src/permissions/walker.rs deleted file mode 100644 index 3bda01c3..00000000 --- a/crates/atuin-ai/src/permissions/walker.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::path::{Path, PathBuf}; - -use eyre::Result; -use tokio::task::JoinSet; - -use crate::permissions::file::{RuleFile, RuleFileContent}; - -#[derive(Debug)] -struct FoundRuleFile { - depth: usize, - file: RuleFile, -} - -pub(crate) struct PermissionWalker { - start: PathBuf, - /// Direct path to the global permissions file (e.g. `~/.config/atuin/permissions.ai.toml`). - global_permissions_file: Option<PathBuf>, - rules: Vec<RuleFile>, -} - -impl PermissionWalker { - pub fn new(start: PathBuf, global_permissions_file: Option<PathBuf>) -> Self { - Self { - start, - global_permissions_file, - rules: Vec::new(), - } - } - - pub fn rules(&self) -> &[RuleFile] { - &self.rules - } - - /// Walks the filesystem starting from the start path and collecting permission files along the way. - /// Walks to the root, then checks the global permissions file, if any. - pub async fn walk(&mut self) -> Result<()> { - let dirs_to_check: Vec<PathBuf> = self.start.ancestors().map(PathBuf::from).collect(); - let dir_count = dirs_to_check.len(); - - let mut set: JoinSet<Result<Option<FoundRuleFile>>> = JoinSet::new(); - - for (index, path) in dirs_to_check.into_iter().enumerate() { - set.spawn(async move { - match check_dir_for_permissions(&path).await { - Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { - depth: index, - file: rule_file, - })), - Ok(None) => Ok(None), - Err(e) => Err(e), - } - }); - } - - // Check the global file separately (it's a direct file path, not a dir/.atuin/ pattern) - if let Some(global_path) = self.global_permissions_file.clone() { - let depth = dir_count; // sorts after all directory-walk entries - set.spawn(async move { - match load_permissions_file(&global_path).await { - Ok(Some(rule_file)) => Ok(Some(FoundRuleFile { - depth, - file: rule_file, - })), - Ok(None) => Ok(None), - Err(e) => Err(e), - } - }); - } - - let capacity = dir_count + usize::from(self.global_permissions_file.is_some()); - let mut found = Vec::with_capacity(capacity); - while let Some(result) = set.join_next().await { - let result = result?; // JoinErrors result in failure to walk the filesystem - - match result { - Ok(Some(FoundRuleFile { depth, file })) => { - found.push((depth, file)); - } - Ok(None) => { - continue; - } - Err(e) => { - tracing::error!( - "Error while walking filesystem for permissions check; skipping: {}", - e - ); - continue; - } - } - } - // join_next() returns in order of completion, not order of spawn - found.sort_by_key(|(depth, _)| *depth); - self.rules = found.into_iter().map(|(_, file)| file).collect(); - - Ok(()) - } -} - -/// Checks a directory for `.atuin/permissions.ai.toml` and returns the RuleFile if found. -async fn check_dir_for_permissions(path: &Path) -> Result<Option<RuleFile>> { - let file_path = path.join(".atuin").join("permissions.ai.toml"); - load_permissions_file(&file_path).await -} - -/// Load a permissions file from an exact path. Returns None if the file doesn't exist. -async fn load_permissions_file(file_path: &Path) -> Result<Option<RuleFile>> { - if !tokio::fs::try_exists(file_path).await? { - return Ok(None); - } - - let raw = tokio::fs::read_to_string(file_path).await?; - let content: RuleFileContent = toml::from_str(&raw)?; - - // Use the file's parent as the rule file path (for logging/debugging) - let path = file_path - .parent() - .map(Path::to_path_buf) - .unwrap_or_else(|| file_path.to_path_buf()); - - Ok(Some(RuleFile { path, content })) -} diff --git a/crates/atuin-ai/src/permissions/writer.rs b/crates/atuin-ai/src/permissions/writer.rs deleted file mode 100644 index ffef404e..00000000 --- a/crates/atuin-ai/src/permissions/writer.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::path::Path; - -use eyre::Result; - -use crate::permissions::rule::Rule; - -/// Whether a rule should be added to the allow or deny list. -#[derive(Debug, Clone)] -#[allow(dead_code)] -pub(crate) enum RuleDisposition { - Allow, - Deny, -} - -/// Write a permission rule to a `permissions.ai.toml` file. -/// -/// If the file doesn't exist it is created (along with parent directories). -/// If it does exist, `toml_edit` is used to append the rule while preserving -/// existing formatting and comments. -/// -/// **Not concurrent-safe.** The read-modify-write cycle is not atomic. In the -/// current UI this is fine — the Select widget serializes permission decisions — -/// but callers should not invoke this concurrently for the same file. -pub(crate) async fn write_rule( - file_path: &Path, - rule: &Rule, - disposition: RuleDisposition, -) -> Result<()> { - let content = if tokio::fs::try_exists(file_path).await.unwrap_or(false) { - tokio::fs::read_to_string(file_path).await? - } else { - String::new() - }; - - let mut doc: toml_edit::DocumentMut = content.parse()?; - - // Ensure [permissions] table exists - if !doc.contains_key("permissions") { - doc["permissions"] = toml_edit::Item::Table(toml_edit::Table::new()); - } - - let key = match disposition { - RuleDisposition::Allow => "allow", - RuleDisposition::Deny => "deny", - }; - - // Use as_table_like_mut so both standard and inline tables work. - let permissions = doc["permissions"] - .as_table_like_mut() - .ok_or_else(|| eyre::eyre!("[permissions] is not a table"))?; - - // Get or create the array - if !permissions.contains_key(key) { - permissions.insert(key, toml_edit::Item::Value(toml_edit::Array::new().into())); - } - - let array = permissions - .get_mut(key) - .and_then(|item| item.as_value_mut()) - .and_then(|v| v.as_array_mut()) - .ok_or_else(|| eyre::eyre!("permissions.{key} is not an array"))?; - - // Don't add duplicates - let rule_str = rule.to_string(); - let already_present = array.iter().any(|v| v.as_str() == Some(&rule_str)); - if !already_present { - array.push(rule_str); - } - - // Write back, creating parent directories as needed - if let Some(parent) = file_path.parent() { - tokio::fs::create_dir_all(parent).await?; - } - tokio::fs::write(file_path, doc.to_string()).await?; - - Ok(()) -} - -/// Build the path to the project-level permissions file. -/// `project_root` is typically a git root or the current working directory. -pub(crate) fn project_permissions_path(project_root: &Path) -> std::path::PathBuf { - project_root.join(".atuin").join("permissions.ai.toml") -} - -/// Build the path to the global permissions file (sibling of atuin config). -pub(crate) fn global_permissions_path() -> std::path::PathBuf { - atuin_common::utils::config_dir().join("permissions.ai.toml") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn creates_new_file_with_allow_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains("[permissions]")); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn appends_to_existing_file() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let existing = r#"# My permissions -[permissions] -allow = ["Read"] -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - // Comment preserved - assert!(content.contains("# My permissions")); - // Both rules present - assert!(content.contains(r#""Read""#)); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn does_not_duplicate_existing_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let existing = r#"[permissions] -allow = ["AtuinHistory"] -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - // Should appear exactly once - assert_eq!(content.matches("AtuinHistory").count(), 1); - } - - #[tokio::test] - async fn handles_inline_table_permissions() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - // Inline table style — as_table_mut() would return None for this - let existing = r#"permissions = { allow = ["Read"] } -"#; - tokio::fs::write(&file, existing).await.unwrap(); - - let rule = Rule { - tool: "AtuinHistory".to_string(), - scope: None, - }; - write_rule(&file, &rule, RuleDisposition::Allow) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains(r#""Read""#)); - assert!(content.contains(r#""AtuinHistory""#)); - } - - #[tokio::test] - async fn writes_deny_rule() { - let dir = tempfile::tempdir().unwrap(); - let file = dir.path().join("permissions.ai.toml"); - let rule = Rule { - tool: "Shell".to_string(), - scope: None, - }; - - write_rule(&file, &rule, RuleDisposition::Deny) - .await - .unwrap(); - - let content = tokio::fs::read_to_string(&file).await.unwrap(); - assert!(content.contains("deny")); - assert!(content.contains(r#""Shell""#)); - } -} diff --git a/crates/atuin-ai/src/session.rs b/crates/atuin-ai/src/session.rs deleted file mode 100644 index 848330fc..00000000 --- a/crates/atuin-ai/src/session.rs +++ /dev/null @@ -1,509 +0,0 @@ -//! Session service abstraction and manager. -//! -//! The TUI interacts with sessions through `SessionManager`, which wraps a -//! `SessionService` trait. Today the only implementation is `LocalSessionService` -//! (direct SQLite). When the daemon owns session state, a gRPC-backed -//! implementation can be swapped in without changing the TUI code. - -use async_trait::async_trait; -use eyre::Result; - -use crate::event_serde; -use crate::store::{AiSessionStore, StoredEvent, StoredSession}; -use crate::tui::ConversationEvent; - -// --------------------------------------------------------------------------- -// Trait -// --------------------------------------------------------------------------- - -#[async_trait] -pub(crate) trait SessionService: Send + Sync { - async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result<StoredSession>; - - async fn find_resumable( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result<Option<StoredSession>>; - - async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>>; - - async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()>; - - async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()>; - - async fn archive(&self, session_id: &str) -> Result<()>; - - async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>>; - async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()>; -} - -// --------------------------------------------------------------------------- -// Local implementation (direct SQLite) -// --------------------------------------------------------------------------- - -pub(crate) struct LocalSessionService { - store: AiSessionStore, -} - -impl LocalSessionService { - pub async fn open(path: impl AsRef<std::path::Path>, timeout: f64) -> Result<Self> { - let store = AiSessionStore::new(path, timeout).await?; - Ok(Self { store }) - } -} - -#[async_trait] -impl SessionService for LocalSessionService { - async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result<StoredSession> { - self.store.create_session(id, directory, git_root).await - } - - async fn find_resumable( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result<Option<StoredSession>> { - self.store - .find_resumable_session(directory, git_root, max_age_secs) - .await - } - - async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> { - self.store.load_events(session_id).await - } - - async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()> { - self.store - .append_event( - session_id, - event_id, - parent_id, - invocation_id, - event_type, - event_data, - ) - .await - } - - async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()> { - self.store - .update_server_session_id(session_id, server_session_id) - .await - } - - async fn archive(&self, session_id: &str) -> Result<()> { - self.store.archive_session(session_id).await - } - - async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> { - self.store.get_metadata(session_id, key).await - } - - async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { - self.store.set_metadata(session_id, key, value).await - } -} - -// --------------------------------------------------------------------------- -// SessionManager -// --------------------------------------------------------------------------- - -/// High-level session manager used by the TUI dispatch loop. -/// -/// Owns the current session identity, tracks what has been persisted, and -/// handles serialization between `ConversationEvent` and the storage format. -pub(crate) struct SessionManager { - service: Box<dyn SessionService>, - session_id: String, - invocation_id: String, - /// Number of events already persisted. `persist_events` only writes the - /// delta from this index onward. - persisted_count: usize, - /// ID of the last persisted event, used as `parent_id` for the next one. - head_id: Option<String>, - /// Stored for creating a new session on `/new`. - directory: Option<String>, - git_root: Option<String>, - /// Whether the session row has been created in the database. New sessions - /// are deferred until the first event is persisted, so empty sessions - /// don't linger and get spuriously resumed. - persisted_to_db: bool, -} - -impl SessionManager { - /// Create a new session manager. The database row is deferred until the - /// first event is persisted. - pub fn create_new( - service: Box<dyn SessionService>, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Self { - let session_id = atuin_common::utils::uuid_v7().to_string(); - let invocation_id = atuin_common::utils::uuid_v7().to_string(); - - Self { - service, - session_id, - invocation_id, - persisted_count: 0, - head_id: None, - directory: directory.map(String::from), - git_root: git_root.map(String::from), - persisted_to_db: false, - } - } - - /// Load an existing session and return a manager for it, along with the - /// deserialized conversation events, the server session ID, and the - /// timestamp of the last stored event. - pub async fn resume( - service: Box<dyn SessionService>, - stored: &StoredSession, - ) -> Result<( - Self, - Vec<ConversationEvent>, - Option<String>, - Option<i64>, - String, - )> { - let invocation_id = atuin_common::utils::uuid_v7().to_string(); - let stored_events = service.load_events(&stored.id).await?; - - let mut events = Vec::with_capacity(stored_events.len()); - let mut last_event_id = None; - let mut last_event_ts = None; - for se in &stored_events { - events.push(event_serde::deserialize_event( - &se.event_type, - &se.event_data, - )?); - last_event_id = Some(se.id.clone()); - last_event_ts = Some(se.created_at); - } - - let manager = Self { - service, - session_id: stored.id.clone(), - invocation_id: invocation_id.clone(), - persisted_count: events.len(), - head_id: last_event_id, - directory: stored.directory.clone(), - git_root: stored.git_root.clone(), - persisted_to_db: true, - }; - - Ok(( - manager, - events, - stored.server_session_id.clone(), - last_event_ts, - invocation_id, - )) - } - - /// Ensure the session row exists in the database. - async fn ensure_persisted(&mut self) -> Result<()> { - if !self.persisted_to_db { - self.service - .create_session( - &self.session_id, - self.directory.as_deref(), - self.git_root.as_deref(), - ) - .await?; - self.persisted_to_db = true; - } - Ok(()) - } - - /// Persist any new events since the last persist call. - pub async fn persist_events(&mut self, events: &[ConversationEvent]) -> Result<()> { - if self.persisted_count >= events.len() { - return Ok(()); - } - self.ensure_persisted().await?; - for event in &events[self.persisted_count..] { - let event_id = atuin_common::utils::uuid_v7().to_string(); - let (event_type, event_data) = event_serde::serialize_event(event); - - self.service - .append_event( - &self.session_id, - &event_id, - self.head_id.as_deref(), - &self.invocation_id, - &event_type, - &event_data, - ) - .await?; - - self.head_id = Some(event_id); - self.persisted_count += 1; - } - Ok(()) - } - - /// Persist the server session ID if it has changed. - pub async fn persist_server_session_id(&mut self, server_session_id: &str) -> Result<()> { - self.ensure_persisted().await?; - self.service - .update_server_session_id(&self.session_id, server_session_id) - .await - } - - /// Archive the current session (for `/new` command). - #[allow(dead_code)] // used in tests; will be used by dispatch for `/new` - pub async fn archive(&self) -> Result<()> { - if self.persisted_to_db { - self.service.archive(&self.session_id).await?; - } - Ok(()) - } - - /// Archive the current session and reset to a fresh one. - /// The new session row is deferred until the first event is persisted. - pub async fn archive_and_reset(&mut self) -> Result<()> { - if self.persisted_to_db { - self.service.archive(&self.session_id).await?; - } - - self.session_id = atuin_common::utils::uuid_v7().to_string(); - self.invocation_id = atuin_common::utils::uuid_v7().to_string(); - self.persisted_count = 0; - self.head_id = None; - self.persisted_to_db = false; - Ok(()) - } - - #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon - pub fn session_id(&self) -> &str { - &self.session_id - } - - #[allow(dead_code)] // used in tests; part of public API for dispatch/daemon - pub fn invocation_id(&self) -> &str { - &self.invocation_id - } - - /// Read a metadata value for the current session. - pub async fn get_metadata(&self, key: &str) -> Result<Option<String>> { - if !self.persisted_to_db { - return Ok(None); - } - self.service.get_metadata(&self.session_id, key).await - } - - /// Write a metadata value for the current session. - pub async fn set_metadata(&mut self, key: &str, value: &str) -> Result<()> { - self.ensure_persisted().await?; - self.service - .set_metadata(&self.session_id, key, value) - .await - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn test_service() -> Box<dyn SessionService> { - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - Box::new(svc) - } - - #[tokio::test] - async fn test_create_new_and_persist() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - let events = vec![ - ConversationEvent::UserMessage { - content: "hello".to_string(), - }, - ConversationEvent::Text { - content: "hi there".to_string(), - }, - ]; - - mgr.persist_events(&events).await.unwrap(); - - // Persist again with no new events — should be a no-op - mgr.persist_events(&events).await.unwrap(); - } - - #[tokio::test] - async fn test_create_and_resume() { - // Create a session and persist some events - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let session_id = atuin_common::utils::uuid_v7().to_string(); - svc.create_session(&session_id, Some("/project"), Some("/project")) - .await - .unwrap(); - - let events = vec![ - ConversationEvent::UserMessage { - content: "how do I list files?".to_string(), - }, - ConversationEvent::Text { - content: "Use ls".to_string(), - }, - ConversationEvent::ToolCall { - id: "tc_1".to_string(), - name: "suggest_command".to_string(), - input: serde_json::json!({"command": "ls -la"}), - }, - ]; - - // Persist events manually through the service - let inv_id = "inv-1"; - let mut parent: Option<String> = None; - for event in &events { - let eid = atuin_common::utils::uuid_v7().to_string(); - let (etype, edata) = event_serde::serialize_event(event); - svc.append_event(&session_id, &eid, parent.as_deref(), inv_id, &etype, &edata) - .await - .unwrap(); - parent = Some(eid); - } - - svc.update_server_session_id(&session_id, "srv-abc") - .await - .unwrap(); - - // Now find and resume the session with a fresh service connection - let stored = svc - .find_resumable(Some("/project"), Some("/project"), 3600) - .await - .unwrap() - .expect("should find session"); - - let (mut mgr, loaded_events, server_sid, last_ts, _invocation_id) = - SessionManager::resume(Box::new(svc), &stored) - .await - .unwrap(); - - assert_eq!(loaded_events.len(), 3); - assert_eq!(server_sid.as_deref(), Some("srv-abc")); - assert_ne!(mgr.invocation_id(), inv_id, "new invocation ID on resume"); - assert!(last_ts.is_some(), "should have a last event timestamp"); - - // Persisting again with the same events should be a no-op - mgr.persist_events(&loaded_events).await.unwrap(); - } - - #[tokio::test] - async fn test_incremental_persist() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - let mut events = vec![ConversationEvent::UserMessage { - content: "first".to_string(), - }]; - mgr.persist_events(&events).await.unwrap(); - - // Add more events and persist again — only the new ones should be written - events.push(ConversationEvent::Text { - content: "response".to_string(), - }); - events.push(ConversationEvent::UserMessage { - content: "second".to_string(), - }); - mgr.persist_events(&events).await.unwrap(); - - // Verify by loading through a fresh service (can't easily here since - // the service is moved, but the lack of errors confirms correctness) - } - - #[tokio::test] - async fn test_archive() { - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); - - mgr.archive().await.unwrap(); - } - - #[tokio::test] - async fn test_persist_server_session_id() { - let service = test_service().await; - let mut mgr = SessionManager::create_new(service, Some("/tmp"), None); - - mgr.persist_server_session_id("srv-123").await.unwrap(); - } - - #[tokio::test] - async fn test_parent_chain_integrity() { - // Verify that persisted events form a proper parent chain - let svc = LocalSessionService::open("sqlite::memory:", 2.0) - .await - .unwrap(); - - let session_id = { - let mut mgr = SessionManager::create_new(Box::new(svc), Some("/tmp"), None); - - let events = vec![ - ConversationEvent::UserMessage { - content: "a".to_string(), - }, - ConversationEvent::Text { - content: "b".to_string(), - }, - ConversationEvent::UserMessage { - content: "c".to_string(), - }, - ]; - mgr.persist_events(&events).await.unwrap(); - mgr.session_id().to_string() - }; - - // Re-open the store and load events to verify the chain - // (Can't do this with in-memory DB since it's gone, but the - // lack of FK constraint violations during persist confirms the - // parent_id values are valid) - let _ = session_id; - } -} diff --git a/crates/atuin-ai/src/skills/frontmatter.rs b/crates/atuin-ai/src/skills/frontmatter.rs deleted file mode 100644 index 759dffcc..00000000 --- a/crates/atuin-ai/src/skills/frontmatter.rs +++ /dev/null @@ -1,233 +0,0 @@ -//! YAML frontmatter parsing for `SKILL.md` files. -//! -//! Extracts the YAML block between `---` delimiters and parses it with -//! `yaml-rust2`. Returns the parsed fields and the byte offset where the -//! body begins (after the closing `---`). - -use yaml_rust2::YamlLoader; - -/// Parsed frontmatter fields from a `SKILL.md` file. -#[derive(Debug, Default)] -pub(crate) struct Frontmatter { - pub name: Option<String>, - pub description: Option<String>, - pub disable_model_invocation: bool, -} - -/// Result of splitting a skill file into frontmatter + body. -#[derive(Debug)] -pub(crate) struct ParsedSkillFile { - pub frontmatter: Frontmatter, - /// Everything after the closing `---` delimiter. - pub body: String, -} - -/// Parse a `SKILL.md` file's content into frontmatter and body. -/// -/// If no frontmatter delimiters are found, all content is treated as body -/// with default frontmatter. -pub(crate) fn parse(content: &str) -> ParsedSkillFile { - let Some((yaml_str, body)) = split_frontmatter(content) else { - return ParsedSkillFile { - frontmatter: Frontmatter::default(), - body: content.to_string(), - }; - }; - - let frontmatter = match YamlLoader::load_from_str(yaml_str) { - Ok(docs) if !docs.is_empty() => extract_fields(&docs[0]), - Ok(_) => Frontmatter::default(), - Err(e) => { - tracing::warn!("Failed to parse skill frontmatter: {e}"); - Frontmatter::default() - } - }; - - ParsedSkillFile { frontmatter, body } -} - -/// Split content on `---` delimiters. Returns `(yaml_str, body)` or `None` -/// if frontmatter is not present. -fn split_frontmatter(content: &str) -> Option<(&str, String)> { - let trimmed = content.trim_start(); - - // Must start with `---` - if !trimmed.starts_with("---") { - return None; - } - - // Find the end of the opening delimiter line - let after_open = trimmed.get(3..)?.trim_start_matches(|c: char| c != '\n'); - let after_open = after_open.strip_prefix('\n').unwrap_or(after_open); - - // Find the closing `---` - let close_pos = after_open - .lines() - .enumerate() - .find(|(_, line)| line.trim() == "---") - .map(|(i, _)| { - after_open - .lines() - .take(i) - .map(|l| l.len() + 1) // +1 for newline - .sum::<usize>() - })?; - - let yaml_str = &after_open[..close_pos]; - let rest = &after_open[close_pos..]; - // Skip the closing `---` line - let body = rest - .strip_prefix("---") - .unwrap_or(rest) - .trim_start_matches(|c: char| c != '\n'); - let body = body.strip_prefix('\n').unwrap_or(body); - - Some((yaml_str, body.to_string())) -} - -fn extract_fields(doc: &yaml_rust2::Yaml) -> Frontmatter { - use yaml_rust2::Yaml; - - let name = match &doc["name"] { - Yaml::String(s) => Some(s.clone()), - _ => None, - }; - - let description = match &doc["description"] { - Yaml::String(s) => Some(s.trim().to_string()), - _ => None, - }; - - let disable_model_invocation = match &doc["disable-model-invocation"] { - Yaml::Boolean(b) => *b, - Yaml::String(s) => matches!(s.as_str(), "true" | "yes" | "1"), - _ => false, - }; - - Frontmatter { - name, - description, - disable_model_invocation, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn basic_frontmatter() { - let content = "\ ---- -name: my-skill -description: A test skill -disable-model-invocation: true ---- - -Body content here. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); - assert_eq!( - parsed.frontmatter.description.as_deref(), - Some("A test skill") - ); - assert!(parsed.frontmatter.disable_model_invocation); - assert_eq!(parsed.body.trim(), "Body content here."); - } - - #[test] - fn multiline_folded_description() { - let content = "\ ---- -name: release -description: > - Orchestrate a multi-step release — version bumping, changelog - generation, PR creation, tagging, and publishing. -disable-model-invocation: true ---- - -# Release steps -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("release")); - let desc = parsed.frontmatter.description.unwrap(); - assert!(desc.contains("Orchestrate a multi-step release")); - assert!(desc.contains("publishing")); - assert!(parsed.frontmatter.disable_model_invocation); - assert!(parsed.body.contains("# Release steps")); - } - - #[test] - fn no_frontmatter() { - let content = "Just a body with no frontmatter."; - let parsed = parse(content); - assert!(parsed.frontmatter.name.is_none()); - assert!(parsed.frontmatter.description.is_none()); - assert!(!parsed.frontmatter.disable_model_invocation); - assert_eq!(parsed.body, content); - } - - #[test] - fn empty_frontmatter() { - let content = "\ ---- ---- - -Body after empty frontmatter. -"; - let parsed = parse(content); - assert!(parsed.frontmatter.name.is_none()); - assert!(parsed.frontmatter.description.is_none()); - assert_eq!(parsed.body.trim(), "Body after empty frontmatter."); - } - - #[test] - fn missing_fields_use_defaults() { - let content = "\ ---- -name: partial ---- - -Some body. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("partial")); - assert!(parsed.frontmatter.description.is_none()); - assert!(!parsed.frontmatter.disable_model_invocation); - } - - #[test] - fn unknown_fields_ignored() { - let content = "\ ---- -name: my-skill -future-field: some value -another: 42 ---- - -Body. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("my-skill")); - } - - #[test] - fn body_with_triple_dashes() { - let content = "\ ---- -name: test ---- - -Some body. - ---- - -More body after a horizontal rule. -"; - let parsed = parse(content); - assert_eq!(parsed.frontmatter.name.as_deref(), Some("test")); - assert!(parsed.body.contains("Some body.")); - assert!(parsed.body.contains("More body after a horizontal rule.")); - } -} diff --git a/crates/atuin-ai/src/skills/mod.rs b/crates/atuin-ai/src/skills/mod.rs deleted file mode 100644 index 36b3a2ae..00000000 --- a/crates/atuin-ai/src/skills/mod.rs +++ /dev/null @@ -1,468 +0,0 @@ -//! AI skill discovery, metadata, and lazy loading. -//! -//! Skills are markdown files (`SKILL.md`) with YAML frontmatter that define -//! reusable instructions for the LLM. Only skill metadata (name + description) -//! is sent to the server; full content is loaded on demand via `load_skill`. - -mod frontmatter; -pub(crate) mod walker; - -use std::path::Path; - -use eyre::{Result, eyre}; - -use crate::user_context::interpolate; - -/// Per-skill description truncation limit (before budget calculation). -const MAX_DESCRIPTION_LEN: usize = 1024; - -/// Default total character budget for skill descriptions sent to the server. -const DEFAULT_DESCRIPTION_BUDGET: usize = 9992; - -/// JSON overhead per skill entry: `{"name":"","description":""},` ≈ 30 chars. -const PER_ENTRY_OVERHEAD: usize = 30; - -/// Metadata for a discovered skill. Produced at discovery time from -/// frontmatter only — the body is not read until `load()`. -#[derive(Debug, Clone)] -pub(crate) struct SkillDescriptor { - pub name: String, - pub description: String, - pub source_path: std::path::PathBuf, - pub disable_model_invocation: bool, -} - -/// A name + description pair ready to serialize into the request payload. -#[derive(Debug, Clone, serde::Serialize)] -pub(crate) struct SkillSummary { - pub name: String, - pub description: String, -} - -/// Holds discovered skills and provides lookup, budget packing, and loading. -#[derive(Debug, Clone)] -pub(crate) struct SkillRegistry { - skills: Vec<SkillDescriptor>, -} - -impl SkillRegistry { - /// Discover skills from project and global directories. - pub async fn discover(project_root: Option<&Path>) -> Self { - let global_dir = walker::global_skills_dir(); - let project_dir = project_root.map(walker::project_skills_dir); - - Self::discover_from_dirs(project_dir.as_deref(), &global_dir).await - } - - /// Discover skills from explicit directory paths. Useful for testing. - pub async fn discover_from_dirs( - project_skills_dir: Option<&Path>, - global_skills_dir: &Path, - ) -> Self { - let raw_files = walker::discover(project_skills_dir, global_skills_dir).await; - - let mut skills = Vec::new(); - let mut seen_names = std::collections::HashSet::new(); - - for raw in raw_files { - let parsed = frontmatter::parse(&raw.content); - let fm = parsed.frontmatter; - - let name = fm.name.unwrap_or_else(|| sanitize_name(&raw.dir_name)); - - // Deduplicate: first seen wins (project before global) - if !seen_names.insert(name.clone()) { - continue; - } - - let description = fm - .description - .or_else(|| first_paragraph(&parsed.body)) - .unwrap_or_default(); - - skills.push(SkillDescriptor { - name, - description, - source_path: raw.path, - disable_model_invocation: fm.disable_model_invocation, - }); - } - - Self { skills } - } - - /// Create an empty registry. - #[cfg(test)] - pub fn empty() -> Self { - Self { skills: Vec::new() } - } - - /// Look up a skill by name. - pub fn get(&self, name: &str) -> Option<&SkillDescriptor> { - self.skills.iter().find(|s| s.name == name) - } - - /// All discovered skills. - pub fn all(&self) -> &[SkillDescriptor] { - &self.skills - } - - /// Whether any non-disabled skills exist (determines capability advertisement). - #[cfg(test)] - pub fn has_server_visible_skills(&self) -> bool { - self.skills.iter().any(|s| !s.disable_model_invocation) - } - - /// Pack skill descriptions into the server payload under a character budget. - /// - /// Returns the summaries that fit plus an optional overflow message. - pub fn server_skills(&self) -> (Vec<SkillSummary>, Option<String>) { - self.server_skills_with_budget(DEFAULT_DESCRIPTION_BUDGET) - } - - pub fn server_skills_with_budget(&self, budget: usize) -> (Vec<SkillSummary>, Option<String>) { - let eligible: Vec<&SkillDescriptor> = self - .skills - .iter() - .filter(|s| !s.disable_model_invocation) - .collect(); - - let mut summaries = Vec::new(); - let mut used = 0; - let mut overflow_names = Vec::new(); - - for skill in &eligible { - let truncated_desc = truncate_description(&skill.description, MAX_DESCRIPTION_LEN); - let entry_size = skill.name.len() + truncated_desc.len() + PER_ENTRY_OVERHEAD; - - if used + entry_size > budget && !summaries.is_empty() { - overflow_names.push(skill.name.as_str()); - continue; - } - - used += entry_size; - summaries.push(SkillSummary { - name: skill.name.clone(), - description: truncated_desc, - }); - } - - let overflow = if overflow_names.is_empty() { - None - } else { - Some(format!( - "{} additional skill(s) not listed due to size limits: {}", - overflow_names.len(), - overflow_names.join(", ") - )) - }; - - (summaries, overflow) - } - - /// Load a skill's full body content, with argument substitution and - /// `!`` interpolation applied. - /// - /// `$ARGUMENTS` in the body is replaced with the provided arguments before - /// shell interpolation runs. If `$ARGUMENTS` does not appear in the body - /// and arguments were provided, they are appended as `ARGUMENTS: <value>`. - pub async fn load(&self, name: &str, shell: &str, arguments: Option<&str>) -> Result<String> { - let skill = self - .get(name) - .ok_or_else(|| eyre!("Unknown skill: {name}"))?; - - let content = tokio::fs::read_to_string(&skill.source_path).await?; - let parsed = frontmatter::parse(&content); - let body = parsed.body; - - if body.trim().is_empty() { - return Ok(format!("(Skill '{name}' has no body content)")); - } - - let body = substitute_arguments(&body, arguments); - - Ok(interpolate::interpolate(&body, shell).await) - } -} - -/// Replace `$ARGUMENTS` placeholders in skill body text. -/// -/// If `$ARGUMENTS` appears in the body, all occurrences are replaced with the -/// argument string (or empty string if none). If `$ARGUMENTS` does not appear -/// and arguments were provided, they are appended on a new line. -fn substitute_arguments(body: &str, arguments: Option<&str>) -> String { - let args = arguments.unwrap_or(""); - - if body.contains("$ARGUMENTS") { - return body.replace("$ARGUMENTS", args); - } - - if !args.is_empty() { - return format!("{body}\n\nARGUMENTS: {args}"); - } - - body.to_string() -} - -/// Sanitize a directory name into a valid skill name. -fn sanitize_name(name: &str) -> String { - name.chars() - .map(|c| { - if c.is_ascii_alphanumeric() || c == '-' { - c - } else { - '-' - } - }) - .collect::<String>() - .to_lowercase() -} - -/// Extract the first non-empty paragraph from markdown body text. -fn first_paragraph(body: &str) -> Option<String> { - let trimmed = body.trim(); - if trimmed.is_empty() { - return None; - } - - let para: String = trimmed - .lines() - .take_while(|line| !line.trim().is_empty()) - .collect::<Vec<_>>() - .join(" "); - - let para = para.trim().to_string(); - if para.is_empty() { None } else { Some(para) } -} - -/// Truncate a description to `max_len` characters, adding ellipsis if cut. -fn truncate_description(desc: &str, max_len: usize) -> String { - if desc.len() <= max_len { - return desc.to_string(); - } - let mut end = max_len.saturating_sub(3); - // Avoid splitting a multi-byte char - while !desc.is_char_boundary(end) && end > 0 { - end -= 1; - } - format!("{}...", &desc[..end]) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sanitize_name_basic() { - assert_eq!(sanitize_name("My Skill"), "my-skill"); - assert_eq!(sanitize_name("deploy_prod"), "deploy-prod"); - assert_eq!(sanitize_name("code-review"), "code-review"); - } - - #[test] - fn first_paragraph_extraction() { - assert_eq!( - first_paragraph("Hello world\nSecond line\n\nNew paragraph"), - Some("Hello world Second line".to_string()) - ); - assert_eq!(first_paragraph(""), None); - assert_eq!(first_paragraph("\n\n"), None); - assert_eq!( - first_paragraph("Single line"), - Some("Single line".to_string()) - ); - } - - #[test] - fn truncate_description_short() { - assert_eq!(truncate_description("short", 100), "short"); - } - - #[test] - fn substitute_arguments_replaces_placeholder() { - let body = "Deploy $ARGUMENTS to production."; - assert_eq!( - substitute_arguments(body, Some("patch")), - "Deploy patch to production." - ); - } - - #[test] - fn substitute_arguments_multiple_occurrences() { - let body = "Run $ARGUMENTS then verify $ARGUMENTS worked."; - assert_eq!( - substitute_arguments(body, Some("migrate")), - "Run migrate then verify migrate worked." - ); - } - - #[test] - fn substitute_arguments_appends_when_no_placeholder() { - let body = "Do the thing."; - let result = substitute_arguments(body, Some("extra context")); - assert!(result.starts_with("Do the thing.")); - assert!(result.contains("ARGUMENTS: extra context")); - } - - #[test] - fn substitute_arguments_no_args_no_placeholder() { - let body = "Just a body."; - assert_eq!(substitute_arguments(body, None), "Just a body."); - } - - #[test] - fn substitute_arguments_no_args_clears_placeholder() { - let body = "Deploy $ARGUMENTS to production."; - assert_eq!(substitute_arguments(body, None), "Deploy to production."); - } - - #[test] - fn truncate_description_long() { - let long = "a".repeat(600); - let result = truncate_description(&long, 512); - assert!(result.len() <= 512); - assert!(result.ends_with("...")); - } - - #[test] - fn budget_packing() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "a".to_string(), - description: "Short desc".to_string(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "b".to_string(), - description: "Another desc".to_string(), - source_path: "b/SKILL.md".into(), - disable_model_invocation: false, - }, - ], - }; - - let (summaries, overflow) = registry.server_skills_with_budget(4096); - assert_eq!(summaries.len(), 2); - assert!(overflow.is_none()); - } - - #[test] - fn budget_overflow() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "first".to_string(), - description: "x".repeat(200), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "second".to_string(), - description: "y".repeat(200), - source_path: "b/SKILL.md".into(), - disable_model_invocation: false, - }, - ], - }; - - // Budget only fits one - let (summaries, overflow) = registry.server_skills_with_budget(260); - assert_eq!(summaries.len(), 1); - assert_eq!(summaries[0].name, "first"); - let overflow = overflow.unwrap(); - assert!(overflow.contains("second")); - assert!(overflow.contains("1 additional")); - } - - #[test] - fn disabled_skills_excluded_from_server() { - let registry = SkillRegistry { - skills: vec![ - SkillDescriptor { - name: "visible".to_string(), - description: "I show up".to_string(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }, - SkillDescriptor { - name: "hidden".to_string(), - description: "I don't".to_string(), - source_path: "b/SKILL.md".into(), - disable_model_invocation: true, - }, - ], - }; - - let (summaries, _) = registry.server_skills(); - assert_eq!(summaries.len(), 1); - assert_eq!(summaries[0].name, "visible"); - - // But all() includes both - assert_eq!(registry.all().len(), 2); - } - - #[test] - fn has_server_visible_skills() { - let empty = SkillRegistry::empty(); - assert!(!empty.has_server_visible_skills()); - - let all_disabled = SkillRegistry { - skills: vec![SkillDescriptor { - name: "hidden".to_string(), - description: String::new(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: true, - }], - }; - assert!(!all_disabled.has_server_visible_skills()); - - let some_visible = SkillRegistry { - skills: vec![SkillDescriptor { - name: "visible".to_string(), - description: String::new(), - source_path: "a/SKILL.md".into(), - disable_model_invocation: false, - }], - }; - assert!(some_visible.has_server_visible_skills()); - } - - #[tokio::test] - async fn end_to_end_discover() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - - // Create a skill with frontmatter - let skill_dir = skills_dir.join("my-skill"); - std::fs::create_dir_all(&skill_dir).unwrap(); - std::fs::write( - skill_dir.join("SKILL.md"), - "---\nname: my-skill\ndescription: A test skill\n---\n\nBody here.\n", - ) - .unwrap(); - - // Create a skill with multiline description - let skill_dir2 = skills_dir.join("release"); - std::fs::create_dir_all(&skill_dir2).unwrap(); - std::fs::write( - skill_dir2.join("SKILL.md"), - "---\nname: release\ndescription: >\n Multi-line\n description here.\n---\n\nRelease steps.\n", - ) - .unwrap(); - - let registry = SkillRegistry::discover_from_dirs( - Some(&skills_dir), - &std::path::PathBuf::from("/nonexistent"), - ) - .await; - assert_eq!(registry.all().len(), 2); - - let my_skill = registry.get("my-skill").unwrap(); - assert_eq!(my_skill.description, "A test skill"); - - let release = registry.get("release").unwrap(); - assert!(release.description.contains("Multi-line")); - } -} diff --git a/crates/atuin-ai/src/skills/walker.rs b/crates/atuin-ai/src/skills/walker.rs deleted file mode 100644 index b93845f9..00000000 --- a/crates/atuin-ai/src/skills/walker.rs +++ /dev/null @@ -1,178 +0,0 @@ -//! Filesystem discovery for `SKILL.md` files. -//! -//! Recursively scans `.atuin/skills/` directories at the project and global -//! levels. Supports nested directories for organization (e.g. -//! `.atuin/skills/ops/deploy/SKILL.md`). - -use std::path::{Path, PathBuf}; - -const SKILL_FILENAME: &str = "SKILL.md"; - -/// A skill file found on disk, before body interpolation. -#[derive(Debug)] -pub(crate) struct RawSkillFile { - /// Full path to the SKILL.md file. - pub path: PathBuf, - /// The parent directory name, used as fallback skill name. - pub dir_name: String, - /// Whether this is a project-level skill (vs global). - #[allow(dead_code)] - pub is_project: bool, - /// Raw file content. - pub content: String, -} - -/// Discover all `SKILL.md` files across project and global skill directories. -/// -/// Project skills come first in the returned list (higher priority for -/// deduplication). -pub(crate) async fn discover( - project_skills_dir: Option<&Path>, - global_skills_dir: &Path, -) -> Vec<RawSkillFile> { - let mut files = Vec::new(); - - // Project skills first (higher priority) - if let Some(dir) = project_skills_dir.filter(|d| d.is_dir()) { - scan_dir(dir, true, &mut files).await; - } - - // Global skills second - if global_skills_dir.is_dir() { - scan_dir(global_skills_dir, false, &mut files).await; - } - - files -} - -/// The default global skills directory (`~/.config/atuin/skills/`). -pub(crate) fn global_skills_dir() -> PathBuf { - atuin_common::utils::config_dir().join("skills") -} - -/// Given a project working directory, return the project skills directory. -pub(crate) fn project_skills_dir(project_root: &Path) -> PathBuf { - project_root.join(".atuin").join("skills") -} - -/// Recursively scan a directory for `SKILL.md` files. -async fn scan_dir(dir: &Path, is_project: bool, out: &mut Vec<RawSkillFile>) { - let mut entries = match tokio::fs::read_dir(dir).await { - Ok(entries) => entries, - Err(e) => { - tracing::debug!("Could not read skills directory {}: {e}", dir.display()); - return; - } - }; - - let mut subdirs = Vec::new(); - - while let Ok(Some(entry)) = entries.next_entry().await { - let path = entry.path(); - - if path.is_dir() { - // Check for SKILL.md directly in this directory - let skill_path = path.join(SKILL_FILENAME); - if skill_path.is_file() { - let dir_name = path - .file_name() - .and_then(|n| n.to_str()) - .unwrap_or("unknown") - .to_string(); - - match tokio::fs::read_to_string(&skill_path).await { - Ok(content) => { - out.push(RawSkillFile { - path: skill_path, - dir_name, - is_project, - content, - }); - } - Err(e) => { - tracing::warn!("Failed to read skill file {}: {e}", skill_path.display()); - } - } - } - - // Collect subdirectories for recursive scanning - subdirs.push(path); - } - } - - for subdir in subdirs { - Box::pin(scan_dir(&subdir, is_project, out)).await; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn setup_skill(dir: &Path, rel_path: &str, content: &str) { - let skill_dir = dir.join(rel_path); - std::fs::create_dir_all(&skill_dir).unwrap(); - std::fs::write(skill_dir.join(SKILL_FILENAME), content).unwrap(); - } - - #[tokio::test] - async fn discovers_project_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "deploy", "---\nname: deploy\n---\nDeploy."); - - let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; - assert_eq!(files.len(), 1); - assert_eq!(files[0].dir_name, "deploy"); - assert!(files[0].is_project); - } - - #[tokio::test] - async fn discovers_global_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "review", "---\nname: review\n---\nReview."); - - let files = discover(None, &skills_dir).await; - assert_eq!(files.len(), 1); - assert_eq!(files[0].dir_name, "review"); - assert!(!files[0].is_project); - } - - #[tokio::test] - async fn discovers_nested_skills() { - let dir = tempfile::tempdir().unwrap(); - let skills_dir = dir.path().join("skills"); - setup_skill(&skills_dir, "ops/deploy", "---\nname: deploy\n---\n"); - setup_skill(&skills_dir, "ops/rollback", "---\nname: rollback\n---\n"); - - let files = discover(Some(&skills_dir), Path::new("/nonexistent")).await; - assert_eq!(files.len(), 2); - } - - #[tokio::test] - async fn project_comes_before_global() { - let project = tempfile::tempdir().unwrap(); - let global = tempfile::tempdir().unwrap(); - let project_skills = project.path().join("skills"); - let global_skills = global.path().join("skills"); - - setup_skill(&project_skills, "a-skill", "project"); - setup_skill(&global_skills, "b-skill", "global"); - - let files = discover(Some(&project_skills), &global_skills).await; - assert_eq!(files.len(), 2); - assert!(files[0].is_project); - assert!(!files[1].is_project); - } - - #[tokio::test] - async fn missing_directories_handled() { - let files = discover( - Some(Path::new("/does/not/exist")), - Path::new("/also/missing"), - ) - .await; - assert!(files.is_empty()); - } -} diff --git a/crates/atuin-ai/src/snapshots.rs b/crates/atuin-ai/src/snapshots.rs deleted file mode 100644 index d46223a8..00000000 --- a/crates/atuin-ai/src/snapshots.rs +++ /dev/null @@ -1,414 +0,0 @@ -//! Backup snapshots for files before AI edits. -//! -//! Before the first edit to a file within a session, a snapshot of the -//! original content is saved so the user can recover if needed. Snapshots -//! are stored alongside a manifest that maps sanitized filenames back to -//! their original paths. -//! -//! Filenames use percent-encoding (`/` → `%2F`) so the snapshot directory -//! is human-readable via `ls`. - -use std::collections::HashMap; -use std::io::Write; -use std::path::{Path, PathBuf}; - -use eyre::{Result, eyre}; -use serde::{Deserialize, Serialize}; -use time::OffsetDateTime; - -/// Snapshot store for a single session. -/// -/// Each session gets its own directory under the snapshots root: -/// `<data_dir>/ai/snapshots/<session_id>/` -/// -/// Files are stored with percent-encoded filenames derived from their -/// canonical paths, alongside a `manifest.json` that maps filenames -/// back to original paths with timestamps. -#[derive(Debug)] -pub(crate) struct SnapshotStore { - session_dir: PathBuf, - manifest: SnapshotManifest, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -struct SnapshotManifest { - files: HashMap<String, SnapshotEntry>, -} - -#[derive(Debug, Serialize, Deserialize)] -struct SnapshotEntry { - original_path: String, - snapshot_at: String, - size_bytes: u64, -} - -impl SnapshotStore { - /// Open or create a snapshot store for the given session directory. - /// - /// If a manifest already exists (from a prior CLI invocation in the same - /// session), it's loaded so we don't re-snapshot files that were already - /// backed up. - pub fn open(session_dir: PathBuf) -> Result<Self> { - let manifest_path = session_dir.join("manifest.json"); - let manifest = if manifest_path.exists() { - let data = fs_err::read_to_string(&manifest_path)?; - serde_json::from_str(&data)? - } else { - SnapshotManifest::default() - }; - - Ok(Self { - session_dir, - manifest, - }) - } - - /// Snapshot a file's contents if it hasn't been snapshotted yet this session. - /// - /// Returns `true` if a new snapshot was created, `false` if one already - /// existed. The `canonical_path` should be absolute (already tilde-expanded - /// and resolved). - pub fn ensure_snapshot(&mut self, canonical_path: &Path, content: &[u8]) -> Result<bool> { - let filename = sanitize_path(canonical_path); - - if self.manifest.files.contains_key(&filename) { - return Ok(false); - } - - fs_err::create_dir_all(&self.session_dir)?; - - let snapshot_path = self.session_dir.join(&filename); - atomic_write_file(&snapshot_path, content)?; - - let now = OffsetDateTime::now_utc(); - let entry = SnapshotEntry { - original_path: canonical_path.to_string_lossy().into_owned(), - snapshot_at: format_iso8601(now), - size_bytes: content.len() as u64, - }; - - self.manifest.files.insert(filename, entry); - self.save_manifest()?; - - Ok(true) - } - - /// Whether a file has already been snapshotted in this session. - #[cfg(test)] - pub fn has_snapshot(&self, canonical_path: &Path) -> bool { - let filename = sanitize_path(canonical_path); - self.manifest.files.contains_key(&filename) - } - - fn save_manifest(&self) -> Result<()> { - let json = serde_json::to_string_pretty(&self.manifest)?; - atomic_write_file(&self.session_dir.join("manifest.json"), json.as_bytes()) - } -} - -/// Percent-encode a path for use as a filename. -/// -/// Encodes `%` as `%25`, `/` as `%2F`, and `\` as `%5C`, then strips -/// leading separators and drive prefixes (e.g. `C:\`). The result is -/// always a flat filename safe for use with `Path::join` on any platform. -/// -/// Example (Unix): `/Users/me/.config/foo.toml` → `Users%2Fme%2F.config%2Ffoo.toml` -/// Example (Windows): `C:\Users\me\config.toml` → `Users%5Cme%5Cconfig.toml` -pub(crate) fn sanitize_path(path: &Path) -> String { - let s = path.to_string_lossy(); - // Strip drive letter prefix on Windows (e.g. "C:\") - let s = s.strip_prefix('/').unwrap_or_else(|| { - // Handle Windows drive prefix like "C:\" or "C:/" - if s.len() >= 3 - && s.as_bytes()[0].is_ascii_alphabetic() - && s.as_bytes()[1] == b':' - && (s.as_bytes()[2] == b'\\' || s.as_bytes()[2] == b'/') - { - &s[3..] - } else { - &s - } - }); - s.replace('%', "%25") - .replace('/', "%2F") - .replace('\\', "%5C") -} - -/// Write a file atomically using temp-file-then-rename. -/// -/// Creates a temporary file in the same directory as `target`, writes -/// content, fsyncs, then renames into place. Preserves permissions from -/// the original file if it exists. -pub(crate) fn atomic_write_file(target: &Path, content: &[u8]) -> Result<()> { - let dir = target - .parent() - .ok_or_else(|| eyre!("target path has no parent directory"))?; - fs_err::create_dir_all(dir)?; - - let mut tmp = tempfile::NamedTempFile::new_in(dir)?; - tmp.write_all(content)?; - tmp.as_file().sync_all()?; - - // Preserve permissions from original if it exists - if let Ok(meta) = std::fs::metadata(target) { - std::fs::set_permissions(tmp.path(), meta.permissions())?; - } - - tmp.persist(target).map_err(|e| { - eyre!( - "failed to persist atomic write to {}: {}", - target.display(), - e - ) - })?; - Ok(()) -} - -fn format_iso8601(dt: OffsetDateTime) -> String { - format!( - "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", - dt.year(), - dt.month() as u8, - dt.day(), - dt.hour(), - dt.minute(), - dt.second(), - ) -} - -#[cfg(test)] -mod tests { - use super::*; - - // ── sanitize_path ────────────────────────────────────────── - - #[test] - fn sanitize_absolute_path() { - let path = Path::new("/Users/me/.config/atuin/config.toml"); - assert_eq!( - sanitize_path(path), - "Users%2Fme%2F.config%2Fatuin%2Fconfig.toml" - ); - } - - #[test] - fn sanitize_preserves_existing_percent() { - let path = Path::new("/data/100%done/file.txt"); - assert_eq!(sanitize_path(path), "data%2F100%25done%2Ffile.txt"); - } - - #[test] - fn sanitize_relative_path() { - let path = Path::new("relative/path.txt"); - assert_eq!(sanitize_path(path), "relative%2Fpath.txt"); - } - - #[test] - fn sanitize_no_collision_between_similar_paths() { - let a = sanitize_path(Path::new("/foo/bar-baz")); - let b = sanitize_path(Path::new("/foo/bar/baz")); - assert_ne!(a, b); - } - - #[test] - fn sanitize_backslash_encoded() { - // Windows-style path: backslashes become %5C, drive prefix stripped - let s = sanitize_path(Path::new("C:\\Users\\me\\config.toml")); - assert!(!s.contains('\\'), "backslashes must be encoded: {s}"); - assert!(!s.starts_with("C:"), "drive prefix must be stripped: {s}"); - assert!(s.contains("Users")); - assert!(s.contains("config.toml")); - } - - #[test] - fn sanitize_result_is_flat_filename() { - // The result must not be interpreted as a path with separators - // when passed to Path::join — no raw / or \ allowed. - let unix = sanitize_path(Path::new("/home/user/file.txt")); - assert!(!unix.contains('/')); - // Construct as if on Windows - let win = "C:\\Users\\me\\file.txt"; - let encoded = win - .strip_prefix("C:\\") - .unwrap() - .replace('%', "%25") - .replace('/', "%2F") - .replace('\\', "%5C"); - assert!(!encoded.contains('\\')); - assert!(!encoded.contains('/')); - } - - // ── atomic_write_file ────────────────────────────────────── - - #[test] - fn atomic_write_creates_file() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - atomic_write_file(&target, b"hello world").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "hello world"); - } - - #[test] - fn atomic_write_overwrites_existing() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - std::fs::write(&target, "old content").unwrap(); - atomic_write_file(&target, b"new content").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "new content"); - } - - #[test] - fn atomic_write_creates_parent_dirs() { - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("sub").join("dir").join("test.txt"); - - atomic_write_file(&target, b"nested").unwrap(); - - assert_eq!(std::fs::read_to_string(&target).unwrap(), "nested"); - } - - #[cfg(unix)] - #[test] - fn atomic_write_preserves_permissions() { - use std::os::unix::fs::PermissionsExt; - - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("test.txt"); - - std::fs::write(&target, "original").unwrap(); - std::fs::set_permissions(&target, std::fs::Permissions::from_mode(0o600)).unwrap(); - - atomic_write_file(&target, b"updated").unwrap(); - - let mode = std::fs::metadata(&target).unwrap().permissions().mode() & 0o777; - assert_eq!(mode, 0o600); - } - - // ── SnapshotStore ────────────────────────────────────────── - - #[test] - fn snapshot_creates_file_and_manifest() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - let file_path = Path::new("/Users/me/.config/foo.toml"); - let created = store - .ensure_snapshot(file_path, b"[key]\nval = 1\n") - .unwrap(); - - assert!(created); - assert!(store.has_snapshot(file_path)); - - // Snapshot file on disk - let expected_file = session_dir.join("Users%2Fme%2F.config%2Ffoo.toml"); - assert!(expected_file.exists()); - assert_eq!( - std::fs::read_to_string(&expected_file).unwrap(), - "[key]\nval = 1\n" - ); - - // Manifest on disk - let manifest_path = session_dir.join("manifest.json"); - assert!(manifest_path.exists()); - let manifest: serde_json::Value = - serde_json::from_str(&std::fs::read_to_string(&manifest_path).unwrap()).unwrap(); - let files = manifest["files"].as_object().unwrap(); - assert_eq!(files.len(), 1); - let entry = &files["Users%2Fme%2F.config%2Ffoo.toml"]; - assert_eq!( - entry["original_path"].as_str().unwrap(), - "/Users/me/.config/foo.toml" - ); - assert_eq!(entry["size_bytes"].as_u64().unwrap(), 14); - } - - #[test] - fn snapshot_is_idempotent() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - let path = Path::new("/etc/hosts"); - let first = store.ensure_snapshot(path, b"first content").unwrap(); - let second = store.ensure_snapshot(path, b"different content").unwrap(); - - assert!(first); - assert!(!second); - - // Original content preserved, not overwritten - let snapshot_file = session_dir.join("etc%2Fhosts"); - assert_eq!( - std::fs::read_to_string(snapshot_file).unwrap(), - "first content" - ); - } - - #[test] - fn snapshot_store_loads_existing_manifest() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - - // First store: create a snapshot - { - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - store - .ensure_snapshot(Path::new("/etc/hosts"), b"127.0.0.1") - .unwrap(); - } - - // Second store (simulates new CLI invocation): should see existing snapshot - { - let mut store = SnapshotStore::open(session_dir).unwrap(); - assert!(store.has_snapshot(Path::new("/etc/hosts"))); - - let created = store - .ensure_snapshot(Path::new("/etc/hosts"), b"new content") - .unwrap(); - assert!(!created); - } - } - - #[test] - fn snapshot_multiple_files() { - let dir = tempfile::tempdir().unwrap(); - let session_dir = dir.path().join("session-abc"); - let mut store = SnapshotStore::open(session_dir.clone()).unwrap(); - - store - .ensure_snapshot(Path::new("/etc/hosts"), b"hosts content") - .unwrap(); - store - .ensure_snapshot(Path::new("/Users/me/.bashrc"), b"bashrc content") - .unwrap(); - - assert!(store.has_snapshot(Path::new("/etc/hosts"))); - assert!(store.has_snapshot(Path::new("/Users/me/.bashrc"))); - assert!(!store.has_snapshot(Path::new("/nonexistent"))); - - // Both snapshot files exist - assert!(session_dir.join("etc%2Fhosts").exists()); - assert!(session_dir.join("Users%2Fme%2F.bashrc").exists()); - - // Manifest has both entries - let manifest: serde_json::Value = serde_json::from_str( - &std::fs::read_to_string(session_dir.join("manifest.json")).unwrap(), - ) - .unwrap(); - assert_eq!(manifest["files"].as_object().unwrap().len(), 2); - } - - #[test] - fn format_iso8601_produces_valid_format() { - let dt = OffsetDateTime::from_unix_timestamp(1700000000).unwrap(); - let formatted = format_iso8601(dt); - assert_eq!(formatted.len(), 20); - assert!(formatted.starts_with("2023-")); - assert!(formatted.contains('T')); - assert!(formatted.ends_with('Z')); - } -} diff --git a/crates/atuin-ai/src/store.rs b/crates/atuin-ai/src/store.rs deleted file mode 100644 index 20b9e881..00000000 --- a/crates/atuin-ai/src/store.rs +++ /dev/null @@ -1,554 +0,0 @@ -use std::path::Path; -use std::str::FromStr; -use std::time::Duration; - -use eyre::{Result, eyre}; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; -use time::OffsetDateTime; - -// Database row mappings — all columns are kept even if not yet read in -// non-test code, since they're part of the schema and used in tests. -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) struct StoredSession { - pub id: String, - pub head_id: Option<String>, - pub server_session_id: Option<String>, - pub directory: Option<String>, - pub git_root: Option<String>, - pub created_at: i64, - pub updated_at: i64, - pub archived_at: Option<i64>, -} - -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) struct StoredEvent { - pub id: String, - pub session_id: String, - pub parent_id: Option<String>, - pub invocation_id: String, - pub event_type: String, - pub event_data: String, - pub created_at: i64, -} - -/// Row type returned by session queries (avoids clippy::type_complexity). -type SessionRow = ( - String, - Option<String>, - Option<String>, - Option<String>, - Option<String>, - i64, - i64, - Option<i64>, -); - -/// Row type returned by event queries. -type EventRow = (String, String, Option<String>, String, String, String, i64); - -pub(crate) struct AiSessionStore { - pool: SqlitePool, -} - -impl AiSessionStore { - pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { - let path = path.as_ref(); - let path_str = path - .as_os_str() - .to_str() - .ok_or_else(|| eyre!("AI session database path is not valid UTF-8: {path:?}"))?; - - let is_memory = path_str.contains(":memory:"); - - if !is_memory - && !path.exists() - && let Some(dir) = path.parent() - { - fs_err::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path_str)? - .journal_mode(SqliteJournalMode::Wal) - .optimize_on_close(true, None) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - sqlx::migrate!("./migrations").run(&pool).await?; - - #[cfg(unix)] - if !is_memory { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; - } - - Ok(Self { pool }) - } - - pub async fn create_session( - &self, - id: &str, - directory: Option<&str>, - git_root: Option<&str>, - ) -> Result<StoredSession> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - - sqlx::query( - "INSERT INTO sessions (id, directory, git_root, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?4)", - ) - .bind(id) - .bind(directory) - .bind(git_root) - .bind(now) - .execute(&self.pool) - .await?; - - Ok(StoredSession { - id: id.to_string(), - head_id: None, - server_session_id: None, - directory: directory.map(String::from), - git_root: git_root.map(String::from), - created_at: now, - updated_at: now, - archived_at: None, - }) - } - - #[allow(dead_code)] // used in tests; will be used by daemon service - pub async fn get_session(&self, id: &str) -> Result<Option<StoredSession>> { - let row: Option<SessionRow> = sqlx::query_as( - "SELECT id, head_id, server_session_id, directory, git_root, - created_at, updated_at, archived_at - FROM sessions WHERE id = ?1", - ) - .bind(id) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map( - |( - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - )| { - StoredSession { - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - } - }, - )) - } - - /// Find the most recent non-archived session matching the given directory or git - /// root, updated within `max_age_secs` seconds. - pub async fn find_resumable_session( - &self, - directory: Option<&str>, - git_root: Option<&str>, - max_age_secs: i64, - ) -> Result<Option<StoredSession>> { - let cutoff = OffsetDateTime::now_utc().unix_timestamp() - max_age_secs; - - let row: Option<SessionRow> = sqlx::query_as( - "SELECT id, head_id, server_session_id, directory, git_root, - created_at, updated_at, archived_at - FROM sessions - WHERE archived_at IS NULL - AND updated_at > ?1 - AND (directory = ?2 OR (git_root IS NOT NULL AND git_root = ?3)) - ORDER BY updated_at DESC - LIMIT 1", - ) - .bind(cutoff) - .bind(directory) - .bind(git_root) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map( - |( - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - )| { - StoredSession { - id, - head_id, - server_session_id, - directory, - git_root, - created_at, - updated_at, - archived_at, - } - }, - )) - } - - /// Append a single event and update the session's `head_id` and `updated_at`. - pub async fn append_event( - &self, - session_id: &str, - event_id: &str, - parent_id: Option<&str>, - invocation_id: &str, - event_type: &str, - event_data: &str, - ) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - - let mut tx = self.pool.begin().await?; - - sqlx::query( - "INSERT INTO session_events (id, session_id, parent_id, invocation_id, event_type, event_data, created_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", - ) - .bind(event_id) - .bind(session_id) - .bind(parent_id) - .bind(invocation_id) - .bind(event_type) - .bind(event_data) - .bind(now) - .execute(&mut *tx) - .await?; - - sqlx::query("UPDATE sessions SET head_id = ?1, updated_at = ?2 WHERE id = ?3") - .bind(event_id) - .bind(now) - .bind(session_id) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - Ok(()) - } - - /// Load all events for a session, ordered chronologically. - pub async fn load_events(&self, session_id: &str) -> Result<Vec<StoredEvent>> { - let rows: Vec<EventRow> = sqlx::query_as( - "SELECT id, session_id, parent_id, invocation_id, event_type, event_data, created_at - FROM session_events - WHERE session_id = ?1 - ORDER BY created_at ASC, rowid ASC", - ) - .bind(session_id) - .fetch_all(&self.pool) - .await?; - - Ok(rows - .into_iter() - .map( - |(id, session_id, parent_id, invocation_id, event_type, event_data, created_at)| { - StoredEvent { - id, - session_id, - parent_id, - invocation_id, - event_type, - event_data, - created_at, - } - }, - ) - .collect()) - } - - pub async fn update_server_session_id( - &self, - session_id: &str, - server_session_id: &str, - ) -> Result<()> { - sqlx::query("UPDATE sessions SET server_session_id = ?1 WHERE id = ?2") - .bind(server_session_id) - .bind(session_id) - .execute(&self.pool) - .await?; - Ok(()) - } - - pub async fn archive_session(&self, session_id: &str) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - sqlx::query("UPDATE sessions SET archived_at = ?1 WHERE id = ?2") - .bind(now) - .bind(session_id) - .execute(&self.pool) - .await?; - Ok(()) - } - - // ── Session metadata (key-value per session) ── - - /// Read a metadata value for a session. Returns `None` if the key doesn't - /// exist or the session hasn't been persisted yet. - pub async fn get_metadata(&self, session_id: &str, key: &str) -> Result<Option<String>> { - let row: Option<(String,)> = - sqlx::query_as("SELECT value FROM session_metadata WHERE session_id = ?1 AND key = ?2") - .bind(session_id) - .bind(key) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(|(v,)| v)) - } - - /// Write a metadata value for a session (upsert). - pub async fn set_metadata(&self, session_id: &str, key: &str, value: &str) -> Result<()> { - let now = OffsetDateTime::now_utc().unix_timestamp(); - sqlx::query( - "INSERT INTO session_metadata (session_id, key, value, updated_at) - VALUES (?1, ?2, ?3, ?4) - ON CONFLICT (session_id, key) DO UPDATE SET value = ?3, updated_at = ?4", - ) - .bind(session_id) - .bind(key) - .bind(value) - .bind(now) - .execute(&self.pool) - .await?; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn new_test_store() -> AiSessionStore { - AiSessionStore::new("sqlite::memory:", 2.0).await.unwrap() - } - - #[tokio::test] - async fn test_create_and_get_session() { - let store = new_test_store().await; - - let session = store - .create_session("s1", Some("/home/user/project"), Some("/home/user/project")) - .await - .unwrap(); - assert_eq!(session.id, "s1"); - assert!(session.head_id.is_none()); - assert!(session.archived_at.is_none()); - - let loaded = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(loaded.id, "s1"); - assert_eq!(loaded.directory.as_deref(), Some("/home/user/project")); - } - - #[tokio::test] - async fn test_get_nonexistent_session() { - let store = new_test_store().await; - assert!(store.get_session("nope").await.unwrap().is_none()); - } - - #[tokio::test] - async fn test_append_and_load_events() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .append_event( - "s1", - "e1", - None, - "inv1", - "user_message", - r#"{"content":"hello"}"#, - ) - .await - .unwrap(); - store - .append_event( - "s1", - "e2", - Some("e1"), - "inv1", - "text", - r#"{"content":"hi there"}"#, - ) - .await - .unwrap(); - - let events = store.load_events("s1").await.unwrap(); - assert_eq!(events.len(), 2); - assert_eq!(events[0].id, "e1"); - assert!(events[0].parent_id.is_none()); - assert_eq!(events[0].invocation_id, "inv1"); - assert_eq!(events[1].id, "e2"); - assert_eq!(events[1].parent_id.as_deref(), Some("e1")); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(session.head_id.as_deref(), Some("e2")); - } - - #[tokio::test] - async fn test_find_resumable_session() { - let store = new_test_store().await; - store - .create_session("s1", Some("/home/user/project"), None) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/home/user/project"), None, 3600) - .await - .unwrap(); - assert!(found.is_some()); - assert_eq!(found.unwrap().id, "s1"); - } - - #[tokio::test] - async fn test_find_resumable_by_git_root() { - let store = new_test_store().await; - store - .create_session( - "s1", - Some("/home/user/project/sub"), - Some("/home/user/project"), - ) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/different/dir"), Some("/home/user/project"), 3600) - .await - .unwrap(); - assert!(found.is_some()); - assert_eq!(found.unwrap().id, "s1"); - } - - #[tokio::test] - async fn test_find_resumable_skips_archived() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - store.archive_session("s1").await.unwrap(); - - let found = store - .find_resumable_session(Some("/tmp"), None, 3600) - .await - .unwrap(); - assert!(found.is_none()); - } - - #[tokio::test] - async fn test_find_resumable_no_match_different_dir() { - let store = new_test_store().await; - store - .create_session("s1", Some("/home/user/project"), None) - .await - .unwrap(); - - let found = store - .find_resumable_session(Some("/other/dir"), None, 3600) - .await - .unwrap(); - assert!(found.is_none()); - } - - #[tokio::test] - async fn test_archive_session() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store.archive_session("s1").await.unwrap(); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert!(session.archived_at.is_some()); - } - - #[tokio::test] - async fn test_update_server_session_id() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .update_server_session_id("s1", "server-abc") - .await - .unwrap(); - - let session = store.get_session("s1").await.unwrap().unwrap(); - assert_eq!(session.server_session_id.as_deref(), Some("server-abc")); - } - - #[tokio::test] - async fn test_find_resumable_does_not_mutate() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - let before = store.get_session("s1").await.unwrap().unwrap(); - store - .find_resumable_session(Some("/tmp"), None, 3600) - .await - .unwrap() - .unwrap(); - let after = store.get_session("s1").await.unwrap().unwrap(); - - assert_eq!(before.updated_at, after.updated_at); - } - - #[tokio::test] - async fn test_events_ordered_chronologically() { - let store = new_test_store().await; - store - .create_session("s1", Some("/tmp"), None) - .await - .unwrap(); - - store - .append_event("s1", "e1", None, "inv1", "user_message", "{}") - .await - .unwrap(); - store - .append_event("s1", "e2", Some("e1"), "inv1", "text", "{}") - .await - .unwrap(); - store - .append_event("s1", "e3", Some("e2"), "inv2", "user_message", "{}") - .await - .unwrap(); - - let events = store.load_events("s1").await.unwrap(); - assert_eq!(events.len(), 3); - assert!(events[0].created_at <= events[1].created_at); - assert!(events[1].created_at <= events[2].created_at); - assert_eq!(events[2].invocation_id, "inv2"); - } -} diff --git a/crates/atuin-ai/src/stream.rs b/crates/atuin-ai/src/stream.rs deleted file mode 100644 index e78dc2e1..00000000 --- a/crates/atuin-ai/src/stream.rs +++ /dev/null @@ -1,288 +0,0 @@ -// ─────────────────────────────────────────────────────────────────── -// SSE streaming -// ─────────────────────────────────────────────────────────────────── - -use atuin_client::history::History; -use atuin_client::settings::AiCapabilities; - -use crate::context::history_output_capability_available; -use atuin_common::tls::ensure_crypto_provider; - -use eventsource_stream::Eventsource; -use eyre::{Context, Result}; -use futures::StreamExt; -use reqwest::Url; -use reqwest::header::USER_AGENT; - -use crate::context::ClientContext; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); - -/// Frames that alter the stream lifecycle — terminal or state-changing. -#[derive(Debug, Clone)] -pub(crate) enum StreamControl { - Done { session_id: String }, - Error(String), - StatusChanged(String), -} - -/// Frames that carry conversation content — they mutate the event log. -#[derive(Debug, Clone)] -pub(crate) enum StreamContent { - TextChunk(String), - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - remote: bool, - content_length: Option<usize>, - }, -} - -/// A frame from the SSE stream, classified as control or content. -#[derive(Debug, Clone)] -pub(crate) enum StreamFrame { - Content(StreamContent), - Control(StreamControl), -} - -/// Per-turn request payload for the chat API. -pub(crate) struct ChatRequest { - pub messages: Vec<serde_json::Value>, - pub session_id: Option<String>, - pub capabilities: Vec<String>, - pub invocation_id: String, -} - -impl ChatRequest { - pub(crate) fn new( - messages: Vec<serde_json::Value>, - session_id: Option<String>, - capabilities: &AiCapabilities, - history_output_available: bool, - invocation_id: String, - ) -> Self { - let mut caps = vec![ - "client_invocations".to_string(), - "client_v1_load_skill".to_string(), - ]; - if capabilities.enable_history_search.unwrap_or(true) { - caps.push("client_v1_atuin_history".to_string()); - } - if capabilities.enable_file_tools.unwrap_or(true) { - caps.push("client_v1_read_file".to_string()); - caps.push("client_v1_edit_file".to_string()); - caps.push("client_v1_write_file".to_string()); - } - if capabilities.enable_command_execution.unwrap_or(true) { - caps.push("client_v1_execute_shell_command".to_string()); - } - if history_output_capability_available(history_output_available) - && capabilities.enable_history_output.unwrap_or(true) - { - caps.push("client_v1_atuin_output".to_string()); - } - if let Ok(extra) = std::env::var("ATUIN_AI__ADDITIONAL_CAPS") { - caps.extend( - extra - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()), - ); - } - - Self { - messages, - session_id, - capabilities: caps, - invocation_id, - } - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) fn create_chat_stream( - hub_address: String, - token: String, - request: ChatRequest, - client_ctx: ClientContext, - send_cwd: bool, - last_command: Option<History>, - user_contexts: Vec<crate::user_context::UserContext>, - skill_summaries: Vec<crate::skills::SkillSummary>, - skill_overflow: Option<String>, -) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<StreamFrame>> + Send>> { - Box::pin(async_stream::stream! { - ensure_crypto_provider(); - let endpoint = match hub_url(&hub_address, "/api/cli/chat") { - Ok(url) => url, - Err(e) => { - yield Err(e); - return; - } - }; - - tracing::debug!("Sending SSE request to {endpoint}"); - - let context = client_ctx.to_json(send_cwd, last_command.as_ref()); - - let mut config = serde_json::json!({ - "capabilities": request.capabilities, - }); - - if !user_contexts.is_empty() { - config["user_contexts"] = serde_json::json!(user_contexts); - } - - if !skill_summaries.is_empty() { - config["skills"] = serde_json::json!(skill_summaries); - if let Some(ref overflow) = skill_overflow { - config["skills_overflow"] = serde_json::json!(overflow); - } - } - - if let Ok(model) = std::env::var("ATUIN_AI__MODEL") - && !model.trim().is_empty() { - config["model"] = serde_json::json!(model.trim()); - - } - - - let mut request_body = serde_json::json!({ - "messages": request.messages, - "context": context, - "config": config, - "invocation_id": request.invocation_id - }); - - if let Some(ref sid) = request.session_id { - tracing::trace!("Including session_id in request: {sid}"); - request_body["session_id"] = serde_json::json!(sid); - } - - let client = reqwest::Client::new(); - let response = match client - .post(endpoint.clone()) - .header("Accept", "text/event-stream") - .header(USER_AGENT, APP_USER_AGENT) - .bearer_auth(&token) - .json(&request_body) - .send() - .await - { - Ok(resp) => resp, - Err(e) => { - yield Err(eyre::eyre!("Failed to send SSE request: {}", e)); - return; - } - }; - - let status = response.status(); - if status == reqwest::StatusCode::UNAUTHORIZED { - tracing::error!("SSE request failed with status: {status}, clearing session"); - let _ = atuin_client::hub::delete_session().await; - yield Err(eyre::eyre!("Hub session expired. Re-run to authenticate again.")); - return; - } - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - tracing::error!("SSE request failed ({}): {}", status, body); - yield Err(eyre::eyre!("SSE request failed ({}): {}", status, body)); - return; - } - - let byte_stream = response.bytes_stream(); - let mut stream = byte_stream.eventsource(); - - while let Some(event) = stream.next().await { - match event { - Ok(sse_event) => { - let event_type = sse_event.event.as_str(); - let data = sse_event.data.clone(); - - tracing::debug!(event_type = %event_type, "SSE event received"); - - match event_type { - "text" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(content) = json.get("content").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Content(StreamContent::TextChunk(content.to_string()))); - } - } - "tool_call" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let id = json.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = json.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let input = json.get("input").cloned().unwrap_or(serde_json::json!({})); - yield Ok(StreamFrame::Content(StreamContent::ToolCall { id, name, input })); - } - } - "tool_result" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let tool_use_id = json.get("tool_use_id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let content = json.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let is_error = json.get("is_error").and_then(|v| v.as_bool()).unwrap_or(false); - let remote = json.get("remote").and_then(|v| v.as_bool()).unwrap_or(false); - let content_length = json.get("content_length").and_then(|v| v.as_u64()).map(|v| v as usize); - yield Ok(StreamFrame::Content(StreamContent::ToolResult { tool_use_id, content, is_error, remote, content_length })); - } - } - "status" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) - && let Some(state) = json.get("state").and_then(|v| v.as_str()) - { - yield Ok(StreamFrame::Control(StreamControl::StatusChanged(state.to_string()))); - } - } - "done" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let session_id = json.get("session_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - yield Ok(StreamFrame::Control(StreamControl::Done { session_id })); - } else { - yield Ok(StreamFrame::Control(StreamControl::Done { session_id: String::new() })); - } - break; - } - "error" => { - if let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) { - let message = json.get("message").and_then(|v| v.as_str()).unwrap_or("Unknown error").to_string(); - tracing::error!("SSE error: {}", message); - yield Ok(StreamFrame::Control(StreamControl::Error(message))); - } else { - tracing::error!("SSE error: {}", data); - yield Ok(StreamFrame::Control(StreamControl::Error(data))); - } - break; - } - _ => {} - } - } - Err(e) => { - yield Err(eyre::eyre!("SSE error: {}", e)); - break; - } - } - } - }) -} - -fn hub_url(base: &str, path: &str) -> Result<Url> { - let base_with_slash = if base.ends_with('/') { - base.to_string() - } else { - format!("{base}/") - }; - let stripped = path.strip_prefix('/').unwrap_or(path); - Url::parse(&base_with_slash)? - .join(stripped) - .context("failed to build hub URL") -} diff --git a/crates/atuin-ai/src/tools/descriptor.rs b/crates/atuin-ai/src/tools/descriptor.rs deleted file mode 100644 index 4190540c..00000000 --- a/crates/atuin-ai/src/tools/descriptor.rs +++ /dev/null @@ -1,129 +0,0 @@ -/// Centralized metadata for a tool type. -/// -/// Covers both client-side tools (ones the CLI executes locally) and -/// server-side tools (ones the API executes remotely). This is the single -/// source of truth for display text and classification. -pub(crate) struct ToolDescriptor { - /// Canonical wire names for this tool (the names the server sends). - pub canonical_names: &'static [&'static str], - /// The capability string the client must advertise for this tool to be - /// accepted. `None` for server-side tools (always accepted). - pub capability: Option<&'static str>, - /// Imperative verb for permission prompts (e.g. "read", "run"). - pub display_verb: &'static str, - /// Present-tense progressive verb for spinners (e.g. "Reading file..."). - pub progressive_verb: &'static str, - /// Past-tense verb for summaries (e.g. "Read file"). - pub past_verb: &'static str, - /// Whether this tool is executed client-side (by the CLI). - #[expect(dead_code)] - pub is_client: bool, -} - -// ── Client-side tool descriptors ── - -pub(crate) const READ: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["read_file"], - capability: Some("client_v1_read_file"), - display_verb: "read", - progressive_verb: "Reading file...", - past_verb: "Read file", - is_client: true, -}; - -pub(crate) const EDIT: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["edit_file"], - capability: Some("client_v1_edit_file"), - display_verb: "edit", - progressive_verb: "Editing file...", - past_verb: "Edited file", - is_client: true, -}; - -pub(crate) const WRITE: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["write_file"], - capability: Some("client_v1_write_file"), - display_verb: "write to", - progressive_verb: "Writing file...", - past_verb: "Wrote file", - is_client: true, -}; - -pub(crate) const SHELL: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["execute_shell_command"], - capability: Some("client_v1_execute_shell_command"), - display_verb: "run", - progressive_verb: "Running command...", - past_verb: "Ran command", - is_client: true, -}; - -pub(crate) const ATUIN_HISTORY: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["atuin_history"], - capability: Some("client_v1_atuin_history"), - display_verb: "search your Atuin history for", - progressive_verb: "Searching...", - past_verb: "Searched", - is_client: true, -}; - -pub(crate) const ATUIN_OUTPUT: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["atuin_output"], - capability: Some("client_v1_atuin_output"), - display_verb: "view the output for command", - progressive_verb: "Viewing output...", - past_verb: "Viewed output", - is_client: true, -}; - -pub(crate) const LOAD_SKILL: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["load_skill"], - capability: Some("client_v1_load_skill"), - display_verb: "load skill", - progressive_verb: "Loading skill...", - past_verb: "Loaded skill", - is_client: true, -}; - -// ── Server-side tool descriptors ── -// These appear in tool summaries but aren't client-side tools. - -pub(crate) const SERVER_SEARCH: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["web_search"], - capability: None, - display_verb: "search", - progressive_verb: "Searching...", - past_verb: "Searched", - is_client: false, -}; - -pub(crate) const SERVER_SCRAPE: &ToolDescriptor = &ToolDescriptor { - canonical_names: &["web_scrape"], - capability: None, - display_verb: "scrape", - progressive_verb: "Scraping...", - past_verb: "Scraped", - is_client: false, -}; - -/// All known tool descriptors, for lookup by name. -const ALL_DESCRIPTORS: &[&ToolDescriptor] = &[ - READ, - EDIT, - WRITE, - SHELL, - ATUIN_HISTORY, - ATUIN_OUTPUT, - LOAD_SKILL, - SERVER_SEARCH, - SERVER_SCRAPE, -]; - -/// Look up a tool descriptor by its canonical wire name. -/// Returns None for unknown tool names. -pub(crate) fn by_name(name: &str) -> Option<&'static ToolDescriptor> { - ALL_DESCRIPTORS - .iter() - .find(|d| d.canonical_names.contains(&name)) - .copied() -} diff --git a/crates/atuin-ai/src/tools/mod.rs b/crates/atuin-ai/src/tools/mod.rs deleted file mode 100644 index d1352661..00000000 --- a/crates/atuin-ai/src/tools/mod.rs +++ /dev/null @@ -1,2159 +0,0 @@ -use std::{ - io::BufRead, - path::{Path, PathBuf}, - time::Duration, -}; - -use eyre::Result; -use uuid::Uuid; - -const DEFAULT_FILE_READ_LINES: u64 = 100; -const MAX_FILE_READ_LINES: u64 = 1000; - -pub(crate) mod descriptor; - -use crate::permissions::rule::Rule; - -/// Check whether a file path matches a scope glob pattern. -/// -/// Resolves relative paths against the current directory before matching so -/// that `./foo.md` and `/cwd/foo.md` match the same glob. Supports `*`, `**`, -/// `?`, and `[...]` via `glob_match`. -fn path_matches_scope(path: &Path, scope: &str) -> bool { - let path = if path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(path)) - .unwrap_or_else(|_| path.to_path_buf()) - } else { - path.to_path_buf() - }; - // Normalize to forward slashes so globs work on Windows too. - let path_str = path.to_string_lossy().replace('\\', "/"); - - // If the scope is also relative, try matching against both the absolute - // path and just the filename/relative portion. - if !scope.starts_with('/') { - // Match against filename (e.g. "*.md" matches any .md file) - if let Some(name) = path.file_name().and_then(|n| n.to_str()) - && glob_match::glob_match(scope, name) - { - return true; - } - // Also try matching against the full absolute path in case the scope - // is a relative multi-segment pattern like "crates/**/*.rs" - if glob_match::glob_match(scope, &path_str) { - return true; - } - // And match relative to cwd (so "src/*.rs" works from project root) - if let Ok(cwd) = std::env::current_dir() - && let Ok(rel) = path.strip_prefix(&cwd) - { - let rel_str = rel.to_string_lossy().replace('\\', "/"); - return glob_match::glob_match(scope, &rel_str); - } - return false; - } - - // Absolute scope — match against absolute path - glob_match::glob_match(scope, &path_str) -} - -/// Result of executing a client-side tool. -#[derive(Debug, Clone)] -pub(crate) enum ToolOutcome { - /// Simple success with a text result (used by Read, AtuinHistory). - Success(String), - /// Error with a message. - Error(String), - /// Structured shell result with separated stdout, stderr, exit code, and duration. - Structured { - stdout: String, - stderr: String, - exit_code: Option<i32>, - duration_ms: u64, - interrupted: bool, - }, -} - -impl ToolOutcome { - /// Format this outcome as a string for the tool result sent to the LLM. - /// - /// The optional `interrupt_reason` overrides the generic interrupted message - /// with a specific one (user interrupt vs timeout). - pub fn format_for_llm( - &self, - interrupt_reason: Option<&crate::fsm::tools::InterruptReason>, - ) -> String { - match self { - ToolOutcome::Success(s) => s.clone(), - ToolOutcome::Error(e) => e.clone(), - ToolOutcome::Structured { - stdout, - stderr, - exit_code, - duration_ms, - interrupted, - } => { - let mut parts = Vec::new(); - - if let Some(code) = exit_code { - parts.push(format!("Exit code: {code}")); - } - - parts.push(format!("Duration: {duration_ms}ms")); - - if !stdout.is_empty() { - parts.push(format!("stdout:\n{stdout}")); - } else { - parts.push("stdout: (empty)".to_string()); - } - - if !stderr.is_empty() { - parts.push(format!("stderr:\n{stderr}")); - } else { - parts.push("stderr: (empty)".to_string()); - } - - if *interrupted { - use crate::fsm::tools::InterruptReason; - let msg = match interrupt_reason { - Some(InterruptReason::Timeout(secs)) => { - format!("[Timed out after {secs}s]") - } - _ => "[Interrupted by user]".to_string(), - }; - parts.push(msg); - } - - parts.join("\n\n") - } - } - } - - /// Whether this outcome represents an error. - pub fn is_error(&self) -> bool { - match self { - ToolOutcome::Error(_) => true, - ToolOutcome::Structured { - exit_code: Some(code), - .. - } if *code != 0 => true, - _ => false, - } - } -} - -/// Cached VT100 preview data for a shell tool call. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct ToolPreview { - pub lines: Vec<String>, - pub exit_code: Option<i32>, - pub interrupted: Option<crate::fsm::tools::InterruptReason>, -} - -/// A tool call from the server, with parsed input parameters. -#[derive(Debug, Clone)] -pub(crate) enum ClientToolCall { - Read(ReadToolCall), - Edit(EditToolCall), - Write(WriteToolCall), - Shell(ShellToolCall), - AtuinHistory(AtuinHistoryToolCall), - AtuinOutput(AtuinOutputToolCall), - LoadSkill(LoadSkillToolCall), -} - -impl TryFrom<(&str, &serde_json::Value)> for ClientToolCall { - type Error = eyre::Error; - - fn try_from((name, input): (&str, &serde_json::Value)) -> Result<Self, Self::Error> { - match name { - "read_file" => Ok(ClientToolCall::Read(ReadToolCall::try_from(input)?)), - "edit_file" => Ok(ClientToolCall::Edit(EditToolCall::try_from(input)?)), - "write_file" => Ok(ClientToolCall::Write(WriteToolCall::try_from(input)?)), - "execute_shell_command" => Ok(ClientToolCall::Shell(ShellToolCall::try_from(input)?)), - "atuin_history" => Ok(ClientToolCall::AtuinHistory( - AtuinHistoryToolCall::try_from(input)?, - )), - "atuin_output" => Ok(ClientToolCall::AtuinOutput(AtuinOutputToolCall::try_from( - input, - )?)), - "load_skill" => Ok(ClientToolCall::LoadSkill(LoadSkillToolCall::try_from( - input, - )?)), - _ => Err(eyre::eyre!("Unknown tool call: {name}")), - } - } -} - -impl ClientToolCall { - pub(crate) fn descriptor(&self) -> &'static descriptor::ToolDescriptor { - match self { - ClientToolCall::Read(_) => descriptor::READ, - ClientToolCall::Edit(_) => descriptor::EDIT, - ClientToolCall::Write(_) => descriptor::WRITE, - ClientToolCall::Shell(_) => descriptor::SHELL, - ClientToolCall::AtuinHistory(_) => descriptor::ATUIN_HISTORY, - ClientToolCall::AtuinOutput(_) => descriptor::ATUIN_OUTPUT, - ClientToolCall::LoadSkill(_) => descriptor::LOAD_SKILL, - } - } - - /// The permission rule name for this tool category. - /// - /// Edit and Write share the `"Write"` rule name — a Write permission - /// covers both str_replace edits and full file creates. Write also - /// implies Read (checked in `ReadToolCall::matches_rule`). - pub(crate) fn rule_name(&self) -> &'static str { - match self { - ClientToolCall::Read(_) => "Read", - ClientToolCall::Edit(_) => "Write", - ClientToolCall::Write(_) => "Write", - ClientToolCall::Shell(_) => "Shell", - ClientToolCall::AtuinHistory(_) => "AtuinHistory", - ClientToolCall::AtuinOutput(_) => "AtuinOutput", - ClientToolCall::LoadSkill(_) => "LoadSkill", - } - } - - /// The resolved file path for this tool call, if it's a file-based tool. - /// Used to build scoped permission rules like `Write(/abs/path/to/file)`. - pub(crate) fn resolved_file_path(&self) -> Option<PathBuf> { - match self { - ClientToolCall::Read(tool) => Some(tool.resolved_path()), - ClientToolCall::Edit(tool) => Some(tool.resolved_path()), - ClientToolCall::Write(tool) => Some(tool.resolved_path()), - ClientToolCall::Shell(_) - | ClientToolCall::AtuinHistory(_) - | ClientToolCall::AtuinOutput(_) - | ClientToolCall::LoadSkill(_) => None, - } - } - - pub(crate) fn matches_rule(&self, rule: &Rule) -> bool { - match self { - ClientToolCall::Read(tool) => tool.matches_rule(rule), - ClientToolCall::Edit(tool) => tool.matches_rule(rule), - ClientToolCall::Write(tool) => tool.matches_rule(rule), - ClientToolCall::Shell(tool) => tool.matches_rule(rule), - ClientToolCall::AtuinHistory(tool) => tool.matches_rule(rule), - ClientToolCall::AtuinOutput(tool) => tool.matches_rule(rule), - ClientToolCall::LoadSkill(tool) => tool.matches_rule(rule), - } - } - - pub(crate) fn target_dir(&self) -> Option<&Path> { - match self { - ClientToolCall::Read(tool) => tool.target_dir(), - ClientToolCall::Edit(tool) => tool.target_dir(), - ClientToolCall::Write(tool) => tool.target_dir(), - ClientToolCall::Shell(tool) => tool.target_dir(), - ClientToolCall::AtuinHistory(tool) => tool.target_dir(), - ClientToolCall::AtuinOutput(tool) => tool.target_dir(), - ClientToolCall::LoadSkill(tool) => tool.target_dir(), - } - } -} - -/// A trait for tool calls that can be checked against permission rules. -pub(crate) trait PermissibleToolCall { - /// Checks if this tool call matches the given permission rule. - fn matches_rule(&self, rule: &Rule) -> bool; - - /// Check if every part of this tool call is covered by at least one rule in - /// the set. For compound operations (e.g. shell pipelines), each sub-part - /// must be individually covered. The default treats the call as atomic — - /// any single matching rule is sufficient. - fn all_covered_by(&self, rules: &[Rule]) -> bool { - rules.iter().any(|r| self.matches_rule(r)) - } - - /// Returns the target directory of this tool call, if applicable, for checking against directory-based rules. - fn target_dir(&self) -> Option<&Path> { - None - } -} - -impl PermissibleToolCall for ClientToolCall { - fn matches_rule(&self, rule: &Rule) -> bool { - self.matches_rule(rule) - } - - fn all_covered_by(&self, rules: &[Rule]) -> bool { - match self { - ClientToolCall::Shell(tool) => tool.all_covered_by(rules), - // LoadSkill is always auto-approved, but support rules for completeness - _ => rules.iter().any(|r| self.matches_rule(r)), - } - } - - fn target_dir(&self) -> Option<&Path> { - self.target_dir() - } -} - -/// Returns true if this tool call should bypass the permission system entirely. -impl ClientToolCall { - pub(crate) fn is_auto_approved(&self) -> bool { - matches!(self, ClientToolCall::LoadSkill(_)) - } -} - -/// Expand shell constructs (`~`, `$HOME`, etc.) in a path string. -/// -/// Tool call paths arrive as raw strings from the API without shell -/// expansion. Uses `shellexpand` (same as `atuin-client`). -fn expand_path(path: &str) -> PathBuf { - PathBuf::from(shellexpand::tilde(path).into_owned()) -} - -#[derive(Debug, Clone)] -pub(crate) struct ReadToolCall { - pub path: PathBuf, - pub offset: u64, - pub limit: u64, -} - -impl TryFrom<&serde_json::Value> for ReadToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing path"))?; - - let offset = value.get("offset").and_then(|v| v.as_u64()).unwrap_or(0); - let limit = value - .get("limit") - .and_then(|v| v.as_u64()) - .unwrap_or(DEFAULT_FILE_READ_LINES) - .min(MAX_FILE_READ_LINES); - - Ok(ReadToolCall { - path: expand_path(path), - offset, - limit, - }) - } -} - -impl ReadToolCall { - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - pub fn execute(&self) -> ToolOutcome { - let path = self.resolved_path(); - - if !path.exists() { - return ToolOutcome::Error(format!("Error: file does not exist: {}", path.display())); - } - - if path.is_dir() { - let Some(files) = std::fs::read_dir(&path).ok().and_then(|entries| { - entries - .filter_map(|entry| entry.ok()) - .map(|entry| entry.file_name().to_string_lossy().to_string()) - .collect::<Vec<_>>() - .into() - }) else { - return ToolOutcome::Error(format!( - "Error: could not read directory: {}", - path.display() - )); - }; - - return ToolOutcome::Success(format!("Directory contents:\n{}", files.join("\n"))); - } - - let file = match std::fs::File::open(&path) { - Ok(file) => file, - Err(e) => return ToolOutcome::Error(format!("Error opening file: {e}")), - }; - let reader = std::io::BufReader::new(file); - - let raw_lines = reader - .lines() - .skip(self.offset as usize) - .take(self.limit as usize) - .collect::<Result<Vec<_>, _>>(); - - match raw_lines { - Ok(lines) => { - let first_line_no = self.offset as usize + 1; - let last_line_no = first_line_no + lines.len().saturating_sub(1); - let width = last_line_no.max(1).ilog10() as usize + 1; - - let numbered: String = lines - .iter() - .enumerate() - .map(|(i, line)| format!("{:>width$}\t{line}", first_line_no + i)) - .collect::<Vec<_>>() - .join("\n"); - - if numbered.len() > 100_000 { - ToolOutcome::Error(format!( - "Error: file is too large to read ({} bytes in {} lines); use view_range to read a subset of the file", - numbered.len(), - lines.len() - )) - } else { - ToolOutcome::Success(numbered) - } - } - Err(e) => ToolOutcome::Error(format!("Error reading file: {e}")), - } - } -} - -impl PermissibleToolCall for ReadToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - // Write implies Read — a Write permission on a path also permits reading it. - if rule.tool != "Read" && rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct EditToolCall { - pub path: PathBuf, - pub old_string: String, - pub new_string: String, - pub replace_all: bool, -} - -impl TryFrom<&serde_json::Value> for EditToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing file_path"))?; - - let old_string = value - .get("old_string") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing old_string"))?; - - let new_string = value - .get("new_string") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing new_string"))?; - - let replace_all = value - .get("replace_all") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - Ok(EditToolCall { - path: expand_path(path), - old_string: old_string.to_string(), - new_string: new_string.to_string(), - replace_all, - }) - } -} - -impl EditToolCall { - /// Resolve the edit path to an absolute path. - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - /// Execute the edit against the filesystem. - /// - /// Checks freshness via the provided tracker, validates matches, applies - /// the replacement, and writes atomically. Returns the outcome and (on - /// success) the new file content bytes for tracker updates. - /// - /// Callers should snapshot the file before calling this method and - /// update the file tracker after a successful return. - pub fn execute( - &self, - resolved_path: &Path, - file_tracker: &crate::file_tracker::FileReadTracker, - ) -> (ToolOutcome, Option<Vec<u8>>) { - use crate::file_tracker::FreshnessCheck; - - // 1. Basic validation - if !resolved_path.exists() { - return ( - ToolOutcome::Error(format!( - "Error: file does not exist: {}", - resolved_path.display() - )), - None, - ); - } - if resolved_path.is_dir() { - return ( - ToolOutcome::Error(format!( - "Error: path is a directory, not a file: {}", - resolved_path.display() - )), - None, - ); - } - if self.old_string.is_empty() { - return ( - ToolOutcome::Error( - "old_string must not be empty. To create a new file, use write_file instead." - .to_string(), - ), - None, - ); - } - - // 2. Freshness check - match file_tracker.check_freshness(resolved_path) { - Ok(FreshnessCheck::NotRead) => { - return ( - ToolOutcome::Error( - "File has not been read yet. Read it first before editing.".to_string(), - ), - None, - ); - } - Ok(FreshnessCheck::Stale) => { - return ( - ToolOutcome::Error( - "File has been modified since read, either by the user or by a linter. Read it again before attempting to edit it.".to_string(), - ), - None, - ); - } - Err(e) => { - return ( - ToolOutcome::Error(format!("Error checking file state: {e}")), - None, - ); - } - Ok(FreshnessCheck::Fresh) => {} - } - - // 3. Read current contents - let content = match std::fs::read_to_string(resolved_path) { - Ok(c) => c, - Err(e) => return (ToolOutcome::Error(format!("Error reading file: {e}")), None), - }; - - // 4. Find and validate matches - let match_count = content.matches(&self.old_string).count(); - - if match_count == 0 { - return ( - ToolOutcome::Error(format!( - "old_string not found in {}. Make sure it matches exactly, including whitespace and indentation.", - resolved_path.display() - )), - None, - ); - } - - if match_count > 1 && !self.replace_all { - return ( - ToolOutcome::Error(format!( - "Found {match_count} matches of old_string in {}, but replace_all is false. Either provide more context to make the match unique, or set replace_all to true.", - resolved_path.display() - )), - None, - ); - } - - // 5. Apply replacement - let new_content = if self.replace_all { - content.replace(&self.old_string, &self.new_string) - } else { - content.replacen(&self.old_string, &self.new_string, 1) - }; - - // 6. Write atomically - let new_bytes = new_content.into_bytes(); - if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &new_bytes) { - return (ToolOutcome::Error(format!("Error writing file: {e}")), None); - } - - // 7. Success - let verb = if match_count == 1 { - "occurrence" - } else { - "occurrences" - }; - ( - ToolOutcome::Success(format!( - "Edited {}: replaced {match_count} {verb} of old_string with new_string.", - resolved_path.display() - )), - Some(new_bytes), - ) - } -} - -impl PermissibleToolCall for EditToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct WriteToolCall { - pub path: PathBuf, - pub content: String, - pub overwrite: bool, -} - -impl TryFrom<&serde_json::Value> for WriteToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let path = value - .get("file_path") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing file_path"))?; - - let content = value - .get("content") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing content"))?; - - let overwrite = value - .get("overwrite") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - Ok(WriteToolCall { - path: expand_path(path), - content: content.to_string(), - overwrite, - }) - } -} - -impl WriteToolCall { - /// Resolve the write path to an absolute path. - pub fn resolved_path(&self) -> PathBuf { - if self.path.is_relative() { - std::env::current_dir() - .map(|cwd| cwd.join(&self.path)) - .unwrap_or_else(|_| self.path.clone()) - } else { - self.path.clone() - } - } - - /// Execute the write operation. - /// - /// Creates a new file or overwrites an existing one (if `overwrite` is set). - /// Returns the outcome and the written bytes (for tracker updates). - pub fn execute(&self, resolved_path: &Path) -> (ToolOutcome, Option<Vec<u8>>) { - if resolved_path.is_dir() { - return ( - ToolOutcome::Error(format!( - "Error: path is a directory, not a file: {}", - resolved_path.display() - )), - None, - ); - } - if resolved_path.exists() && !self.overwrite { - return ( - ToolOutcome::Error(format!( - "File already exists: {}. Set overwrite to true to replace it, or use edit_file to make targeted changes.", - resolved_path.display() - )), - None, - ); - } - - // Capture before the write — after atomic_write the file always exists. - let existed = resolved_path.exists(); - - // Write atomically - let content_bytes = self.content.as_bytes().to_vec(); - if let Err(e) = crate::snapshots::atomic_write_file(resolved_path, &content_bytes) { - return (ToolOutcome::Error(format!("Error writing file: {e}")), None); - } - - let line_count = self.content.lines().count(); - let verb = if existed { "Overwrote" } else { "Created" }; - ( - ToolOutcome::Success(format!( - "{verb} {} ({line_count} lines).", - resolved_path.display() - )), - Some(content_bytes), - ) - } -} - -impl PermissibleToolCall for WriteToolCall { - fn target_dir(&self) -> Option<&Path> { - Some(&self.path) - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Write" { - return false; - } - - match rule.scope.as_deref() { - None | Some("*") => true, - Some(scope) => path_matches_scope(&self.path, scope), - } - } -} - -#[derive(Debug, Clone)] -pub(crate) struct ShellToolCall { - pub dir: Option<PathBuf>, - pub command: String, - pub shell: String, - /// Maximum execution time in seconds (from LLM). Clamped to 1..=600, default 30. - pub timeout_secs: u64, - // allow dead code here; this will be tied into o11y and user-facing descriptions - #[expect(dead_code)] - pub description: Option<String>, -} - -impl TryFrom<&serde_json::Value> for ShellToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let dir = value.get("dir").and_then(|v| v.as_str()); - - let command = value - .get("command") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing command"))?; - - let shell = value - .get("shell") - .and_then(|v| v.as_str()) - .unwrap_or("bash") - .to_string(); - - let timeout_secs = value - .get("timeout") - .and_then(|v| v.as_u64()) - .filter(|&v| v > 0) - .unwrap_or(30) - .min(600); - - let description = value - .get("description") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - Ok(ShellToolCall { - dir: dir.map(expand_path), - command: command.to_string(), - shell, - timeout_secs, - description, - }) - } -} - -impl PermissibleToolCall for ShellToolCall { - fn target_dir(&self) -> Option<&Path> { - self.dir.as_deref() - } - - fn matches_rule(&self, rule: &Rule) -> bool { - if rule.tool != "Shell" { - return false; - } - - let Some(scope) = rule.scope.as_ref() else { - // Shell without scope matches all shell commands - return true; - }; - - let shell_kind = crate::permissions::shell::ShellKind::from_shell_name(&self.shell); - let parsed = crate::permissions::shell::parse_shell_command(&self.command, shell_kind); - // Deny/ask path: prefix_bare = true so `deny = ["Shell(rm)"]` blocks `rm -rf /` - crate::permissions::shell::any_subcommand_matches(&parsed.subcommands, true, scope) - } - - /// For compound shell commands, every subcommand must be individually - /// covered by at least one rule. This ensures that `allow = ["Shell(git *)"]` - /// does not silently permit `git add . && rm -rf /`. - fn all_covered_by(&self, rules: &[Rule]) -> bool { - use crate::permissions::shell; - - let shell_kind = shell::ShellKind::from_shell_name(&self.shell); - let parsed = shell::parse_shell_command(&self.command, shell_kind); - - // If parsing yields nothing, don't vacuously allow — fall through to ask. - !parsed.subcommands.is_empty() - && parsed.subcommands.iter().all(|subcmd| { - rules.iter().any(|rule| { - if rule.tool != "Shell" { - return false; - } - match rule.scope.as_deref() { - None | Some("*") => true, - // Allow path: prefix_bare = false so `Shell(git commit)` - // only allows exactly `git commit`, not `git commit --amend` - Some(scope) => shell::any_subcommand_matches( - std::slice::from_ref(subcmd), - false, - scope, - ), - } - }) - }) - } -} - -/// Preview viewport height for VT100 emulation. -const PREVIEW_HEIGHT: u16 = 10; - -/// Default terminal width for VT100 emulation. -const PREVIEW_WIDTH: u16 = 120; - -/// Normalize newlines for VT100 processing. -/// -/// When subprocess output is captured via pipes (no PTY), bare `\n` (LF) bytes -/// are not translated to `\r\n` (CR+LF) the way a kernel terminal driver would -/// with the `ONLCR` flag. In VT100, LF only moves the cursor down without -/// returning to column 0. This causes lines to start at progressively higher -/// column offsets and eventually wrap, producing garbled output. -/// -/// This function inserts `\r` before any `\n` that isn't already preceded by -/// `\r`, mimicking the terminal driver's ONLCR behavior. -fn normalize_newlines_for_vt100(data: &[u8]) -> Vec<u8> { - let mut out = Vec::with_capacity(data.len() + data.len() / 8); - for (i, &b) in data.iter().enumerate() { - if b == b'\n' && (i == 0 || data[i - 1] != b'\r') { - out.push(b'\r'); - } - out.push(b); - } - out -} - -/// Extract plain text lines from a VT100 screen buffer. -/// -/// Strips trailing blank lines so the result only contains rows with actual -/// content. Without this, the fixed-size VT100 screen (PREVIEW_HEIGHT rows) -/// would always return that many lines, and downstream components that use -/// tail-mode display (like the Viewport) would show the blank padding rows -/// instead of the real output. -fn vt100_screen_lines(screen: &vt100::Screen) -> Vec<String> { - let (rows, cols) = screen.size(); - let mut lines = Vec::with_capacity(rows as usize); - for row in 0..rows { - let mut line = String::with_capacity(cols as usize); - for col in 0..cols { - if let Some(cell) = screen.cell(row, col) { - line.push_str(cell.contents()); - } - } - lines.push(line.trim_end().to_string()); - } - while lines.last().is_some_and(|l| l.is_empty()) { - lines.pop(); - } - lines -} - -/// Strip ANSI escape sequences from raw bytes using a VT100 parser. -/// -/// Uses a large virtual screen so scrollback is preserved, then extracts -/// the plain text contents. This handles all escape sequences (colors, -/// cursor movement, progress bars, etc.) not just simple SGR codes. -fn strip_ansi_via_vt100(raw: &[u8]) -> String { - if raw.is_empty() { - return String::new(); - } - // Normalize bare LF to CR+LF so lines start at column 0 in the VT100 screen. - let normalized = normalize_newlines_for_vt100(raw); - // Feed bytes into a VT100 parser large enough to hold all output, then - // read back the plain text. We estimate rows from the number of newlines - // (not total byte length) because real output typically has short lines - // that would be severely under-counted by a bytes÷width estimate. - let newline_count = normalized.iter().filter(|&&b| b == b'\n').count(); - let wrap_estimate = normalized.len() / PREVIEW_WIDTH as usize; - let estimated_rows = (newline_count + wrap_estimate + 1).min(10_000) as u16; - let mut parser = vt100::Parser::new(estimated_rows, PREVIEW_WIDTH, 0); - parser.process(&normalized); - let screen = parser.screen(); - // screen.contents() returns the full plain-text content with trailing - // whitespace trimmed per line and trailing blank lines removed. - screen.contents() -} - -/// Execute a shell command with VT100 emulation and streaming output. -/// -/// Feeds stdout+stderr into a `vt100::Parser` so that ANSI escape sequences, -/// progress bars (`\r`), and cursor movement are handled correctly. Periodically -/// sends the current screen state as `Vec<String>` through `output_tx` for the -/// live preview. -/// -/// Captures the FULL stdout and stderr separately for the tool result sent to the LLM. -/// Returns a `ToolOutcome::Structured` with full output, exit code, and duration. -pub(crate) async fn execute_shell_command_streaming( - shell_call: &ShellToolCall, - output_tx: tokio::sync::mpsc::Sender<Vec<String>>, - mut interrupt_rx: tokio::sync::oneshot::Receiver<()>, -) -> ToolOutcome { - use tokio::io::AsyncReadExt; - - let start = std::time::Instant::now(); - - // TODO: check if this is proper for all shells we support - let mut cmd = tokio::process::Command::new(&shell_call.shell); - cmd.arg("-c").arg(&shell_call.command); - cmd.stdout(std::process::Stdio::piped()); - cmd.stderr(std::process::Stdio::piped()); - - if let Some(ref dir) = shell_call.dir { - cmd.current_dir(dir); - } - - let mut child = match cmd.spawn() { - Ok(child) => child, - Err(e) => return ToolOutcome::Error(format!("Failed to spawn command: {e}")), - }; - - let stdout = child.stdout.take().expect("stdout was piped"); - let stderr = child.stderr.take().expect("stderr was piped"); - - // VT100 emulator for the live preview (viewport-sized) - let mut parser = vt100::Parser::new(PREVIEW_HEIGHT, PREVIEW_WIDTH, 0); - - let mut stdout_reader = tokio::io::BufReader::new(stdout); - let mut stderr_reader = tokio::io::BufReader::new(stderr); - - let mut stdout_buf = [0u8; 4096]; - let mut stderr_buf = [0u8; 4096]; - let mut stdout_done = false; - let mut stderr_done = false; - - // Full output buffers (for the LLM, not the preview) - let mut full_stdout = Vec::<u8>::new(); - let mut full_stderr = Vec::<u8>::new(); - - let mut interval = tokio::time::interval(Duration::from_millis(50)); - - // Send initial empty screen - let initial_lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(initial_lines).await; - - let mut interrupted = false; - - loop { - tokio::select! { - biased; - - // Check for interrupt signal - _ = &mut interrupt_rx, if !interrupted => { - interrupted = true; - let _ = child.start_kill(); - } - - // Read stdout - result = stdout_reader.read(&mut stdout_buf), if !stdout_done => { - match result { - Ok(0) => stdout_done = true, - Ok(n) => { - full_stdout.extend_from_slice(&stdout_buf[..n]); - let normalized = normalize_newlines_for_vt100(&stdout_buf[..n]); - parser.process(&normalized); - } - Err(_) => stdout_done = true, - } - } - - // Read stderr - result = stderr_reader.read(&mut stderr_buf), if !stderr_done => { - match result { - Ok(0) => stderr_done = true, - Ok(n) => { - full_stderr.extend_from_slice(&stderr_buf[..n]); - // Feed stderr to the preview parser too, so it shows in the VT100 screen - let normalized = normalize_newlines_for_vt100(&stderr_buf[..n]); - parser.process(&normalized); - } - Err(_) => stderr_done = true, - } - } - - // Periodic screen snapshot for preview - _ = interval.tick() => { - let lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(lines).await; - } - } - - // Exit when both streams are done - if stdout_done && stderr_done { - break; - } - } - - // Wait for process to finish - let exit_code = match child.wait().await { - Ok(status) => status.code(), - Err(e) => { - if interrupted { - None - } else { - return ToolOutcome::Error(format!("Failed to wait for command: {e}")); - } - } - }; - - let duration = start.elapsed(); - - // Send final screen state - let final_lines = vt100_screen_lines(parser.screen()); - let _ = output_tx.send(final_lines).await; - - // Strip ANSI escape sequences for clean LLM output by running - // the raw bytes through a VT100 parser and extracting plain text. - let stdout_text = strip_ansi_via_vt100(&full_stdout); - let stderr_text = strip_ansi_via_vt100(&full_stderr); - - ToolOutcome::Structured { - stdout: stdout_text, - stderr: stderr_text, - exit_code, - duration_ms: duration.as_millis() as u64, - interrupted, - } -} - -#[derive(Debug, Clone)] -pub(crate) struct AtuinHistoryToolCall { - pub filter_modes: Vec<HistorySearchFilterMode>, - pub query: String, - pub limit: i64, -} - -#[derive(Debug, Clone)] -pub(crate) enum HistorySearchFilterMode { - Global, - Host, - Session, - Directory, - Workspace, -} - -impl From<&HistorySearchFilterMode> for atuin_client::settings::FilterMode { - fn from(mode: &HistorySearchFilterMode) -> Self { - match mode { - HistorySearchFilterMode::Global => Self::Global, - HistorySearchFilterMode::Host => Self::Host, - HistorySearchFilterMode::Session => Self::Session, - HistorySearchFilterMode::Directory => Self::Directory, - HistorySearchFilterMode::Workspace => Self::Workspace, - } - } -} - -impl TryFrom<&serde_json::Value> for AtuinHistoryToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let filter_modes = value - .get("filter_modes") - .and_then(|v| v.as_array()) - .ok_or(eyre::eyre!("Missing filter_modes"))?; - - let filter_modes = filter_modes - .iter() - .map(|v| { - let mode = v.as_str().ok_or(eyre::eyre!("Invalid filter mode"))?; - match mode { - "global" => Ok(HistorySearchFilterMode::Global), - "host" => Ok(HistorySearchFilterMode::Host), - "session" => Ok(HistorySearchFilterMode::Session), - "directory" => Ok(HistorySearchFilterMode::Directory), - "workspace" => Ok(HistorySearchFilterMode::Workspace), - _ => Err(eyre::eyre!("Invalid filter mode: {mode}")), - } - }) - .collect::<Result<Vec<HistorySearchFilterMode>>>()?; - - let query = value - .get("query") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing query"))?; - - let limit = value - .get("limit") - .and_then(|v| v.as_i64()) - .unwrap_or(10) - .clamp(1, 50); - - Ok(AtuinHistoryToolCall { - filter_modes, - query: query.to_string(), - limit, - }) - } -} - -impl PermissibleToolCall for AtuinHistoryToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "AtuinHistory" - } -} - -impl AtuinHistoryToolCall { - pub(crate) async fn execute(&self, db: &atuin_client::database::Sqlite) -> ToolOutcome { - use atuin_client::database::{self, Database as _, OptFilters}; - use atuin_client::settings::SearchMode; - - let context = match database::current_context().await { - Ok(ctx) => ctx, - Err(e) => return ToolOutcome::Error(format!("Failed to get history context: {e}")), - }; - - let filter_mode = self - .filter_modes - .first() - .map(atuin_client::settings::FilterMode::from) - .unwrap_or(atuin_client::settings::FilterMode::Global); - - let filter_options = OptFilters { - limit: Some(self.limit), - ..Default::default() - }; - - let results = match db - .search( - SearchMode::Fuzzy, - filter_mode, - &context, - &self.query, - filter_options, - ) - .await - { - Ok(results) => results, - Err(e) => return ToolOutcome::Error(format!("History search failed: {e}")), - }; - - if results.is_empty() { - return ToolOutcome::Success("No matching history entries found.".to_string()); - } - - let local_offset = crate::history_format::current_local_offset(); - - let formatted: Vec<String> = results - .iter() - .enumerate() - .map(|(i, history)| { - crate::history_format::format_history_search_result(i + 1, history, local_offset) - }) - .collect(); - - ToolOutcome::Success(formatted.join("\n")) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct AtuinOutputToolCall { - pub history_id: Uuid, - pub ranges: Vec<(i64, i64)>, -} - -impl TryFrom<&serde_json::Value> for AtuinOutputToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let history_id = value - .get("history_id") - .and_then(|v| v.as_str()) - .and_then(|v| Uuid::parse_str(v).ok()) - .ok_or(eyre::eyre!("Missing or invalid history ID"))?; - - let ranges = value - .get("ranges") - .and_then(|v| v.as_array()) - .map(Vec::as_slice) - .unwrap_or(&[]); - - let ranges = ranges - .iter() - .map(|r| { - let range = r - .as_array() - .filter(|a| a.len() == 2) - .ok_or_else(|| eyre::eyre!("Each range must be a [start, end] array"))?; - - let start = range[0] - .as_i64() - .ok_or_else(|| eyre::eyre!("Range start must be an integer"))?; - let end = range[1] - .as_i64() - .ok_or_else(|| eyre::eyre!("Range end must be an integer"))?; - - Ok((start, end)) - }) - .collect::<Result<Vec<(i64, i64)>, eyre::Error>>()?; - - Ok(Self { history_id, ranges }) - } -} - -impl PermissibleToolCall for AtuinOutputToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "AtuinOutput" - } -} - -fn format_output_lines_for_llm(lines: &[atuin_daemon::semantic::OutputLine]) -> String { - let width = lines - .iter() - .map(|line| line.line_number) - .max() - .unwrap_or(1) - .max(1) - .ilog10() as usize - + 1; - let mut formatted = Vec::with_capacity(lines.len()); - let mut previous_line_number = None; - - for line in lines { - if let Some(previous) = previous_line_number { - let skipped = line.line_number.saturating_sub(previous + 1); - if skipped > 0 { - formatted.push(format!("[...skipped {skipped} lines...]")); - } - } - - formatted.push(format!("{:>width$}\t{}", line.line_number, line.content)); - previous_line_number = Some(line.line_number); - } - - formatted.join("\n") -} - -impl AtuinOutputToolCall { - pub(crate) async fn execute(&self) -> ToolOutcome { - let settings = match atuin_client::settings::Settings::new() { - Ok(settings) => settings, - Err(e) => return ToolOutcome::Error(format!("Failed to load Atuin settings: {e}")), - }; - - let mut client = match atuin_daemon::SemanticClient::from_settings(&settings).await { - Ok(client) => client, - Err(e) => return ToolOutcome::Error(format!("Failed to connect to Atuin daemon: {e}")), - }; - - let history_id = self.history_id.as_simple().to_string(); - let response = match client - .command_output(history_id.clone(), self.ranges.clone()) - .await - { - Ok(response) => response, - Err(e) => return ToolOutcome::Error(format!("Failed to fetch command output: {e}")), - }; - - if !response.found { - return ToolOutcome::Success(format!( - "No captured output found for history ID {history_id}." - )); - } - - if response.total_lines == 0 { - return ToolOutcome::Success(format!( - "Captured output for history ID {history_id} is empty." - )); - } - - let output = format_output_lines_for_llm(&response.lines); - if output.is_empty() { - return ToolOutcome::Success(format!( - "No lines selected from captured output for history ID {history_id}." - )); - } - - let total_output = if response.output_truncated { - format!( - "{} bytes captured, {} bytes observed before truncation, {} lines", - response.total_bytes, response.output_observed_bytes, response.total_lines - ) - } else { - format!( - "{} bytes, {} lines", - response.total_bytes, response.total_lines - ) - }; - - ToolOutcome::Success(format!( - "History ID: {history_id}\nTotal output: {total_output}\nSelected output:\n{output}" - )) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct LoadSkillToolCall { - pub name: String, -} - -impl TryFrom<&serde_json::Value> for LoadSkillToolCall { - type Error = eyre::Error; - - fn try_from(value: &serde_json::Value) -> Result<Self, Self::Error> { - let name = value - .get("name") - .and_then(|v| v.as_str()) - .ok_or(eyre::eyre!("Missing skill name"))?; - - Ok(LoadSkillToolCall { - name: name.to_string(), - }) - } -} - -impl PermissibleToolCall for LoadSkillToolCall { - fn target_dir(&self) -> Option<&Path> { - None - } - - fn matches_rule(&self, rule: &Rule) -> bool { - rule.tool == "LoadSkill" - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn read_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Read".to_string(), - scope: scope.map(String::from), - } - } - - fn write_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Write".to_string(), - scope: scope.map(String::from), - } - } - - fn read_tool(path: &str) -> ReadToolCall { - ReadToolCall { - path: expand_path(path), - offset: 0, - limit: 100, - } - } - - fn write_tool(path: &str) -> WriteToolCall { - WriteToolCall { - path: expand_path(path), - content: String::new(), - overwrite: false, - } - } - - // ── Cross-platform tests ── - - #[test] - fn atuin_output_ranges_are_optional() { - let input = serde_json::json!({ - "history_id": "018f0000000070008000000000000000" - }); - - let call = AtuinOutputToolCall::try_from(&input).unwrap(); - - assert_eq!( - call.history_id.as_simple().to_string(), - "018f0000000070008000000000000000" - ); - assert!(call.ranges.is_empty()); - } - - #[test] - fn atuin_output_parses_line_ranges() { - let input = serde_json::json!({ - "history_id": "018f0000000070008000000000000000", - "ranges": [[0, 30], [-100, -1]] - }); - - let call = AtuinOutputToolCall::try_from(&input).unwrap(); - - assert_eq!(call.ranges, vec![(0, 30), (-100, -1)]); - } - - #[test] - fn atuin_output_formats_lines_like_read_file() { - let lines = vec![ - atuin_daemon::semantic::OutputLine { - line_number: 98, - content: "near end".to_string(), - }, - atuin_daemon::semantic::OutputLine { - line_number: 100, - content: "end".to_string(), - }, - ]; - - assert_eq!( - format_output_lines_for_llm(&lines), - " 98\tnear end\n[...skipped 1 lines...]\n100\tend" - ); - } - - #[test] - fn no_scope_matches_everything() { - assert!(read_tool("any/path.txt").matches_rule(&read_rule(None))); - assert!(write_tool("any/path.txt").matches_rule(&write_rule(None))); - } - - #[test] - fn wildcard_star_matches_everything() { - assert!(read_tool("foo/bar.rs").matches_rule(&read_rule(Some("*")))); - } - - #[test] - fn write_implies_read() { - // A Write rule also permits reads on the same path - assert!(read_tool("foo.txt").matches_rule(&write_rule(None))); - // But a Read rule does not permit writes - assert!(!write_tool("foo.txt").matches_rule(&read_rule(None))); - } - - #[test] - fn edit_uses_write_rule() { - let edit = EditToolCall { - path: expand_path("/home/user/config.toml"), - old_string: "x".into(), - new_string: "y".into(), - replace_all: false, - }; - assert!(edit.matches_rule(&write_rule(None))); - assert!(!edit.matches_rule(&read_rule(None))); - } - - #[test] - fn extension_glob() { - assert!(read_tool("notes.md").matches_rule(&read_rule(Some("*.md")))); - assert!(!read_tool("notes.txt").matches_rule(&read_rule(Some("*.md")))); - } - - #[test] - fn relative_multi_segment_glob() { - // This matches against the path relative to cwd - let cwd = std::env::current_dir().unwrap(); - let abs = cwd - .join("crates") - .join("atuin-ai") - .join("src") - .join("lib.rs"); - let tool = read_tool(abs.to_str().unwrap()); - assert!(tool.matches_rule(&read_rule(Some("crates/**/*.rs")))); - assert!(!tool.matches_rule(&read_rule(Some("crates/**/*.py")))); - } - - // ── all_covered_by tests (compound shell command semantics) ── - - fn shell_rule(scope: Option<&str>) -> Rule { - Rule { - tool: "Shell".to_string(), - scope: scope.map(String::from), - } - } - - fn shell_tool(command: &str) -> ShellToolCall { - ShellToolCall { - dir: None, - command: command.to_string(), - shell: "bash".to_string(), - timeout_secs: 30, - description: None, - } - } - - #[test] - fn all_covered_by_simple_command() { - let rules = vec![shell_rule(Some("git *"))]; - assert!(shell_tool("git add .").all_covered_by(&rules)); - assert!(!shell_tool("npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_compound_all_covered() { - let rules = vec![shell_rule(Some("git *")), shell_rule(Some("npm *"))]; - assert!(shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_compound_partially_covered() { - // Only git is allowed — npm subcommand is not covered, so the - // compound command must not be auto-allowed. - let rules = vec![shell_rule(Some("git *"))]; - assert!(!shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_unscoped_shell_rule() { - // Shell without scope covers everything - let rules = vec![shell_rule(None)]; - assert!(shell_tool("git add . && rm -rf /").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_wildcard_shell_rule() { - let rules = vec![shell_rule(Some("*"))]; - assert!(shell_tool("git add . && npm test").all_covered_by(&rules)); - } - - #[test] - fn all_covered_by_non_shell_tool_unchanged() { - // Non-shell tools use the default (any single rule matches) - let rules = vec![read_rule(Some("*.md"))]; - assert!(read_tool("notes.md").all_covered_by(&rules)); - assert!(!read_tool("notes.txt").all_covered_by(&rules)); - } - - #[test] - fn matches_rule_still_uses_any_semantics() { - // matches_rule (used for deny/ask) still triggers on any subcommand - let rule = shell_rule(Some("rm *")); - assert!(shell_tool("git add . && rm -rf /").matches_rule(&rule)); - } - - #[test] - fn bare_pattern_asymmetry() { - // Deny (matches_rule, prefix_bare=true): bare "rm" blocks "rm -rf /" - let deny_rule = shell_rule(Some("rm")); - assert!(shell_tool("rm -rf /").matches_rule(&deny_rule)); - - // Allow (all_covered_by, prefix_bare=false): bare "rm" only allows exactly "rm" - let allow_rules = vec![shell_rule(Some("rm"))]; - assert!(shell_tool("rm").all_covered_by(&allow_rules)); - assert!(!shell_tool("rm -rf /").all_covered_by(&allow_rules)); - - // Bare prefix match is word-boundary, not substring — "rm" must not match "rmbackup" - assert!(!shell_tool("rmbackup").matches_rule(&deny_rule)); - assert!(!shell_tool("rmbackup /tmp").matches_rule(&deny_rule)); - } - - // ── Unix-specific tests (absolute paths with forward slashes) ── - - #[cfg(unix)] - mod unix { - use super::*; - - #[test] - fn absolute_glob() { - assert!( - read_tool("/home/user/src/main.rs") - .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) - ); - assert!( - !read_tool("/home/user/docs/readme.md") - .matches_rule(&read_rule(Some("/home/user/src/*.rs"))) - ); - } - - #[test] - fn double_star_glob() { - assert!( - read_tool("/project/crates/foo/src/lib.rs") - .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) - ); - assert!( - !read_tool("/project/crates/foo/src/lib.py") - .matches_rule(&read_rule(Some("/project/crates/**/*.rs"))) - ); - } - } - - // ── edit_file execution tests ── - - mod edit { - use super::*; - use crate::file_tracker::FileReadTracker; - - /// Helper: create a temp file (with a closed handle), record it in a tracker. - /// Returns the TempDir (keeps the path alive) and tracker. - /// The file handle is closed so atomic_write_file can rename over it on Windows. - fn setup_tracked_file(content: &str) -> (tempfile::TempDir, PathBuf, FileReadTracker) { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("test_file.toml"); - std::fs::write(&path, content).unwrap(); - - let file_content = std::fs::read(&path).unwrap(); - let mtime = std::fs::metadata(&path).unwrap().modified().unwrap(); - - let mut tracker = FileReadTracker::default(); - tracker.record_read(path.clone(), &file_content, mtime); - - (dir, path, tracker) - } - - fn edit_call(path: &Path, old: &str, new: &str, replace_all: bool) -> EditToolCall { - EditToolCall { - path: path.to_path_buf(), - old_string: old.to_string(), - new_string: new.to_string(), - replace_all, - } - } - - #[test] - fn successful_single_replacement() { - let (_dir, path, tracker) = setup_tracked_file("[section]\nkey = old_value\n"); - - let call = edit_call(&path, "old_value", "new_value", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "[section]\nkey = new_value\n" - ); - } - - #[test] - fn successful_replace_all() { - let (_dir, path, tracker) = setup_tracked_file("aaa bbb aaa ccc aaa"); - - let call = edit_call(&path, "aaa", "xxx", true); - let (outcome, _) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("3 occurrences"))); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "xxx bbb xxx ccc xxx" - ); - } - - #[test] - fn error_file_not_read() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("unread.txt"); - std::fs::write(&path, "content").unwrap(); - let tracker = FileReadTracker::default(); // empty — never read - - let call = edit_call(&path, "x", "y", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("not been read yet"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_file_modified_since_read() { - let (_dir, path, tracker) = setup_tracked_file("original"); - - // Modify the file after the read was recorded - std::thread::sleep(std::time::Duration::from_millis(10)); - std::fs::write(&path, "modified externally").unwrap(); - - let call = edit_call(&path, "original", "replaced", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("modified since read"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_no_match() { - let (_dir, path, tracker) = setup_tracked_file("hello world"); - - let call = edit_call(&path, "nonexistent", "replacement", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("not found"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn error_multiple_matches_without_replace_all() { - let (_dir, path, tracker) = setup_tracked_file("foo bar foo baz foo"); - - let call = edit_call(&path, "foo", "qux", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("3 matches"), "got: {msg}"); - assert!(msg.contains("replace_all"), "got: {msg}"); - } - _ => panic!("expected error"), - } - // File should be unchanged - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "foo bar foo baz foo" - ); - } - - #[test] - fn error_empty_old_string() { - let (_dir, path, tracker) = setup_tracked_file("content"); - - let call = edit_call(&path, "", "something", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - assert!(matches!(outcome, ToolOutcome::Error(_))); - } - - #[test] - fn error_file_does_not_exist() { - let tracker = FileReadTracker::default(); - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("nonexistent.txt"); - - let call = edit_call(&path, "x", "y", false); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("does not exist"), "got: {msg}"); - } - _ => panic!("expected error"), - } - } - - #[test] - fn preserves_file_when_no_match() { - let original = "[config]\nport = 8080\nhost = localhost\n"; - let (_dir, path, tracker) = setup_tracked_file(original); - - let call = edit_call(&path, "port = 9090", "port = 3000", false); - let (outcome, _) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Error(_))); - assert_eq!(std::fs::read_to_string(&path).unwrap(), original); - } - - #[test] - fn multiline_replacement() { - let content = "[section]\nkey1 = val1\nkey2 = val2\n[other]\n"; - let (_dir, path, tracker) = setup_tracked_file(content); - - let call = edit_call( - &path, - "key1 = val1\nkey2 = val2", - "key1 = new1\nkey2 = new2", - false, - ); - let (outcome, new_bytes) = call.execute(&path, &tracker); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "[section]\nkey1 = new1\nkey2 = new2\n[other]\n" - ); - } - } - - // ── Integration tests: full edit lifecycle ── - // - // These exercise the cross-component flow that dispatch orchestrates: - // FileReadTracker → SnapshotStore → EditToolCall.execute → tracker update - - mod edit_integration { - use super::*; - use crate::edit_permissions::EditPermissionCache; - use crate::file_tracker::FileReadTracker; - use crate::snapshots::SnapshotStore; - - /// Simulate a file read (what dispatch does after ReadToolCall.execute). - fn simulate_read(tracker: &mut FileReadTracker, path: &std::path::Path) { - let content = std::fs::read(path).unwrap(); - let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); - tracker.record_read(path.to_path_buf(), &content, mtime); - } - - /// Simulate a tracker update after edit (what dispatch does after execute). - fn simulate_tracker_update( - tracker: &mut FileReadTracker, - path: &std::path::Path, - new_bytes: &[u8], - ) { - let mtime = std::fs::metadata(path).unwrap().modified().unwrap(); - tracker.update_after_edit(path, new_bytes, mtime); - } - - #[test] - fn full_read_snapshot_edit_cycle() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "[db]\nhost = localhost\nport = 5432\n").unwrap(); - - let snapshot_dir = dir.path().join("snapshots").join("session-1"); - let mut tracker = FileReadTracker::default(); - let mut store = SnapshotStore::open(snapshot_dir.clone()).unwrap(); - - // 1. Simulate reading the file - simulate_read(&mut tracker, &file_path); - - // 2. Snapshot before edit - let original = std::fs::read(&file_path).unwrap(); - store.ensure_snapshot(&file_path, &original).unwrap(); - - // 3. Execute edit - let call = EditToolCall { - path: file_path.clone(), - old_string: "host = localhost".to_string(), - new_string: "host = 10.0.0.1".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - let new_bytes = new_bytes.unwrap(); - - // 4. Update tracker (simulating what dispatch does) - simulate_tracker_update(&mut tracker, &file_path, &new_bytes); - - // Verify: file was edited - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "[db]\nhost = 10.0.0.1\nport = 5432\n" - ); - - // Verify: snapshot has original content - assert!(store.has_snapshot(&file_path)); - let snapshot_name = crate::snapshots::sanitize_path(&file_path); - let snapshot_content = - std::fs::read_to_string(snapshot_dir.join(snapshot_name)).unwrap(); - assert_eq!(snapshot_content, "[db]\nhost = localhost\nport = 5432\n"); - } - - #[test] - fn second_edit_without_reread() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "key1 = aaa\nkey2 = bbb\n").unwrap(); - - let mut tracker = FileReadTracker::default(); - - // Read the file - simulate_read(&mut tracker, &file_path); - - // First edit - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "key1 = aaa".to_string(), - new_string: "key1 = xxx".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call1.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // Second edit — should work without re-reading because tracker was updated - let call2 = EditToolCall { - path: file_path.clone(), - old_string: "key2 = bbb".to_string(), - new_string: "key2 = yyy".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call2.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "key1 = xxx\nkey2 = yyy\n" - ); - } - - #[test] - fn external_modification_between_edits() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "value = original\n").unwrap(); - - let mut tracker = FileReadTracker::default(); - simulate_read(&mut tracker, &file_path); - - // First edit succeeds - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "value = original".to_string(), - new_string: "value = edited".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call1.execute(&file_path, &tracker); - assert!(matches!(outcome, ToolOutcome::Success(_))); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // External modification (e.g., user edits the file) - std::thread::sleep(std::time::Duration::from_millis(10)); - std::fs::write(&file_path, "value = user_changed\n").unwrap(); - - // Second edit should fail (stale) - let call2 = EditToolCall { - path: file_path.clone(), - old_string: "value = edited".to_string(), - new_string: "value = second_edit".to_string(), - replace_all: false, - }; - let (outcome, new_bytes) = call2.execute(&file_path, &tracker); - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => assert!(msg.contains("modified since read")), - _ => panic!("expected stale error"), - } - - // File should be unchanged (the user's edit preserved) - assert_eq!( - std::fs::read_to_string(&file_path).unwrap(), - "value = user_changed\n" - ); - } - - #[test] - fn snapshot_only_created_once_per_file() { - let dir = tempfile::tempdir().unwrap(); - let file_path = dir.path().join("config.toml"); - std::fs::write(&file_path, "a = 1\nb = 2\n").unwrap(); - - let snapshot_dir = dir.path().join("snapshots").join("session-1"); - let mut tracker = FileReadTracker::default(); - let mut store = SnapshotStore::open(snapshot_dir).unwrap(); - - simulate_read(&mut tracker, &file_path); - - // First edit — snapshot should be created - let original = std::fs::read(&file_path).unwrap(); - let created = store.ensure_snapshot(&file_path, &original).unwrap(); - assert!(created); - - let call1 = EditToolCall { - path: file_path.clone(), - old_string: "a = 1".to_string(), - new_string: "a = 10".to_string(), - replace_all: false, - }; - let (_, new_bytes) = call1.execute(&file_path, &tracker); - simulate_tracker_update(&mut tracker, &file_path, &new_bytes.unwrap()); - - // Second edit — snapshot should NOT be recreated - let content_before_second = std::fs::read(&file_path).unwrap(); - let created = store - .ensure_snapshot(&file_path, &content_before_second) - .unwrap(); - assert!(!created); // idempotent — already snapshotted - } - - #[test] - fn permission_cache_grant_and_check() { - let mut cache = EditPermissionCache::default(); - let path = std::path::PathBuf::from("/Users/me/.config/atuin/config.toml"); - - // Initially no grant - assert!(!cache.has_valid_grant(&path)); - - // Grant permission - cache.grant(path.clone()); - assert!(cache.has_valid_grant(&path)); - - // Different file has no grant - assert!(!cache.has_valid_grant(std::path::Path::new("/other/file.toml"))); - - // Roundtrip through JSON (simulates session persistence) - let json = cache.to_json().unwrap(); - let restored = EditPermissionCache::from_json(&json).unwrap(); - assert!(restored.has_valid_grant(&path)); - } - } - - // ── write_file execution tests ── - - mod write { - use super::*; - - #[test] - fn creates_new_file() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("new_file.txt"); - - let call = WriteToolCall { - path: path.clone(), - content: "hello\nworld\n".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(ref s) if s.contains("Created"))); - assert!(new_bytes.is_some()); - assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello\nworld\n"); - } - - #[test] - fn error_file_exists_without_overwrite() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("existing.txt"); - std::fs::write(&path, "original").unwrap(); - - let call = WriteToolCall { - path: path.clone(), - content: "new content".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(new_bytes.is_none()); - match outcome { - ToolOutcome::Error(msg) => { - assert!(msg.contains("already exists"), "got: {msg}"); - assert!(msg.contains("overwrite"), "got: {msg}"); - } - _ => panic!("expected error"), - } - // Original preserved - assert_eq!(std::fs::read_to_string(&path).unwrap(), "original"); - } - - #[test] - fn overwrites_existing_file_when_flag_set() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("existing.txt"); - std::fs::write(&path, "original").unwrap(); - - let call = WriteToolCall { - path: path.clone(), - content: "replaced content\n".to_string(), - overwrite: true, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert!(new_bytes.is_some()); - assert_eq!( - std::fs::read_to_string(&path).unwrap(), - "replaced content\n" - ); - } - - #[test] - fn creates_parent_directories() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().join("sub").join("dir").join("file.txt"); - - let call = WriteToolCall { - path: path.clone(), - content: "nested\n".to_string(), - overwrite: false, - }; - let (outcome, _) = call.execute(&path); - - assert!(matches!(outcome, ToolOutcome::Success(_))); - assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n"); - } - - #[test] - fn error_path_is_directory() { - let dir = tempfile::tempdir().unwrap(); - let path = dir.path().to_path_buf(); - - let call = WriteToolCall { - path: path.clone(), - content: "content".to_string(), - overwrite: false, - }; - let (outcome, new_bytes) = call.execute(&path); - - assert!(new_bytes.is_none()); - assert!(matches!(outcome, ToolOutcome::Error(ref msg) if msg.contains("directory"))); - } - } - - // ── Windows-specific tests (absolute paths with drive letters) ── - - #[cfg(windows)] - mod windows { - use super::*; - - #[test] - fn absolute_glob() { - assert!( - read_tool(r"C:\Users\dev\src\main.rs") - .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) - ); - assert!( - !read_tool(r"C:\Users\dev\docs\readme.md") - .matches_rule(&read_rule(Some("C:/Users/dev/src/*.rs"))) - ); - } - - #[test] - fn double_star_glob() { - assert!( - read_tool(r"C:\project\crates\foo\src\lib.rs") - .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) - ); - assert!( - !read_tool(r"C:\project\crates\foo\src\lib.py") - .matches_rule(&read_rule(Some("C:/project/crates/**/*.rs"))) - ); - } - } -} diff --git a/crates/atuin-ai/src/tui/components/atuin_ai.rs b/crates/atuin-ai/src/tui/components/atuin_ai.rs deleted file mode 100644 index 31dff1c3..00000000 --- a/crates/atuin-ai/src/tui/components/atuin_ai.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! Top-level AtuinAi component that translates key events into AiTuiEvents. -//! -//! Global shortcuts (Ctrl+C, Esc) are handled in the capture phase so they -//! fire regardless of which child is focused. Contextual shortcuts (Enter, -//! Tab) are handled in the bubble phase so child components like the -//! permission Select can consume them first. - -use crossterm::event::{Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; -use eye_declare::{Elements, EventResult, Hooks, component, props}; - -use crate::commands::inline::DriverEventSender; -use crate::tui::events::AiTuiEvent; -use crate::tui::state::AppMode; - -/// Top-level wrapper component for the AI TUI. -/// -/// Props carry the current mode so `handle_event` can translate keys -/// into the right `AiTuiEvent`. Children are rendered via slot children. -#[props] -pub(crate) struct AtuinAi { - pub mode: AppMode, - pub has_command: bool, - pub is_input_blank: bool, - pub pending_confirmation: bool, - pub has_executing_preview: bool, -} - -#[derive(Default)] -pub(crate) struct AtuinAiState { - tx: Option<DriverEventSender>, -} - -#[component(props = AtuinAi, state = AtuinAiState, children = Elements)] -fn atuin_ai( - _props: &AtuinAi, - _state: &AtuinAiState, - hooks: &mut Hooks<AtuinAi, AtuinAiState>, - children: Elements, -) -> Elements { - hooks.use_context::<DriverEventSender>(|tx, _, state| { - state.tx = tx.cloned(); - }); - - // Capture phase: global shortcuts that must fire regardless of child focus. - hooks.use_event_capture(move |event, props, state| { - let Event::Key(KeyEvent { - code, - kind: KeyEventKind::Press, - modifiers, - .. - }) = event - else { - return EventResult::Ignored; - }; - - let Some(ref tx) = state.read().tx else { - return EventResult::Ignored; - }; - - // Ctrl+C — interrupt executing command or exit - if modifiers.contains(KeyModifiers::CONTROL) && *code == KeyCode::Char('c') { - if props.has_executing_preview { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - } else { - let _ = tx.send(AiTuiEvent::Exit); - } - return EventResult::Consumed; - } - - // Esc — always handled at the top level - if *code == KeyCode::Esc { - match props.mode { - AppMode::Input => { - if props.has_executing_preview { - let _ = tx.send(AiTuiEvent::InterruptToolExecution); - } else if props.pending_confirmation { - let _ = tx.send(AiTuiEvent::CancelConfirmation); - } else { - let _ = tx.send(AiTuiEvent::Exit); - } - } - AppMode::Generating | AppMode::Streaming => { - let _ = tx.send(AiTuiEvent::CancelGeneration); - } - AppMode::Error => { - let _ = tx.send(AiTuiEvent::Exit); - } - } - return EventResult::Consumed; - } - - if *code == KeyCode::Tab - && matches!(props.mode, AppMode::Input) - && modifiers.contains(KeyModifiers::NONE) - && props.has_command - && props.is_input_blank - { - let _ = tx.send(AiTuiEvent::InsertCommand); - return EventResult::Consumed; - } - - EventResult::Ignored - }); - - // Bubble phase: contextual shortcuts that children (e.g. Select) may handle first. - hooks.use_event(move |event, props, state| { - let Event::Key(KeyEvent { - code, - kind: KeyEventKind::Press, - .. - }) = event - else { - return EventResult::Ignored; - }; - - let Some(ref tx) = state.read().tx else { - return EventResult::Ignored; - }; - - match props.mode { - AppMode::Input => match code { - KeyCode::Enter => { - if props.has_command && props.is_input_blank { - let _ = tx.send(AiTuiEvent::ExecuteCommand); - return EventResult::Consumed; - } - EventResult::Ignored - } - _ => EventResult::Ignored, - }, - AppMode::Error => match code { - KeyCode::Enter | KeyCode::Char('r') => { - let _ = tx.send(AiTuiEvent::Retry); - EventResult::Consumed - } - _ => EventResult::Ignored, - }, - _ => EventResult::Ignored, - } - }); - - children -} diff --git a/crates/atuin-ai/src/tui/components/input_box.rs b/crates/atuin-ai/src/tui/components/input_box.rs deleted file mode 100644 index 6b81322c..00000000 --- a/crates/atuin-ai/src/tui/components/input_box.rs +++ /dev/null @@ -1,220 +0,0 @@ -//! Bordered input box component for the AI TUI. -//! -//! Wraps tui-textarea's TextArea, which handles rendering, wrapping, cursor -//! positioning, and height measurement natively. The component configures the -//! TextArea's block (border + titles) and forwards events to it. -//! -//! On Enter, sends `AiTuiEvent::SubmitInput` via the context-provided channel. - -use std::sync::{Arc, Mutex}; - -use crossterm::event::KeyModifiers; -use eye_declare::{Canvas, Elements, EventResult, Hooks, component, element, props}; -use ratatui::widgets::{Block, Borders, Padding}; -use ratatui_core::{ - layout::Rect, - style::{Color, Modifier, Style}, - text::Line, - widgets::Widget, -}; -use tui_textarea::TextArea; - -use crate::commands::inline::DriverEventSender; -use crate::tui::{events::AiTuiEvent, slash::SlashCommandSearchResult}; - -/// A bordered text input box backed by tui-textarea. -/// -/// Props configure the chrome (title, footer). The TextArea itself lives -/// in the component's State so it owns cursor, wrapping, and rendering. -#[props] -pub(crate) struct InputBox { - /// Title shown in top-left border - pub title: String, - /// Right-side label in top border - pub title_right: String, - /// Footer text shown in bottom border (keybinding hints) - pub footer: String, - /// Whether the input is currently active (shows cursor, accepts input) - pub active: bool, - /// If the user has typed a slash command, this holds the best match for it. - pub slash_suggestion: Option<SlashCommandSearchResult>, -} - -pub(crate) struct InputBoxState { - textarea: Arc<Mutex<TextArea<'static>>>, - tx: Option<DriverEventSender>, -} - -impl Default for InputBoxState { - fn default() -> Self { - let mut textarea = TextArea::default(); - textarea.set_cursor_line_style(ratatui::style::Style::default()); - textarea.set_wrap_mode(tui_textarea::WrapMode::Word); - textarea.set_placeholder_text("Type a message..."); - textarea.set_placeholder_style( - ratatui::style::Style::default() - .fg(ratatui::style::Color::DarkGray) - .add_modifier(ratatui::style::Modifier::ITALIC), - ); - Self { - textarea: Arc::new(Mutex::new(textarea)), - tx: None, - } - } -} - -fn make_block(props: &InputBox) -> Block<'static> { - let border_style = Style::default().fg(Color::DarkGray); - let title_style = Style::default() - .fg(Color::Gray) - .add_modifier(Modifier::BOLD); - - let mut block = Block::default() - .borders(Borders::ALL) - .border_style(border_style) - .padding(Padding::horizontal(1)); - - if !props.title.is_empty() { - block = - block.title_top(Line::styled(format!(" {} ", props.title), title_style).left_aligned()); - } - if !props.title_right.is_empty() { - block = block.title_top( - Line::styled(format!(" {} ", props.title_right), border_style).right_aligned(), - ); - } - if !props.footer.is_empty() { - block = block.title_bottom( - Line::styled(format!(" {} ", props.footer), border_style).right_aligned(), - ); - } - - block -} - -#[component(props = InputBox, state = InputBoxState)] -fn input_box( - props: &InputBox, - state: &InputBoxState, - hooks: &mut Hooks<InputBox, InputBoxState>, -) -> Elements { - // Always focusable so focus isn't lost when the permission Select is - // removed from the tree. The `active` prop controls visual state and - // whether keystrokes are processed, not focusability. - hooks.use_focusable(true); - hooks.use_autofocus(); - - hooks.use_context::<DriverEventSender>(|tx, _, state| { - state.tx = tx.cloned(); - }); - - hooks.use_event(move |event, props, state| { - let state = state.read(); - - if !props.active { - return EventResult::Ignored; - } - - if let crossterm::event::Event::Paste(text) = event { - let mut textarea = state.textarea.lock().unwrap(); - textarea.insert_str(text); - return EventResult::Consumed; - } - - if let crossterm::event::Event::Key(key) = event { - if key.kind != crossterm::event::KeyEventKind::Press { - return EventResult::Ignored; - } - - let mut textarea = state.textarea.lock().unwrap(); - - match key.code { - crossterm::event::KeyCode::Char('j') - if key.modifiers.contains(KeyModifiers::CONTROL) => - { - textarea.insert_newline(); - return EventResult::Consumed; - } - crossterm::event::KeyCode::Tab if props.slash_suggestion.is_some() => { - // If there's a slash command suggestion, Tab accepts it. - if let Some(suggestion) = &props.slash_suggestion { - textarea.clear(); - textarea.insert_str(format!("/{}", suggestion.command.name)); - // Manually trigger an input update event so the slash suggestion box can update immediately - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n"))); - } - return EventResult::Consumed; - } - } - crossterm::event::KeyCode::Enter => { - if key.modifiers.contains(KeyModifiers::SHIFT) { - textarea.insert_newline(); - return EventResult::Consumed; - } else { - let text = textarea.lines().join("\n"); - if text.trim().is_empty() { - return EventResult::Ignored; - } - - textarea.clear(); - - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::SubmitInput(text)); - } - return EventResult::Consumed; - } - } - _ => {} - } - - // All other keys: forward to textarea. - // tui-textarea can convert crossterm events itself. - textarea.input(*key); - - if let Some(ref tx) = state.tx { - let _ = tx.send(AiTuiEvent::InputUpdated(textarea.lines().join("\n"))); - } - return EventResult::Consumed; - } - - EventResult::Ignored - }); - - let textarea = state.textarea.clone(); - let block = make_block(props); - let active = props.active; - element!( - Canvas(render_fn: move |area, buf| { - let mut area = area; - - if area.height < 3 || area.width < 4 { - return; - } - - let height = { - // TextArea handles scrolling internally if content overflows. - let inner = block.inner(Rect::new(0, 0, area.width, u16::MAX)); - let chrome = (u16::MAX).saturating_sub(inner.height); - let content = textarea.lock().unwrap().measure(area.width - 4); - chrome + content.preferred_rows - }; - - area.height = height.min(7); - let inner = block.clone().inner(area); - block.clone().render(area, buf); - - let mut textarea = textarea.lock().unwrap(); - if active { - textarea.set_cursor_style(Style::default().add_modifier(Modifier::REVERSED)); - textarea.set_placeholder_text("Type a message..."); - } else { - textarea.set_cursor_style(Style::default()); - textarea.set_placeholder_text(""); - } - - // Render textarea into the inner area - textarea.render(inner, buf); - }) - ) -} diff --git a/crates/atuin-ai/src/tui/components/markdown.rs b/crates/atuin-ai/src/tui/components/markdown.rs deleted file mode 100644 index 607520b7..00000000 --- a/crates/atuin-ai/src/tui/components/markdown.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Markdown rendering component using pulldown-cmark. -//! -//! More robust than eye-declare's built-in Markdown component: -//! uses a proper CommonMark parser rather than line-by-line regex. - -use eye_declare::{Component, props}; -use pulldown_cmark::{Event, Parser, Tag, TagEnd}; -use ratatui_core::{ - buffer::Buffer, - layout::Rect, - style::{Color, Modifier, Style}, - text::{Line, Span, Text}, - widgets::Widget, -}; -use ratatui_widgets::paragraph::{Paragraph, Wrap}; - -/// A markdown rendering component backed by pulldown-cmark. -#[props] -pub(crate) struct Markdown { - pub source: String, -} - -/// Style configuration for markdown rendering. -pub(crate) struct MarkdownStyles { - pub base: Style, - pub code_inline: Style, - pub code_block: Style, - pub bold: Style, - pub italic: Style, - pub heading: Style, -} - -impl MarkdownStyles { - pub fn new() -> Self { - let base = Style::default(); - Self { - base, - code_inline: Style::default().fg(Color::Yellow), - code_block: Style::default().fg(Color::Green), - bold: base.add_modifier(Modifier::BOLD), - italic: base.add_modifier(Modifier::ITALIC), - heading: Style::default() - .fg(Color::Cyan) - .add_modifier(Modifier::BOLD), - } - } -} - -impl Default for MarkdownStyles { - fn default() -> Self { - Self::new() - } -} - -impl Component for Markdown { - type State = MarkdownStyles; - - fn render(&self, area: Rect, buf: &mut Buffer, state: &Self::State) { - if self.source.is_empty() || area.width == 0 || area.height == 0 { - return; - } - let text = parse_markdown(&self.source, state); - Paragraph::new(text) - .wrap(Wrap { trim: false }) - .render(area, buf); - } - - fn desired_height(&self, width: u16, state: &Self::State) -> Option<u16> { - if self.source.is_empty() || width == 0 { - return Some(0); - } - let text = parse_markdown(&self.source, state); - Some( - Paragraph::new(text) - .wrap(Wrap { trim: false }) - .line_count(width) as u16, - ) - } - - fn initial_state(&self) -> Option<MarkdownStyles> { - Some(MarkdownStyles::new()) - } -} - -/// Parse markdown source into styled ratatui Text using pulldown-cmark. -fn parse_markdown<'a>(source: &'a str, styles: &'a MarkdownStyles) -> Text<'static> { - let parser = Parser::new(source); - let mut lines: Vec<Vec<Span<'static>>> = vec![Vec::new()]; - let mut current_line = 0; - - let mut style_stack: Vec<Style> = vec![styles.base]; - let mut in_code_block = false; - let mut in_list_item = false; - // True until the first paragraph inside a list item has been opened. - // The first paragraph should flow inline with the "- " prefix. - let mut list_item_first_para = false; - - for event in parser { - match event { - Event::Start(Tag::Strong) => { - let bold = style_stack.last().copied().unwrap_or(styles.bold); - style_stack.push(bold); - } - Event::End(TagEnd::Strong) => { - style_stack.pop(); - } - Event::Start(Tag::Emphasis) => { - let italic = style_stack.last().copied().unwrap_or(styles.italic); - style_stack.push(italic); - } - Event::End(TagEnd::Emphasis) => { - style_stack.pop(); - } - Event::Start(Tag::CodeBlock(_)) => { - in_code_block = true; - if !lines[current_line].is_empty() { - current_line += 1; - lines.push(Vec::new()); - current_line += 1; - lines.push(Vec::new()); - } - } - Event::End(TagEnd::CodeBlock) => { - in_code_block = false; - if !lines[current_line].is_empty() { - current_line += 1; - lines.push(Vec::new()); - } - } - Event::Code(code) => { - lines[current_line].push(Span::styled(format!("{}", code), styles.code_inline)); - } - Event::Text(text) => { - let current_style = if in_code_block { - styles.code_block - } else { - style_stack.last().copied().unwrap_or(styles.base) - }; - let prefix = if in_code_block { " " } else { "" }; - let parts: Vec<&str> = text.split('\n').collect(); - for (i, part) in parts.iter().enumerate() { - if i > 0 { - current_line += 1; - lines.push(Vec::new()); - } - if !part.is_empty() { - lines[current_line] - .push(Span::styled(format!("{}{}", prefix, part), current_style)); - } - } - } - Event::SoftBreak => { - let current_style = style_stack.last().copied().unwrap_or(styles.base); - lines[current_line].push(Span::styled(" ", current_style)); - } - Event::HardBreak => { - current_line += 1; - lines.push(Vec::new()); - } - Event::Start(Tag::Paragraph) => { - if in_list_item && list_item_first_para { - // First paragraph flows inline with the "- " prefix - list_item_first_para = false; - } else if current_line > 0 || !lines[0].is_empty() { - current_line += 1; - lines.push(Vec::new()); - if !in_list_item { - // Blank separator between paragraphs (but not inside list items) - current_line += 1; - lines.push(Vec::new()); - } - } - } - Event::End(TagEnd::Paragraph) => {} - Event::Start(Tag::Heading { .. }) => { - if current_line > 0 || !lines[0].is_empty() { - current_line += 1; - lines.push(Vec::new()); - current_line += 1; - lines.push(Vec::new()); - } - style_stack.push(styles.heading); - } - Event::End(TagEnd::Heading(_)) => { - style_stack.pop(); - } - Event::Start(Tag::Item) => { - if current_line > 0 || !lines[0].is_empty() { - current_line += 1; - lines.push(Vec::new()); - } - lines[current_line].push(Span::styled("- ", Style::default().fg(Color::DarkGray))); - in_list_item = true; - list_item_first_para = true; - } - Event::End(TagEnd::Item) => { - in_list_item = false; - } - Event::Start(Tag::List(_)) if current_line > 0 || !lines[0].is_empty() => { - current_line += 1; - lines.push(Vec::new()); - } - Event::End(TagEnd::List(_)) => {} - _ => {} - } - } - - let text_lines: Vec<Line<'static>> = lines.into_iter().map(Line::from).collect(); - Text::from(text_lines) -} diff --git a/crates/atuin-ai/src/tui/components/mod.rs b/crates/atuin-ai/src/tui/components/mod.rs deleted file mode 100644 index 9959dbad..00000000 --- a/crates/atuin-ai/src/tui/components/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub(crate) mod atuin_ai; -pub(crate) mod input_box; -pub(crate) mod markdown; -pub(crate) mod select; -pub(crate) mod session_continue; diff --git a/crates/atuin-ai/src/tui/components/select.rs b/crates/atuin-ai/src/tui/components/select.rs deleted file mode 100644 index 771d7830..00000000 --- a/crates/atuin-ai/src/tui/components/select.rs +++ /dev/null @@ -1,95 +0,0 @@ -use crossterm::event::KeyCode; -use eye_declare::{Elements, EventResult, Hooks, Span, Text, View, component, element, props}; -use ratatui::style::Style; -use typed_builder::TypedBuilder; - -use crate::commands::inline::DriverEventSender; -use crate::tui::events::AiTuiEvent; - -type OnSelectFn = Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync + 'static>; - -#[derive(TypedBuilder)] -pub(crate) struct SelectOption { - #[builder(setter(into))] - pub label: String, - #[builder(setter(into))] - pub value: String, - #[builder(default = Style::default())] - pub label_style: Style, - #[builder(default = Style::default().reversed())] - pub selected_style: Style, -} - -#[derive(Default)] -pub(crate) struct PermissionSelectorState { - selected_option: usize, - tx: Option<DriverEventSender>, -} - -#[props] -pub(crate) struct Select { - pub options: Vec<SelectOption>, - pub on_select: OnSelectFn, -} - -#[component(props = Select, state = PermissionSelectorState)] -pub(crate) fn permission_selector( - props: &Select, - state: &PermissionSelectorState, - hooks: &mut Hooks<Select, PermissionSelectorState>, -) -> Elements { - hooks.use_focusable(true); - hooks.use_autofocus(); - - hooks.use_context::<DriverEventSender>(|tx, _, state| { - state.tx = tx.cloned(); - }); - - hooks.use_event(move |event, props, state| { - if !event.is_key_press() { - return EventResult::Ignored; - } - - if let crossterm::event::Event::Key(key) = event { - if key.kind != crossterm::event::KeyEventKind::Press { - return EventResult::Ignored; - } - - match key.code { - KeyCode::Up => { - state.selected_option = - (state.selected_option + props.options.len() - 1) % props.options.len(); - return EventResult::Consumed; - } - KeyCode::Down => { - state.selected_option = (state.selected_option + 1) % props.options.len(); - return EventResult::Consumed; - } - KeyCode::Enter => { - let option = &props.options[state.selected_option]; - if let Some(event) = (props.on_select)(option) - && let Some(ref tx) = state.tx - { - let _ = tx.send(event); - } - return EventResult::Consumed; - } - _ => {} - } - } - - EventResult::Ignored - }); - - element!( - View { - #(for (index, option) in props.options.iter().enumerate() { - Text { Span(text: &option.label, style: if index == state.selected_option { - option.selected_style - } else { - option.label_style - }) } - }) - } - ) -} diff --git a/crates/atuin-ai/src/tui/components/session_continue.rs b/crates/atuin-ai/src/tui/components/session_continue.rs deleted file mode 100644 index bfbfb191..00000000 --- a/crates/atuin-ai/src/tui/components/session_continue.rs +++ /dev/null @@ -1,49 +0,0 @@ -use chrono_humanize::HumanTime; -use eye_declare::{Elements, Hooks, Span, Text, component, element, props}; -use ratatui::style::{Color, Modifier, Style}; - -#[props] -pub(crate) struct SessionContinue { - pub continued_at: Option<chrono::DateTime<chrono::Utc>>, -} - -#[derive(Default)] -pub(crate) struct SessionContinueState { - /// Frozen on mount so the label doesn't change on every render. - label: Option<String>, -} - -#[component(props = SessionContinue, state = SessionContinueState)] -fn session_continue( - _props: &SessionContinue, - state: &SessionContinueState, - hooks: &mut Hooks<SessionContinue, SessionContinueState>, -) -> Elements { - hooks.use_mount(|props, state| { - state.label = Some(match props.continued_at { - Some(t) => { - let human = HumanTime::from(t - chrono::Utc::now()); - format!( - " Continuing previous session (last active {human}) - type /new to start a new session" - ) - } - None => { - " Continuing previous session - type /new to start a new session".to_string() - } - }); - }); - - let resume_label = state - .label - .as_deref() - .unwrap_or(" Continuing previous session - type /new to start a new session"); - - element! { - Text { - Span( - text: resume_label, - style: Style::default().fg(Color::DarkGray).add_modifier(Modifier::ITALIC), - ) - } - } -} diff --git a/crates/atuin-ai/src/tui/content/help.md b/crates/atuin-ai/src/tui/content/help.md deleted file mode 100644 index d6623ac9..00000000 --- a/crates/atuin-ai/src/tui/content/help.md +++ /dev/null @@ -1,6 +0,0 @@ -Welcome to Atuin AI, an AI assistant in your terminal. You can ask it to generate a shell command for you, or ask general terminal or software questions. - -Commands: -{commands} - -For more information, see [https://docs.atuin.sh/cli/ai/introduction/](https://docs.atuin.sh/cli/ai/introduction/) diff --git a/crates/atuin-ai/src/tui/events.rs b/crates/atuin-ai/src/tui/events.rs deleted file mode 100644 index abcb1bd9..00000000 --- a/crates/atuin-ai/src/tui/events.rs +++ /dev/null @@ -1,67 +0,0 @@ -/// Application-domain events emitted by UI components. -/// -/// Components translate raw key events into these semantic events, -/// which are sent via an `mpsc::Sender<AiTuiEvent>` provided through -/// eye-declare's context system. The main event loop in `inline.rs` -/// receives them and mutates `AppState` accordingly. -#[derive(Debug)] -pub(crate) enum AiTuiEvent { - /// User updated the input text - InputUpdated(String), - /// User submitted text input (Enter in Input mode) - SubmitInput(String), - /// User entered a slash command (e.g. "/help") - #[allow(unused)] - SlashCommand(String), - /// User selected a permission - SelectPermission(PermissionResult), - /// Cancel active generation or streaming (Esc during Generating/Streaming) - CancelGeneration, - /// Execute the suggested command - ExecuteCommand, - /// Insert command without executing - InsertCommand, - /// Cancel confirmation of dangerous command - CancelConfirmation, - /// Interrupt a running tool execution (Ctrl+C during ExecutingPreview) - InterruptToolExecution, - /// Retry after error - Retry, - /// Exit the application - Exit, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum PermissionResult { - Allow, - /// Per-file, time-limited grant scoped to the current session. - AllowFileForSession, - AlwaysAllowInDir, - AlwaysAllow, - Deny, -} - -impl PermissionResult { - /// String identifier used as the SelectOption value. - pub fn as_value_str(&self) -> &'static str { - match self { - Self::Allow => "allow", - Self::AllowFileForSession => "allow-file-session", - Self::AlwaysAllowInDir => "always-allow-in-dir", - Self::AlwaysAllow => "always-allow", - Self::Deny => "deny", - } - } - - /// Parse from a SelectOption value string. - pub fn from_value_str(s: &str) -> Option<Self> { - match s { - "allow" => Some(Self::Allow), - "allow-file-session" => Some(Self::AllowFileForSession), - "always-allow-in-dir" => Some(Self::AlwaysAllowInDir), - "always-allow" => Some(Self::AlwaysAllow), - "deny" => Some(Self::Deny), - _ => None, - } - } -} diff --git a/crates/atuin-ai/src/tui/mod.rs b/crates/atuin-ai/src/tui/mod.rs deleted file mode 100644 index 9727f362..00000000 --- a/crates/atuin-ai/src/tui/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod components; -pub(crate) mod events; -pub(crate) mod slash; -pub(crate) mod state; -pub(crate) mod view; - -pub(crate) use state::{ConversationEvent, events_to_messages}; diff --git a/crates/atuin-ai/src/tui/slash.rs b/crates/atuin-ai/src/tui/slash.rs deleted file mode 100644 index 7d5e6fa8..00000000 --- a/crates/atuin-ai/src/tui/slash.rs +++ /dev/null @@ -1,79 +0,0 @@ -#[derive(Debug, Clone)] -pub(crate) struct SlashCommand { - pub name: String, - pub description: String, -} - -impl SlashCommand { - pub fn new(name: &str, description: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - } - } -} - -#[derive(Debug)] -pub(crate) struct SlashCommandRegistry { - commands: Vec<SlashCommand>, -} - -#[derive(Debug, Clone)] -pub(crate) struct SlashCommandSearchResult { - pub command: SlashCommand, - pub relevance: f32, - pub span: (usize, usize), -} - -impl SlashCommandRegistry { - pub fn new() -> Self { - Self { - commands: Vec::new(), - } - } - - pub fn register(&mut self, command: SlashCommand) { - self.commands.push(command); - } - - pub fn get_commands(&self) -> &[SlashCommand] { - &self.commands - } - - pub fn search_fuzzy(&self, query: &str) -> Vec<SlashCommandSearchResult> { - let query_lower = query.to_lowercase(); - - self.commands - .iter() - .filter_map(|command| { - let name_lower = command.name.to_lowercase(); - if let Some(start) = name_lower.find(&query_lower as &str) { - let end = start + query_lower.len(); - Some((command, start, end)) - } else { - None - } - }) - .map(|(command, start, end)| { - SlashCommandSearchResult { - command: command.clone(), - relevance: 1.0, // Simple relevance score for now - span: (start, end), - } - }) - .collect() - } -} - -impl Default for SlashCommandRegistry { - fn default() -> Self { - let mut registry = Self::new(); - registry.register(SlashCommand::new("help", "Show help information")); - registry.register(SlashCommand::new( - "new", - "Start a new conversation, archiving the current one", - )); - - registry - } -} diff --git a/crates/atuin-ai/src/tui/state.rs b/crates/atuin-ai/src/tui/state.rs deleted file mode 100644 index 71da6ff5..00000000 --- a/crates/atuin-ai/src/tui/state.rs +++ /dev/null @@ -1,237 +0,0 @@ -//! Core state types for the conversation protocol. -//! -//! ConversationEvent and events_to_messages are the canonical representations -//! used by both the FSM and the context window builder. AppMode is used by -//! the view layer for component prop derivation. - -/// Conversation event types matching the API protocol. -#[derive(Debug, Clone)] -pub(crate) enum ConversationEvent { - /// User message (what the user typed) - UserMessage { content: String }, - /// Text content from assistant (streamed or complete) - Text { content: String }, - /// Tool call from assistant - ToolCall { - id: String, - name: String, - input: serde_json::Value, - }, - /// Tool result (from server-side or client-side execution) - ToolResult { - tool_use_id: String, - content: String, - is_error: bool, - /// Server-side results are stored in the DB; the client sends an opaque - /// reference (`remote: true`) instead of the full content. - remote: bool, - /// Approximate content length for token estimation of remote results. - content_length: Option<usize>, - }, - /// Out-of-band output from the system — not sent to the server - OutOfBandOutput { - name: String, - command: Option<String>, - content: String, - }, - /// Context injected for the LLM that is not rendered in the TUI. - /// Converted to a user message in the API protocol. - SystemContext { content: String }, - /// A skill was loaded and its content injected into the conversation. - /// Serialized as a full user message for the API but rendered compactly - /// in the TUI (just the `/name args` invocation line). - SkillInvocation { - name: String, - arguments: Option<String>, - content: String, - }, -} - -impl ConversationEvent { - /// Whether this event represents actual conversation content sent to the API. - pub(crate) fn is_api_content(&self) -> bool { - match self { - ConversationEvent::UserMessage { .. } => true, - ConversationEvent::Text { .. } => true, - ConversationEvent::ToolCall { .. } => true, - ConversationEvent::ToolResult { .. } => true, - ConversationEvent::OutOfBandOutput { .. } => false, - ConversationEvent::SystemContext { .. } => false, - ConversationEvent::SkillInvocation { .. } => true, - } - } - - /// Extract command from a suggest_command tool call. - pub(crate) fn as_command(&self) -> Option<&str> { - if let ConversationEvent::ToolCall { name, input, .. } = self - && name == "suggest_command" - { - return input.get("command").and_then(|v| v.as_str()); - } - None - } -} - -/// Application mode for key handling and component props. -/// -/// Derived from AgentState in the view layer via `From<&AgentState>`. -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub(crate) enum AppMode { - /// User is typing input - Input, - /// Waiting for generation (showing spinner) - Generating, - /// Streaming SSE response - Streaming, - /// Error state, can retry - Error, -} - -/// Convert a slice of conversation events to Claude API message format. -/// -/// This is the canonical event-to-message conversion, used by the context window -/// builder to convert turn slices independently. The logic handles combining -/// adjacent Text + ToolCall events into single assistant messages with mixed -/// content blocks. -pub(crate) fn events_to_messages(events: &[ConversationEvent]) -> Vec<serde_json::Value> { - let mut messages = Vec::new(); - let mut i = 0; - - while i < events.len() { - match &events[i] { - ConversationEvent::UserMessage { content } => { - messages.push(serde_json::json!({ - "role": "user", - "content": content - })); - i += 1; - } - ConversationEvent::Text { content } if content.is_empty() => { - i += 1; - } - ConversationEvent::Text { content } => { - let next_is_tool_call = events - .get(i + 1) - .is_some_and(|e| matches!(e, ConversationEvent::ToolCall { .. })); - - if next_is_tool_call { - let mut content_blocks = Vec::new(); - - if !content.is_empty() { - content_blocks.push(serde_json::json!({ - "type": "text", - "text": content - })); - } - - while let Some(ConversationEvent::ToolCall { - id, name, input, .. - }) = events.get(i + 1) - { - content_blocks.push(serde_json::json!({ - "type": "tool_use", - "id": id, - "name": name, - "input": input - })); - i += 1; - } - - messages.push(serde_json::json!({ - "role": "assistant", - "content": content_blocks - })); - i += 1; - } else { - messages.push(serde_json::json!({ - "role": "assistant", - "content": content - })); - i += 1; - } - } - ConversationEvent::ToolCall { .. } => { - let mut tool_uses = Vec::new(); - while i < events.len() { - if let ConversationEvent::ToolCall { - id, name, input, .. - } = &events[i] - { - tool_uses.push(serde_json::json!({ - "type": "tool_use", - "id": id, - "name": name, - "input": input - })); - i += 1; - } else { - break; - } - } - messages.push(serde_json::json!({ - "role": "assistant", - "content": tool_uses - })); - } - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - remote, - content_length, - } => { - let tool_result = if *remote { - let mut obj = serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "remote": true, - "is_error": is_error - }); - if let Some(len) = content_length { - obj["content_length"] = serde_json::json!(len); - } - obj - } else { - serde_json::json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": content, - "is_error": is_error - }) - }; - messages.push(serde_json::json!({ - "role": "user", - "content": [tool_result] - })); - i += 1; - } - ConversationEvent::OutOfBandOutput { .. } => { - i += 1; - } - ConversationEvent::SystemContext { content } => { - messages.push(serde_json::json!({ - "role": "user", - "content": content - })); - i += 1; - } - ConversationEvent::SkillInvocation { - name, - arguments, - content, - } => { - let header = match arguments { - Some(args) => format!("[Loaded skill: {name}]\n[Arguments: {args}]"), - None => format!("[Loaded skill: {name}]"), - }; - messages.push(serde_json::json!({ - "role": "user", - "content": format!("{header}\n\n{content}") - })); - i += 1; - } - } - } - - messages -} diff --git a/crates/atuin-ai/src/tui/view/mod.rs b/crates/atuin-ai/src/tui/view/mod.rs deleted file mode 100644 index b594cedf..00000000 --- a/crates/atuin-ai/src/tui/view/mod.rs +++ /dev/null @@ -1,978 +0,0 @@ -//! View function that builds the eye-declare element tree from app state. - -use eye_declare::{ - Cells, Column, Elements, HStack, Span, Spinner, Text, View, Viewport, WidthConstraint, element, -}; -use ratatui_core::style::{Color, Modifier, Style}; - -use crate::driver::ViewState; -use crate::fsm::{AgentState, StreamPhase}; -use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview}; -use crate::tui::components::select::SelectOption; -use crate::tui::components::session_continue::SessionContinue; -use crate::tui::events::{AiTuiEvent, PermissionResult}; - -use super::components::atuin_ai::AtuinAi; -use super::components::input_box::InputBox; -use super::components::markdown::Markdown; -use super::components::select::Select; -use super::state::AppMode; - -pub(crate) mod turn; - -impl From<&AgentState> for AppMode { - fn from(state: &AgentState) -> Self { - match state { - AgentState::Idle { .. } => AppMode::Input, - AgentState::Turn { - stream: StreamPhase::Connecting, - } => AppMode::Generating, - AgentState::Turn { .. } => AppMode::Streaming, - AgentState::Error(_) => AppMode::Error, - } - } -} - -/// Build the element tree from current state. -/// -/// Layout (top to bottom): -/// - Conversation messages (user messages, agent responses, tool status) -/// - Streaming content (if actively streaming) -/// - Error display (if in error state) -/// - Spacer -/// - Input box (bordered, with contextual keybindings) -pub(crate) fn ai_view(state: &ViewState) -> Elements { - let committed = state.committed_turn_count; - let turns: Vec<&turn::UiTurn> = state.turns.iter().filter(|t| t.id >= committed).collect(); - let busy = state.is_busy(); - let last_index = turns.len().saturating_sub(1); - - // Turns are direct children of the root VStack so that eye_declare's - // on_commit can detect them scrolling into terminal scrollback and - // prune them from the tree. AtuinAi wraps only the interactive footer - // (input box, error display, pending banner) so its event capture/bubble - // handlers still fire for keyboard events. - element! { - #(if state.is_resumed && (!state.is_exiting() || !turns.is_empty()) { - SessionContinue(key: "continuation-notice", continued_at: state.last_event_time) - }) - - #(for (index, turn) in turns.iter().enumerate() { - #(match &turn.kind { - turn::UiTurnKind::User { events } => { - user_turn_view(events, index == 0, turn.id) - } - turn::UiTurnKind::Agent { events } => { - agent_turn_view(events, busy && index == last_index, state.tools.awaiting_permission().is_some(), turn.id) - } - turn::UiTurnKind::OutOfBand { events } => { - out_of_band_turn_view(events, turn.id) - } - }) - }) - - AtuinAi( - key: "footer", - mode: AppMode::from(&state.agent_state), - has_command: state.has_command, - is_input_blank: state.is_input_blank, - pending_confirmation: state.has_confirmation(), - has_executing_preview: state.tools.has_executing_preview(), - ) { - #({ - let needs_pending_banner = busy && !matches!(turns.last(), Some(turn::UiTurn { kind: turn::UiTurnKind::Agent { .. }, .. })); - if needs_pending_banner { - let empty: &[turn::UiEvent] = &[]; - agent_turn_view(empty, true, false, usize::MAX) - } else { - element! {} - } - }) - - #(if let AgentState::Error(ref msg) = state.agent_state { - View(key: "error-display", padding_left: Cells::from(2), padding_top: Cells::from(1)) { - Text { - Span(text: "Error: ", style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) - Span(text: msg, style: Style::default().fg(Color::Red)) - } - } - }) - - #(if !state.is_exiting() { - #(input_view(state)) - }) - } - } -} - -fn input_view(state: &ViewState) -> Elements { - let asking_tool = state.tools.awaiting_permission(); - let in_git_project = state.in_git_project; - let slash_results = state - .slash_command_search_results - .iter() - .take(4) - .collect::<Vec<_>>(); - let first_slash_result = slash_results.first().cloned(); - - element! { - #(if let Some(tc) = asking_tool { - #(tool_call_view(tc, in_git_project)) - }) - - #(if asking_tool.is_none() { - View(key: "input-box", padding_top: Cells::from(1)) { - InputBox( - key: "input", - title: "Generate a command or ask a question", - title_right: "Atuin AI", - footer: state.footer_text(), - active: state.is_input_active(), - slash_suggestion: first_slash_result.cloned() - ) - - #(if state.is_input_blank && state.has_command && state.is_input_active() { - #(if state.has_confirmation() { - Text { Span(text: "[Enter] Confirm dangerous command [Esc] Cancel", style: Style::default().fg(Color::Gray)) } - } else { - Text { Span(text: "[Enter] Execute suggested command [Tab] Insert Command", style: Style::default().fg(Color::Gray)) } - }) - }) - - #(if !slash_results.is_empty() { - #(for (i, result) in slash_results.iter().enumerate() { - Text { - Span(text: format!("/{}", &result.command.name[..result.span.0]), style: Style::default().fg(Color::Blue)) - Span(text: &result.command.name[result.span.0..result.span.1], style: Style::default().fg(Color::Blue).add_modifier(Modifier::UNDERLINED)) - Span(text: format!("{}", &result.command.name[result.span.1..]), style: Style::default().fg(Color::Blue)) - Span(text: " - ") - Span(text: &result.command.description) - - #(if i == 0 { - Span(text: " [Tab] Insert", style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC).dim()) - }) - } - - }) - }) - } - }) - } -} - -fn tool_call_view(tool_call: &crate::fsm::tools::TrackedTool, in_git_project: bool) -> Elements { - let verb = tool_call.tool.descriptor().display_verb; - let tool_desc = match &tool_call.tool { - ClientToolCall::Read(tool) => tool.path.display().to_string(), - ClientToolCall::Edit(tool) => tool.path.display().to_string(), - ClientToolCall::Write(tool) => tool.path.display().to_string(), - ClientToolCall::Shell(tool) => tool.command.clone(), - ClientToolCall::AtuinHistory(tool) => tool.query.clone(), - ClientToolCall::AtuinOutput(tool) => tool.history_id.to_string(), - ClientToolCall::LoadSkill(tool) => format!("skill: {}", tool.name), - }; - - let select_options = permission_options_for_tool(&tool_call.tool, in_git_project); - - element! { - View(key: format!("tool-call-{}", tool_call.id), padding_left: Cells::from(2), padding_top: Cells::from(1)) { - Text { - Span(text: format!("Atuin AI would like to {}: ", verb), style: Style::default()) - Span(text: &tool_desc, style: Style::default().fg(Color::Yellow)) - } - View(padding_left: Cells::from(2)) { - Select(options: select_options, on_select: Box::new(move |option: &SelectOption| { - PermissionResult::from_value_str(option.value.as_str()) - .map(AiTuiEvent::SelectPermission) - }) as Box<dyn Fn(&SelectOption) -> Option<AiTuiEvent> + Send + Sync>) - } - } - } -} - -/// Build the permission SelectOptions appropriate for a tool call. -/// -/// Edit tools get a per-file session-scoped option instead of the -/// workspace-level "Always allow in this directory". Other tools -/// keep the standard set. -fn permission_options_for_tool(tool: &ClientToolCall, in_git_project: bool) -> Vec<SelectOption> { - match tool { - ClientToolCall::Edit(_) | ClientToolCall::Write(_) => vec![ - SelectOption::builder() - .label("Allow") - .value(PermissionResult::Allow.as_value_str()) - .build(), - SelectOption::builder() - .label("Allow this file for this session") - .value(PermissionResult::AllowFileForSession.as_value_str()) - .build(), - SelectOption::builder() - .label("Always allow") - .value(PermissionResult::AlwaysAllow.as_value_str()) - .build(), - SelectOption::builder() - .label("Deny") - .value(PermissionResult::Deny.as_value_str()) - .build(), - ], - _ => { - let dir_label = if in_git_project { - "Always allow in this workspace" - } else { - "Always allow in this directory" - }; - vec![ - SelectOption::builder() - .label("Allow") - .value(PermissionResult::Allow.as_value_str()) - .build(), - SelectOption::builder() - .label(dir_label) - .value(PermissionResult::AlwaysAllowInDir.as_value_str()) - .build(), - SelectOption::builder() - .label("Always allow") - .value(PermissionResult::AlwaysAllow.as_value_str()) - .build(), - SelectOption::builder() - .label("Deny") - .value(PermissionResult::Deny.as_value_str()) - .build(), - ] - } - } -} - -fn user_turn_view(events: &[turn::UiEvent], first_turn: bool, turn_id: usize) -> Elements { - let label_style = Style::default() - .fg(Color::Cyan) - .add_modifier(Modifier::BOLD); - - let padding = if first_turn { 0 } else { 1 }; - - element! { - View(key: format!("turn-{turn_id}"), padding_top: Cells::from(padding)) { - Text { - Span(text: " You ", style: label_style.reversed()) - } - #(for event in events { - #(match event { - turn::UiEvent::Text { content } => { - element! { - View(padding_left: Cells::from(2)) { - Text { - Span(text: content, style: Style::default()) - } - } - } - }, - _ => element!{} - }) - }) - } - } -} - -fn agent_turn_view( - events: &[turn::UiEvent], - busy: bool, - showing_ui: bool, - turn_id: usize, -) -> Elements { - let label_style = Style::default() - .fg(Color::Yellow) - .add_modifier(Modifier::BOLD); - - element! { - View(key: format!("turn-{turn_id}")) { - Text { - Span(text: " Atuin AI ", style: label_style.reversed()) - } - #(for (i, event) in events.iter().enumerate() { - #(if i > 0 { - Text { Span(text: "") } - }) - #(match event { - turn::UiEvent::Text { content } => { - element! { - View(padding_left: Cells::from(2)) { - Markdown(source: content) - } - } - }, - turn::UiEvent::ToolSummary(summary) => { - tool_summary_view(summary) - }, - turn::UiEvent::SuggestedCommand(details) => { - suggested_command_view(details) - }, - turn::UiEvent::ToolCall(details) => { - let tool_key = details.tool_use_id.clone(); - - element! { - View(key: format!("tool-output-{tool_key}"), padding_left: Cells::from(2)) { - #(match &details.render_data { - turn::ToolRenderData::Shell { command, preview } => { - shell_tool_view(&tool_key, command, preview.as_ref()) - }, - turn::ToolRenderData::FileEdit { path, preview } => { - file_edit_tool_view(&tool_key, &details.status, path, preview.as_ref()) - }, - turn::ToolRenderData::FileWrite { path, preview } => { - file_write_tool_view(&tool_key, &details.status, path, preview.as_ref()) - }, - turn::ToolRenderData::Remote => { - tool_status_view(&details.name, &details.status) - }, - turn::ToolRenderData::FileRead { .. } - | turn::ToolRenderData::HistorySearch { .. } - | turn::ToolRenderData::SkillLoad { .. } => { - element!{} - }, - }) - } - } - } - turn::UiEvent::ToolGroup(group) => { - let group_key = group.calls - .first() - .map(|c| c.tool_use_id.as_str()) - .unwrap_or("empty"); - - element! { - View(key: format!("group-{group_key}"), padding_left: Cells::from(2)) { - #(match group.kind { - turn::ToolGroupKind::FileRead => file_read_group_view(group), - turn::ToolGroupKind::HistorySearch => history_search_group_view(group), - }) - } - } - } - _ => element!{} - }) - }) - - #(if busy && !showing_ui { - View(key: "agent-working-spinner", padding_left: Cells::from(2), padding_top: Cells::from(1)) { - Spinner( - label: "", - spinner_style: Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD), - ) - } - }) - } - } -} - -fn out_of_band_turn_view(events: &[turn::UiEvent], turn_id: usize) -> Elements { - element! { - View(key: format!("turn-{turn_id}")) { - Text { - Span(text: " System ", style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD).add_modifier(Modifier::REVERSED)) - } - #(for event in events { - #(match event { - turn::UiEvent::OutOfBandOutput(details) => { - out_of_band_output_view(details) - } - _ => element!{} - }) - }) - } - } -} - -fn out_of_band_output_view(details: &turn::OutOfBandOutputDetails) -> Elements { - element! { - View(padding_left: Cells::from(2)) { - #(if details.command.is_some() { - Text { - Span(text: details.command.as_ref().unwrap(), style: Style::default().fg(Color::Blue)) - } - }) - Markdown(source: details.content.clone()) - } - } -} - -fn tool_summary_view(summary: &turn::ToolSummary) -> Elements { - element! { - Spinner(label: summary.summary(), done: !summary.any_pending()) - } -} - -/// Render a status indicator for a non-preview tool call (e.g. atuin_history, read_file). -fn tool_status_view(name: &str, status: &turn::ToolResultStatus) -> Elements { - match status { - turn::ToolResultStatus::Pending => { - element! { - Spinner( - label: format!("Running: {name}"), - label_style: Style::default().fg(Color::Yellow), - done: false, - ) - } - } - turn::ToolResultStatus::Success => { - element! { - Spinner( - label: format!("Ran: {name}"), - done: true, - ) - } - } - turn::ToolResultStatus::Error => { - element! { - Text { - Span(text: "✗ ", style: Style::default().fg(Color::Red)) - Span(text: format!("{name}: denied"), style: Style::default().fg(Color::Red)) - } - } - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Per-tool view functions -// ─────────────────────────────────────────────────────────────────── - -/// Max output lines shown for a shell command preview. -const MAX_SHELL_PREVIEW_LINES: u16 = 5; - -/// Render a shell command execution with live VT100 output viewport. -fn shell_tool_view(tool_key: &str, command: &str, preview: Option<&ToolPreview>) -> Elements { - let preview_done = preview.is_some_and(|p| p.exit_code.is_some() || p.interrupted.is_some()); - - element! { - #(if let Some(preview) = preview { - View(key: format!("preview-{tool_key}")) { - Spinner( - label: if preview_done { format!("Ran: {command}") } else { format!("Running: {command}") }, - done: preview_done, - hide_checkmark: true, - ) - HStack { - View(width: WidthConstraint::Fixed(2)) { - Text { Span(text: "└ ") } - } - Column { - Viewport( - key: format!("viewport-{tool_key}"), - lines: preview.lines.clone(), - height: (preview.lines.len() as u16).clamp(1, MAX_SHELL_PREVIEW_LINES), - style: Style::default().fg(Color::Gray), - wrap: false, - ) - } - } - #(shell_tool_footer(preview, preview_done)) - } - } else { - Spinner( - label: format!("Running: {command}"), - label_style: Style::default().fg(Color::Yellow), - done: false, - ) - }) - } -} - -fn shell_tool_footer(preview: &ToolPreview, preview_done: bool) -> Elements { - use crate::fsm::tools::InterruptReason; - - if let Some(reason) = &preview.interrupted { - let text = match reason { - InterruptReason::User => "Interrupted".to_string(), - InterruptReason::Timeout(secs) => format!("Timed out ({secs}s)"), - }; - return element! { - Text { - Span(text: text, style: Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)) - } - }; - } - if !preview_done { - return element! { - Text { - Span(text: "[Ctrl+C] Interrupt", style: Style::default().fg(Color::DarkGray)) - } - }; - } - if let Some(code) = preview.exit_code { - let style = if code == 0 { - Style::default().fg(Color::Green) - } else { - Style::default().fg(Color::Red) - }; - return element! { - Text { Span(text: format!("Exit code: {code}"), style: style) } - }; - } - element! {} -} - -/// Render a file edit tool call with diff preview. -fn file_edit_tool_view( - key: &str, - status: &turn::ToolResultStatus, - path: &std::path::Path, - preview: Option<&crate::diff::EditPreview>, -) -> Elements { - use crate::diff::DiffLine; - - let display_path = format_path_for_display(path); - - let status_line = match status { - turn::ToolResultStatus::Pending => { - element! { - Spinner( - label: format!("Editing: {display_path}"), - label_style: Style::default().fg(Color::Yellow), - done: false, - ) - } - } - turn::ToolResultStatus::Success => { - element! { - Spinner(label: format!("Edited: {display_path}"), done: true) - } - } - turn::ToolResultStatus::Error => { - element! { - Text { - Span(text: "✗ ", style: Style::default().fg(Color::Red)) - Span(text: format!("Edit {display_path}: failed"), style: Style::default().fg(Color::Red)) - } - } - } - }; - - // If no preview, just show the status line - let Some(preview) = preview else { - return status_line; - }; - if preview.hunks.is_empty() { - return status_line; - } - - // Calculate the line number gutter width from the highest line number - let max_line_num = preview.max_line_number(); - let gutter_width = max_line_num.to_string().len().max(2) as u16 + 1; // +1 for spacing - - element! { - View(key: key.to_string()) { - #(status_line) - - View(key: format!("{key}-diff"), padding_left: Cells::from(2)) { - #(for (hunk_idx, hunk) in preview.hunks.iter().enumerate() { - #({ - let gutter_w = gutter_width; - let mut before_pos = hunk.before_start; - let mut after_pos = hunk.after_start; - let lines_rendered: Vec<_> = hunk.lines.iter().enumerate().map(|(line_idx, line)| { - let (prefix, text, style, gutter_text, gutter_style) = match line { - DiffLine::Context(t) => { - let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize); - before_pos += 1; - after_pos += 1; - (" ", t.as_str(), Style::default().fg(Color::DarkGray), num, Style::default().fg(Color::DarkGray)) - } - DiffLine::Removed(t) => { - let num = format!("{:>width$}", before_pos, width = (gutter_w - 1) as usize); - before_pos += 1; - ("-", t.as_str(), Style::default().fg(Color::Red), num, Style::default().fg(Color::Red)) - } - DiffLine::Added(t) => { - let num = format!("{:>width$}", after_pos, width = (gutter_w - 1) as usize); - after_pos += 1; - ("+", t.as_str(), Style::default().fg(Color::Green), num, Style::default().fg(Color::Green)) - } - }; - (line_idx, prefix, text.to_string(), style, gutter_text, gutter_style) - }).collect(); - - element! { - View(key: format!("{key}-hunk-{hunk_idx}")) { - #(for (line_idx, prefix, text, style, gutter_text, gutter_style) in &lines_rendered { - HStack(key: format!("{key}-hunk-{hunk_idx}-line-{line_idx}")) { - View(width: WidthConstraint::Fixed(gutter_w)) { - Text { Span(text: gutter_text, style: *gutter_style) } - } - View { - Text { - Span(text: *prefix, style: *style) - Span(text: text, style: *style) - } - } - } - }) - } - } - }) - }) - } - } - } -} - -/// Render a file write tool call with content preview. -fn file_write_tool_view( - key: &str, - status: &turn::ToolResultStatus, - path: &std::path::Path, - preview: Option<&crate::diff::WritePreview>, -) -> Elements { - let display_path = format_path_for_display(path); - - let status_line = match status { - turn::ToolResultStatus::Pending => { - element! { - Spinner( - label: format!("Writing: {display_path}"), - label_style: Style::default().fg(Color::Yellow), - done: false, - ) - } - } - turn::ToolResultStatus::Success => { - let line_info = preview - .map(|p| format!(" ({} lines)", p.total_lines)) - .unwrap_or_default(); - element! { - Spinner(label: format!("Wrote: {display_path}{line_info}"), done: true) - } - } - turn::ToolResultStatus::Error => { - element! { - Text { - Span(text: "✗ ", style: Style::default().fg(Color::Red)) - Span(text: format!("Write {display_path}: failed"), style: Style::default().fg(Color::Red)) - } - } - } - }; - - let Some(preview) = preview else { - return status_line; - }; - if preview.lines.is_empty() { - return status_line; - } - - let gutter_width = preview.total_lines.to_string().len().max(2) as u16 + 1; - let remaining = preview.remaining_lines(); - - element! { - View(key: key.to_string()) { - #(status_line) - - View(key: format!("{key}-content"), padding_left: Cells::from(2)) { - #(for (idx, line) in preview.lines.iter().enumerate() { - HStack(key: format!("{key}-line-{idx}")) { - View(width: WidthConstraint::Fixed(gutter_width)) { - Text { Span( - text: format!("{:>width$}", idx + 1, width = (gutter_width - 1) as usize), - style: Style::default().fg(Color::DarkGray) - ) } - } - View { - Text { Span(text: line, style: Style::default().fg(Color::DarkGray)) } - } - } - }) - - #(if remaining > 0 { - Text { - Span( - text: format!(" ... +{remaining} more lines"), - style: Style::default().fg(Color::DarkGray) - ) - } - }) - } - } - } -} - -// ─────────────────────────────────────────────────────────────────── -// Tool group view functions -// ─────────────────────────────────────────────────────────────────── - -/// Max entries shown under a tool group header. When the group holds more -/// than this, only the most recent `MAX_GROUP_ENTRIES` are displayed; the -/// count in the header line tells the full story. -const MAX_GROUP_ENTRIES: usize = 5; - -/// Format a filesystem path for display in tool rows. -/// -/// - Relative to the current working directory if the path is under it -/// - `~/...` prefix if the path is under the user's home directory -/// - Absolute otherwise (and relative paths pass through unchanged) -fn format_path_for_display(path: &std::path::Path) -> String { - if let Ok(cwd) = std::env::current_dir() - && let Ok(relative) = path.strip_prefix(&cwd) - { - return relative.display().to_string(); - } - - if let Ok(home) = std::env::var("HOME") - && let Ok(relative) = path.strip_prefix(&home) - { - return format!("~/{}", relative.display()); - } - - path.display().to_string() -} - -fn filter_mode_label(mode: &HistorySearchFilterMode) -> &'static str { - match mode { - HistorySearchFilterMode::Global => "global", - HistorySearchFilterMode::Host => "host", - HistorySearchFilterMode::Session => "session", - HistorySearchFilterMode::Directory => "directory", - HistorySearchFilterMode::Workspace => "workspace", - } -} - -/// Format a list of filter modes as `"(global, workspace)"`, or an empty -/// string if the list is empty. -fn format_filter_modes(modes: &[HistorySearchFilterMode]) -> String { - if modes.is_empty() { - return String::new(); - } - let parts: Vec<&'static str> = modes.iter().map(filter_mode_label).collect(); - format!("({})", parts.join(", ")) -} - -/// Tree-connector marker for a row in a grouped list: `└ ` for the first -/// visible row, two spaces for subsequent rows. -fn tree_marker(is_first: bool) -> &'static str { - if is_first { "└ " } else { " " } -} - -/// 2-char status marker column: ✓ / ✗ / blank. -fn status_marker_view(status: &turn::ToolResultStatus) -> Elements { - match status { - turn::ToolResultStatus::Pending => element! { - Text { Span(text: " ") } - }, - turn::ToolResultStatus::Success => element! { - Text { Span(text: "✓ ", style: Style::default().fg(Color::Green)) } - }, - turn::ToolResultStatus::Error => element! { - Text { Span(text: "✗ ", style: Style::default().fg(Color::Red)) } - }, - } -} - -/// Compute the slice of calls to show — the most recent `MAX_GROUP_ENTRIES`. -fn visible_group_calls(group: &turn::ToolGroup) -> &[turn::ToolCallDetails] { - let start = group.calls.len().saturating_sub(MAX_GROUP_ENTRIES); - &group.calls[start..] -} - -/// Render a single row in a grouped list: [tree marker][status][content]. -fn group_row_view(is_first: bool, status: &turn::ToolResultStatus, content: Elements) -> Elements { - element! { - HStack { - View(width: WidthConstraint::Fixed(2)) { - Text { Span(text: tree_marker(is_first)) } - } - View(width: WidthConstraint::Fixed(2)) { - #(status_marker_view(status)) - } - Column { - #(content) - } - } - } -} - -/// Render a group of consecutive `read_file` tool calls. -fn file_read_group_view(group: &turn::ToolGroup) -> Elements { - let count = group.calls.len(); - let label = if count == 1 { - "Read 1 file".to_string() - } else { - format!("Read {count} files") - }; - let done = !group.any_pending(); - let visible = visible_group_calls(group); - - element! { - Spinner(label: label, done: done, hide_checkmark: true) - #(for (i, details) in visible.iter().enumerate() { - #(file_read_row(i == 0, details)) - }) - } -} - -fn file_read_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements { - let path_str = match &details.render_data { - turn::ToolRenderData::FileRead { path } => format_path_for_display(path), - _ => String::new(), - }; - - let content = element! { - Text { Span(text: path_str) } - }; - - group_row_view(is_first, &details.status, content) -} - -/// Render a group of consecutive `atuin_history` tool calls. -fn history_search_group_view(group: &turn::ToolGroup) -> Elements { - let done = !group.any_pending(); - let visible = visible_group_calls(group); - - element! { - Spinner(label: "Searched Atuin history:", done: done, hide_checkmark: true) - #(for (i, details) in visible.iter().enumerate() { - #(history_search_row(i == 0, details)) - }) - } -} - -fn history_search_row(is_first: bool, details: &turn::ToolCallDetails) -> Elements { - let (query, filter_modes) = match &details.render_data { - turn::ToolRenderData::HistorySearch { - query, - filter_modes, - } => (query.as_str(), filter_modes.as_slice()), - _ => ("", [].as_slice()), - }; - - let is_empty_query = query.trim().is_empty(); - let filter_label = format_filter_modes(filter_modes); - - let content = if is_empty_query { - element! { - Text { - Span( - text: "recent commands", - style: Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC), - ) - #(if !filter_label.is_empty() { - Span(text: " ") - Span(text: filter_label, style: Style::default().fg(Color::DarkGray)) - }) - } - } - } else { - element! { - Text { - Span(text: query.to_string()) - #(if !filter_label.is_empty() { - Span(text: " ") - Span(text: filter_label, style: Style::default().fg(Color::DarkGray)) - }) - } - } - }; - - group_row_view(is_first, &details.status, content) -} - -fn suggested_command_view(details: &turn::SuggestedCommandDetails) -> Elements { - let is_dangerous = matches!( - details.danger_level, - turn::DangerLevel::High(_) | turn::DangerLevel::Medium(_) - ); - let danger_notes = details.danger_level.notes(); - let danger_style = match details.danger_level { - turn::DangerLevel::High(_) => Style::default().fg(Color::Red), - turn::DangerLevel::Medium(_) => Style::default().fg(Color::Yellow), - turn::DangerLevel::Low(_) => Style::default().fg(Color::Green), - turn::DangerLevel::Unknown(_) => Style::default().fg(Color::Green), - }; - let danger_text = match details.danger_level { - turn::DangerLevel::High(_) => "High", - turn::DangerLevel::Medium(_) => "Medium", - turn::DangerLevel::Low(_) => "Low", - turn::DangerLevel::Unknown(_) => "Unknown", - }; - - let low_confidence = matches!( - details.confidence_level, - turn::ConfidenceLevel::Low(_) | turn::ConfidenceLevel::Medium(_) - ); - - let confidence_level = match details.confidence_level { - turn::ConfidenceLevel::Low(_) => "Low", - turn::ConfidenceLevel::Medium(_) => "Medium", - turn::ConfidenceLevel::High(_) => "High", - turn::ConfidenceLevel::Unknown(_) => "Unknown", - }; - - let confidence_notes = details.confidence_level.notes(); - - element! { - View { - Text { - Span(text: " Suggested command:", style: Style::default().fg(Color::Cyan)) - } - HStack { - View(width: WidthConstraint::Fixed(2)) { - Text { - #(if is_dangerous || low_confidence { - Span(text: "! ", style: Style::default().fg(Color::Yellow)) - } else { - Span(text: "$ ", style: Style::default().fg(Color::Blue)) - }) - } - } - Column { - Text { - Span(text: &details.command, style: Style::default().fg(Color::Green)) - } - } - } - #(if is_dangerous { - View(padding_left: Cells::from(2)) { - Text { - Span(text: "Danger: ", style: danger_style) - Span(text: danger_text, style: danger_style.add_modifier(Modifier::BOLD)) - } - } - }) - #(if is_dangerous && danger_notes.is_some() { - View(padding_left: Cells::from(2)) { - HStack { - View(width: WidthConstraint::Fixed(2)) { - Text { - Span(text: "└") - } - } - View(width: WidthConstraint::Fill) { - Markdown(source: danger_notes.unwrap()) - } - } - } - }) - #(if low_confidence { - View(padding_left: Cells::from(2)) { - Text { - Span(text: "Confidence: ", style: Style::default().fg(Color::Blue)) - Span(text: confidence_level, style: Style::default().fg(Color::Blue).add_modifier(Modifier::BOLD)) - } - } - }) - #(if low_confidence && confidence_notes.is_some() { - View(padding_left: Cells::from(2)) { - HStack { - View(width: WidthConstraint::Fixed(2)) { - Text { - Span(text: "└") - } - } - View(width: WidthConstraint::Fill) { - Markdown(source: confidence_notes.unwrap()) - } - } - } - }) - } - } -} - -// ai_view_old removed — superseded by ai_view above diff --git a/crates/atuin-ai/src/tui/view/turn.rs b/crates/atuin-ai/src/tui/view/turn.rs deleted file mode 100644 index aa1f55fa..00000000 --- a/crates/atuin-ai/src/tui/view/turn.rs +++ /dev/null @@ -1,606 +0,0 @@ -use std::path::PathBuf; - -use crate::fsm::tools::ToolManager; -use crate::tools::descriptor; -use crate::tools::{ClientToolCall, HistorySearchFilterMode, ToolPreview}; -use crate::tui::ConversationEvent; - -/// Server-sent danger level for a suggested command -#[derive(Debug)] -pub(crate) enum DangerLevel { - Low(Option<String>), - Medium(Option<String>), - High(Option<String>), - Unknown(Option<String>), -} - -impl DangerLevel { - pub(crate) fn notes(&self) -> Option<&String> { - match self { - DangerLevel::Low(notes) => notes.as_ref(), - DangerLevel::Medium(notes) => notes.as_ref(), - DangerLevel::High(notes) => notes.as_ref(), - DangerLevel::Unknown(notes) => notes.as_ref(), - } - } -} - -impl From<(&String, &String)> for DangerLevel { - fn from((danger_level, danger_notes): (&String, &String)) -> Self { - let notes = if danger_notes.is_empty() { - None - } else { - Some(danger_notes.to_string()) - }; - - match danger_level.as_str() { - "low" => DangerLevel::Low(notes), - "medium" => DangerLevel::Medium(notes), - "med" => DangerLevel::Medium(notes), - "high" => DangerLevel::High(notes), - _ => DangerLevel::Unknown(notes), - } - } -} - -/// Server-sent confidence level for a suggested command -#[derive(Debug)] -pub(crate) enum ConfidenceLevel { - Low(Option<String>), - Medium(Option<String>), - High(Option<String>), - Unknown(Option<String>), -} - -impl ConfidenceLevel { - pub(crate) fn notes(&self) -> Option<&String> { - match self { - ConfidenceLevel::Low(notes) => notes.as_ref(), - ConfidenceLevel::Medium(notes) => notes.as_ref(), - ConfidenceLevel::High(notes) => notes.as_ref(), - ConfidenceLevel::Unknown(notes) => notes.as_ref(), - } - } -} - -impl From<(&String, &String)> for ConfidenceLevel { - fn from((confidence_level, confidence_notes): (&String, &String)) -> Self { - let notes = if confidence_notes.is_empty() { - None - } else { - Some(confidence_notes.to_string()) - }; - - match confidence_level.as_str() { - "low" => ConfidenceLevel::Low(notes), - "medium" => ConfidenceLevel::Medium(notes), - "med" => ConfidenceLevel::Medium(notes), - "high" => ConfidenceLevel::High(notes), - _ => ConfidenceLevel::Unknown(notes), - } - } -} - -#[derive(Debug)] -pub(crate) enum UiEvent { - Text { - content: String, - }, - ToolCall(ToolCallDetails), - /// Consecutive client-side tool calls of the same groupable kind, collapsed - /// into one unit so the view can render a shared status line + a list of - /// individual entries. - ToolGroup(ToolGroup), - ToolSummary(ToolSummary), - SuggestedCommand(SuggestedCommandDetails), - OutOfBandOutput(OutOfBandOutputDetails), -} - -/// A run of consecutive client-side tool calls of the same groupable kind. -#[derive(Debug)] -pub(crate) struct ToolGroup { - pub(crate) kind: ToolGroupKind, - pub(crate) calls: Vec<ToolCallDetails>, -} - -impl ToolGroup { - /// True if any call in the group is still pending. - pub(crate) fn any_pending(&self) -> bool { - self.calls - .iter() - .any(|c| c.status == ToolResultStatus::Pending) - } -} - -/// Which kind of client-side tools this group holds. -/// -/// Only tool types that benefit from grouped presentation appear here. -/// Shell (needs its own viewport) and FileWrite (wants diffs/contents) are -/// intentionally absent — those render as individual `UiEvent::ToolCall`s. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub(crate) enum ToolGroupKind { - FileRead, - HistorySearch, -} - -/// Tool-type-specific data for rendering in the view layer. -/// -/// Each variant carries the data a per-tool renderer component needs. -/// Built by TurnBuilder from ToolTracker + ConversationEvent data. -#[derive(Debug)] -pub(crate) enum ToolRenderData { - /// Shell command with live/cached VT100 output preview. - Shell { - command: String, - preview: Option<ToolPreview>, - }, - /// File read operation. - FileRead { path: PathBuf }, - /// File edit (str_replace) operation. - FileEdit { - path: PathBuf, - preview: Option<crate::diff::EditPreview>, - }, - /// File write/create operation. - FileWrite { - path: PathBuf, - preview: Option<crate::diff::WritePreview>, - }, - /// Atuin history search. - HistorySearch { - query: String, - filter_modes: Vec<HistorySearchFilterMode>, - }, - /// Skill loading — read-only, auto-approved. - SkillLoad { _name: String }, - /// Server-side tool — no client rendering data available. - Remote, -} - -impl ToolRenderData { - pub(crate) fn is_remote(&self) -> bool { - matches!(self, ToolRenderData::Remote) - } - - /// The group kind this tool should collapse into, if any. - /// - /// Returns `None` for tools that render as individual `UiEvent::ToolCall`s - /// (shell, file writes, remote). - pub(crate) fn group_kind(&self) -> Option<ToolGroupKind> { - match self { - ToolRenderData::FileRead { .. } => Some(ToolGroupKind::FileRead), - ToolRenderData::HistorySearch { .. } => Some(ToolGroupKind::HistorySearch), - _ => None, - } - } -} - -#[derive(Debug)] -pub(crate) struct ToolCallDetails { - pub(crate) tool_use_id: String, - pub(crate) name: String, - pub(crate) status: ToolResultStatus, - pub(crate) render_data: ToolRenderData, -} - -#[derive(Debug)] -pub(crate) struct SuggestedCommandDetails { - pub(crate) command: String, - pub(crate) danger_level: DangerLevel, - pub(crate) confidence_level: ConfidenceLevel, -} - -#[derive(Debug)] -pub(crate) struct OutOfBandOutputDetails { - pub(crate) command: Option<String>, - pub(crate) content: String, -} - -#[derive(Debug, PartialEq, Eq)] -pub(crate) enum ToolResultStatus { - Pending, - Success, - Error, -} - -#[derive(Debug)] -pub(crate) struct UiTurn { - pub(crate) id: usize, - pub(crate) kind: UiTurnKind, -} - -#[derive(Debug)] -pub(crate) enum UiTurnKind { - User { events: Vec<UiEvent> }, - Agent { events: Vec<UiEvent> }, - OutOfBand { events: Vec<UiEvent> }, -} - -pub(crate) struct TurnBuilder<'a> { - turns: Vec<UiTurnKind>, - current_turn: Option<UiTurnKind>, - tracker: &'a ToolManager, - next_id: usize, -} - -/// A struct to iteratively build [UiTurn] events from [ConversationEvent]s. -impl<'a> TurnBuilder<'a> { - pub(crate) fn new(tracker: &'a ToolManager) -> Self { - Self { - turns: Vec::new(), - current_turn: None, - tracker, - next_id: 0, - } - } - - pub(crate) fn new_starting_at(tracker: &'a ToolManager, start_id: usize) -> Self { - Self { - turns: Vec::new(), - current_turn: None, - tracker, - next_id: start_id, - } - } - - pub(crate) fn add_event(&mut self, event: &ConversationEvent) { - match event { - ConversationEvent::UserMessage { content } => { - self.add_user_message(content); - } - ConversationEvent::Text { content } => { - self.add_agent_text(content); - } - ConversationEvent::ToolCall { id, name, input } => { - if name == "suggest_command" { - self.add_suggested_command(input); - } else { - self.add_tool_call(id, name, input); - } - } - ConversationEvent::ToolResult { - tool_use_id, - content, - is_error, - .. - } => { - self.add_tool_result(tool_use_id, content, *is_error); - } - ConversationEvent::OutOfBandOutput { - name, - command, - content, - } => { - self.add_out_of_band_output(name, command.as_deref(), content); - } - ConversationEvent::SystemContext { .. } => { - // Not rendered in the TUI — only sent to the API - } - ConversationEvent::SkillInvocation { - name, arguments, .. - } => { - let display = match arguments { - Some(args) => format!("/{name} {args}"), - None => format!("/{name}"), - }; - self.add_user_message(&display); - } - } - } - - pub(crate) fn build(&mut self) -> Vec<UiTurn> { - self.commit_turn(); - - // Within each agent turn: - // - Consecutive remote tool calls collapse into a ToolSummary - // - Consecutive client-side tool calls of the same group kind collapse - // into a ToolGroup (e.g. N file reads → one group) - // - All other events pass through unchanged - for turn in &mut self.turns { - if let UiTurnKind::Agent { events } = turn { - let mut new_events: Vec<UiEvent> = Vec::new(); - let mut pending_remote: Vec<ToolCallDetails> = Vec::new(); - let mut pending_group: Option<(ToolGroupKind, Vec<ToolCallDetails>)> = None; - - for event in events.drain(..) { - match event { - UiEvent::ToolCall(details) if details.render_data.is_remote() => { - flush_group(&mut pending_group, &mut new_events); - pending_remote.push(details); - } - UiEvent::ToolCall(details) - if details.render_data.group_kind().is_some() => - { - flush_remote(&mut pending_remote, &mut new_events); - - let kind = details.render_data.group_kind().unwrap(); - match pending_group.as_mut() { - Some((current_kind, calls)) if *current_kind == kind => { - calls.push(details); - } - _ => { - flush_group(&mut pending_group, &mut new_events); - pending_group = Some((kind, vec![details])); - } - } - } - other => { - flush_remote(&mut pending_remote, &mut new_events); - flush_group(&mut pending_group, &mut new_events); - new_events.push(other); - } - } - } - - flush_remote(&mut pending_remote, &mut new_events); - flush_group(&mut pending_group, &mut new_events); - - *events = new_events; - } - } - - let kinds = std::mem::take(&mut self.turns); - kinds - .into_iter() - .enumerate() - .map(|(i, kind)| UiTurn { - id: self.next_id + i, - kind, - }) - .collect() - } - - fn commit_turn(&mut self) { - if let Some(turn) = self.current_turn.take() { - self.turns.push(turn); - } - } - - fn start_user_turn(&mut self) { - if !matches!(self.current_turn, Some(UiTurnKind::User { .. })) { - self.commit_turn(); - self.current_turn = Some(UiTurnKind::User { events: vec![] }); - } - } - - fn start_agent_turn(&mut self) { - if !matches!(self.current_turn, Some(UiTurnKind::Agent { .. })) { - self.commit_turn(); - self.current_turn = Some(UiTurnKind::Agent { events: vec![] }); - } - } - - fn start_out_of_band_turn(&mut self) { - if !matches!(self.current_turn, Some(UiTurnKind::OutOfBand { .. })) { - self.commit_turn(); - self.current_turn = Some(UiTurnKind::OutOfBand { events: vec![] }); - } - } - - fn current_events_mut(&mut self) -> &mut Vec<UiEvent> { - match self.current_turn.as_mut().unwrap() { - UiTurnKind::User { events } - | UiTurnKind::Agent { events } - | UiTurnKind::OutOfBand { events } => events, - } - } - - fn add_user_message(&mut self, content: &str) { - self.start_user_turn(); - self.current_events_mut().push(UiEvent::Text { - content: content.to_string(), - }); - } - - fn add_agent_text(&mut self, content: &str) { - if content.trim().is_empty() { - return; - } - self.start_agent_turn(); - self.current_events_mut().push(UiEvent::Text { - content: content.to_string(), - }); - } - - fn add_suggested_command(&mut self, input: &serde_json::Value) { - let command = input - .get("command") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - if command.is_empty() { - return; - } - - self.start_agent_turn(); - { - let events = self.current_events_mut(); - let danger_level = input - .get("danger") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let confidence_level = input - .get("confidence") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let danger_notes = input - .get("danger_notes") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let confidence_notes = input - .get("confidence_notes") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let danger = DangerLevel::from((&danger_level, &danger_notes)); - let confidence = ConfidenceLevel::from((&confidence_level, &confidence_notes)); - - events.push(UiEvent::SuggestedCommand(SuggestedCommandDetails { - command: input - .get("command") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(), - danger_level: danger, - confidence_level: confidence, - })); - } - } - - fn add_tool_call(&mut self, id: &str, name: &str, _input: &serde_json::Value) { - let render_data = self.build_render_data(id, name); - - self.start_agent_turn(); - self.current_events_mut() - .push(UiEvent::ToolCall(ToolCallDetails { - tool_use_id: id.to_string(), - name: name.to_string(), - status: ToolResultStatus::Pending, - render_data, - })); - } - - /// Build tool-type-specific render data from the ToolTracker. - /// - /// For client-side tools, the tracker holds the typed `ClientToolCall` and - /// any live/cached preview data. For server-side (or unknown) tools, we - /// fall back to `ToolRenderData::Remote`. - fn build_render_data(&self, id: &str, _name: &str) -> ToolRenderData { - if let Some(tracked) = self.tracker.get(id) { - match &tracked.tool { - ClientToolCall::Shell(shell) => ToolRenderData::Shell { - command: shell.command.clone(), - preview: tracked.shell_preview(), - }, - ClientToolCall::Read(read) => ToolRenderData::FileRead { - path: read.path.clone(), - }, - ClientToolCall::Edit(edit) => ToolRenderData::FileEdit { - path: edit.path.clone(), - preview: tracked.edit_preview().cloned(), - }, - ClientToolCall::Write(write) => ToolRenderData::FileWrite { - path: write.path.clone(), - preview: tracked.write_preview().cloned(), - }, - ClientToolCall::AtuinHistory(history) => ToolRenderData::HistorySearch { - query: history.query.clone(), - filter_modes: history.filter_modes.clone(), - }, - ClientToolCall::AtuinOutput(_) => ToolRenderData::Remote, - ClientToolCall::LoadSkill(skill) => ToolRenderData::SkillLoad { - _name: skill.name.clone(), - }, - } - } else { - // Not in tracker → server-side tool - ToolRenderData::Remote - } - } - - fn add_tool_result(&mut self, tool_use_id: &str, _content: &str, is_error: bool) { - self.start_agent_turn(); - let events = self.current_events_mut(); - let event = events.iter_mut().find(|e| match e { - UiEvent::ToolCall(ToolCallDetails { - tool_use_id: id, .. - }) => id == tool_use_id, - _ => false, - }); - if let Some(UiEvent::ToolCall(ToolCallDetails { status, .. })) = event { - *status = if is_error { - ToolResultStatus::Error - } else { - ToolResultStatus::Success - }; - } - } - - fn add_out_of_band_output(&mut self, _name: &str, command: Option<&str>, content: &str) { - self.start_out_of_band_turn(); - self.current_events_mut() - .push(UiEvent::OutOfBandOutput(OutOfBandOutputDetails { - command: command.map(|c| c.to_string()), - content: content.to_string(), - })); - } -} - -/// Drain pending remote tool calls into a `ToolSummary`. -fn flush_remote(pending: &mut Vec<ToolCallDetails>, out: &mut Vec<UiEvent>) { - if !pending.is_empty() { - out.push(UiEvent::ToolSummary(ToolSummary { - tool_calls: std::mem::take(pending), - })); - } -} - -/// Drain a pending client-side tool group into a `ToolGroup`. -fn flush_group( - pending: &mut Option<(ToolGroupKind, Vec<ToolCallDetails>)>, - out: &mut Vec<UiEvent>, -) { - if let Some((kind, calls)) = pending.take() { - out.push(UiEvent::ToolGroup(ToolGroup { kind, calls })); - } -} - -#[derive(Debug)] -pub(crate) struct ToolSummary { - tool_calls: Vec<ToolCallDetails>, -} - -impl ToolSummary { - /// Determines the summary line: - /// - If any call is pending, use present tense verb with `-ing` - /// - If multiple calls are complete, say "Used n tools" - /// - If a single call is complete, use past tense verb - pub(crate) fn summary(&self) -> String { - if self.any_pending() { - // Find the last pending tool for the active verb - if let Some(pending) = self - .tool_calls - .iter() - .rev() - .find(|t| t.status == ToolResultStatus::Pending) - { - return Self::progressive_verb(&pending.name); - } - } - - if self.tool_calls.len() == 1 { - return Self::past_verb(&self.tool_calls[0].name); - } - - format!("Used {} tools", self.tool_calls.len()) - } - - /// Determines if the spinner should be spinning - pub(crate) fn any_pending(&self) -> bool { - self.tool_calls - .iter() - .any(|tool_call| tool_call.status == ToolResultStatus::Pending) - } - - /// Present-tense progressive verb for a tool name (e.g. "Searching...") - fn progressive_verb(name: &str) -> String { - descriptor::by_name(name) - .map(|d| d.progressive_verb.to_string()) - .unwrap_or_else(|| format!("Running {}...", name.replace('_', " "))) - } - - /// Past-tense verb for a tool name (e.g. "Searched") - fn past_verb(name: &str) -> String { - descriptor::by_name(name) - .map(|d| d.past_verb.to_string()) - .unwrap_or_else(|| format!("Ran {}", name.replace('_', " "))) - } -} diff --git a/crates/atuin-ai/src/user_context/interpolate.rs b/crates/atuin-ai/src/user_context/interpolate.rs deleted file mode 100644 index 91e34ab4..00000000 --- a/crates/atuin-ai/src/user_context/interpolate.rs +++ /dev/null @@ -1,279 +0,0 @@ -//! Parse `.atuin/ai-context.md` files and execute embedded commands. -//! -//! Two interpolation syntaxes are supported: -//! -//! **Inline:** `!`command`` — the `!` immediately before a code span triggers -//! execution. The entire `!`...`` span is replaced with the command's stdout. -//! -//! **Block:** -//! ````markdown -//! ```! -//! command -//! ``` -//! ```` -//! A fenced code block with `!` as the info string. The block body is executed -//! as a script and the entire fenced block is replaced with stdout. -//! -//! Regular code spans and fenced code blocks (without `!`) are left untouched. - -use std::ops::Range; -use std::time::Duration; - -use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; - -/// A command to execute, with its byte range in the source for replacement. -#[derive(Debug)] -struct Command { - /// Byte range in the source to replace (includes the `!` for inline, or - /// the full ``` fence for blocks). - range: Range<usize>, - /// The command string to execute. - body: String, -} - -/// Maximum time for a single command. -const COMMAND_TIMEOUT: Duration = Duration::from_secs(5); - -/// Maximum bytes of stdout to capture from a single command. -const MAX_OUTPUT_BYTES: usize = 64 * 1024; - -/// Parse a context file for interpolation commands. -fn parse_commands(source: &str) -> Vec<Command> { - let parser = Parser::new_ext(source, Options::empty()); - let mut commands = Vec::new(); - - // Block state: accumulate text across multiple Text events, finalize on End. - let mut block_start: Option<usize> = None; - let mut block_body = String::new(); - - for (event, range) in parser.into_offset_iter() { - match event { - // Inline: !`command` - Event::Code(code) if range.start > 0 && source.as_bytes()[range.start - 1] == b'!' => { - commands.push(Command { - range: (range.start - 1)..range.end, - body: code.to_string(), - }); - } - - // Block: ```! ... ``` - Event::Start(Tag::CodeBlock(CodeBlockKind::Fenced(info))) if info.as_ref() == "!" => { - block_start = Some(range.start); - block_body.clear(); - } - Event::Text(text) if block_start.is_some() => { - block_body.push_str(&text); - } - Event::End(TagEnd::CodeBlock) if block_start.is_some() => { - let start = block_start.take().unwrap(); - let trimmed = block_body.trim(); - if !trimmed.is_empty() { - commands.push(Command { - range: start..range.end, - body: trimmed.to_string(), - }); - } - block_body.clear(); - } - - _ => {} - } - } - - commands -} - -/// Execute all commands in a context file and return the interpolated content. -/// -/// Commands are executed in parallel. Failed commands are replaced with an -/// error marker so the AI has visibility into what went wrong. -pub(crate) async fn interpolate(source: &str, shell: &str) -> String { - let commands = parse_commands(source); - if commands.is_empty() { - return source.to_string(); - } - - // Execute all commands in parallel. - let mut handles = Vec::with_capacity(commands.len()); - for cmd in &commands { - let shell = shell.to_string(); - let body = cmd.body.clone(); - handles.push(tokio::spawn( - async move { run_command(&shell, &body).await }, - )); - } - - // Collect results. - let mut results = Vec::with_capacity(handles.len()); - for handle in handles { - let output = match handle.await { - Ok(output) => output, - Err(e) => format!("[error: task panicked: {e}]"), - }; - results.push(output); - } - - // Rebuild the source, replacing command ranges with their output. - // Commands are in source order from the parser, but let's sort to be safe. - let mut replacements: Vec<(Range<usize>, &str)> = commands - .iter() - .zip(results.iter()) - .map(|(cmd, output)| (cmd.range.clone(), output.as_str())) - .collect(); - replacements.sort_by_key(|(range, _)| range.start); - - let mut out = String::with_capacity(source.len()); - let mut cursor = 0; - for (range, output) in &replacements { - out.push_str(&source[cursor..range.start]); - out.push_str(output); - cursor = range.end; - } - out.push_str(&source[cursor..]); - - out -} - -async fn run_command(shell: &str, body: &str) -> String { - let result = tokio::time::timeout( - COMMAND_TIMEOUT, - tokio::process::Command::new(shell) - .arg("-c") - .arg(body) - .output(), - ) - .await; - - match result { - Ok(Ok(output)) => { - if output.status.success() { - if output.stdout.len() > MAX_OUTPUT_BYTES { - let truncated = String::from_utf8_lossy(&output.stdout[..MAX_OUTPUT_BYTES]); - format!( - "{}\n[output truncated at {}KB]", - truncated.trim(), - MAX_OUTPUT_BYTES / 1024 - ) - } else { - String::from_utf8_lossy(&output.stdout).trim().to_string() - } - } else { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let code = output.status.code().unwrap_or(-1); - format!("[error: exit code {code}: {stderr}]") - } - } - Ok(Err(e)) => format!("[error: {e}]"), - Err(_) => format!( - "[error: command timed out after {}s]", - COMMAND_TIMEOUT.as_secs() - ), - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_inline_command() { - let source = "Branch: !`git branch --show-current`"; - let cmds = parse_commands(source); - assert_eq!(cmds.len(), 1); - assert_eq!(cmds[0].body, "git branch --show-current"); - assert_eq!( - &source[cmds[0].range.clone()], - "!`git branch --show-current`" - ); - } - - #[test] - fn parse_inline_double_backtick() { - let source = r#"Host: !``echo `hostname` ``"#; - let cmds = parse_commands(source); - assert_eq!(cmds.len(), 1); - assert_eq!(cmds[0].body, "echo `hostname` "); - } - - #[test] - fn parse_block_command() { - let source = "Before\n\n```!\necho hello\npython3 --version\n```\n\nAfter"; - let cmds = parse_commands(source); - assert_eq!(cmds.len(), 1); - assert_eq!(cmds[0].body, "echo hello\npython3 --version"); - } - - #[test] - fn regular_code_not_matched() { - let source = "Normal `code span` and ```bash\necho hi\n```"; - let cmds = parse_commands(source); - assert_eq!(cmds.len(), 0); - } - - #[test] - fn bang_not_adjacent_not_matched() { - let source = "Exclaim! Then `code` here."; - let cmds = parse_commands(source); - // The `!` and backtick are separated by " Then ", not adjacent. - assert_eq!(cmds.len(), 0); - } - - #[test] - fn mixed_content() { - let source = "\ -# Project Context - -Branch: !`git branch --show-current` - -Regular code: `not a command` - -```! -echo $VIRTUAL_ENV -``` - -```bash -echo not interpolated -``` - -End."; - let cmds = parse_commands(source); - assert_eq!(cmds.len(), 2); - assert_eq!(cmds[0].body, "git branch --show-current"); - assert_eq!(cmds[1].body, "echo $VIRTUAL_ENV"); - } - - #[tokio::test] - async fn interpolate_replaces_inline_command() { - let source = "Branch: !`echo main`"; - let result = interpolate(source, "sh").await; - assert_eq!(result, "Branch: main"); - } - - #[tokio::test] - async fn interpolate_replaces_block_command() { - let source = "Before\n\n```!\necho hello world\n```\n\nAfter"; - let result = interpolate(source, "sh").await; - assert_eq!(result, "Before\n\nhello world\n\nAfter"); - } - - #[tokio::test] - async fn interpolate_preserves_non_command_content() { - let source = "Just plain markdown with `code` and no bangs."; - let result = interpolate(source, "sh").await; - assert_eq!(result, source); - } - - #[tokio::test] - async fn interpolate_failed_command_shows_error() { - let source = "Result: !`exit 1`"; - let result = interpolate(source, "sh").await; - assert!(result.starts_with("Result: [error:")); - } - - #[tokio::test] - async fn interpolate_multiple_commands() { - let source = "A: !`echo one` B: !`echo two`"; - let result = interpolate(source, "sh").await; - assert_eq!(result, "A: one B: two"); - } -} diff --git a/crates/atuin-ai/src/user_context/mod.rs b/crates/atuin-ai/src/user_context/mod.rs deleted file mode 100644 index fdeb890b..00000000 --- a/crates/atuin-ai/src/user_context/mod.rs +++ /dev/null @@ -1,68 +0,0 @@ -//! User-authored context files (`TERMINAL.md`). -//! -//! Context files are markdown documents that can embed shell commands for -//! dynamic content. Before each API request, context files are discovered -//! by walking the filesystem, commands are executed, and the interpolated -//! content is sent to the server as `config.user_contexts`. - -pub(crate) mod interpolate; -mod walker; - -use std::path::Path; - -pub(crate) use walker::global_context_path; - -/// A fully resolved user context, ready to include in an API request. -#[derive(Debug, Clone, serde::Serialize)] -pub(crate) struct UserContext { - /// The path to the context file on disk. - pub path: String, - /// The interpolated content. - pub data: String, -} - -/// Discover context files and interpolate embedded commands. -/// -/// Walks from `start` up to the filesystem root looking for -/// `.atuin/ai-context.md`, then checks `global_path`. Returns contexts -/// ordered from most general (global/root) to most specific (deepest). -pub(crate) async fn gather( - start: &Path, - global_path: Option<&Path>, - shell: &str, -) -> Vec<UserContext> { - let raw_files = match walker::walk(start, global_path).await { - Ok(files) => files, - Err(e) => { - tracing::warn!("Failed to walk for context files: {e}"); - return Vec::new(); - } - }; - - if raw_files.is_empty() { - return Vec::new(); - } - - // Interpolate all files in parallel. - let mut handles = Vec::with_capacity(raw_files.len()); - for file in raw_files { - let shell = shell.to_string(); - handles.push(tokio::spawn(async move { - let data = interpolate::interpolate(&file.content, &shell).await; - UserContext { - path: file.path.to_string_lossy().to_string(), - data, - } - })); - } - - let mut contexts = Vec::with_capacity(handles.len()); - for handle in handles { - match handle.await { - Ok(ctx) => contexts.push(ctx), - Err(e) => tracing::warn!("Context interpolation task failed: {e}"), - } - } - - contexts -} diff --git a/crates/atuin-ai/src/user_context/walker.rs b/crates/atuin-ai/src/user_context/walker.rs deleted file mode 100644 index 117bbd33..00000000 --- a/crates/atuin-ai/src/user_context/walker.rs +++ /dev/null @@ -1,90 +0,0 @@ -//! Filesystem traversal for `TERMINAL.md` context files. -//! -//! Walks from the starting directory up to the filesystem root, checking for -//! `.atuin/TERMINAL.md` and `TERMINAL.md` at each level. Then checks the global -//! config directory. Returns files ordered from shallowest (global/root) to -//! deepest (most project-specific), so that context layers naturally from -//! general to specific. - -use std::path::{Path, PathBuf}; - -use eyre::Result; -use tokio::task::JoinSet; - -const CONTEXT_FILENAME: &str = "TERMINAL.md"; - -/// A context file found on disk, before interpolation. -#[derive(Debug)] -pub(crate) struct RawContextFile { - pub path: PathBuf, - pub content: String, -} - -struct FoundFile { - depth: usize, - file: RawContextFile, -} - -/// Walk from `start` up to the filesystem root collecting `TERMINAL.md` -/// context files, then check the global path. Returns files shallowest-first. -/// -/// At each ancestor directory, checks two locations: -/// - `.atuin/TERMINAL.md` (dotdir-scoped) -/// - `TERMINAL.md` (project root) -pub(crate) async fn walk(start: &Path, global_path: Option<&Path>) -> Result<Vec<RawContextFile>> { - let dirs: Vec<PathBuf> = start.ancestors().map(PathBuf::from).collect(); - let dir_count = dirs.len(); - - let mut set: JoinSet<Result<Option<FoundFile>>> = JoinSet::new(); - - for (index, dir) in dirs.into_iter().enumerate() { - let dir2 = dir.clone(); - set.spawn(async move { - load_context_file(&dir.join(".atuin").join(CONTEXT_FILENAME), index).await - }); - set.spawn(async move { load_context_file(&dir2.join(CONTEXT_FILENAME), index).await }); - } - - if let Some(global) = global_path { - let global = global.to_path_buf(); - let depth = dir_count; - set.spawn(async move { load_context_file(&global, depth).await }); - } - - let mut found = Vec::new(); - while let Some(result) = set.join_next().await { - match result? { - Ok(Some(f)) => found.push(f), - Ok(None) => {} - Err(e) => { - tracing::warn!("Error reading context file, skipping: {e}"); - } - } - } - - // Sort shallowest-first (highest depth index = shallowest ancestor). - // The global file has the highest depth index so it sorts last... but we - // actually want global first, then root → cwd. Reverse the depth ordering. - found.sort_by_key(|b| std::cmp::Reverse(b.depth)); - - Ok(found.into_iter().map(|f| f.file).collect()) -} - -/// The default global context file path (`~/.config/atuin/TERMINAL.md`). -pub(crate) fn global_context_path() -> PathBuf { - atuin_common::utils::config_dir().join(CONTEXT_FILENAME) -} - -async fn load_context_file(path: &Path, depth: usize) -> Result<Option<FoundFile>> { - match tokio::fs::read_to_string(path).await { - Ok(content) => Ok(Some(FoundFile { - depth, - file: RawContextFile { - path: path.to_path_buf(), - content, - }, - })), - Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), - Err(e) => Err(e.into()), - } -} diff --git a/crates/atuin-ai/test-renders.json b/crates/atuin-ai/test-renders.json deleted file mode 100644 index 31c180fa..00000000 --- a/crates/atuin-ai/test-renders.json +++ /dev/null @@ -1,295 +0,0 @@ -[ - { - "name": "01_empty_input", - "description": "Initial state with empty input prompt", - "state": { - "events": [], - "mode": "Input", - "input": "", - "cursor_pos": 0 - } - }, - { - "name": "02_typing_input", - "description": "User typing in input field", - "state": { - "events": [], - "mode": "Input", - "input": "list all files", - "cursor_pos": 14 - } - }, - { - "name": "03_generating_spinner", - "description": "Waiting for API response (spinner)", - "state": { - "events": [ - {"type": "user_message", "content": "list all files"} - ], - "mode": "Generating", - "spinner_frame": 0 - } - }, - { - "name": "04_streaming_text", - "description": "Text streaming in from API", - "state": { - "events": [ - {"type": "user_message", "content": "what is rust?"} - ], - "mode": "Streaming", - "streaming_text": "Rust is a systems programming language focused on safety, speed, and", - "spinner_frame": 2 - } - }, - { - "name": "05_simple_command", - "description": "Simple command suggestion", - "state": { - "events": [ - {"type": "user_message", "content": "list all files"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": {"command": "ls -la"}} - ], - "mode": "Review" - } - }, - { - "name": "06_command_with_long_text", - "description": "Command that wraps to multiple lines", - "state": { - "events": [ - {"type": "user_message", "content": "find large files"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": {"command": "find /home -type f -size +100M -exec ls -lh {} \\; 2>/dev/null | sort -k5 -h"}} - ], - "mode": "Review" - } - }, - { - "name": "07_conversation_only_response", - "description": "Response without command (conversation mode)", - "state": { - "events": [ - {"type": "user_message", "content": "what does the -la flag do?"}, - {"type": "text", "content": "The `-la` flags combine two options:\n\n- `-l` shows long format with permissions, owner, size, and date\n- `-a` shows all files including hidden ones (starting with .)"} - ], - "mode": "Review" - } - }, - { - "name": "08_multi_turn_conversation", - "description": "Multiple turns of conversation", - "state": { - "events": [ - {"type": "user_message", "content": "list all files"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": {"command": "ls -la"}}, - {"type": "user_message", "content": "can you explain those flags?"}, - {"type": "text", "content": "The -l flag shows long format with permissions, -a shows hidden files."} - ], - "mode": "Review" - } - }, - { - "name": "09_tool_call_in_progress", - "description": "Tool being executed (spinner)", - "state": { - "events": [ - {"type": "user_message", "content": "what is the latest version of node?"}, - {"type": "tool_call", "id": "1", "name": "web_search", "input": {"query": "nodejs latest version"}} - ], - "mode": "Streaming", - "streaming_text": "", - "spinner_frame": 1 - } - }, - { - "name": "10_tool_calls_completed_with_text", - "description": "Tools finished, text streaming", - "state": { - "events": [ - {"type": "user_message", "content": "what is the latest version of node?"}, - {"type": "tool_call", "id": "1", "name": "web_search", "input": {"query": "nodejs latest version"}}, - {"type": "tool_result", "tool_use_id": "1", "content": "Node.js v22.0.0"} - ], - "mode": "Streaming", - "streaming_text": "The latest version of Node.js is v22.0.0, released in April 2024.", - "spinner_frame": 0 - } - }, - { - "name": "11_tool_calls_in_review", - "description": "Completed tools shown in review mode", - "state": { - "events": [ - {"type": "user_message", "content": "what is the latest version of node?"}, - {"type": "tool_call", "id": "1", "name": "web_search", "input": {"query": "nodejs latest version"}}, - {"type": "tool_result", "tool_use_id": "1", "content": "Node.js v22.0.0"}, - {"type": "tool_call", "id": "2", "name": "web_fetch", "input": {"url": "https://nodejs.org"}}, - {"type": "tool_result", "tool_use_id": "2", "content": "..."}, - {"type": "text", "content": "The latest version of Node.js is **v22.0.0**, released in April 2024. Key features include:\n\n- Native WebSocket client\n- Improved ES modules support\n- Better performance"} - ], - "mode": "Review" - } - }, - { - "name": "12_error_state", - "description": "Error message displayed", - "state": { - "events": [ - {"type": "user_message", "content": "do something"} - ], - "mode": "Error", - "error": "Failed to connect to API: connection timeout" - } - }, - { - "name": "13_dangerous_command", - "description": "Dangerous command with warning", - "state": { - "events": [ - {"type": "user_message", "content": "delete all files in home"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": { - "command": "rm -rf ~/*", - "dangerous": true, - "warning": "This will permanently delete all files in your home directory including documents, configurations, and SSH keys." - }} - ], - "mode": "Review", - "confirmation_pending": false - } - }, - { - "name": "14_dangerous_command_confirming", - "description": "Dangerous command awaiting second Enter", - "state": { - "events": [ - {"type": "user_message", "content": "delete all files in home"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": { - "command": "rm -rf ~/*", - "dangerous": true, - "warning": "This will permanently delete all files in your home directory." - }} - ], - "mode": "Review", - "confirmation_pending": true - } - }, - { - "name": "15_low_confidence", - "description": "Low confidence command with warning", - "state": { - "events": [ - {"type": "user_message", "content": "do that thing with the files"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": { - "command": "ls -la", - "confidence": "low", - "warning": "I'm not entirely sure what you mean by 'that thing'. This lists files - is that what you wanted?" - }} - ], - "mode": "Review" - } - }, - { - "name": "16_long_user_input", - "description": "User input that wraps", - "state": { - "events": [ - {"type": "user_message", "content": "I need a command that will find all JavaScript files in my project, excluding node_modules, and count the total lines of code"} - ], - "mode": "Generating", - "spinner_frame": 0 - } - }, - { - "name": "17_long_text_response", - "description": "Long text response that wraps multiple times", - "state": { - "events": [ - {"type": "user_message", "content": "explain git"}, - {"type": "text", "content": "Git is a distributed version control system created by Linus Torvalds in 2005. It tracks changes to files and enables collaboration between developers. Key concepts include:\n\n- **Repository**: A directory containing your project and its history\n- **Commit**: A snapshot of your changes with a message\n- **Branch**: An independent line of development\n- **Merge**: Combining changes from different branches\n- **Remote**: A version of your repository hosted elsewhere (like GitHub)"} - ], - "mode": "Review" - } - }, - { - "name": "18_streaming_with_tool_in_progress", - "description": "Tool in progress while streaming", - "state": { - "events": [ - {"type": "user_message", "content": "search for rust async patterns"}, - {"type": "text", "content": "Let me search for that..."}, - {"type": "tool_call", "id": "1", "name": "web_search", "input": {"query": "rust async patterns"}} - ], - "mode": "Streaming", - "streaming_text": "", - "spinner_frame": 2 - } - }, - { - "name": "19_multiple_commands_in_conversation", - "description": "Multiple command suggestions across turns", - "state": { - "events": [ - {"type": "user_message", "content": "create a new directory called test"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": {"command": "mkdir test"}}, - {"type": "user_message", "content": "now cd into it"}, - {"type": "tool_call", "id": "2", "name": "suggest_command", "input": {"command": "cd test"}}, - {"type": "user_message", "content": "create a file"}, - {"type": "tool_call", "id": "3", "name": "suggest_command", "input": {"command": "touch file.txt"}} - ], - "mode": "Review" - } - }, - { - "name": "20_empty_command_with_description", - "description": "Tool call with null command (conversation only)", - "state": { - "events": [ - {"type": "user_message", "content": "what's the weather like?"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": { - "command": null, - "description": "I can't check the weather directly, but you could use: curl wttr.in" - }} - ], - "mode": "Review" - } - }, - { - "name": "21_status_processing", - "description": "Streaming with Processing status", - "state": { - "events": [ - {"type": "user_message", "content": "analyze this code"} - ], - "mode": "Streaming", - "streaming_text": "", - "streaming_status": "Processing", - "spinner_frame": 0 - } - }, - { - "name": "22_status_thinking", - "description": "Streaming with Thinking status", - "state": { - "events": [ - {"type": "user_message", "content": "how do I optimize this query?"} - ], - "mode": "Streaming", - "streaming_text": "", - "streaming_status": "Thinking", - "spinner_frame": 1 - } - }, - { - "name": "23_follow_up_input", - "description": "Follow-up input after command", - "state": { - "events": [ - {"type": "user_message", "content": "list files"}, - {"type": "tool_call", "id": "1", "name": "suggest_command", "input": {"command": "ls -la"}} - ], - "mode": "Input", - "input": "but only show directories", - "cursor_pos": 24 - } - } -] diff --git a/crates/atuin-daemon/Cargo.toml b/crates/atuin-daemon/Cargo.toml index 97bda630..e767d3c9 100644 --- a/crates/atuin-daemon/Cargo.toml +++ b/crates/atuin-daemon/Cargo.toml @@ -16,7 +16,6 @@ readme.workspace = true [dependencies] atuin-client = { path = "../atuin-client", version = "18.16.1" } atuin-common = { path = "../atuin-common", version = "18.16.1" } -atuin-dotfiles = { path = "../atuin-dotfiles", version = "18.16.1" } atuin-history = { path = "../atuin-history", version = "18.16.1" } time = { workspace = true } diff --git a/crates/atuin-daemon/src/components/sync.rs b/crates/atuin-daemon/src/components/sync.rs index a342f700..6e486250 100644 --- a/crates/atuin-daemon/src/components/sync.rs +++ b/crates/atuin-daemon/src/components/sync.rs @@ -10,7 +10,6 @@ use tokio::sync::mpsc; use tokio::time::{self, MissedTickBehavior}; use atuin_client::{history::store::HistoryStore, record::sync, settings::Settings}; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; use crate::{ daemon::{Component, DaemonHandle}, @@ -123,8 +122,6 @@ async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand> // Create the stores we need let encryption_key = *handle.encryption_key(); let history_store = HistoryStore::new(handle.store().clone(), host_id, encryption_key); - let alias_store = AliasStore::new(handle.store().clone(), host_id, encryption_key); - let var_store = VarStore::new(handle.store().clone(), host_id, encryption_key); // Don't backoff by more than 30 mins (with a random jitter of up to 1 min) let max_interval: f64 = 60.0 * 30.0 + rand::thread_rng().gen_range(0.0..60.0); @@ -152,8 +149,6 @@ async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand> sync_state = do_sync_tick( &handle, &history_store, - &alias_store, - &var_store, &mut ticker, max_interval, &settings, @@ -167,8 +162,6 @@ async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand> sync_state = do_sync_tick( &handle, &history_store, - &alias_store, - &var_store, &mut ticker, max_interval, &settings, @@ -190,8 +183,6 @@ async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver<SyncCommand> async fn do_sync_tick( handle: &DaemonHandle, history_store: &HistoryStore, - alias_store: &AliasStore, - var_store: &VarStore, ticker: &mut time::Interval, max_interval: f64, settings: &Settings, @@ -267,14 +258,6 @@ async fn do_sync_tick( downloaded: downloaded_records.len(), }); - // Rebuild alias and var stores - if let Err(e) = alias_store.build().await { - tracing::error!("failed to rebuild alias store: {e}"); - } - if let Err(e) = var_store.build().await { - tracing::error!("failed to rebuild var store: {e}"); - } - // Reset backoff on success if ticker.period().as_secs() != settings.daemon.sync_frequency { *ticker = time::interval_at( diff --git a/crates/atuin-dotfiles/Cargo.toml b/crates/atuin-dotfiles/Cargo.toml deleted file mode 100644 index 183091b3..00000000 --- a/crates/atuin-dotfiles/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "atuin-dotfiles" -description = "The dotfiles crate for Atuin" -edition = "2024" -version = { workspace = true } - -authors.workspace = true -rust-version.workspace = true -license.workspace = true -homepage.workspace = true -repository.workspace = true -readme.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -atuin-common = { path = "../atuin-common", version = "18.16.1" } -atuin-client = { path = "../atuin-client", version = "18.16.1" } - -eyre = { workspace = true } -tokio = { workspace = true } -rmp = { version = "0.8.14" } -rand = { workspace = true } -serde = { workspace = true } -crypto_secretbox = "0.1.1" diff --git a/crates/atuin-dotfiles/src/lib.rs b/crates/atuin-dotfiles/src/lib.rs deleted file mode 100644 index 74daf8ef..00000000 --- a/crates/atuin-dotfiles/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod shell; -pub mod store; diff --git a/crates/atuin-dotfiles/src/shell.rs b/crates/atuin-dotfiles/src/shell.rs deleted file mode 100644 index 73a9ce8c..00000000 --- a/crates/atuin-dotfiles/src/shell.rs +++ /dev/null @@ -1,241 +0,0 @@ -use eyre::{Result, ensure, eyre}; -use rmp::{decode, encode}; -use serde::Serialize; - -use atuin_common::shell::{Shell, ShellError}; - -use crate::store::AliasStore; - -pub mod bash; -pub mod fish; -pub mod powershell; -pub mod xonsh; -pub mod zsh; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct Alias { - pub name: String, - pub value: String, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub struct Var { - pub name: String, - pub value: String, - - // False? This is a _shell var_ - // True? This is an _env var_ - pub export: bool, -} - -impl Var { - /// Serialize into the given vec - /// This is intended to be called by the store - pub fn serialize(&self, output: &mut Vec<u8>) -> Result<()> { - encode::write_array_len(output, 3)?; // 3 fields - - encode::write_str(output, self.name.as_str())?; - encode::write_str(output, self.value.as_str())?; - encode::write_bool(output, self.export)?; - - Ok(()) - } - - pub fn deserialize(bytes: &mut decode::Bytes) -> Result<Self> { - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let nfields = decode::read_array_len(bytes).map_err(error_report)?; - - ensure!( - nfields == 3, - "too many entries in v0 dotfiles env create record, got {}, expected {}", - nfields, - 3 - ); - - let bytes = bytes.remaining_slice(); - - let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = decode::Bytes::new(bytes); - let export = decode::read_bool(&mut bytes).map_err(error_report)?; - - ensure!( - bytes.remaining_slice().is_empty(), - "trailing bytes in encoded dotfiles env record, malformed" - ); - - Ok(Var { - name: key.to_owned(), - value: value.to_owned(), - export, - }) - } -} - -pub fn parse_alias(line: &str) -> Option<Alias> { - // consider the fact we might be importing a fish alias - // 'alias' output - // fish: alias foo bar - // posix: foo=bar - - let is_fish = line.split(' ').next().unwrap_or("") == "alias"; - - let parts: Vec<&str> = if is_fish { - line.split(' ') - .enumerate() - .filter_map(|(n, i)| if n == 0 { None } else { Some(i) }) - .collect() - } else { - line.split('=').collect() - }; - - if parts.len() <= 1 { - return None; - } - - let mut parts = parts.iter().map(|s| s.to_string()); - - let name = parts.next().unwrap(); - - let remaining = if is_fish { - parts.collect::<Vec<String>>().join(" ") - } else { - parts.collect::<Vec<String>>().join("=") - }; - - Some(Alias { - name, - value: remaining.trim().to_string(), - }) -} - -pub fn existing_aliases(shell: Option<Shell>) -> Result<Vec<Alias>, ShellError> { - let shell = if let Some(shell) = shell { - shell - } else { - Shell::current() - }; - - // this only supports posix-y shells atm - if !shell.is_posixish() { - return Err(ShellError::NotSupported); - } - - // This will return a list of aliases, each on its own line - // They will be in the form foo=bar - let aliases = shell.run_interactive(["alias"])?; - - let aliases: Vec<Alias> = aliases.lines().filter_map(parse_alias).collect(); - - Ok(aliases) -} - -/// Import aliases from the current shell -/// This will not import aliases already in the store -/// Returns aliases that were set -pub async fn import_aliases(store: &AliasStore) -> Result<Vec<Alias>> { - let shell_aliases = existing_aliases(None)?; - let store_aliases = store.aliases().await?; - - let mut res = Vec::new(); - - for alias in shell_aliases { - // O(n), but n is small, and imports infrequent - // can always make a map - if store_aliases.contains(&alias) { - continue; - } - - res.push(alias.clone()); - store.set(&alias.name, &alias.value).await?; - } - - Ok(res) -} - -#[cfg(test)] -mod tests { - use crate::shell::{Alias, parse_alias}; - - #[test] - fn test_parse_simple_alias() { - let alias = super::parse_alias("foo=bar").expect("failed to parse alias"); - assert_eq!(alias.name, "foo"); - assert_eq!(alias.value, "bar"); - } - - #[test] - fn test_parse_quoted_alias() { - let alias = super::parse_alias("emacs='TERM=xterm-24bits emacs -nw'") - .expect("failed to parse alias"); - - assert_eq!(alias.name, "emacs"); - assert_eq!(alias.value, "'TERM=xterm-24bits emacs -nw'"); - - let git_alias = super::parse_alias("gwip='git add -A; git rm $(git ls-files --deleted) 2> /dev/null; git commit --no-verify --no-gpg-sign --message \"--wip-- [skip ci]\"'").expect("failed to parse alias"); - assert_eq!(git_alias.name, "gwip"); - assert_eq!( - git_alias.value, - "'git add -A; git rm $(git ls-files --deleted) 2> /dev/null; git commit --no-verify --no-gpg-sign --message \"--wip-- [skip ci]\"'" - ); - } - - #[test] - fn test_parse_quoted_alias_equals() { - let alias = super::parse_alias("emacs='TERM=xterm-24bits emacs -nw --foo=bar'") - .expect("failed to parse alias"); - assert_eq!(alias.name, "emacs"); - assert_eq!(alias.value, "'TERM=xterm-24bits emacs -nw --foo=bar'"); - } - - #[test] - fn test_parse_fish() { - let alias = super::parse_alias("alias foo bar").expect("failed to parse alias"); - assert_eq!(alias.name, "foo"); - assert_eq!(alias.value, "bar"); - - let alias = - super::parse_alias("alias x 'exa --icons --git --classify --group-directories-first'") - .expect("failed to parse alias"); - - assert_eq!(alias.name, "x"); - assert_eq!( - alias.value, - "'exa --icons --git --classify --group-directories-first'" - ); - } - - #[test] - fn test_parse_with_fortune() { - // Because we run the alias command in an interactive subshell - // there may be other output. - // Ensure that the parser can handle it - // Annoyingly not all aliases are picked up all the time if we use - // a non-interactive subshell. Boo. - let shell = " -/ In a consumer society there are \\ -| inevitably two kinds of slaves: the | -| prisoners of addiction and the | -\\ prisoners of envy. / - ------------------------------------- - \\ ^__^ - \\ (oo)\\_______ - (__)\\ )\\/\\ - ||----w | - || || -emacs='TERM=xterm-24bits emacs -nw --foo=bar' -k=kubectl -"; - - let aliases: Vec<Alias> = shell.lines().filter_map(parse_alias).collect(); - assert_eq!(aliases[0].name, "emacs"); - assert_eq!(aliases[0].value, "'TERM=xterm-24bits emacs -nw --foo=bar'"); - - assert_eq!(aliases[1].name, "k"); - assert_eq!(aliases[1].value, "kubectl"); - } -} diff --git a/crates/atuin-dotfiles/src/shell/bash.rs b/crates/atuin-dotfiles/src/shell/bash.rs deleted file mode 100644 index 2b9b4c88..00000000 --- a/crates/atuin-dotfiles/src/shell/bash.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::path::PathBuf; - -use crate::store::{AliasStore, var::VarStore}; - -async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new aliases on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -async fn cached_vars(path: PathBuf, store: &VarStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(vars) => vars, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new vars on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate vars: \n{r}\n{e}'",) - }) - } - } -} - -/// Return bash dotfile config -/// -/// Do not return an error. We should not prevent the shell from starting. -/// -/// In the worst case, Atuin should not function but the shell should start correctly. -/// -/// While currently this only returns aliases, it will be extended to also return other synced dotfiles -pub async fn alias_config(store: &AliasStore) -> String { - // First try to read the cached config - let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.bash"); - - if aliases.exists() { - return cached_aliases(aliases, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_aliases(aliases, store).await -} - -pub async fn var_config(store: &VarStore) -> String { - // First try to read the cached config - let vars = atuin_common::utils::dotfiles_cache_dir().join("vars.bash"); - - if vars.exists() { - return cached_vars(vars, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate vars: {e}'"); - } - - cached_vars(vars, store).await -} diff --git a/crates/atuin-dotfiles/src/shell/fish.rs b/crates/atuin-dotfiles/src/shell/fish.rs deleted file mode 100644 index 6d472f67..00000000 --- a/crates/atuin-dotfiles/src/shell/fish.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Configuration for fish -use std::path::PathBuf; - -use crate::store::{AliasStore, var::VarStore}; - -async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new aliases on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -async fn cached_vars(path: PathBuf, store: &VarStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(vars) => vars, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new vars on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate vars: \n{r}\n{e}'",) - }) - } - } -} - -/// Return fish dotfile config -/// -/// Do not return an error. We should not prevent the shell from starting. -/// -/// In the worst case, Atuin should not function but the shell should start correctly. -/// -/// While currently this only returns aliases, it will be extended to also return other synced dotfiles -pub async fn alias_config(store: &AliasStore) -> String { - // First try to read the cached config - let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.fish"); - - if aliases.exists() { - return cached_aliases(aliases, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_aliases(aliases, store).await -} - -pub async fn var_config(store: &VarStore) -> String { - // First try to read the cached config - let vars = atuin_common::utils::dotfiles_cache_dir().join("vars.fish"); - - if vars.exists() { - return cached_vars(vars, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate vars: {e}'"); - } - - cached_vars(vars, store).await -} diff --git a/crates/atuin-dotfiles/src/shell/powershell.rs b/crates/atuin-dotfiles/src/shell/powershell.rs deleted file mode 100644 index 1daee28b..00000000 --- a/crates/atuin-dotfiles/src/shell/powershell.rs +++ /dev/null @@ -1,169 +0,0 @@ -use crate::shell::{Alias, Var}; -use crate::store::{AliasStore, var::VarStore}; -use std::path::PathBuf; - -async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new aliases on the fly - - store.powershell().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -async fn cached_vars(path: PathBuf, store: &VarStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(vars) => vars, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new vars on the fly - - store.powershell().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate vars: \n{r}\n{e}'",) - }) - } - } -} - -/// Return powershell dotfile config -/// -/// Do not return an error. We should not prevent the shell from starting. -/// -/// In the worst case, Atuin should not function but the shell should start correctly. -/// -/// While currently this only returns aliases, it will be extended to also return other synced dotfiles -pub async fn alias_config(store: &AliasStore) -> String { - // First try to read the cached config - let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.ps1"); - - if aliases.exists() { - return cached_aliases(aliases, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_aliases(aliases, store).await -} - -pub async fn var_config(store: &VarStore) -> String { - // First try to read the cached config - let vars = atuin_common::utils::dotfiles_cache_dir().join("vars.ps1"); - - if vars.exists() { - return cached_vars(vars, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate vars: {e}'"); - } - - cached_vars(vars, store).await -} - -pub fn format_alias(alias: &Alias) -> String { - // Set-Alias doesn't support adding implicit arguments, so use a function. - // See https://github.com/PowerShell/PowerShell/issues/12962 - - let mut result = secure_command(&format!( - "function {} {{\n {}{} @args\n}}", - alias.name, - if alias.value.starts_with(['"', '\'']) { - "& " - } else { - "" - }, - alias.value - )); - - // This makes the file layout prettier - result.insert(0, '\n'); - result -} - -pub fn format_var(var: &Var) -> String { - secure_command(&format!( - "${}{} = '{}'", - if var.export { "env:" } else { "" }, - var.name, - var.value.replace("'", "''") - )) -} - -/// Wraps the given command in an Invoke-Expression to ensure the outer script is not halted -/// if the inner command contains a syntax error. -fn secure_command(command: &str) -> String { - format!( - "Invoke-Expression -ErrorAction Continue -Command '{}'\n", - command.replace("'", "''") - ) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn aliases() { - assert_eq!( - format_alias(&Alias { - name: "gp".to_string(), - value: "git push".to_string(), - }), - "\n".to_string() - + &secure_command( - "function gp { - git push @args -}" - ) - ); - - assert_eq!( - format_alias(&Alias { - name: "spc".to_string(), - value: "\"path with spaces\" arg".to_string(), - }), - "\n".to_string() - + &secure_command( - "function spc { - & \"path with spaces\" arg @args -}" - ) - ); - } - - #[test] - fn vars() { - assert_eq!( - format_var(&Var { - name: "FOO".to_owned(), - value: "bar 'baz'".to_owned(), - export: true, - }), - secure_command("$env:FOO = 'bar ''baz'''") - ); - - assert_eq!( - format_var(&Var { - name: "TEST".to_owned(), - value: "1".to_owned(), - export: false, - }), - secure_command("$TEST = '1'") - ); - } - - #[test] - fn invoke_expression() { - assert_eq!( - secure_command("echo 'foo'"), - "Invoke-Expression -ErrorAction Continue -Command 'echo ''foo'''\n" - ) - } -} diff --git a/crates/atuin-dotfiles/src/shell/xonsh.rs b/crates/atuin-dotfiles/src/shell/xonsh.rs deleted file mode 100644 index 1e56fc1d..00000000 --- a/crates/atuin-dotfiles/src/shell/xonsh.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::path::PathBuf; - -use crate::store::{AliasStore, var::VarStore}; - -async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new aliases on the fly - - store.xonsh().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -async fn cached_vars(path: PathBuf, store: &VarStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(vars) => vars, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new vars on the fly - - store.xonsh().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate vars: \n{r}\n{e}'",) - }) - } - } -} - -/// Return xonsh dotfile config -/// -/// Do not return an error. We should not prevent the shell from starting. -/// -/// In the worst case, Atuin should not function but the shell should start correctly. -/// -/// While currently this only returns aliases, it will be extended to also return other synced dotfiles -pub async fn alias_config(store: &AliasStore) -> String { - // First try to read the cached config - let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.xsh"); - - if aliases.exists() { - return cached_aliases(aliases, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_aliases(aliases, store).await -} - -pub async fn var_config(store: &VarStore) -> String { - // First try to read the cached config - let vars = atuin_common::utils::dotfiles_cache_dir().join("vars.xsh"); - - if vars.exists() { - return cached_vars(vars, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate vars: {e}'"); - } - - cached_vars(vars, store).await -} diff --git a/crates/atuin-dotfiles/src/shell/zsh.rs b/crates/atuin-dotfiles/src/shell/zsh.rs deleted file mode 100644 index 117e9403..00000000 --- a/crates/atuin-dotfiles/src/shell/zsh.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::path::PathBuf; - -use crate::store::{AliasStore, var::VarStore}; - -async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new aliases on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -async fn cached_vars(path: PathBuf, store: &VarStore) -> String { - match tokio::fs::read_to_string(path).await { - Ok(aliases) => aliases, - Err(r) => { - // we failed to read the file for some reason, but the file does exist - // fallback to generating new vars on the fly - - store.posix().await.unwrap_or_else(|e| { - format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) - }) - } - } -} - -/// Return zsh dotfile config -/// -/// Do not return an error. We should not prevent the shell from starting. -/// -/// In the worst case, Atuin should not function but the shell should start correctly. -/// -/// While currently this only returns aliases, it will be extended to also return other synced dotfiles -pub async fn alias_config(store: &AliasStore) -> String { - // First try to read the cached config - let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.zsh"); - - if aliases.exists() { - return cached_aliases(aliases, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_aliases(aliases, store).await -} - -pub async fn var_config(store: &VarStore) -> String { - // First try to read the cached config - let vars = atuin_common::utils::dotfiles_cache_dir().join("vars.zsh"); - - if vars.exists() { - return cached_vars(vars, store).await; - } - - if let Err(e) = store.build().await { - return format!("echo 'Atuin: failed to generate aliases: {e}'"); - } - - cached_vars(vars, store).await -} diff --git a/crates/atuin-dotfiles/src/store.rs b/crates/atuin-dotfiles/src/store.rs deleted file mode 100644 index 17597065..00000000 --- a/crates/atuin-dotfiles/src/store.rs +++ /dev/null @@ -1,421 +0,0 @@ -use std::collections::BTreeMap; - -use atuin_client::record::sqlite_store::SqliteStore; -// Sync aliases -// This will be noticeable similar to the kv store, though I expect the two shall diverge -// While we will support a range of shell config, I'd rather have a larger number of small records -// + stores, rather than one mega config store. -use atuin_common::record::{DecryptedData, Host, HostId}; -use atuin_common::utils::unquote; -use eyre::{Result, bail, ensure, eyre}; - -use atuin_client::record::encryption::PASETO_V4; -use atuin_client::record::store::Store; - -use crate::shell::Alias; - -const CONFIG_SHELL_ALIAS_VERSION: &str = "v0"; -const CONFIG_SHELL_ALIAS_TAG: &str = "config-shell-alias"; -const CONFIG_SHELL_ALIAS_FIELD_MAX_LEN: usize = 20000; // 20kb max total len, way more than should be needed. - -mod alias; -pub mod var; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AliasRecord { - Create(Alias), // create a full record - Delete(String), // delete by name -} - -impl AliasRecord { - pub fn serialize(&self) -> Result<DecryptedData> { - use rmp::encode; - - let mut output = vec![]; - - match self { - AliasRecord::Create(alias) => { - encode::write_u8(&mut output, 0)?; // create - encode::write_array_len(&mut output, 2)?; // 2 fields - - encode::write_str(&mut output, alias.name.as_str())?; - encode::write_str(&mut output, alias.value.as_str())?; - } - AliasRecord::Delete(name) => { - encode::write_u8(&mut output, 1)?; // delete - encode::write_array_len(&mut output, 1)?; // 1 field - - encode::write_str(&mut output, name.as_str())?; - } - } - - Ok(DecryptedData(output)) - } - - pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match version { - CONFIG_SHELL_ALIAS_VERSION => { - let mut bytes = decode::Bytes::new(&data.0); - - let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; - - match record_type { - // create - 0 => { - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!( - nfields == 2, - "too many entries in v0 shell alias create record" - ); - - let bytes = bytes.remaining_slice(); - - let (key, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - let (value, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded shell alias record. malformed") - } - - Ok(AliasRecord::Create(Alias { - name: key.to_owned(), - value: value.to_owned(), - })) - } - - // delete - 1 => { - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!( - nfields == 1, - "too many entries in v0 shell alias delete record" - ); - - let bytes = bytes.remaining_slice(); - - let (key, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded shell alias record. malformed") - } - - Ok(AliasRecord::Delete(key.to_owned())) - } - - n => { - bail!("unknown AliasRecord type {n}") - } - } - } - _ => { - bail!("unknown version {version:?}") - } - } - } -} - -#[derive(Debug, Clone)] -pub struct AliasStore { - pub store: SqliteStore, - pub host_id: HostId, - pub encryption_key: [u8; 32], -} - -impl AliasStore { - // will want to init the actual kv store when that is done - pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> AliasStore { - AliasStore { - store, - host_id, - encryption_key, - } - } - - pub async fn posix(&self) -> Result<String> { - let aliases = self.aliases().await?; - Ok(Self::format_posix(&aliases)) - } - - pub async fn xonsh(&self) -> Result<String> { - let aliases = self.aliases().await?; - Ok(Self::format_xonsh(&aliases)) - } - - pub async fn powershell(&self) -> Result<String> { - let aliases = self.aliases().await?; - Ok(Self::format_powershell(&aliases)) - } - - fn format_posix(aliases: &[Alias]) -> String { - let mut config = String::new(); - - for alias in aliases { - // If it's quoted, remove the quotes. If it's not quoted, do nothing. - let value = unquote(alias.value.as_str()).unwrap_or(alias.value.clone()); - - // we're about to quote it ourselves anyway! - config.push_str(&format!("alias {}='{}'\n", alias.name, value)); - } - - config - } - - fn format_xonsh(aliases: &[Alias]) -> String { - let mut config = String::new(); - - for alias in aliases { - config.push_str(&format!("aliases['{}'] ='{}'\n", alias.name, alias.value)); - } - - config - } - - fn format_powershell(aliases: &[Alias]) -> String { - let mut config = String::new(); - - for alias in aliases { - config.push_str(&crate::shell::powershell::format_alias(alias)); - } - - config - } - - pub async fn build(&self) -> Result<()> { - let dir = atuin_common::utils::dotfiles_cache_dir(); - tokio::fs::create_dir_all(dir.clone()).await?; - - let aliases = self.aliases().await?; - - // Build for all supported shells - let posix = Self::format_posix(&aliases); - let xonsh = Self::format_xonsh(&aliases); - let powershell = Self::format_powershell(&aliases); - - // All the same contents, maybe optimize in the future or perhaps there will be quirks - // per-shell - // I'd prefer separation atm - let zsh = dir.join("aliases.zsh"); - let bash = dir.join("aliases.bash"); - let fish = dir.join("aliases.fish"); - let xsh = dir.join("aliases.xsh"); - let ps1 = dir.join("aliases.ps1"); - - tokio::fs::write(zsh, &posix).await?; - tokio::fs::write(bash, &posix).await?; - tokio::fs::write(fish, &posix).await?; - tokio::fs::write(xsh, &xonsh).await?; - tokio::fs::write(ps1, &powershell).await?; - - Ok(()) - } - - pub async fn set(&self, name: &str, value: &str) -> Result<()> { - if name.len() + value.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN { - return Err(eyre!( - "alias record too large: max len {} bytes", - CONFIG_SHELL_ALIAS_FIELD_MAX_LEN - )); - } - - let record = AliasRecord::Create(Alias { - name: name.to_string(), - value: value.to_string(), - }); - - let bytes = record.serialize()?; - - let idx = self - .store - .last(self.host_id, CONFIG_SHELL_ALIAS_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(self.host_id)) - .version(CONFIG_SHELL_ALIAS_VERSION.to_string()) - .tag(CONFIG_SHELL_ALIAS_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - // set mutates shell config, so build again - self.build().await?; - - Ok(()) - } - - pub async fn delete(&self, name: &str) -> Result<()> { - if name.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN { - return Err(eyre!( - "alias record too large: max len {} bytes", - CONFIG_SHELL_ALIAS_FIELD_MAX_LEN - )); - } - - let record = AliasRecord::Delete(name.to_string()); - - let bytes = record.serialize()?; - - let idx = self - .store - .last(self.host_id, CONFIG_SHELL_ALIAS_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(self.host_id)) - .version(CONFIG_SHELL_ALIAS_VERSION.to_string()) - .tag(CONFIG_SHELL_ALIAS_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - // delete mutates shell config, so build again - self.build().await?; - - Ok(()) - } - - pub async fn aliases(&self) -> Result<Vec<Alias>> { - let mut build = BTreeMap::new(); - - // this is sorted, oldest to newest - let tagged = self.store.all_tagged(CONFIG_SHELL_ALIAS_TAG).await?; - - for record in tagged { - let version = record.version.clone(); - - let decrypted = match version.as_str() { - CONFIG_SHELL_ALIAS_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?, - version => bail!("unknown version {version:?}"), - }; - - let ar = AliasRecord::deserialize(&decrypted.data, version.as_str())?; - - match ar { - AliasRecord::Create(a) => { - build.insert(a.name.clone(), a); - } - AliasRecord::Delete(d) => { - build.remove(&d); - } - } - } - - Ok(build.into_values().collect()) - } -} - -#[cfg(test)] -pub(crate) fn test_local_timeout() -> f64 { - std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") - .ok() - .and_then(|x| x.parse().ok()) - // this hardcoded value should be replaced by a simple way to get the - // default local_timeout of Settings if possible - .unwrap_or(2.0) -} - -#[cfg(test)] -mod tests { - use rand::rngs::OsRng; - - use atuin_client::record::sqlite_store::SqliteStore; - - use crate::shell::Alias; - - use super::{AliasRecord, AliasStore, CONFIG_SHELL_ALIAS_VERSION, test_local_timeout}; - use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; - - #[test] - fn encode_decode() { - let record = Alias { - name: "k".to_owned(), - value: "kubectl".to_owned(), - }; - let record = AliasRecord::Create(record); - - let snapshot = [204, 0, 146, 161, 107, 167, 107, 117, 98, 101, 99, 116, 108]; - - let encoded = record.serialize().unwrap(); - let decoded = AliasRecord::deserialize(&encoded, CONFIG_SHELL_ALIAS_VERSION).unwrap(); - - assert_eq!(encoded.0, &snapshot); - assert_eq!(decoded, record); - } - - #[tokio::test] - async fn build_aliases() { - let store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); - let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); - - let alias = AliasStore::new(store, host_id, key); - - alias.set("k", "kubectl").await.unwrap(); - alias.set("gp", "git push").await.unwrap(); - alias - .set("kgap", "'kubectl get pods --all-namespaces'") - .await - .unwrap(); - - let mut aliases = alias.aliases().await.unwrap(); - - aliases.sort_by_key(|a| a.name.clone()); - - assert_eq!(aliases.len(), 3); - - assert_eq!( - aliases[0], - Alias { - name: String::from("gp"), - value: String::from("git push") - } - ); - - assert_eq!( - aliases[1], - Alias { - name: String::from("k"), - value: String::from("kubectl") - } - ); - - assert_eq!( - aliases[2], - Alias { - name: String::from("kgap"), - value: String::from("'kubectl get pods --all-namespaces'") - } - ); - - let build = alias.posix().await.expect("failed to build aliases"); - - assert_eq!( - build, - "alias gp='git push' -alias k='kubectl' -alias kgap='kubectl get pods --all-namespaces' -" - ) - } -} diff --git a/crates/atuin-dotfiles/src/store/alias.rs b/crates/atuin-dotfiles/src/store/alias.rs deleted file mode 100644 index 8b137891..00000000 --- a/crates/atuin-dotfiles/src/store/alias.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/atuin-dotfiles/src/store/var.rs b/crates/atuin-dotfiles/src/store/var.rs deleted file mode 100644 index 9d25b85d..00000000 --- a/crates/atuin-dotfiles/src/store/var.rs +++ /dev/null @@ -1,542 +0,0 @@ -/// Store for shell vars -/// I should abstract this and reuse code between the alias/env stores -/// This is easier for now -/// Once I have two implementations, building a common base is much easier. -use std::collections::BTreeMap; - -use atuin_client::record::sqlite_store::SqliteStore; -use atuin_common::record::{DecryptedData, Host, HostId}; -use eyre::{Result, bail, ensure, eyre}; - -use atuin_client::record::encryption::PASETO_V4; -use atuin_client::record::store::Store; - -use crate::shell::Var; - -const DOTFILES_VAR_VERSION: &str = "v0"; -const DOTFILES_VAR_TAG: &str = "dotfiles-var"; -const DOTFILES_VAR_LEN: usize = 20000; // 20kb max total len, way more than should be needed. - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum VarRecord { - Create(Var), // create a full record - Delete(String), // delete by name -} - -impl VarRecord { - pub fn serialize(&self) -> Result<DecryptedData> { - use rmp::encode; - - let mut output = vec![]; - - match self { - VarRecord::Create(env) => { - encode::write_u8(&mut output, 0)?; // create - - env.serialize(&mut output)?; - } - VarRecord::Delete(env) => { - encode::write_u8(&mut output, 1)?; // delete - encode::write_array_len(&mut output, 1)?; // 1 field - - encode::write_str(&mut output, env.as_str())?; - } - } - - Ok(DecryptedData(output)) - } - - pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match version { - DOTFILES_VAR_VERSION => { - let mut bytes = decode::Bytes::new(&data.0); - - let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; - - match record_type { - // create - 0 => { - let env = Var::deserialize(&mut bytes)?; - Ok(VarRecord::Create(env)) - } - - // delete - 1 => { - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - ensure!( - nfields == 1, - "too many entries in v0 dotfiles var delete record" - ); - - let bytes = bytes.remaining_slice(); - - let (key, bytes) = - decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded dotfiles var record. malformed") - } - - Ok(VarRecord::Delete(key.to_owned())) - } - - n => { - bail!("unknown Dotfiles var record type {n}") - } - } - } - _ => { - bail!("unknown version {version:?}") - } - } - } -} - -#[derive(Debug, Clone)] -pub struct VarStore { - pub store: SqliteStore, - pub host_id: HostId, - pub encryption_key: [u8; 32], -} - -impl VarStore { - // will want to init the actual kv store when that is done - pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> VarStore { - VarStore { - store, - host_id, - encryption_key, - } - } - - /// Escape a value for use in POSIX shells (bash, zsh) - /// This adds double quotes around the value and escapes any embedded double quotes - fn escape_posix_value(value: &str) -> String { - // If the value contains no special characters, we can use it unquoted - if value - .chars() - .all(|c| c.is_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '.') - { - value.to_string() - } else { - // Otherwise, wrap in double quotes and escape any special characters - format!( - "\"{}\"", - value - .replace('\\', "\\\\") - .replace('"', "\\\"") - .replace('$', "\\$") - .replace('`', "\\`") - ) - } - } - - /// Escape a value for use in fish shell - /// Fish uses single quotes for literal strings, but we need to handle embedded single quotes - fn escape_fish_value(value: &str) -> String { - // If the value contains no special characters, we can use it unquoted - if value - .chars() - .all(|c| c.is_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '.') - { - value.to_string() - } else { - // Use single quotes and escape any embedded single quotes - format!("'{}'", value.replace('\'', "\\'")) - } - } - - /// Escape a value for use in xonsh - /// Xonsh uses Python-style string literals - fn escape_xonsh_value(value: &str) -> String { - // If the value contains no special characters, we can use it unquoted - if value - .chars() - .all(|c| c.is_alphanumeric() || c == '_' || c == '-' || c == '/' || c == '.') - { - value.to_string() - } else { - // Use double quotes and escape appropriately for Python strings - format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\"")) - } - } - - pub async fn xonsh(&self) -> Result<String> { - let env = self.vars().await?; - Ok(Self::format_xonsh(&env)) - } - - pub async fn fish(&self) -> Result<String> { - let env = self.vars().await?; - Ok(Self::format_fish(&env)) - } - - pub async fn posix(&self) -> Result<String> { - let env = self.vars().await?; - Ok(Self::format_posix(&env)) - } - - pub async fn powershell(&self) -> Result<String> { - let env = self.vars().await?; - Ok(Self::format_powershell(&env)) - } - - fn format_xonsh(env: &[Var]) -> String { - let mut config = String::new(); - - for env in env { - let escaped_value = Self::escape_xonsh_value(&env.value); - config.push_str(&format!("${}={}\n", env.name, escaped_value)); - } - - config - } - - fn format_fish(env: &[Var]) -> String { - let mut config = String::new(); - - for env in env { - let escaped_value = Self::escape_fish_value(&env.value); - config.push_str(&format!("set -gx {} {}\n", env.name, escaped_value)); - } - - config - } - - fn format_posix(env: &[Var]) -> String { - let mut config = String::new(); - - for env in env { - let escaped_value = Self::escape_posix_value(&env.value); - if env.export { - config.push_str(&format!("export {}={}\n", env.name, escaped_value)); - } else { - config.push_str(&format!("{}={}\n", env.name, escaped_value)); - } - } - - config - } - - fn format_powershell(env: &[Var]) -> String { - let mut config = String::new(); - - for var in env { - config.push_str(&crate::shell::powershell::format_var(var)); - } - - config - } - - pub async fn build(&self) -> Result<()> { - let dir = atuin_common::utils::dotfiles_cache_dir(); - tokio::fs::create_dir_all(dir.clone()).await?; - - let env = self.vars().await?; - - // Build for all supported shells - let posix = Self::format_posix(&env); - let xonsh = Self::format_xonsh(&env); - let fsh = Self::format_fish(&env); - let powershell = Self::format_powershell(&env); - - // All the same contents, maybe optimize in the future or perhaps there will be quirks - // per-shell - // I'd prefer separation atm - let zsh = dir.join("vars.zsh"); - let bash = dir.join("vars.bash"); - let fish = dir.join("vars.fish"); - let xsh = dir.join("vars.xsh"); - let ps1 = dir.join("vars.ps1"); - - tokio::fs::write(zsh, &posix).await?; - tokio::fs::write(bash, &posix).await?; - tokio::fs::write(fish, &fsh).await?; - tokio::fs::write(xsh, &xonsh).await?; - tokio::fs::write(ps1, &powershell).await?; - - Ok(()) - } - - pub async fn set(&self, name: &str, value: &str, export: bool) -> Result<()> { - if name.len() + value.len() > DOTFILES_VAR_LEN { - return Err(eyre!( - "var record too large: max len {} bytes", - DOTFILES_VAR_LEN - )); - } - - let record = VarRecord::Create(Var { - name: name.to_string(), - value: value.to_string(), - export, - }); - - let bytes = record.serialize()?; - - let idx = self - .store - .last(self.host_id, DOTFILES_VAR_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(self.host_id)) - .version(DOTFILES_VAR_VERSION.to_string()) - .tag(DOTFILES_VAR_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - // set mutates shell config, so build again - self.build().await?; - - Ok(()) - } - - pub async fn delete(&self, name: &str) -> Result<()> { - if name.len() > DOTFILES_VAR_LEN { - return Err(eyre!( - "var record too large: max len {} bytes", - DOTFILES_VAR_LEN, - )); - } - - let record = VarRecord::Delete(name.to_string()); - - let bytes = record.serialize()?; - - let idx = self - .store - .last(self.host_id, DOTFILES_VAR_TAG) - .await? - .map_or(0, |entry| entry.idx + 1); - - let record = atuin_common::record::Record::builder() - .host(Host::new(self.host_id)) - .version(DOTFILES_VAR_VERSION.to_string()) - .tag(DOTFILES_VAR_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - // delete mutates shell config, so build again - self.build().await?; - - Ok(()) - } - - pub async fn vars(&self) -> Result<Vec<Var>> { - let mut build = BTreeMap::new(); - - // this is sorted, oldest to newest - let tagged = self.store.all_tagged(DOTFILES_VAR_TAG).await?; - - for record in tagged { - let version = record.version.clone(); - - let decrypted = match version.as_str() { - DOTFILES_VAR_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?, - version => bail!("unknown version {version:?}"), - }; - - let ar = VarRecord::deserialize(&decrypted.data, version.as_str())?; - - match ar { - VarRecord::Create(a) => { - build.insert(a.name.clone(), a); - } - VarRecord::Delete(d) => { - build.remove(&d); - } - } - } - - Ok(build.into_values().collect()) - } -} - -#[cfg(test)] -mod tests { - use rand::rngs::OsRng; - - use atuin_client::record::sqlite_store::SqliteStore; - - use crate::{shell::Var, store::test_local_timeout}; - - use super::{DOTFILES_VAR_VERSION, VarRecord, VarStore}; - use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; - - #[test] - fn encode_decode() { - let record = Var { - name: "BEEP".to_owned(), - value: "boop".to_owned(), - export: false, - }; - let record = VarRecord::Create(record); - - let snapshot = [ - 204, 0, 147, 164, 66, 69, 69, 80, 164, 98, 111, 111, 112, 194, - ]; - - let encoded = record.serialize().unwrap(); - let decoded = VarRecord::deserialize(&encoded, DOTFILES_VAR_VERSION).unwrap(); - - assert_eq!(encoded.0, &snapshot); - assert_eq!(decoded, record); - } - - #[test] - fn test_escape_posix_value() { - // Simple values should not be quoted - assert_eq!(VarStore::escape_posix_value("simple"), "simple"); - assert_eq!(VarStore::escape_posix_value("path/to/file"), "path/to/file"); - assert_eq!( - VarStore::escape_posix_value("value_with_underscores"), - "value_with_underscores" - ); - - // Values with spaces should be quoted - assert_eq!( - VarStore::escape_posix_value("hello world"), - "\"hello world\"" - ); - assert_eq!(VarStore::escape_posix_value("bar baz"), "\"bar baz\""); - - // Values with special characters should be quoted and escaped - assert_eq!( - VarStore::escape_posix_value("say \"hello\""), - "\"say \\\"hello\\\"\"" - ); - assert_eq!( - VarStore::escape_posix_value("path\\with\\backslashes"), - "\"path\\\\with\\\\backslashes\"" - ); - assert_eq!( - VarStore::escape_posix_value("say $hello"), - "\"say \\$hello\"" - ); - assert_eq!( - VarStore::escape_posix_value("see `example.md`"), - "\"see \\`example.md\\`\"" - ); - } - - #[test] - fn test_escape_fish_value() { - // Simple values should not be quoted - assert_eq!(VarStore::escape_fish_value("simple"), "simple"); - assert_eq!(VarStore::escape_fish_value("path/to/file"), "path/to/file"); - - // Values with spaces should be single-quoted - assert_eq!(VarStore::escape_fish_value("hello world"), "'hello world'"); - assert_eq!(VarStore::escape_fish_value("bar baz"), "'bar baz'"); - - // Values with single quotes should be escaped - assert_eq!(VarStore::escape_fish_value("don't"), "'don\\'t'"); - } - - #[test] - fn test_escape_xonsh_value() { - // Simple values should not be quoted - assert_eq!(VarStore::escape_xonsh_value("simple"), "simple"); - assert_eq!(VarStore::escape_xonsh_value("path/to/file"), "path/to/file"); - - // Values with spaces should be quoted - assert_eq!( - VarStore::escape_xonsh_value("hello world"), - "\"hello world\"" - ); - assert_eq!(VarStore::escape_xonsh_value("bar baz"), "\"bar baz\""); - - // Values with special characters should be quoted and escaped - assert_eq!( - VarStore::escape_xonsh_value("say \"hello\""), - "\"say \\\"hello\\\"\"" - ); - assert_eq!( - VarStore::escape_xonsh_value("path\\with\\backslashes"), - "\"path\\\\with\\\\backslashes\"" - ); - } - - #[tokio::test] - async fn build_vars() { - let store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); - let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); - - let env = VarStore::new(store, host_id, key); - - env.set("BEEP", "boop", false).await.unwrap(); - env.set("HOMEBREW_NO_AUTO_UPDATE", "1", true).await.unwrap(); - - let mut env_vars = env.vars().await.unwrap(); - - env_vars.sort_by_key(|a| a.name.clone()); - - assert_eq!(env_vars.len(), 2); - - assert_eq!( - env_vars[0], - Var { - name: String::from("BEEP"), - value: String::from("boop"), - export: false, - } - ); - - assert_eq!( - env_vars[1], - Var { - name: String::from("HOMEBREW_NO_AUTO_UPDATE"), - value: String::from("1"), - export: true, - } - ); - } - - #[tokio::test] - async fn test_var_generation_with_spaces() { - let store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); - let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); - - let env = VarStore::new(store, host_id, key); - - // Test the exact scenario from the bug report - env.set("FOO", "bar baz", true).await.unwrap(); - - let posix_output = env.posix().await.unwrap(); - let fish_output = env.fish().await.unwrap(); - let xonsh_output = env.xonsh().await.unwrap(); - - // POSIX should quote the value - assert_eq!(posix_output, "export FOO=\"bar baz\"\n"); - - // Fish should quote the value - assert_eq!(fish_output, "set -gx FOO 'bar baz'\n"); - - // Xonsh should quote the value - assert_eq!(xonsh_output, "$FOO=\"bar baz\"\n"); - } -} diff --git a/crates/atuin-scripts/Cargo.toml b/crates/atuin-scripts/Cargo.toml deleted file mode 100644 index 6c168ecd..00000000 --- a/crates/atuin-scripts/Cargo.toml +++ /dev/null @@ -1,34 +0,0 @@ -[package] -name = "atuin-scripts" -edition = "2024" -version = { workspace = true } -description = "The scripts crate for Atuin" - -authors.workspace = true -rust-version.workspace = true -license.workspace = true -homepage.workspace = true -repository.workspace = true -readme.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -atuin-client = { path = "../atuin-client", version = "18.16.1" } -atuin-common = { path = "../atuin-common", version = "18.16.1" } - -tracing = { workspace = true } -tracing-subscriber = { workspace = true } -rmp = { version = "0.8.14" } -uuid = { workspace = true } -eyre = { workspace = true } -tokio = { workspace = true } -serde = { workspace = true } -typed-builder = { workspace = true } -pretty_assertions = { workspace = true } -sql-builder = { workspace = true } -sqlx = { workspace = true } -tempfile = { workspace = true } -minijinja = { workspace = true } -serde_json = { workspace = true } - diff --git a/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql b/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql deleted file mode 100644 index b2c5a363..00000000 --- a/crates/atuin-scripts/migrations/20250326160051_create_scripts.down.sql +++ /dev/null @@ -1,2 +0,0 @@ -DROP TABLE scripts; -DROP TABLE script_tags;
\ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql b/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql deleted file mode 100644 index 1b2f3688..00000000 --- a/crates/atuin-scripts/migrations/20250326160051_create_scripts.up.sql +++ /dev/null @@ -1,17 +0,0 @@ --- Add up migration script here -CREATE TABLE scripts ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - description TEXT NOT NULL, - shebang TEXT NOT NULL, - script TEXT NOT NULL, - inserted_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) -); - -CREATE TABLE script_tags ( - id INTEGER PRIMARY KEY, - script_id TEXT NOT NULL, - tag TEXT NOT NULL -); - -CREATE UNIQUE INDEX idx_script_tags ON script_tags (script_id, tag);
\ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql b/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql deleted file mode 100644 index 269b8cd9..00000000 --- a/crates/atuin-scripts/migrations/20250402170430_unique_names.down.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add down migration script here -alter table scripts drop index name_uniq_idx;
\ No newline at end of file diff --git a/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql b/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql deleted file mode 100644 index d2cdd02f..00000000 --- a/crates/atuin-scripts/migrations/20250402170430_unique_names.up.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add up migration script here -create unique index name_uniq_idx ON scripts(name);
\ No newline at end of file diff --git a/crates/atuin-scripts/src/database.rs b/crates/atuin-scripts/src/database.rs deleted file mode 100644 index be113526..00000000 --- a/crates/atuin-scripts/src/database.rs +++ /dev/null @@ -1,371 +0,0 @@ -use std::{path::Path, str::FromStr, time::Duration}; - -use atuin_common::utils; -use sqlx::{ - Result, Row, - sqlite::{ - SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, - SqliteSynchronous, - }, -}; -use tokio::fs; -use tracing::debug; -use uuid::Uuid; - -use crate::store::script::Script; - -#[derive(Debug, Clone)] -pub struct Database { - pub pool: SqlitePool, -} - -impl Database { - pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { - let path = path.as_ref(); - debug!("opening script sqlite database at {:?}", path); - - if utils::broken_symlink(path) { - eprintln!( - "Atuin: Script sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." - ); - std::process::exit(1); - } - - if !path.exists() - && let Some(dir) = path.parent() - { - fs::create_dir_all(dir).await?; - } - - let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? - .journal_mode(SqliteJournalMode::Wal) - .optimize_on_close(true, None) - .synchronous(SqliteSynchronous::Normal) - .with_regexp() - .foreign_keys(true) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - Self::setup_db(&pool).await?; - Ok(Self { pool }) - } - - pub async fn sqlite_version(&self) -> Result<String> { - sqlx::query_scalar("SELECT sqlite_version()") - .fetch_one(&self.pool) - .await - } - - async fn setup_db(pool: &SqlitePool) -> Result<()> { - debug!("running sqlite database setup"); - - sqlx::migrate!("./migrations").run(pool).await?; - - Ok(()) - } - - async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, s: &Script) -> Result<()> { - sqlx::query( - "insert or ignore into scripts(id, name, description, shebang, script) - values(?1, ?2, ?3, ?4, ?5)", - ) - .bind(s.id.to_string()) - .bind(s.name.as_str()) - .bind(s.description.as_str()) - .bind(s.shebang.as_str()) - .bind(s.script.as_str()) - .execute(&mut **tx) - .await?; - - for tag in s.tags.iter() { - sqlx::query( - "insert or ignore into script_tags(script_id, tag) - values(?1, ?2)", - ) - .bind(s.id.to_string()) - .bind(tag) - .execute(&mut **tx) - .await?; - } - - Ok(()) - } - - pub async fn save(&self, s: &Script) -> Result<()> { - debug!("saving script to sqlite"); - let mut tx = self.pool.begin().await?; - Self::save_raw(&mut tx, s).await?; - tx.commit().await?; - - Ok(()) - } - - pub async fn save_bulk(&self, s: &[Script]) -> Result<()> { - debug!("saving scripts to sqlite"); - - let mut tx = self.pool.begin().await?; - - for i in s { - Self::save_raw(&mut tx, i).await?; - } - - tx.commit().await?; - - Ok(()) - } - - fn query_script(row: SqliteRow) -> Script { - let id = row.get("id"); - let name = row.get("name"); - let description = row.get("description"); - let shebang = row.get("shebang"); - let script = row.get("script"); - - let id = Uuid::parse_str(id).unwrap(); - - Script { - id, - name, - description, - shebang, - script, - tags: vec![], - } - } - - fn query_script_tags(row: SqliteRow) -> String { - row.get("tag") - } - - #[allow(dead_code)] - async fn load(&self, id: &str) -> Result<Option<Script>> { - debug!("loading script item {}", id); - - let res = sqlx::query("select * from scripts where id = ?1") - .bind(id) - .map(Self::query_script) - .fetch_optional(&self.pool) - .await?; - - // intentionally not joining, don't want to duplicate the script data in memory a whole bunch. - if let Some(mut script) = res { - let tags = sqlx::query("select tag from script_tags where script_id = ?1") - .bind(id) - .map(Self::query_script_tags) - .fetch_all(&self.pool) - .await?; - - script.tags = tags; - Ok(Some(script)) - } else { - Ok(None) - } - } - - pub async fn list(&self) -> Result<Vec<Script>> { - debug!("listing scripts"); - - let mut res = sqlx::query("select * from scripts") - .map(Self::query_script) - .fetch_all(&self.pool) - .await?; - - // Fetch all the tags for each script - for script in res.iter_mut() { - let tags = sqlx::query("select tag from script_tags where script_id = ?1") - .bind(script.id.to_string()) - .map(Self::query_script_tags) - .fetch_all(&self.pool) - .await?; - - script.tags = tags; - } - - Ok(res) - } - - pub async fn clear(&self) -> Result<()> { - debug!("clearing all scripts from sqlite"); - - sqlx::query("delete from script_tags") - .execute(&self.pool) - .await?; - sqlx::query("delete from scripts") - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn delete(&self, id: &str) -> Result<()> { - debug!("deleting script {}", id); - - sqlx::query("delete from scripts where id = ?1") - .bind(id) - .execute(&self.pool) - .await?; - - // delete all the tags for the script - sqlx::query("delete from script_tags where script_id = ?1") - .bind(id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn update(&self, s: &Script) -> Result<()> { - debug!("updating script {:?}", s); - - let mut tx = self.pool.begin().await?; - - // Update the script's base fields - sqlx::query("update scripts set name = ?1, description = ?2, shebang = ?3, script = ?4 where id = ?5") - .bind(s.name.as_str()) - .bind(s.description.as_str()) - .bind(s.shebang.as_str()) - .bind(s.script.as_str()) - .bind(s.id.to_string()) - .execute(&mut *tx) - .await?; - - // Delete all existing tags for this script - sqlx::query("delete from script_tags where script_id = ?1") - .bind(s.id.to_string()) - .execute(&mut *tx) - .await?; - - // Insert new tags - for tag in s.tags.iter() { - sqlx::query( - "insert or ignore into script_tags(script_id, tag) - values(?1, ?2)", - ) - .bind(s.id.to_string()) - .bind(tag) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - pub async fn get_by_name(&self, name: &str) -> Result<Option<Script>> { - let res = sqlx::query("select * from scripts where name = ?1") - .bind(name) - .map(Self::query_script) - .fetch_optional(&self.pool) - .await?; - - let script = if let Some(mut script) = res { - let tags = sqlx::query("select tag from script_tags where script_id = ?1") - .bind(script.id.to_string()) - .map(Self::query_script_tags) - .fetch_all(&self.pool) - .await?; - - script.tags = tags; - Some(script) - } else { - None - }; - - Ok(script) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[tokio::test] - async fn test_list() { - let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); - let scripts = db.list().await.unwrap(); - assert_eq!(scripts.len(), 0); - - let script = Script::builder() - .name("test".to_string()) - .description("test".to_string()) - .shebang("test".to_string()) - .script("test".to_string()) - .build(); - - db.save(&script).await.unwrap(); - - let scripts = db.list().await.unwrap(); - assert_eq!(scripts.len(), 1); - assert_eq!(scripts[0].name, "test"); - } - - #[tokio::test] - async fn test_save_load() { - let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); - - let script = Script::builder() - .name("test name".to_string()) - .description("test description".to_string()) - .shebang("test shebang".to_string()) - .script("test script".to_string()) - .build(); - - db.save(&script).await.unwrap(); - - let loaded = db.load(&script.id.to_string()).await.unwrap().unwrap(); - - assert_eq!(loaded, script); - } - - #[tokio::test] - async fn test_save_bulk() { - let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); - - let scripts = vec![ - Script::builder() - .name("test name".to_string()) - .description("test description".to_string()) - .shebang("test shebang".to_string()) - .script("test script".to_string()) - .build(), - Script::builder() - .name("test name 2".to_string()) - .description("test description 2".to_string()) - .shebang("test shebang 2".to_string()) - .script("test script 2".to_string()) - .build(), - ]; - - db.save_bulk(&scripts).await.unwrap(); - - let loaded = db.list().await.unwrap(); - assert_eq!(loaded.len(), 2); - assert_eq!(loaded[0].name, "test name"); - assert_eq!(loaded[1].name, "test name 2"); - } - - #[tokio::test] - async fn test_delete() { - let db = Database::new("sqlite::memory:", 1.0).await.unwrap(); - - let script = Script::builder() - .name("test name".to_string()) - .description("test description".to_string()) - .shebang("test shebang".to_string()) - .script("test script".to_string()) - .build(); - - db.save(&script).await.unwrap(); - - assert_eq!(db.list().await.unwrap().len(), 1); - db.delete(&script.id.to_string()).await.unwrap(); - - let loaded = db.list().await.unwrap(); - assert_eq!(loaded.len(), 0); - } -} diff --git a/crates/atuin-scripts/src/execution.rs b/crates/atuin-scripts/src/execution.rs deleted file mode 100644 index 5bf94aaa..00000000 --- a/crates/atuin-scripts/src/execution.rs +++ /dev/null @@ -1,286 +0,0 @@ -use crate::store::script::Script; -use eyre::Result; -use std::collections::{HashMap, HashSet}; -use std::process::Stdio; -use tempfile::NamedTempFile; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::mpsc; -use tokio::task; -use tracing::debug; - -// Helper function to build a complete script with shebang -pub fn build_executable_script(script: String, shebang: String) -> String { - if shebang.is_empty() { - // Default to bash if no shebang is provided - format!("#!/usr/bin/env bash\n{script}") - } else if script.starts_with("#!") { - format!("{shebang}\n{script}") - } else { - format!("#!{shebang}\n{script}") - } -} - -/// Represents the communication channels for an interactive script -pub struct ScriptSession { - /// Channel to send input to the script - pub stdin_tx: mpsc::Sender<String>, - /// Exit code of the process once it completes - pub exit_code_rx: mpsc::Receiver<i32>, -} - -impl ScriptSession { - /// Send input to the running script - pub async fn send_input(&self, input: String) -> Result<(), mpsc::error::SendError<String>> { - self.stdin_tx.send(input).await - } - - /// Wait for the script to complete and get the exit code - pub async fn wait_for_exit(&mut self) -> Option<i32> { - self.exit_code_rx.recv().await - } -} - -fn setup_template(script: &Script) -> Result<minijinja::Environment<'_>> { - let mut env = minijinja::Environment::new(); - env.set_trim_blocks(true); - env.add_template("script", script.script.as_str())?; - - Ok(env) -} - -/// Template a script with the given context -pub fn template_script( - script: &Script, - context: &HashMap<String, serde_json::Value>, -) -> Result<String> { - let env = setup_template(script)?; - let template = env.get_template("script")?; - let rendered = template.render(context)?; - - Ok(rendered) -} - -/// Get the variables that need to be templated in a script -pub fn template_variables(script: &Script) -> Result<HashSet<String>> { - let env = setup_template(script)?; - let template = env.get_template("script")?; - - Ok(template.undeclared_variables(true)) -} - -/// Execute a script interactively, allowing for ongoing stdin/stdout interaction -pub async fn execute_script_interactive( - script: String, - shebang: String, -) -> Result<ScriptSession, Box<dyn std::error::Error + Send + Sync>> { - // Create a temporary file for the script - let temp_file = NamedTempFile::new()?; - let temp_path = temp_file.path().to_path_buf(); - - debug!("creating temp file at {}", temp_path.display()); - - // Extract interpreter from shebang for fallback execution - let interpreter = if !shebang.is_empty() { - shebang.trim_start_matches("#!").trim().to_string() - } else { - "/usr/bin/env bash".to_string() - }; - - // Write script content to the temp file, including the shebang - let full_script_content = build_executable_script(script.clone(), shebang.clone()); - - debug!("writing script content to temp file"); - tokio::fs::write(&temp_path, &full_script_content).await?; - - // Make it executable on Unix systems - #[cfg(unix)] - { - debug!("making script executable"); - use std::os::unix::fs::PermissionsExt; - let mut perms = std::fs::metadata(&temp_path)?.permissions(); - perms.set_mode(0o755); - std::fs::set_permissions(&temp_path, perms)?; - } - - // Store the temp_file to prevent it from being dropped - // This ensures it won't be deleted while the script is running - let _keep_temp_file = temp_file; - - debug!("attempting direct script execution"); - let mut child_result = tokio::process::Command::new(temp_path.to_str().unwrap()) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn(); - - // If direct execution fails, try using the interpreter - if let Err(e) = &child_result { - debug!("direct execution failed: {}, trying with interpreter", e); - - // When falling back to interpreter, remove the shebang from the file - // Some interpreters don't handle scripts with shebangs well - debug!("writing script content without shebang for interpreter execution"); - tokio::fs::write(&temp_path, &script).await?; - - // Parse the interpreter command - let parts: Vec<&str> = interpreter.split_whitespace().collect(); - if !parts.is_empty() { - let mut cmd = tokio::process::Command::new(parts[0]); - - // Add any interpreter args - for i in parts.iter().skip(1) { - cmd.arg(i); - } - - // Add the script path - cmd.arg(temp_path.to_str().unwrap()); - - // Try with the interpreter - child_result = cmd - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn(); - } - } - - // If it still fails, return the error - let mut child = match child_result { - Ok(child) => child, - Err(e) => { - return Err(format!("Failed to execute script: {e}").into()); - } - }; - - // Get handles to stdin, stdout, stderr - let mut stdin = child - .stdin - .take() - .ok_or_else(|| "Failed to open child process stdin".to_string())?; - let stdout = child - .stdout - .take() - .ok_or_else(|| "Failed to open child process stdout".to_string())?; - let stderr = child - .stderr - .take() - .ok_or_else(|| "Failed to open child process stderr".to_string())?; - - // Create channels for the interactive session - let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(32); - let (exit_code_tx, exit_code_rx) = mpsc::channel::<i32>(1); - - // handle user stdin - debug!("spawning stdin handler"); - tokio::spawn(async move { - while let Some(input) = stdin_rx.recv().await { - if let Err(e) = stdin.write_all(input.as_bytes()).await { - eprintln!("Error writing to stdin: {e}"); - break; - } - if let Err(e) = stdin.flush().await { - eprintln!("Error flushing stdin: {e}"); - break; - } - } - // when the channel closes (sender dropped), we let stdin close naturally - }); - - // handle stdout - debug!("spawning stdout handler"); - let stdout_handle = task::spawn(async move { - let mut stdout_reader = BufReader::new(stdout); - let mut buffer = [0u8; 1024]; - let mut stdout_writer = tokio::io::stdout(); - - loop { - match stdout_reader.read(&mut buffer).await { - Ok(0) => break, // End of stdout - Ok(n) => { - if let Err(e) = stdout_writer.write_all(&buffer[0..n]).await { - eprintln!("Error writing to stdout: {e}"); - break; - } - if let Err(e) = stdout_writer.flush().await { - eprintln!("Error flushing stdout: {e}"); - break; - } - } - Err(e) => { - eprintln!("Error reading from process stdout: {e}"); - break; - } - } - } - }); - - // Process stderr in a separate task - debug!("spawning stderr handler"); - let stderr_handle = task::spawn(async move { - let mut stderr_reader = BufReader::new(stderr); - let mut buffer = [0u8; 1024]; - let mut stderr_writer = tokio::io::stderr(); - - loop { - match stderr_reader.read(&mut buffer).await { - Ok(0) => break, // End of stderr - Ok(n) => { - if let Err(e) = stderr_writer.write_all(&buffer[0..n]).await { - eprintln!("Error writing to stderr: {e}"); - break; - } - if let Err(e) = stderr_writer.flush().await { - eprintln!("Error flushing stderr: {e}"); - break; - } - } - Err(e) => { - eprintln!("Error reading from process stderr: {e}"); - break; - } - } - } - }); - - // Spawn a task to wait for the child process to complete - debug!("spawning exit code handler"); - let _keep_temp_file_clone = _keep_temp_file; - tokio::spawn(async move { - // Keep the temp file alive until the process completes - let _temp_file_ref = _keep_temp_file_clone; - - // Wait for the child process to complete - let status = match child.wait().await { - Ok(status) => { - debug!("Process exited with status: {:?}", status); - status - } - Err(e) => { - eprintln!("Error waiting for child process: {e}"); - // Send a default error code - let _ = exit_code_tx.send(-1).await; - return; - } - }; - - // Wait for stdout/stderr tasks to complete - if let Err(e) = stdout_handle.await { - eprintln!("Error joining stdout task: {e}"); - } - - if let Err(e) = stderr_handle.await { - eprintln!("Error joining stderr task: {e}"); - } - - // Send the exit code - let exit_code = status.code().unwrap_or(-1); - debug!("Sending exit code: {}", exit_code); - let _ = exit_code_tx.send(exit_code).await; - }); - - // Return the communication channels as a ScriptSession - Ok(ScriptSession { - stdin_tx, - exit_code_rx, - }) -} diff --git a/crates/atuin-scripts/src/lib.rs b/crates/atuin-scripts/src/lib.rs deleted file mode 100644 index c79c7089..00000000 --- a/crates/atuin-scripts/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod database; -pub mod execution; -pub mod settings; -pub mod store; diff --git a/crates/atuin-scripts/src/settings.rs b/crates/atuin-scripts/src/settings.rs deleted file mode 100644 index 8b137891..00000000 --- a/crates/atuin-scripts/src/settings.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/crates/atuin-scripts/src/store.rs b/crates/atuin-scripts/src/store.rs deleted file mode 100644 index e70f6909..00000000 --- a/crates/atuin-scripts/src/store.rs +++ /dev/null @@ -1,114 +0,0 @@ -use eyre::{Result, bail}; - -use atuin_client::record::sqlite_store::SqliteStore; -use atuin_client::record::{encryption::PASETO_V4, store::Store}; -use atuin_common::record::{Host, HostId, Record, RecordId, RecordIdx}; -use record::ScriptRecord; -use script::{SCRIPT_TAG, SCRIPT_VERSION, Script}; - -use crate::database::Database; - -pub mod record; -pub mod script; - -#[derive(Debug, Clone)] -pub struct ScriptStore { - pub store: SqliteStore, - pub host_id: HostId, - pub encryption_key: [u8; 32], -} - -impl ScriptStore { - pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { - ScriptStore { - store, - host_id, - encryption_key, - } - } - - async fn push_record(&self, record: ScriptRecord) -> Result<(RecordId, RecordIdx)> { - let bytes = record.serialize()?; - let idx = self - .store - .last(self.host_id, SCRIPT_TAG) - .await? - .map_or(0, |p| p.idx + 1); - - let record = Record::builder() - .host(Host::new(self.host_id)) - .version(SCRIPT_VERSION.to_string()) - .tag(SCRIPT_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - let id = record.id; - - self.store - .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) - .await?; - - Ok((id, idx)) - } - - pub async fn create(&self, script: Script) -> Result<()> { - let record = ScriptRecord::Create(script); - self.push_record(record).await?; - Ok(()) - } - - pub async fn update(&self, script: Script) -> Result<()> { - let record = ScriptRecord::Update(script); - self.push_record(record).await?; - Ok(()) - } - - pub async fn delete(&self, script_id: uuid::Uuid) -> Result<()> { - let record = ScriptRecord::Delete(script_id); - self.push_record(record).await?; - Ok(()) - } - - pub async fn scripts(&self) -> Result<Vec<ScriptRecord>> { - let records = self.store.all_tagged(SCRIPT_TAG).await?; - let mut ret = Vec::with_capacity(records.len()); - - for record in records.into_iter() { - let script = match record.version.as_str() { - SCRIPT_VERSION => { - let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; - - ScriptRecord::deserialize(&decrypted.data, SCRIPT_VERSION) - } - version => bail!("unknown history version {version:?}"), - }?; - - ret.push(script); - } - - Ok(ret) - } - - pub async fn build(&self, database: Database) -> Result<()> { - // Clear existing data before replaying all records from the store. - // Without this, stale rows can cause unique constraint violations - // when records are replayed (eg name conflicts from renamed scripts). - database.clear().await?; - - // Get all the scripts from the store - they are already sorted by timestamp - let scripts = self.scripts().await?; - - for script in scripts { - match script { - ScriptRecord::Create(script) => { - database.save(&script).await?; - } - ScriptRecord::Update(script) => database.update(&script).await?, - ScriptRecord::Delete(id) => database.delete(&id.to_string()).await?, - } - } - - Ok(()) - } -} diff --git a/crates/atuin-scripts/src/store/record.rs b/crates/atuin-scripts/src/store/record.rs deleted file mode 100644 index 4c925be3..00000000 --- a/crates/atuin-scripts/src/store/record.rs +++ /dev/null @@ -1,215 +0,0 @@ -use atuin_common::record::DecryptedData; -use eyre::{Result, eyre}; -use uuid::Uuid; - -use crate::store::script::SCRIPT_VERSION; - -use super::script::Script; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ScriptRecord { - Create(Script), - Update(Script), - Delete(Uuid), -} - -impl ScriptRecord { - pub fn serialize(&self) -> Result<DecryptedData> { - use rmp::encode; - - let mut output = vec![]; - - match self { - ScriptRecord::Create(script) => { - // 0 -> a script create - encode::write_u8(&mut output, 0)?; - - let bytes = script.serialize()?; - - encode::write_bin(&mut output, &bytes.0)?; - } - - ScriptRecord::Delete(id) => { - // 1 -> a script delete - encode::write_u8(&mut output, 1)?; - encode::write_str(&mut output, id.to_string().as_str())?; - } - - ScriptRecord::Update(script) => { - // 2 -> a script update - encode::write_u8(&mut output, 2)?; - let bytes = script.serialize()?; - encode::write_bin(&mut output, &bytes.0)?; - } - }; - - Ok(DecryptedData(output)) - } - - pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { - use rmp::decode; - - fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match version { - SCRIPT_VERSION => { - let mut bytes = decode::Bytes::new(&data.0); - - let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; - - match record_type { - // create - 0 => { - // written by encode::write_bin above - let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; - let script = Script::deserialize(bytes.remaining_slice())?; - Ok(ScriptRecord::Create(script)) - } - - // delete - 1 => { - let bytes = bytes.remaining_slice(); - let (id, _) = decode::read_str_from_slice(bytes).map_err(error_report)?; - Ok(ScriptRecord::Delete(Uuid::parse_str(id)?)) - } - - // update - 2 => { - // written by encode::write_bin above - let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; - let script = Script::deserialize(bytes.remaining_slice())?; - Ok(ScriptRecord::Update(script)) - } - - _ => Err(eyre!("unknown script record type {record_type}")), - } - } - _ => Err(eyre!("unknown version {version:?}")), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_serialize_create() { - let script = Script::builder() - .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap()) - .name("test".to_string()) - .description("test".to_string()) - .shebang("test".to_string()) - .tags(vec!["test".to_string()]) - .script("test".to_string()) - .build(); - - let record = ScriptRecord::Create(script); - - let serialized = record.serialize().unwrap(); - - assert_eq!( - serialized.0, - vec![ - 204, 0, 196, 65, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, - 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, - 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, - 116, 145, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116 - ] - ); - } - - #[test] - fn test_serialize_delete() { - let record = ScriptRecord::Delete( - uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), - ); - - let serialized = record.serialize().unwrap(); - - assert_eq!( - serialized.0, - vec![ - 204, 1, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, - 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54 - ] - ); - } - - #[test] - fn test_serialize_update() { - let script = Script::builder() - .id(uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap()) - .name(String::from("test")) - .description(String::from("test")) - .shebang(String::from("test")) - .tags(vec![String::from("test"), String::from("test2")]) - .script(String::from("test")) - .build(); - - let record = ScriptRecord::Update(script); - - let serialized = record.serialize().unwrap(); - - assert_eq!( - serialized.0, - vec![ - 204, 2, 196, 71, 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, - 45, 55, 57, 56, 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, - 98, 99, 54, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, - 116, 146, 164, 116, 101, 115, 116, 165, 116, 101, 115, 116, 50, 164, 116, 101, 115, - 116 - ], - ); - } - - #[test] - fn test_serialize_deserialize_create() { - let script = Script::builder() - .name("test".to_string()) - .description("test".to_string()) - .shebang("test".to_string()) - .tags(vec!["test".to_string()]) - .script("test".to_string()) - .build(); - - let record = ScriptRecord::Create(script); - - let serialized = record.serialize().unwrap(); - let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); - - assert_eq!(record, deserialized); - } - - #[test] - fn test_serialize_deserialize_delete() { - let record = ScriptRecord::Delete( - uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), - ); - - let serialized = record.serialize().unwrap(); - let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); - - assert_eq!(record, deserialized); - } - - #[test] - fn test_serialize_deserialize_update() { - let script = Script::builder() - .name("test".to_string()) - .description("test".to_string()) - .shebang("test".to_string()) - .tags(vec!["test".to_string()]) - .script("test".to_string()) - .build(); - - let record = ScriptRecord::Update(script); - - let serialized = record.serialize().unwrap(); - let deserialized = ScriptRecord::deserialize(&serialized, SCRIPT_VERSION).unwrap(); - - assert_eq!(record, deserialized); - } -} diff --git a/crates/atuin-scripts/src/store/script.rs b/crates/atuin-scripts/src/store/script.rs deleted file mode 100644 index af180320..00000000 --- a/crates/atuin-scripts/src/store/script.rs +++ /dev/null @@ -1,151 +0,0 @@ -use atuin_common::record::DecryptedData; -use eyre::{Result, bail, ensure}; -use uuid::Uuid; - -use rmp::{ - decode::{self, Bytes}, - encode, -}; -use typed_builder::TypedBuilder; - -pub const SCRIPT_VERSION: &str = "v0"; -pub const SCRIPT_TAG: &str = "script"; -pub const SCRIPT_LEN: usize = 20000; // 20kb max total len - -#[derive(Debug, Clone, PartialEq, Eq, TypedBuilder)] -/// A script is a set of commands that can be run, with the specified shebang -pub struct Script { - /// The id of the script - #[builder(default = uuid::Uuid::new_v4())] - pub id: Uuid, - - /// The name of the script - pub name: String, - - /// The description of the script - #[builder(default = String::new())] - pub description: String, - - /// The interpreter of the script - #[builder(default = String::new())] - pub shebang: String, - - /// The tags of the script - #[builder(default = Vec::new())] - pub tags: Vec<String>, - - /// The script content - pub script: String, -} - -impl Script { - pub fn serialize(&self) -> Result<DecryptedData> { - // sort the tags first, to ensure consistent ordering - let mut tags = self.tags.clone(); - tags.sort(); - - let mut output = vec![]; - - encode::write_array_len(&mut output, 6)?; - encode::write_str(&mut output, &self.id.to_string())?; - encode::write_str(&mut output, &self.name)?; - encode::write_str(&mut output, &self.description)?; - encode::write_str(&mut output, &self.shebang)?; - encode::write_array_len(&mut output, self.tags.len() as u32)?; - - for tag in &tags { - encode::write_str(&mut output, tag)?; - } - - encode::write_str(&mut output, &self.script)?; - - Ok(DecryptedData(output)) - } - - pub fn deserialize(bytes: &[u8]) -> Result<Self> { - let mut bytes = decode::Bytes::new(bytes); - let nfields = decode::read_array_len(&mut bytes).unwrap(); - - ensure!(nfields == 6, "too many entries in v0 script record"); - - let bytes = bytes.remaining_slice(); - - let (id, bytes) = decode::read_str_from_slice(bytes).unwrap(); - let (name, bytes) = decode::read_str_from_slice(bytes).unwrap(); - let (description, bytes) = decode::read_str_from_slice(bytes).unwrap(); - let (shebang, bytes) = decode::read_str_from_slice(bytes).unwrap(); - - let mut bytes = Bytes::new(bytes); - let tags_len = decode::read_array_len(&mut bytes).unwrap(); - - let mut bytes = bytes.remaining_slice(); - - let mut tags = Vec::new(); - for _ in 0..tags_len { - let (tag, remaining) = decode::read_str_from_slice(bytes).unwrap(); - tags.push(tag.to_owned()); - bytes = remaining; - } - - let (script, bytes) = decode::read_str_from_slice(bytes).unwrap(); - - if !bytes.is_empty() { - bail!("trailing bytes in encoded script record. malformed") - } - - Ok(Script { - id: Uuid::parse_str(id).unwrap(), - name: name.to_owned(), - description: description.to_owned(), - shebang: shebang.to_owned(), - tags, - script: script.to_owned(), - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_serialize() { - let script = Script { - id: uuid::Uuid::parse_str("0195c825a35f7982bdb016168881cbc6").unwrap(), - name: "test".to_string(), - description: "test".to_string(), - shebang: "test".to_string(), - tags: vec!["test".to_string()], - script: "test".to_string(), - }; - - let serialized = script.serialize().unwrap(); - assert_eq!( - serialized.0, - vec![ - 150, 217, 36, 48, 49, 57, 53, 99, 56, 50, 53, 45, 97, 51, 53, 102, 45, 55, 57, 56, - 50, 45, 98, 100, 98, 48, 45, 49, 54, 49, 54, 56, 56, 56, 49, 99, 98, 99, 54, 164, - 116, 101, 115, 116, 164, 116, 101, 115, 116, 164, 116, 101, 115, 116, 145, 164, - 116, 101, 115, 116, 164, 116, 101, 115, 116 - ] - ); - } - - #[test] - fn test_serialize_deserialize() { - let script = Script { - id: uuid::Uuid::new_v4(), - name: "test".to_string(), - description: "test".to_string(), - shebang: "test".to_string(), - tags: vec!["test".to_string()], - script: "test".to_string(), - }; - - let serialized = script.serialize().unwrap(); - - let deserialized = Script::deserialize(&serialized.0).unwrap(); - - assert_eq!(script, deserialized); - } -} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml index 8e425232..198db48b 100644 --- a/crates/atuin/Cargo.toml +++ b/crates/atuin/Cargo.toml @@ -11,47 +11,23 @@ license = { workspace = true } homepage = { workspace = true } repository = { workspace = true } -[package.metadata.binstall] -pkg-url = "{ repo }/releases/download/v{ version }/{ name }-{ target }.tar.gz" -bin-dir = "{ name }-{ target }/{ bin }{ binary-ext }" -pkg-fmt = "tgz" - -[package.metadata.deb] -maintainer = "Ellie Huxtable <ellie@elliehuxtable.com>" -copyright = "2021, Ellie Huxtable <ellie@elliehuxtable.com>" -license-file = ["LICENSE"] -depends = "$auto" -section = "utility" - -[package.metadata.rpm] -package = "atuin" - -[package.metadata.rpm.cargo] -buildflags = ["--release"] - -[package.metadata.rpm.targets] -atuin = { path = "/usr/bin/atuin" } - [features] -default = ["client", "sync", "clipboard", "check-update", "daemon", "ai", "pty-proxy"] +default = [ + "client", "sync", "clipboard", "daemon", "pty-proxy" +] client = ["atuin-client"] sync = ["atuin-client/sync"] -daemon = ["atuin-client/daemon", "atuin-daemon", "atuin-ai?/daemon"] -ai = ["atuin-ai"] +daemon = ["atuin-client/daemon", "atuin-daemon"] pty-proxy = ["dep:atuin-pty-proxy"] hex = ["pty-proxy"] clipboard = ["arboard"] -check-update = ["atuin-client/check-update"] [dependencies] -atuin-ai = { path = "../atuin-ai", version = "18.16.1", optional = true, default-features = false } atuin-client = { path = "../atuin-client", version = "18.16.1", optional = true, default-features = false } atuin-common = { workspace = true } -atuin-dotfiles = { workspace = true } atuin-history = { workspace = true } atuin-daemon = { path = "../atuin-daemon", version = "18.16.1", optional = true, default-features = false } atuin-pty-proxy = { path = "../atuin-pty-proxy", version = "18.16.1", optional = true, default-features = false } -atuin-scripts = { workspace = true } atuin-kv = { workspace = true } log = { workspace = true } @@ -95,9 +71,6 @@ shlex = "1.3.0" # settings editor with comment and relative ordering preservation toml_edit = { workspace = true } -[target.'cfg(any(target_os = "windows", target_os = "macos"))'.dependencies] -arboard = { version = "3.4", optional = true, default-features = false } - [target.'cfg(target_os = "linux")'.dependencies] arboard = { version = "3.4", optional = true, default-features = false, features = [ "wayland-data-control", @@ -106,15 +79,6 @@ arboard = { version = "3.4", optional = true, default-features = false, features [target.'cfg(unix)'.dependencies] daemonize = "0.5.0" -# Enable tree-sitter shell parsing on platforms where tree-sitter's bundled C -# compiles cleanly. tree-sitter 0.26's portable/endian.h fails on illumos, -# Windows cross-compiles, and potentially other exotic targets. -[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dependencies] -atuin-ai = { path = "../atuin-ai", version = "18.16.1", optional = true, default-features = false, features = ["tree-sitter"] } - -[target.'cfg(windows)'.dependencies] -windows-sys = { version = "0.61.2", features = ["Win32_System_Console"] } - [dev-dependencies] tracing-tree = "0.4" diff --git a/crates/atuin/contrib/pi/atuin.ts b/crates/atuin/contrib/pi/atuin.ts deleted file mode 100644 index 55c17cb8..00000000 --- a/crates/atuin/contrib/pi/atuin.ts +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Atuin extension for pi. - * - * Tracks bash commands executed by pi in Atuin history with author `pi`. - * - * Install with: - * atuin hook install pi - * - * Then restart pi or run /reload. - */ - -import type { BashOperations, ExtensionAPI } from "@mariozechner/pi-coding-agent"; -import { createBashTool, createLocalBashOperations } from "@mariozechner/pi-coding-agent"; - -const ATUIN_AUTHOR = "pi"; -const ATUIN_TIMEOUT_MS = 10_000; - -async function startHistory( - pi: ExtensionAPI, - cwd: string, - command: string, -): Promise<string | undefined> { - try { - const result = await pi.exec( - "atuin", - ["history", "start", "--author", ATUIN_AUTHOR, "--", command], - { cwd, timeout: ATUIN_TIMEOUT_MS }, - ); - - if (result.code !== 0) return undefined; - - const id = result.stdout.trim(); - return id.length > 0 ? id : undefined; - } catch { - return undefined; - } -} - -async function endHistory( - pi: ExtensionAPI, - cwd: string, - historyId: string, - exitCode: number, -): Promise<void> { - try { - await pi.exec( - "atuin", - ["history", "end", historyId, "--exit", String(exitCode)], - { cwd, timeout: ATUIN_TIMEOUT_MS }, - ); - } catch { - // Ignore Atuin failures so command execution is never blocked. - } -} - -export default function atuinPiExtension(pi: ExtensionAPI) { - const cwd = process.cwd(); - const local = createLocalBashOperations(); - - const trackedOperations: BashOperations = { - async exec(command, commandCwd, options) { - const historyId = await startHistory(pi, commandCwd, command); - let exitCode: number | null = null; - - try { - const result = await local.exec(command, commandCwd, options); - exitCode = result.exitCode; - return result; - } finally { - if (historyId) { - await endHistory( - pi, - commandCwd, - historyId, - exitCode ?? (options.signal?.aborted ? 130 : 1), - ); - } - } - }, - }; - - pi.registerTool( - createBashTool(cwd, { - operations: trackedOperations, - }), - ); -} diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs index e3ec01d9..3b0ef8a9 100644 --- a/crates/atuin/src/command/client.rs +++ b/crates/atuin/src/command/client.rs @@ -52,14 +52,11 @@ mod daemon; mod config; mod default_config; mod doctor; -mod dotfiles; mod history; -mod hook; mod import; mod info; mod init; mod kv; -mod scripts; mod search; mod setup; mod stats; @@ -77,9 +74,6 @@ pub enum Cmd { #[command(subcommand)] History(history::Cmd), - /// Manage AI-agent shell hooks - Hook(hook::Cmd), - /// Import shell history from file #[command(subcommand)] Import(import::Cmd), @@ -106,14 +100,6 @@ pub enum Cmd { #[command(subcommand)] Store(store::Cmd), - /// Manage your dotfiles with Atuin - #[command(subcommand)] - Dotfiles(dotfiles::Cmd), - - /// Manage your scripts with Atuin - #[command(subcommand)] - Scripts(scripts::Cmd), - /// Print Atuin's shell init script #[command()] Init(init::Cmd), @@ -140,11 +126,6 @@ pub enum Cmd { #[command(subcommand)] Config(config::Cmd), - - /// Run the AI assistant - #[cfg(feature = "ai")] - #[command(subcommand)] - Ai(atuin_ai::commands::Commands), } impl Cmd { @@ -158,14 +139,6 @@ impl Cmd { daemon::daemonize_current_process()?; } - #[cfg(feature = "ai")] - let mut runtime = if matches!(&self, Self::Ai(_)) { - tokio::runtime::Builder::new_multi_thread() - } else { - tokio::runtime::Builder::new_current_thread() - }; - - #[cfg(not(feature = "ai"))] let mut runtime = tokio::runtime::Builder::new_current_thread(); let runtime = runtime.enable_all().build().unwrap(); @@ -341,7 +314,6 @@ impl Cmd { // runs match self { Self::History(history) => return history.run(&settings).await, - Self::Hook(hook) => return hook.run(&settings).await, Self::Init(init) => return init.run(&settings).await, Self::Doctor => return doctor::run(&settings).await, Self::Config(config) => return config.run(&settings).await, @@ -373,10 +345,6 @@ impl Cmd { Self::Store(store) => store.run(&settings, &db, sqlite_store).await, - Self::Dotfiles(dotfiles) => dotfiles.run(&settings, sqlite_store).await, - - Self::Scripts(scripts) => scripts.run(&settings, sqlite_store, &db).await, - Self::Info => { info::run(&settings); Ok(()) @@ -387,17 +355,14 @@ impl Cmd { Ok(()) } - Self::Wrapped { year } => wrapped::run(year, &db, &settings, sqlite_store, theme).await, + Self::Wrapped { year } => wrapped::run(year, &db, &settings, theme).await, #[cfg(feature = "daemon")] Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, - Self::History(_) | Self::Hook(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { unreachable!() } - - #[cfg(feature = "ai")] - Self::Ai(cli) => atuin_ai::commands::run(cli, &settings).await, } } } diff --git a/crates/atuin/src/command/client/dotfiles.rs b/crates/atuin/src/command/client/dotfiles.rs deleted file mode 100644 index f42b18f2..00000000 --- a/crates/atuin/src/command/client/dotfiles.rs +++ /dev/null @@ -1,28 +0,0 @@ -use clap::Subcommand; -use eyre::Result; - -use atuin_client::{record::sqlite_store::SqliteStore, settings::Settings}; - -mod alias; -mod var; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Manage shell aliases with Atuin - #[command(subcommand)] - Alias(alias::Cmd), - - /// Manage shell and environment variables with Atuin - #[command(subcommand)] - Var(var::Cmd), -} - -impl Cmd { - pub async fn run(self, settings: &Settings, store: SqliteStore) -> Result<()> { - match self { - Self::Alias(cmd) => cmd.run(settings, store).await, - Self::Var(cmd) => cmd.run(settings, store).await, - } - } -} diff --git a/crates/atuin/src/command/client/dotfiles/alias.rs b/crates/atuin/src/command/client/dotfiles/alias.rs deleted file mode 100644 index 61f8601d..00000000 --- a/crates/atuin/src/command/client/dotfiles/alias.rs +++ /dev/null @@ -1,187 +0,0 @@ -use clap::{Subcommand, ValueEnum}; -use eyre::{Context, Result, eyre}; - -use atuin_client::{encryption, record::sqlite_store::SqliteStore, settings::Settings}; - -use atuin_dotfiles::{shell::Alias, store::AliasStore}; - -#[derive(Clone, Copy, Debug, Default, ValueEnum)] -pub enum SortBy { - /// Sort by alias name - #[default] - Name, - /// Sort by alias value - Value, -} - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Set an alias - Set { name: String, value: String }, - - /// Delete an alias - Delete { name: String }, - - /// List all aliases - List { - /// Sort results by field - #[arg(long, value_enum, default_value_t = SortBy::Name)] - sort_by: SortBy, - - /// Sort in reverse (descending) order - #[arg(long, short)] - reverse: bool, - - /// Filter aliases by name (substring match) - #[arg(long, short)] - name: Option<String>, - - /// Filter aliases by value (substring match) - #[arg(long, short)] - value: Option<String>, - }, - - /// Delete all aliases - Clear, - // There are too many edge cases to parse at the moment. Disable for now. - // Import, -} - -impl Cmd { - async fn set(&self, store: &AliasStore, name: String, value: String) -> Result<()> { - let illegal_char = regex::Regex::new("[ \t\n&();<>|\\\"'`$/]").unwrap(); - if illegal_char.is_match(name.as_str()) { - return Err(eyre!("Illegal character in alias name")); - } - - let aliases = store.aliases().await?; - let found: Vec<Alias> = aliases.into_iter().filter(|a| a.name == name).collect(); - - if found.is_empty() { - println!("Aliasing '{name}={value}'."); - } else { - println!( - "Overwriting alias '{name}={}' with '{name}={value}'.", - found[0].value - ); - } - - store.set(&name, &value).await?; - - Ok(()) - } - - async fn list( - &self, - store: &AliasStore, - sort_by: SortBy, - reverse: bool, - name_filter: Option<String>, - value_filter: Option<String>, - ) -> Result<()> { - let mut aliases = store.aliases().await?; - - // Apply filters - if let Some(ref name_pattern) = name_filter { - let pattern = name_pattern.to_lowercase(); - aliases.retain(|a| a.name.to_lowercase().contains(&pattern)); - } - if let Some(ref value_pattern) = value_filter { - let pattern = value_pattern.to_lowercase(); - aliases.retain(|a| a.value.to_lowercase().contains(&pattern)); - } - - // Apply sorting - match sort_by { - SortBy::Name => { - aliases.sort_by_key(|a| a.name.to_lowercase()); - } - SortBy::Value => { - aliases.sort_by_key(|a| a.value.to_lowercase()); - } - } - - // Apply reverse if requested - if reverse { - aliases.reverse(); - } - - for i in aliases { - println!("{}={}", i.name, i.value); - } - - Ok(()) - } - - async fn clear(&self, store: &AliasStore) -> Result<()> { - let aliases = store.aliases().await?; - - for i in aliases { - self.delete(store, i.name).await?; - } - - Ok(()) - } - - async fn delete(&self, store: &AliasStore, name: String) -> Result<()> { - let mut aliases = store.aliases().await?.into_iter(); - if let Some(alias) = aliases.find(|alias| alias.name == name) { - println!("Deleting '{name}={}'.", alias.value); - store.delete(&name).await?; - } else { - eprintln!("Cannot delete '{name}': Alias not set."); - } - Ok(()) - } - - /* - async fn import(&self, store: &AliasStore) -> Result<()> { - let aliases = atuin_dotfiles::shell::import_aliases(store).await?; - - for i in aliases { - println!("Importing {}={}", i.name, i.value); - } - - Ok(()) - } - */ - - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - if !settings.dotfiles.enabled { - eprintln!( - "Dotfiles are not enabled. Add\n\n[dotfiles]\nenabled = true\n\nto your configuration file to enable them.\n" - ); - eprintln!("The default configuration file is located at ~/.config/atuin/config.toml."); - return Ok(()); - } - - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - - let alias_store = AliasStore::new(store, host_id, encryption_key); - - match self { - Self::Set { name, value } => self.set(&alias_store, name.clone(), value.clone()).await, - Self::Delete { name } => self.delete(&alias_store, name.clone()).await, - Self::List { - sort_by, - reverse, - name, - value, - } => { - self.list( - &alias_store, - *sort_by, - *reverse, - name.clone(), - value.clone(), - ) - .await - } - Self::Clear => self.clear(&alias_store).await, - } - } -} diff --git a/crates/atuin/src/command/client/dotfiles/var.rs b/crates/atuin/src/command/client/dotfiles/var.rs deleted file mode 100644 index 94f75d57..00000000 --- a/crates/atuin/src/command/client/dotfiles/var.rs +++ /dev/null @@ -1,197 +0,0 @@ -use clap::{Subcommand, ValueEnum}; -use eyre::{Context, Result}; - -use atuin_client::{encryption, record::sqlite_store::SqliteStore, settings::Settings}; - -use atuin_dotfiles::{shell::Var, store::var::VarStore}; - -#[derive(Clone, Copy, Debug, Default, ValueEnum)] -pub enum SortBy { - /// Sort by variable name - #[default] - Name, - /// Sort by variable value - Value, -} - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Set a variable - Set { - name: String, - value: String, - - #[clap(long, short, action)] - no_export: bool, - }, - - /// Delete a variable - Delete { name: String }, - - /// List all variables - List { - /// Sort results by field - #[arg(long, value_enum, default_value_t = SortBy::Name)] - sort_by: SortBy, - - /// Sort in reverse (descending) order - #[arg(long, short)] - reverse: bool, - - /// Filter variables by name (substring match) - #[arg(long, short)] - name: Option<String>, - - /// Filter variables by value (substring match) - #[arg(long, short)] - value: Option<String>, - - /// Show only exported variables - #[arg(long, conflicts_with = "shell_only")] - exports_only: bool, - - /// Show only non-exported (shell) variables - #[arg(long, conflicts_with = "exports_only")] - shell_only: bool, - }, -} - -impl Cmd { - async fn set(&self, store: VarStore, name: String, value: String, export: bool) -> Result<()> { - let vars = store.vars().await?; - let found: Vec<Var> = vars.into_iter().filter(|a| a.name == name).collect(); - let show_export = if export { "export " } else { "" }; - - if found.is_empty() { - println!("Setting '{show_export}{name}={value}'."); - } else { - println!( - "Overwriting var '{show_export}{name}={}' with '{name}={value}'.", - found[0].value - ); - } - - store.set(&name, &value, export).await?; - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - async fn list( - &self, - store: VarStore, - sort_by: SortBy, - reverse: bool, - name_filter: Option<String>, - value_filter: Option<String>, - exports_only: bool, - shell_only: bool, - ) -> Result<()> { - let mut vars = store.vars().await?; - - // Apply export/shell filters - if exports_only { - vars.retain(|v| v.export); - } - if shell_only { - vars.retain(|v| !v.export); - } - - // Apply name/value filters - if let Some(ref name_pattern) = name_filter { - let pattern = name_pattern.to_lowercase(); - vars.retain(|v| v.name.to_lowercase().contains(&pattern)); - } - if let Some(ref value_pattern) = value_filter { - let pattern = value_pattern.to_lowercase(); - vars.retain(|v| v.value.to_lowercase().contains(&pattern)); - } - - // Apply sorting - match sort_by { - SortBy::Name => { - vars.sort_by_key(|a| a.name.to_lowercase()); - } - SortBy::Value => { - vars.sort_by_key(|a| a.value.to_lowercase()); - } - } - - // Apply reverse if requested - if reverse { - vars.reverse(); - } - - for i in vars { - if i.export { - println!("export {}={}", i.name, i.value); - } else { - println!("{}={}", i.name, i.value); - } - } - - Ok(()) - } - - async fn delete(&self, store: VarStore, name: String) -> Result<()> { - let mut vars = store.vars().await?.into_iter(); - - if let Some(var) = vars.find(|var| var.name == name) { - println!("Deleting '{name}={}'.", var.value); - store.delete(&name).await?; - } else { - eprintln!("Cannot delete '{name}': Var not set."); - } - - Ok(()) - } - - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - if !settings.dotfiles.enabled { - eprintln!( - "Dotfiles are not enabled. Add\n\n[dotfiles]\nenabled = true\n\nto your configuration file to enable them.\n" - ); - eprintln!("The default configuration file is located at ~/.config/atuin/config.toml."); - return Ok(()); - } - - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - - let var_store = VarStore::new(store, host_id, encryption_key); - - match self { - Self::Set { - name, - value, - no_export, - } => { - self.set(var_store, name.clone(), value.clone(), !no_export) - .await - } - Self::Delete { name } => self.delete(var_store, name.clone()).await, - Self::List { - sort_by, - reverse, - name, - value, - exports_only, - shell_only, - } => { - self.list( - var_store, - *sort_by, - *reverse, - name.clone(), - value.clone(), - *exports_only, - *shell_only, - ) - .await - } - } - } -} diff --git a/crates/atuin/src/command/client/hook.rs b/crates/atuin/src/command/client/hook.rs deleted file mode 100644 index 0abae575..00000000 --- a/crates/atuin/src/command/client/hook.rs +++ /dev/null @@ -1,479 +0,0 @@ -use std::io::Read; -use std::path::PathBuf; - -use atuin_client::settings::Settings; -use atuin_common::utils::home_dir; -use clap::{Parser, Subcommand}; -use eyre::{Result, bail}; -use serde_json::Value; - -use super::history; - -const HOOK_EVENT_TYPES: &[&str] = &["PreToolUse", "PostToolUse", "PostToolUseFailure"]; -const PI_EXTENSION_SOURCE: &str = include_str!("../../../contrib/pi/atuin.ts"); - -enum InstallKind { - JsonHooks { - config_path: &'static [&'static str], - hook_command: &'static str, - matcher: &'static str, - }, - PiExtension { - extension_path: &'static [&'static str], - }, -} - -struct AgentSpec { - aliases: &'static [&'static str], - actor_name: &'static str, - install_kind: InstallKind, -} - -const CLAUDE_CODE: AgentSpec = AgentSpec { - aliases: &["claude-code", "claude"], - actor_name: "claude-code", - install_kind: InstallKind::JsonHooks { - config_path: &[".claude", "settings.json"], - hook_command: "atuin hook claude-code", - matcher: "Bash", - }, -}; - -const CODEX: AgentSpec = AgentSpec { - aliases: &["codex"], - actor_name: "codex", - install_kind: InstallKind::JsonHooks { - config_path: &[".codex", "hooks.json"], - hook_command: "atuin hook codex", - matcher: "^Bash$", - }, -}; - -const PI: AgentSpec = AgentSpec { - aliases: &["pi"], - actor_name: "pi", - install_kind: InstallKind::PiExtension { - extension_path: &[".pi", "agent", "extensions", "atuin.ts"], - }, -}; - -const AGENTS: &[&AgentSpec] = &[&CLAUDE_CODE, &CODEX, &PI]; - -struct Agent(&'static AgentSpec); - -impl Agent { - fn from_name(name: &str) -> Result<Self> { - AGENTS - .iter() - .copied() - .find(|spec| spec.aliases.contains(&name)) - .map(Self) - .ok_or_else(|| { - eyre::eyre!("unknown agent: {name}. Supported agents: claude-code, codex, pi") - }) - } - - fn actor_name(&self) -> &'static str { - self.0.actor_name - } - - fn path(path: &'static [&'static str]) -> PathBuf { - path.iter() - .fold(home_dir(), |path, segment| path.join(segment)) - } - - fn install_kind(&self) -> &InstallKind { - &self.0.install_kind - } -} - -#[derive(Subcommand, Debug)] -enum Action { - /// Install hooks for an AI agent to capture commands in atuin history - Install { - /// Agent to install hooks for (e.g., "claude-code") - #[arg(value_name = "AGENT")] - agent: String, - }, -} - -#[derive(Parser, Debug)] -#[command(infer_subcommands = true, args_conflicts_with_subcommands = true)] -pub struct Cmd { - #[command(subcommand)] - action: Option<Action>, - - /// Which agent's hook format to parse (e.g., "claude-code") - #[arg(value_name = "AGENT", hide = true)] - agent: Option<String>, -} - -impl Cmd { - pub async fn run(self, settings: &Settings) -> Result<()> { - match (self.action, self.agent) { - (Some(Action::Install { agent }), None) => install(&agent), - (None, Some(agent)) => handle(&agent, settings).await, - (None, None) => bail!("expected `atuin hook <agent>` or `atuin hook install <agent>`"), - (Some(_), Some(_)) => bail!("hook action cannot be combined with a positional agent"), - } - } -} - -#[derive(Debug)] -enum HookEvent { - Start { - command: String, - intent: Option<String>, - tool_use_id: String, - }, - End { - tool_use_id: String, - exit: i64, - }, - Skip, -} - -fn parse_hook_stdin(input: &str) -> Result<HookEvent> { - let v: Value = serde_json::from_str(input)?; - - if v.get("tool_name").and_then(|t| t.as_str()) != Some("Bash") { - return Ok(HookEvent::Skip); - } - - let tool_use_id = match v.get("tool_use_id").and_then(|t| t.as_str()) { - Some(id) if !id.is_empty() => id.to_string(), - _ => return Ok(HookEvent::Skip), - }; - - match v.get("hook_event_name").and_then(|e| e.as_str()) { - Some("PreToolUse") => { - let tool_input = v.get("tool_input"); - let command = tool_input - .and_then(|ti| ti.get("command")) - .and_then(|c| c.as_str()) - .unwrap_or(""); - - if command.is_empty() { - return Ok(HookEvent::Skip); - } - - let intent = tool_input - .and_then(|ti| ti.get("description")) - .and_then(|d| d.as_str()) - .map(String::from); - - Ok(HookEvent::Start { - command: command.to_string(), - intent, - tool_use_id, - }) - } - Some(event @ ("PostToolUse" | "PostToolUseFailure")) => { - let exit = if event == "PostToolUseFailure" { - 1 - } else { - v.get("tool_response") - .and_then(|tr| tr.get("exitCode")) - .and_then(Value::as_i64) - .unwrap_or(0) - }; - - Ok(HookEvent::End { tool_use_id, exit }) - } - _ => Ok(HookEvent::Skip), - } -} - -fn id_file_path(tool_use_id: &str) -> PathBuf { - std::env::temp_dir().join(format!("atuin-hook-{tool_use_id}")) -} - -async fn handle(agent_name: &str, settings: &Settings) -> Result<()> { - let agent = Agent::from_name(agent_name)?; - - if matches!(agent.install_kind(), InstallKind::PiExtension { .. }) { - bail!("`atuin hook pi` is not supported. Use `atuin hook install pi` and reload pi."); - } - - let mut input = String::new(); - std::io::stdin().read_to_string(&mut input)?; - - if input.trim().is_empty() { - return Ok(()); - } - - match parse_hook_stdin(&input)? { - HookEvent::Start { - command, - intent, - tool_use_id, - } => { - if let Some(history_id) = history::start_history_entry( - settings, - &command, - Some(agent.actor_name()), - intent.as_deref(), - ) - .await? - { - std::fs::write(id_file_path(&tool_use_id), &history_id)?; - } - } - HookEvent::End { tool_use_id, exit } => { - let id_path = id_file_path(&tool_use_id); - - if let Ok(history_id) = std::fs::read_to_string(&id_path) { - let history_id = history_id.trim(); - if !history_id.is_empty() { - let _ = history::end_history_entry(settings, history_id, exit, None).await; - } - let _ = std::fs::remove_file(&id_path); - } - } - HookEvent::Skip => {} - } - - Ok(()) -} - -fn install(agent_name: &str) -> Result<()> { - let agent = Agent::from_name(agent_name)?; - - match agent.install_kind() { - InstallKind::JsonHooks { - config_path, - hook_command: _, - matcher: _, - } => { - let config_path = Agent::path(config_path); - - if let Some(parent) = config_path.parent() { - std::fs::create_dir_all(parent)?; - } - - let mut root: Value = if config_path.exists() { - let content = std::fs::read_to_string(&config_path)?; - serde_json::from_str(&content)? - } else { - Value::Object(serde_json::Map::new()) - }; - - let hooks = root - .as_object_mut() - .ok_or_else(|| eyre::eyre!("config is not a JSON object"))? - .entry("hooks") - .or_insert_with(|| Value::Object(serde_json::Map::new())); - - add_hook_entries(hooks, &agent)?; - - let content = serde_json::to_string_pretty(&root)?; - std::fs::write(&config_path, content)?; - - eprintln!( - "\nAtuin hooks installed for {}. Config: {}", - agent.actor_name(), - config_path.display() - ); - } - InstallKind::PiExtension { extension_path } => { - let extension_path = Agent::path(extension_path); - - if let Some(parent) = extension_path.parent() { - std::fs::create_dir_all(parent)?; - } - - let already_installed = std::fs::read_to_string(&extension_path) - .is_ok_and(|existing| existing == PI_EXTENSION_SOURCE); - - if already_installed { - eprintln!("pi extension: already installed, skipping"); - } else { - std::fs::write(&extension_path, PI_EXTENSION_SOURCE)?; - eprintln!("pi extension: installed atuin extension"); - } - - eprintln!( - "\nAtuin extension installed for {}. Extension: {}\nReload pi with `/reload` or restart pi.", - agent.actor_name(), - extension_path.display() - ); - } - } - - Ok(()) -} - -fn add_hook_entries(hooks: &mut Value, agent: &Agent) -> Result<()> { - let InstallKind::JsonHooks { - config_path: _, - hook_command, - matcher, - } = agent.install_kind() - else { - bail!("agent does not use JSON hooks") - }; - - for event_type in HOOK_EVENT_TYPES { - let event_hooks = hooks - .as_object_mut() - .ok_or_else(|| eyre::eyre!("hooks is not a JSON object"))? - .entry(*event_type) - .or_insert_with(|| Value::Array(Vec::new())); - - let arr = event_hooks - .as_array_mut() - .ok_or_else(|| eyre::eyre!("hooks.{event_type} is not an array"))?; - - let already_installed = arr.iter().any(|entry| { - entry["hooks"].as_array().is_some_and(|h| { - h.iter() - .any(|hook| hook["command"].as_str() == Some(hook_command)) - }) - }); - - if already_installed { - eprintln!("hooks.{event_type}: already installed, skipping"); - continue; - } - - arr.push(serde_json::json!({ - "matcher": matcher, - "hooks": [{"type": "command", "command": hook_command}] - })); - eprintln!("hooks.{event_type}: installed atuin hook"); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - Atuin, - command::{AtuinCmd, client}, - }; - use clap::Parser; - - #[test] - fn parse_hook_agent_command() { - let cmd = Cmd::try_parse_from(["hook", "codex"]).unwrap(); - - assert!(matches!( - (cmd.action, cmd.agent.as_deref()), - (None, Some("codex")) - )); - } - - #[test] - fn parse_hook_install_command() { - let cmd = Cmd::try_parse_from(["hook", "install", "codex"]).unwrap(); - - match (cmd.action, cmd.agent) { - (Some(Action::Install { agent }), None) => assert_eq!(agent, "codex"), - other => panic!("unexpected parsed command: {other:?}"), - } - } - - #[test] - fn parse_hook_install_pi_command() { - let cmd = Cmd::try_parse_from(["hook", "install", "pi"]).unwrap(); - - match (cmd.action, cmd.agent) { - (Some(Action::Install { agent }), None) => assert_eq!(agent, "pi"), - other => panic!("unexpected parsed command: {other:?}"), - } - } - - #[test] - fn agent_from_name_supports_pi() { - let agent = Agent::from_name("pi").unwrap(); - assert_eq!(agent.actor_name(), "pi"); - assert!(matches!( - agent.install_kind(), - InstallKind::PiExtension { .. } - )); - } - - #[test] - fn parse_top_level_hook_command() { - let cmd = Atuin::try_parse_from(["atuin", "hook", "codex"]).unwrap(); - - assert!(matches!( - cmd.atuin, - AtuinCmd::Client(client::Cmd::Hook(Cmd { action: None, agent: Some(agent) })) - if agent == "codex" - )); - } - - #[test] - fn test_parse_pre_tool_use() { - let input = r#"{ - "hook_event_name": "PreToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo hello", "description": "Test greeting"}, - "tool_use_id": "toolu_abc123", - "session_id": "sess1", - "cwd": "/tmp" - }"#; - - match parse_hook_stdin(input).unwrap() { - HookEvent::Start { - command, - intent, - tool_use_id, - } => { - assert_eq!(command, "echo hello"); - assert_eq!(intent.as_deref(), Some("Test greeting")); - assert_eq!(tool_use_id, "toolu_abc123"); - } - _ => panic!("expected Start event"), - } - } - - #[test] - fn test_parse_post_tool_use() { - let input = r#"{ - "hook_event_name": "PostToolUse", - "tool_name": "Bash", - "tool_input": {"command": "echo hello"}, - "tool_response": {"exitCode": 0}, - "tool_use_id": "toolu_abc123" - }"#; - - match parse_hook_stdin(input).unwrap() { - HookEvent::End { tool_use_id, exit } => { - assert_eq!(tool_use_id, "toolu_abc123"); - assert_eq!(exit, 0); - } - _ => panic!("expected End event"), - } - } - - #[test] - fn test_parse_non_bash_tool_skipped() { - let input = r#"{ - "hook_event_name": "PreToolUse", - "tool_name": "Write", - "tool_input": {"file_path": "/tmp/test.txt", "content": "hello"}, - "tool_use_id": "toolu_abc123" - }"#; - - assert!(matches!(parse_hook_stdin(input).unwrap(), HookEvent::Skip)); - } - - #[test] - fn test_parse_failure_event() { - let input = r#"{ - "hook_event_name": "PostToolUseFailure", - "tool_name": "Bash", - "tool_input": {"command": "false"}, - "tool_use_id": "toolu_abc123" - }"#; - - match parse_hook_stdin(input).unwrap() { - HookEvent::End { exit, .. } => assert_eq!(exit, 1), - _ => panic!("expected End event"), - } - } -} diff --git a/crates/atuin/src/command/client/init.rs b/crates/atuin/src/command/client/init.rs index 798cc22b..98ef5c80 100644 --- a/crates/atuin/src/command/client/init.rs +++ b/crates/atuin/src/command/client/init.rs @@ -1,13 +1,6 @@ -use std::path::PathBuf; - -use atuin_client::{ - encryption, - record::sqlite_store::SqliteStore, - settings::{Settings, Tmux}, -}; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; +use atuin_client::settings::{Settings, Tmux}; use clap::{Parser, ValueEnum}; -use eyre::{Result, WrapErr}; +use eyre::Result; mod bash; mod fish; @@ -101,17 +94,15 @@ $env.config = ( fn static_init(&self, settings: &Settings) { let tmux = &settings.tmux; - let disable_ai = self.disable_ai || matches!(settings.ai.enabled, Some(false)); - match self.shell { Shell::Zsh => { - zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, disable_ai, tmux); + zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); } Shell::Bash => { - bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, disable_ai, tmux); + bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); } Shell::Fish => { - fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, disable_ai, tmux); + fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); } Shell::Nu => { self.init_nu(tmux); @@ -126,73 +117,23 @@ $env.config = ( } async fn dotfiles_init(&self, settings: &Settings) -> Result<()> { - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - - let alias_store = AliasStore::new(sqlite_store.clone(), host_id, encryption_key); - let var_store = VarStore::new(sqlite_store.clone(), host_id, encryption_key); - - let disable_ai = self.disable_ai || matches!(settings.ai.enabled, Some(false)); - match self.shell { Shell::Zsh => { - zsh::init( - alias_store, - var_store, - self.disable_up_arrow, - self.disable_ctrl_r, - disable_ai, - &settings.tmux, - ) - .await?; + zsh::init(self.disable_up_arrow, self.disable_ctrl_r, &settings.tmux).await?; } Shell::Bash => { - bash::init( - alias_store, - var_store, - self.disable_up_arrow, - self.disable_ctrl_r, - disable_ai, - &settings.tmux, - ) - .await?; + bash::init(self.disable_up_arrow, self.disable_ctrl_r, &settings.tmux).await?; } Shell::Fish => { - fish::init( - alias_store, - var_store, - self.disable_up_arrow, - self.disable_ctrl_r, - disable_ai, - &settings.tmux, - ) - .await?; + fish::init(self.disable_up_arrow, self.disable_ctrl_r, &settings.tmux).await?; } Shell::Nu => self.init_nu(&settings.tmux), Shell::Xonsh => { - xonsh::init( - alias_store, - var_store, - self.disable_up_arrow, - self.disable_ctrl_r, - &settings.tmux, - ) - .await?; + xonsh::init(self.disable_up_arrow, self.disable_ctrl_r, &settings.tmux).await?; } Shell::PowerShell => { - powershell::init( - alias_store, - var_store, - self.disable_up_arrow, - self.disable_ctrl_r, - &settings.tmux, - ) - .await?; + powershell::init(self.disable_up_arrow, self.disable_ctrl_r, &settings.tmux) + .await?; } } diff --git a/crates/atuin/src/command/client/init/bash.rs b/crates/atuin/src/command/client/init/bash.rs index 745c239a..7fe57b33 100644 --- a/crates/atuin/src/command/client/init/bash.rs +++ b/crates/atuin/src/command/client/init/bash.rs @@ -1,5 +1,4 @@ use atuin_client::settings::Tmux; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; use eyre::Result; fn print_tmux_config(tmux: &Tmux) { @@ -11,7 +10,7 @@ fn print_tmux_config(tmux: &Tmux) { } } -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, disable_ai: bool, tmux: &Tmux) { +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { let base = include_str!("../../../shell/atuin.bash"); let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { @@ -24,29 +23,10 @@ pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, disable_ai: boo println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); println!("__atuin_bind_up_arrow={bind_up_arrow}"); println!("{base}"); - - #[cfg(feature = "ai")] - if !disable_ai { - let bind_ai = atuin_ai::commands::init::generate_bash_integration(); - println!("{bind_ai}"); - } } -pub async fn init( - aliases: AliasStore, - vars: VarStore, - disable_up_arrow: bool, - disable_ctrl_r: bool, - disable_ai: bool, - tmux: &Tmux, -) -> Result<()> { - init_static(disable_up_arrow, disable_ctrl_r, disable_ai, tmux); - - let aliases = atuin_dotfiles::shell::bash::alias_config(&aliases).await; - let vars = atuin_dotfiles::shell::bash::var_config(&vars).await; - - println!("{aliases}"); - println!("{vars}"); +pub async fn init(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) -> Result<()> { + init_static(disable_up_arrow, disable_ctrl_r, tmux); Ok(()) } diff --git a/crates/atuin/src/command/client/init/fish.rs b/crates/atuin/src/command/client/init/fish.rs index 6d6c8c23..e477faed 100644 --- a/crates/atuin/src/command/client/init/fish.rs +++ b/crates/atuin/src/command/client/init/fish.rs @@ -1,5 +1,4 @@ use atuin_client::settings::Tmux; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; use eyre::Result; fn print_tmux_config(tmux: &Tmux) { @@ -37,7 +36,7 @@ fn print_bindings( println!("{indent}end"); } -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, disable_ai: bool, tmux: &Tmux) { +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { let indent = " ".repeat(4); let base = include_str!("../../../shell/atuin.fish"); @@ -84,30 +83,15 @@ pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, disable_ai: boo ); println!("end"); - - #[cfg(feature = "ai")] - if !disable_ai { - let bind_ai = atuin_ai::commands::init::generate_fish_integration(); - println!("{bind_ai}"); - } } } pub async fn init( - aliases: AliasStore, - vars: VarStore, disable_up_arrow: bool, disable_ctrl_r: bool, - disable_ai: bool, tmux: &Tmux, ) -> Result<()> { - init_static(disable_up_arrow, disable_ctrl_r, disable_ai, tmux); - - let aliases = atuin_dotfiles::shell::fish::alias_config(&aliases).await; - let vars = atuin_dotfiles::shell::fish::var_config(&vars).await; - - println!("{aliases}"); - println!("{vars}"); + init_static(disable_up_arrow, disable_ctrl_r, tmux); Ok(()) } diff --git a/crates/atuin/src/command/client/init/powershell.rs b/crates/atuin/src/command/client/init/powershell.rs index d3399404..a36b8e67 100644 --- a/crates/atuin/src/command/client/init/powershell.rs +++ b/crates/atuin/src/command/client/init/powershell.rs @@ -1,5 +1,4 @@ use atuin_client::settings::Tmux; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { let base = include_str!("../../../shell/atuin.ps1"); @@ -19,21 +18,9 @@ pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { ); } -pub async fn init( - aliases: AliasStore, - vars: VarStore, - disable_up_arrow: bool, - disable_ctrl_r: bool, - tmux: &Tmux, -) -> eyre::Result<()> { +pub async fn init(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) -> eyre::Result<()> { init_static(disable_up_arrow, disable_ctrl_r, tmux); - let aliases = atuin_dotfiles::shell::powershell::alias_config(&aliases).await; - let vars = atuin_dotfiles::shell::powershell::var_config(&vars).await; - - println!("{aliases}"); - println!("{vars}"); - Ok(()) } diff --git a/crates/atuin/src/command/client/init/xonsh.rs b/crates/atuin/src/command/client/init/xonsh.rs index 8b9f1595..f14da3d8 100644 --- a/crates/atuin/src/command/client/init/xonsh.rs +++ b/crates/atuin/src/command/client/init/xonsh.rs @@ -1,5 +1,4 @@ use atuin_client::settings::Tmux; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; use eyre::Result; pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { @@ -23,20 +22,8 @@ pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { println!("{base}"); } -pub async fn init( - aliases: AliasStore, - vars: VarStore, - disable_up_arrow: bool, - disable_ctrl_r: bool, - tmux: &Tmux, -) -> Result<()> { +pub async fn init(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) -> Result<()> { init_static(disable_up_arrow, disable_ctrl_r, tmux); - let aliases = atuin_dotfiles::shell::xonsh::alias_config(&aliases).await; - let vars = atuin_dotfiles::shell::xonsh::var_config(&vars).await; - - println!("{aliases}"); - println!("{vars}"); - Ok(()) } diff --git a/crates/atuin/src/command/client/init/zsh.rs b/crates/atuin/src/command/client/init/zsh.rs index 5d588aa0..392e987c 100644 --- a/crates/atuin/src/command/client/init/zsh.rs +++ b/crates/atuin/src/command/client/init/zsh.rs @@ -1,5 +1,4 @@ use atuin_client::settings::Tmux; -use atuin_dotfiles::store::{AliasStore, var::VarStore}; use eyre::Result; fn print_tmux_config(tmux: &Tmux) { @@ -11,7 +10,7 @@ fn print_tmux_config(tmux: &Tmux) { } } -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, disable_ai: bool, tmux: &Tmux) { +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { let base = include_str!("../../../shell/atuin.zsh"); print_tmux_config(tmux); @@ -36,31 +35,15 @@ bindkey -M vicmd 'k' atuin-up-search-vicmd"; if !disable_up_arrow { println!("{BIND_UP_ARROW}"); } - - #[cfg(feature = "ai")] - if !disable_ai { - let bind_ai = atuin_ai::commands::init::generate_zsh_integration(); - - println!("{bind_ai}"); - } } } pub async fn init( - aliases: AliasStore, - vars: VarStore, disable_up_arrow: bool, disable_ctrl_r: bool, - disable_ai: bool, tmux: &Tmux, ) -> Result<()> { - init_static(disable_up_arrow, disable_ctrl_r, disable_ai, tmux); - - let aliases = atuin_dotfiles::shell::zsh::alias_config(&aliases).await; - let vars = atuin_dotfiles::shell::zsh::var_config(&vars).await; - - println!("{aliases}"); - println!("{vars}"); + init_static(disable_up_arrow, disable_ctrl_r, tmux); Ok(()) } diff --git a/crates/atuin/src/command/client/scripts.rs b/crates/atuin/src/command/client/scripts.rs deleted file mode 100644 index e5adacc4..00000000 --- a/crates/atuin/src/command/client/scripts.rs +++ /dev/null @@ -1,590 +0,0 @@ -use std::collections::HashMap; -use std::collections::HashSet; -use std::io::IsTerminal; -use std::io::Read; -use std::path::PathBuf; - -use atuin_scripts::execution::template_script; -use atuin_scripts::{ - execution::{build_executable_script, execute_script_interactive, template_variables}, - store::{ScriptStore, script::Script}, -}; -use clap::{Parser, Subcommand}; -use eyre::OptionExt; -use eyre::{Result, bail}; -use tempfile::NamedTempFile; - -use atuin_client::{database::Database, record::sqlite_store::SqliteStore, settings::Settings}; -use tracing::debug; - -#[derive(Parser, Debug)] -pub struct NewScript { - pub name: String, - - #[arg(short, long)] - pub description: Option<String>, - - #[arg(short, long)] - pub tags: Vec<String>, - - #[arg(short, long)] - pub shebang: Option<String>, - - #[arg(long)] - pub script: Option<PathBuf>, - - #[allow(clippy::option_option)] - #[arg(long)] - /// Use the last command as the script content - /// Optionally specify a number to use the last N commands - pub last: Option<Option<usize>>, - - #[arg(long)] - /// Skip opening editor when using --last - pub no_edit: bool, -} - -#[derive(Parser, Debug)] -pub struct Run { - pub name: String, - - /// Specify template variables in the format KEY=VALUE - /// Example: -v name=John -v greeting="Hello there" - #[arg(short, long = "var")] - pub var: Vec<String>, -} - -#[derive(Parser, Debug)] -pub struct List {} - -#[derive(Parser, Debug)] -pub struct Get { - pub name: String, - - #[arg(short, long)] - /// Display only the executable script with shebang - pub script: bool, -} - -#[derive(Parser, Debug)] -pub struct Edit { - pub name: String, - - #[arg(short, long)] - pub description: Option<String>, - - /// Replace all existing tags with these new tags - #[arg(short, long)] - pub tags: Vec<String>, - - /// Remove all tags from the script - #[arg(long)] - pub no_tags: bool, - - /// Rename the script - #[arg(long)] - pub rename: Option<String>, - - #[arg(short, long)] - pub shebang: Option<String>, - - #[arg(long)] - pub script: Option<PathBuf>, - - #[allow(clippy::struct_field_names)] - /// Skip opening editor - #[arg(long)] - pub no_edit: bool, -} - -#[derive(Parser, Debug)] -pub struct Delete { - pub name: String, - - #[arg(short, long)] - pub force: bool, -} - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - New(NewScript), - Run(Run), - #[command(alias = "ls")] - List(List), - - Get(Get), - Edit(Edit), - #[command(alias = "rm")] - Delete(Delete), -} - -impl Cmd { - // Helper function to open an editor with optional initial content - fn open_editor(initial_content: Option<&str>) -> Result<String> { - // Create a temporary file - let temp_file = NamedTempFile::new()?; - let path = temp_file.into_temp_path(); - - // Write initial content to the temp file if provided - if let Some(content) = initial_content { - std::fs::write(&path, content)?; - } - - // Open the file in the user's preferred editor - let editor_str = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); - - // Use shlex to safely split the string into shell-like parts. - let parts = shlex::split(&editor_str).ok_or_eyre("Failed to parse editor command")?; - let (command, args) = parts.split_first().ok_or_eyre("No editor command found")?; - - let status = std::process::Command::new(command) - .args(args) - .arg(&path) - .status()?; - if !status.success() { - bail!("failed to open editor"); - } - - // Read back the edited content - let content = std::fs::read_to_string(&path)?; - path.close()?; - - Ok(content) - } - - // Helper function to execute a script and manage stdin/stdout/stderr - async fn execute_script(script_content: String, shebang: String) -> Result<i32> { - let mut session = execute_script_interactive(script_content, shebang) - .await - .expect("failed to execute script"); - - // Create a channel to signal when the process exits - let (exit_tx, mut exit_rx) = tokio::sync::oneshot::channel(); - - // Set up a task to read from stdin and forward to the script - let sender = session.stdin_tx.clone(); - let stdin_task = tokio::spawn(async move { - use tokio::io::AsyncReadExt; - use tokio::select; - - let stdin = tokio::io::stdin(); - let mut reader = tokio::io::BufReader::new(stdin); - let mut buffer = vec![0u8; 1024]; // Read in chunks for efficiency - - loop { - // Use select to either read from stdin or detect when the process exits - select! { - // Check if the script process has exited - _ = &mut exit_rx => { - break; - } - // Try to read from stdin - read_result = reader.read(&mut buffer) => { - match read_result { - Ok(0) => break, // EOF - Ok(n) => { - // Convert the bytes to a string and forward to script - let input = String::from_utf8_lossy(&buffer[0..n]).to_string(); - if let Err(e) = sender.send(input).await { - eprintln!("Error sending input to script: {e}"); - break; - } - }, - Err(e) => { - eprintln!("Error reading from stdin: {e}"); - break; - } - } - } - } - } - }); - - // Wait for the script to complete - let exit_code = session.wait_for_exit().await; - - // Signal the stdin task to stop - let _ = exit_tx.send(()); - let _ = stdin_task.await; - - let code = exit_code.unwrap_or(-1); - if code != 0 { - eprintln!("Script exited with code {code}"); - } - - Ok(code) - } - - async fn handle_new_script( - settings: &Settings, - new_script: NewScript, - script_store: ScriptStore, - script_db: atuin_scripts::database::Database, - history_db: &impl Database, - ) -> Result<()> { - let mut stdin = std::io::stdin(); - let script_content = if let Some(count_opt) = new_script.last { - // Get the last N commands from history, plus 1 to exclude the command that runs this script - let count = count_opt.unwrap_or(1) + 1; // Add 1 to the count to exclude the current command - let context = atuin_client::database::current_context().await?; - - // Get the last N+1 commands, filtering by the default mode - let filters = [settings.default_filter_mode(context.git_root.is_some())]; - - let mut history = history_db - .list(&filters, &context, Some(count), false, false) - .await?; - - // Reverse to get chronological order - history.reverse(); - - // Skip the most recent command (which would be the atuin scripts new command itself) - if !history.is_empty() { - history.pop(); // Remove the most recent command - } - - // Format the commands into a script - let commands: Vec<String> = history.iter().map(|h| h.command.clone()).collect(); - - if commands.is_empty() { - bail!("No commands found in history"); - } - - let script_text = commands.join("\n"); - - // Only open editor if --no-edit is not specified - if new_script.no_edit { - Some(script_text) - } else { - // Open the editor with the commands pre-loaded - Some(Self::open_editor(Some(&script_text))?) - } - } else if let Some(script_path) = new_script.script { - let script_content = std::fs::read_to_string(script_path)?; - Some(script_content) - } else if !stdin.is_terminal() { - let mut buffer = String::new(); - stdin.read_to_string(&mut buffer)?; - Some(buffer) - } else { - // Open editor with empty file - Some(Self::open_editor(None)?) - }; - - let script = Script::builder() - .name(new_script.name) - .description(new_script.description.unwrap_or_default()) - .shebang(new_script.shebang.unwrap_or_default()) - .tags(new_script.tags) - .script(script_content.unwrap_or_default()) - .build(); - - script_store.create(script).await?; - - script_store.build(script_db).await?; - - Ok(()) - } - - async fn handle_run( - _settings: &Settings, - run: Run, - script_db: atuin_scripts::database::Database, - ) -> Result<()> { - let script = script_db.get_by_name(&run.name).await?; - - if let Some(script) = script { - // Get variables used in the template - let variables = template_variables(&script)?; - - // Create a hashmap to store variable values - let mut variable_values: HashMap<String, serde_json::Value> = HashMap::new(); - - // Parse variables from command-line arguments first - for var_str in &run.var { - if let Some((key, value)) = var_str.split_once('=') { - // Add to variable values - variable_values.insert( - key.to_string(), - serde_json::Value::String(value.to_string()), - ); - debug!("Using CLI variable: {}={}", key, value); - } else { - eprintln!("Warning: Ignoring malformed variable specification: {var_str}"); - eprintln!("Variables should be specified as KEY=VALUE"); - } - } - - // Collect variables that are still needed (not specified via CLI) - let remaining_vars: HashSet<String> = variables - .into_iter() - .filter(|var| !variable_values.contains_key(var)) - .collect(); - - // If there are variables in the template that weren't specified on the command line, prompt for them - if !remaining_vars.is_empty() { - println!("This script contains template variables that need values:"); - - let stdin = std::io::stdin(); - let mut input = String::new(); - - for var in remaining_vars { - input.clear(); - - println!("Enter value for '{var}': "); - - if stdin.read_line(&mut input).is_err() { - eprintln!("Failed to read input for variable '{var}'"); - // Provide an empty string as fallback - variable_values.insert(var, serde_json::Value::String(String::new())); - continue; - } - - let value = input.trim().to_string(); - variable_values.insert(var, serde_json::Value::String(value)); - } - } - - let final_script = if variable_values.is_empty() { - // No variables to template, just use the original script - script.script.clone() - } else { - // If we have variables, we need to template the script - debug!("Templating script with variables: {:?}", variable_values); - template_script(&script, &variable_values)? - }; - - // Execute the script (either templated or original) - Self::execute_script(final_script, script.shebang.clone()).await?; - } else { - bail!("script not found"); - } - Ok(()) - } - - async fn handle_list( - _settings: &Settings, - _list: List, - script_db: atuin_scripts::database::Database, - ) -> Result<()> { - let scripts = script_db.list().await?; - - if scripts.is_empty() { - println!("No scripts found"); - } else { - println!("Available scripts:"); - for script in scripts { - if script.tags.is_empty() { - println!("- {} ", script.name); - } else { - println!("- {} [tags: {}]", script.name, script.tags.join(", ")); - } - - // Print description if it's not empty - if !script.description.is_empty() { - println!(" Description: {}", script.description); - } - } - } - - Ok(()) - } - - async fn handle_get( - _settings: &Settings, - get: Get, - script_db: atuin_scripts::database::Database, - ) -> Result<()> { - let script = script_db.get_by_name(&get.name).await?; - - if let Some(script) = script { - if get.script { - // Just print the executable script with shebang - print!( - "{}", - build_executable_script(script.script.clone(), script.shebang) - ); - return Ok(()); - } - - // Create a YAML representation of the script - println!("---"); - println!("name: {}", script.name); - println!("id: {}", script.id); - - if script.description.is_empty() { - println!("description: \"\""); - } else { - println!("description: |"); - // Indent multiline descriptions properly for YAML - for line in script.description.lines() { - println!(" {line}"); - } - } - - if script.tags.is_empty() { - println!("tags: []"); - } else { - println!("tags:"); - for tag in &script.tags { - println!(" - {tag}"); - } - } - - println!("shebang: {}", script.shebang); - - println!("script: |"); - // Indent the script content for proper YAML multiline format - for line in script.script.lines() { - println!(" {line}"); - } - - Ok(()) - } else { - bail!("script '{}' not found", get.name); - } - } - - #[allow(clippy::cognitive_complexity)] - async fn handle_edit( - _settings: &Settings, - edit: Edit, - script_store: ScriptStore, - script_db: atuin_scripts::database::Database, - ) -> Result<()> { - debug!("editing script {:?}", edit); - // Find the existing script - let existing_script = script_db.get_by_name(&edit.name).await?; - debug!("existing script {:?}", existing_script); - - if let Some(mut script) = existing_script { - // Update the script with new values if provided - if let Some(description) = edit.description { - script.description = description; - } - - // Handle renaming if requested - if let Some(new_name) = edit.rename { - // Check if a script with the new name already exists - if (script_db.get_by_name(&new_name).await?).is_some() { - bail!("A script named '{}' already exists", new_name); - } - - // Update the name - script.name = new_name; - } - - // Handle tag updates with priority: - // 1. If --no-tags is provided, clear all tags - // 2. If --tags is provided, replace all tags - // 3. If neither is provided, tags remain unchanged - if edit.no_tags { - // Clear all tags - script.tags.clear(); - } else if !edit.tags.is_empty() { - // Replace all tags - script.tags = edit.tags; - } - // If none of the above conditions are met, tags remain unchanged - - if let Some(shebang) = edit.shebang { - script.shebang = shebang; - } - - // Handle script content update - let script_content = if let Some(script_path) = edit.script { - // Load script from provided file - std::fs::read_to_string(script_path)? - } else if !edit.no_edit { - // Open the script in editor for interactive editing if --no-edit is not specified - Self::open_editor(Some(&script.script))? - } else { - // If --no-edit is specified, keep the existing script content - script.script.clone() - }; - - // Update the script content - script.script = script_content; - - // Update the script in the store - script_store.update(script).await?; - - // Rebuild the database to apply changes - script_store.build(script_db).await?; - - println!("Script '{}' updated successfully!", edit.name); - - Ok(()) - } else { - bail!("script '{}' not found", edit.name); - } - } - - async fn handle_delete( - _settings: &Settings, - delete: Delete, - script_store: ScriptStore, - script_db: atuin_scripts::database::Database, - ) -> Result<()> { - // Find the script by name - let script = script_db.get_by_name(&delete.name).await?; - - if let Some(script) = script { - // If not force, confirm deletion - if !delete.force { - println!( - "Are you sure you want to delete script '{}'? [y/N]", - delete.name - ); - let mut input = String::new(); - std::io::stdin().read_line(&mut input)?; - - let input = input.trim().to_lowercase(); - if input != "y" && input != "yes" { - println!("Deletion cancelled"); - return Ok(()); - } - } - - // Delete the script - script_store.delete(script.id).await?; - - // Rebuild the database to apply changes - script_store.build(script_db).await?; - - println!("Script '{}' deleted successfully", delete.name); - Ok(()) - } else { - bail!("script '{}' not found", delete.name); - } - } - - pub async fn run( - self, - settings: &Settings, - store: SqliteStore, - history_db: &impl Database, - ) -> Result<()> { - let host_id = Settings::host_id().await?; - let encryption_key: [u8; 32] = atuin_client::encryption::load_key(settings)?.into(); - - let script_store = ScriptStore::new(store, host_id, encryption_key); - let script_db = - atuin_scripts::database::Database::new(settings.scripts.db_path.clone(), 1.0).await?; - - match self { - Self::New(new_script) => { - Self::handle_new_script(settings, new_script, script_store, script_db, history_db) - .await - } - Self::Run(run) => Self::handle_run(settings, run, script_db).await, - Self::List(list) => Self::handle_list(settings, list, script_db).await, - Self::Get(get) => Self::handle_get(settings, get, script_db).await, - Self::Edit(edit) => Self::handle_edit(settings, edit, script_store, script_db).await, - Self::Delete(delete) => { - Self::handle_delete(settings, delete, script_store, script_db).await - } - } - } -} diff --git a/crates/atuin/src/command/client/store/rebuild.rs b/crates/atuin/src/command/client/store/rebuild.rs index 8b334ced..b9f2837b 100644 --- a/crates/atuin/src/command/client/store/rebuild.rs +++ b/crates/atuin/src/command/client/store/rebuild.rs @@ -1,5 +1,3 @@ -use atuin_dotfiles::store::{AliasStore, var::VarStore}; -use atuin_scripts::store::ScriptStore; use clap::Args; use eyre::{Result, bail}; @@ -33,14 +31,6 @@ impl Rebuild { .await?; } - "dotfiles" => { - self.rebuild_dotfiles(settings, store.clone()).await?; - } - - "scripts" => { - self.rebuild_scripts(settings, store.clone()).await?; - } - tag => bail!("unknown tag: {tag}"), } @@ -65,30 +55,4 @@ impl Rebuild { Ok(()) } - - async fn rebuild_dotfiles(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); - - let host_id = Settings::host_id().await?; - - let alias_store = AliasStore::new(store.clone(), host_id, encryption_key); - let var_store = VarStore::new(store.clone(), host_id, encryption_key); - - alias_store.build().await?; - var_store.build().await?; - - Ok(()) - } - - async fn rebuild_scripts(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); - let host_id = Settings::host_id().await?; - let script_store = ScriptStore::new(store, host_id, encryption_key); - let database = - atuin_scripts::database::Database::new(settings.scripts.db_path.clone(), 1.0).await?; - - script_store.build(database).await?; - - Ok(()) - } } diff --git a/crates/atuin/src/command/client/wrapped.rs b/crates/atuin/src/command/client/wrapped.rs index 8d2b5e51..82b4cd5b 100644 --- a/crates/atuin/src/command/client/wrapped.rs +++ b/crates/atuin/src/command/client/wrapped.rs @@ -3,11 +3,7 @@ use eyre::Result; use std::collections::{HashMap, HashSet}; use time::{Date, Duration, Month, OffsetDateTime, Time}; -use atuin_client::{ - database::Database, encryption, record::sqlite_store::SqliteStore, settings::Settings, - theme::Theme, -}; -use atuin_dotfiles::store::AliasStore; +use atuin_client::{database::Database, settings::Settings, theme::Theme}; use atuin_history::stats::{Stats, compute}; @@ -24,26 +20,7 @@ struct WrappedStats { impl WrappedStats { #[allow(clippy::too_many_lines, clippy::cast_precision_loss)] - fn new( - settings: &Settings, - stats: &Stats, - history: &[atuin_client::history::History], - alias_map: &HashMap<String, String>, - ) -> Self { - // Helper to expand alias to its first command word - let expand_alias = |cmd: &str| -> String { - alias_map.get(cmd).map_or_else( - || cmd.to_string(), - |expanded| { - expanded - .split_whitespace() - .next() - .unwrap_or(cmd) - .to_string() - }, - ) - }; - + fn new(settings: &Settings, stats: &Stats, history: &[atuin_client::history::History]) -> Self { let nav_commands = stats .top .iter() @@ -119,13 +96,12 @@ impl WrappedStats { let mut hours: HashMap<String, usize> = HashMap::new(); for entry in history { - let raw_cmd = entry + let cmd = entry .command .split_whitespace() .next() .unwrap_or("") .to_string(); - let cmd = expand_alias(&raw_cmd); let (total, errors) = command_errors.entry(cmd.clone()).or_insert((0, 0)); *total += 1; if entry.exit != 0 { @@ -290,7 +266,6 @@ pub async fn run( year: Option<i32>, db: &impl Database, settings: &Settings, - store: SqliteStore, theme: &Theme, ) -> Result<()> { let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); @@ -324,30 +299,9 @@ pub async fn run( return Ok(()); } - // Load aliases for expansion - let alias_map: HashMap<String, String> = if settings.dotfiles.enabled { - if let Ok(encryption_key) = encryption::load_key(settings) { - let encryption_key: [u8; 32] = encryption_key.into(); - let host_id = Settings::host_id().await?; - let alias_store = AliasStore::new(store, host_id, encryption_key); - - alias_store - .aliases() - .await - .unwrap_or_default() - .into_iter() - .map(|a| (a.name, a.value)) - .collect() - } else { - HashMap::new() - } - } else { - HashMap::new() - }; - // Compute overall stats using existing functionality let stats = compute(settings, &history, 10, 1).expect("Failed to compute stats"); - let wrapped_stats = WrappedStats::new(settings, &stats, &history, &alias_map); + let wrapped_stats = WrappedStats::new(settings, &stats, &history); // Print wrapped format print_wrapped_header(year); diff --git a/crates/atuin/src/sync.rs b/crates/atuin/src/sync.rs index 26004130..14982300 100644 --- a/crates/atuin/src/sync.rs +++ b/crates/atuin/src/sync.rs @@ -1,5 +1,3 @@ -use atuin_dotfiles::store::{AliasStore, var::VarStore}; -use atuin_scripts::store::ScriptStore; use eyre::{Context, Result}; use atuin_client::{ @@ -32,19 +30,11 @@ pub async fn build( let kv_db = atuin_kv::database::Database::new(settings.kv.db_path.clone(), 1.0).await?; let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - let alias_store = AliasStore::new(store.clone(), host_id, encryption_key); - let var_store = VarStore::new(store.clone(), host_id, encryption_key); let kv_store = KvStore::new(store.clone(), kv_db, host_id, encryption_key); - let script_store = ScriptStore::new(store.clone(), host_id, encryption_key); history_store.incremental_build(db, downloaded).await?; - alias_store.build().await?; - var_store.build().await?; kv_store.build().await?; - let script_db = - atuin_scripts::database::Database::new(settings.scripts.db_path.clone(), 1.0).await?; - script_store.build(script_db).await?; Ok(()) } @@ -1,115 +1,25 @@ { "nodes": { - "fenix": { - "inputs": { - "nixpkgs": [ - "nixpkgs" - ], - "rust-analyzer-src": "rust-analyzer-src" - }, - "locked": { - "lastModified": 1758609765, - "narHash": "sha256-VIYu7R9Yc/CItjmzLSm21Lr9DgpEsKL5H+JUu8KDTn4=", - "owner": "nix-community", - "repo": "fenix", - "rev": "05545a7f3cd5cd5628b195520758e56e6734b90a", - "type": "github" - }, - "original": { - "owner": "nix-community", - "repo": "fenix", - "type": "github" - } - }, - "flake-compat": { - "flake": false, - "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", - "owner": "edolstra", - "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", - "type": "github" - }, - "original": { - "owner": "edolstra", - "repo": "flake-compat", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1731533236, - "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "nixpkgs": { "locked": { - "lastModified": 1758446476, - "narHash": "sha256-5rdAi7CTvM/kSs6fHe1bREIva5W3TbImsto+dxG4mBo=", + "lastModified": 1781074563, + "narHash": "sha256-md8WlXOlfnIeHeOScMTTHFyf2d6iaTwPl2apR5EQ3P4=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "a1f79a1770d05af18111fbbe2a3ab2c42c0f6cd0", + "rev": "9ae611a455b90cf061d8f332b977e387bda8e1ca", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixpkgs-unstable", + "ref": "nixos-unstable-small", "repo": "nixpkgs", "type": "github" } }, "root": { "inputs": { - "fenix": "fenix", - "flake-compat": "flake-compat", - "flake-utils": "flake-utils", "nixpkgs": "nixpkgs" } - }, - "rust-analyzer-src": { - "flake": false, - "locked": { - "lastModified": 1758556272, - "narHash": "sha256-9amq6LAd0CFF3dLrJUItPiG64MQOG4QPrvjbjpa6NFc=", - "owner": "rust-lang", - "repo": "rust-analyzer", - "rev": "d05355db16dc526bb16bd84769ea840668d7015e", - "type": "github" - }, - "original": { - "owner": "rust-lang", - "ref": "nightly", - "repo": "rust-analyzer", - "type": "github" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } } }, "root": "root", @@ -1,74 +1,51 @@ { inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - flake-utils.url = "github:numtide/flake-utils"; - flake-compat = { - url = "github:edolstra/flake-compat"; - flake = false; - }; - fenix = { - url = "github:nix-community/fenix"; - inputs.nixpkgs.follows = "nixpkgs"; - }; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small"; }; - outputs = - { self - , nixpkgs - , flake-utils - , fenix - , ... - }: - flake-utils.lib.eachDefaultSystem - (system: - let - pkgs = nixpkgs.outputs.legacyPackages.${system}; - in - { - packages.atuin = pkgs.callPackage ./atuin.nix { - rustPlatform = - let - toolchain = - fenix.packages.${system}.fromToolchainFile - { - file = ./rust-toolchain.toml; - sha256 = "sha256-mvUGEOHYJpn3ikC5hckneuGixaC+yGrkMM/liDIDgoU="; - }; - in - pkgs.makeRustPlatform { - cargo = toolchain; - rustc = toolchain; - }; - }; - packages.default = self.outputs.packages.${system}.atuin; - devShells.default = self.packages.${system}.default.overrideAttrs (super: { - nativeBuildInputs = with pkgs; - super.nativeBuildInputs - ++ [ - cargo-edit - clippy - rustfmt - ]; - RUST_SRC_PATH = "${pkgs.rustPlatform.rustLibSrc}"; + outputs = { + self, + nixpkgs, + ... + }: let + system = "x86_64-linux"; + pkgs = nixpkgs.outputs.legacyPackages.${system}; + in { + packages."${system}" = { + atuin = pkgs.callPackage ./atuin.nix {}; + default = self.outputs.packages.${system}.atuin; + }; - shellHook = '' - echo >&2 "Setting development database path" - export ATUIN_DB_PATH="/tmp/atuin_dev.db" - export ATUIN_RECORD_STORE_PATH="/tmp/atuin_records.db" + devShells."${system}".default = self.packages.${system}.default.overrideAttrs (super: { + nativeBuildInputs = + super.nativeBuildInputs + ++ [ + # rust stuff + pkgs.cargo + pkgs.clippy + pkgs.rustc + pkgs.rustfmt + pkgs.mold - if [ -e "''${ATUIN_DB_PATH}" ]; then - echo >&2 "''${ATUIN_DB_PATH} already exists, you might want to double-check that" - fi + pkgs.cargo-edit + pkgs.cargo-expand + pkgs.cargo-flamegraph + ]; + RUST_SRC_PATH = "${pkgs.rustPlatform.rustLibSrc}"; - if [ -e "''${ATUIN_RECORD_STORE_PATH}" ]; then - echo >&2 "''${ATUIN_RECORD_STORE_PATH} already exists, you might want to double-check that" - fi - ''; - }); - }) - // { - overlays.default = final: prev: { - inherit (self.packages.${final.stdenv.hostPlatform.system}) atuin; - }; - }; + shellHook = '' + echo >&2 "Setting development database path" + export ATUIN_DB_PATH="/tmp/atuin_dev.db" + export ATUIN_RECORD_STORE_PATH="/tmp/atuin_records.db" + + if [ -e "''${ATUIN_DB_PATH}" ]; then + echo >&2 "''${ATUIN_DB_PATH} already exists, you might want to double-check that" + fi + + if [ -e "''${ATUIN_RECORD_STORE_PATH}" ]; then + echo >&2 "''${ATUIN_RECORD_STORE_PATH} already exists, you might want to double-check that" + fi + ''; + }); + }; } |
