From 5c39e7cf284a1f6e9a1657f2deb44e359fc47eb8 Mon Sep 17 00:00:00 2001 From: Benedikt Peetz Date: Thu, 11 Jun 2026 00:54:30 +0200 Subject: chore: Move everything into one big crate That helps remove duplicated code and rustc/cargo will now also show dead code correctly. --- Cargo.lock | 284 +- Cargo.toml | 66 +- crates/atuin-client/Cargo.toml | 82 - crates/atuin-client/config.toml | 371 --- .../meta-migrations/20260203030924_create_meta.sql | 5 - .../migrations/20210422143411_create_history.sql | 16 - .../migrations/20220505083406_create-events.sql | 11 - .../20220806155627_interactive_search_index.sql | 6 - .../migrations/20230315220114_drop-events.sql | 2 - .../migrations/20230319185725_deleted_at.sql | 2 - .../20260224000100_history_author_intent.sql | 2 - .../20230531212437_create-records.sql | 16 - .../20231127090831_create-store.sql | 15 - crates/atuin-client/src/api_client.rs | 437 --- crates/atuin-client/src/auth.rs | 230 -- crates/atuin-client/src/database.rs | 1525 ---------- crates/atuin-client/src/distro.rs | 89 - crates/atuin-client/src/encryption.rs | 440 --- crates/atuin-client/src/history.rs | 756 ----- crates/atuin-client/src/history/builder.rs | 154 - crates/atuin-client/src/history/store.rs | 434 --- crates/atuin-client/src/import/bash.rs | 220 -- crates/atuin-client/src/import/fish.rs | 179 -- crates/atuin-client/src/import/mod.rs | 140 - crates/atuin-client/src/import/nu.rs | 67 - crates/atuin-client/src/import/nu_histdb.rs | 113 - crates/atuin-client/src/import/powershell.rs | 202 -- crates/atuin-client/src/import/replxx.rs | 137 - crates/atuin-client/src/import/resh.rs | 140 - crates/atuin-client/src/import/xonsh.rs | 234 -- crates/atuin-client/src/import/xonsh_sqlite.rs | 217 -- crates/atuin-client/src/import/zsh.rs | 230 -- crates/atuin-client/src/import/zsh_histdb.rs | 249 -- crates/atuin-client/src/lib.rs | 31 - crates/atuin-client/src/login.rs | 68 - crates/atuin-client/src/logout.rs | 16 - crates/atuin-client/src/meta.rs | 365 --- crates/atuin-client/src/ordering.rs | 32 - crates/atuin-client/src/plugin.rs | 150 - crates/atuin-client/src/record/encryption.rs | 373 --- crates/atuin-client/src/record/mod.rs | 6 - crates/atuin-client/src/record/sqlite_store.rs | 642 ---- crates/atuin-client/src/record/store.rs | 60 - crates/atuin-client/src/record/sync.rs | 663 ----- crates/atuin-client/src/register.rs | 20 - crates/atuin-client/src/secrets.rs | 194 -- crates/atuin-client/src/settings.rs | 1855 ------------ crates/atuin-client/src/settings/meta.rs | 17 - crates/atuin-client/src/settings/watcher.rs | 256 -- crates/atuin-client/src/sync.rs | 213 -- crates/atuin-client/src/theme.rs | 831 ------ crates/atuin-client/src/utils.rs | 14 - .../atuin-client/tests/data/xonsh-history.sqlite | Bin 12288 -> 0 bytes ...xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json | 12 - ...xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json | 12 - crates/atuin-common/Cargo.toml | 31 - crates/atuin-common/src/api.rs | 144 - crates/atuin-common/src/calendar.rs | 16 - crates/atuin-common/src/lib.rs | 60 - crates/atuin-common/src/record.rs | 426 --- crates/atuin-common/src/shell.rs | 183 -- crates/atuin-common/src/tls.rs | 15 - crates/atuin-common/src/utils.rs | 383 --- crates/atuin-daemon/Cargo.toml | 52 - crates/atuin-daemon/build.rs | 25 - crates/atuin-daemon/proto/control.proto | 62 - crates/atuin-daemon/proto/history.proto | 81 - crates/atuin-daemon/proto/search.proto | 35 - crates/atuin-daemon/proto/semantic.proto | 47 - crates/atuin-daemon/src/client.rs | 518 ---- crates/atuin-daemon/src/components/history.rs | 327 --- crates/atuin-daemon/src/components/mod.rs | 25 - crates/atuin-daemon/src/components/search.rs | 413 --- crates/atuin-daemon/src/components/semantic.rs | 900 ------ crates/atuin-daemon/src/components/sync.rs | 279 -- crates/atuin-daemon/src/control/mod.rs | 12 - crates/atuin-daemon/src/control/service.rs | 71 - crates/atuin-daemon/src/daemon.rs | 458 --- crates/atuin-daemon/src/events.rs | 74 - crates/atuin-daemon/src/history/mod.rs | 6 - crates/atuin-daemon/src/lib.rs | 136 - crates/atuin-daemon/src/search/index.rs | 683 ----- crates/atuin-daemon/src/search/mod.rs | 11 - crates/atuin-daemon/src/semantic/mod.rs | 3 - crates/atuin-daemon/src/server.rs | 170 -- crates/atuin-daemon/tests/lifecycle.rs | 222 -- crates/atuin-history/Cargo.toml | 30 - crates/atuin-history/benches/smart_sort.rs | 35 - crates/atuin-history/src/lib.rs | 2 - crates/atuin-history/src/sort.rs | 46 - crates/atuin-history/src/stats.rs | 548 ---- crates/atuin-pty-proxy/Cargo.toml | 21 - crates/atuin-pty-proxy/src/capture.rs | 467 --- crates/atuin-pty-proxy/src/debug.rs | 53 - crates/atuin-pty-proxy/src/lib.rs | 48 - crates/atuin-pty-proxy/src/osc133.rs | 900 ------ crates/atuin-pty-proxy/src/pty_proxy.rs | 231 -- crates/atuin-pty-proxy/src/runtime.rs | 184 -- crates/atuin-pty-proxy/src/screen.rs | 104 - crates/atuin-server-database/Cargo.toml | 21 - crates/atuin-server-database/src/calendar.rs | 18 - crates/atuin-server-database/src/lib.rs | 268 -- crates/atuin-server-database/src/models.rs | 52 - crates/atuin-server-postgres/Cargo.toml | 25 - crates/atuin-server-postgres/build.rs | 5 - .../migrations/20210425153745_create_history.sql | 11 - .../migrations/20210425153757_create_users.sql | 10 - .../migrations/20210425153800_create_sessions.sql | 6 - .../20220419082412_add_count_trigger.sql | 51 - .../20220421073605_fix_count_trigger_delete.sql | 35 - .../migrations/20220421174016_larger-commands.sql | 3 - .../migrations/20220426172813_user-created-at.sql | 1 - .../migrations/20220505082442_create-events.sql | 14 - .../migrations/20220610074049_history-length.sql | 2 - .../migrations/20230315220537_drop-events.sql | 2 - .../migrations/20230315224203_create-deleted.sql | 5 - .../20230515221038_trigger-delete-only.sql | 30 - .../migrations/20230623070418_records.sql | 15 - .../migrations/20231202170508_create-store.sql | 15 - .../migrations/20231203124112_create-store-idx.sql | 2 - .../20240108124837_drop-some-defaults.sql | 4 - .../migrations/20240614104159_idx-cache.sql | 8 - .../migrations/20240621110731_user-verified.sql | 8 - .../migrations/20240702094825_idx_cache_index.sql | 1 - .../20260127000000_remove-email-verification.sql | 2 - crates/atuin-server-postgres/src/lib.rs | 581 ---- crates/atuin-server-postgres/src/wrappers.rs | 77 - crates/atuin-server-sqlite/Cargo.toml | 24 - crates/atuin-server-sqlite/build.rs | 5 - .../migrations/20231203124112_create-store.sql | 17 - .../migrations/20240108124830_create-history.sql | 15 - .../migrations/20240108124831_create-sessions.sql | 6 - .../migrations/20240621110730_create-users.sql | 12 - ...240621110731_create-user-verification-token.sql | 6 - .../20240702094825_create-store-idx-cache.sql | 10 - .../20260127000000_remove-email-verification.sql | 2 - crates/atuin-server-sqlite/src/lib.rs | 430 --- crates/atuin-server-sqlite/src/wrappers.rs | 72 - crates/atuin-server/CHANGELOG.md | 1 - crates/atuin-server/Cargo.toml | 45 - crates/atuin-server/server.toml | 38 - crates/atuin-server/src/bin/main.rs | 73 - crates/atuin-server/src/handlers/health.rs | 15 - crates/atuin-server/src/handlers/history.rs | 237 -- crates/atuin-server/src/handlers/mod.rs | 60 - crates/atuin-server/src/handlers/record.rs | 42 - crates/atuin-server/src/handlers/status.rs | 45 - crates/atuin-server/src/handlers/user.rs | 269 -- crates/atuin-server/src/handlers/v0/me.rs | 16 - crates/atuin-server/src/handlers/v0/mod.rs | 3 - crates/atuin-server/src/handlers/v0/record.rs | 114 - crates/atuin-server/src/handlers/v0/store.rs | 37 - crates/atuin-server/src/lib.rs | 89 - crates/atuin-server/src/metrics.rs | 55 - crates/atuin-server/src/router.rs | 155 - crates/atuin-server/src/settings.rs | 113 - crates/atuin-server/src/utils.rs | 15 - crates/atuin/CHANGELOG.md | 1 - crates/atuin/Cargo.toml | 87 - crates/atuin/LICENSE | 21 - crates/atuin/README.md | 1 - crates/atuin/build.rs | 11 - crates/atuin/src/command/CONTRIBUTORS | 1 - crates/atuin/src/command/client.rs | 364 --- crates/atuin/src/command/client/account.rs | 47 - .../src/command/client/account/change_password.rs | 67 - crates/atuin/src/command/client/account/delete.rs | 57 - crates/atuin/src/command/client/account/login.rs | 206 -- crates/atuin/src/command/client/account/logout.rs | 5 - .../atuin/src/command/client/account/register.rs | 67 - crates/atuin/src/command/client/config.rs | 352 --- crates/atuin/src/command/client/daemon.rs | 784 ----- crates/atuin/src/command/client/default_config.rs | 5 - crates/atuin/src/command/client/doctor.rs | 412 --- crates/atuin/src/command/client/history.rs | 1337 --------- crates/atuin/src/command/client/import.rs | 186 -- crates/atuin/src/command/client/info.rs | 31 - crates/atuin/src/command/client/init.rs | 127 - crates/atuin/src/command/client/init/bash.rs | 25 - crates/atuin/src/command/client/init/fish.rs | 86 - crates/atuin/src/command/client/init/powershell.rs | 23 - crates/atuin/src/command/client/init/xonsh.rs | 22 - crates/atuin/src/command/client/init/zsh.rs | 38 - crates/atuin/src/command/client/search.rs | 375 --- crates/atuin/src/command/client/search/cursor.rs | 405 --- crates/atuin/src/command/client/search/duration.rs | 65 - crates/atuin/src/command/client/search/engines.rs | 95 - .../src/command/client/search/engines/daemon.rs | 249 -- .../atuin/src/command/client/search/engines/db.rs | 110 - .../src/command/client/search/engines/skim.rs | 229 -- .../src/command/client/search/history_list.rs | 429 --- .../atuin/src/command/client/search/inspector.rs | 421 --- .../atuin/src/command/client/search/interactive.rs | 3099 -------------------- .../command/client/search/keybindings/actions.rs | 322 -- .../client/search/keybindings/conditions.rs | 801 ----- .../command/client/search/keybindings/defaults.rs | 1286 -------- .../src/command/client/search/keybindings/key.rs | 629 ---- .../command/client/search/keybindings/keymap.rs | 233 -- .../src/command/client/search/keybindings/mod.rs | 14 - crates/atuin/src/command/client/setup.rs | 81 - crates/atuin/src/command/client/stats.rs | 85 - crates/atuin/src/command/client/store.rs | 120 - crates/atuin/src/command/client/store/pull.rs | 94 - crates/atuin/src/command/client/store/purge.rs | 26 - crates/atuin/src/command/client/store/push.rs | 112 - crates/atuin/src/command/client/store/rebuild.rs | 58 - crates/atuin/src/command/client/store/rekey.rs | 41 - crates/atuin/src/command/client/store/verify.rs | 26 - crates/atuin/src/command/client/sync.rs | 120 - crates/atuin/src/command/client/sync/status.rs | 37 - crates/atuin/src/command/client/wrapped.rs | 322 -- crates/atuin/src/command/contributors.rs | 5 - crates/atuin/src/command/external.rs | 102 - crates/atuin/src/command/gen_completions.rs | 84 - crates/atuin/src/command/mod.rs | 162 - crates/atuin/src/main.rs | 61 - crates/atuin/src/print_error.rs | 123 - crates/atuin/src/shell/.gitattributes | 1 - crates/atuin/src/shell/atuin.bash | 725 ----- crates/atuin/src/shell/atuin.fish | 178 -- crates/atuin/src/shell/atuin.nu | 121 - crates/atuin/src/shell/atuin.ps1 | 240 -- crates/atuin/src/shell/atuin.xsh | 86 - crates/atuin/src/shell/atuin.zsh | 221 -- crates/atuin/src/sync.rs | 34 - crates/atuin/tests/common/mod.rs | 117 - crates/atuin/tests/sync.rs | 45 - crates/atuin/tests/users.rs | 121 - crates/turtle/Cargo.toml | 142 + crates/turtle/build.rs | 39 + crates/turtle/proto/control.proto | 62 + crates/turtle/proto/history.proto | 81 + crates/turtle/proto/search.proto | 35 + crates/turtle/proto/semantic.proto | 47 + crates/turtle/src/atuin_client/api_client.rs | 438 +++ crates/turtle/src/atuin_client/auth.rs | 223 ++ crates/turtle/src/atuin_client/database.rs | 1526 ++++++++++ crates/turtle/src/atuin_client/distro.rs | 89 + crates/turtle/src/atuin_client/encryption.rs | 440 +++ crates/turtle/src/atuin_client/history.rs | 756 +++++ crates/turtle/src/atuin_client/history/builder.rs | 154 + crates/turtle/src/atuin_client/history/store.rs | 435 +++ crates/turtle/src/atuin_client/import/bash.rs | 221 ++ crates/turtle/src/atuin_client/import/fish.rs | 179 ++ crates/turtle/src/atuin_client/import/mod.rs | 140 + crates/turtle/src/atuin_client/import/nu.rs | 67 + crates/turtle/src/atuin_client/import/nu_histdb.rs | 113 + .../turtle/src/atuin_client/import/powershell.rs | 202 ++ crates/turtle/src/atuin_client/import/replxx.rs | 137 + crates/turtle/src/atuin_client/import/resh.rs | 140 + crates/turtle/src/atuin_client/import/xonsh.rs | 234 ++ .../turtle/src/atuin_client/import/xonsh_sqlite.rs | 217 ++ crates/turtle/src/atuin_client/import/zsh.rs | 230 ++ .../turtle/src/atuin_client/import/zsh_histdb.rs | 249 ++ crates/turtle/src/atuin_client/login.rs | 68 + crates/turtle/src/atuin_client/logout.rs | 16 + crates/turtle/src/atuin_client/meta.rs | 366 +++ crates/turtle/src/atuin_client/mod.rs | 26 + crates/turtle/src/atuin_client/ordering.rs | 32 + crates/turtle/src/atuin_client/plugin.rs | 150 + .../turtle/src/atuin_client/record/encryption.rs | 373 +++ crates/turtle/src/atuin_client/record/mod.rs | 6 + .../turtle/src/atuin_client/record/sqlite_store.rs | 643 ++++ crates/turtle/src/atuin_client/record/store.rs | 60 + crates/turtle/src/atuin_client/record/sync.rs | 664 +++++ crates/turtle/src/atuin_client/register.rs | 20 + crates/turtle/src/atuin_client/secrets.rs | 194 ++ crates/turtle/src/atuin_client/settings.rs | 1851 ++++++++++++ crates/turtle/src/atuin_client/settings/meta.rs | 17 + crates/turtle/src/atuin_client/settings/watcher.rs | 256 ++ crates/turtle/src/atuin_client/sync.rs | 214 ++ crates/turtle/src/atuin_client/theme.rs | 831 ++++++ crates/turtle/src/atuin_client/utils.rs | 14 + crates/turtle/src/atuin_common/api.rs | 144 + crates/turtle/src/atuin_common/calendar.rs | 16 + crates/turtle/src/atuin_common/mod.rs | 58 + crates/turtle/src/atuin_common/record.rs | 426 +++ crates/turtle/src/atuin_common/shell.rs | 183 ++ crates/turtle/src/atuin_common/tls.rs | 15 + crates/turtle/src/atuin_common/utils.rs | 383 +++ crates/turtle/src/atuin_daemon/client.rs | 418 +++ .../turtle/src/atuin_daemon/components/history.rs | 327 +++ crates/turtle/src/atuin_daemon/components/mod.rs | 25 + .../turtle/src/atuin_daemon/components/search.rs | 413 +++ .../turtle/src/atuin_daemon/components/semantic.rs | 903 ++++++ crates/turtle/src/atuin_daemon/components/sync.rs | 279 ++ crates/turtle/src/atuin_daemon/control/mod.rs | 12 + crates/turtle/src/atuin_daemon/control/service.rs | 71 + crates/turtle/src/atuin_daemon/daemon.rs | 458 +++ crates/turtle/src/atuin_daemon/events.rs | 74 + crates/turtle/src/atuin_daemon/history/mod.rs | 6 + crates/turtle/src/atuin_daemon/mod.rs | 128 + crates/turtle/src/atuin_daemon/search/index.rs | 684 +++++ crates/turtle/src/atuin_daemon/search/mod.rs | 11 + crates/turtle/src/atuin_daemon/semantic/mod.rs | 3 + crates/turtle/src/atuin_daemon/server.rs | 115 + crates/turtle/src/atuin_history/mod.rs | 2 + crates/turtle/src/atuin_history/sort.rs | 46 + crates/turtle/src/atuin_history/stats.rs | 548 ++++ crates/turtle/src/atuin_pty_proxy/capture.rs | 467 +++ crates/turtle/src/atuin_pty_proxy/debug.rs | 53 + crates/turtle/src/atuin_pty_proxy/mod.rs | 17 + crates/turtle/src/atuin_pty_proxy/osc133.rs | 900 ++++++ crates/turtle/src/atuin_pty_proxy/pty_proxy.rs | 231 ++ crates/turtle/src/atuin_pty_proxy/runtime.rs | 184 ++ crates/turtle/src/atuin_pty_proxy/screen.rs | 104 + crates/turtle/src/atuin_server/handlers/health.rs | 15 + crates/turtle/src/atuin_server/handlers/history.rs | 237 ++ crates/turtle/src/atuin_server/handlers/mod.rs | 60 + crates/turtle/src/atuin_server/handlers/record.rs | 42 + crates/turtle/src/atuin_server/handlers/status.rs | 45 + crates/turtle/src/atuin_server/handlers/user.rs | 269 ++ crates/turtle/src/atuin_server/handlers/v0/me.rs | 16 + crates/turtle/src/atuin_server/handlers/v0/mod.rs | 3 + .../turtle/src/atuin_server/handlers/v0/record.rs | 114 + .../turtle/src/atuin_server/handlers/v0/store.rs | 37 + crates/turtle/src/atuin_server/metrics.rs | 55 + crates/turtle/src/atuin_server/mod.rs | 86 + crates/turtle/src/atuin_server/router.rs | 155 + crates/turtle/src/atuin_server/settings.rs | 110 + crates/turtle/src/atuin_server/utils.rs | 15 + .../turtle/src/atuin_server_database/calendar.rs | 18 + crates/turtle/src/atuin_server_database/mod.rs | 266 ++ crates/turtle/src/atuin_server_database/models.rs | 52 + crates/turtle/src/atuin_server_postgres/mod.rs | 583 ++++ .../turtle/src/atuin_server_postgres/wrappers.rs | 77 + crates/turtle/src/atuin_server_sqlite/mod.rs | 430 +++ crates/turtle/src/atuin_server_sqlite/wrappers.rs | 72 + crates/turtle/src/command/CONTRIBUTORS | 1 + crates/turtle/src/command/client.rs | 371 +++ crates/turtle/src/command/client/account.rs | 47 + .../src/command/client/account/change_password.rs | 67 + crates/turtle/src/command/client/account/delete.rs | 57 + crates/turtle/src/command/client/account/login.rs | 206 ++ crates/turtle/src/command/client/account/logout.rs | 5 + .../turtle/src/command/client/account/register.rs | 67 + crates/turtle/src/command/client/config.rs | 352 +++ crates/turtle/src/command/client/daemon.rs | 769 +++++ crates/turtle/src/command/client/default_config.rs | 4 + crates/turtle/src/command/client/doctor.rs | 412 +++ crates/turtle/src/command/client/history.rs | 1340 +++++++++ crates/turtle/src/command/client/import.rs | 186 ++ crates/turtle/src/command/client/info.rs | 31 + crates/turtle/src/command/client/init.rs | 127 + crates/turtle/src/command/client/init/bash.rs | 25 + crates/turtle/src/command/client/init/fish.rs | 86 + .../turtle/src/command/client/init/powershell.rs | 23 + crates/turtle/src/command/client/init/xonsh.rs | 22 + crates/turtle/src/command/client/init/zsh.rs | 38 + crates/turtle/src/command/client/search.rs | 375 +++ crates/turtle/src/command/client/search/cursor.rs | 405 +++ .../turtle/src/command/client/search/duration.rs | 65 + crates/turtle/src/command/client/search/engines.rs | 95 + .../src/command/client/search/engines/daemon.rs | 242 ++ .../turtle/src/command/client/search/engines/db.rs | 110 + .../src/command/client/search/engines/skim.rs | 229 ++ .../src/command/client/search/history_list.rs | 429 +++ .../turtle/src/command/client/search/inspector.rs | 421 +++ .../src/command/client/search/interactive.rs | 3041 +++++++++++++++++++ .../command/client/search/keybindings/actions.rs | 322 ++ .../client/search/keybindings/conditions.rs | 801 +++++ .../command/client/search/keybindings/defaults.rs | 1286 ++++++++ .../src/command/client/search/keybindings/key.rs | 629 ++++ .../command/client/search/keybindings/keymap.rs | 233 ++ .../src/command/client/search/keybindings/mod.rs | 14 + crates/turtle/src/command/client/server.rs | 61 + crates/turtle/src/command/client/setup.rs | 81 + crates/turtle/src/command/client/stats.rs | 85 + crates/turtle/src/command/client/store.rs | 120 + crates/turtle/src/command/client/store/pull.rs | 94 + crates/turtle/src/command/client/store/purge.rs | 26 + crates/turtle/src/command/client/store/push.rs | 112 + crates/turtle/src/command/client/store/rebuild.rs | 58 + crates/turtle/src/command/client/store/rekey.rs | 41 + crates/turtle/src/command/client/store/verify.rs | 26 + crates/turtle/src/command/client/sync.rs | 120 + crates/turtle/src/command/client/sync/status.rs | 37 + crates/turtle/src/command/client/wrapped.rs | 326 ++ crates/turtle/src/command/contributors.rs | 5 + crates/turtle/src/command/external.rs | 102 + crates/turtle/src/command/gen_completions.rs | 84 + crates/turtle/src/command/mod.rs | 156 + crates/turtle/src/main.rs | 73 + crates/turtle/src/print_error.rs | 123 + crates/turtle/src/shell/.gitattributes | 1 + crates/turtle/src/shell/atuin.bash | 725 +++++ crates/turtle/src/shell/atuin.fish | 178 ++ crates/turtle/src/shell/atuin.nu | 121 + crates/turtle/src/shell/atuin.ps1 | 240 ++ crates/turtle/src/shell/atuin.xsh | 86 + crates/turtle/src/shell/atuin.zsh | 221 ++ crates/turtle/src/sync.rs | 34 + 392 files changed, 39213 insertions(+), 41318 deletions(-) delete mode 100644 crates/atuin-client/Cargo.toml delete mode 100644 crates/atuin-client/config.toml delete mode 100644 crates/atuin-client/meta-migrations/20260203030924_create_meta.sql delete mode 100644 crates/atuin-client/migrations/20210422143411_create_history.sql delete mode 100644 crates/atuin-client/migrations/20220505083406_create-events.sql delete mode 100644 crates/atuin-client/migrations/20220806155627_interactive_search_index.sql delete mode 100644 crates/atuin-client/migrations/20230315220114_drop-events.sql delete mode 100644 crates/atuin-client/migrations/20230319185725_deleted_at.sql delete mode 100644 crates/atuin-client/migrations/20260224000100_history_author_intent.sql delete mode 100644 crates/atuin-client/record-migrations/20230531212437_create-records.sql delete mode 100644 crates/atuin-client/record-migrations/20231127090831_create-store.sql delete mode 100644 crates/atuin-client/src/api_client.rs delete mode 100644 crates/atuin-client/src/auth.rs delete mode 100644 crates/atuin-client/src/database.rs delete mode 100644 crates/atuin-client/src/distro.rs delete mode 100644 crates/atuin-client/src/encryption.rs delete mode 100644 crates/atuin-client/src/history.rs delete mode 100644 crates/atuin-client/src/history/builder.rs delete mode 100644 crates/atuin-client/src/history/store.rs delete mode 100644 crates/atuin-client/src/import/bash.rs delete mode 100644 crates/atuin-client/src/import/fish.rs delete mode 100644 crates/atuin-client/src/import/mod.rs delete mode 100644 crates/atuin-client/src/import/nu.rs delete mode 100644 crates/atuin-client/src/import/nu_histdb.rs delete mode 100644 crates/atuin-client/src/import/powershell.rs delete mode 100644 crates/atuin-client/src/import/replxx.rs delete mode 100644 crates/atuin-client/src/import/resh.rs delete mode 100644 crates/atuin-client/src/import/xonsh.rs delete mode 100644 crates/atuin-client/src/import/xonsh_sqlite.rs delete mode 100644 crates/atuin-client/src/import/zsh.rs delete mode 100644 crates/atuin-client/src/import/zsh_histdb.rs delete mode 100644 crates/atuin-client/src/lib.rs delete mode 100644 crates/atuin-client/src/login.rs delete mode 100644 crates/atuin-client/src/logout.rs delete mode 100644 crates/atuin-client/src/meta.rs delete mode 100644 crates/atuin-client/src/ordering.rs delete mode 100644 crates/atuin-client/src/plugin.rs delete mode 100644 crates/atuin-client/src/record/encryption.rs delete mode 100644 crates/atuin-client/src/record/mod.rs delete mode 100644 crates/atuin-client/src/record/sqlite_store.rs delete mode 100644 crates/atuin-client/src/record/store.rs delete mode 100644 crates/atuin-client/src/record/sync.rs delete mode 100644 crates/atuin-client/src/register.rs delete mode 100644 crates/atuin-client/src/secrets.rs delete mode 100644 crates/atuin-client/src/settings.rs delete mode 100644 crates/atuin-client/src/settings/meta.rs delete mode 100644 crates/atuin-client/src/settings/watcher.rs delete mode 100644 crates/atuin-client/src/sync.rs delete mode 100644 crates/atuin-client/src/theme.rs delete mode 100644 crates/atuin-client/src/utils.rs delete mode 100644 crates/atuin-client/tests/data/xonsh-history.sqlite delete mode 100644 crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json delete mode 100644 crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json delete mode 100644 crates/atuin-common/Cargo.toml delete mode 100644 crates/atuin-common/src/api.rs delete mode 100644 crates/atuin-common/src/calendar.rs delete mode 100644 crates/atuin-common/src/lib.rs delete mode 100644 crates/atuin-common/src/record.rs delete mode 100644 crates/atuin-common/src/shell.rs delete mode 100644 crates/atuin-common/src/tls.rs delete mode 100644 crates/atuin-common/src/utils.rs delete mode 100644 crates/atuin-daemon/Cargo.toml delete mode 100644 crates/atuin-daemon/build.rs delete mode 100644 crates/atuin-daemon/proto/control.proto delete mode 100644 crates/atuin-daemon/proto/history.proto delete mode 100644 crates/atuin-daemon/proto/search.proto delete mode 100644 crates/atuin-daemon/proto/semantic.proto delete mode 100644 crates/atuin-daemon/src/client.rs delete mode 100644 crates/atuin-daemon/src/components/history.rs delete mode 100644 crates/atuin-daemon/src/components/mod.rs delete mode 100644 crates/atuin-daemon/src/components/search.rs delete mode 100644 crates/atuin-daemon/src/components/semantic.rs delete mode 100644 crates/atuin-daemon/src/components/sync.rs delete mode 100644 crates/atuin-daemon/src/control/mod.rs delete mode 100644 crates/atuin-daemon/src/control/service.rs delete mode 100644 crates/atuin-daemon/src/daemon.rs delete mode 100644 crates/atuin-daemon/src/events.rs delete mode 100644 crates/atuin-daemon/src/history/mod.rs delete mode 100644 crates/atuin-daemon/src/lib.rs delete mode 100644 crates/atuin-daemon/src/search/index.rs delete mode 100644 crates/atuin-daemon/src/search/mod.rs delete mode 100644 crates/atuin-daemon/src/semantic/mod.rs delete mode 100644 crates/atuin-daemon/src/server.rs delete mode 100644 crates/atuin-daemon/tests/lifecycle.rs delete mode 100644 crates/atuin-history/Cargo.toml delete mode 100644 crates/atuin-history/benches/smart_sort.rs delete mode 100644 crates/atuin-history/src/lib.rs delete mode 100644 crates/atuin-history/src/sort.rs delete mode 100644 crates/atuin-history/src/stats.rs delete mode 100644 crates/atuin-pty-proxy/Cargo.toml delete mode 100644 crates/atuin-pty-proxy/src/capture.rs delete mode 100644 crates/atuin-pty-proxy/src/debug.rs delete mode 100644 crates/atuin-pty-proxy/src/lib.rs delete mode 100644 crates/atuin-pty-proxy/src/osc133.rs delete mode 100644 crates/atuin-pty-proxy/src/pty_proxy.rs delete mode 100644 crates/atuin-pty-proxy/src/runtime.rs delete mode 100644 crates/atuin-pty-proxy/src/screen.rs delete mode 100644 crates/atuin-server-database/Cargo.toml delete mode 100644 crates/atuin-server-database/src/calendar.rs delete mode 100644 crates/atuin-server-database/src/lib.rs delete mode 100644 crates/atuin-server-database/src/models.rs delete mode 100644 crates/atuin-server-postgres/Cargo.toml delete mode 100644 crates/atuin-server-postgres/build.rs delete mode 100644 crates/atuin-server-postgres/migrations/20210425153745_create_history.sql delete mode 100644 crates/atuin-server-postgres/migrations/20210425153757_create_users.sql delete mode 100644 crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220505082442_create-events.sql delete mode 100644 crates/atuin-server-postgres/migrations/20220610074049_history-length.sql delete mode 100644 crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql delete mode 100644 crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql delete mode 100644 crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql delete mode 100644 crates/atuin-server-postgres/migrations/20230623070418_records.sql delete mode 100644 crates/atuin-server-postgres/migrations/20231202170508_create-store.sql delete mode 100644 crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql delete mode 100644 crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql delete mode 100644 crates/atuin-server-postgres/migrations/20240614104159_idx-cache.sql delete mode 100644 crates/atuin-server-postgres/migrations/20240621110731_user-verified.sql delete mode 100644 crates/atuin-server-postgres/migrations/20240702094825_idx_cache_index.sql delete mode 100644 crates/atuin-server-postgres/migrations/20260127000000_remove-email-verification.sql delete mode 100644 crates/atuin-server-postgres/src/lib.rs delete mode 100644 crates/atuin-server-postgres/src/wrappers.rs delete mode 100644 crates/atuin-server-sqlite/Cargo.toml delete mode 100644 crates/atuin-server-sqlite/build.rs delete mode 100644 crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql delete mode 100644 crates/atuin-server-sqlite/migrations/20260127000000_remove-email-verification.sql delete mode 100644 crates/atuin-server-sqlite/src/lib.rs delete mode 100644 crates/atuin-server-sqlite/src/wrappers.rs delete mode 120000 crates/atuin-server/CHANGELOG.md delete mode 100644 crates/atuin-server/Cargo.toml delete mode 100644 crates/atuin-server/server.toml delete mode 100644 crates/atuin-server/src/bin/main.rs delete mode 100644 crates/atuin-server/src/handlers/health.rs delete mode 100644 crates/atuin-server/src/handlers/history.rs delete mode 100644 crates/atuin-server/src/handlers/mod.rs delete mode 100644 crates/atuin-server/src/handlers/record.rs delete mode 100644 crates/atuin-server/src/handlers/status.rs delete mode 100644 crates/atuin-server/src/handlers/user.rs delete mode 100644 crates/atuin-server/src/handlers/v0/me.rs delete mode 100644 crates/atuin-server/src/handlers/v0/mod.rs delete mode 100644 crates/atuin-server/src/handlers/v0/record.rs delete mode 100644 crates/atuin-server/src/handlers/v0/store.rs delete mode 100644 crates/atuin-server/src/lib.rs delete mode 100644 crates/atuin-server/src/metrics.rs delete mode 100644 crates/atuin-server/src/router.rs delete mode 100644 crates/atuin-server/src/settings.rs delete mode 100644 crates/atuin-server/src/utils.rs delete mode 120000 crates/atuin/CHANGELOG.md delete mode 100644 crates/atuin/Cargo.toml delete mode 100644 crates/atuin/LICENSE delete mode 120000 crates/atuin/README.md delete mode 100644 crates/atuin/build.rs delete mode 120000 crates/atuin/src/command/CONTRIBUTORS delete mode 100644 crates/atuin/src/command/client.rs delete mode 100644 crates/atuin/src/command/client/account.rs delete mode 100644 crates/atuin/src/command/client/account/change_password.rs delete mode 100644 crates/atuin/src/command/client/account/delete.rs delete mode 100644 crates/atuin/src/command/client/account/login.rs delete mode 100644 crates/atuin/src/command/client/account/logout.rs delete mode 100644 crates/atuin/src/command/client/account/register.rs delete mode 100644 crates/atuin/src/command/client/config.rs delete mode 100644 crates/atuin/src/command/client/daemon.rs delete mode 100644 crates/atuin/src/command/client/default_config.rs delete mode 100644 crates/atuin/src/command/client/doctor.rs delete mode 100644 crates/atuin/src/command/client/history.rs delete mode 100644 crates/atuin/src/command/client/import.rs delete mode 100644 crates/atuin/src/command/client/info.rs delete mode 100644 crates/atuin/src/command/client/init.rs delete mode 100644 crates/atuin/src/command/client/init/bash.rs delete mode 100644 crates/atuin/src/command/client/init/fish.rs delete mode 100644 crates/atuin/src/command/client/init/powershell.rs delete mode 100644 crates/atuin/src/command/client/init/xonsh.rs delete mode 100644 crates/atuin/src/command/client/init/zsh.rs delete mode 100644 crates/atuin/src/command/client/search.rs delete mode 100644 crates/atuin/src/command/client/search/cursor.rs delete mode 100644 crates/atuin/src/command/client/search/duration.rs delete mode 100644 crates/atuin/src/command/client/search/engines.rs delete mode 100644 crates/atuin/src/command/client/search/engines/daemon.rs delete mode 100644 crates/atuin/src/command/client/search/engines/db.rs delete mode 100644 crates/atuin/src/command/client/search/engines/skim.rs delete mode 100644 crates/atuin/src/command/client/search/history_list.rs delete mode 100644 crates/atuin/src/command/client/search/inspector.rs delete mode 100644 crates/atuin/src/command/client/search/interactive.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/actions.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/conditions.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/defaults.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/key.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/keymap.rs delete mode 100644 crates/atuin/src/command/client/search/keybindings/mod.rs delete mode 100644 crates/atuin/src/command/client/setup.rs delete mode 100644 crates/atuin/src/command/client/stats.rs delete mode 100644 crates/atuin/src/command/client/store.rs delete mode 100644 crates/atuin/src/command/client/store/pull.rs delete mode 100644 crates/atuin/src/command/client/store/purge.rs delete mode 100644 crates/atuin/src/command/client/store/push.rs delete mode 100644 crates/atuin/src/command/client/store/rebuild.rs delete mode 100644 crates/atuin/src/command/client/store/rekey.rs delete mode 100644 crates/atuin/src/command/client/store/verify.rs delete mode 100644 crates/atuin/src/command/client/sync.rs delete mode 100644 crates/atuin/src/command/client/sync/status.rs delete mode 100644 crates/atuin/src/command/client/wrapped.rs delete mode 100644 crates/atuin/src/command/contributors.rs delete mode 100644 crates/atuin/src/command/external.rs delete mode 100644 crates/atuin/src/command/gen_completions.rs delete mode 100644 crates/atuin/src/command/mod.rs delete mode 100644 crates/atuin/src/main.rs delete mode 100644 crates/atuin/src/print_error.rs delete mode 100644 crates/atuin/src/shell/.gitattributes delete mode 100644 crates/atuin/src/shell/atuin.bash delete mode 100644 crates/atuin/src/shell/atuin.fish delete mode 100644 crates/atuin/src/shell/atuin.nu delete mode 100644 crates/atuin/src/shell/atuin.ps1 delete mode 100644 crates/atuin/src/shell/atuin.xsh delete mode 100644 crates/atuin/src/shell/atuin.zsh delete mode 100644 crates/atuin/src/sync.rs delete mode 100644 crates/atuin/tests/common/mod.rs delete mode 100644 crates/atuin/tests/sync.rs delete mode 100644 crates/atuin/tests/users.rs create mode 100644 crates/turtle/Cargo.toml create mode 100644 crates/turtle/build.rs create mode 100644 crates/turtle/proto/control.proto create mode 100644 crates/turtle/proto/history.proto create mode 100644 crates/turtle/proto/search.proto create mode 100644 crates/turtle/proto/semantic.proto create mode 100644 crates/turtle/src/atuin_client/api_client.rs create mode 100644 crates/turtle/src/atuin_client/auth.rs create mode 100644 crates/turtle/src/atuin_client/database.rs create mode 100644 crates/turtle/src/atuin_client/distro.rs create mode 100644 crates/turtle/src/atuin_client/encryption.rs create mode 100644 crates/turtle/src/atuin_client/history.rs create mode 100644 crates/turtle/src/atuin_client/history/builder.rs create mode 100644 crates/turtle/src/atuin_client/history/store.rs create mode 100644 crates/turtle/src/atuin_client/import/bash.rs create mode 100644 crates/turtle/src/atuin_client/import/fish.rs create mode 100644 crates/turtle/src/atuin_client/import/mod.rs create mode 100644 crates/turtle/src/atuin_client/import/nu.rs create mode 100644 crates/turtle/src/atuin_client/import/nu_histdb.rs create mode 100644 crates/turtle/src/atuin_client/import/powershell.rs create mode 100644 crates/turtle/src/atuin_client/import/replxx.rs create mode 100644 crates/turtle/src/atuin_client/import/resh.rs create mode 100644 crates/turtle/src/atuin_client/import/xonsh.rs create mode 100644 crates/turtle/src/atuin_client/import/xonsh_sqlite.rs create mode 100644 crates/turtle/src/atuin_client/import/zsh.rs create mode 100644 crates/turtle/src/atuin_client/import/zsh_histdb.rs create mode 100644 crates/turtle/src/atuin_client/login.rs create mode 100644 crates/turtle/src/atuin_client/logout.rs create mode 100644 crates/turtle/src/atuin_client/meta.rs create mode 100644 crates/turtle/src/atuin_client/mod.rs create mode 100644 crates/turtle/src/atuin_client/ordering.rs create mode 100644 crates/turtle/src/atuin_client/plugin.rs create mode 100644 crates/turtle/src/atuin_client/record/encryption.rs create mode 100644 crates/turtle/src/atuin_client/record/mod.rs create mode 100644 crates/turtle/src/atuin_client/record/sqlite_store.rs create mode 100644 crates/turtle/src/atuin_client/record/store.rs create mode 100644 crates/turtle/src/atuin_client/record/sync.rs create mode 100644 crates/turtle/src/atuin_client/register.rs create mode 100644 crates/turtle/src/atuin_client/secrets.rs create mode 100644 crates/turtle/src/atuin_client/settings.rs create mode 100644 crates/turtle/src/atuin_client/settings/meta.rs create mode 100644 crates/turtle/src/atuin_client/settings/watcher.rs create mode 100644 crates/turtle/src/atuin_client/sync.rs create mode 100644 crates/turtle/src/atuin_client/theme.rs create mode 100644 crates/turtle/src/atuin_client/utils.rs create mode 100644 crates/turtle/src/atuin_common/api.rs create mode 100644 crates/turtle/src/atuin_common/calendar.rs create mode 100644 crates/turtle/src/atuin_common/mod.rs create mode 100644 crates/turtle/src/atuin_common/record.rs create mode 100644 crates/turtle/src/atuin_common/shell.rs create mode 100644 crates/turtle/src/atuin_common/tls.rs create mode 100644 crates/turtle/src/atuin_common/utils.rs create mode 100644 crates/turtle/src/atuin_daemon/client.rs create mode 100644 crates/turtle/src/atuin_daemon/components/history.rs create mode 100644 crates/turtle/src/atuin_daemon/components/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/components/search.rs create mode 100644 crates/turtle/src/atuin_daemon/components/semantic.rs create mode 100644 crates/turtle/src/atuin_daemon/components/sync.rs create mode 100644 crates/turtle/src/atuin_daemon/control/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/control/service.rs create mode 100644 crates/turtle/src/atuin_daemon/daemon.rs create mode 100644 crates/turtle/src/atuin_daemon/events.rs create mode 100644 crates/turtle/src/atuin_daemon/history/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/search/index.rs create mode 100644 crates/turtle/src/atuin_daemon/search/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/semantic/mod.rs create mode 100644 crates/turtle/src/atuin_daemon/server.rs create mode 100644 crates/turtle/src/atuin_history/mod.rs create mode 100644 crates/turtle/src/atuin_history/sort.rs create mode 100644 crates/turtle/src/atuin_history/stats.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/capture.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/debug.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/mod.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/osc133.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/pty_proxy.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/runtime.rs create mode 100644 crates/turtle/src/atuin_pty_proxy/screen.rs create mode 100644 crates/turtle/src/atuin_server/handlers/health.rs create mode 100644 crates/turtle/src/atuin_server/handlers/history.rs create mode 100644 crates/turtle/src/atuin_server/handlers/mod.rs create mode 100644 crates/turtle/src/atuin_server/handlers/record.rs create mode 100644 crates/turtle/src/atuin_server/handlers/status.rs create mode 100644 crates/turtle/src/atuin_server/handlers/user.rs create mode 100644 crates/turtle/src/atuin_server/handlers/v0/me.rs create mode 100644 crates/turtle/src/atuin_server/handlers/v0/mod.rs create mode 100644 crates/turtle/src/atuin_server/handlers/v0/record.rs create mode 100644 crates/turtle/src/atuin_server/handlers/v0/store.rs create mode 100644 crates/turtle/src/atuin_server/metrics.rs create mode 100644 crates/turtle/src/atuin_server/mod.rs create mode 100644 crates/turtle/src/atuin_server/router.rs create mode 100644 crates/turtle/src/atuin_server/settings.rs create mode 100644 crates/turtle/src/atuin_server/utils.rs create mode 100644 crates/turtle/src/atuin_server_database/calendar.rs create mode 100644 crates/turtle/src/atuin_server_database/mod.rs create mode 100644 crates/turtle/src/atuin_server_database/models.rs create mode 100644 crates/turtle/src/atuin_server_postgres/mod.rs create mode 100644 crates/turtle/src/atuin_server_postgres/wrappers.rs create mode 100644 crates/turtle/src/atuin_server_sqlite/mod.rs create mode 100644 crates/turtle/src/atuin_server_sqlite/wrappers.rs create mode 120000 crates/turtle/src/command/CONTRIBUTORS create mode 100644 crates/turtle/src/command/client.rs create mode 100644 crates/turtle/src/command/client/account.rs create mode 100644 crates/turtle/src/command/client/account/change_password.rs create mode 100644 crates/turtle/src/command/client/account/delete.rs create mode 100644 crates/turtle/src/command/client/account/login.rs create mode 100644 crates/turtle/src/command/client/account/logout.rs create mode 100644 crates/turtle/src/command/client/account/register.rs create mode 100644 crates/turtle/src/command/client/config.rs create mode 100644 crates/turtle/src/command/client/daemon.rs create mode 100644 crates/turtle/src/command/client/default_config.rs create mode 100644 crates/turtle/src/command/client/doctor.rs create mode 100644 crates/turtle/src/command/client/history.rs create mode 100644 crates/turtle/src/command/client/import.rs create mode 100644 crates/turtle/src/command/client/info.rs create mode 100644 crates/turtle/src/command/client/init.rs create mode 100644 crates/turtle/src/command/client/init/bash.rs create mode 100644 crates/turtle/src/command/client/init/fish.rs create mode 100644 crates/turtle/src/command/client/init/powershell.rs create mode 100644 crates/turtle/src/command/client/init/xonsh.rs create mode 100644 crates/turtle/src/command/client/init/zsh.rs create mode 100644 crates/turtle/src/command/client/search.rs create mode 100644 crates/turtle/src/command/client/search/cursor.rs create mode 100644 crates/turtle/src/command/client/search/duration.rs create mode 100644 crates/turtle/src/command/client/search/engines.rs create mode 100644 crates/turtle/src/command/client/search/engines/daemon.rs create mode 100644 crates/turtle/src/command/client/search/engines/db.rs create mode 100644 crates/turtle/src/command/client/search/engines/skim.rs create mode 100644 crates/turtle/src/command/client/search/history_list.rs create mode 100644 crates/turtle/src/command/client/search/inspector.rs create mode 100644 crates/turtle/src/command/client/search/interactive.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/actions.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/conditions.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/defaults.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/key.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/keymap.rs create mode 100644 crates/turtle/src/command/client/search/keybindings/mod.rs create mode 100644 crates/turtle/src/command/client/server.rs create mode 100644 crates/turtle/src/command/client/setup.rs create mode 100644 crates/turtle/src/command/client/stats.rs create mode 100644 crates/turtle/src/command/client/store.rs create mode 100644 crates/turtle/src/command/client/store/pull.rs create mode 100644 crates/turtle/src/command/client/store/purge.rs create mode 100644 crates/turtle/src/command/client/store/push.rs create mode 100644 crates/turtle/src/command/client/store/rebuild.rs create mode 100644 crates/turtle/src/command/client/store/rekey.rs create mode 100644 crates/turtle/src/command/client/store/verify.rs create mode 100644 crates/turtle/src/command/client/sync.rs create mode 100644 crates/turtle/src/command/client/sync/status.rs create mode 100644 crates/turtle/src/command/client/wrapped.rs create mode 100644 crates/turtle/src/command/contributors.rs create mode 100644 crates/turtle/src/command/external.rs create mode 100644 crates/turtle/src/command/gen_completions.rs create mode 100644 crates/turtle/src/command/mod.rs create mode 100644 crates/turtle/src/main.rs create mode 100644 crates/turtle/src/print_error.rs create mode 100644 crates/turtle/src/shell/.gitattributes create mode 100644 crates/turtle/src/shell/atuin.bash create mode 100644 crates/turtle/src/shell/atuin.fish create mode 100644 crates/turtle/src/shell/atuin.nu create mode 100644 crates/turtle/src/shell/atuin.ps1 create mode 100644 crates/turtle/src/shell/atuin.xsh create mode 100644 crates/turtle/src/shell/atuin.zsh create mode 100644 crates/turtle/src/sync.rs diff --git a/Cargo.lock b/Cargo.lock index d10c4402..cfe35268 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -189,86 +189,65 @@ name = "atuin" version = "18.16.1" dependencies = [ "arboard", + "argon2", "async-trait", - "atuin-client", - "atuin-common", - "atuin-daemon", - "atuin-history", + "atuin-nucleo", "atuin-nucleo-matcher", - "atuin-pty-proxy", - "atuin-server", - "atuin-server-database", - "atuin-server-postgres", + "axum", + "base64", "clap", "clap_complete", "clap_complete_nushell", "colored", - "crossterm", - "daemonize", - "eyre", - "fs-err", - "fs4", - "futures-util", - "fuzzy-matcher", - "indicatif", - "interim", - "itertools", - "log", - "norm", - "open", - "ratatui", - "regex", - "rpassword", - "runtime-format", - "rustix", - "semver", - "serde", - "serde_json", - "shlex", - "sysinfo", - "tempfile", - "time", - "tokio", - "toml_edit", - "tracing", - "tracing-appender", - "tracing-subscriber", - "tracing-tree", - "unicode-width 0.2.2", - "uuid", -] - -[[package]] -name = "atuin-client" -version = "18.16.1" -dependencies = [ - "async-trait", - "atuin-common", - "base64", - "clap", "config", "crossterm", "crypto_secretbox", + "daemonize", + "dashmap", "directories", + "divan", "eyre", "fs-err", + "fs4", "futures", + "futures-util", + "fuzzy-matcher", "generic-array", + "getrandom 0.2.17", + "glob-match", "hex", "humantime", + "hyper-util", + "imara-diff", "indicatif", "interim", "itertools", + "lasso", + "listenfd", "log", "memchr", + "metrics", + "metrics-exporter-prometheus", + "minijinja", "minspan", + "norm", "notify", + "open", "palette", + "portable-pty", "pretty_assertions", + "prost", + "prost-types", + "protox", "rand 0.8.5", + "ratatui", "regex", "reqwest", "rmp", + "rpassword", + "runtime-format", + "rustix", + "rustls", "rusty_paserk", "rusty_paseto", "semver", @@ -278,83 +257,40 @@ dependencies = [ "serde_with", "sha2", "shellexpand", + "shlex", + "signal-hook", "sql-builder", "sqlx", "strum", "strum_macros", - "testing_logger", - "thiserror 2.0.18", - "time", - "tokio", - "typed-builder", - "urlencoding", - "uuid", - "whoami 2.1.1", -] - -[[package]] -name = "atuin-common" -version = "18.16.1" -dependencies = [ - "base64", - "directories", - "eyre", - "getrandom 0.2.17", - "pretty_assertions", - "rustls", - "semver", - "serde", - "sqlx", "sysinfo", - "thiserror 2.0.18", - "time", - "typed-builder", - "uuid", -] - -[[package]] -name = "atuin-daemon" -version = "18.16.1" -dependencies = [ - "atuin-client", - "atuin-common", - "atuin-history", - "atuin-nucleo", - "dashmap", - "eyre", - "hyper-util", - "lasso", - "listenfd", - "prost", - "prost-types", - "protox", - "rand 0.8.5", "tempfile", + "testing_logger", + "thiserror 2.0.18", "time", "tokio", "tokio-stream", + "toml_edit", "tonic", "tonic-build", "tonic-prost", "tonic-prost-build", "tonic-types", "tower", + "tower-http", "tracing", + "tracing-appender", "tracing-subscriber", - "uuid", -] - -[[package]] -name = "atuin-history" -version = "18.16.1" -dependencies = [ - "atuin-client", - "crossterm", - "divan", - "rand 0.8.5", - "serde", - "time", + "tracing-tree", + "typed-builder", "unicode-segmentation", + "unicode-width 0.2.2", + "url", + "urlencoding", + "uuid", + "vt100", + "whoami 2.1.1", + "xxhash-rust", ] [[package]] @@ -384,96 +320,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "atuin-pty-proxy" -version = "18.16.1" -dependencies = [ - "clap", - "crossterm", - "eyre", - "portable-pty", - "signal-hook", - "vt100", -] - -[[package]] -name = "atuin-server" -version = "18.16.1" -dependencies = [ - "argon2", - "atuin-common", - "atuin-server-database", - "atuin-server-postgres", - "atuin-server-sqlite", - "axum", - "clap", - "config", - "eyre", - "fs-err", - "metrics", - "metrics-exporter-prometheus", - "rand 0.8.5", - "reqwest", - "semver", - "serde", - "serde_json", - "time", - "tokio", - "tower", - "tower-http", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "atuin-server-database" -version = "18.16.1" -dependencies = [ - "async-trait", - "atuin-common", - "eyre", - "serde", - "sqlx", - "time", - "tracing", - "url", -] - -[[package]] -name = "atuin-server-postgres" -version = "18.16.1" -dependencies = [ - "async-trait", - "atuin-common", - "atuin-server-database", - "eyre", - "futures-util", - "metrics", - "rand 0.8.5", - "serde", - "sqlx", - "time", - "tracing", - "uuid", -] - -[[package]] -name = "atuin-server-sqlite" -version = "18.16.1" -dependencies = [ - "async-trait", - "atuin-common", - "atuin-server-database", - "eyre", - "futures-util", - "metrics", - "serde", - "sqlx", - "time", - "tracing", - "uuid", -] - [[package]] name = "autocfg" version = "1.5.0" @@ -1694,6 +1540,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob-match" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d" + [[package]] name = "h2" version = "0.4.13" @@ -2063,6 +1915,16 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "imara-diff" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f01d462f766df78ab820dd06f5eb700233c51f0f4c2e846520eaf4ba6aa5c5c" +dependencies = [ + "hashbrown 0.15.5", + "memchr", +] + [[package]] name = "indenter" version = "0.3.4" @@ -2512,6 +2374,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a64a92489e2744ce060c349162be1c5f33c6969234104dbd99ddb5feb08b8c15" +[[package]] +name = "memo-map" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d1115007560874e373613744c6fba374c17688327a71c1476d1a5954cc857b" + [[package]] name = "memoffset" version = "0.9.1" @@ -2589,6 +2457,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minijinja" +version = "2.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f" +dependencies = [ + "memo-map", + "serde", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -6376,6 +6254,12 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd" +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + [[package]] name = "yansi" version = "1.0.1" diff --git a/Cargo.toml b/Cargo.toml index cc88e07f..4c87b914 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [workspace] members = [ - "crates/*", "crates/atuin-nucleo/matcher", "crates/atuin-nucleo/bench", + "crates/turtle" ] resolver = "2" -exclude = ["ui/backend", "crates/atuin-nucleo/matcher/fuzz"] +exclude = ["crates/atuin-nucleo/matcher/fuzz"] [workspace.package] version = "18.16.1" @@ -18,67 +18,5 @@ repository = "https://github.com/atuinsh/atuin" readme = "README.md" [workspace.dependencies] -async-trait = "0.1.58" -atuin-client = { path = "crates/atuin-client", version = "18.16.1" } -atuin-common = { path = "crates/atuin-common", version = "18.16.1" } -atuin-daemon = { path = "crates/atuin-daemon", version = "18.16.1" } -atuin-history = { path = "crates/atuin-history", version = "18.16.1" } -atuin-server = { path = "crates/atuin-server", version = "18.16.1" } -atuin-server-database = { path = "crates/atuin-server-database", version = "18.16.1" } -atuin-server-postgres = { path = "crates/atuin-server-postgres", version = "18.16.1" } -atuin-server-sqlite = { path = "crates/atuin-server-sqlite", version = "18.16.1" } atuin-nucleo = { path = "crates/atuin-nucleo", version = "0.6.0" } atuin-nucleo-matcher = { path = "crates/atuin-nucleo/matcher", version = "0.3.1" } -base64 = "0.22" -crossterm = "0.29.0" -log = "0.4" -time = { version = "0.3.47", features = [ - "serde-human-readable", - "macros", - "local-offset", -] } -clap = { version = "4.5.7", features = ["derive"] } -config = { version = "0.15.8", default-features = false, features = ["toml"] } -directories = "6.0.0" -eyre = "0.6" -fs-err = "3.1" -interim = { version = "0.2.0", features = ["time_0_3"] } -itertools = "0.14.0" -rand = { version = "0.8.5", features = ["std"] } -semver = "1.0.20" -serde = { version = "1.0.202", features = ["derive"] } -serde_json = "1.0.119" -shellexpand = "3" -tokio = { version = "1", features = ["full"] } -uuid = { version = "1.9", features = ["v4", "v7", "serde"] } -whoami = "2.1.0" -typed-builder = "0.18.2" -pretty_assertions = "1.3.0" -thiserror = "2" -rustix = { version = "1.1.4", features = ["process", "fs"] } -tower = "0.5" -tracing = "0.1" -ratatui = "0.30.0" -sql-builder = "3" -tempfile = { version = "3.19" } -minijinja = "2.9.0" -rustls = { version = "0.23", default-features = false, features = [ - "ring", - "std", - "tls12", -] } -glob-match = "0.2.1" -imara-diff = "0.2" -xxhash-rust = { version = "0.8", features = ["xxh3"] } -vt100 = "0.16" -regex = "1.10.5" -toml_edit = "0.25.4" -tracing-subscriber = { version = "0.3", features = ["ansi", "fmt", "registry", "env-filter", "json"] } -reqwest = { version = "0.13", features = ["json", "rustls-no-provider", "stream"], default-features = false } -sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "time", "postgres", "uuid"] } - -# The profile that 'cargo dist' will build with -[profile.dist] -inherits = "release" -lto = "thin" -strip = "symbols" diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml deleted file mode 100644 index c6a0f261..00000000 --- a/crates/atuin-client/Cargo.toml +++ /dev/null @@ -1,82 +0,0 @@ -[package] -name = "atuin-client" -edition = "2024" -description = "client library for atuin" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[features] -default = ["sync", "daemon"] -sync = ["urlencoding", "reqwest", "sha2", "hex"] -daemon = [] -check-update = [] - -[dependencies] -atuin-common = { path = "../atuin-common", version = "18.16.1" } - -log = { workspace = true } -base64 = { workspace = true } -time = { workspace = true, features = ["macros", "formatting", "parsing"] } -clap = { workspace = true } -eyre = { workspace = true } -directories = { workspace = true } -uuid = { workspace = true } -whoami = { workspace = true } -interim = { workspace = true } -config = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -humantime = "2.1.0" -async-trait = { workspace = true } -itertools = { workspace = true } -rand = { workspace = true } -shellexpand = { workspace = true } -sqlx = { workspace = true, features = ["sqlite", "regexp"] } -minspan = "0.1.5" -regex = { workspace = true } -serde_regex = "1.1.0" -fs-err = { workspace = true } -sql-builder = { workspace = true } -memchr = "2.7" -rmp = { version = "0.8.14" } -typed-builder = { workspace = true } -tokio = { workspace = true } -semver = { workspace = true } -thiserror = { workspace = true } -futures = "0.3" -notify = "7" -crypto_secretbox = "0.1.1" -generic-array = { version = "0.14", features = ["serde"] } -serde_with = "3.8.1" - -# encryption -rusty_paseto = { version = "0.8.0", default-features = false } -rusty_paserk = { version = "0.5.0", default-features = false, features = [ - "v4", - "serde", -] } - -# sync -urlencoding = { version = "2.1.0", optional = true } -reqwest = { workspace = true, optional = true } -hex = { version = "0.4", optional = true } -sha2 = { version = "0.10", optional = true } -indicatif = "0.18.0" - -# theme -crossterm = { workspace = true, features = ["serde"] } -palette = { version = "0.7.5", features = ["serializing"] } -strum_macros = "0.27" -strum = { version = "0.27", features = ["strum_macros"] } - -[dev-dependencies] -tokio = { version = "1", features = ["full"] } -pretty_assertions = { workspace = true } -testing_logger = "0.1.1" diff --git a/crates/atuin-client/config.toml b/crates/atuin-client/config.toml deleted file mode 100644 index 0d0672bf..00000000 --- a/crates/atuin-client/config.toml +++ /dev/null @@ -1,371 +0,0 @@ -## Base directory for Atuin data files (databases, keys, session, etc.) -## All data file paths default to being relative to this directory. -## linux/mac: ~/.local/share/atuin (or XDG_DATA_HOME/atuin) -## windows: %USERPROFILE%/.local/share/atuin -# data_dir = "~/.local/share/atuin" - -## where to store your database, default is your system data directory -## linux/mac: ~/.local/share/atuin/history.db -## windows: %USERPROFILE%/.local/share/atuin/history.db -# db_path = "~/.history.db" - -## where to store your encryption key, default is your system data directory -## linux/mac: ~/.local/share/atuin/key -## windows: %USERPROFILE%/.local/share/atuin/key -# key_path = "~/.key" - -## where to store your auth session token, default is your system data directory -## linux/mac: ~/.local/share/atuin/session -## windows: %USERPROFILE%/.local/share/atuin/session -# session_path = "~/.session" - -## date format used, either "us" or "uk" -# dialect = "us" - -## default timezone to use when displaying time -## either "l", "local" to use the system's current local timezone, or an offset -## from UTC in the format of "<+|->H[H][:M[M][:S[S]]]" -## for example: "+9", "-05", "+03:30", "-01:23:45", etc. -# timezone = "local" - -## enable or disable automatic sync -# auto_sync = true - -## enable or disable automatic update checks -# update_check = true - -## address of the sync server -# sync_address = "https://api.atuin.sh" - -## how often to sync history. note that this is only triggered when a command -## is ran, so sync intervals may well be longer -## set it to 0 to sync after every command -# sync_frequency = "10m" - -## which search mode to use -## possible values: prefix, fulltext, fuzzy, skim -# search_mode = "fuzzy" - -## which filter mode to use by default -## possible values: "global", "host", "session", "session-preload", "directory", "workspace" -## consider using search.filters to customize the enablement and order of filter modes -# filter_mode = "global" - -## With workspace filtering enabled, Atuin will filter for commands executed -## in any directory within a git repository tree (default: false). -## -## To use workspace mode by default when available, set this to true and -## set filter_mode to "workspace" or leave it unspecified and -## set search.filters to include "workspace" before other filter modes. -# workspaces = false - -## which filter mode to use when atuin is invoked from a shell up-key binding -## the accepted values are identical to those of "filter_mode" -## leave unspecified to use same mode set in "filter_mode" -# filter_mode_shell_up_key_binding = "global" - -## which search mode to use when atuin is invoked from a shell up-key binding -## the accepted values are identical to those of "search_mode" -## leave unspecified to use same mode set in "search_mode" -# search_mode_shell_up_key_binding = "fuzzy" - -## which style to use -## possible values: auto, full, compact -# style = "auto" - -## the maximum number of lines the interface should take up -## set it to 0 to always go full screen -# inline_height = 0 - -## the maximum number of lines the interface should take up -## when atuin is invoked from a shell up-key binding -## the accepted values are identical to those of "inline_height" -# inline_height_shell_up_key_binding = 0 - -## Invert the UI - put the search bar at the top , Default to `false` -# invert = false - -## enable or disable showing a preview of the selected command -## useful when the command is longer than the terminal width and is cut off -# show_preview = true - -## what to do when the escape key is pressed when searching -## possible values: return-original, return-query -# exit_mode = "return-original" - -## possible values: emacs, subl -# word_jump_mode = "emacs" - -## characters that count as a part of a word -# word_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - -## number of context lines to show when scrolling by pages -# scroll_context_lines = 1 - -## use ctrl instead of alt as the shortcut modifier key for numerical UI shortcuts -## alt-0 .. alt-9 -# ctrl_n_shortcuts = false - -## Show numeric shortcuts (1..9) beside list items in the TUI -## set to false to hide the moving numbers if you find them distracting -# show_numeric_shortcuts = true - -## default history list format - can also be specified with the --format arg -# history_format = "{time}\t{command}\t{duration}" - -## Defaults to true. If enabled, strip trailing spaces and tabs from commands -## before saving them to history. -## Escaped trailing spaces (for example `printf foo\\ `) are preserved. -# strip_trailing_whitespace = true - -## prevent commands matching any of these regexes from being written to history. -## Note that these regular expressions are unanchored, i.e. if they don't start -## with ^ or end with $, they'll match anywhere in the command. -## For details on the supported regular expression syntax, see -## https://docs.rs/regex/latest/regex/#syntax -# history_filter = [ -# "^secret-cmd", -# "^innocuous-cmd .*--secret=.+", -# ] - -## prevent commands run with cwd matching any of these regexes from being written -## to history. Note that these regular expressions are unanchored, i.e. if they don't -## start with ^ or end with $, they'll match anywhere in CWD. -## For details on the supported regular expression syntax, see -## https://docs.rs/regex/latest/regex/#syntax -# cwd_filter = [ -# "^/very/secret/area", -# ] - -## Configure the maximum height of the preview to show. -## Useful when you have long scripts in your history that you want to distinguish -## by more than the first few lines. -# max_preview_height = 4 - -## Configure whether or not to show the help row, which includes the current Atuin -## version (and whether an update is available), a keymap hint, and the total -## amount of commands in your history. -# show_help = true - -## Configure whether or not to show tabs for search and inspect -# show_tabs = true - -## Configure whether or not the tabs row may be auto-hidden, which includes the current Atuin -## tab, such as Search or Inspector, and other tabs you may wish to see. This will -## only be hidden if there are fewer than this count of lines available, and does not affect the use -## of keyboard shortcuts to switch tab. 0 to never auto-hide, default is 8 (lines). -## This is ignored except in `compact` mode. -# auto_hide_height = 8 - -## Defaults to true. This matches history against a set of default regex, and will not save it if we get a match. Defaults include -## 1. AWS key id -## 2. Github pat (old and new) -## 3. Slack oauth tokens (bot, user) -## 4. Slack webhooks -## 5. Stripe live/test keys -# secrets_filter = true - -## Defaults to true. If enabled, upon hitting enter Atuin will immediately execute the command, -## whereas tab will put the command in the prompt for editing. -## If set to false, both enter and tab will place the command in the prompt for editing. -## This applies for new installs. Old installs will keep the old behaviour unless configured otherwise. -enter_accept = true - -## Defaults to false. If enabled, when triggered after &&, || or |, Atuin will complete commands to chain rather than replace the current line. -# command_chaining = false - -## Defaults to "emacs". This specifies the keymap on the startup of `atuin -## search`. If this is set to "auto", the startup keymap mode in the Atuin -## search is automatically selected based on the shell's keymap where the -## keybinding is defined. If this is set to "emacs", "vim-insert", or -## "vim-normal", the startup keymap mode in the Atuin search is forced to be -## the specified one. -# keymap_mode = "auto" - -## Cursor style in each keymap mode. If specified, the cursor style is changed -## in entering the cursor shape. Available values are "default" and -## "{blink,steady}-{block,underline,bar}". -# keymap_cursor = { emacs = "blink-block", vim_insert = "blink-block", vim_normal = "steady-block" } - -# network_connect_timeout = 5 -# network_timeout = 5 - -## Timeout (in seconds) for acquiring a local database connection (sqlite) -# local_timeout = 5 - -## Set this to true and Atuin will minimize motion in the UI - timers will not update live, etc. -## Alternatively, set env NO_MOTION=true -# prefers_reduced_motion = false - -[stats] -## Set commands where we should consider the subcommand for statistics. Eg, kubectl get vs just kubectl -# common_subcommands = [ -# "apt", -# "cargo", -# "composer", -# "dnf", -# "docker", -# "dotnet", -# "git", -# "go", -# "ip", -# "jj", -# "kubectl", -# "nix", -# "nmcli", -# "npm", -# "pecl", -# "pnpm", -# "podman", -# "port", -# "systemctl", -# "tmux", -# "yarn", -# ] - -## Set commands that should be totally stripped and ignored from stats -# common_prefix = ["sudo"] - -## Set commands that will be completely ignored from stats -# ignored_commands = [ -# "cd", -# "ls", -# "vi" -# ] - -[keys] -# Defaults to true. If disabled, using the up/down key won't exit the TUI when scrolled past the first/last entry. -# scroll_exits = true - -# Defaults to true. The left arrow key will exit the TUI when scrolling before the first character -# exit_past_line_start = true - -# Defaults to true. The right arrow key performs the same functionality as Tab and copies the selected line to the command line to be modified. -# accept_past_line_end = true - -# Defaults to false. The left arrow key performs the same functionality as Tab and copies the selected line to the command line to be modified. -# accept_past_line_start = false - -# Defaults to false. The backspace key performs the same functionality as Tab and copies the selected line to the command line to be modified when at the start of the line. -# accept_with_backspace = false - -[sync] -# Enable sync v2 by default -# This ensures that sync v2 is enabled for new installs only -# In a later release it will become the default across the board -records = true - -[preview] -## which preview strategy to use to calculate the preview height (respects max_preview_height). -## possible values: auto, static -## auto: length of the selected command. -## static: length of the longest command stored in the history. -## fixed: use max_preview_height as fixed height. -# strategy = "auto" - -[daemon] -## Enables using the daemon to sync. -# enabled = false - -## Automatically start and manage the daemon when needed. -## Not compatible with `systemd_socket = true`. -# autostart = false - -## How often the daemon should sync in seconds -# sync_frequency = 300 - -## The path to the unix socket used by the daemon (on unix systems) -## linux/mac: ~/.local/share/atuin/atuin.sock -## windows: Not Supported -# socket_path = "~/.local/share/atuin/atuin.sock" - -## The daemon pidfile used for lifecycle management. -## Defaults to the Atuin data directory. -# pidfile_path = "~/.local/share/atuin/atuin-daemon.pid" - -## Use systemd socket activation rather than opening the given path (the path must still be correct for the client) -## linux: false -## mac/windows: Not Supported -# systemd_socket = false - -## The port that should be used for TCP on non unix systems -# tcp_port = 8889 - -# [theme] -## Color theme to use for rendering in the terminal. -## There are some built-in themes, including the base theme ("default"), -## "autumn" and "marine". You can add your own themes to the "./themes" subdirectory of your -## Atuin config (or ATUIN_THEME_DIR, if provided) as TOML files whose keys should be one or -## more of AlertInfo, AlertWarn, AlertError, Annotation, Base, Guidance, Important, and -## the string values as lowercase entries from this list: -## https://ogeon.github.io/docs/palette/master/palette/named/index.html -## If you provide a custom theme file, it should be called "NAME.toml" and the theme below -## should be the stem, i.e. `theme = "NAME"` for your chosen NAME. -# name = "autumn" - -## Whether the theme manager should output normal or extra information to help fix themes. -## Boolean, true or false. If unset, left up to the theme manager. -# debug = true - -[search] -## The list of enabled filter modes, in order of priority. -## The "workspace" mode is skipped when not in a workspace or workspaces = false. -## Default filter mode can be overridden with the filter_mode setting. -# filters = [ "global", "host", "session", "session-preload", "workspace", "directory" ] - -[tmux] -## Enable using atuin with tmux popup (requires tmux >= 3.2) -## When enabled and running inside tmux, Atuin will use a popup window for interactive search. -## Set to false to disable the popup. -## This can also be controlled with the ATUIN_TMUX_POPUP environment variable. -## Note: The tmux popup is currently supported in zsh, bash, and fish shells. This currently doesn't work with iTerm native tmux integration. -# enabled = false - -## Width of the tmux popup window -## Can be a percentage, or integer (e.g. "100" means 100 characters wide) -# width = "80%" - -## Height of the tmux popup window -## Can be a percentage, or integer (e.g. "100" means 100 lines tall) -# height = "60%" - -[ui] -## Columns to display in the interactive search, from left to right. -## The selection indicator (" > ") is always shown first implicitly. -## -## Each column can be specified as a simple string (uses default width) -## or as an object with type, width, and expand: -## { type = "directory", width = 30, expand = true } -## -## Available column types (with default widths): -## duration (5) - Command execution duration (e.g., "123ms") -## time (8) - Relative time since execution (e.g., "59m ago") -## datetime (16) - Absolute timestamp (e.g., "2025-01-22 14:35") -## directory (20) - Working directory (truncated if too long) -## host (15) - Hostname where command was run -## user (10) - Username -## exit (3) - Exit code (colored by success/failure) -## command (*) - The command itself (expands by default) -## -## The "expand" option (default: true for command, false for others) makes a -## column fill remaining space. Only one column should have expand = true. -## -## Default: -# columns = ["duration", "time", "command"] -## -## Examples: -## -## Minimal - more space for commands: -# columns = ["duration", "command"] -## -## With wider directory column: -# columns = ["duration", { type = "directory", width = 30 }, "command"] -## -## Show host for multi-machine sync users: -# columns = ["duration", "time", "host", "command"] -## -## Show exit codes prominently: -# columns = ["exit", "duration", "command"] -## -## Make directory expand instead of command: -# columns = ["duration", "time", { type = "directory", expand = true }, { type = "command", expand = false }] diff --git a/crates/atuin-client/meta-migrations/20260203030924_create_meta.sql b/crates/atuin-client/meta-migrations/20260203030924_create_meta.sql deleted file mode 100644 index 26c3c142..00000000 --- a/crates/atuin-client/meta-migrations/20260203030924_create_meta.sql +++ /dev/null @@ -1,5 +0,0 @@ -create table if not exists meta ( - key text not null primary key, - value text not null, - updated_at integer not null default (strftime('%s', 'now')) -); diff --git a/crates/atuin-client/migrations/20210422143411_create_history.sql b/crates/atuin-client/migrations/20210422143411_create_history.sql deleted file mode 100644 index 1f3f8686..00000000 --- a/crates/atuin-client/migrations/20210422143411_create_history.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add migration script here -create table if not exists history ( - id text primary key, - timestamp integer not null, - duration integer not null, - exit integer not null, - command text not null, - cwd text not null, - session text not null, - hostname text not null, - - unique(timestamp, cwd, command) -); - -create index if not exists idx_history_timestamp on history(timestamp); -create index if not exists idx_history_command on history(command); diff --git a/crates/atuin-client/migrations/20220505083406_create-events.sql b/crates/atuin-client/migrations/20220505083406_create-events.sql deleted file mode 100644 index f6cafeba..00000000 --- a/crates/atuin-client/migrations/20220505083406_create-events.sql +++ /dev/null @@ -1,11 +0,0 @@ -create table if not exists events ( - id text primary key, - timestamp integer not null, - hostname text not null, - event_type text not null, - - history_id text not null -); - --- Ensure there is only ever one of each event type per history item -create unique index history_event_idx ON events(event_type, history_id); diff --git a/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql b/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql deleted file mode 100644 index b5770e62..00000000 --- a/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Interactive search filters by command then by the max(timestamp) for that --- command. Create an index that covers those -create index if not exists idx_history_command_timestamp on history( - command, - timestamp -); diff --git a/crates/atuin-client/migrations/20230315220114_drop-events.sql b/crates/atuin-client/migrations/20230315220114_drop-events.sql deleted file mode 100644 index fe3cae17..00000000 --- a/crates/atuin-client/migrations/20230315220114_drop-events.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -drop table events; diff --git a/crates/atuin-client/migrations/20230319185725_deleted_at.sql b/crates/atuin-client/migrations/20230319185725_deleted_at.sql deleted file mode 100644 index 6c422abc..00000000 --- a/crates/atuin-client/migrations/20230319185725_deleted_at.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -alter table history add column deleted_at integer; diff --git a/crates/atuin-client/migrations/20260224000100_history_author_intent.sql b/crates/atuin-client/migrations/20260224000100_history_author_intent.sql deleted file mode 100644 index 2bed17e9..00000000 --- a/crates/atuin-client/migrations/20260224000100_history_author_intent.sql +++ /dev/null @@ -1,2 +0,0 @@ -alter table history add column author text; -alter table history add column intent text; diff --git a/crates/atuin-client/record-migrations/20230531212437_create-records.sql b/crates/atuin-client/record-migrations/20230531212437_create-records.sql deleted file mode 100644 index 4f4b304a..00000000 --- a/crates/atuin-client/record-migrations/20230531212437_create-records.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add migration script here -create table if not exists records ( - id text primary key, - parent text unique, -- null if this is the first one - host text not null, - - timestamp integer not null, - tag text not null, - version text not null, - data blob not null, - cek blob not null -); - -create index host_idx on records (host); -create index tag_idx on records (tag); -create index host_tag_idx on records (host, tag); diff --git a/crates/atuin-client/record-migrations/20231127090831_create-store.sql b/crates/atuin-client/record-migrations/20231127090831_create-store.sql deleted file mode 100644 index 53d78860..00000000 --- a/crates/atuin-client/record-migrations/20231127090831_create-store.sql +++ /dev/null @@ -1,15 +0,0 @@ --- Add migration script here -create table if not exists store ( - id text primary key, -- globally unique ID - - idx integer, -- incrementing integer ID unique per (host, tag) - host text not null, -- references the host row - tag text not null, - - timestamp integer not null, - version text not null, - data blob not null, - cek blob not null -); - -create unique index record_uniq ON store(host, tag, idx); diff --git a/crates/atuin-client/src/api_client.rs b/crates/atuin-client/src/api_client.rs deleted file mode 100644 index ca2fc661..00000000 --- a/crates/atuin-client/src/api_client.rs +++ /dev/null @@ -1,437 +0,0 @@ -use std::collections::HashMap; -use std::env; -use std::time::Duration; - -use eyre::{Result, bail, eyre}; -use reqwest::{ - Response, StatusCode, Url, - header::{AUTHORIZATION, HeaderMap, USER_AGENT}, -}; - -use atuin_common::{ - api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, - record::{EncryptedData, HostId, Record, RecordIdx}, - tls::ensure_crypto_provider, -}; -use atuin_common::{ - api::{ - AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, - ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, - SyncHistoryResponse, - }, - record::RecordStatus, -}; - -use semver::Version; -use time::OffsetDateTime; -use time::format_description::well_known::Rfc3339; - -use crate::{history::History, sync::hash_str, utils::get_host_user}; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); - -/// Authentication token for sync API requests. -/// -/// The sync API supports two authentication methods: -/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub) -/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted) -/// -/// When both are available, Hub tokens are preferred as they provide unified -/// authentication across CLI and Hub features. -#[derive(Debug, Clone)] -pub enum AuthToken { - /// Legacy CLI session token, used with "Token {token}" header - Token(String), -} - -impl AuthToken { - /// Format the token as an Authorization header value - fn to_header_value(&self) -> String { - match self { - AuthToken::Token(token) => format!("Token {token}"), - } - } -} - -pub struct Client<'a> { - sync_addr: &'a str, - client: reqwest::Client, -} - -fn make_url(address: &str, path: &str) -> Result { - // `join()` expects a trailing `/` in order to join paths - // e.g. it treats `http://host:port/subdir` as a file called `subdir` - let address = if address.ends_with("/") { - address - } else { - &format!("{address}/") - }; - - // passing a path with a leading `/` will cause `join()` to replace the entire URL path - let path = path.strip_prefix("/").unwrap_or(path); - - let url = Url::parse(address) - .map(|url| url.join(path))? - .map_err(|_| eyre!("invalid address"))?; - - Ok(url.to_string()) -} - -pub async fn register( - address: &str, - username: &str, - email: &str, - password: &str, -) -> Result { - ensure_crypto_provider(); - let mut map = HashMap::new(); - map.insert("username", username); - map.insert("email", email); - map.insert("password", password); - - let url = make_url(address, &format!("/user/{username}"))?; - let resp = reqwest::get(url).await?; - - if resp.status().is_success() { - bail!("username already in use"); - } - - let url = make_url(address, "/register")?; - let client = reqwest::Client::new(); - let resp = client - .post(url) - .header(USER_AGENT, APP_USER_AGENT) - .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) - .json(&map) - .send() - .await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not register user due to version mismatch"); - } - - let session = resp.json::().await?; - Ok(session) -} - -pub async fn login(address: &str, req: LoginRequest) -> Result { - ensure_crypto_provider(); - let url = make_url(address, "/login")?; - let client = reqwest::Client::new(); - - let resp = client - .post(url) - .header(USER_AGENT, APP_USER_AGENT) - .json(&req) - .send() - .await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("Could not login due to version mismatch"); - } - - let session = resp.json::().await?; - Ok(session) -} - -pub fn ensure_version(response: &Response) -> Result { - let version = response.headers().get(ATUIN_HEADER_VERSION); - - let version = if let Some(version) = version { - match version.to_str() { - Ok(v) => Version::parse(v), - Err(e) => bail!("failed to parse server version: {:?}", e), - } - } else { - bail!("Server not reporting its version: it is either too old or unhealthy"); - }?; - - // If the client is newer than the server - if version.major < ATUIN_VERSION.major { - println!( - "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" - ); - println!("Client: {ATUIN_CARGO_VERSION}"); - println!("Server: {version}"); - - return Ok(false); - } - - Ok(true) -} - -async fn handle_resp_error(resp: Response) -> Result { - let status = resp.status(); - let url = resp.url().to_string(); - - if status == StatusCode::SERVICE_UNAVAILABLE { - bail!( - "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" - ); - } - - if status == StatusCode::TOO_MANY_REQUESTS { - bail!("Rate limited; please wait before doing that again"); - } - - if !status.is_success() { - if let Ok(error) = resp.json::().await { - let reason = error.reason; - - if status.is_client_error() { - bail!("Invalid request to the service at {url}, {status} - {reason}.") - } - - bail!( - "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" - ) - } - - bail!( - "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" - ) - } - - Ok(resp) -} - -impl<'a> Client<'a> { - pub fn new( - sync_addr: &'a str, - auth: AuthToken, - connect_timeout: u64, - timeout: u64, - ) -> Result { - ensure_crypto_provider(); - let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); - - // used for semver server check - headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); - - Ok(Client { - sync_addr, - client: reqwest::Client::builder() - .user_agent(APP_USER_AGENT) - .default_headers(headers) - .connect_timeout(Duration::new(connect_timeout, 0)) - .timeout(Duration::new(timeout, 0)) - .build()?, - }) - } - - pub async fn count(&self) -> Result { - let url = make_url(self.sync_addr, "/sync/count")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync due to version mismatch"); - } - - if resp.status() != StatusCode::OK { - bail!("failed to get count (are you logged in?)"); - } - - let count = resp.json::().await?; - - Ok(count.count) - } - - pub async fn status(&self) -> Result { - let url = make_url(self.sync_addr, "/sync/status")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync due to version mismatch"); - } - - let status = resp.json::().await?; - - Ok(status) - } - - pub async fn me(&self) -> Result { - let url = make_url(self.sync_addr, "/api/v0/me")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let status = resp.json::().await?; - - Ok(status) - } - - pub async fn get_history( - &self, - sync_ts: OffsetDateTime, - history_ts: OffsetDateTime, - host: Option, - ) -> Result { - let host = host.unwrap_or_else(|| hash_str(&get_host_user())); - - let url = make_url( - self.sync_addr, - &format!( - "/sync/history?sync_ts={}&history_ts={}&host={}", - urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), - urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), - host, - ), - )?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let history = resp.json::().await?; - Ok(history) - } - - pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { - let url = make_url(self.sync_addr, "/history")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.post(url).json(history).send().await?; - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn delete_history(&self, h: History) -> Result<()> { - let url = make_url(self.sync_addr, "/history")?; - let url = Url::parse(url.as_str())?; - - let resp = self - .client - .delete(url) - .json(&DeleteHistoryRequest { - client_id: h.id.to_string(), - }) - .send() - .await?; - - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn delete_store(&self) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/store")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.delete(url).send().await?; - - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn post_records(&self, records: &[Record]) -> Result<()> { - let url = make_url(self.sync_addr, "/api/v0/record")?; - let url = Url::parse(url.as_str())?; - - debug!("uploading {} records to {url}", records.len()); - - let resp = self.client.post(url).json(records).send().await?; - handle_resp_error(resp).await?; - - Ok(()) - } - - pub async fn next_records( - &self, - host: HostId, - tag: String, - start: RecordIdx, - count: u64, - ) -> Result>> { - debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); - - let url = make_url( - self.sync_addr, - &format!( - "/api/v0/record/next?host={}&tag={}&count={}&start={}", - host.0, tag, count, start - ), - )?; - - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - let records = resp.json::>>().await?; - - Ok(records) - } - - pub async fn record_status(&self) -> Result { - let url = make_url(self.sync_addr, "/api/v0/record")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.get(url).send().await?; - let resp = handle_resp_error(resp).await?; - - if !ensure_version(&resp)? { - bail!("could not sync records due to version mismatch"); - } - - let index = resp.json().await?; - - debug!("got remote index {index:?}"); - - Ok(index) - } - - pub async fn delete(&self) -> Result<()> { - let url = make_url(self.sync_addr, "/account")?; - let url = Url::parse(url.as_str())?; - - let resp = self.client.delete(url).send().await?; - - if resp.status() == 403 { - bail!("invalid login details"); - } else if resp.status() == 200 { - Ok(()) - } else { - bail!("Unknown error"); - } - } - - pub async fn change_password( - &self, - current_password: String, - new_password: String, - ) -> Result<()> { - let url = make_url(self.sync_addr, "/account/password")?; - let url = Url::parse(url.as_str())?; - - let resp = self - .client - .patch(url) - .json(&ChangePasswordRequest { - current_password, - new_password, - }) - .send() - .await?; - - if resp.status() == 401 { - bail!("current password is incorrect") - } else if resp.status() == 403 { - bail!("invalid login details"); - } else if resp.status() == 200 { - Ok(()) - } else { - bail!("Unknown error"); - } - } -} diff --git a/crates/atuin-client/src/auth.rs b/crates/atuin-client/src/auth.rs deleted file mode 100644 index 1031c11f..00000000 --- a/crates/atuin-client/src/auth.rs +++ /dev/null @@ -1,230 +0,0 @@ -use async_trait::async_trait; -use eyre::{Context, Result, bail}; -use reqwest::{Url, header::USER_AGENT}; - -use atuin_common::{ - api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ChangePasswordRequest, LoginRequest}, - tls::ensure_crypto_provider, -}; - -use crate::settings::Settings; - -static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); - -/// Result of an auth operation that may require 2FA. -pub enum AuthResponse { - /// Operation succeeded; for login/register, contains the session token. - /// `auth_type` indicates the kind of token: `Some("hub")` for Hub API - /// tokens (prefixed `atapi_`), `Some("cli")` for legacy CLI session - /// tokens. `None` when the server didn't include the field (old servers). - Success { - session: String, - auth_type: Option, - }, - /// Two-factor authentication is required; the caller should prompt for a - /// TOTP code and retry with it. - TwoFactorRequired, -} - -/// Result of a mutating account operation that may require 2FA. -pub enum MutateResponse { - /// Operation completed successfully. - Success, - /// Two-factor authentication is required; the caller should prompt for a - /// TOTP code and retry. - TwoFactorRequired, -} - -/// Abstraction over the legacy (Rust sync server) and Hub auth APIs. -/// -/// CLI commands use this trait so they don't need to know which backend is -/// active — they just prompt for input and call these methods. -#[async_trait] -pub trait AuthClient: Send + Sync { - /// Log in with username + password, optionally providing a TOTP code. - async fn login( - &self, - username: &str, - password: &str, - totp_code: Option<&str>, - ) -> Result; - - /// Register a new account. - async fn register(&self, username: &str, email: &str, password: &str) -> Result; - - /// Change the account password, optionally providing a TOTP code. - async fn change_password( - &self, - current_password: &str, - new_password: &str, - totp_code: Option<&str>, - ) -> Result; - - /// Delete the account, requiring the current password and optionally a TOTP code. - async fn delete_account( - &self, - password: &str, - totp_code: Option<&str>, - ) -> Result; -} - -/// Resolve the appropriate [`AuthClient`] for the current settings. -pub async fn auth_client(settings: &Settings) -> Box { - Box::new(LegacyAuthClient::new( - &settings.sync_address, - settings.session_token().await.ok(), - settings.network_connect_timeout, - settings.network_timeout, - )) as Box -} - -// --------------------------------------------------------------------------- -// Legacy backend — talks to the Rust sync server -// --------------------------------------------------------------------------- - -pub struct LegacyAuthClient { - address: String, - session_token: Option, - connect_timeout: u64, - timeout: u64, -} - -impl LegacyAuthClient { - pub fn new( - address: &str, - session_token: Option, - connect_timeout: u64, - timeout: u64, - ) -> Self { - Self { - address: address.to_string(), - session_token, - connect_timeout, - timeout, - } - } - - fn authenticated_client(&self) -> Result { - let token = self - .session_token - .as_deref() - .ok_or_else(|| eyre::eyre!("Not logged in"))?; - - ensure_crypto_provider(); - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::AUTHORIZATION, - format!("Token {token}").parse()?, - ); - headers.insert(USER_AGENT, APP_USER_AGENT.parse()?); - headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); - - Ok(reqwest::Client::builder() - .default_headers(headers) - .connect_timeout(std::time::Duration::new(self.connect_timeout, 0)) - .timeout(std::time::Duration::new(self.timeout, 0)) - .build()?) - } -} - -#[async_trait] -impl AuthClient for LegacyAuthClient { - async fn login( - &self, - username: &str, - password: &str, - _totp_code: Option<&str>, - ) -> Result { - // The legacy server has no 2FA support; totp_code is ignored. - let resp = crate::api_client::login( - &self.address, - LoginRequest { - username: username.to_string(), - password: password.to_string(), - }, - ) - .await?; - - Ok(AuthResponse::Success { - session: resp.session, - auth_type: resp.auth.or(Some("cli".into())), - }) - } - - async fn register(&self, username: &str, email: &str, password: &str) -> Result { - let resp = crate::api_client::register(&self.address, username, email, password).await?; - Ok(AuthResponse::Success { - session: resp.session, - auth_type: resp.auth.or(Some("cli".into())), - }) - } - - async fn change_password( - &self, - current_password: &str, - new_password: &str, - _totp_code: Option<&str>, - ) -> Result { - let client = self.authenticated_client()?; - let url = make_url(&self.address, "/account/password")?; - - let resp = client - .patch(&url) - .json(&ChangePasswordRequest { - current_password: current_password.to_string(), - new_password: new_password.to_string(), - }) - .send() - .await?; - - match resp.status().as_u16() { - 200 => Ok(MutateResponse::Success), - 401 => bail!("current password is incorrect"), - 403 => bail!("invalid login details"), - _ => bail!("unknown error"), - } - } - - async fn delete_account( - &self, - password: &str, - _totp_code: Option<&str>, - ) -> Result { - let client = self.authenticated_client()?; - let url = make_url(&self.address, "/account")?; - - let resp = client - .delete(&url) - .json(&serde_json::json!({ "password": password })) - .send() - .await?; - - match resp.status().as_u16() { - 200 => Ok(MutateResponse::Success), - 401 => bail!("password is incorrect"), - 403 => bail!("invalid login details"), - _ => bail!("unknown error"), - } - } -} - -// --------------------------------------------------------------------------- -// Shared helpers -// --------------------------------------------------------------------------- - -fn make_url(address: &str, path: &str) -> Result { - let address = if address.ends_with('/') { - address.to_string() - } else { - format!("{address}/") - }; - - let path = path.strip_prefix('/').unwrap_or(path); - - let url = Url::parse(&address) - .context("failed to parse server address")? - .join(path) - .context("failed to join URL path")?; - - Ok(url.to_string()) -} diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs deleted file mode 100644 index 946c1eb0..00000000 --- a/crates/atuin-client/src/database.rs +++ /dev/null @@ -1,1525 +0,0 @@ -use std::{ - env, - path::{Path, PathBuf}, - str::FromStr, - time::Duration, -}; - -use crate::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; -use async_trait::async_trait; -use atuin_common::utils; -use fs_err as fs; -use itertools::Itertools; -use rand::{Rng, distributions::Alphanumeric}; -use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; -use sqlx::{ - Result, Row, - sqlite::{ - SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, - SqliteSynchronous, - }, -}; -use time::OffsetDateTime; -use uuid::Uuid; - -use crate::{ - history::{HistoryId, HistoryStats}, - utils::get_host_user, -}; - -use super::{ - history::History, - ordering, - settings::{FilterMode, SearchMode, Settings}, -}; - -#[derive(Clone)] -pub struct Context { - pub session: String, - pub cwd: String, - pub hostname: String, - pub host_id: String, - pub git_root: Option, -} - -#[derive(Default, Clone)] -pub struct OptFilters { - pub exit: Option, - pub exclude_exit: Option, - pub cwd: Option, - pub exclude_cwd: Option, - pub before: Option, - pub after: Option, - pub limit: Option, - pub offset: Option, - pub reverse: bool, - pub include_duplicates: bool, - /// Author filter. Supports special values `$all-user` and `$all-agent`. - pub authors: Vec, -} - -pub async fn current_context() -> eyre::Result { - let session = env::var("ATUIN_SESSION").map_err(|_| { - eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") - })?; - let hostname = get_host_user(); - let cwd = utils::get_current_dir(); - let host_id = Settings::host_id().await?; - let git_root = utils::in_git_repo(cwd.as_str()); - - Ok(Context { - session, - hostname, - cwd, - git_root, - host_id: host_id.0.as_simple().to_string(), - }) -} - -impl Context { - pub fn from_history(entry: &History) -> Self { - Context { - session: entry.session.to_string(), - cwd: entry.cwd.to_string(), - hostname: entry.hostname.to_string(), - host_id: String::new(), - git_root: utils::in_git_repo(entry.cwd.as_str()), - } - } -} - -/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. -fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { - let mut conditions: Vec = Vec::new(); - let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); - let author_expr = "CASE \ - WHEN author IS NULL OR trim(author) = '' THEN \ - CASE \ - WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ - ELSE hostname \ - END \ - ELSE author \ - END"; - - for author in authors { - match author.as_str() { - AUTHOR_FILTER_ALL_USER => { - conditions.push(format!("{author_expr} NOT IN ({agent_list})")); - } - AUTHOR_FILTER_ALL_AGENT => { - conditions.push(format!("{author_expr} IN ({agent_list})")); - } - literal => { - conditions.push(format!("{author_expr} = {}", quote(literal))); - } - } - } - - if !conditions.is_empty() { - sql.and_where(format!("({})", conditions.join(" OR "))); - } -} - -fn get_session_start_time(session_id: &str) -> Option { - if let Ok(uuid) = Uuid::parse_str(session_id) - && let Some(timestamp) = uuid.get_timestamp() - { - let (seconds, nanos) = timestamp.to_unix(); - return Some(seconds as i64 * 1_000_000_000 + nanos as i64); - } - None -} - -#[async_trait] -pub trait Database: Send + Sync + 'static { - async fn save(&self, h: &History) -> Result<()>; - async fn save_bulk(&self, h: &[History]) -> Result<()>; - - async fn load(&self, id: &str) -> Result>; - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option, - unique: bool, - include_deleted: bool, - ) -> Result>; - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result>; - - async fn update(&self, h: &History) -> Result<()>; - async fn history_count(&self, include_deleted: bool) -> Result; - - async fn last(&self) -> Result>; - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result>; - - async fn delete(&self, h: History) -> Result<()>; - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; - async fn deleted(&self) -> Result>; - - // Yes I know, it's a lot. - // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. - // Been debating maybe a DSL for search? eg "before:time limit:1 the query" - #[expect(clippy::too_many_arguments)] - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result>; - - async fn query_history(&self, query: &str) -> Result>; - - async fn all_with_count(&self) -> Result>; - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; - - async fn stats(&self, h: &History) -> Result; - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result>; - - fn clone_boxed(&self) -> Box; -} - -// Intended for use on a developer machine and not a sync server. -// TODO: implement IntoIterator -#[derive(Debug, Clone)] -pub struct Sqlite { - pub pool: SqlitePool, -} - -impl Sqlite { - pub async fn new(path: impl AsRef, timeout: f64) -> Result { - let path = path.as_ref(); - debug!("opening sqlite database at {path:?}"); - - if utils::broken_symlink(path) { - eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." - ); - std::process::exit(1); - } - - if !path.exists() - && let Some(dir) = path.parent() - { - fs::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? - .journal_mode(SqliteJournalMode::Wal) - .optimize_on_close(true, None) - .synchronous(SqliteSynchronous::Normal) - .with_regexp() - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - Self::setup_db(&pool).await?; - Ok(Self { pool }) - } - - pub async fn sqlite_version(&self) -> Result { - sqlx::query_scalar("SELECT sqlite_version()") - .fetch_one(&self.pool) - .await - } - - async fn setup_db(pool: &SqlitePool) -> Result<()> { - debug!("running sqlite database setup"); - - sqlx::migrate!("./migrations").run(pool).await?; - - Ok(()) - } - - async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { - sqlx::query( - "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) - values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", - ) - .bind(h.id.0.as_str()) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(h.duration) - .bind(h.exit) - .bind(h.command.as_str()) - .bind(h.cwd.as_str()) - .bind(h.session.as_str()) - .bind(h.hostname.as_str()) - .bind(h.author.as_str()) - .bind(h.intent.as_deref()) - .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - async fn delete_row_raw( - tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, - id: HistoryId, - ) -> Result<()> { - sqlx::query("delete from history where id = ?1") - .bind(id.0.as_str()) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - fn query_history(row: SqliteRow) -> History { - let deleted_at: Option = row.get("deleted_at"); - let hostname: String = row.get("hostname"); - let author: Option = row.try_get("author").ok().flatten(); - let author = author - .filter(|author| !author.trim().is_empty()) - .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); - let intent: Option = row.try_get("intent").ok().flatten(); - let intent = intent.filter(|intent| !intent.trim().is_empty()); - - History::from_db() - .id(row.get("id")) - .timestamp( - OffsetDateTime::from_unix_timestamp_nanos(row.get::("timestamp") as i128) - .unwrap(), - ) - .duration(row.get("duration")) - .exit(row.get("exit")) - .command(row.get("command")) - .cwd(row.get("cwd")) - .session(row.get("session")) - .hostname(hostname) - .author(author) - .intent(intent) - .deleted_at( - deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), - ) - .build() - .into() - } -} - -#[async_trait] -impl Database for Sqlite { - async fn save(&self, h: &History) -> Result<()> { - debug!("saving history to sqlite"); - let mut tx = self.pool.begin().await?; - Self::save_raw(&mut tx, h).await?; - tx.commit().await?; - - Ok(()) - } - - async fn save_bulk(&self, h: &[History]) -> Result<()> { - debug!("saving history to sqlite"); - - let mut tx = self.pool.begin().await?; - - for i in h { - Self::save_raw(&mut tx, i).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn load(&self, id: &str) -> Result> { - debug!("loading history item {}", id); - - let res = sqlx::query("select * from history where id = ?1") - .bind(id) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - Ok(res) - } - - async fn update(&self, h: &History) -> Result<()> { - debug!("updating sqlite history"); - - sqlx::query( - "update history - set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 - where id = ?1", - ) - .bind(h.id.0.as_str()) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(h.duration) - .bind(h.exit) - .bind(h.command.as_str()) - .bind(h.cwd.as_str()) - .bind(h.session.as_str()) - .bind(h.hostname.as_str()) - .bind(h.author.as_str()) - .bind(h.intent.as_deref()) - .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) - .execute(&self.pool) - .await?; - - Ok(()) - } - - // make a unique list, that only shows the *newest* version of things - async fn list( - &self, - filters: &[FilterMode], - context: &Context, - max: Option, - unique: bool, - include_deleted: bool, - ) -> Result> { - debug!("listing history"); - - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - query.field("*").order_desc("timestamp"); - if !include_deleted { - query.and_where_is_null("deleted_at"); - } - - let git_root = if let Some(git_root) = context.git_root.clone() { - git_root.to_str().unwrap_or("/").to_string() - } else { - context.cwd.clone() - }; - - let session_start = get_session_start_time(&context.session); - - for filter in filters { - match filter { - FilterMode::Global => &mut query, - FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), - FilterMode::Session => query.and_where_eq("session", quote(&context.session)), - FilterMode::SessionPreload => { - query.and_where_eq("session", quote(&context.session)); - if let Some(session_start) = session_start { - query.or_where_lt("timestamp", session_start); - } - &mut query - } - FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), - FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), - }; - } - - if unique { - query.group_by("command").having("max(timestamp)"); - } - - if let Some(max) = max { - query.limit(max); - } - - let query = query.sql().expect("bug in list query. please report"); - - let res = sqlx::query(&query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result> { - debug!("listing history from {:?} to {:?}", from, to); - - let res = sqlx::query( - "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", - ) - .bind(from.unix_timestamp_nanos() as i64) - .bind(to.unix_timestamp_nanos() as i64) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn last(&self) -> Result> { - let res = sqlx::query( - "select * from history where duration >= 0 order by timestamp desc limit 1", - ) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - Ok(res) - } - - async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result> { - let res = sqlx::query( - "select * from history where timestamp < ?1 order by timestamp desc limit ?2", - ) - .bind(timestamp.unix_timestamp_nanos() as i64) - .bind(count) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn deleted(&self) -> Result> { - let res = sqlx::query("select * from history where deleted_at is not null") - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn history_count(&self, include_deleted: bool) -> Result { - let query = if include_deleted { - "select count(1) from history" - } else { - "select count(1) from history where deleted_at is null" - }; - - let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; - Ok(res.0) - } - - async fn search( - &self, - search_mode: SearchMode, - filter: FilterMode, - context: &Context, - query: &str, - filter_options: OptFilters, - ) -> Result> { - let mut sql = SqlBuilder::select_from("history"); - - if !filter_options.include_duplicates { - sql.group_by("command").having("max(timestamp)"); - } - - if let Some(limit) = filter_options.limit { - sql.limit(limit); - } - - if let Some(offset) = filter_options.offset { - sql.offset(offset); - } - - if filter_options.reverse { - sql.order_asc("timestamp"); - } else { - sql.order_desc("timestamp"); - } - - let git_root = if let Some(git_root) = context.git_root.clone() { - git_root.to_str().unwrap_or("/").to_string() - } else { - context.cwd.clone() - }; - - let session_start = get_session_start_time(&context.session); - - match filter { - FilterMode::Global => &mut sql, - FilterMode::Host => { - sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) - } - FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), - FilterMode::SessionPreload => { - sql.and_where_eq("session", quote(&context.session)); - if let Some(session_start) = session_start { - sql.or_where_lt("timestamp", session_start); - } - &mut sql - } - FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), - FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), - }; - - let orig_query = query; - - let mut regexes = Vec::new(); - match search_mode { - SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), - _ => { - let mut is_or = false; - for token in QueryTokenizer::new(query) { - // TODO smart case mode could be made configurable like in fzf - let (is_glob, glob) = if token.has_uppercase() { - (true, "*") - } else { - (false, "%") - }; - let param = match token { - QueryToken::Regex(r) => { - regexes.push(String::from(r)); - continue; - } - QueryToken::Or => { - if !is_or { - is_or = true; - continue; - } else { - format!("{glob}|{glob}") - } - } - QueryToken::MatchStart(term, _) => { - format!("{term}{glob}") - } - QueryToken::MatchEnd(term, _) => { - format!("{glob}{term}") - } - QueryToken::MatchFull(term, _) => { - format!("{glob}{term}{glob}") - } - QueryToken::Match(term, _) => { - if search_mode == SearchMode::FullText { - format!("{glob}{term}{glob}") - } else { - term.split("").join(glob) - } - } - }; - - sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); - is_or = false; - } - - &mut sql - } - }; - - for regex in regexes { - sql.and_where("command regexp ?".bind(®ex)); - } - - filter_options - .exit - .map(|exit| sql.and_where_eq("exit", exit)); - - filter_options - .exclude_exit - .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); - - filter_options - .cwd - .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); - - filter_options - .exclude_cwd - .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); - - filter_options.before.map(|before| { - interim::parse_date_string( - before.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - ) - .map(|before| { - sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) - }) - }); - - filter_options.after.map(|after| { - interim::parse_date_string( - after.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - ) - .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) - }); - - if !filter_options.authors.is_empty() { - apply_author_filter(&mut sql, &filter_options.authors); - } - - sql.and_where_is_null("deleted_at"); - - let query = sql.sql().expect("bug in search query. please report"); - - let res = sqlx::query(&query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) - } - - async fn query_history(&self, query: &str) -> Result> { - let res = sqlx::query(query) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn all_with_count(&self) -> Result> { - debug!("listing history"); - - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - - query - .fields(&[ - "id", - "max(timestamp) as timestamp", - "max(duration) as duration", - "exit", - "command", - "deleted_at", - "null as author", - "null as intent", - "group_concat(cwd, ':') as cwd", - "group_concat(session) as session", - "group_concat(hostname, ',') as hostname", - "count(*) as count", - ]) - .group_by("command") - .group_by("exit") - .and_where("deleted_at is null") - .order_desc("timestamp"); - - let query = query.sql().expect("bug in list query. please report"); - - let res = sqlx::query(&query) - .map(|row: SqliteRow| { - let count: i32 = row.get("count"); - (Self::query_history(row), count) - }) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { - Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) - } - - // deleted_at doesn't mean the actual time that the user deleted it, - // but the time that the system marks it as deleted - async fn delete(&self, mut h: History) -> Result<()> { - let now = OffsetDateTime::now_utc(); - h.command = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(32) - .map(char::from) - .collect(); // overwrite with random string - h.deleted_at = Some(now); // delete it - - self.update(&h).await?; // save it - - Ok(()) - } - - async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for id in ids { - Self::delete_row_raw(&mut tx, id.clone()).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn stats(&self, h: &History) -> Result { - // We select the previous in the session by time - let mut prev = SqlBuilder::select_from("history"); - prev.field("*") - .and_where("timestamp < ?1") - .and_where("session = ?2") - .order_by("timestamp", true) - .limit(1); - - let mut next = SqlBuilder::select_from("history"); - next.field("*") - .and_where("timestamp > ?1") - .and_where("session = ?2") - .order_by("timestamp", false) - .limit(1); - - let mut total = SqlBuilder::select_from("history"); - total.field("count(1)").and_where("command = ?1"); - - let mut average = SqlBuilder::select_from("history"); - average.field("avg(duration)").and_where("command = ?1"); - - let mut exits = SqlBuilder::select_from("history"); - exits - .fields(&["exit", "count(1) as count"]) - .and_where("command = ?1") - .group_by("exit"); - - // rewrite the following with sqlbuilder - let mut day_of_week = SqlBuilder::select_from("history"); - day_of_week - .fields(&[ - "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", - "count(1) as count", - ]) - .and_where("command = ?1") - .group_by("day_of_week"); - - // Intentionally format the string with 01 hardcoded. We want the average runtime for the - // _entire month_, but will later parse it as a datetime for sorting - // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a - // string sort, which won't be correct. - let mut duration_over_time = SqlBuilder::select_from("history"); - duration_over_time - .fields(&[ - "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", - "avg(duration) as duration", - ]) - .and_where("command = ?1") - .group_by("month_year") - .having("duration > 0"); - - let prev = prev.sql().expect("issue in stats previous query"); - let next = next.sql().expect("issue in stats next query"); - let total = total.sql().expect("issue in stats average query"); - let average = average.sql().expect("issue in stats previous query"); - let exits = exits.sql().expect("issue in stats exits query"); - let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); - let duration_over_time = duration_over_time - .sql() - .expect("issue in stats duration over time query"); - - let prev = sqlx::query(&prev) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(&h.session) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - let next = sqlx::query(&next) - .bind(h.timestamp.unix_timestamp_nanos() as i64) - .bind(&h.session) - .map(Self::query_history) - .fetch_optional(&self.pool) - .await?; - - let total: (i64,) = sqlx::query_as(&total) - .bind(&h.command) - .fetch_one(&self.pool) - .await?; - - let average: (f64,) = sqlx::query_as(&average) - .bind(&h.command) - .fetch_one(&self.pool) - .await?; - - let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) - .bind(&h.command) - .fetch_all(&self.pool) - .await?; - - let duration_over_time = duration_over_time - .iter() - .map(|f| (f.0.clone(), f.1.round() as i64)) - .collect(); - - Ok(HistoryStats { - next, - previous: prev, - total: total.0 as u64, - average_duration: average.0 as u64, - exits, - day_of_week, - duration_over_time, - }) - } - - async fn get_dups(&self, before: i64, dupkeep: u32) -> Result> { - let res = sqlx::query( - "SELECT * FROM ( - SELECT *, ROW_NUMBER() - OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) - AS rn - FROM history - ) sub - WHERE rn > ?1 and timestamp < ?2; - ", - ) - .bind(dupkeep) - .bind(before) - .map(Self::query_history) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - fn clone_boxed(&self) -> Box { - Box::new(self.clone()) - } -} - -pub struct Paged { - database: Box, - page_size: usize, - last_id: Option, - include_deleted: bool, - unique: bool, -} - -impl Paged { - pub fn new( - database: Box, - page_size: usize, - include_deleted: bool, - unique: bool, - ) -> Self { - Self { - database, - page_size, - last_id: None, - include_deleted, - unique, - } - } - - pub async fn next(&mut self) -> Result>> { - let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); - - query.field("*").order_desc("id"); - - if !self.include_deleted { - query.and_where_is_null("deleted_at"); - } - - if self.unique { - // We want to deduplicate on command, but the user can search via cwd, hostname, and session. - // Without those fields, filter modes won't work right. With those fields, we get duplicates. - // This must be handled upstream. - query - .group_by("command, cwd, hostname, session") - .having("max(timestamp)"); - } - - query.limit(self.page_size); - - if let Some(last_id) = &self.last_id { - query.and_where_lt("id", quote(last_id)); - } - - let query = query.sql().expect("bug in list query. please report"); - let res = self.database.query_history(&query).await?; - - if res.is_empty() { - Ok(None) - } else { - self.last_id = Some(res.last().unwrap().id.0.clone()); - Ok(Some(res)) - } - } -} - -trait SqlBuilderExt { - fn fuzzy_condition( - &mut self, - field: S, - mask: T, - inverse: bool, - glob: bool, - is_or: bool, - ) -> &mut Self; -} - -impl SqlBuilderExt for SqlBuilder { - /// adapted from the sql-builder *like functions - fn fuzzy_condition( - &mut self, - field: S, - mask: T, - inverse: bool, - glob: bool, - is_or: bool, - ) -> &mut Self { - let mut cond = field.to_string(); - if inverse { - cond.push_str(" NOT"); - } - if glob { - cond.push_str(" GLOB '"); - } else { - cond.push_str(" LIKE '"); - } - cond.push_str(&esc(mask.to_string())); - cond.push('\''); - if is_or { - self.or_where(cond) - } else { - self.and_where(cond) - } - } -} - -#[cfg(test)] -mod test { - use crate::settings::test_local_timeout; - - use super::*; - use std::time::{Duration, Instant}; - - async fn assert_search_eq( - db: &impl Database, - mode: SearchMode, - filter_mode: FilterMode, - query: &str, - expected: usize, - ) -> Result> { - let context = Context { - hostname: "test:host".to_string(), - session: "beepboopiamasession".to_string(), - cwd: "/home/ellie".to_string(), - host_id: "test-host".to_string(), - git_root: None, - }; - - let results = db - .search( - mode, - filter_mode, - &context, - query, - OptFilters { - ..Default::default() - }, - ) - .await?; - - assert_eq!( - results.len(), - expected, - "query \"{}\", commands: {:?}", - query, - results.iter().map(|a| &a.command).collect::>() - ); - Ok(results) - } - - async fn assert_search_commands( - db: &impl Database, - mode: SearchMode, - filter_mode: FilterMode, - query: &str, - expected_commands: Vec<&str>, - ) { - let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) - .await - .unwrap(); - let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); - assert_eq!(commands, expected_commands); - } - - async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { - let mut captured: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(cmd) - .cwd("/home/ellie") - .build() - .into(); - - captured.exit = 0; - captured.duration = 1; - captured.session = "beep boop".to_string(); - captured.hostname = "booop".to_string(); - - db.save(&captured).await - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_prefix() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_fulltext() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) - .await - .unwrap(); - - // regex - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/ls / ie$", - 1, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/ls / !ie", - 0, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "meow r/ls/", - 0, - ) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r//home//", - 1, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r//home///", - 0, - ) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::FullText, - FilterMode::Global, - "r/home.*e", - 1, - ) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - new_history_item(&mut db, "ls /home/ellie").await.unwrap(); - new_history_item(&mut db, "ls /home/frank").await.unwrap(); - new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); - new_history_item(&mut db, "/home/ellie/.bin/rustup") - .await - .unwrap(); - - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) - .await - .unwrap(); - - // single term operators - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) - .await - .unwrap(); - - // multiple terms - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "'frank | 'rustup", - 2, - ) - .await - .unwrap(); - assert_search_eq( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "'frank | 'rustup 'ls", - 1, - ) - .await - .unwrap(); - - // case matching - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) - .await - .unwrap(); - - // regex - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_reordered_fuzzy() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - // test ordering of results: we should choose the first, even though it happened longer ago. - - new_history_item(&mut db, "curl").await.unwrap(); - new_history_item(&mut db, "corburl").await.unwrap(); - - // if fuzzy reordering is on, it should come back in a more sensible order - assert_search_commands( - &db, - SearchMode::Fuzzy, - FilterMode::Global, - "curl", - vec!["curl", "corburl"], - ) - .await; - - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) - .await - .unwrap(); - assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) - .await - .unwrap(); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_basic() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add 5 history items - for i in 0..5 { - new_history_item(&mut db, &format!("command{}", i)) - .await - .unwrap(); - } - - // Create a paged iterator with page_size of 2 - let mut paged = db.all_paged(2, false, false); - - // First page should have 2 items - let page1 = paged.next().await.unwrap(); - assert!(page1.is_some()); - assert_eq!(page1.unwrap().len(), 2); - - // Second page should have 2 items - let page2 = paged.next().await.unwrap(); - assert!(page2.is_some()); - assert_eq!(page2.unwrap().len(), 2); - - // Third page should have 1 item - let page3 = paged.next().await.unwrap(); - assert!(page3.is_some()); - assert_eq!(page3.unwrap().len(), 1); - - // Fourth page should be None (exhausted) - let page4 = paged.next().await.unwrap(); - assert!(page4.is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_empty() { - let db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Create a paged iterator on empty database - let mut paged = db.all_paged(10, false, false); - - // Should return None immediately - let page = paged.next().await.unwrap(); - assert!(page.is_none()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_unique() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add duplicate commands - new_history_item(&mut db, "duplicate").await.unwrap(); - new_history_item(&mut db, "duplicate").await.unwrap(); - new_history_item(&mut db, "unique1").await.unwrap(); - new_history_item(&mut db, "unique2").await.unwrap(); - - // Without unique flag - should get all 4 - let mut paged = db.all_paged(10, false, false); - let page = paged.next().await.unwrap().unwrap(); - assert_eq!(page.len(), 4); - - // With unique flag - should get 3 (duplicates collapsed) - let mut paged_unique = db.all_paged(10, false, true); - let page_unique = paged_unique.next().await.unwrap().unwrap(); - assert_eq!(page_unique.len(), 3); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_paged_include_deleted() { - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - - // Add items - new_history_item(&mut db, "keep1").await.unwrap(); - new_history_item(&mut db, "keep2").await.unwrap(); - new_history_item(&mut db, "delete_me").await.unwrap(); - - // Delete one item - let all = db - .list( - &[], - &Context { - hostname: "".to_string(), - session: "".to_string(), - cwd: "".to_string(), - host_id: "".to_string(), - git_root: None, - }, - None, - false, - false, - ) - .await - .unwrap(); - - let to_delete = all - .iter() - .find(|h| h.command == "delete_me") - .unwrap() - .clone(); - db.delete(to_delete).await.unwrap(); - - // Without include_deleted - should get 2 - let mut paged = db.all_paged(10, false, false); - let page = paged.next().await.unwrap().unwrap(); - assert_eq!(page.len(), 2); - - // With include_deleted - should get 3 - let mut paged_deleted = db.all_paged(10, true, false); - let page_deleted = paged_deleted.next().await.unwrap().unwrap(); - assert_eq!(page_deleted.len(), 3); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_search_bench_dupes() { - let context = Context { - hostname: "test:host".to_string(), - session: "beepboopiamasession".to_string(), - cwd: "/home/ellie".to_string(), - host_id: "test-host".to_string(), - git_root: None, - }; - - let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) - .await - .unwrap(); - for _i in 1..10000 { - new_history_item(&mut db, "i am a duplicated command") - .await - .unwrap(); - } - let start = Instant::now(); - let _results = db - .search( - SearchMode::Fuzzy, - FilterMode::Global, - &context, - "", - OptFilters { - ..Default::default() - }, - ) - .await - .unwrap(); - let duration = start.elapsed(); - - assert!(duration < Duration::from_secs(15)); - } -} - -pub struct QueryTokenizer<'a> { - query: &'a str, - last_pos: usize, -} - -pub enum QueryToken<'a> { - Match(&'a str, bool), - MatchStart(&'a str, bool), - MatchEnd(&'a str, bool), - MatchFull(&'a str, bool), - Or, - Regex(&'a str), -} - -impl<'a> QueryToken<'a> { - pub fn has_uppercase(&self) -> bool { - match self { - Self::Match(term, _) - | Self::MatchStart(term, _) - | Self::MatchEnd(term, _) - | Self::MatchFull(term, _) => term.contains(char::is_uppercase), - _ => false, - } - } - - pub fn is_inverse(&self) -> bool { - match self { - Self::Match(_, inv) - | Self::MatchStart(_, inv) - | Self::MatchEnd(_, inv) - | Self::MatchFull(_, inv) => *inv, - _ => false, - } - } -} - -impl<'a> QueryTokenizer<'a> { - pub fn new(query: &'a str) -> Self { - Self { query, last_pos: 0 } - } -} - -impl<'a> Iterator for QueryTokenizer<'a> { - type Item = QueryToken<'a>; - fn next(&mut self) -> Option { - let remaining = &self.query[self.last_pos..]; - if remaining.is_empty() { - return None; - } - - if let Some(remaining) = remaining.strip_prefix("r/") { - let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { - (&remaining[..end], self.last_pos + 2 + end + 2) - } else if let Some(remaining) = remaining.strip_suffix('/') { - (remaining, self.query.len()) - } else { - (remaining, self.query.len()) - }; - self.last_pos = next_pos; - Some(QueryToken::Regex(regex)) - } else { - let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { - (&remaining[..sp], self.last_pos + sp + 1) - } else { - (remaining, self.query.len()) - }; - self.last_pos = next_pos; - - if part == "|" { - return Some(QueryToken::Or); - } - - let mut is_inverse = false; - if let Some(s) = part.strip_prefix('!') { - part = s; - is_inverse = true; - } - let token = if let Some(s) = part.strip_prefix('^') { - QueryToken::MatchStart(s, is_inverse) - } else if let Some(s) = part.strip_suffix('$') { - QueryToken::MatchEnd(s, is_inverse) - } else if let Some(s) = part.strip_prefix('\'') { - QueryToken::MatchFull(s, is_inverse) - } else { - QueryToken::Match(part, is_inverse) - }; - Some(token) - } - } -} diff --git a/crates/atuin-client/src/distro.rs b/crates/atuin-client/src/distro.rs deleted file mode 100644 index dead8355..00000000 --- a/crates/atuin-client/src/distro.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::process::Command; - -/// Detect the Linux distribution from the system, -/// using system-specific release files and falling -/// back to lsb_release. -pub fn detect_linux_distribution() -> String { - detect_from_os_release() - .or_else(detect_from_debian_version) - .or_else(detect_from_centos_release) - .or_else(detect_from_redhat_release) - .or_else(detect_from_fedora_release) - .or_else(detect_from_arch_release) - .or_else(detect_from_alpine_release) - .or_else(detect_from_suse_release) - .or_else(detect_from_lsb_release) - .unwrap_or_else(|| "Unknown".to_string()) -} - -fn detect_from_os_release() -> Option { - let content = std::fs::read_to_string("/etc/os-release").ok()?; - - content - .lines() - .find(|l| l.starts_with("PRETTY_NAME=")) - .and_then(|l| l.split_once('=').map(|s| s.1)) - .map(|s| s.trim_matches('"').to_string()) -} - -fn detect_from_debian_version() -> Option { - std::fs::read_to_string("/etc/debian_version") - .ok() - .map(|v| format!("Debian {}", v.trim())) -} - -fn detect_from_centos_release() -> Option { - std::fs::read_to_string("/etc/centos-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_redhat_release() -> Option { - std::fs::read_to_string("/etc/redhat-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_fedora_release() -> Option { - std::fs::read_to_string("/etc/fedora-release") - .ok() - .map(|v| v.trim().to_string()) -} - -fn detect_from_arch_release() -> Option { - std::fs::read_to_string("/etc/arch-release") - .ok() - .filter(|v| !v.trim().is_empty()) - .map(|_| "Arch Linux".to_string()) -} - -fn detect_from_alpine_release() -> Option { - std::fs::read_to_string("/etc/alpine-release") - .ok() - .map(|v| format!("Alpine {}", v.trim())) -} - -fn detect_from_suse_release() -> Option { - std::fs::read_to_string("/etc/SuSE-release") - .ok() - .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) -} - -fn detect_from_lsb_release() -> Option { - let output = Command::new("lsb_release").arg("-a").output().ok()?; - - if !output.status.success() { - return None; - } - - let output = String::from_utf8(output.stdout).ok()?; - linux_distro_from_lsb_release(&output) -} - -fn linux_distro_from_lsb_release(output: &str) -> Option { - output - .lines() - .find(|line| line.starts_with("Description:")) - .and_then(|line| line.split_once(':').map(|s| s.1)) - .map(|s| s.trim().to_string()) -} diff --git a/crates/atuin-client/src/encryption.rs b/crates/atuin-client/src/encryption.rs deleted file mode 100644 index f2032482..00000000 --- a/crates/atuin-client/src/encryption.rs +++ /dev/null @@ -1,440 +0,0 @@ -// The general idea is that we NEVER send cleartext history to the server -// This way the odds of anything private ending up where it should not are -// very low -// The server authenticates via the usual username and password. This has -// nothing to do with the encryption, and is purely authentication! The client -// generates its own secret key, and encrypts all shell history with libsodium's -// secretbox. The data is then sent to the server, where it is stored. All -// clients must share the secret in order to be able to sync, as it is needed -// to decrypt - -use std::{io::prelude::*, path::PathBuf}; - -use base64::prelude::{BASE64_STANDARD, Engine}; -pub use crypto_secretbox::Key; -use crypto_secretbox::{ - AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, - aead::{Nonce, OsRng}, -}; -use eyre::{Context, Result, bail, ensure, eyre}; -use fs_err as fs; -use rmp::{Marker, decode::Bytes}; -use serde::{Deserialize, Serialize}; -use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; - -use crate::{history::History, settings::Settings}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct EncryptedHistory { - pub ciphertext: Vec, - pub nonce: Nonce, -} - -pub fn generate_encoded_key() -> Result<(Key, String)> { - let key = XSalsa20Poly1305::generate_key(&mut OsRng); - let encoded = encode_key(&key)?; - - Ok((key, encoded)) -} - -pub fn new_key(settings: &Settings) -> Result { - let path = settings.key_path.as_str(); - let path = PathBuf::from(path); - - if path.exists() { - bail!("key already exists! cannot overwrite"); - } - - let (key, encoded) = generate_encoded_key()?; - - let mut file = fs::File::create(path)?; - file.write_all(encoded.as_bytes())?; - - Ok(key) -} - -// Loads the secret key, will create + save if it doesn't exist -pub fn load_key(settings: &Settings) -> Result { - let path = settings.key_path.as_str(); - - let key = if PathBuf::from(path).exists() { - let key = fs_err::read_to_string(path)?; - decode_key(key)? - } else { - new_key(settings)? - }; - - Ok(key) -} - -pub fn encode_key(key: &Key) -> Result { - let mut buf = vec![]; - rmp::encode::write_array_len(&mut buf, key.len() as u32) - .wrap_err("could not encode key to message pack")?; - for b in key { - rmp::encode::write_uint(&mut buf, *b as u64) - .wrap_err("could not encode key to message pack")?; - } - let buf = BASE64_STANDARD.encode(buf); - - Ok(buf) -} - -pub fn decode_key(key: String) -> Result { - use rmp::decode; - - let buf = BASE64_STANDARD - .decode(key.trim_end()) - .wrap_err("encryption key is not a valid base64 encoding")?; - - // old code wrote the key as a fixed length array of 32 bytes - // new code writes the key with a length prefix - match <[u8; 32]>::try_from(&*buf) { - Ok(key) => Ok(key.into()), - Err(_) => { - let mut bytes = rmp::decode::Bytes::new(&buf); - - match Marker::from_u8(buf[0]) { - Marker::Bin8 => { - let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); - let key = <[u8; 32]>::try_from(bytes.remaining_slice()) - .context("could not decode encryption key")?; - Ok(key.into()) - } - Marker::Array16 => { - let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - ensure!(len == 32, "encryption key is not the correct size"); - - let mut key = Key::default(); - for i in &mut key { - *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; - } - Ok(key) - } - _ => bail!("could not decode encryption key"), - } - } - } -} - -pub fn encrypt(history: &History, key: &Key) -> Result { - // serialize with msgpack - let mut buf = encode(history)?; - - let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); - XSalsa20Poly1305::new(key) - .encrypt_in_place(&nonce, &[], &mut buf) - .map_err(|_| eyre!("could not encrypt"))?; - - Ok(EncryptedHistory { - ciphertext: buf, - nonce, - }) -} - -pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result { - XSalsa20Poly1305::new(key) - .decrypt_in_place( - &encrypted_history.nonce, - &[], - &mut encrypted_history.ciphertext, - ) - .map_err(|_| eyre!("could not decrypt history"))?; - let plaintext = encrypted_history.ciphertext; - - let history = decode(&plaintext)?; - - Ok(history) -} - -fn format_rfc3339(ts: OffsetDateTime) -> Result { - // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. - // time does not have this functionality. - static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); - static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); - static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); - static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); - - let fmt = match ts.nanosecond() { - 0 => PARTIAL_RFC3339_0, - ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, - ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, - _ => PARTIAL_RFC3339_9, - }; - - Ok(ts.format(fmt)?) -} - -fn encode(h: &History) -> Result> { - use rmp::encode; - - let mut output = vec![]; - // INFO: ensure this is updated when adding new fields - encode::write_array_len(&mut output, 9)?; - - encode::write_str(&mut output, &h.id.0)?; - encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; - encode::write_sint(&mut output, h.duration)?; - encode::write_sint(&mut output, h.exit)?; - encode::write_str(&mut output, &h.command)?; - encode::write_str(&mut output, &h.cwd)?; - encode::write_str(&mut output, &h.session)?; - encode::write_str(&mut output, &h.hostname)?; - match h.deleted_at { - Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, - None => encode::write_nil(&mut output)?, - } - - Ok(output) -} - -fn decode(bytes: &[u8]) -> Result { - use rmp::decode::{self, DecodeStringError}; - - let mut bytes = Bytes::new(bytes); - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - if nfields < 8 { - bail!("malformed decrypted history") - } - if nfields > 9 { - bail!("cannot decrypt history from a newer version of atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - // if we have more fields, try and get the deleted_at - let mut deleted_at = None; - let mut bytes = bytes; - if nfields > 8 { - bytes = match decode::read_str_from_slice(bytes) { - Ok((d, b)) => { - deleted_at = Some(d); - b - } - // we accept null here - Err(DecodeStringError::TypeMismatch(Marker::Null)) => { - // consume the null marker - let mut c = Bytes::new(bytes); - decode::read_nil(&mut c).map_err(error_report)?; - c.remaining_slice() - } - Err(err) => return Err(error_report(err)), - }; - } - - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: History::author_from_hostname(hostname), - intent: None, - deleted_at: deleted_at - .map(|t| OffsetDateTime::parse(t, &Rfc3339)) - .transpose()?, - }) -} - -fn error_report(err: E) -> eyre::Report { - eyre!("{err:?}") -} - -#[cfg(test)] -mod test { - use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; - use pretty_assertions::assert_eq; - use time::{OffsetDateTime, macros::datetime}; - - use crate::history::History; - - use super::{decode, decrypt, encode, encrypt}; - - #[test] - fn test_encrypt_decrypt() { - let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); - let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); - - let history = History::from_db() - .id("1".into()) - .timestamp(OffsetDateTime::now_utc()) - .command("ls".into()) - .cwd("/home/ellie".into()) - .exit(0) - .duration(1) - .session("beep boop".into()) - .hostname("booop".into()) - .author("booop".into()) - .intent(None) - .deleted_at(None) - .build() - .into(); - - let e1 = encrypt(&history, &key1).unwrap(); - let e2 = encrypt(&history, &key2).unwrap(); - - assert_ne!(e1.ciphertext, e2.ciphertext); - assert_ne!(e1.nonce, e2.nonce); - - // test decryption works - // this should pass - match decrypt(e1, &key1) { - Err(e) => panic!("failed to decrypt, got {e}"), - Ok(h) => assert_eq!(h, history), - }; - - // this should err - let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); - } - - #[test] - fn test_decode() { - let bytes = [ - 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, 192, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - - let b = encode(&h).unwrap(); - assert_eq!(&bytes, &*b); - } - - #[test] - fn test_decode_deleted() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), - }; - - let b = encode(&history).unwrap(); - let h = decode(&b).unwrap(); - assert_eq!(history, h); - } - - #[test] - fn test_decode_old() { - let bytes = [ - 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, - 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, - 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, - 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, - 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, - 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, - 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, - 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, - 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, - 108, 117, 100, 103, 97, 116, 101, - ]; - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let h = decode(&bytes).unwrap(); - assert_eq!(history, h); - } - - #[test] - fn key_encodings() { - use super::{Key, decode_key, encode_key}; - - // a history of our key encodings. - // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== - // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) - // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) - // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== - // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) - // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) - - let key = Key::from([ - 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, - 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, - ]); - - assert_eq!( - encode_key(&key).unwrap(), - "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" - ); - - // key encodings we have to support - let valid_encodings = [ - "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", - "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", - ]; - - for k in valid_encodings { - assert_eq!(decode_key(k.to_owned()).expect(k), key); - } - } -} diff --git a/crates/atuin-client/src/history.rs b/crates/atuin-client/src/history.rs deleted file mode 100644 index aa0d84d5..00000000 --- a/crates/atuin-client/src/history.rs +++ /dev/null @@ -1,756 +0,0 @@ -use core::fmt::Formatter; -use rmp::decode::DecodeStringError; -use rmp::decode::ValueReadError; -use rmp::{Marker, decode::Bytes}; -use std::env; -use std::fmt::Display; - -use atuin_common::record::DecryptedData; -use atuin_common::utils::uuid_v7; - -use eyre::{Result, bail, eyre}; - -use crate::secrets::SECRET_PATTERNS_RE; -use crate::settings::Settings; -use crate::utils::get_host_user; -use time::OffsetDateTime; - -mod builder; -pub mod store; - -/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. -pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; -pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; -pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; - -pub fn is_known_agent(author: &str) -> bool { - KNOWN_AGENTS.contains(&author) -} - -pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { - filters.is_empty() - || filters.iter().any(|filter| match filter.as_str() { - AUTHOR_FILTER_ALL_USER => !is_known_agent(author), - AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), - literal => author == literal, - }) -} - -pub(crate) const HISTORY_VERSION_V0: &str = "v0"; -pub(crate) const HISTORY_VERSION_V1: &str = "v1"; -const HISTORY_RECORD_VERSION_V0: u16 = 0; -const HISTORY_RECORD_VERSION_V1: u16 = 1; -pub(crate) const HISTORY_VERSION: &str = HISTORY_VERSION_V1; -pub const HISTORY_TAG: &str = "history"; -const HISTORY_AUTHOR_ENV: &str = "ATUIN_HISTORY_AUTHOR"; -const HISTORY_INTENT_ENV: &str = "ATUIN_HISTORY_INTENT"; - -#[derive(Clone, Debug, Eq, PartialEq, Hash)] -pub struct HistoryId(pub String); - -impl Display for HistoryId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From for HistoryId { - fn from(s: String) -> Self { - Self(s) - } -} - -/// Client-side history entry. -/// -/// Client stores data unencrypted, and only encrypts it before sending to the server. -/// -/// To create a new history entry, use one of the builders: -/// - [`History::import()`] to import an entry from the shell history file -/// - [`History::capture()`] to capture an entry via hook -/// - [`History::from_db()`] to create an instance from the database entry -// -// ## Implementation Notes -// -// New fields must be added to `History::{serialize,deserialize}` in a backwards -// compatible way (sensible defaults and careful `nfields` handling). -#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] -pub struct History { - /// A client-generated ID, used to identify the entry when syncing. - /// - /// Stored as `client_id` in the database. - pub id: HistoryId, - /// When the command was run. - pub timestamp: OffsetDateTime, - /// How long the command took to run. - pub duration: i64, - /// The exit code of the command. - pub exit: i64, - /// The command that was run. - pub command: String, - /// The current working directory when the command was run. - pub cwd: String, - /// The session ID, associated with a terminal session. - pub session: String, - /// The hostname of the machine the command was run on. - pub hostname: String, - /// Who wrote this command (human user or automation/agent identity). - pub author: String, - /// Optional rationale for why the command was executed. - pub intent: Option, - /// Timestamp, which is set when the entry is deleted, allowing a soft delete. - pub deleted_at: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] -pub struct HistoryStats { - /// The command that was ran after this one in the session - pub next: Option, - /// - /// The command that was ran before this one in the session - pub previous: Option, - - /// How many times has this command been ran? - pub total: u64, - - pub average_duration: u64, - - pub exits: Vec<(i64, i64)>, - - pub day_of_week: Vec<(String, i64)>, - - pub duration_over_time: Vec<(String, i64)>, -} - -impl History { - pub(crate) fn author_from_hostname(hostname: &str) -> String { - hostname - .split_once(':') - .map_or_else(|| hostname.to_owned(), |(_, user)| user.to_owned()) - } - - fn normalize_optional_field(field: Option) -> Option { - field.and_then(|value| { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_owned()) - } - }) - } - - #[expect(clippy::too_many_arguments)] - fn new( - timestamp: OffsetDateTime, - command: String, - cwd: String, - exit: i64, - duration: i64, - session: Option, - hostname: Option, - author: Option, - intent: Option, - deleted_at: Option, - ) -> Self { - let session = session - .or_else(|| env::var("ATUIN_SESSION").ok()) - .unwrap_or_else(|| uuid_v7().as_simple().to_string()); - let hostname = hostname.unwrap_or_else(get_host_user); - let author = Self::normalize_optional_field(author) - .or_else(|| Self::normalize_optional_field(env::var(HISTORY_AUTHOR_ENV).ok())) - .unwrap_or_else(|| Self::author_from_hostname(hostname.as_str())); - let intent = Self::normalize_optional_field(intent) - .or_else(|| Self::normalize_optional_field(env::var(HISTORY_INTENT_ENV).ok())); - - Self { - id: uuid_v7().as_simple().to_string().into(), - timestamp, - command, - cwd, - exit, - duration, - session, - hostname, - author, - intent, - deleted_at, - } - } - - pub fn serialize(&self) -> Result { - // This is pretty much the same as what we used for the old history, with one difference - - // it uses integers for timestamps rather than a string format. - - use rmp::encode; - - let mut output = vec![]; - - // write the version - encode::write_u16(&mut output, HISTORY_RECORD_VERSION_V1)?; - let include_intent = self.intent.is_some(); - encode::write_array_len(&mut output, 10 + u32::from(include_intent))?; - - encode::write_str(&mut output, &self.id.0)?; - encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; - encode::write_sint(&mut output, self.duration)?; - encode::write_sint(&mut output, self.exit)?; - encode::write_str(&mut output, &self.command)?; - encode::write_str(&mut output, &self.cwd)?; - encode::write_str(&mut output, &self.session)?; - encode::write_str(&mut output, &self.hostname)?; - - match self.deleted_at { - Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, - None => encode::write_nil(&mut output)?, - } - - encode::write_str(&mut output, self.author.as_str())?; - if let Some(intent) = &self.intent { - encode::write_str(&mut output, intent.as_str())?; - } - - Ok(DecryptedData(output)) - } - - fn read_optional_string(bytes: &[u8]) -> Result<(Option, &[u8])> { - use rmp::decode; - - fn error_report(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - match decode::read_str_from_slice(bytes) { - Ok((value, bytes)) => Ok((Some(value.to_owned()), bytes)), - Err(DecodeStringError::TypeMismatch(Marker::Null)) => { - let mut cursor = Bytes::new(bytes); - decode::read_nil(&mut cursor).map_err(error_report)?; - - Ok((None, cursor.remaining_slice())) - } - Err(err) => Err(error_report(err)), - } - } - - fn deserialize_v0(bytes: &[u8]) -> Result { - use rmp::decode; - - fn error_report(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(bytes); - - let version = decode::read_u16(&mut bytes).map_err(error_report)?; - - if version != HISTORY_RECORD_VERSION_V0 { - bail!("expected decoding v0 record, found v{version}"); - } - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - - if nfields != 9 { - bail!("cannot decrypt history from a different version of Atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - - let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { - Ok(unix) => (Some(unix), bytes.remaining_slice()), - // we accept null here - Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), - Err(err) => return Err(error_report(err)), - }; - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: Self::author_from_hostname(hostname), - intent: None, - deleted_at: deleted_at - .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) - .transpose()?, - }) - } - - fn deserialize_v1(bytes: &[u8]) -> Result { - use rmp::decode; - - fn error_report(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(bytes); - - let version = decode::read_u16(&mut bytes).map_err(error_report)?; - - if version != HISTORY_RECORD_VERSION_V1 { - bail!("expected decoding v1 record, found v{version}"); - } - - let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; - - if !(10..=11).contains(&nfields) { - bail!("cannot decrypt history from a different version of Atuin"); - } - - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; - let duration = decode::read_int(&mut bytes).map_err(error_report)?; - let exit = decode::read_int(&mut bytes).map_err(error_report)?; - - let bytes = bytes.remaining_slice(); - let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - let mut bytes = Bytes::new(bytes); - - let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { - Ok(unix) => (Some(unix), bytes.remaining_slice()), - // we accept null here - Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), - Err(err) => return Err(error_report(err)), - }; - let (author, bytes) = Self::read_optional_string(bytes)?; - let (intent, bytes) = if nfields > 10 { - Self::read_optional_string(bytes)? - } else { - (None, bytes) - }; - - if !bytes.is_empty() { - bail!("trailing bytes in encoded history. malformed") - } - - Ok(History { - id: id.to_owned().into(), - timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, - duration, - exit, - command: command.to_owned(), - cwd: cwd.to_owned(), - session: session.to_owned(), - hostname: hostname.to_owned(), - author: author.unwrap_or_else(|| Self::author_from_hostname(hostname)), - intent, - deleted_at: deleted_at - .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) - .transpose()?, - }) - } - - pub fn deserialize(bytes: &[u8], version: &str) -> Result { - match version { - HISTORY_VERSION_V0 => Self::deserialize_v0(bytes), - HISTORY_VERSION_V1 => Self::deserialize_v1(bytes), - - _ => bail!("unknown version {version:?}"), - } - } - - /// Builder for a history entry that is imported from shell history. - /// - /// The only two required fields are `timestamp` and `command`. - /// - /// ## Examples - /// ``` - /// use atuin_client::history::History; - /// - /// let history: History = History::import() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - /// - /// If shell history contains more information, it can be added to the builder: - /// ``` - /// use atuin_client::history::History; - /// - /// let history: History = History::import() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .exit(0) - /// .duration(100) - /// .build() - /// .into(); - /// ``` - /// - /// Unknown command or command without timestamp cannot be imported, which - /// is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because timestamp is missing - /// let history: History = History::import() - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - pub fn import() -> builder::HistoryImportedBuilder { - builder::HistoryImported::builder() - } - - /// Builder for a history entry that is captured via hook. - /// - /// This builder is used only at the `start` step of the hook, - /// so it doesn't have any fields which are known only after - /// the command is finished, such as `exit` or `duration`. - /// - /// ## Examples - /// ```rust - /// use atuin_client::history::History; - /// - /// let history: History = History::capture() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .build() - /// .into(); - /// ``` - /// - /// Command without any required info cannot be captured, which is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `cwd` is missing - /// let history: History = History::capture() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .build() - /// .into(); - /// ``` - pub fn capture() -> builder::HistoryCapturedBuilder { - builder::HistoryCaptured::builder() - } - - /// Builder for a history entry that is captured via hook, and sent to the daemon. - /// - /// This builder is used only at the `start` step of the hook, - /// so it doesn't have any fields which are known only after - /// the command is finished, such as `exit` or `duration`. - /// - /// It does, however, include information that can usually be inferred. - /// - /// This is because the daemon we are sending a request to lacks the context of the command - /// - /// ## Examples - /// ```rust - /// use atuin_client::history::History; - /// - /// let history: History = History::daemon() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .session("018deb6e8287781f9973ef40e0fde76b") - /// .hostname("computer:ellie") - /// .build() - /// .into(); - /// ``` - /// - /// Command without any required info cannot be captured, which is forced at compile time: - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `hostname` is missing - /// let history: History = History::daemon() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la") - /// .cwd("/home/user") - /// .session("018deb6e8287781f9973ef40e0fde76b") - /// .build() - /// .into(); - /// ``` - pub fn daemon() -> builder::HistoryDaemonCaptureBuilder { - builder::HistoryDaemonCapture::builder() - } - - /// Builder for a history entry that is imported from the database. - /// - /// All fields are required, as they are all present in the database. - /// - /// ```compile_fail - /// use atuin_client::history::History; - /// - /// // this will not compile because `id` field is missing - /// let history: History = History::from_db() - /// .timestamp(time::OffsetDateTime::now_utc()) - /// .command("ls -la".to_string()) - /// .cwd("/home/user".to_string()) - /// .exit(0) - /// .duration(100) - /// .session("somesession".to_string()) - /// .hostname("localhost".to_string()) - /// .author("user".to_string()) - /// .intent(None) - /// .deleted_at(None) - /// .build() - /// .into(); - /// ``` - pub fn from_db() -> builder::HistoryFromDbBuilder { - builder::HistoryFromDb::builder() - } - - pub fn success(&self) -> bool { - self.exit == 0 || self.duration == -1 - } - - pub fn should_save(&self, settings: &Settings) -> bool { - !(self.command.starts_with(' ') - || self.command.is_empty() - || settings.history_filter.is_match(&self.command) - || settings.cwd_filter.is_match(&self.cwd) - || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) - } -} - -#[cfg(test)] -mod tests { - use regex::RegexSet; - use time::macros::datetime; - - use crate::{ - history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, - settings::Settings, - }; - - use super::{History, author_matches_filters, is_known_agent}; - - // Test that we don't save history where necessary - #[test] - fn privacy_test() { - let settings = Settings { - cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), - history_filter: RegexSet::new(["^psql"]).unwrap(), - ..Settings::utc() - }; - - let normal_command: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo foo") - .cwd("/") - .build() - .into(); - - let with_space: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command(" echo bar") - .cwd("/") - .build() - .into(); - - let empty: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("") - .cwd("/") - .build() - .into(); - - let stripe_key: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") - .cwd("/") - .build() - .into(); - - let secret_dir: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo ohno") - .cwd("/supasecret") - .build() - .into(); - - let with_psql: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("psql") - .cwd("/supasecret") - .build() - .into(); - - assert!(normal_command.should_save(&settings)); - assert!(!with_space.should_save(&settings)); - assert!(!empty.should_save(&settings)); - assert!(!stripe_key.should_save(&settings)); - assert!(!secret_dir.should_save(&settings)); - assert!(!with_psql.should_save(&settings)); - } - - #[test] - fn known_agents_include_pi() { - assert!(is_known_agent("pi")); - assert!(author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_AGENT.to_string()] - )); - assert!(!author_matches_filters( - "pi", - &[AUTHOR_FILTER_ALL_USER.to_string()] - )); - } - - #[test] - fn disable_secrets() { - let settings = Settings { - secrets_filter: false, - ..Settings::utc() - }; - - let stripe_key: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") - .cwd("/") - .build() - .into(); - - assert!(stripe_key.should_save(&settings)); - } - - #[test] - fn test_serialize_deserialize() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let serialized = history.serialize().expect("failed to serialize history"); - assert_eq!( - &serialized.0[0..3], - [205, 0, 1], - "should encode as history v1" - ); - - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_deleted() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), - }; - - let serialized = history.serialize().expect("failed to serialize history"); - - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_with_author_and_intent() { - let history = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "claude".to_owned(), - intent: Some("check repository status".to_owned()), - deleted_at: None, - }; - - let serialized = history.serialize().expect("failed to serialize history"); - let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) - .expect("failed to deserialize history"); - - assert_eq!(history, deserialized); - } - - #[test] - fn test_serialize_deserialize_version() { - // v0 - let bytes_v0 = [ - 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, - 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, - 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, - 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, - 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, - 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, - 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, - 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, - 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, - ]; - - let deserialized = History::deserialize(&bytes_v0, "v0"); - assert!(deserialized.is_ok()); - - let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); - assert!(deserialized.is_err()); - - let current = History { - id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), - timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), - duration: 49206000, - exit: 0, - command: "git status".to_owned(), - cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), - session: "b97d9a306f274473a203d2eba41f9457".to_owned(), - hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), - author: "conrad.ludgate".to_owned(), - intent: None, - deleted_at: None, - }; - - let bytes_v1 = current.serialize().expect("failed to serialize history"); - let deserialized = History::deserialize(&bytes_v1.0, HISTORY_VERSION); - assert!(deserialized.is_ok()); - - let deserialized = History::deserialize(&bytes_v1.0, "v0"); - assert!(deserialized.is_err()); - } -} diff --git a/crates/atuin-client/src/history/builder.rs b/crates/atuin-client/src/history/builder.rs deleted file mode 100644 index 72a505fd..00000000 --- a/crates/atuin-client/src/history/builder.rs +++ /dev/null @@ -1,154 +0,0 @@ -use typed_builder::TypedBuilder; - -use super::History; - -/// Builder for a history entry that is imported from shell history. -/// -/// The only two required fields are `timestamp` and `command`. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryImported { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(default = "unknown".into(), setter(into))] - cwd: String, - #[builder(default = -1)] - exit: i64, - #[builder(default = -1)] - duration: i64, - #[builder(default, setter(strip_option, into))] - session: Option, - #[builder(default, setter(strip_option, into))] - hostname: Option, - #[builder(default, setter(strip_option, into))] - author: Option, - #[builder(default, setter(strip_option, into))] - intent: Option, -} - -impl From for History { - fn from(imported: HistoryImported) -> Self { - History::new( - imported.timestamp, - imported.command, - imported.cwd, - imported.exit, - imported.duration, - imported.session, - imported.hostname, - imported.author, - imported.intent, - None, - ) - } -} - -/// Builder for a history entry that is captured via hook. -/// -/// This builder is used only at the `start` step of the hook, -/// so it doesn't have any fields which are known only after -/// the command is finished, such as `exit` or `duration`. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryCaptured { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(setter(into))] - cwd: String, - #[builder(default, setter(strip_option, into))] - author: Option, - #[builder(default, setter(strip_option, into))] - intent: Option, -} - -impl From for History { - fn from(captured: HistoryCaptured) -> Self { - History::new( - captured.timestamp, - captured.command, - captured.cwd, - -1, - -1, - None, - None, - captured.author, - captured.intent, - None, - ) - } -} - -/// Builder for a history entry that is loaded from the database. -/// -/// All fields are required, as they are all present in the database. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryFromDb { - id: String, - timestamp: time::OffsetDateTime, - command: String, - cwd: String, - exit: i64, - duration: i64, - session: String, - hostname: String, - author: String, - intent: Option, - deleted_at: Option, -} - -impl From for History { - fn from(from_db: HistoryFromDb) -> Self { - History { - id: from_db.id.into(), - timestamp: from_db.timestamp, - exit: from_db.exit, - command: from_db.command, - cwd: from_db.cwd, - duration: from_db.duration, - session: from_db.session, - hostname: from_db.hostname, - author: from_db.author, - intent: from_db.intent, - deleted_at: from_db.deleted_at, - } - } -} - -/// Builder for a history entry that is captured via hook and sent to the daemon -/// -/// This builder is similar to Capture, but we just require more information up front. -/// For the old setup, we could just rely on History::new to read some of the missing -/// data. This is no longer the case. -#[derive(Debug, Clone, TypedBuilder)] -pub struct HistoryDaemonCapture { - timestamp: time::OffsetDateTime, - #[builder(setter(into))] - command: String, - #[builder(setter(into))] - cwd: String, - #[builder(setter(into))] - session: String, - #[builder(setter(into))] - hostname: String, - #[builder(default, setter(strip_option, into))] - author: Option, - #[builder(default, setter(strip_option, into))] - intent: Option, -} - -impl From for History { - fn from(captured: HistoryDaemonCapture) -> Self { - History::new( - captured.timestamp, - captured.command, - captured.cwd, - -1, - -1, - Some(captured.session), - Some(captured.hostname), - captured.author, - captured.intent, - None, - ) - } -} diff --git a/crates/atuin-client/src/history/store.rs b/crates/atuin-client/src/history/store.rs deleted file mode 100644 index ce7b43a1..00000000 --- a/crates/atuin-client/src/history/store.rs +++ /dev/null @@ -1,434 +0,0 @@ -use std::{collections::HashSet, fmt::Write, time::Duration}; - -use eyre::{Result, bail, eyre}; -use indicatif::{ProgressBar, ProgressState, ProgressStyle}; -use rmp::decode::Bytes; - -use crate::{ - database::{Database, current_context}, - record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, -}; -use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; - -use super::{HISTORY_TAG, HISTORY_VERSION, HISTORY_VERSION_V0, History, HistoryId}; - -#[derive(Debug, Clone)] -pub struct HistoryStore { - pub store: SqliteStore, - pub host_id: HostId, - pub encryption_key: [u8; 32], -} - -#[derive(Debug, Eq, PartialEq, Clone)] -pub enum HistoryRecord { - Create(History), // Create a history record - Delete(HistoryId), // Delete a history record, identified by ID -} - -impl HistoryRecord { - /// Serialize a history record, returning DecryptedData - /// The record will be of a certain type - /// We map those like so: - /// - /// HistoryRecord::Create -> 0 - /// HistoryRecord::Delete-> 1 - /// - /// This numeric identifier is then written as the first byte to the buffer. For history, we - /// append the serialized history right afterwards, to avoid having to handle serialization - /// twice. - /// - /// Deletion simply refers to the history by ID - pub fn serialize(&self) -> Result { - // probably don't actually need to use rmp here, but if we ever need to extend it, it's a - // nice wrapper around raw byte stuff - use rmp::encode; - - let mut output = vec![]; - - match self { - HistoryRecord::Create(history) => { - // 0 -> a history create - encode::write_u8(&mut output, 0)?; - - let bytes = history.serialize()?; - - encode::write_bin(&mut output, &bytes.0)?; - } - HistoryRecord::Delete(id) => { - // 1 -> a history delete - encode::write_u8(&mut output, 1)?; - encode::write_str(&mut output, id.0.as_str())?; - } - }; - - Ok(DecryptedData(output)) - } - - pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result { - use rmp::decode; - - fn error_report(err: E) -> eyre::Report { - eyre!("{err:?}") - } - - let mut bytes = Bytes::new(&bytes.0); - - let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; - - match record_type { - // 0 -> HistoryRecord::Create - 0 => { - // not super useful to us atm, but perhaps in the future - // written by write_bin above - let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; - - let record = History::deserialize(bytes.remaining_slice(), version)?; - - Ok(HistoryRecord::Create(record)) - } - - // 1 -> HistoryRecord::Delete - 1 => { - let bytes = bytes.remaining_slice(); - let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; - - if !bytes.is_empty() { - bail!( - "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" - ); - } - - Ok(HistoryRecord::Delete(id.to_string().into())) - } - - n => { - bail!("unknown HistoryRecord type {n}") - } - } - } -} - -impl HistoryStore { - pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { - HistoryStore { - store, - host_id, - encryption_key, - } - } - - async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { - let bytes = record.serialize()?; - let idx = self - .store - .last(self.host_id, HISTORY_TAG) - .await? - .map_or(0, |p| p.idx + 1); - - let record = Record::builder() - .host(Host::new(self.host_id)) - .version(HISTORY_VERSION.to_string()) - .tag(HISTORY_TAG.to_string()) - .idx(idx) - .data(bytes) - .build(); - - let id = record.id; - - self.store - .push(&record.encrypt::(&self.encryption_key)) - .await?; - - Ok((id, idx)) - } - - async fn push_batch(&self, records: impl Iterator) -> Result<()> { - let mut ret = Vec::new(); - - let idx = self - .store - .last(self.host_id, HISTORY_TAG) - .await? - .map_or(0, |p| p.idx + 1); - - // Could probably _also_ do this as an iterator, but let's see how this is for now. - // optimizing for minimal sqlite transactions, this code can be optimised later - for (n, record) in records.enumerate() { - let bytes = record.serialize()?; - - let record = Record::builder() - .host(Host::new(self.host_id)) - .version(HISTORY_VERSION.to_string()) - .tag(HISTORY_TAG.to_string()) - .idx(idx + n as u64) - .data(bytes) - .build(); - - let record = record.encrypt::(&self.encryption_key); - - ret.push(record); - } - - self.store.push_batch(ret.iter()).await?; - - Ok(()) - } - - pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { - let record = HistoryRecord::Delete(id); - - self.push_record(record).await - } - - /// Delete a batch of history entries via the record store. - /// Returns the record IDs so the caller can run incremental_build when ready. - pub async fn delete_entries( - &self, - entries: impl IntoIterator, - ) -> Result> { - let mut record_ids = Vec::new(); - for entry in entries { - let (id, _) = self.delete(entry.id).await?; - record_ids.push(id); - } - Ok(record_ids) - } - - pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { - // TODO(ellie): move the history store to its own file - // it's tiny rn so fine as is - let record = HistoryRecord::Create(history); - - self.push_record(record).await - } - - pub async fn history(&self) -> Result> { - // Atm this loads all history into memory - // Not ideal as that is potentially quite a lot, although history will be small. - let records = self.store.all_tagged(HISTORY_TAG).await?; - let mut ret = Vec::with_capacity(records.len()); - - for record in records.into_iter() { - let hist = match record.version.as_str() { - HISTORY_VERSION_V0 | HISTORY_VERSION => { - let version = record.version.clone(); - let decrypted = record.decrypt::(&self.encryption_key)?; - - HistoryRecord::deserialize(&decrypted.data, version.as_str()) - } - version => bail!("unknown history version {version:?}"), - }?; - - ret.push(hist); - } - - Ok(ret) - } - - pub async fn build(&self, database: &dyn Database) -> Result<()> { - // I'd like to change how we rebuild and not couple this with the database, but need to - // consider the structure more deeply. This will be easy to change. - - // TODO(ellie): page or iterate this - let history = self.history().await?; - - // In theory we could flatten this here - // The current issue is that the database may have history in it already, from the old sync - // This didn't actually delete old history - // If we're sure we have a DB only maintained by the new store, we can flatten - // create/delete before we even get to sqlite - let mut creates = Vec::new(); - let mut deletes = Vec::new(); - - for i in history { - match i { - HistoryRecord::Create(h) => { - creates.push(h); - } - HistoryRecord::Delete(id) => { - deletes.push(id); - } - } - } - - database.save_bulk(&creates).await?; - database.delete_rows(&deletes).await?; - - Ok(()) - } - - pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { - for id in ids { - let record = self.store.get(*id).await; - - let record = match record { - Ok(record) => record, - _ => { - continue; - } - }; - - if record.tag != HISTORY_TAG { - continue; - } - - let version = record.version.clone(); - let decrypted = record.decrypt::(&self.encryption_key)?; - let record = match version.as_str() { - HISTORY_VERSION_V0 | HISTORY_VERSION => { - HistoryRecord::deserialize(&decrypted.data, version.as_str())? - } - version => bail!("unknown history version {version:?}"), - }; - - match record { - HistoryRecord::Create(h) => { - // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time - database.save(&h).await?; - } - HistoryRecord::Delete(id) => { - database.delete_rows(&[id]).await?; - } - } - } - - Ok(()) - } - - /// Get a list of history IDs that exist in the store - /// Note: This currently involves loading all history into memory. This is not going to be a - /// large amount in absolute terms, but do not all it in a hot loop. - pub async fn history_ids(&self) -> Result> { - let history = self.history().await?; - - let ret = HashSet::from_iter(history.iter().map(|h| match h { - HistoryRecord::Create(h) => h.id.clone(), - HistoryRecord::Delete(id) => id.clone(), - })); - - Ok(ret) - } - - pub async fn init_store(&self, db: &impl Database) -> Result<()> { - let pb = ProgressBar::new_spinner(); - pb.set_style( - ProgressStyle::with_template("{spinner:.blue} {msg}") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { - write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() - }) - .progress_chars("#>-"), - ); - pb.enable_steady_tick(Duration::from_millis(500)); - - pb.set_message("Fetching history from old database"); - - let context = current_context().await?; - let history = db.list(&[], &context, None, false, true).await?; - - pb.set_message("Fetching history already in store"); - let store_ids = self.history_ids().await?; - - pb.set_message("Converting old history to new store"); - let mut records = Vec::new(); - - for i in history { - debug!("loaded {}", i.id); - - if store_ids.contains(&i.id) { - debug!("skipping {} - already exists", i.id); - continue; - } - - if i.deleted_at.is_some() { - records.push(HistoryRecord::Delete(i.id)); - } else { - records.push(HistoryRecord::Create(i)); - } - } - - pb.set_message("Writing to db"); - - if !records.is_empty() { - self.push_batch(records.into_iter()).await?; - } - - pb.finish_with_message("Import complete"); - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use atuin_common::record::DecryptedData; - use time::macros::datetime; - - use crate::history::{HISTORY_VERSION, store::HistoryRecord}; - - use super::History; - - #[test] - fn test_serialize_deserialize_create() { - let bytes = [ - 204, 0, 196, 147, 205, 0, 1, 154, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, - 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, - 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, - 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, - 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, - 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, - 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, - 58, 101, 108, 108, 105, 101, 192, 165, 101, 108, 108, 105, 101, - ]; - - let history = History { - id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), - timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), - duration: 100, - exit: 0, - command: "ls".to_owned(), - cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), - session: "018cd4fead897597852527a31c998059".to_owned(), - hostname: "boop:ellie".to_owned(), - author: "ellie".to_owned(), - intent: None, - deleted_at: None, - }; - - let record = HistoryRecord::Create(history); - - let serialized = record.serialize().expect("failed to serialize history"); - assert_eq!(serialized.0, bytes); - - let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - - // check the snapshot too - let deserialized = - HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - } - - #[test] - fn test_serialize_deserialize_delete() { - let bytes = [ - 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, - 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, - ]; - let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); - - let serialized = record.serialize().expect("failed to serialize history"); - assert_eq!(serialized.0, bytes); - - let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - - let deserialized = - HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) - .expect("failed to deserialize HistoryRecord"); - assert_eq!(deserialized, record); - } -} diff --git a/crates/atuin-client/src/import/bash.rs b/crates/atuin-client/src/import/bash.rs deleted file mode 100644 index 99a44a58..00000000 --- a/crates/atuin-client/src/import/bash.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use itertools::Itertools; -use time::{Duration, OffsetDateTime}; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Bash { - bytes: Vec, -} - -fn default_histpath() -> Result { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".bash_history")) -} - -#[async_trait] -impl Importer for Bash { - const NAME: &'static str = "bash"; - - async fn new() -> Result { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - let count = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| matches!(line, LineType::Command(_))) - .count(); - Ok(count) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let lines = unix_byte_lines(&self.bytes) - .map(LineType::from) - .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored - .collect_vec(); - - let (commands_before_first_timestamp, first_timestamp) = lines - .iter() - .enumerate() - .find_map(|(i, line)| match line { - LineType::Timestamp(t) => Some((i, *t)), - _ => None, - }) - // if no known timestamps, use now as base - .unwrap_or((lines.len(), OffsetDateTime::now_utc())); - - // if no timestamp is recorded, then use this increment to set an arbitrary timestamp - // to preserve ordering - // this increment is deliberately very small to prevent particularly fast fingers - // causing ordering issues; it also helps in handling the "here document" syntax, - // where several lines are recorded in succession without individual timestamps - let timestamp_increment = Duration::milliseconds(1); - - // make sure there is a minimum amount of time before the first known timestamp - // to fit all commands, given the default increment - let mut next_timestamp = - first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; - - for line in lines.into_iter() { - match line { - LineType::NotUtf8 => unreachable!(), // already filtered - LineType::Empty => {} // do nothing - LineType::Timestamp(t) => { - if t < next_timestamp { - warn!( - "Time reversal detected in Bash history! Commands may be ordered incorrectly." - ); - } - next_timestamp = t; - } - LineType::Command(c) => { - let imported = History::import().timestamp(next_timestamp).command(c); - - h.push(imported.build().into()).await?; - next_timestamp += timestamp_increment; - } - } - } - - Ok(()) - } -} - -#[derive(Debug, Clone)] -enum LineType<'a> { - NotUtf8, - /// Can happen when using the "here document" syntax. - Empty, - /// A timestamp line start with a '#', followed immediately by an integer - /// that represents seconds since UNIX epoch. - Timestamp(OffsetDateTime), - /// Anything else. - Command(&'a str), -} -impl<'a> From<&'a [u8]> for LineType<'a> { - fn from(bytes: &'a [u8]) -> Self { - let Ok(line) = str::from_utf8(bytes) else { - return LineType::NotUtf8; - }; - if line.is_empty() { - return LineType::Empty; - } - - match try_parse_line_as_timestamp(line) { - Some(time) => LineType::Timestamp(time), - None => LineType::Command(line), - } - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option { - let seconds = line.strip_prefix('#')?.parse().ok()?; - OffsetDateTime::from_unix_timestamp(seconds).ok() -} - -#[cfg(test)] -mod test { - use std::cmp::Ordering; - - use itertools::{Itertools, assert_equal}; - - use crate::import::{Importer, tests::TestLoader}; - - use super::Bash; - - #[tokio::test] - async fn parse_no_timestamps() { - let bytes = r"cargo install atuin -cargo update -cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - #[tokio::test] - async fn parse_with_timestamps() { - let bytes = b"#1672918999 -git reset -#1672919006 -git clean -dxf -#1672919020 -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert_equal( - loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), - [1672918999, 1672919006, 1672919020], - ) - } - - #[tokio::test] - async fn parse_with_partial_timestamps() { - let bytes = b"git reset -#1672919006 -git clean -dxf -cd ../ -" - .to_vec(); - - let mut bash = Bash { bytes }; - assert_eq!(bash.entries().await.unwrap(), 3); - - let mut loader = TestLoader::default(); - bash.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["git reset", "git clean -dxf", "cd ../"], - ); - assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) - } - - fn is_strictly_sorted(iter: impl IntoIterator) -> bool - where - T: Clone + PartialOrd, - { - iter.into_iter() - .tuple_windows() - .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) - } -} diff --git a/crates/atuin-client/src/import/fish.rs b/crates/atuin-client/src/import/fish.rs deleted file mode 100644 index 9fcf624c..00000000 --- a/crates/atuin-client/src/import/fish.rs +++ /dev/null @@ -1,179 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Fish { - bytes: Vec, -} - -/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history -fn default_histpath() -> Result { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let data = std::env::var("XDG_DATA_HOME").map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ); - - // fish supports multiple history sessions - // If `fish_history` var is missing, or set to `default`, use `fish` as the session - let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); - let session = if session == "default" { - String::from("fish") - } else { - session - }; - - let mut histpath = data.join("fish"); - histpath.push(format!("{session}_history")); - - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Fish { - const NAME: &'static str = "fish"; - - async fn new() -> Result { - let bytes = read_to_end(default_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut time: Option = None; - let mut cmd: Option = None; - - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - if let Some(c) = s.strip_prefix("- cmd: ") { - // first, we must deal with the prev cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - // using raw strings to avoid needing escaping. - // replaces double backslashes with single backslashes - let c = c.replace(r"\\", r"\"); - // replaces escaped newlines - let c = c.replace(r"\n", "\n"); - // TODO: any other escape characters? - - cmd = Some(c); - } else if let Some(t) = s.strip_prefix(" when: ") { - // if t is not an int, just ignore this line - if let Ok(t) = t.parse::() { - time = Some(OffsetDateTime::from_unix_timestamp(t)?); - } - } else { - // ... ignore paths lines - } - } - - // we might have a trailing cmd - if let Some(cmd) = cmd.take() { - let time = time.unwrap_or(now); - let entry = History::import().timestamp(time).command(cmd); - - loader.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Fish; - - #[tokio::test] - async fn parse_complex() { - // complicated input with varying contents and escaped strings. - let bytes = r#"- cmd: history --help - when: 1639162832 -- cmd: cat ~/.bash_history - when: 1639162851 - paths: - - ~/.bash_history -- cmd: ls ~/.local/share/fish/fish_history - when: 1639162890 - paths: - - ~/.local/share/fish/fish_history -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162893 - paths: - - ~/.local/share/fish/fish_history -ERROR -- CORRUPTED: ENTRY - CONTINUE: - - AS - - NORMAL -- cmd: echo "foo" \\\n'bar' baz - when: 1639162933 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639162939 - paths: - - ~/.local/share/fish/fish_history -- cmd: echo "\\"" \\\\ "\\\\" - when: 1639163063 -- cmd: cat ~/.local/share/fish/fish_history - when: 1639163066 - paths: - - ~/.local/share/fish/fish_history -"# - .as_bytes() - .to_owned(); - - let fish = Fish { bytes }; - - let mut loader = TestLoader::default(); - fish.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for fish history entry - macro_rules! fishtory { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - fishtory!(1639162832, "history --help"); - fishtory!(1639162851, "cat ~/.bash_history"); - fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); - fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); - fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); - fishtory!(1639163063, r#"echo "\"" \\ "\\""#); - fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); - } -} diff --git a/crates/atuin-client/src/import/mod.rs b/crates/atuin-client/src/import/mod.rs deleted file mode 100644 index 4a1c6af6..00000000 --- a/crates/atuin-client/src/import/mod.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::fs::File; -use std::io::Read; -use std::path::PathBuf; - -use async_trait::async_trait; -use eyre::{Result, bail}; -use memchr::Memchr; - -use crate::history::History; - -pub mod bash; -pub mod fish; -pub mod nu; -pub mod nu_histdb; -pub mod powershell; -pub mod replxx; -pub mod resh; -pub mod xonsh; -pub mod xonsh_sqlite; -pub mod zsh; -pub mod zsh_histdb; - -#[async_trait] -pub trait Importer: Sized { - const NAME: &'static str; - async fn new() -> Result; - async fn entries(&mut self) -> Result; - async fn load(self, loader: &mut impl Loader) -> Result<()>; -} - -#[async_trait] -pub trait Loader: Sync + Send { - async fn push(&mut self, hist: History) -> eyre::Result<()>; -} - -fn unix_byte_lines(input: &[u8]) -> impl Iterator { - UnixByteLines { - iter: memchr::memchr_iter(b'\n', input), - bytes: input, - i: 0, - } -} - -struct UnixByteLines<'a> { - iter: Memchr<'a>, - bytes: &'a [u8], - i: usize, -} - -impl<'a> Iterator for UnixByteLines<'a> { - type Item = &'a [u8]; - - fn next(&mut self) -> Option { - let j = self.iter.next()?; - let out = &self.bytes[self.i..j]; - self.i = j + 1; - Some(out) - } - - fn count(self) -> usize - where - Self: Sized, - { - self.iter.count() - } -} - -fn count_lines(input: &[u8]) -> usize { - unix_byte_lines(input).count() -} - -fn get_histpath(def: D) -> Result -where - D: FnOnce() -> Result, -{ - if let Ok(p) = std::env::var("HISTFILE") { - Ok(PathBuf::from(p)) - } else { - def() - } -} - -fn get_histfile_path(def: D) -> Result -where - D: FnOnce() -> Result, -{ - get_histpath(def).and_then(is_file) -} - -fn get_histdir_path(def: D) -> Result -where - D: FnOnce() -> Result, -{ - get_histpath(def).and_then(is_dir) -} - -fn read_to_end(path: PathBuf) -> Result> { - let mut bytes = Vec::new(); - let mut f = File::open(path)?; - f.read_to_end(&mut bytes)?; - Ok(bytes) -} -fn is_file(p: PathBuf) -> Result { - if p.is_file() { - Ok(p) - } else { - bail!( - "Could not find history file {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} -fn is_dir(p: PathBuf) -> Result { - if p.is_dir() { - Ok(p) - } else { - bail!( - "Could not find history directory {:?}. Try setting and exporting $HISTFILE", - p - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[derive(Default)] - pub struct TestLoader { - pub buf: Vec, - } - - #[async_trait] - impl Loader for TestLoader { - async fn push(&mut self, hist: History) -> Result<()> { - self.buf.push(hist); - Ok(()) - } - } -} diff --git a/crates/atuin-client/src/import/nu.rs b/crates/atuin-client/src/import/nu.rs deleted file mode 100644 index cae90ac4..00000000 --- a/crates/atuin-client/src/import/nu.rs +++ /dev/null @@ -1,67 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Nu { - bytes: Vec, -} - -fn get_histpath() -> Result { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histpath = config_dir.join("history.txt"); - if histpath.exists() { - Ok(histpath) - } else { - Err(eyre!("Could not find history file.")) - } -} - -#[async_trait] -impl Importer for Nu { - const NAME: &'static str = "nu"; - - async fn new() -> Result { - let bytes = read_to_end(get_histpath()?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - - let cmd: String = s.replace("<\\n>", "\n"); - - let offset = time::Duration::nanoseconds(counter); - counter += 1; - - let entry = History::import().timestamp(now - offset).command(cmd); - - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/nu_histdb.rs b/crates/atuin-client/src/import/nu_histdb.rs deleted file mode 100644 index a13cb2b4..00000000 --- a/crates/atuin-client/src/import/nu_histdb.rs +++ /dev/null @@ -1,113 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::{Duration, OffsetDateTime}; - -use super::Importer; -use crate::history::History; -use crate::import::Loader; - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntry { - pub id: i64, - pub command_line: Vec, - pub start_timestamp: i64, - pub session_id: i64, - pub hostname: Vec, - pub cwd: Vec, - pub duration_ms: i64, - pub exit_status: i64, - pub more_info: Vec, -} - -impl From for History { - fn from(histdb_item: HistDbEntry) -> Self { - let ts_secs = histdb_item.start_timestamp / 1000; - let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; - let imported = History::import() - .timestamp( - OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() - + Duration::nanoseconds(ts_ns), - ) - .command(String::from_utf8(histdb_item.command_line).unwrap()) - .cwd(String::from_utf8(histdb_item.cwd).unwrap()) - .exit(histdb_item.exit_status) - .duration(histdb_item.duration_ms) - .session(format!("{:x}", histdb_item.session_id)) - .hostname(String::from_utf8(histdb_item.hostname).unwrap()); - - imported.build().into() - } -} - -#[derive(Debug)] -pub struct NuHistDb { - histdb: Vec, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool) -> Result> { - let query = r#" - SELECT - id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, - more_info - FROM history - ORDER BY start_timestamp - "#; - let histdb_vec: Vec = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl NuHistDb { - pub fn histpath() -> Result { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - let config_dir = base.config_dir().join("nushell"); - - let histdb_path = config_dir.join("history.sqlite3"); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!("Could not find history file.")) - } - } -} - -#[async_trait] -impl Importer for NuHistDb { - // Not sure how this is used - const NAME: &'static str = "nu_histdb"; - - /// Creates a new NuHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result { - let dbpath = NuHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - }) - } - - async fn entries(&mut self) -> Result { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for i in self.histdb { - h.push(i.into()).await?; - } - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/powershell.rs b/crates/atuin-client/src/import/powershell.rs deleted file mode 100644 index 86fd007d..00000000 --- a/crates/atuin-client/src/import/powershell.rs +++ /dev/null @@ -1,202 +0,0 @@ -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use std::path::PathBuf; -use time::{Duration, OffsetDateTime}; - -use super::{Importer, Loader, count_lines, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct PowerShell { - bytes: Vec, - line_count: Option, -} - -fn get_history_path() -> Result { - let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; - - // The command line history in PowerShell is maintained by the PSReadLine module: - // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history - // - // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. - // > The history files are a file named `$($Host.Name)_history.txt`. - // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. - // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` - // > or `$Env:HOME/.local/share/powershell/PSReadLine`. - - let dir = if cfg!(windows) { - base.data_dir() - .join("Microsoft") - .join("Windows") - .join("PowerShell") - .join("PSReadLine") - } else { - std::env::var("XDG_DATA_HOME") - .map_or_else( - |_| base.home_dir().join(".local").join("share"), - PathBuf::from, - ) - .join("powershell") - .join("PSReadLine") - }; - - // The history is stored in a file named `$($Host.Name)_history.txt`. - // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: - // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks - - let file = dir.join("ConsoleHost_history.txt"); - - if file.is_file() { - Ok(file) - } else { - Err(eyre!("Could not find history file: {}", file.display())) - } -} - -#[async_trait] -impl Importer for PowerShell { - const NAME: &'static str = "PowerShell"; - - async fn new() -> Result { - let bytes = read_to_end(get_history_path()?)?; - Ok(Self { - bytes, - line_count: None, - }) - } - - async fn entries(&mut self) -> Result { - // Commands can be split over multiple lines, - // but this is only used for a progress bar, and multi-line commands - // should be quite rare, so this is not an issue in practice. - if self.line_count.is_none() { - self.line_count = Some(count_lines(&self.bytes)); - } - Ok(self.line_count.unwrap()) - } - - async fn load(mut self, h: &mut impl Loader) -> Result<()> { - let line_count = self.entries().await?; - let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); - - let mut counter = 0; - let mut iter = unix_byte_lines(&self.bytes); - - while let Some(s) = iter.next() { - let Ok(s) = read_line(s) else { - continue; // We can skip past things like invalid utf8 - }; - - let mut cmd = s.to_string(); - - // Multi-line commands end with a backtick, append the following lines. - while cmd.ends_with('`') { - cmd.pop(); - - let Some(next) = iter.next() else { - break; - }; - let Ok(next) = read_line(next) else { - break; - }; - - cmd.push('\n'); - cmd.push_str(next); - } - - if cmd.is_empty() { - continue; - } - - let offset = Duration::milliseconds(counter); - counter += 1; - - let entry = History::import().timestamp(start + offset).command(cmd); - h.push(entry.build().into()).await?; - } - - Ok(()) - } -} - -fn read_line(s: &[u8]) -> Result<&str> { - let s = str::from_utf8(s)?; - - // History is stored in CRLF on Windows, normalize the input to LF on all platforms. - let s = s.strip_suffix('\r').unwrap_or(s); - - Ok(s) -} - -#[cfg(test)] -mod test { - use super::*; - use crate::import::tests::TestLoader; - use itertools::assert_equal; - - const INPUT: &str = r#"cargo install atuin -cargo update -echo "first line` -second line` -` -last line" -echo foo - -echo bar -echo baz -"#; - - const EXPECTED: &[&str] = &[ - "cargo install atuin", - "cargo update", - "echo \"first line\nsecond line\n\nlast line\"", - "echo foo", - "echo bar", - "echo baz", - ]; - - #[tokio::test] - async fn test_import() { - let loader = import(INPUT).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_crlf() { - let input = INPUT.replace("\n", "\r\n"); - let loader = import(input.as_str()).await; - - let actual = loader.buf.iter().map(|h| h.command.clone()); - let expected = EXPECTED.iter().map(|s| s.to_string()); - - assert_equal(actual, expected); - } - - #[tokio::test] - async fn test_timestamps() { - let loader = import(INPUT).await; - - let mut prev = loader.buf.first().unwrap().timestamp; - for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { - assert!(current > prev); - prev = current; - } - } - - async fn import(input: &str) -> TestLoader { - let powershell = PowerShell { - bytes: input.as_bytes().to_vec(), - line_count: None, - }; - - let mut loader = TestLoader::default(); - powershell.load(&mut loader).await.unwrap(); - loader - } -} diff --git a/crates/atuin-client/src/import/replxx.rs b/crates/atuin-client/src/import/replxx.rs deleted file mode 100644 index 47d566cf..00000000 --- a/crates/atuin-client/src/import/replxx.rs +++ /dev/null @@ -1,137 +0,0 @@ -use std::{path::PathBuf, str}; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Replxx { - bytes: Vec, -} - -fn default_histpath() -> Result { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - // There is no default histfile for replxx. - // Here we try a couple of common names. - let mut candidates = ["replxx_history.txt", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Replxx { - const NAME: &'static str = "replxx"; - - async fn new() -> Result { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - Ok(super::count_lines(&self.bytes) / 2) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut timestamp = OffsetDateTime::UNIX_EPOCH; - - for b in unix_byte_lines(&self.bytes) { - let s = std::str::from_utf8(b)?; - match try_parse_line_as_timestamp(s) { - Some(t) => timestamp = t, - None => { - // replxx uses ETB character (0x17) as line breaker - let cmd = s.replace('\u{0017}', "\n"); - let imported = History::import().timestamp(timestamp).command(cmd); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn try_parse_line_as_timestamp(line: &str) -> Option { - // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx - let date_time_str = line.strip_prefix("### ")?; - let format = - format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); - - let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; - // There is no safe way to get local time offset. - // For simplicity let's just assume UTC. - Some(primitive_date_time.assume_utc()) -} - -#[cfg(test)] -mod test { - - use crate::import::{Importer, tests::TestLoader}; - - use super::Replxx; - - #[tokio::test] - async fn parse_complex() { - let bytes = r#"### 2024-02-10 22:16:28.302 -select * from remote('127.0.0.1:20222', view(select 1)) -### 2024-02-10 22:16:36.919 -select * from numbers(10) -### 2024-02-10 22:16:41.710 -select * from system.numbers -### 2024-02-10 22:19:28.655 -select 1 -### 2024-02-22 11:15:33.046 -CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); -"# - .as_bytes() - .to_owned(); - - let replxx = Replxx { bytes }; - - let mut loader = TestLoader::default(); - replxx.load(&mut loader).await.unwrap(); - let mut history = loader.buf.into_iter(); - - // simple wrapper for replxx history entry - macro_rules! history { - ($timestamp:expr_2021, $command:expr_2021) => { - let h = history.next().expect("missing entry in history"); - assert_eq!(h.command.as_str(), $command); - assert_eq!(h.timestamp.unix_timestamp(), $timestamp); - }; - } - - history!( - 1707603388, - "select * from remote('127.0.0.1:20222', view(select 1))" - ); - history!(1707603396, "select * from numbers(10)"); - history!(1707603401, "select * from system.numbers"); - history!(1707603568, "select 1"); - history!( - 1708600533, - "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" - ); - } -} diff --git a/crates/atuin-client/src/import/resh.rs b/crates/atuin-client/src/import/resh.rs deleted file mode 100644 index df15f5b4..00000000 --- a/crates/atuin-client/src/import/resh.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; - -use atuin_common::utils::uuid_v7; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Deserialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct ReshEntry { - pub cmd_line: String, - pub exit_code: i64, - pub shell: String, - pub uname: String, - pub session_id: String, - pub home: String, - pub lang: String, - pub lc_all: String, - pub login: String, - pub pwd: String, - pub pwd_after: String, - pub shell_env: String, - pub term: String, - pub real_pwd: String, - pub real_pwd_after: String, - pub pid: i64, - pub session_pid: i64, - pub host: String, - pub hosttype: String, - pub ostype: String, - pub machtype: String, - pub shlvl: i64, - pub timezone_before: String, - pub timezone_after: String, - pub realtime_before: f64, - pub realtime_after: f64, - pub realtime_before_local: f64, - pub realtime_after_local: f64, - pub realtime_duration: f64, - pub realtime_since_session_start: f64, - pub realtime_since_boot: f64, - pub git_dir: String, - pub git_real_dir: String, - pub git_origin_remote: String, - pub git_dir_after: String, - pub git_real_dir_after: String, - pub git_origin_remote_after: String, - pub machine_id: String, - pub os_release_id: String, - pub os_release_version_id: String, - pub os_release_id_like: String, - pub os_release_name: String, - pub os_release_pretty_name: String, - pub resh_uuid: String, - pub resh_version: String, - pub resh_revision: String, - pub parts_merged: bool, - pub recalled: bool, - pub recall_last_cmd_line: String, - pub cols: String, - pub lines: String, -} - -#[derive(Debug)] -pub struct Resh { - bytes: Vec, -} - -fn default_histpath() -> Result { - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - Ok(home_dir.join(".resh_history.json")) -} - -#[async_trait] -impl Importer for Resh { - const NAME: &'static str = "resh"; - - async fn new() -> Result { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - for b in unix_byte_lines(&self.bytes) { - let s = match std::str::from_utf8(b) { - Ok(s) => s, - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let entry = match serde_json::from_str::(s) { - Ok(e) => e, - Err(_) => continue, // skip invalid json :shrug: - }; - - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let timestamp = { - let secs = entry.realtime_before.floor() as i64; - let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; - OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) - }; - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::cast_sign_loss)] - let duration = { - let secs = entry.realtime_after.floor() as i64; - let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; - let base = OffsetDateTime::from_unix_timestamp(secs)? - + time::Duration::nanoseconds(nanosecs); - let difference = base - timestamp; - difference.whole_nanoseconds() as i64 - }; - - let imported = History::import() - .command(entry.cmd_line) - .timestamp(timestamp) - .duration(duration) - .exit(entry.exit_code) - .cwd(entry.pwd) - .hostname(entry.host) - // CHECK: should we add uuid here? It's not set in the other importers - .session(uuid_v7().as_simple().to_string()); - - h.push(imported.build().into()).await?; - } - - Ok(()) - } -} diff --git a/crates/atuin-client/src/import/xonsh.rs b/crates/atuin-client/src/import/xonsh.rs deleted file mode 100644 index 6f38de68..00000000 --- a/crates/atuin-client/src/import/xonsh.rs +++ /dev/null @@ -1,234 +0,0 @@ -use std::env; -use std::fs::{self, File}; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use serde::Deserialize; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histdir_path}; -use crate::history::History; -use crate::utils::get_host_user; - -// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't -// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. -#[derive(Debug, Deserialize)] -struct HistoryFile { - data: HistoryData, -} - -#[derive(Debug, Deserialize)] -struct HistoryData { - sessionid: String, - cmds: Vec, -} - -#[derive(Debug, Deserialize)] -struct HistoryCmd { - cwd: String, - inp: String, - rtn: Option, - ts: (f64, f64), -} - -#[derive(Debug)] -pub struct Xonsh { - // history is stored as a bunch of json files, one per session - sessions: Vec, - hostname: String, -} - -fn xonsh_hist_dir(xonsh_data_dir: Option) -> Result { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("history_json"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_dir = base.data_dir().join("xonsh/history_json"); - if hist_dir.exists() || cfg!(test) { - Ok(hist_dir) - } else { - Err(eyre!("Could not find xonsh history files")) - } -} - -fn load_sessions(hist_dir: &Path) -> Result> { - let mut sessions = vec![]; - for entry in fs::read_dir(hist_dir)? { - let p = entry?.path(); - let ext = p.extension().and_then(|e| e.to_str()); - if p.is_file() - && ext == Some("json") - && let Some(data) = load_session(&p)? - { - sessions.push(data); - } - } - Ok(sessions) -} - -fn load_session(path: &Path) -> Result> { - let file = File::open(path)?; - // empty files are not valid json, so we can't deserialize them - if file.metadata()?.len() == 0 { - return Ok(None); - } - - let mut hist_file: HistoryFile = serde_json::from_reader(file)?; - - // if there are commands in this session, replace the existing UUIDv4 - // with a UUIDv7 generated from the timestamp of the first command - if let Some(cmd) = hist_file.data.cmds.first() { - let seconds = cmd.ts.0.trunc() as u64; - let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; - let ts = Timestamp::from_unix(NoContext, seconds, nanos); - hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); - } - Ok(Some(hist_file.data)) -} - -#[async_trait] -impl Importer for Xonsh { - const NAME: &'static str = "xonsh"; - - async fn new() -> Result { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; - let sessions = load_sessions(&hist_dir)?; - let hostname = get_host_user(); - Ok(Xonsh { sessions, hostname }) - } - - async fn entries(&mut self) -> Result { - let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); - Ok(total) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - for session in self.sessions { - for cmd in session.cmds { - let (start, end) = cmd.ts; - let ts_nanos = (start * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; - - let duration = (end - start) * 1_000_000_000_f64; - - match cmd.rtn { - Some(exit) => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - None => { - let entry = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(cmd.inp.trim()) - .cwd(cmd.cwd) - .session(session.sessionid.clone()) - .hostname(self.hostname.clone()); - loader.push(entry.build().into()).await?; - } - } - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_hist_dir_xonsh() { - let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - hist_dir, - PathBuf::from("/home/user/xonsh_data/history_json") - ); - } - - #[tokio::test] - async fn test_import() { - let dir = PathBuf::from("tests/data/xonsh"); - let sessions = load_sessions(&dir).unwrap(); - let hostname = "box:user".to_string(); - let xonsh = Xonsh { sessions, hostname }; - - let mut loader = TestLoader::default(); - xonsh.load(&mut loader).await.unwrap(); - // order in buf will depend on filenames, so sort by timestamp for consistency - loader.buf.sort_by_key(|h| h.timestamp); - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4651069) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(21288633) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(1) - .duration(10269403) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) - .exit(0) - .duration(4259347) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/atuin-client/src/import/xonsh_sqlite.rs b/crates/atuin-client/src/import/xonsh_sqlite.rs deleted file mode 100644 index 7d50ac84..00000000 --- a/crates/atuin-client/src/import/xonsh_sqlite.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::env; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::BaseDirs; -use eyre::{Result, eyre}; -use futures::TryStreamExt; -use sqlx::{FromRow, Row, sqlite::SqlitePool}; -use time::OffsetDateTime; -use uuid::Uuid; -use uuid::timestamp::{Timestamp, context::NoContext}; - -use super::{Importer, Loader, get_histfile_path}; -use crate::history::History; -use crate::utils::get_host_user; - -#[derive(Debug, FromRow)] -struct HistDbEntry { - inp: String, - rtn: Option, - tsb: f64, - tse: f64, - cwd: String, - session_start: f64, -} - -impl HistDbEntry { - fn into_hist_with_hostname(self, hostname: String) -> History { - let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); - - let session_ts_seconds = self.session_start.trunc() as u64; - let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; - let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); - let session_id = Uuid::new_v7(session_ts).to_string(); - let duration = (self.tse - self.tsb) * 1_000_000_000_f64; - - if let Some(exit) = self.rtn { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .exit(exit) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } else { - let imported = History::import() - .timestamp(timestamp) - .duration(duration.trunc() as i64) - .command(self.inp) - .cwd(self.cwd) - .session(session_id) - .hostname(hostname); - imported.build().into() - } - } -} - -fn xonsh_db_path(xonsh_data_dir: Option) -> Result { - // if running within xonsh, this will be available - if let Some(d) = xonsh_data_dir { - let mut path = PathBuf::from(d); - path.push("xonsh-history.sqlite"); - return Ok(path); - } - - // otherwise, fall back to default - let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; - - let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); - if hist_file.exists() || cfg!(test) { - Ok(hist_file) - } else { - Err(eyre!( - "Could not find xonsh history db at: {}", - hist_file.to_string_lossy() - )) - } -} - -#[derive(Debug)] -pub struct XonshSqlite { - pool: SqlitePool, - hostname: String, -} - -#[async_trait] -impl Importer for XonshSqlite { - const NAME: &'static str = "xonsh_sqlite"; - - async fn new() -> Result { - // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH - let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); - let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; - let connection_str = db_path.to_str().ok_or_else(|| { - eyre!( - "Invalid path for SQLite database: {}", - db_path.to_string_lossy() - ) - })?; - - let pool = SqlitePool::connect(connection_str).await?; - let hostname = get_host_user(); - Ok(XonshSqlite { pool, hostname }) - } - - async fn entries(&mut self) -> Result { - let query = "SELECT COUNT(*) FROM xonsh_history"; - let row = sqlx::query(query).fetch_one(&self.pool).await?; - let count: u32 = row.get(0); - Ok(count as usize) - } - - async fn load(self, loader: &mut impl Loader) -> Result<()> { - let query = r#" - SELECT inp, rtn, tsb, tse, cwd, - MIN(tsb) OVER (PARTITION BY sessionid) AS session_start - FROM xonsh_history - ORDER BY rowid - "#; - - let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); - - let mut count = 0; - while let Some(entry) = entries.try_next().await? { - let hist = entry.into_hist_with_hostname(self.hostname.clone()); - loader.push(hist).await?; - count += 1; - } - - println!("Loaded: {count}"); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use super::*; - - use crate::history::History; - use crate::import::tests::TestLoader; - - #[test] - fn test_db_path_xonsh() { - let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); - assert_eq!( - db_path, - PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") - ); - } - - #[tokio::test] - async fn test_import() { - let connection_str = "tests/data/xonsh-history.sqlite"; - let xonsh_sqlite = XonshSqlite { - pool: SqlitePool::connect(connection_str).await.unwrap(), - hostname: "box:user".to_string(), - }; - - let mut loader = TestLoader::default(); - xonsh_sqlite.load(&mut loader).await.unwrap(); - - for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.command, expected.command); - assert_eq!(actual.cwd, expected.cwd); - assert_eq!(actual.exit, expected.exit); - assert_eq!(actual.duration, expected.duration); - assert_eq!(actual.hostname, expected.hostname); - } - } - - fn expected_hist_entries() -> [History; 4] { - [ - History::import() - .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) - .command("echo hello world!".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(2628564) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) - .command("ls -l".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(9371519) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) - .command("false".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(1) - .duration(17337560) - .hostname("box:user".to_string()) - .build() - .into(), - History::import() - .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) - .command("exit".to_string()) - .cwd("/home/user/Documents/code/atuin".to_string()) - .exit(0) - .duration(4599094) - .hostname("box:user".to_string()) - .build() - .into(), - ] - } -} diff --git a/crates/atuin-client/src/import/zsh.rs b/crates/atuin-client/src/import/zsh.rs deleted file mode 100644 index 11e2f371..00000000 --- a/crates/atuin-client/src/import/zsh.rs +++ /dev/null @@ -1,230 +0,0 @@ -// import old shell history! -// automatically hoover up all that we can find - -use std::borrow::Cow; -use std::path::PathBuf; - -use async_trait::async_trait; -use directories::UserDirs; -use eyre::{Result, eyre}; -use time::OffsetDateTime; - -use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; -use crate::history::History; -use crate::import::read_to_end; - -#[derive(Debug)] -pub struct Zsh { - bytes: Vec, -} - -fn default_histpath() -> Result { - // oh-my-zsh sets HISTFILE=~/.zhistory - // zsh has no default value for this var, but uses ~/.zhistory. - // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 - // we could maybe be smarter about this in the future :) - let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; - let home_dir = user_dirs.home_dir(); - - let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); - loop { - match candidates.next() { - Some(candidate) => { - let histpath = home_dir.join(candidate); - if histpath.exists() { - break Ok(histpath); - } - } - None => { - break Err(eyre!( - "Could not find history file. Try setting and exporting $HISTFILE" - )); - } - } - } -} - -#[async_trait] -impl Importer for Zsh { - const NAME: &'static str = "zsh"; - - async fn new() -> Result { - let bytes = read_to_end(get_histfile_path(default_histpath)?)?; - Ok(Self { bytes }) - } - - async fn entries(&mut self) -> Result { - Ok(super::count_lines(&self.bytes)) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let now = OffsetDateTime::now_utc(); - let mut line = String::new(); - - let mut counter = 0; - for b in unix_byte_lines(&self.bytes) { - let s = match unmetafy(b) { - Some(s) => s, - _ => continue, // we can skip past things like invalid utf8 - }; - - if let Some(s) = s.strip_suffix('\\') { - line.push_str(s); - line.push('\n'); - } else { - line.push_str(&s); - let command = std::mem::take(&mut line); - - if let Some(command) = command.strip_prefix(": ") { - counter += 1; - h.push(parse_extended(command, counter)).await?; - } else { - let offset = time::Duration::seconds(counter); - counter += 1; - - let imported = History::import() - // preserve ordering - .timestamp(now - offset) - .command(command.trim_end().to_string()); - - h.push(imported.build().into()).await?; - } - } - } - - Ok(()) - } -} - -fn parse_extended(line: &str, counter: i64) -> History { - let (time, duration) = line.split_once(':').unwrap(); - let (duration, command) = duration.split_once(';').unwrap(); - - let time = time - .parse::() - .ok() - .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) - .unwrap_or_else(OffsetDateTime::now_utc) - + time::Duration::milliseconds(counter); - - // use nanos, because why the hell not? we won't display them. - let duration = duration.parse::().map_or(-1, |t| t * 1_000_000_000); - - let imported = History::import() - .timestamp(time) - .command(command.trim_end().to_string()) - .duration(duration); - - imported.build().into() -} - -fn unmetafy(line: &[u8]) -> Option> { - if line.contains(&0x83) { - let mut s = Vec::with_capacity(line.len()); - let mut is_meta = false; - for ch in line { - if *ch == 0x83 { - is_meta = true; - } else if is_meta { - is_meta = false; - s.push(*ch ^ 32); - } else { - s.push(*ch) - } - } - String::from_utf8(s).ok().map(Cow::Owned) - } else { - std::str::from_utf8(line).ok().map(Cow::Borrowed) - } -} - -#[cfg(test)] -mod test { - use itertools::assert_equal; - - use crate::import::tests::TestLoader; - - use super::*; - - #[test] - fn test_parse_extended_simple() { - let parsed = parse_extended("1613322469:0;cargo install atuin", 0); - - assert_eq!(parsed.command, "cargo install atuin"); - assert_eq!(parsed.duration, 0); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); - - assert_eq!(parsed.command, "cargo install atuin;cargo update"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); - - assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - - let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); - - assert_eq!(parsed.command, "cargo install \\n atuin"); - assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!( - parsed.timestamp, - OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() - ); - } - - #[tokio::test] - async fn test_parse_file() { - let bytes = r": 1613322469:0;cargo install atuin -: 1613322469:10;cargo install atuin; \\ -cargo update -: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ -" - .as_bytes() - .to_owned(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 4); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - [ - "cargo install atuin", - "cargo install atuin; \\\ncargo update", - "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", - ], - ); - } - - #[tokio::test] - async fn test_parse_metafied() { - let bytes = - b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); - - let mut zsh = Zsh { bytes }; - assert_eq!(zsh.entries().await.unwrap(), 2); - - let mut loader = TestLoader::default(); - zsh.load(&mut loader).await.unwrap(); - - assert_equal( - loader.buf.iter().map(|h| h.command.as_str()), - ["echo 你好", "ls ~/音乐"], - ); - } -} diff --git a/crates/atuin-client/src/import/zsh_histdb.rs b/crates/atuin-client/src/import/zsh_histdb.rs deleted file mode 100644 index bf44c3ad..00000000 --- a/crates/atuin-client/src/import/zsh_histdb.rs +++ /dev/null @@ -1,249 +0,0 @@ -// import old shell history from zsh-histdb! -// automatically hoover up all that we can find - -// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based -// on the schema from 2022-05-01 -// -// I have run into some histories that will not import b/c of non UTF-8 characters. -// - -// -// An Example sqlite query for hsitdb data: -// -//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir -// -// -// select -// history.id, -// history.start_time, -// places.host, -// places.dir, -// commands.argv -// from history -// left join commands on history.command_id = commands.id -// left join places on history.place_id = places.id ; -// -// CREATE TABLE history (id integer primary key autoincrement, -// session int, -// command_id int references commands (id), -// place_id int references places (id), -// exit_status int, -// start_time int, -// duration int); -// - -use std::collections::HashMap; -use std::path::{Path, PathBuf}; - -use async_trait::async_trait; -use atuin_common::utils::uuid_v7; -use directories::UserDirs; -use eyre::{Result, eyre}; -use sqlx::{Pool, sqlite::SqlitePool}; -use time::PrimitiveDateTime; - -use super::Importer; -use crate::history::History; -use crate::import::Loader; -use crate::utils::{get_hostname, get_username}; - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntryCount { - pub count: usize, -} - -#[derive(sqlx::FromRow, Debug)] -pub struct HistDbEntry { - pub id: i64, - pub start_time: PrimitiveDateTime, - pub host: Vec, - pub dir: Vec, - pub argv: Vec, - pub duration: i64, - pub exit_status: i64, - pub session: i64, -} - -#[derive(Debug)] -pub struct ZshHistDb { - histdb: Vec, - username: String, -} - -/// Read db at given file, return vector of entries. -async fn hist_from_db(dbpath: PathBuf) -> Result> { - let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; - hist_from_db_conn(pool).await -} - -async fn hist_from_db_conn(pool: Pool) -> Result> { - let query = r#" - SELECT - history.id, history.start_time, history.duration, places.host, places.dir, - commands.argv, history.exit_status, history.session - FROM history - LEFT JOIN commands ON history.command_id = commands.id - LEFT JOIN places ON history.place_id = places.id - ORDER BY history.start_time - "#; - let histdb_vec: Vec = sqlx::query_as::<_, HistDbEntry>(query) - .fetch_all(&pool) - .await?; - Ok(histdb_vec) -} - -impl ZshHistDb { - pub fn histpath_candidate() -> PathBuf { - // By default histdb database is `${HOME}/.histdb/zsh-history.db` - // This can be modified by ${HISTDB_FILE} - // - // if [[ -z ${HISTDB_FILE} ]]; then - // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" - let user_dirs = UserDirs::new().unwrap(); // should catch error here? - let home_dir = user_dirs.home_dir(); - std::env::var("HISTDB_FILE") - .as_ref() - .map(|x| Path::new(x).to_path_buf()) - .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) - } - pub fn histpath() -> Result { - let histdb_path = ZshHistDb::histpath_candidate(); - if histdb_path.exists() { - Ok(histdb_path) - } else { - Err(eyre!( - "Could not find history file. Try setting $HISTDB_FILE" - )) - } - } -} - -#[async_trait] -impl Importer for ZshHistDb { - // Not sure how this is used - const NAME: &'static str = "zsh_histdb"; - - /// Creates a new ZshHistDb and populates the history based on the pre-populated data - /// structure. - async fn new() -> Result { - let dbpath = ZshHistDb::histpath()?; - let histdb_entry_vec = hist_from_db(dbpath).await?; - Ok(Self { - histdb: histdb_entry_vec, - username: get_username(), - }) - } - - async fn entries(&mut self) -> Result { - Ok(self.histdb.len()) - } - - async fn load(self, h: &mut impl Loader) -> Result<()> { - let mut session_map = HashMap::new(); - for entry in self.histdb { - let command = match std::str::from_utf8(&entry.argv) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let cwd = match std::str::from_utf8(&entry.dir) { - Ok(s) => s.trim_end(), - Err(_) => continue, // we can skip past things like invalid utf8 - }; - let hostname = format!( - "{}:{}", - String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), - self.username - ); - let session = session_map.entry(entry.session).or_insert_with(uuid_v7); - - let imported = History::import() - .timestamp(entry.start_time.assume_utc()) - .command(command) - .cwd(cwd) - .duration(entry.duration * 1_000_000_000) - .exit(entry.exit_status) - .session(session.as_simple().to_string()) - .hostname(hostname) - .build(); - h.push(imported.into()).await?; - } - Ok(()) - } -} - -#[cfg(test)] -mod test { - - use super::*; - use sqlx::sqlite::SqlitePoolOptions; - use std::env; - #[tokio::test(flavor = "multi_thread")] - #[expect(unsafe_code)] - async fn test_env_vars() { - let test_env_db = "nonstd-zsh-history.db"; - let key = "HISTDB_FILE"; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var(key, test_env_db) }; - - // test the env got set - assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); - - // test histdb returns the proper db from previous step - let histdb_path = ZshHistDb::histpath_candidate(); - assert_eq!(histdb_path.to_str().unwrap(), test_env_db); - } - - #[tokio::test(flavor = "multi_thread")] - async fn test_import() { - let pool: SqlitePool = SqlitePoolOptions::new() - .min_connections(2) - .connect(":memory:") - .await - .unwrap(); - - // sql dump directly from a test database. - let db_sql = r#" - PRAGMA foreign_keys=OFF; - BEGIN TRANSACTION; - CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); - INSERT INTO commands VALUES(1,'pwd'); - INSERT INTO commands VALUES(2,'curl google.com'); - INSERT INTO commands VALUES(3,'bash'); - CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); - INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); - CREATE TABLE history (id integer primary key autoincrement, - session int, - command_id int references commands (id), - place_id int references places (id), - exit_status int, - start_time int, - duration int); - INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); - INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); - INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); - DELETE FROM sqlite_sequence; - INSERT INTO sqlite_sequence VALUES('commands',3); - INSERT INTO sqlite_sequence VALUES('places',3); - INSERT INTO sqlite_sequence VALUES('history',3); - CREATE INDEX hist_time on history(start_time); - CREATE INDEX place_dir on places(dir); - CREATE INDEX place_host on places(host); - CREATE INDEX history_command_place on history(command_id, place_id); - COMMIT; "#; - - sqlx::query(db_sql).execute(&pool).await.unwrap(); - - // test histdb iterator - let histdb_vec = hist_from_db_conn(pool).await.unwrap(); - let histdb = ZshHistDb { - histdb: histdb_vec, - username: get_username(), - }; - - println!("h: {:#?}", histdb.histdb); - println!("counter: {:?}", histdb.histdb.len()); - for i in histdb.histdb { - println!("{i:?}"); - } - } -} diff --git a/crates/atuin-client/src/lib.rs b/crates/atuin-client/src/lib.rs deleted file mode 100644 index cd7785e1..00000000 --- a/crates/atuin-client/src/lib.rs +++ /dev/null @@ -1,31 +0,0 @@ -#![deny(unsafe_code)] - -#[macro_use] -extern crate log; - -#[cfg(feature = "sync")] -pub mod api_client; -#[cfg(feature = "sync")] -pub mod auth; -#[cfg(feature = "sync")] -pub mod login; -#[cfg(feature = "sync")] -pub mod register; -#[cfg(feature = "sync")] -pub mod sync; - -pub mod database; -pub mod distro; -pub mod encryption; -pub mod history; -pub mod import; -pub mod logout; -pub mod meta; -pub mod ordering; -pub mod plugin; -pub mod record; -pub mod secrets; -pub mod settings; -pub mod theme; - -mod utils; diff --git a/crates/atuin-client/src/login.rs b/crates/atuin-client/src/login.rs deleted file mode 100644 index 2545e890..00000000 --- a/crates/atuin-client/src/login.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::path::PathBuf; - -use atuin_common::api::LoginRequest; -use eyre::{Context, Result, bail}; -use tokio::fs::File; -use tokio::io::AsyncWriteExt; - -use crate::{ - api_client, - encryption::{decode_key, load_key}, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, -}; - -pub async fn login( - settings: &Settings, - store: &SqliteStore, - username: String, - password: String, - key: String, -) -> Result { - let key_path = settings.key_path.as_str(); - let key_path = PathBuf::from(key_path); - - if !key_path.exists() { - if decode_key(key.clone()).is_err() { - bail!("the specified key was invalid"); - } - - let mut file = File::create(&key_path).await?; - file.write_all(key.as_bytes()).await?; - } else { - // we now know that the user has logged in specifying a key, AND that the key path - // exists - - // 1. check if the saved key and the provided key match. if so, nothing to do. - // 2. if not, re-encrypt the local history and overwrite the key - let current_key: [u8; 32] = load_key(settings)?.into(); - - let encoded = key.clone(); // gonna want to save it in a bit - let new_key: [u8; 32] = decode_key(key) - .context("could not decode provided key - is not valid base64")? - .into(); - - if new_key != current_key { - println!("\nRe-encrypting local store with new key"); - - store.re_encrypt(¤t_key, &new_key).await?; - - println!("Writing new key"); - let mut file = File::create(&key_path).await?; - file.write_all(encoded.as_bytes()).await?; - } - } - - let session = api_client::login( - settings.sync_address.as_str(), - LoginRequest { username, password }, - ) - .await?; - - Settings::meta_store() - .await? - .save_session(&session.session) - .await?; - - Ok(session.session) -} diff --git a/crates/atuin-client/src/logout.rs b/crates/atuin-client/src/logout.rs deleted file mode 100644 index f720b302..00000000 --- a/crates/atuin-client/src/logout.rs +++ /dev/null @@ -1,16 +0,0 @@ -use eyre::Result; - -use crate::settings::Settings; - -pub async fn logout() -> Result<()> { - let meta = Settings::meta_store().await?; - - if meta.logged_in().await? { - meta.delete_session().await?; - println!("You have logged out!"); - } else { - println!("You are not logged in"); - } - - Ok(()) -} diff --git a/crates/atuin-client/src/meta.rs b/crates/atuin-client/src/meta.rs deleted file mode 100644 index 870f36d0..00000000 --- a/crates/atuin-client/src/meta.rs +++ /dev/null @@ -1,365 +0,0 @@ -use std::path::Path; -use std::str::FromStr; -use std::time::Duration; - -use atuin_common::record::HostId; -use eyre::{Result, eyre}; -use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; -use time::{OffsetDateTime, format_description::well_known::Rfc3339}; -use tokio::sync::OnceCell; -use uuid::Uuid; - -// Filenames for the legacy plain-text files that we migrate from. -const LEGACY_HOST_ID_FILENAME: &str = "host_id"; -const LEGACY_LAST_SYNC_FILENAME: &str = "last_sync_time"; -const LEGACY_LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; -const LEGACY_LATEST_VERSION_FILENAME: &str = "latest_version"; -const LEGACY_SESSION_FILENAME: &str = "session"; - -const KEY_HOST_ID: &str = "host_id"; -const KEY_LAST_SYNC: &str = "last_sync_time"; -const KEY_LAST_VERSION_CHECK: &str = "last_version_check_time"; -const KEY_LATEST_VERSION: &str = "latest_version"; -const KEY_SESSION: &str = "session"; -const KEY_FILES_MIGRATED: &str = "files_migrated"; - -pub struct MetaStore { - pool: SqlitePool, - cached_host_id: OnceCell, -} - -impl MetaStore { - pub async fn new(path: impl AsRef, timeout: f64) -> Result { - let path = path.as_ref(); - let path_str = path - .as_os_str() - .to_str() - .ok_or_else(|| eyre!("meta database path is not valid UTF-8: {path:?}"))?; - debug!("opening meta sqlite database at {path:?}"); - - let is_memory = path_str.contains(":memory:"); - - if !is_memory - && !path.exists() - && let Some(dir) = path.parent() - { - fs_err::create_dir_all(dir)?; - } - - // Use DELETE journal mode instead of WAL. This is a small, infrequently- - // written KV store — WAL's concurrency benefits aren't needed, and DELETE - // mode avoids creating auxiliary -wal/-shm files that complicate - // permission handling. - let opts = SqliteConnectOptions::from_str(path_str)? - .journal_mode(SqliteJournalMode::Delete) - .optimize_on_close(true, None) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - sqlx::migrate!("./meta-migrations").run(&pool).await?; - - // Session tokens are stored in this database, so restrict permissions. - #[cfg(unix)] - if !is_memory { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; - } - - let store = Self { - pool, - cached_host_id: OnceCell::const_new(), - }; - - if !is_memory { - store.migrate_files().await?; - } - - Ok(store) - } - - // Generic key-value operations - - pub async fn get(&self, key: &str) -> Result> { - let row: Option<(String,)> = sqlx::query_as("SELECT value FROM meta WHERE key = ?1") - .bind(key) - .fetch_optional(&self.pool) - .await?; - - Ok(row.map(|r| r.0)) - } - - pub async fn set(&self, key: &str, value: &str) -> Result<()> { - sqlx::query( - "INSERT INTO meta (key, value, updated_at) VALUES (?1, ?2, strftime('%s', 'now')) - ON CONFLICT(key) DO UPDATE SET value = ?2, updated_at = strftime('%s', 'now')", - ) - .bind(key) - .bind(value) - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn delete(&self, key: &str) -> Result<()> { - sqlx::query("DELETE FROM meta WHERE key = ?1") - .bind(key) - .execute(&self.pool) - .await?; - - Ok(()) - } - - // Typed accessors - - pub async fn host_id(&self) -> Result { - self.cached_host_id - .get_or_try_init(|| async { - if let Some(id) = self.get(KEY_HOST_ID).await? { - let parsed = Uuid::from_str(id.as_str()) - .map_err(|e| eyre!("failed to parse host ID: {e}"))?; - return Ok(HostId(parsed)); - } - - let uuid = atuin_common::utils::uuid_v7(); - self.set(KEY_HOST_ID, uuid.as_simple().to_string().as_ref()) - .await?; - - Ok(HostId(uuid)) - }) - .await - .copied() - } - - pub async fn last_sync(&self) -> Result { - match self.get(KEY_LAST_SYNC).await? { - Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), - None => Ok(OffsetDateTime::UNIX_EPOCH), - } - } - - pub async fn save_sync_time(&self) -> Result<()> { - self.set( - KEY_LAST_SYNC, - OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), - ) - .await - } - - pub async fn last_version_check(&self) -> Result { - match self.get(KEY_LAST_VERSION_CHECK).await? { - Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), - None => Ok(OffsetDateTime::UNIX_EPOCH), - } - } - - pub async fn save_version_check_time(&self) -> Result<()> { - self.set( - KEY_LAST_VERSION_CHECK, - OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), - ) - .await - } - - pub async fn latest_version(&self) -> Result> { - self.get(KEY_LATEST_VERSION).await - } - - pub async fn save_latest_version(&self, version: &str) -> Result<()> { - self.set(KEY_LATEST_VERSION, version).await - } - - pub async fn session_token(&self) -> Result> { - self.get(KEY_SESSION).await - } - - pub async fn save_session(&self, token: &str) -> Result<()> { - self.set(KEY_SESSION, token).await - } - - pub async fn delete_session(&self) -> Result<()> { - self.delete(KEY_SESSION).await - } - - pub async fn logged_in(&self) -> Result { - Ok(self.session_token().await?.is_some()) - } - - // File migration: on first open, migrate old plain-text files into the database. - // Old files are left in place for safe downgrades. - - async fn migrate_files(&self) -> Result<()> { - if self.get(KEY_FILES_MIGRATED).await?.is_some() { - return Ok(()); - } - - let data_dir = crate::settings::Settings::effective_data_dir(); - - // host_id — validate as UUID - let host_id_path = data_dir.join(LEGACY_HOST_ID_FILENAME); - if host_id_path.exists() - && let Ok(value) = fs_err::read_to_string(&host_id_path) - { - let value = value.trim(); - if !value.is_empty() { - if Uuid::from_str(value).is_ok() { - self.set(KEY_HOST_ID, value).await?; - } else { - warn!("skipping migration of host_id: invalid UUID {value:?}"); - } - } - } - - // last_sync_time — validate as RFC3339 - let sync_path = data_dir.join(LEGACY_LAST_SYNC_FILENAME); - if sync_path.exists() - && let Ok(value) = fs_err::read_to_string(&sync_path) - { - let value = value.trim(); - if !value.is_empty() { - if OffsetDateTime::parse(value, &Rfc3339).is_ok() { - self.set(KEY_LAST_SYNC, value).await?; - } else { - warn!("skipping migration of last_sync_time: invalid RFC3339 {value:?}"); - } - } - } - - // last_version_check_time — validate as RFC3339 - let version_check_path = data_dir.join(LEGACY_LAST_VERSION_CHECK_FILENAME); - if version_check_path.exists() - && let Ok(value) = fs_err::read_to_string(&version_check_path) - { - let value = value.trim(); - if !value.is_empty() { - if OffsetDateTime::parse(value, &Rfc3339).is_ok() { - self.set(KEY_LAST_VERSION_CHECK, value).await?; - } else { - warn!( - "skipping migration of last_version_check_time: invalid RFC3339 {value:?}" - ); - } - } - } - - // latest_version — no strict validation, just non-empty - let latest_version_path = data_dir.join(LEGACY_LATEST_VERSION_FILENAME); - if latest_version_path.exists() - && let Ok(value) = fs_err::read_to_string(&latest_version_path) - { - let value = value.trim(); - if !value.is_empty() { - self.set(KEY_LATEST_VERSION, value).await?; - } - } - - // session token — no strict validation, just non-empty - let session_path = data_dir.join(LEGACY_SESSION_FILENAME); - if session_path.exists() - && let Ok(value) = fs_err::read_to_string(&session_path) - { - let value = value.trim(); - if !value.is_empty() { - self.set(KEY_SESSION, value).await?; - } - } - - self.set(KEY_FILES_MIGRATED, "true").await?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - async fn new_test_store() -> MetaStore { - MetaStore::new("sqlite::memory:", 2.0).await.unwrap() - } - - #[tokio::test] - async fn test_get_set_delete() { - let store = new_test_store().await; - - assert_eq!(store.get("foo").await.unwrap(), None); - - store.set("foo", "bar").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), Some("bar".to_string())); - - store.set("foo", "baz").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), Some("baz".to_string())); - - store.delete("foo").await.unwrap(); - assert_eq!(store.get("foo").await.unwrap(), None); - } - - #[tokio::test] - async fn test_host_id_generation_and_stability() { - let store = new_test_store().await; - - let id1 = store.host_id().await.unwrap(); - let id2 = store.host_id().await.unwrap(); - - assert_eq!(id1, id2, "host_id should be stable across calls"); - } - - #[tokio::test] - async fn test_sync_time() { - let store = new_test_store().await; - - let t = store.last_sync().await.unwrap(); - assert_eq!(t, OffsetDateTime::UNIX_EPOCH); - - store.save_sync_time().await.unwrap(); - let t = store.last_sync().await.unwrap(); - assert!(t > OffsetDateTime::UNIX_EPOCH); - } - - #[tokio::test] - async fn test_version_check_time() { - let store = new_test_store().await; - - let t = store.last_version_check().await.unwrap(); - assert_eq!(t, OffsetDateTime::UNIX_EPOCH); - - store.save_version_check_time().await.unwrap(); - let t = store.last_version_check().await.unwrap(); - assert!(t > OffsetDateTime::UNIX_EPOCH); - } - - #[tokio::test] - async fn test_session_crud() { - let store = new_test_store().await; - - assert!(!store.logged_in().await.unwrap()); - assert_eq!(store.session_token().await.unwrap(), None); - - store.save_session("tok123").await.unwrap(); - assert!(store.logged_in().await.unwrap()); - assert_eq!( - store.session_token().await.unwrap(), - Some("tok123".to_string()) - ); - - store.delete_session().await.unwrap(); - assert!(!store.logged_in().await.unwrap()); - } - - #[tokio::test] - async fn test_latest_version() { - let store = new_test_store().await; - - assert_eq!(store.latest_version().await.unwrap(), None); - - store.save_latest_version("1.2.3").await.unwrap(); - assert_eq!( - store.latest_version().await.unwrap(), - Some("1.2.3".to_string()) - ); - } -} diff --git a/crates/atuin-client/src/ordering.rs b/crates/atuin-client/src/ordering.rs deleted file mode 100644 index 4e5ec84c..00000000 --- a/crates/atuin-client/src/ordering.rs +++ /dev/null @@ -1,32 +0,0 @@ -use minspan::minspan; - -use super::{history::History, settings::SearchMode}; - -pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec) -> Vec { - match mode { - SearchMode::Fuzzy => reorder(query, |x| &x.command, res), - _ => res, - } -} - -fn reorder(query: &str, f: F, res: Vec) -> Vec -where - F: Fn(&A) -> &String, - A: Clone, -{ - let mut r = res.clone(); - let qvec = &query.chars().collect(); - r.sort_by_cached_key(|h| { - // TODO for fzf search we should sum up scores for each matched term - let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { - Some(x) => x, - // this is a little unfortunate: when we are asked to match a query that is found nowhere, - // we don't want to return a None, as the comparison behaviour would put the worst matches - // at the front. therefore, we'll return a set of indices that are one larger than the longest - // possible legitimate match. This is meaningless except as a comparison. - None => (0, res.len()), - }; - 1 + to - from - }); - r -} diff --git a/crates/atuin-client/src/plugin.rs b/crates/atuin-client/src/plugin.rs deleted file mode 100644 index 6f351bf1..00000000 --- a/crates/atuin-client/src/plugin.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug, Clone)] -pub struct OfficialPlugin { - pub name: String, - pub description: String, - pub install_message: String, -} - -impl OfficialPlugin { - pub fn new(name: &str, description: &str, install_message: &str) -> Self { - Self { - name: name.to_string(), - description: description.to_string(), - install_message: install_message.to_string(), - } - } -} - -pub struct OfficialPluginRegistry { - plugins: HashMap, -} - -impl OfficialPluginRegistry { - pub fn new() -> Self { - let mut registry = Self { - plugins: HashMap::new(), - }; - - // Register official plugins - registry.register_official_plugins(); - - registry - } - - fn register_official_plugins(&mut self) { - // atuin-update plugin - self.plugins.insert( - "update".to_string(), - OfficialPlugin::new( - "update", - "Update atuin to the latest version", - "The 'atuin update' command is provided by the atuin-update plugin.\n\ - It is only installed if you used the install script\n \ - If you used a package manager (brew, apt, etc), please continue to use it for updates", - ), - ); - } - - pub fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { - self.plugins.get(name) - } - - pub fn is_official_plugin(&self, name: &str) -> bool { - self.plugins.contains_key(name) - } - - pub fn get_install_message(&self, name: &str) -> Option<&str> { - self.plugins - .get(name) - .map(|plugin| plugin.install_message.as_str()) - } -} - -impl Default for OfficialPluginRegistry { - fn default() -> Self { - Self::new() - } -} - -pub struct PluginContext { - #[cfg(windows)] - _update_on_windows: Option, -} - -impl PluginContext { - pub fn new(_subcommand: &str) -> Self { - PluginContext { - #[cfg(windows)] - _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), - } - } -} - -impl Drop for PluginContext { - fn drop(&mut self) {} -} - -#[cfg(windows)] -struct UpdateOnWindowsContext { - initial_exe: Option, -} - -#[cfg(windows)] -impl UpdateOnWindowsContext { - const OLD_FILE_NAME: &'static str = "atuin.old"; - - pub fn new() -> Self { - // Windows doesn't let you overwrite a running exe, but it lets you rename it, - // so make some room for atuin-update to install the new version. - let initial_exe = std::env::current_exe().ok().and_then(|exe| { - std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; - Some(exe) - }); - - Self { initial_exe } - } -} - -#[cfg(windows)] -impl Drop for UpdateOnWindowsContext { - fn drop(&mut self) { - if let Some(exe) = &self.initial_exe - && !exe.exists() - { - // The update failed, roll back the current exe to its initial name. - std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { - eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); - }); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_registry_creation() { - let registry = OfficialPluginRegistry::new(); - assert!(registry.is_official_plugin("update")); - assert!(!registry.is_official_plugin("nonexistent")); - } - - #[test] - fn test_get_plugin() { - let registry = OfficialPluginRegistry::new(); - let plugin = registry.get_plugin("update"); - assert!(plugin.is_some()); - assert_eq!(plugin.unwrap().name, "update"); - } - - #[test] - fn test_get_install_message() { - let registry = OfficialPluginRegistry::new(); - let message = registry.get_install_message("update"); - assert!(message.is_some()); - assert!(message.unwrap().contains("atuin-update")); - } -} diff --git a/crates/atuin-client/src/record/encryption.rs b/crates/atuin-client/src/record/encryption.rs deleted file mode 100644 index 1e94d967..00000000 --- a/crates/atuin-client/src/record/encryption.rs +++ /dev/null @@ -1,373 +0,0 @@ -use atuin_common::record::{ - AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, -}; -use base64::{Engine, engine::general_purpose}; -use eyre::{Context, Result, ensure}; -use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; -use rusty_paseto::core::{ - ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, -}; -use serde::{Deserialize, Serialize}; - -/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. -#[expect(non_camel_case_types)] -pub struct PASETO_V4; - -/* -Why do we use a random content-encryption key? -Originally I was planning on using a derived key for encryption based on additional data. -This would be a lot more secure than using the master key directly. - -However, there's an established norm of using a random key. This scheme might be otherwise known as -- client-side encryption -- envelope encryption -- key wrapping - -A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey -will have some keys that they keep to themselves. These keys never leave their physical hardware. -If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. -This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM -and then store that with your data. - -See - - - - - - - - - - - -Why would we care? In the past we have received some requests for company solutions. If in future we can configure a -KMS service with little effort, then that would solve a lot of issues for their security team. - -Even for personal use, if a user is not comfortable with sharing keys between hosts, -GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs -1000 atuin records a day, that would only cost them $1 and 10 cent a month. - -Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. -This makes it very fast to rotate a key in bulk. - -For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting -will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption -that happens in the background can make the network calls to the HSM -*/ - -impl Encryption for PASETO_V4 { - fn re_encrypt( - mut data: EncryptedData, - _ad: AdditionalData, - old_key: &[u8; 32], - new_key: &[u8; 32], - ) -> Result { - let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; - data.content_encryption_key = Self::encrypt_cek(cek, new_key); - Ok(data) - } - - fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { - // generate a random key for this entry - // aka content-encryption-key (CEK) - let random_key = Key::::new_os_random(); - - // encode the implicit assertions - let assertions = Assertions::from(ad).encode(); - - // build the payload and encrypt the token - let payload = serde_json::to_string(&AtuinPayload { - data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), - }) - .expect("json encoding can't fail"); - let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); - let nonce = PasetoNonce::::from(&nonce); - - let token = Paseto::::builder() - .set_payload(Payload::from(payload.as_str())) - .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) - .try_encrypt(&random_key.into(), &nonce) - .expect("error encrypting atuin data"); - - EncryptedData { - data: token, - content_encryption_key: Self::encrypt_cek(random_key, key), - } - } - - fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result { - let token = data.data; - let cek = Self::decrypt_cek(data.content_encryption_key, key)?; - - // encode the implicit assertions - let assertions = Assertions::from(ad).encode(); - - // decrypt the payload with the footer and implicit assertions - let payload = Paseto::::try_decrypt( - &token, - &cek.into(), - None, - ImplicitAssertion::from(&*assertions), - ) - .context("could not decrypt entry")?; - - let payload: AtuinPayload = serde_json::from_str(&payload)?; - let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; - Ok(DecryptedData(data)) - } -} - -impl PASETO_V4 { - fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result> { - let wrapping_key = Key::::from_bytes(*key); - - // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); - - let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) - .context("wrapped cek did not contain the correct contents")?; - - // check that the wrapping key matches the required key to decrypt. - // In future, we could support multiple keys and use this key to - // look up the key rather than only allow one key. - // For now though we will only support the one key and key rotation will - // have to be a hard reset - let current_kid = wrapping_key.to_id(); - - ensure!( - current_kid == kid, - "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" - ); - - // decrypt the random key - Ok(wpk.unwrap_key(&wrapping_key)?) - } - - fn encrypt_cek(cek: Key, key: &[u8; 32]) -> String { - // aka key-encryption-key (KEK) - let wrapping_key = Key::::from_bytes(*key); - - // wrap the random key so we can decrypt it later - let wrapped_cek = AtuinFooter { - wpk: cek.wrap_pie(&wrapping_key), - kid: wrapping_key.to_id(), - }; - serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") - } -} - -#[derive(Serialize, Deserialize)] -struct AtuinPayload { - data: String, -} - -#[derive(Serialize, Deserialize)] -/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. -/// -struct AtuinFooter { - /// Wrapped key - wpk: PieWrappedKey, - /// ID of the key which was used to wrap - kid: KeyId, -} - -/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. -// This cannot be changed, otherwise it breaks the authenticated encryption. -#[derive(Debug, Copy, Clone, Serialize)] -struct Assertions<'a> { - id: &'a RecordId, - idx: &'a RecordIdx, - version: &'a str, - tag: &'a str, - host: &'a HostId, -} - -impl<'a> From> for Assertions<'a> { - fn from(ad: AdditionalData<'a>) -> Self { - Self { - id: ad.id, - version: ad.version, - tag: ad.tag, - host: ad.host, - idx: ad.idx, - } - } -} - -impl Assertions<'_> { - fn encode(&self) -> String { - serde_json::to_string(self).expect("could not serialize implicit assertions") - } -} - -#[cfg(test)] -mod tests { - use atuin_common::{ - record::{Host, Record}, - utils::uuid_v7, - }; - - use super::*; - - #[test] - fn round_trip() { - let key = Key::::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); - let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); - assert_eq!(decrypted, data); - } - - #[test] - fn same_entry_different_output() { - let key = Key::::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); - let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - - assert_ne!( - encrypted.data, encrypted2.data, - "re-encrypting the same contents should have different output due to key randomization" - ); - } - - #[test] - fn cannot_decrypt_different_key() { - let key = Key::::new_os_random(); - let fake_key = Key::::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); - } - - #[test] - fn cannot_decrypt_different_id() { - let key = Key::::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - ..ad - }; - let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); - } - - #[test] - fn re_encrypt_round_trip() { - let key1 = Key::::new_os_random(); - let key2 = Key::::new_os_random(); - - let ad = AdditionalData { - id: &RecordId(uuid_v7()), - version: "v0", - tag: "kv", - host: &HostId(uuid_v7()), - idx: &0, - }; - - let data = DecryptedData(vec![1, 2, 3, 4]); - - let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); - let encrypted2 = - PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) - .unwrap(); - - // we only re-encrypt the content keys - assert_eq!(encrypted1.data, encrypted2.data); - assert_ne!( - encrypted1.content_encryption_key, - encrypted2.content_encryption_key - ); - - let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); - - assert_eq!(decrypted, data); - } - - #[test] - fn full_record_round_trip() { - let key = [0x55; 32]; - let record = Record::builder() - .id(RecordId(uuid_v7())) - .version("v0".to_owned()) - .tag("kv".to_owned()) - .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) - .data(DecryptedData(vec![1, 2, 3, 4])) - .idx(0) - .build(); - - let encrypted = record.encrypt::(&key); - - assert!(!encrypted.data.data.is_empty()); - assert!(!encrypted.data.content_encryption_key.is_empty()); - - let decrypted = encrypted.decrypt::(&key).unwrap(); - - assert_eq!(decrypted.data.0, [1, 2, 3, 4]); - } - - #[test] - fn full_record_round_trip_fail() { - let key = [0x55; 32]; - let record = Record::builder() - .id(RecordId(uuid_v7())) - .version("v0".to_owned()) - .tag("kv".to_owned()) - .host(Host::new(HostId(uuid_v7()))) - .timestamp(1687244806000000) - .data(DecryptedData(vec![1, 2, 3, 4])) - .idx(0) - .build(); - - let encrypted = record.encrypt::(&key); - - let mut enc1 = encrypted.clone(); - enc1.host = Host::new(HostId(uuid_v7())); - let _ = enc1 - .decrypt::(&key) - .expect_err("tampering with the host should result in auth failure"); - - let mut enc2 = encrypted; - enc2.id = RecordId(uuid_v7()); - let _ = enc2 - .decrypt::(&key) - .expect_err("tampering with the id should result in auth failure"); - } -} diff --git a/crates/atuin-client/src/record/mod.rs b/crates/atuin-client/src/record/mod.rs deleted file mode 100644 index c40fd395..00000000 --- a/crates/atuin-client/src/record/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod encryption; -pub mod sqlite_store; -pub mod store; - -#[cfg(feature = "sync")] -pub mod sync; diff --git a/crates/atuin-client/src/record/sqlite_store.rs b/crates/atuin-client/src/record/sqlite_store.rs deleted file mode 100644 index ed51f3fd..00000000 --- a/crates/atuin-client/src/record/sqlite_store.rs +++ /dev/null @@ -1,642 +0,0 @@ -// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. -// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index -// by tag/host - -use std::str::FromStr; -use std::{path::Path, time::Duration}; - -use async_trait::async_trait; -use eyre::{Result, eyre}; -use fs_err as fs; - -use sqlx::{ - Row, - sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, -}; - -use atuin_common::record::{ - EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, -}; -use atuin_common::utils; -use uuid::Uuid; - -use super::encryption::PASETO_V4; -use super::store::Store; - -#[derive(Debug, Clone)] -pub struct SqliteStore { - pool: SqlitePool, -} - -impl SqliteStore { - pub async fn new(path: impl AsRef, timeout: f64) -> Result { - let path = path.as_ref(); - - debug!("opening sqlite database at {path:?}"); - - if utils::broken_symlink(path) { - eprintln!( - "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." - ); - std::process::exit(1); - } - - if !path.exists() - && let Some(dir) = path.parent() - { - fs::create_dir_all(dir)?; - } - - let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? - .journal_mode(SqliteJournalMode::Wal) - .foreign_keys(true) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; - - Self::setup_db(&pool).await?; - - Ok(Self { pool }) - } - - async fn setup_db(pool: &SqlitePool) -> Result<()> { - debug!("running sqlite database setup"); - - sqlx::migrate!("./record-migrations").run(pool).await?; - - Ok(()) - } - - async fn save_raw( - tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, - r: &Record, - ) -> Result<()> { - // In sqlite, we are "limited" to i64. But that is still fine, until 2262. - sqlx::query( - "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) - values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", - ) - .bind(r.id.0.as_hyphenated().to_string()) - .bind(r.idx as i64) - .bind(r.host.id.0.as_hyphenated().to_string()) - .bind(r.tag.as_str()) - .bind(r.timestamp as i64) - .bind(r.version.as_str()) - .bind(r.data.data.as_str()) - .bind(r.data.content_encryption_key.as_str()) - .execute(&mut **tx) - .await?; - - Ok(()) - } - - fn query_row(row: SqliteRow) -> Record { - let idx: i64 = row.get("idx"); - let timestamp: i64 = row.get("timestamp"); - - // tbh at this point things are pretty fucked so just panic - let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); - let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); - - Record { - id: RecordId(id), - idx: idx as u64, - host: Host::new(HostId(host)), - timestamp: timestamp as u64, - tag: row.get("tag"), - version: row.get("version"), - data: EncryptedData { - data: row.get("data"), - content_encryption_key: row.get("cek"), - }, - } - } - - async fn load_all(&self) -> Result>> { - let res = sqlx::query("select * from store ") - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } -} - -#[async_trait] -impl Store for SqliteStore { - async fn push_batch( - &self, - records: impl Iterator> + Send + Sync, - ) -> Result<()> { - let mut tx = self.pool.begin().await?; - - for record in records { - Self::save_raw(&mut tx, record).await?; - } - - tx.commit().await?; - - Ok(()) - } - - async fn get(&self, id: RecordId) -> Result> { - let res = sqlx::query("select * from store where store.id = ?1") - .bind(id.0.as_hyphenated().to_string()) - .map(Self::query_row) - .fetch_one(&self.pool) - .await?; - - Ok(res) - } - - async fn delete(&self, id: RecordId) -> Result<()> { - sqlx::query("delete from store where id = ?1") - .bind(id.0.as_hyphenated().to_string()) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete_all(&self) -> Result<()> { - sqlx::query("delete from store").execute(&self.pool).await?; - - Ok(()) - } - - async fn last(&self, host: HostId, tag: &str) -> Result>> { - let res = - sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .map(Self::query_row) - .fetch_one(&self.pool) - .await; - - match res { - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(eyre!("an error occurred: {}", e)), - Ok(record) => Ok(Some(record)), - } - } - - async fn first(&self, host: HostId, tag: &str) -> Result>> { - self.idx(host, tag, 0).await - } - - async fn len_all(&self) -> Result { - let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") - .fetch_one(&self.pool) - .await; - match res { - Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), - Ok(v) => Ok(v.0 as u64), - } - } - - async fn len_tag(&self, tag: &str) -> Result { - let res: Result<(i64,), sqlx::Error> = - sqlx::query_as("select count(*) from store where tag=?1") - .bind(tag) - .fetch_one(&self.pool) - .await; - match res { - Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), - Ok(v) => Ok(v.0 as u64), - } - } - - async fn len(&self, host: HostId, tag: &str) -> Result { - let last = self.last(host, tag).await?; - - if let Some(last) = last { - return Ok(last.idx + 1); - } - - return Ok(0); - } - - async fn next( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - limit: u64, - ) -> Result>> { - let res = sqlx::query( - "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4", - ) - .bind(idx as i64) - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .bind(limit as i64) - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - async fn idx( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - ) -> Result>> { - let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") - .bind(idx as i64) - .bind(host.0.as_hyphenated().to_string()) - .bind(tag) - .map(Self::query_row) - .fetch_one(&self.pool) - .await; - - match res { - Err(sqlx::Error::RowNotFound) => Ok(None), - Err(e) => Err(eyre!("an error occurred: {}", e)), - Ok(v) => Ok(Some(v)), - } - } - - async fn status(&self) -> Result { - let mut status = RecordStatus::new(); - - let res: Result, sqlx::Error> = - sqlx::query_as("select host, tag, max(idx) from store group by host, tag") - .fetch_all(&self.pool) - .await; - - let res = match res { - Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), - Ok(v) => v, - }; - - for i in res { - let host = HostId( - Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), - ); - - status.set_raw(host, i.1, i.2 as u64); - } - - Ok(status) - } - - async fn all_tagged(&self, tag: &str) -> Result>> { - let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") - .bind(tag) - .map(Self::query_row) - .fetch_all(&self.pool) - .await?; - - Ok(res) - } - - /// Reencrypt every single item in this store with a new key - /// Be careful - this may mess with sync. - async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { - // Load all the records - // In memory like some of the other code here - // This will never be called in a hot loop, and only under the following circumstances - // 1. The user has logged into a new account, with a new key. They are unlikely to have a - // lot of data - // 2. The user has encountered some sort of issue, and runs a maintenance command that - // invokes this - let all = self.load_all().await?; - - let re_encrypted = all - .into_iter() - .map(|record| record.re_encrypt::(old_key, new_key)) - .collect::>>()?; - - // next up, we delete all the old data and reinsert the new stuff - // do it in one transaction, so if anything fails we rollback OK - - let mut tx = self.pool.begin().await?; - - let res = sqlx::query("delete from store").execute(&mut *tx).await?; - - let rows = res.rows_affected(); - debug!("deleted {rows} rows"); - - // don't call push_batch, as it will start its own transaction - // call the underlying save_raw - - for record in re_encrypted { - Self::save_raw(&mut tx, &record).await?; - } - - tx.commit().await?; - - Ok(()) - } - - /// Verify that every record in this store can be decrypted with the current key - /// Someday maybe also check each tag/record can be deserialized, but not for now. - async fn verify(&self, key: &[u8; 32]) -> Result<()> { - let all = self.load_all().await?; - - all.into_iter() - .map(|record| record.decrypt::(key)) - .collect::>>()?; - - Ok(()) - } - - /// Verify that every record in this store can be decrypted with the current key - /// Someday maybe also check each tag/record can be deserialized, but not for now. - async fn purge(&self, key: &[u8; 32]) -> Result<()> { - let all = self.load_all().await?; - - for record in all.iter() { - match record.clone().decrypt::(key) { - Ok(_) => continue, - Err(_) => { - println!( - "Failed to decrypt {}, deleting", - record.id.0.as_hyphenated() - ); - - self.delete(record.id).await?; - } - } - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use atuin_common::{ - record::{DecryptedData, EncryptedData, Host, HostId, Record}, - utils::uuid_v7, - }; - - use crate::{ - encryption::generate_encoded_key, - record::{encryption::PASETO_V4, store::Store}, - settings::test_local_timeout, - }; - - use super::SqliteStore; - - fn test_record() -> Record { - Record::builder() - .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) - .version("v1".into()) - .tag(atuin_common::utils::uuid_v7().simple().to_string()) - .data(EncryptedData { - data: "1234".into(), - content_encryption_key: "1234".into(), - }) - .idx(0) - .build() - } - - #[tokio::test] - async fn create_db() { - let db = SqliteStore::new(":memory:", test_local_timeout()).await; - - assert!( - db.is_ok(), - "db could not be created, {:?}", - db.err().unwrap() - ); - } - - #[tokio::test] - async fn push_record() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - - db.push(&record).await.expect("failed to insert record"); - } - - #[tokio::test] - async fn get_record() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let new_record = db.get(record.id).await.expect("failed to fetch record"); - - assert_eq!(record, new_record, "records are not equal"); - } - - #[tokio::test] - async fn last() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let last = db - .last(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!( - last.unwrap().id, - record.id, - "expected to get back the same record that was inserted" - ); - } - - #[tokio::test] - async fn first() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let first = db - .first(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!( - first.unwrap().id, - record.id, - "expected to get back the same record that was inserted" - ); - } - - #[tokio::test] - async fn len() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let len = db - .len(record.host.id, record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!(len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn len_tag() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let record = test_record(); - db.push(&record).await.unwrap(); - - let len = db - .len_tag(record.tag.as_str()) - .await - .expect("failed to get store len"); - - assert_eq!(len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn len_different_tags() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - // these have different tags, so the len should be the same - // we model multiple stores within one database - // new store = new tag = independent length - let first = test_record(); - let second = test_record(); - - db.push(&first).await.unwrap(); - db.push(&second).await.unwrap(); - - let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); - let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); - - assert_eq!(first_len, 1, "expected length of 1 after insert"); - assert_eq!(second_len, 1, "expected length of 1 after insert"); - } - - #[tokio::test] - async fn append_a_bunch() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - let mut tail = test_record(); - db.push(&tail).await.expect("failed to push record"); - - for _ in 1..100 { - tail = tail.append(vec![1, 2, 3, 4]).encrypt::(&[0; 32]); - db.push(&tail).await.unwrap(); - } - - assert_eq!( - db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), - 100, - "failed to insert 100 records" - ); - - assert_eq!( - db.len_tag(tail.tag.as_str()).await.unwrap(), - 100, - "failed to insert 100 records" - ); - } - - #[tokio::test] - async fn append_a_big_bunch() { - let db = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - - let mut records: Vec> = Vec::with_capacity(10000); - - let mut tail = test_record(); - records.push(tail.clone()); - - for _ in 1..10000 { - tail = tail.append(vec![1, 2, 3]).encrypt::(&[0; 32]); - records.push(tail.clone()); - } - - db.push_batch(records.iter()).await.unwrap(); - - assert_eq!( - db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), - 10000, - "failed to insert 10k records" - ); - } - - #[tokio::test] - async fn re_encrypt() { - let store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .unwrap(); - let (key, _) = generate_encoded_key().unwrap(); - let data = vec![0u8, 1u8, 2u8, 3u8]; - let host_id = HostId(uuid_v7()); - - for i in 0..10 { - let record = Record::builder() - .host(Host::new(host_id)) - .version(String::from("test")) - .tag(String::from("test")) - .idx(i) - .data(DecryptedData(data.clone())) - .build(); - - let record = record.encrypt::(&key.into()); - store - .push(&record) - .await - .expect("failed to push encrypted record"); - } - - // first, check that we can decrypt the data with the current key - let all = store.all_tagged("test").await.unwrap(); - - assert_eq!(all.len(), 10, "failed to fetch all records"); - - for record in all { - let decrypted = record.decrypt::(&key.into()).unwrap(); - assert_eq!(decrypted.data.0, data); - } - - // reencrypt the store, then check if - // 1) it cannot be decrypted with the old key - // 2) it can be decrypted with the new key - - let (new_key, _) = generate_encoded_key().unwrap(); - store - .re_encrypt(&key.into(), &new_key.into()) - .await - .expect("failed to re-encrypt store"); - - let all = store.all_tagged("test").await.unwrap(); - - for record in all.iter() { - let decrypted = record.clone().decrypt::(&key.into()); - assert!( - decrypted.is_err(), - "did not get error decrypting with old key after re-encrypt" - ) - } - - for record in all { - let decrypted = record.decrypt::(&new_key.into()).unwrap(); - assert_eq!(decrypted.data.0, data); - } - - assert_eq!(store.len(host_id, "test").await.unwrap(), 10); - } -} diff --git a/crates/atuin-client/src/record/store.rs b/crates/atuin-client/src/record/store.rs deleted file mode 100644 index 49ca4968..00000000 --- a/crates/atuin-client/src/record/store.rs +++ /dev/null @@ -1,60 +0,0 @@ -use async_trait::async_trait; -use eyre::Result; - -use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; - -/// A record store stores records -/// In more detail - we tend to need to process this into _another_ format to actually query it. -/// As is, the record store is intended as the source of truth for arbitrary data, which could -/// be shell history, kvs, etc. -#[async_trait] -pub trait Store { - // Push a record - async fn push(&self, record: &Record) -> Result<()> { - self.push_batch(std::iter::once(record)).await - } - - // Push a batch of records, all in one transaction - async fn push_batch( - &self, - records: impl Iterator> + Send + Sync, - ) -> Result<()>; - - async fn get(&self, id: RecordId) -> Result>; - - async fn delete(&self, id: RecordId) -> Result<()>; - async fn delete_all(&self) -> Result<()>; - - async fn len_all(&self) -> Result; - async fn len(&self, host: HostId, tag: &str) -> Result; - async fn len_tag(&self, tag: &str) -> Result; - - async fn last(&self, host: HostId, tag: &str) -> Result>>; - async fn first(&self, host: HostId, tag: &str) -> Result>>; - - async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; - async fn verify(&self, key: &[u8; 32]) -> Result<()>; - async fn purge(&self, key: &[u8; 32]) -> Result<()>; - - /// Get the next `limit` records, after and including the given index - async fn next( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - limit: u64, - ) -> Result>>; - - /// Get the first record for a given host and tag - async fn idx( - &self, - host: HostId, - tag: &str, - idx: RecordIdx, - ) -> Result>>; - - async fn status(&self) -> Result; - - /// Get all records for a given tag - async fn all_tagged(&self, tag: &str) -> Result>>; -} diff --git a/crates/atuin-client/src/record/sync.rs b/crates/atuin-client/src/record/sync.rs deleted file mode 100644 index b785b5dc..00000000 --- a/crates/atuin-client/src/record/sync.rs +++ /dev/null @@ -1,663 +0,0 @@ -// do a sync :O -use std::{cmp::Ordering, fmt::Write}; - -use eyre::Result; -use thiserror::Error; - -use super::{encryption::PASETO_V4, store::Store}; -use crate::{api_client::Client, settings::Settings}; - -use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; -use indicatif::{ProgressBar, ProgressState, ProgressStyle}; - -#[derive(Error, Debug)] -pub enum SyncError { - #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] - LocalAheadOtherHost, - - #[error("an issue with the local database occurred: {msg:?}")] - LocalStoreError { msg: String }, - - #[error("something has gone wrong with the sync logic: {msg:?}")] - SyncLogicError { msg: String }, - - #[error("operational error: {msg:?}")] - OperationalError { msg: String }, - - #[error("a request to the sync server failed: {msg:?}")] - RemoteRequestError { msg: String }, - - #[error( - "the encryption key on this machine does not match the data on the server. \ - this usually means a new machine was set up without copying the existing key. \ - to fix: run `atuin key` on a machine that already syncs correctly, then run \ - `atuin store rekey ` on this machine with the value from the other machine" - )] - WrongKey, -} - -#[derive(Debug, Eq, PartialEq)] -pub enum Operation { - // Either upload or download until the states matches the below - Upload { - local: RecordIdx, - remote: Option, - host: HostId, - tag: String, - }, - Download { - local: Option, - remote: RecordIdx, - host: HostId, - tag: String, - }, - Noop { - host: HostId, - tag: String, - }, -} - -pub async fn build_client(settings: &Settings) -> Result, SyncError> { - Client::new( - &settings.sync_address, - settings - .sync_auth_token() - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, - settings.network_connect_timeout, - settings.network_timeout, - ) - .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) -} - -pub async fn diff( - client: &Client<'_>, - store: &impl Store, -) -> Result<(Vec, RecordStatus), SyncError> { - let local_index = store - .status() - .await - .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - - let remote_index = client - .record_status() - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - let diff = local_index.diff(&remote_index); - - Ok((diff, remote_index)) -} - -// Take a diff, along with a local store, and resolve it into a set of operations. -// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. -// In theory this could be done as a part of the diffing stage, but it's easier to reason -// about and test this way -pub async fn operations( - diffs: Vec, - _store: &impl Store, -) -> Result, SyncError> { - let mut operations = Vec::with_capacity(diffs.len()); - - for diff in diffs { - let op = match (diff.local, diff.remote) { - // We both have it! Could be either. Compare. - (Some(local), Some(remote)) => match local.cmp(&remote) { - Ordering::Equal => Operation::Noop { - host: diff.host, - tag: diff.tag, - }, - Ordering::Greater => Operation::Upload { - local, - remote: Some(remote), - host: diff.host, - tag: diff.tag, - }, - Ordering::Less => Operation::Download { - local: Some(local), - remote, - host: diff.host, - tag: diff.tag, - }, - }, - - // Remote has it, we don't. Gotta be download - (None, Some(remote)) => Operation::Download { - local: None, - remote, - host: diff.host, - tag: diff.tag, - }, - - // We have it, remote doesn't. Gotta be upload. - (Some(local), None) => Operation::Upload { - local, - remote: None, - host: diff.host, - tag: diff.tag, - }, - - // something is pretty fucked. - (None, None) => { - return Err(SyncError::SyncLogicError { - msg: String::from( - "diff has nothing for local or remote - (host, tag) does not exist", - ), - }); - } - }; - - operations.push(op); - } - - // sort them - purely so we have a stable testing order, and can rely on - // same input = same output - // We can sort by ID so long as we continue to use UUIDv7 or something - // with the same properties - - operations.sort_by_key(|op| match op { - Operation::Noop { host, tag } => (0, *host, tag.clone()), - - Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), - - Operation::Download { host, tag, .. } => (2, *host, tag.clone()), - }); - - Ok(operations) -} - -async fn sync_upload( - store: &impl Store, - client: &Client<'_>, - host: HostId, - tag: String, - local: RecordIdx, - remote: Option, - page_size: u64, -) -> Result { - let remote = remote.unwrap_or(0); - let expected = local - remote; - let mut progress = 0; - - let pb = ProgressBar::new(expected); - pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) - .progress_chars("#>-")); - - println!( - "Uploading {} records to {}/{}", - expected, - host.0.as_simple(), - tag - ); - - loop { - let page = store - .next(host, tag.as_str(), remote + progress, page_size) - .await - .map_err(|e| { - error!("failed to read upload page: {e:?}"); - - SyncError::LocalStoreError { msg: e.to_string() } - })?; - - if page.is_empty() { - break; - } - - client.post_records(&page).await.map_err(|e| { - error!("failed to post records: {e:?}"); - - SyncError::RemoteRequestError { msg: e.to_string() } - })?; - - progress += page.len() as u64; - pb.set_position(progress); - - if progress >= expected { - break; - } - } - - pb.finish_with_message("Uploaded records"); - - Ok(progress as i64) -} - -async fn sync_download( - store: &impl Store, - client: &Client<'_>, - host: HostId, - tag: String, - local: Option, - remote: RecordIdx, - page_size: u64, -) -> Result, SyncError> { - let local = local.unwrap_or(0); - let expected = remote - local; - let mut progress = 0; - let mut ret = Vec::new(); - - println!( - "Downloading {} records from {}/{}", - expected, - host.0.as_simple(), - tag - ); - - let pb = ProgressBar::new(expected); - pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") - .unwrap() - .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) - .progress_chars("#>-")); - - loop { - let page = client - .next_records(host, tag.clone(), local + progress, page_size) - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - if page.is_empty() { - break; - } - - store - .push_batch(page.iter()) - .await - .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; - - ret.extend(page.iter().map(|f| f.id)); - - progress += page.len() as u64; - pb.set_position(progress); - - if progress >= expected { - break; - } - } - - pb.finish_with_message("Downloaded records"); - - Ok(ret) -} - -pub async fn sync_remote( - client: &Client<'_>, - operations: Vec, - local_store: &impl Store, - page_size: u64, -) -> Result<(i64, Vec), SyncError> { - let mut uploaded = 0; - let mut downloaded = Vec::new(); - - // this can totally run in parallel, but lets get it working first - for i in operations { - match i { - Operation::Upload { - host, - tag, - local, - remote, - } => { - uploaded += - sync_upload(local_store, client, host, tag, local, remote, page_size).await? - } - - Operation::Download { - host, - tag, - local, - remote, - } => { - let mut d = - sync_download(local_store, client, host, tag, local, remote, page_size).await?; - downloaded.append(&mut d) - } - - Operation::Noop { .. } => continue, - } - } - - Ok((uploaded, downloaded)) -} - -pub async fn check_encryption_key( - client: &Client<'_>, - remote_index: &RecordStatus, - encryption_key: &[u8; 32], -) -> Result<(), SyncError> { - let sample = remote_index - .hosts - .iter() - .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone()))) - .next(); - - let Some((host, tag)) = sample else { - return Ok(()); - }; - - let records = client - .next_records(host, tag, 0, 1) - .await - .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; - - let Some(record) = records.into_iter().next() else { - return Ok(()); - }; - - record - .decrypt::(encryption_key) - .map_err(|_| SyncError::WrongKey)?; - - Ok(()) -} - -pub async fn sync( - settings: &Settings, - store: &impl Store, - encryption_key: &[u8; 32], -) -> Result<(i64, Vec), SyncError> { - let client = build_client(settings).await?; - let (diff, remote_index) = diff(&client, store).await?; - - // Bail before mutating either side if the local key can't read the remote. - check_encryption_key(&client, &remote_index, encryption_key).await?; - - let operations = operations(diff, store).await?; - let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?; - - Ok((uploaded, downloaded)) -} - -#[cfg(test)] -mod tests { - use atuin_common::record::{Diff, EncryptedData, HostId, Record}; - use pretty_assertions::assert_eq; - - use crate::{ - record::{ - encryption::PASETO_V4, - sqlite_store::SqliteStore, - store::Store, - sync::{self, Operation}, - }, - settings::test_local_timeout, - }; - - fn test_record() -> Record { - Record::builder() - .host(atuin_common::record::Host::new(HostId( - atuin_common::utils::uuid_v7(), - ))) - .version("v1".into()) - .tag(atuin_common::utils::uuid_v7().simple().to_string()) - .data(EncryptedData { - data: String::new(), - content_encryption_key: String::new(), - }) - .idx(0) - .build() - } - - // Take a list of local records, and a list of remote records. - // Return the local database, and a diff of local/remote, ready to build - // ops - async fn build_test_diff( - local_records: Vec>, - remote_records: Vec>, - ) -> (SqliteStore, Vec) { - let local_store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .expect("failed to open in memory sqlite"); - let remote_store = SqliteStore::new(":memory:", test_local_timeout()) - .await - .expect("failed to open in memory sqlite"); // "remote" - - for i in local_records { - local_store.push(&i).await.unwrap(); - } - - for i in remote_records { - remote_store.push(&i).await.unwrap(); - } - - let local_index = local_store.status().await.unwrap(); - let remote_index = remote_store.status().await.unwrap(); - - let diff = local_index.diff(&remote_index); - - (local_store, diff) - } - - #[tokio::test] - async fn test_basic_diff() { - // a diff where local is ahead of remote. nothing else. - - let record = test_record(); - let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; - - assert_eq!(diff.len(), 1); - - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 1); - - assert_eq!( - operations[0], - Operation::Upload { - host: record.host.id, - tag: record.tag, - local: record.idx, - remote: None, - } - ); - } - - #[tokio::test] - async fn build_two_way_diff() { - // a diff where local is ahead of remote for one, and remote for - // another. One upload, one download - - let shared_record = test_record(); - let remote_ahead = test_record(); - - let local_ahead = shared_record - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - - assert_eq!(local_ahead.idx, 1); - - let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store - let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store - - let (store, diff) = build_test_diff(local, remote).await; - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 2); - - assert_eq!( - operations, - vec![ - // Or in otherwords, local is ahead by one - Operation::Upload { - host: local_ahead.host.id, - tag: local_ahead.tag, - local: 1, - remote: Some(0), - }, - // Or in other words, remote knows of a record in an entirely new store (tag) - Operation::Download { - host: remote_ahead.host.id, - tag: remote_ahead.tag, - local: None, - remote: 0, - }, - ] - ); - } - - #[tokio::test] - async fn build_complex_diff() { - // One shared, ahead but known only by remote - // One known only by local - // One known only by remote - - let shared_record = test_record(); - let local_only = test_record(); - - let local_only_20 = test_record(); - let local_only_21 = local_only_20 - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - let local_only_22 = local_only_21 - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - let local_only_23 = local_only_22 - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - - let remote_only = test_record(); - - let remote_only_20 = test_record(); - let remote_only_21 = remote_only_20 - .append(vec![2, 3, 2]) - .encrypt::(&[0; 32]); - let remote_only_22 = remote_only_21 - .append(vec![2, 3, 2]) - .encrypt::(&[0; 32]); - let remote_only_23 = remote_only_22 - .append(vec![2, 3, 2]) - .encrypt::(&[0; 32]); - let remote_only_24 = remote_only_23 - .append(vec![2, 3, 2]) - .encrypt::(&[0; 32]); - - let second_shared = test_record(); - let second_shared_remote_ahead = second_shared - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - let second_shared_remote_ahead2 = second_shared_remote_ahead - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - - let third_shared = test_record(); - let third_shared_local_ahead = third_shared - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - let third_shared_local_ahead2 = third_shared_local_ahead - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - - let fourth_shared = test_record(); - let fourth_shared_remote_ahead = fourth_shared - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead - .append(vec![1, 2, 3]) - .encrypt::(&[0; 32]); - - let local = vec![ - shared_record.clone(), - second_shared.clone(), - third_shared.clone(), - fourth_shared.clone(), - fourth_shared_remote_ahead.clone(), - // single store, only local has it - local_only.clone(), - // bigger store, also only known by local - local_only_20.clone(), - local_only_21.clone(), - local_only_22.clone(), - local_only_23.clone(), - // another shared store, but local is ahead on this one - third_shared_local_ahead.clone(), - third_shared_local_ahead2.clone(), - ]; - - let remote = vec![ - remote_only.clone(), - remote_only_20.clone(), - remote_only_21.clone(), - remote_only_22.clone(), - remote_only_23.clone(), - remote_only_24.clone(), - shared_record.clone(), - second_shared.clone(), - third_shared.clone(), - second_shared_remote_ahead.clone(), - second_shared_remote_ahead2.clone(), - fourth_shared.clone(), - fourth_shared_remote_ahead.clone(), - fourth_shared_remote_ahead2.clone(), - ]; // remote knows about the already-synced, and one new record in a new store - - let (store, diff) = build_test_diff(local, remote).await; - let operations = sync::operations(diff, &store).await.unwrap(); - - assert_eq!(operations.len(), 7); - - let mut result_ops = vec![ - // We started with a shared record, but the remote knows of two newer records in the - // same store - Operation::Download { - local: Some(0), - remote: 2, - host: second_shared_remote_ahead.host.id, - tag: second_shared_remote_ahead.tag, - }, - // We have a shared record, local knows of the first two but not the last - Operation::Download { - local: Some(1), - remote: 2, - host: fourth_shared_remote_ahead2.host.id, - tag: fourth_shared_remote_ahead2.tag, - }, - // Remote knows of a store with a single record that local does not have - Operation::Download { - local: None, - remote: 0, - host: remote_only.host.id, - tag: remote_only.tag, - }, - // Remote knows of a store with a bunch of records that local does not have - Operation::Download { - local: None, - remote: 4, - host: remote_only_20.host.id, - tag: remote_only_20.tag, - }, - // Local knows of a record in a store that remote does not have - Operation::Upload { - local: 0, - remote: None, - host: local_only.host.id, - tag: local_only.tag, - }, - // Local knows of 4 records in a store that remote does not have - Operation::Upload { - local: 3, - remote: None, - host: local_only_20.host.id, - tag: local_only_20.tag, - }, - // Local knows of 2 more records in a shared store that remote only has one of - Operation::Upload { - local: 2, - remote: Some(0), - host: third_shared.host.id, - tag: third_shared.tag, - }, - ]; - - result_ops.sort_by_key(|op| match op { - Operation::Noop { host, tag } => (0, *host, tag.clone()), - - Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), - - Operation::Download { host, tag, .. } => (2, *host, tag.clone()), - }); - - assert_eq!(result_ops, operations); - } -} diff --git a/crates/atuin-client/src/register.rs b/crates/atuin-client/src/register.rs deleted file mode 100644 index ad077dd1..00000000 --- a/crates/atuin-client/src/register.rs +++ /dev/null @@ -1,20 +0,0 @@ -use eyre::Result; - -use crate::{api_client, settings::Settings}; - -pub async fn register_classic( - settings: &Settings, - username: String, - email: String, - password: String, -) -> Result { - let session = - api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; - - let meta = Settings::meta_store().await?; - meta.save_session(&session.session).await?; - - let _key = crate::encryption::load_key(settings)?; - - Ok(session.session) -} diff --git a/crates/atuin-client/src/secrets.rs b/crates/atuin-client/src/secrets.rs deleted file mode 100644 index e8a6ab62..00000000 --- a/crates/atuin-client/src/secrets.rs +++ /dev/null @@ -1,194 +0,0 @@ -// This file will probably trigger a lot of scanners. Sorry. - -use regex::RegexSet; -use std::sync::LazyLock; - -pub enum TestValue<'a> { - Single(&'a str), - Multiple(&'a [&'a str]), -} - -/// A list of `(name, regex, test)`, where `test` should match against `regex`. -pub static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ - ( - "AWS Access Key ID", - "A[KS]IA[0-9A-Z]{16}", - TestValue::Single("AKIAIOSFODNN7EXAMPLE"), - ), - ( - "AWS Secret Access Key env var", - "AWS_SECRET_ACCESS_KEY", - TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), - ), - ( - "AWS Session Token env var", - "AWS_SESSION_TOKEN", - TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), - ), - ( - "Microsoft Azure secret access key env var", - "AZURE_.*_KEY", - TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), - ), - ( - "Google cloud platform key env var", - "GOOGLE_SERVICE_ACCOUNT_KEY", - TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), - ), - ( - "Atuin login", - r"atuin\s+login", - TestValue::Single( - "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", - ), - ), - ( - "GitHub PAT (old)", - "ghp_[a-zA-Z0-9]{36}", - TestValue::Single("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"), // legit, I expired it - ), - ( - "GitHub PAT (new)", - "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", - TestValue::Multiple(&[ - "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", - "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", // also legit, also expired - ]), - ), - ( - "GitHub OAuth Access Token", - "gho_[A-Za-z0-9]{36}", - TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token - ), - ( - "GitHub OAuth Access Token (user)", - "ghu_[A-Za-z0-9]{36}", - TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token - ), - ( - "GitHub App Installation Access Token", - "ghs_[A-Za-z0-9._-]{36,}", - TestValue::Multiple(&[ - "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token - "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data - ]), - ), - ( - "GitHub Refresh Token", - "ghr_[A-Za-z0-9]{76}", - TestValue::Single( - "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", - ), // not a real token - ), - ( - "GitHub App Installation Access Token v1", - "v1\\.[0-9A-Fa-f]{40}", - TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token - ), - ( - "GitLab PAT", - "glpat-[a-zA-Z0-9_]{20}", - TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), - ), - ( - "Slack OAuth v2 bot", - "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", - TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), - ), - ( - "Slack OAuth v2 user token", - "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", - TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), - ), - ( - "Slack webhook", - "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", - TestValue::Single( - "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", - ), - ), - ( - "Stripe test key", - "sk_test_[0-9a-zA-Z]{24}", - TestValue::Single("sk_test_1234567890abcdefghijklmnop"), - ), - ( - "Stripe live key", - "sk_live_[0-9a-zA-Z]{24}", - TestValue::Single("sk_live_1234567890abcdefghijklmnop"), - ), - ( - "Netlify authentication token", - "nf[pcoub]_[0-9a-zA-Z]{36}", - TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), - ), - ( - "npm token", - "npm_[A-Za-z0-9]{36}", - TestValue::Single("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"), - ), - ( - "Pulumi personal access token", - "pul-[0-9a-f]{40}", - TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), - ), -]; - -/// The `regex` expressions from [`SECRET_PATTERNS`] compiled into a `RegexSet`. -pub static SECRET_PATTERNS_RE: LazyLock = LazyLock::new(|| { - let exprs = SECRET_PATTERNS.iter().map(|f| f.1); - RegexSet::new(exprs).expect("Failed to build secrets regex") -}); - -#[cfg(test)] -mod tests { - use regex::Regex; - - use crate::secrets::{SECRET_PATTERNS, TestValue}; - - #[test] - fn test_secrets() { - for (name, regex, test) in SECRET_PATTERNS { - let re = - Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); - - match test { - TestValue::Single(test) => { - assert!(re.is_match(test), "{name} test failed!"); - } - TestValue::Multiple(tests) => { - for test_str in tests.iter() { - assert!( - re.is_match(test_str), - "{name} test with value \"{test_str}\" failed!" - ); - } - } - } - } - } - - #[test] - fn test_secrets_embedded() { - for (name, regex, test) in SECRET_PATTERNS { - let re = - Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); - - match test { - TestValue::Single(test) => { - let embedded = format!("some random text {test} some more random text"); - assert!(re.is_match(&embedded), "{name} embedded test failed!"); - } - TestValue::Multiple(tests) => { - for test_str in tests.iter() { - let embedded = format!("some random text {test_str} some more random text"); - assert!( - re.is_match(&embedded), - "{name} embedded test with value \"{test_str}\" failed!" - ); - } - } - } - } - } -} diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs deleted file mode 100644 index 5fb65c17..00000000 --- a/crates/atuin-client/src/settings.rs +++ /dev/null @@ -1,1855 +0,0 @@ -use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; -use tokio::sync::OnceCell; - -use atuin_common::record::HostId; -use atuin_common::utils; -use clap::ValueEnum; -use config::{ - Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, -}; -use eyre::{Context, Error, Result, bail, eyre}; -use fs_err::{File, create_dir_all}; -use humantime::parse_duration; -use regex::RegexSet; -use serde::{Deserialize, Serialize}; -use serde_with::DeserializeFromStr; -use time::{OffsetDateTime, UtcOffset, format_description::FormatItem, macros::format_description}; - -pub const HISTORY_PAGE_SIZE: i64 = 100; -static EXAMPLE_CONFIG: &str = include_str!("../config.toml"); - -static DATA_DIR: OnceLock = OnceLock::new(); -static META_CONFIG: OnceLock<(String, f64)> = OnceLock::new(); -static META_STORE: OnceCell = OnceCell::const_new(); - -pub(crate) mod meta; -pub mod watcher; - -/// Default sync address for Atuin's hosted service -pub const DEFAULT_SYNC_ADDRESS: &str = "https://api.atuin.sh"; - -#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] -pub enum SearchMode { - #[serde(rename = "prefix")] - Prefix, - - #[serde(rename = "fulltext")] - #[clap(aliases = &["fulltext"])] - FullText, - - #[serde(rename = "fuzzy")] - Fuzzy, - - #[serde(rename = "skim")] - Skim, - - #[serde(rename = "daemon-fuzzy")] - #[clap(aliases = &["daemon-fuzzy"])] - DaemonFuzzy, -} - -impl SearchMode { - pub fn as_str(&self) -> &'static str { - match self { - SearchMode::Prefix => "PREFIX", - SearchMode::FullText => "FULLTXT", - SearchMode::Fuzzy => "FUZZY", - SearchMode::Skim => "SKIM", - SearchMode::DaemonFuzzy => "DAEMON", - } - } - pub fn next(&self, settings: &Settings) -> Self { - match self { - SearchMode::Prefix => SearchMode::FullText, - // if the user is using skim, we go to skim - SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, - // if the user is using daemon-fuzzy, we go to daemon-fuzzy - SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { - SearchMode::DaemonFuzzy - } - // otherwise fuzzy. - SearchMode::FullText => SearchMode::Fuzzy, - SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, - } - } -} - -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum FilterMode { - #[serde(rename = "global")] - Global = 0, - - #[serde(rename = "host")] - Host = 1, - - #[serde(rename = "session")] - Session = 2, - - #[serde(rename = "directory")] - Directory = 3, - - #[serde(rename = "workspace")] - Workspace = 4, - - #[serde(rename = "session-preload")] - SessionPreload = 5, -} - -impl FilterMode { - pub fn as_str(&self) -> &'static str { - match self { - FilterMode::Global => "GLOBAL", - FilterMode::Host => "HOST", - FilterMode::Session => "SESSION", - FilterMode::Directory => "DIRECTORY", - FilterMode::Workspace => "WORKSPACE", - FilterMode::SessionPreload => "SESSION+", - } - } -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum ExitMode { - #[serde(rename = "return-original")] - ReturnOriginal, - - #[serde(rename = "return-query")] - ReturnQuery, -} - -// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged -// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum Dialect { - #[serde(rename = "us")] - Us, - - #[serde(rename = "uk")] - Uk, -} - -impl From for interim::Dialect { - fn from(d: Dialect) -> interim::Dialect { - match d { - Dialect::Uk => interim::Dialect::Uk, - Dialect::Us => interim::Dialect::Us, - } - } -} - -/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. -/// -/// Note that the parsing of this struct needs to be done before starting any -/// multithreaded runtime, otherwise it will fail on most Unix systems. -/// -/// See: -#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr, Serialize)] -pub struct Timezone(pub UtcOffset); -impl fmt::Display for Timezone { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} -/// format: <+|->[:[:]] -static OFFSET_FMT: &[FormatItem<'_>] = format_description!( - "[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]" -); -impl FromStr for Timezone { - type Err = Error; - - fn from_str(s: &str) -> Result { - // local timezone - if matches!(s.to_lowercase().as_str(), "l" | "local") { - // There have been some timezone issues, related to errors fetching it on some - // platforms - // Rather than fail to start, fallback to UTC. The user should still be able to specify - // their timezone manually in the config file. - let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); - return Ok(Self(offset)); - } - - if matches!(s.to_lowercase().as_str(), "0" | "utc") { - let offset = UtcOffset::UTC; - return Ok(Self(offset)); - } - - // offset from UTC - if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { - return Ok(Self(offset)); - } - - // IDEA: Currently named timezones are not supported, because the well-known crate - // for this is `chrono_tz`, which is not really interoperable with the datetime crate - // that we currently use - `time`. If ever we migrate to using `chrono`, this would - // be a good feature to add. - - bail!(r#""{s}" is not a valid timezone spec"#) - } -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum Style { - #[serde(rename = "auto")] - Auto, - - #[serde(rename = "full")] - Full, - - #[serde(rename = "compact")] - Compact, -} - -#[derive(Clone, Debug, Deserialize, Copy, Serialize)] -pub enum WordJumpMode { - #[serde(rename = "emacs")] - Emacs, - - #[serde(rename = "subl")] - Subl, -} - -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum KeymapMode { - #[serde(rename = "emacs")] - Emacs, - - #[serde(rename = "vim-normal")] - VimNormal, - - #[serde(rename = "vim-insert")] - VimInsert, - - #[serde(rename = "auto")] - Auto, -} - -impl KeymapMode { - pub fn as_str(&self) -> &'static str { - match self { - KeymapMode::Emacs => "EMACS", - KeymapMode::VimNormal => "VIMNORMAL", - KeymapMode::VimInsert => "VIMINSERT", - KeymapMode::Auto => "AUTO", - } - } -} - -// We want to translate the config to crossterm::cursor::SetCursorStyle, but -// the original type does not implement trait serde::Deserialize unfortunately. -// It seems impossible to implement Deserialize for external types when it is -// used in HashMap (https://stackoverflow.com/questions/67142663). We instead -// define an adapter type. -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum CursorStyle { - #[serde(rename = "default")] - DefaultUserShape, - - #[serde(rename = "blink-block")] - BlinkingBlock, - - #[serde(rename = "steady-block")] - SteadyBlock, - - #[serde(rename = "blink-underline")] - BlinkingUnderScore, - - #[serde(rename = "steady-underline")] - SteadyUnderScore, - - #[serde(rename = "blink-bar")] - BlinkingBar, - - #[serde(rename = "steady-bar")] - SteadyBar, -} - -impl CursorStyle { - pub fn as_str(&self) -> &'static str { - match self { - CursorStyle::DefaultUserShape => "DEFAULT", - CursorStyle::BlinkingBlock => "BLINKBLOCK", - CursorStyle::SteadyBlock => "STEADYBLOCK", - CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", - CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", - CursorStyle::BlinkingBar => "BLINKBAR", - CursorStyle::SteadyBar => "STEADYBAR", - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Stats { - #[serde(default = "Stats::common_prefix_default")] - pub common_prefix: Vec, // sudo, etc. commands we want to strip off - #[serde(default = "Stats::common_subcommands_default")] - pub common_subcommands: Vec, // kubectl, commands we should consider subcommands for - #[serde(default = "Stats::ignored_commands_default")] - pub ignored_commands: Vec, // cd, ls, etc. commands we want to completely hide from stats -} - -impl Stats { - fn common_prefix_default() -> Vec { - vec!["sudo", "doas"].into_iter().map(String::from).collect() - } - - fn common_subcommands_default() -> Vec { - vec![ - "apt", - "cargo", - "composer", - "dnf", - "docker", - "dotnet", - "git", - "go", - "ip", - "jj", - "kubectl", - "nix", - "nmcli", - "npm", - "pecl", - "pnpm", - "podman", - "port", - "systemctl", - "tmux", - "yarn", - ] - .into_iter() - .map(String::from) - .collect() - } - - fn ignored_commands_default() -> Vec { - vec![] - } -} - -impl Default for Stats { - fn default() -> Self { - Self { - common_prefix: Self::common_prefix_default(), - common_subcommands: Self::common_subcommands_default(), - ignored_commands: Self::ignored_commands_default(), - } - } -} - -/// Sync protocol type for authentication. -/// -/// This setting is primarily for development/testing. When not explicitly set, -/// the protocol is inferred from the sync_address: -/// - Default sync address (api.atuin.sh) → Hub protocol -/// - Custom sync address → Legacy protocol -/// -/// Set explicitly to "hub" to use Hub authentication with a custom sync_address -/// (useful for local development against a Hub instance). -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize, Default)] -#[serde(rename_all = "lowercase")] -pub enum SyncProtocol { - /// Use legacy CLI authentication (Token from CLI register/login) - #[default] - Legacy, -} - -/// Resolved authentication state for sync operations. -/// -/// Determined at runtime by examining which tokens are available and what -/// server the client is configured to talk to. Operations use this to pick -/// the right auth header and endpoint style. -#[cfg(feature = "sync")] -#[derive(Debug, Clone)] -pub enum SyncAuth { - /// Self-hosted Rust server. Uses `Authorization: Token ` and - /// legacy endpoints. - Legacy { token: String }, - - /// Not authenticated at all. Contains an actionable user-facing message. - NotLoggedIn { reason: String }, -} - -#[cfg(feature = "sync")] -impl SyncAuth { - /// Convert into the auth token type used by the API client. - /// - /// Returns an error with an actionable message for `NotLoggedIn`. - pub fn into_auth_token(self) -> Result { - use crate::api_client::AuthToken; - match self { - SyncAuth::Legacy { token } => Ok(AuthToken::Token(token)), - SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), - } - } -} - -#[derive(Clone, Debug, Deserialize, Default, Serialize)] -pub struct Keys { - pub scroll_exits: bool, - pub exit_past_line_start: bool, - pub accept_past_line_end: bool, - pub accept_past_line_start: bool, - pub accept_with_backspace: bool, - pub prefix: String, -} - -impl Keys { - /// The standard default values for all `[keys]` options. - /// These match the config defaults set in `builder_with_data_dir()`. - pub fn standard_defaults() -> Self { - Keys { - scroll_exits: true, - exit_past_line_start: true, - accept_past_line_end: true, - accept_past_line_start: false, - accept_with_backspace: false, - prefix: "a".to_string(), - } - } - - /// Returns true if any value differs from the standard defaults. - pub fn has_non_default_values(&self) -> bool { - let d = Self::standard_defaults(); - self.scroll_exits != d.scroll_exits - || self.exit_past_line_start != d.exit_past_line_start - || self.accept_past_line_end != d.accept_past_line_end - || self.accept_past_line_start != d.accept_past_line_start - || self.accept_with_backspace != d.accept_with_backspace - || self.prefix != d.prefix - } -} - -/// A single rule within a conditional keybinding config. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct KeyRuleConfig { - /// Optional condition expression (e.g. "cursor-at-start", "input-empty && no-results"). - /// If absent, the rule always matches. - #[serde(default)] - pub when: Option, - /// The action to perform (e.g. "exit", "cursor-left", "accept"). - pub action: String, -} - -/// A keybinding config value: either a simple action string or an ordered list of conditional rules. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum KeyBindingConfig { - /// Simple unconditional binding: `"ctrl-c" = "return-original"` - Simple(String), - /// Conditional binding: `"left" = [{ when = "cursor-at-start", action = "exit" }, { action = "cursor-left" }]` - Rules(Vec), -} - -/// User-facing keymap configuration. Each mode maps key strings to bindings. -/// Keys present here override the defaults for that key; unmentioned keys keep defaults. -#[derive(Clone, Debug, Deserialize, Serialize, Default)] -pub struct KeymapConfig { - #[serde(default)] - pub emacs: HashMap, - #[serde(default, rename = "vim-normal")] - pub vim_normal: HashMap, - #[serde(default, rename = "vim-insert")] - pub vim_insert: HashMap, - #[serde(default)] - pub inspector: HashMap, - #[serde(default)] - pub prefix: HashMap, -} - -impl KeymapConfig { - /// Returns true if no keybinding overrides are configured in any mode. - pub fn is_empty(&self) -> bool { - self.emacs.is_empty() - && self.vim_normal.is_empty() - && self.vim_insert.is_empty() - && self.inspector.is_empty() - && self.prefix.is_empty() - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Preview { - pub strategy: PreviewStrategy, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Theme { - /// Name of desired theme ("default" for base) - pub name: String, - - /// Whether any available additional theme debug should be shown - pub debug: Option, - - /// How many levels of parenthood will be traversed if needed - pub max_depth: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Daemon { - /// Use the daemon to sync - /// If enabled, history hooks are routed through the daemon. - #[serde(alias = "enable")] - pub enabled: bool, - - /// Automatically start and manage a local daemon when needed. - pub autostart: bool, - - /// The daemon will handle sync on an interval. How often to sync, in seconds. - pub sync_frequency: u64, - - /// The path to the unix socket used by the daemon - pub socket_path: String, - - /// Path to the daemon pidfile used for process coordination. - pub pidfile_path: String, - - /// Use a socket passed via systemd's socket activation protocol, instead of the path - pub systemd_socket: bool, - - /// The port that should be used for TCP on non unix systems - pub tcp_port: u64, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Search { - /// The list of enabled filter modes, in order of priority. - pub filters: Vec, - - /// The recency score multiplier for the search index (default: 1.0). - /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. - pub recency_score_multiplier: f64, - - /// The frequency score multiplier for the search index (default: 1.0). - /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. - pub frequency_score_multiplier: f64, - - /// The overall frecency score multiplier for the search index (default: 1.0). - /// Applied after combining recency and frequency scores. - pub frecency_score_multiplier: f64, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Tmux { - /// Enable using atuin with tmux popup (tmux >= 3.2) - pub enabled: bool, - - /// Width of the tmux popup (percentage) - pub width: String, - - /// Height of the tmux popup (percentage) - pub height: String, -} - -/// Log level for file logging. Maps to tracing's LevelFilter. -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum LogLevel { - Trace, - Debug, - #[default] - Info, - Warn, - Error, -} - -impl LogLevel { - /// Convert to a tracing directive string for use with EnvFilter. - pub fn as_directive(&self) -> &'static str { - match self { - LogLevel::Trace => "trace", - LogLevel::Debug => "debug", - LogLevel::Info => "info", - LogLevel::Warn => "warn", - LogLevel::Error => "error", - } - } -} - -/// Configuration for a specific log type (search or daemon). -#[derive(Clone, Debug, Default, Deserialize, Serialize)] -pub struct LogConfig { - /// Log file name (relative to dir) or absolute path. - pub file: String, - - /// Override global enabled setting for this log type. - pub enabled: Option, - - /// Override global level setting for this log type. - pub level: Option, - - /// Override global retention days setting for this log type. - pub retention: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Logs { - /// Enable file logging globally. Defaults to true. - #[serde(default = "Logs::default_enabled")] - pub enabled: bool, - - /// Directory for log files. Defaults to ~/.atuin/logs - pub dir: String, - - /// Default log level for file logging. Defaults to "info". - /// Note: ATUIN_LOG environment variable overrides this. - #[serde(default)] - pub level: LogLevel, - - /// Default retention days for log files. Defaults to 4. - #[serde(default = "Logs::default_retention")] - pub retention: u64, - - /// Search log settings - #[serde(default)] - pub search: LogConfig, - - /// Daemon log settings - #[serde(default)] - pub daemon: LogConfig, - - /// AI log settings - #[serde(default)] - pub ai: LogConfig, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct Ai { - /// Whether or not the AI features are enabled. - pub enabled: Option, - - /// The address of the Atuin AI endpoint. Used for AI features like command generation. - /// Only necessary for custom AI endpoints. - pub endpoint: Option, - - /// The API token for the Atuin AI endpoint. Used for AI features like command generation. - /// Only necessary for custom AI endpoints. - pub api_token: Option, - - /// Path to the AI sessions database. - pub db_path: String, - - /// The maximum time in minutes that an AI session can be automatically resumed. - pub session_continue_minutes: i64, - - /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. - #[serde(default)] - pub send_cwd: Option, - - /// Configuration for what context is sent in the opening AI request. - #[serde(default)] - pub opening: AiOpening, - - /// Tool capability flags. - #[serde(default)] - pub capabilities: AiCapabilities, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct AiCapabilities { - /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_history_search: Option, - /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_history_output: Option, - /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_file_tools: Option, - /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). - pub enable_command_execution: Option, -} - -#[derive(Default, Clone, Debug, Deserialize, Serialize)] -pub struct AiOpening { - /// Whether or not to send the current working directory to the AI endpoint. - pub send_cwd: Option, - - /// Whether or not to send the last command as context in the opening AI request. - pub send_last_command: Option, -} - -impl Default for Preview { - fn default() -> Self { - Self { - strategy: PreviewStrategy::Auto, - } - } -} - -impl Default for Theme { - fn default() -> Self { - Self { - name: "".to_string(), - debug: None::, - max_depth: Some(10), - } - } -} - -impl Default for Daemon { - fn default() -> Self { - Self { - enabled: false, - autostart: false, - sync_frequency: 300, - socket_path: "".to_string(), - pidfile_path: "".to_string(), - systemd_socket: false, - tcp_port: 8889, - } - } -} - -impl Default for Logs { - fn default() -> Self { - Self { - enabled: true, - dir: "".to_string(), - level: LogLevel::default(), - retention: Self::default_retention(), - search: LogConfig { - file: "search.log".to_string(), - ..Default::default() - }, - daemon: LogConfig { - file: "daemon.log".to_string(), - ..Default::default() - }, - ai: LogConfig { - file: "ai.log".to_string(), - ..Default::default() - }, - } - } -} - -impl Logs { - fn default_enabled() -> bool { - true - } - - fn default_retention() -> u64 { - 4 - } - - /// Returns whether search logging is enabled. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_enabled(&self) -> bool { - self.search.enabled.unwrap_or(self.enabled) - } - - /// Returns whether daemon logging is enabled. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_enabled(&self) -> bool { - self.daemon.enabled.unwrap_or(self.enabled) - } - - /// Returns whether AI logging is enabled. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_enabled(&self) -> bool { - self.ai.enabled.unwrap_or(self.enabled) - } - - /// Returns the log level for search logging. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_level(&self) -> LogLevel { - self.search.level.unwrap_or(self.level) - } - - /// Returns the log level for daemon logging. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_level(&self) -> LogLevel { - self.daemon.level.unwrap_or(self.level) - } - - /// Returns the log level for AI logging. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_level(&self) -> LogLevel { - self.ai.level.unwrap_or(self.level) - } - - /// Returns the retention days for search logging. - /// Uses search-specific setting if set, otherwise falls back to global. - pub fn search_retention(&self) -> u64 { - self.search.retention.unwrap_or(self.retention) - } - - /// Returns the retention days for daemon logging. - /// Uses daemon-specific setting if set, otherwise falls back to global. - pub fn daemon_retention(&self) -> u64 { - self.daemon.retention.unwrap_or(self.retention) - } - - /// Returns the retention days for AI logging. - /// Uses AI-specific setting if set, otherwise falls back to global. - pub fn ai_retention(&self) -> u64 { - self.ai.retention.unwrap_or(self.retention) - } - - /// Returns the full path for the search log file. - pub fn search_path(&self) -> PathBuf { - let path = PathBuf::from(&self.search.file); - PathBuf::from(&self.dir).join(path) - } - - /// Returns the full path for the daemon log file. - pub fn daemon_path(&self) -> PathBuf { - let path = PathBuf::from(&self.daemon.file); - PathBuf::from(&self.dir).join(path) - } - - /// Returns the full path for the AI log file. - pub fn ai_path(&self) -> PathBuf { - let path = PathBuf::from(&self.ai.file); - PathBuf::from(&self.dir).join(path) - } -} - -impl Default for Search { - fn default() -> Self { - Self { - filters: vec![ - FilterMode::Global, - FilterMode::Host, - FilterMode::Session, - FilterMode::SessionPreload, - FilterMode::Workspace, - FilterMode::Directory, - ], - - recency_score_multiplier: 1.0, - frequency_score_multiplier: 1.0, - frecency_score_multiplier: 1.0, - } - } -} - -impl Default for Tmux { - fn default() -> Self { - Self { - enabled: false, - width: "80%".to_string(), - height: "60%".to_string(), - } - } -} - -// The preview height strategy also takes max_preview_height into account. -#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] -pub enum PreviewStrategy { - // Preview height is calculated for the length of the selected command. - #[serde(rename = "auto")] - Auto, - - // Preview height is calculated for the length of the longest command stored in the history. - #[serde(rename = "static")] - Static, - - // max_preview_height is used as fixed height. - #[serde(rename = "fixed")] - Fixed, -} - -/// Column types available for the interactive search UI. -#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)] -#[serde(rename_all = "lowercase")] -pub enum UiColumnType { - /// Command execution duration (e.g., "123ms") - Duration, - /// Relative time since execution (e.g., "59s ago") - Time, - /// Absolute timestamp (e.g., "2025-01-22 14:35") - Datetime, - /// Working directory - Directory, - /// Hostname - Host, - /// Username - User, - /// Exit code - Exit, - /// The command itself (should be last, expands to fill) - Command, -} - -impl UiColumnType { - /// Returns the default width for this column type (in characters). - /// The Command column returns 0 as it expands to fill remaining space. - pub fn default_width(&self) -> u16 { - match self { - UiColumnType::Duration => 5, // "814ms" - UiColumnType::Time => 9, // "459ms ago" - UiColumnType::Datetime => 16, // "2025-01-22 14:35" - UiColumnType::Directory => 20, - UiColumnType::Host => 15, - UiColumnType::User => 10, - UiColumnType::Exit => { - if cfg!(windows) { - 11 // 32-bit integer on Windows: "-1978335212" - } else { - 3 // Usually a byte on Unix - } - } - UiColumnType::Command => 0, // Expands to fill - } - } -} - -/// A column configuration with type and optional custom width. -/// Can be specified as just a string (uses default width) or as an object with type and width. -#[derive(Clone, Debug, Serialize)] -pub struct UiColumn { - pub column_type: UiColumnType, - pub width: u16, - /// If true, this column expands to fill remaining space. Only one column should expand. - pub expand: bool, -} - -impl UiColumn { - pub fn new(column_type: UiColumnType) -> Self { - Self { - width: column_type.default_width(), - expand: column_type == UiColumnType::Command, - column_type, - } - } - - pub fn with_width(column_type: UiColumnType, width: u16) -> Self { - Self { - column_type, - width, - expand: column_type == UiColumnType::Command, - } - } -} - -// Custom deserialize to handle both string and object formats: -// "duration" or { type = "duration", width = 8, expand = true } -impl<'de> serde::Deserialize<'de> for UiColumn { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::{self, MapAccess, Visitor}; - - struct UiColumnVisitor; - - impl<'de> Visitor<'de> for UiColumnVisitor { - type Value = UiColumn; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str( - "a column type string or an object with 'type' and optional 'width'/'expand'", - ) - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let column_type: UiColumnType = - serde::Deserialize::deserialize(serde::de::value::StrDeserializer::new(value))?; - Ok(UiColumn::new(column_type)) - } - - fn visit_map(self, mut map: M) -> Result - where - M: MapAccess<'de>, - { - let mut column_type: Option = None; - let mut width: Option = None; - let mut expand: Option = None; - - while let Some(key) = map.next_key::()? { - match key.as_str() { - "type" => { - column_type = Some(map.next_value()?); - } - "width" => { - width = Some(map.next_value()?); - } - "expand" => { - expand = Some(map.next_value()?); - } - _ => { - let _: serde::de::IgnoredAny = map.next_value()?; - } - } - } - - let column_type = column_type.ok_or_else(|| de::Error::missing_field("type"))?; - let width = width.unwrap_or_else(|| column_type.default_width()); - let expand = expand.unwrap_or(column_type == UiColumnType::Command); - Ok(UiColumn { - column_type, - width, - expand, - }) - } - } - - deserializer.deserialize_any(UiColumnVisitor) - } -} - -/// UI-specific settings for the interactive search. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Ui { - /// Columns to display in interactive search, from left to right. - /// The indicator column (" > ") is always shown first implicitly. - /// The "command" column should be last as it expands to fill remaining space. - /// Can be simple strings or objects with type and width. - #[serde(default = "Ui::default_columns")] - pub columns: Vec, -} - -impl Ui { - fn default_columns() -> Vec { - vec![ - UiColumn::new(UiColumnType::Duration), - UiColumn::new(UiColumnType::Time), - UiColumn::new(UiColumnType::Command), - ] - } - - /// Validate the UI configuration. - /// Returns an error if more than one column has expand = true. - pub fn validate(&self) -> Result<()> { - let expand_count = self.columns.iter().filter(|c| c.expand).count(); - if expand_count > 1 { - bail!( - "Only one column can have expand = true, but {} columns are set to expand", - expand_count - ); - } - Ok(()) - } -} - -impl Default for Ui { - fn default() -> Self { - Self { - columns: Self::default_columns(), - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { - pub data_dir: Option, - pub dialect: Dialect, - pub timezone: Timezone, - pub style: Style, - pub auto_sync: bool, - pub update_check: bool, - - /// The sync address for atuin. - pub sync_address: String, - - #[serde(default)] - pub sync_protocol: SyncProtocol, - - pub sync_frequency: String, - pub db_path: String, - pub record_store_path: String, - pub key_path: String, - pub search_mode: SearchMode, - pub filter_mode: Option, - pub filter_mode_shell_up_key_binding: Option, - pub search_mode_shell_up_key_binding: Option, - pub shell_up_key_binding: bool, - pub inline_height: u16, - pub inline_height_shell_up_key_binding: Option, - pub invert: bool, - pub show_preview: bool, - pub max_preview_height: u16, - pub show_help: bool, - pub show_tabs: bool, - pub show_numeric_shortcuts: bool, - pub auto_hide_height: u16, - pub exit_mode: ExitMode, - pub keymap_mode: KeymapMode, - pub keymap_mode_shell: KeymapMode, - pub keymap_cursor: HashMap, - pub word_jump_mode: WordJumpMode, - pub word_chars: String, - pub scroll_context_lines: usize, - pub history_format: String, - pub strip_trailing_whitespace: bool, - pub prefers_reduced_motion: bool, - pub store_failed: bool, - pub no_mouse: bool, - - #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] - pub history_filter: RegexSet, - - #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] - pub cwd_filter: RegexSet, - - pub secrets_filter: bool, - pub workspaces: bool, - pub ctrl_n_shortcuts: bool, - - pub network_connect_timeout: u64, - pub network_timeout: u64, - pub local_timeout: f64, - pub enter_accept: bool, - pub smart_sort: bool, - pub command_chaining: bool, - - #[serde(default)] - pub stats: Stats, - - #[serde(default)] - pub keys: Keys, - - #[serde(default)] - pub keymap: KeymapConfig, - - #[serde(default)] - pub preview: Preview, - - #[serde(default)] - pub daemon: Daemon, - - #[serde(default)] - pub search: Search, - - #[serde(default)] - pub theme: Theme, - - #[serde(default)] - pub ui: Ui, - - #[serde(default)] - pub tmux: Tmux, - - #[serde(default)] - pub logs: Logs, - - #[serde(default)] - pub meta: meta::Settings, -} - -impl Settings { - pub fn utc() -> Self { - Self::builder() - .expect("Could not build default") - .set_override("timezone", "0") - .expect("failed to override timezone with UTC") - .build() - .expect("Could not build config") - .try_deserialize() - .expect("Could not deserialize config") - } - - pub(crate) fn effective_data_dir() -> PathBuf { - DATA_DIR - .get() - .cloned() - .unwrap_or_else(atuin_common::utils::data_dir) - } - - // -- Meta store: lazily initialized on first access -- - - pub async fn meta_store() -> Result<&'static crate::meta::MetaStore> { - META_STORE - .get_or_try_init(|| async { - let (db_path, timeout) = META_CONFIG.get().ok_or_else(|| { - eyre!("meta store config not set — Settings::new() has not been called") - })?; - crate::meta::MetaStore::new(db_path, *timeout).await - }) - .await - } - - pub async fn host_id() -> Result { - Self::meta_store().await?.host_id().await - } - - pub async fn last_sync() -> Result { - Self::meta_store().await?.last_sync().await - } - - pub async fn save_sync_time() -> Result<()> { - Self::meta_store().await?.save_sync_time().await - } - - pub async fn last_version_check() -> Result { - Self::meta_store().await?.last_version_check().await - } - - pub async fn save_version_check_time() -> Result<()> { - Self::meta_store().await?.save_version_check_time().await - } - - pub async fn should_sync(&self) -> Result { - if !self.auto_sync || !Self::meta_store().await?.logged_in().await? { - return Ok(false); - } - - if self.sync_frequency == "0" { - return Ok(true); - } - - match parse_duration(self.sync_frequency.as_str()) { - Ok(d) => { - let d = time::Duration::try_from(d)?; - Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) - } - Err(e) => Err(eyre!("failed to check sync: {}", e)), - } - } - - pub async fn logged_in(&self) -> Result { - Self::meta_store().await?.logged_in().await - } - - pub async fn session_token(&self) -> Result { - match Self::meta_store().await?.session_token().await? { - Some(token) => Ok(token), - None => Err(eyre!("Tried to load session; not logged in")), - } - } - - /// Examines the configured sync target and available tokens to determine - /// the correct auth strategy. Also performs cleanup of mis-stored tokens - /// (e.g. a CLI token incorrectly saved in the Hub session slot). - #[cfg(feature = "sync")] - pub async fn resolve_sync_auth(&self) -> SyncAuth { - let meta = match Self::meta_store().await { - Ok(m) => m, - Err(e) => { - return SyncAuth::NotLoggedIn { - reason: format!("Failed to open meta store: {e}"), - }; - } - }; - - // Self-hosted / legacy server - match meta.session_token().await { - Ok(Some(token)) => SyncAuth::Legacy { token }, - _ => SyncAuth::NotLoggedIn { - reason: "Not logged in. Run 'atuin login' to authenticate \ - with your sync server." - .into(), - }, - } - } - - /// Returns the appropriate auth token for sync operations. - /// - /// Delegates to [`resolve_sync_auth`] and converts the result to an - /// `AuthToken`. Callers that need to distinguish between auth states - /// (e.g. to show different UI) should call `resolve_sync_auth` directly. - #[cfg(feature = "sync")] - pub async fn sync_auth_token(&self) -> Result { - self.resolve_sync_auth().await.into_auth_token() - } - - pub fn default_filter_mode(&self, git_root: bool) -> FilterMode { - self.filter_mode - .filter(|x| self.search.filters.contains(x)) - .or_else(|| { - self.search - .filters - .iter() - .find(|x| match (x, git_root, self.workspaces) { - (FilterMode::Workspace, true, true) => true, - (FilterMode::Workspace, _, _) => false, - (_, _, _) => true, - }) - .copied() - }) - .unwrap_or(FilterMode::Global) - } - - pub fn builder() -> Result> { - Self::builder_with_data_dir(&atuin_common::utils::data_dir()) - } - - fn builder_with_data_dir(data_dir: &std::path::Path) -> Result> { - let db_path = data_dir.join("history.db"); - let record_store_path = data_dir.join("records.db"); - let kv_path = data_dir.join("kv.db"); - let scripts_path = data_dir.join("scripts.db"); - let ai_sessions_path = data_dir.join("ai_sessions.db"); - let socket_path = atuin_common::utils::runtime_dir().join("atuin.sock"); - let pidfile_path = data_dir.join("atuin-daemon.pid"); - let logs_dir = atuin_common::utils::logs_dir(); - - let key_path = data_dir.join("key"); - let meta_path = data_dir.join("meta.db"); - - Ok(Config::builder() - .set_default("history_format", "{time}\t{command}\t{duration}")? - .set_default("db_path", db_path.to_str())? - .set_default("record_store_path", record_store_path.to_str())? - .set_default("key_path", key_path.to_str())? - .set_default("dialect", "us")? - .set_default("timezone", "local")? - .set_default("auto_sync", true)? - .set_default("update_check", cfg!(feature = "check-update"))? - .set_default("sync_address", "https://api.atuin.sh")? - .set_default("sync_frequency", "5m")? - .set_default("search_mode", "fuzzy")? - .set_default("filter_mode", None::)? - .set_default("style", "compact")? - .set_default("inline_height", 40)? - .set_default("show_preview", true)? - .set_default("preview.strategy", "auto")? - .set_default("max_preview_height", 4)? - .set_default("show_help", true)? - .set_default("show_tabs", true)? - .set_default("show_numeric_shortcuts", true)? - .set_default("auto_hide_height", 8)? - .set_default("invert", false)? - .set_default("exit_mode", "return-original")? - .set_default("word_jump_mode", "emacs")? - .set_default( - "word_chars", - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", - )? - .set_default("scroll_context_lines", 1)? - .set_default("shell_up_key_binding", false)? - .set_default("workspaces", false)? - .set_default("ctrl_n_shortcuts", false)? - .set_default("secrets_filter", true)? - .set_default("strip_trailing_whitespace", true)? - .set_default("network_connect_timeout", 5)? - .set_default("network_timeout", 30)? - .set_default("local_timeout", 2.0)? - // enter_accept defaults to false here, but true in the default config file. The dissonance is - // intentional! - // Existing users will get the default "False", so we don't mess with any potential - // muscle memory. - // New users will get the new default, that is more similar to what they are used to. - .set_default("enter_accept", false)? - .set_default("keys.scroll_exits", true)? - .set_default("keys.accept_past_line_end", true)? - .set_default("keys.exit_past_line_start", true)? - .set_default("keys.accept_past_line_start", false)? - .set_default("keys.accept_with_backspace", false)? - .set_default("keys.prefix", "a")? - .set_default("keymap_mode", "emacs")? - .set_default("keymap_mode_shell", "auto")? - .set_default("keymap_cursor", HashMap::::new())? - .set_default("smart_sort", false)? - .set_default("command_chaining", false)? - .set_default("store_failed", true)? - .set_default("daemon.sync_frequency", 300)? - .set_default("daemon.enabled", false)? - .set_default("daemon.autostart", false)? - .set_default("daemon.socket_path", socket_path.to_str())? - .set_default("daemon.pidfile_path", pidfile_path.to_str())? - .set_default("daemon.systemd_socket", false)? - .set_default("daemon.tcp_port", 8889)? - .set_default("logs.enabled", true)? - .set_default("logs.dir", logs_dir.to_str())? - .set_default("logs.level", "info")? - .set_default("logs.search.file", "search.log")? - .set_default("logs.daemon.file", "daemon.log")? - .set_default("logs.ai.file", "ai.log")? - .set_default("kv.db_path", kv_path.to_str())? - .set_default("scripts.db_path", scripts_path.to_str())? - .set_default("search.recency_score_multiplier", 1.0)? - .set_default("search.frequency_score_multiplier", 1.0)? - .set_default("search.frecency_score_multiplier", 1.0)? - .set_default("meta.db_path", meta_path.to_str())? - .set_default("ai.db_path", ai_sessions_path.to_str())? - .set_default("ai.session_continue_minutes", 60)? - .set_default("ai.send_cwd", false)? - .set_default("ai.opening.send_cwd", false)? - .set_default("ai.opening.send_last_command", false)? - .set_default( - "search.filters", - vec![ - "global", - "host", - "session", - "workspace", - "directory", - "session-preload", - ], - )? - .set_default("theme.name", "default")? - .set_default("theme.debug", None::)? - .set_default("tmux.enabled", false)? - .set_default("tmux.width", "80%")? - .set_default("tmux.height", "60%")? - .set_default( - "prefers_reduced_motion", - std::env::var("NO_MOTION") - .ok() - .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) - .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), - )? - .set_default("no_mouse", false)? - .add_source( - Environment::with_prefix("atuin") - .prefix_separator("_") - .separator("__"), - )) - } - - pub fn get_config_path() -> Result { - let config_dir = atuin_common::utils::config_dir(); - - create_dir_all(&config_dir) - .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; - - let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - let mut config_file = PathBuf::new(); - config_file.push(config_dir); - config_file - }; - - config_file.push("config.toml"); - - Ok(config_file) - } - - /// Build a merged `Config` from defaults, config file, and environment. - /// - /// This resolves `data_dir`, initializes the data directory on disk, - /// and layers defaults → config file → env overrides. Both `new()` and - /// `get_config_value()` use this so the resolution logic lives in one place. - fn build_config() -> Result { - let config_file = Self::get_config_path()?; - - // extract data_dir first so we can use it as the base for other path defaults - let effective_data_dir = if config_file.exists() { - #[derive(Deserialize, Default)] - struct DataDirOnly { - data_dir: Option, - } - - let config_file_str = config_file - .to_str() - .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; - - let partial_config = Config::builder() - .add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) - .add_source( - Environment::with_prefix("atuin") - .prefix_separator("_") - .separator("__"), - ) - .build() - .ok(); - - let custom_data_dir = partial_config - .and_then(|c| c.try_deserialize::().ok()) - .and_then(|d| d.data_dir); - - match custom_data_dir { - Some(dir) => { - let expanded = shellexpand::full(&dir) - .map_err(|e| eyre!("failed to expand data_dir path: {}", e))?; - PathBuf::from(expanded.as_ref()) - } - None => atuin_common::utils::data_dir(), - } - } else { - atuin_common::utils::data_dir() - }; - - DATA_DIR.set(effective_data_dir.clone()).ok(); - - create_dir_all(&effective_data_dir) - .wrap_err_with(|| format!("could not create dir {effective_data_dir:?}"))?; - - let mut config_builder = Self::builder_with_data_dir(&effective_data_dir)?; - - config_builder = if config_file.exists() { - let config_file_str = config_file - .to_str() - .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; - config_builder.add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) - } else { - let mut file = File::create(config_file).wrap_err("could not create config file")?; - file.write_all(EXAMPLE_CONFIG.as_bytes()) - .wrap_err("could not write default config file")?; - - config_builder - }; - - // all paths should be expanded - let built = config_builder.build_cloned()?; - config_builder = [ - "db_path", - "record_store_path", - "key_path", - "daemon.socket_path", - "daemon.pidfile_path", - "logs.dir", - "logs.search.file", - "logs.daemon.file", - ] - .iter() - .map(|key| (key, built.get_string(key).unwrap_or_default())) - .filter_map(|(key, value)| match Self::expand_path(value) { - Ok(expanded) => Some((key, expanded)), - Err(e) => { - log::warn!("failed to expand path for {key}: {e}"); - None - } - }) - .fold(config_builder, |builder, (key, value)| { - builder - .set_override(key, value) - .unwrap_or_else(|_| panic!("failed to set absolute path override for {key}")) - }); - - config_builder.build().map_err(Into::into) - } - - /// Look up a single config value by dotted key (e.g. `"daemon.sync_frequency"`). - /// - /// Returns the effective value after merging defaults, config file, and - /// environment — without the side-effects of full `Settings` construction - /// (meta store init, path expansion, etc.). - pub fn get_config_value(key: &str) -> Result { - let config = Self::build_config()?; - let value: config::Value = config - .get(key) - .map_err(|e| eyre!("failed to get config value '{}': {}", key, e))?; - Ok(Self::format_resolved_value(&value, key)) - } - - fn format_resolved_value(value: &config::Value, prefix: &str) -> String { - use config::ValueKind; - - match &value.kind { - ValueKind::Nil => String::new(), - ValueKind::Boolean(b) => b.to_string(), - ValueKind::I64(i) => i.to_string(), - ValueKind::I128(i) => i.to_string(), - ValueKind::U64(u) => u.to_string(), - ValueKind::U128(u) => u.to_string(), - ValueKind::Float(f) => f.to_string(), - ValueKind::String(s) => s.clone(), - ValueKind::Array(arr) => { - let items: Vec = arr - .iter() - .map(|v| Self::format_resolved_value(v, "")) - .collect(); - format!("[{}]", items.join(", ")) - } - ValueKind::Table(map) => { - let mut lines = Vec::new(); - let mut keys: Vec<_> = map.keys().collect(); - keys.sort(); - - for k in keys { - let v = &map[k]; - let full_key = if prefix.is_empty() { - k.clone() - } else { - format!("{}.{}", prefix, k) - }; - - match &v.kind { - ValueKind::Table(_) => { - lines.push(Self::format_resolved_value(v, &full_key)); - } - _ => { - lines.push(format!( - "{} = {}", - full_key, - Self::format_resolved_value(v, "") - )); - } - } - } - - lines.join("\n") - } - } - } - - pub fn new() -> Result { - let config = Self::build_config()?; - let settings: Settings = config - .try_deserialize() - .map_err(|e| eyre!("failed to deserialize: {}", e))?; - - // Validate UI settings - settings.ui.validate()?; - - // Register meta store config for lazy initialization on first access - META_CONFIG - .set((settings.meta.db_path.clone(), settings.local_timeout)) - .ok(); - - Ok(settings) - } - - fn expand_path(path: String) -> Result { - shellexpand::full(&path) - .map(|p| p.to_string()) - .map_err(|e| eyre!("failed to expand path: {}", e)) - } - - pub fn example_config() -> &'static str { - EXAMPLE_CONFIG - } - - pub fn paths_ok(&self) -> bool { - let paths = [ - &self.db_path, - &self.record_store_path, - &self.key_path, - &self.meta.db_path, - ]; - paths.iter().all(|p| !utils::broken_symlink(p)) - } -} - -impl Default for Settings { - fn default() -> Self { - // if this panics something is very wrong, as the default config - // does not build or deserialize into the settings struct - Self::builder() - .expect("Could not build default") - .build() - .expect("Could not build config") - .try_deserialize() - .expect("Could not deserialize config") - } -} - -/// Initialize the meta store configuration for testing. -/// -/// This should only be used in tests. It allows tests to bypass the normal -/// Settings::new() flow while still being able to use Settings::host_id() -/// and other meta store dependent functions. -/// -/// # Safety -/// This function is not thread-safe with concurrent calls to Settings::new() -/// or other meta store initialization. Only call from tests. -#[doc(hidden)] -pub fn init_meta_config_for_testing(meta_db_path: impl Into, local_timeout: f64) { - META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); -} - -#[cfg(test)] -pub(crate) fn test_local_timeout() -> f64 { - std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") - .ok() - .and_then(|x| x.parse().ok()) - // this hardcoded value should be replaced by a simple way to get the - // default local_timeout of Settings if possible - .unwrap_or(2.0) -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use eyre::Result; - - use super::Timezone; - - #[test] - fn can_parse_offset_timezone_spec() -> Result<()> { - assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); - assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); - assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); - assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); - - // single digit hours are allowed - assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); - assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); - assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); - assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); - - // fully qualified form - assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); - assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); - - // these offsets don't really exist but are supported anyway - assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); - assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); - assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); - assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); - - // require a leading sign for clarity - assert!(Timezone::from_str("5").is_err()); - assert!(Timezone::from_str("10:30").is_err()); - - Ok(()) - } - - #[test] - fn can_choose_workspace_filters_when_in_git_context() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = true; - - assert_eq!( - settings.default_filter_mode(true), - super::FilterMode::Workspace, - ); - - Ok(()) - } - - #[test] - fn wont_choose_workspace_filters_when_not_in_git_context() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = true; - - assert_eq!(settings.default_filter_mode(false), super::FilterMode::Host,); - - Ok(()) - } - - #[test] - fn wont_choose_workspace_filters_when_workspaces_disabled() -> Result<()> { - let mut settings = super::Settings::default(); - settings.search.filters = vec![ - super::FilterMode::Workspace, - super::FilterMode::Host, - super::FilterMode::Directory, - super::FilterMode::Session, - super::FilterMode::Global, - ]; - settings.workspaces = false; - - assert_eq!(settings.default_filter_mode(true), super::FilterMode::Host,); - - Ok(()) - } - - #[test] - fn builder_with_data_dir_uses_custom_paths() -> Result<()> { - use std::path::PathBuf; - - let custom_dir = PathBuf::from("/custom/data/dir"); - let builder = super::Settings::builder_with_data_dir(&custom_dir)?; - let config = builder.build()?; - - let db_path: String = config.get("db_path")?; - let key_path: String = config.get("key_path")?; - let record_store_path: String = config.get("record_store_path")?; - let kv_db_path: String = config.get("kv.db_path")?; - let scripts_db_path: String = config.get("scripts.db_path")?; - let meta_db_path: String = config.get("meta.db_path")?; - let daemon_socket_path: String = config.get("daemon.socket_path")?; - let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; - let daemon_autostart: bool = config.get("daemon.autostart")?; - - assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); - assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); - assert_eq!( - record_store_path, - custom_dir.join("records.db").to_str().unwrap() - ); - assert_eq!(kv_db_path, custom_dir.join("kv.db").to_str().unwrap()); - assert_eq!( - scripts_db_path, - custom_dir.join("scripts.db").to_str().unwrap() - ); - assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); - assert_eq!( - daemon_socket_path, - atuin_common::utils::runtime_dir() - .join("atuin.sock") - .to_str() - .unwrap() - ); - assert_eq!( - daemon_pidfile_path, - custom_dir.join("atuin-daemon.pid").to_str().unwrap() - ); - assert!(!daemon_autostart); - - Ok(()) - } - - #[test] - fn effective_data_dir_returns_default_when_not_set() { - let effective = super::Settings::effective_data_dir(); - let default = atuin_common::utils::data_dir(); - - assert!(effective.to_str().is_some()); - assert!(effective.ends_with("atuin") || effective == default); - } - - #[test] - fn keymap_config_deserializes_simple_binding() { - let json = r#"{"emacs": {"ctrl-c": "exit"}}"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert_eq!(config.emacs.len(), 1); - match &config.emacs["ctrl-c"] { - super::KeyBindingConfig::Simple(s) => assert_eq!(s, "exit"), - _ => panic!("expected Simple variant"), - } - } - - #[test] - fn keymap_config_deserializes_conditional_binding() { - let json = r#"{ - "emacs": { - "left": [ - {"when": "cursor-at-start", "action": "exit"}, - {"action": "cursor-left"} - ] - } - }"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - match &config.emacs["left"] { - super::KeyBindingConfig::Rules(rules) => { - assert_eq!(rules.len(), 2); - assert_eq!(rules[0].when.as_deref(), Some("cursor-at-start")); - assert_eq!(rules[0].action, "exit"); - assert!(rules[1].when.is_none()); - assert_eq!(rules[1].action, "cursor-left"); - } - _ => panic!("expected Rules variant"), - } - } - - #[test] - fn keymap_config_deserializes_vim_normal() { - let json = r#"{"vim-normal": {"j": "select-next", "k": "select-previous"}}"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert_eq!(config.vim_normal.len(), 2); - assert!(config.emacs.is_empty()); - } - - #[test] - fn keymap_config_is_empty_when_default() { - let config = super::KeymapConfig::default(); - assert!(config.is_empty()); - } - - #[test] - fn keymap_config_mixed_modes() { - let json = r#"{ - "emacs": {"ctrl-c": "exit"}, - "vim-normal": {"q": "exit"}, - "inspector": {"d": "delete"} - }"#; - let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); - assert!(!config.is_empty()); - assert_eq!(config.emacs.len(), 1); - assert_eq!(config.vim_normal.len(), 1); - assert_eq!(config.inspector.len(), 1); - assert!(config.vim_insert.is_empty()); - assert!(config.prefix.is_empty()); - } -} diff --git a/crates/atuin-client/src/settings/meta.rs b/crates/atuin-client/src/settings/meta.rs deleted file mode 100644 index 108d74ec..00000000 --- a/crates/atuin-client/src/settings/meta.rs +++ /dev/null @@ -1,17 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct Settings { - pub db_path: String, -} - -impl Default for Settings { - fn default() -> Self { - let dir = atuin_common::utils::data_dir(); - let path = dir.join("meta.db"); - - Self { - db_path: path.to_string_lossy().to_string(), - } - } -} diff --git a/crates/atuin-client/src/settings/watcher.rs b/crates/atuin-client/src/settings/watcher.rs deleted file mode 100644 index 740b8d12..00000000 --- a/crates/atuin-client/src/settings/watcher.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! Config file watching for automatic settings reload. -//! -//! This module provides a `SettingsWatcher` that monitors the config file -//! for changes and broadcasts updated settings via a `tokio::sync::watch` channel. -//! -//! # Example -//! -//! ```no_run -//! use atuin_client::settings::watcher::global_settings_watcher; -//! -//! async fn example() -> eyre::Result<()> { -//! let watcher = global_settings_watcher()?; -//! let mut rx = watcher.subscribe(); -//! -//! // React to settings changes -//! while rx.changed().await.is_ok() { -//! let settings = rx.borrow(); -//! println!("Settings updated!"); -//! } -//! Ok(()) -//! } -//! ``` - -use std::{ - path::PathBuf, - sync::{Arc, OnceLock}, - time::Duration, -}; - -use eyre::{Result, WrapErr}; -use log::{debug, error, info, warn}; -use notify::{ - Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher, - event::{EventKind, ModifyKind}, -}; -use tokio::sync::watch; - -use super::Settings; - -/// Global singleton for the settings watcher. -static SETTINGS_WATCHER: OnceLock> = OnceLock::new(); - -/// Get the global settings watcher singleton. -/// -/// Initializes the watcher on first call. Subsequent calls return the same instance. -/// The watcher monitors the config file for changes and broadcasts updates. -pub fn global_settings_watcher() -> Result<&'static SettingsWatcher> { - let result = SETTINGS_WATCHER.get_or_init(|| SettingsWatcher::new().map_err(|e| e.to_string())); - - match result { - Ok(watcher) => Ok(watcher), - Err(e) => Err(eyre::eyre!("{}", e)), - } -} - -/// Watches the config file for changes and broadcasts updated settings. -/// -/// Uses `notify` for cross-platform file watching and `tokio::sync::watch` -/// for efficient broadcast to multiple subscribers. -pub struct SettingsWatcher { - /// Receiver for settings updates. Clone this to subscribe. - rx: watch::Receiver>, - /// Keeps the file watcher alive for the lifetime of this struct. - _watcher: RecommendedWatcher, -} - -impl SettingsWatcher { - /// Create a new settings watcher. - /// - /// Loads initial settings and starts watching the config file for changes. - /// Changes are debounced (500ms) to avoid multiple reloads during saves. - pub fn new() -> Result { - let initial_settings = Arc::new(Settings::new()?); - let (tx, rx) = watch::channel(initial_settings); - - let config_path = Self::config_path(); - info!("starting config file watcher: {:?}", config_path); - - let watcher = Self::create_watcher(tx, config_path)?; - - Ok(Self { - rx, - _watcher: watcher, - }) - } - - /// Subscribe to settings updates. - /// - /// Returns a receiver that will be notified when settings change. - /// Use `changed().await` to wait for the next update, then `borrow()` - /// to access the current settings. - pub fn subscribe(&self) -> watch::Receiver> { - self.rx.clone() - } - - /// Get the current settings without subscribing to updates. - pub fn current(&self) -> Arc { - self.rx.borrow().clone() - } - - /// Get the config file path. - fn config_path() -> PathBuf { - let config_dir = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - atuin_common::utils::config_dir() - }; - config_dir.join("config.toml") - } - - /// Create the file watcher with debouncing. - fn create_watcher( - tx: watch::Sender>, - config_path: PathBuf, - ) -> Result { - // Channel for debouncing file events - let (debounce_tx, debounce_rx) = std::sync::mpsc::channel::<()>(); - - // Spawn debounce thread - let config_path_clone = config_path.clone(); - std::thread::spawn(move || { - Self::debounce_loop(debounce_rx, tx, config_path_clone); - }); - - // Clone config_path for use in the watcher callback - let config_path_for_watcher = config_path.clone(); - - // Canonicalize config path for reliable comparison on macOS - // (handles symlinks like /var -> /private/var) - let canonical_config_path = config_path_for_watcher - .canonicalize() - .unwrap_or_else(|_| config_path_for_watcher.clone()); - - // Create file watcher - let mut watcher = RecommendedWatcher::new( - move |res: Result| { - match res { - Ok(event) => { - // Defensive: if paths is empty, we can't filter, so assume - // it might be our config file and trigger a reload to be safe - if event.paths.is_empty() { - warn!( - "config watcher: event has no paths, triggering reload to be safe" - ); - let _ = debounce_tx.send(()); - return; - } - - // Only react to events for our specific config file - // (filter out editor temp files, backups, etc.) - let is_config_file = event.paths.iter().any(|path| { - // Canonicalize for reliable comparison (handles macOS symlinks) - let canonical_event_path = - path.canonicalize().unwrap_or_else(|_| path.clone()); - - // Check if this event is for our config file - // (either exact match or the file was renamed to our config) - canonical_event_path == canonical_config_path - || path.file_name() == config_path_for_watcher.file_name() - }); - - if !is_config_file { - return; - } - - // Only react to modify events (content changes) or creates - if matches!( - event.kind, - EventKind::Modify(ModifyKind::Data(_) | ModifyKind::Any) - | EventKind::Create(_) - ) { - debug!("config file event detected: {:?}", event); - // Send to debounce channel (ignore send errors - receiver might be gone) - let _ = debounce_tx.send(()); - } - } - Err(e) => { - error!("file watcher error: {}", e); - } - } - }, - NotifyConfig::default(), - ) - .wrap_err("failed to create file watcher")?; - - // Watch the config file's parent directory (some editors create new files) - let watch_path = config_path.parent().unwrap_or(&config_path); - - // Defensive: ensure watch path exists before trying to watch - if !watch_path.exists() { - warn!( - "config directory does not exist, creating it: {:?}", - watch_path - ); - std::fs::create_dir_all(watch_path) - .wrap_err_with(|| format!("failed to create config directory: {:?}", watch_path))?; - } - - watcher - .watch(watch_path, RecursiveMode::NonRecursive) - .wrap_err_with(|| format!("failed to watch config directory: {:?}", watch_path))?; - - info!("config file watcher initialized for: {:?}", watch_path); - Ok(watcher) - } - - /// Debounce loop that batches file events and reloads settings. - fn debounce_loop( - rx: std::sync::mpsc::Receiver<()>, - tx: watch::Sender>, - config_path: PathBuf, - ) { - const DEBOUNCE_DURATION: Duration = Duration::from_millis(500); - - loop { - // Wait for first event - if rx.recv().is_err() { - // Channel closed, watcher was dropped - debug!("config watcher debounce loop exiting"); - return; - } - - // Drain any additional events within debounce window - while rx.recv_timeout(DEBOUNCE_DURATION).is_ok() { - // Keep draining - } - - // Defensive: check if config file exists before reloading - // (handles case where file was deleted - we'll get notified when it's recreated) - if !config_path.exists() { - debug!( - "config file does not exist, skipping reload: {:?}", - config_path - ); - continue; - } - - // Now reload settings - info!("config file changed, reloading settings: {:?}", config_path); - match Settings::new() { - Ok(settings) => { - if tx.send(Arc::new(settings)).is_err() { - // All receivers dropped - debug!("all settings subscribers dropped, exiting"); - return; - } - info!("settings reloaded successfully"); - } - Err(e) => { - warn!("failed to reload settings: {}", e); - // Keep the old settings, don't broadcast the error - } - } - } - } -} diff --git a/crates/atuin-client/src/sync.rs b/crates/atuin-client/src/sync.rs deleted file mode 100644 index 2c902794..00000000 --- a/crates/atuin-client/src/sync.rs +++ /dev/null @@ -1,213 +0,0 @@ -use std::collections::HashSet; -use std::iter::FromIterator; - -use eyre::Result; - -use atuin_common::api::AddHistoryRequest; -use crypto_secretbox::Key; -use time::OffsetDateTime; - -use crate::{ - api_client, - database::Database, - encryption::{decrypt, encrypt, load_key}, - settings::Settings, -}; - -pub fn hash_str(string: &str) -> String { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(string.as_bytes()); - hex::encode(hasher.finalize()) -} - -// Currently sync is kinda naive, and basically just pages backwards through -// history. This means newly added stuff shows up properly! We also just use -// the total count in each database to indicate whether a sync is needed. -// I think this could be massively improved! If we had a way of easily -// indicating count per time period (hour, day, week, year, etc) then we can -// easily pinpoint where we are missing data and what needs downloading. Start -// with year, then find the week, then the day, then the hour, then download it -// all! The current naive approach will do for now. - -// Check if remote has things we don't, and if so, download them. -// Returns (num downloaded, total local) -async fn sync_download( - key: &Key, - force: bool, - client: &api_client::Client<'_>, - db: &impl Database, -) -> Result<(i64, i64)> { - debug!("starting sync download"); - - let remote_status = client.status().await?; - let remote_count = remote_status.count; - - // useful to ensure we don't even save something that hasn't yet been synced + deleted - let remote_deleted = - HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); - - let initial_local = db.history_count(true).await?; - let mut local_count = initial_local; - - let mut last_sync = if force { - OffsetDateTime::UNIX_EPOCH - } else { - Settings::last_sync().await? - }; - - let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; - - let host = if force { Some(String::from("")) } else { None }; - - while remote_count > local_count { - let page = client - .get_history(last_sync, last_timestamp, host.clone()) - .await?; - - let history: Vec<_> = page - .history - .iter() - // TODO: handle deletion earlier in this chain - .map(|h| serde_json::from_str(h).expect("invalid base64")) - .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) - .map(|mut h| { - if remote_deleted.contains(h.id.0.as_str()) { - h.deleted_at = Some(time::OffsetDateTime::now_utc()); - h.command = String::from(""); - } - - h - }) - .collect(); - - db.save_bulk(&history).await?; - - local_count = db.history_count(true).await?; - let remote_page_size = std::cmp::max(remote_status.page_size, 0) as usize; - - if history.len() < remote_page_size { - break; - } - - let page_last = history - .last() - .expect("could not get last element of page") - .timestamp; - - // in the case of a small sync frequency, it's possible for history to - // be "lost" between syncs. In this case we need to rewind the sync - // timestamps - if page_last == last_timestamp { - last_timestamp = OffsetDateTime::UNIX_EPOCH; - last_sync -= time::Duration::hours(1); - } else { - last_timestamp = page_last; - } - } - - for i in remote_status.deleted { - // we will update the stored history to have this data - // pretty much everything can be nullified - match db.load(i.as_str()).await? { - Some(h) => { - db.delete(h).await?; - } - _ => { - info!( - "could not delete history with id {}, not found locally", - i.as_str() - ); - } - } - } - - Ok((local_count - initial_local, local_count)) -} - -// Check if we have things remote doesn't, and if so, upload them -async fn sync_upload( - key: &Key, - _force: bool, - client: &api_client::Client<'_>, - db: &impl Database, -) -> Result<()> { - debug!("starting sync upload"); - - let remote_status = client.status().await?; - let remote_deleted: HashSet = HashSet::from_iter(remote_status.deleted.clone()); - - let initial_remote_count = client.count().await?; - let mut remote_count = initial_remote_count; - - let local_count = db.history_count(true).await?; - - debug!("remote has {remote_count}, we have {local_count}"); - - // first just try the most recent set - let mut cursor = OffsetDateTime::now_utc(); - - while local_count > remote_count { - let last = db.before(cursor, remote_status.page_size).await?; - let mut buffer = Vec::new(); - - if last.is_empty() { - break; - } - - for i in last { - let data = encrypt(&i, key)?; - let data = serde_json::to_string(&data)?; - - let add_hist = AddHistoryRequest { - id: i.id.to_string(), - timestamp: i.timestamp, - data, - hostname: hash_str(&i.hostname), - }; - - buffer.push(add_hist); - } - - // anything left over outside of the 100 block size - client.post_history(&buffer).await?; - cursor = buffer.last().unwrap().timestamp; - remote_count = client.count().await?; - - debug!("upload cursor: {cursor:?}"); - } - - let deleted = db.deleted().await?; - - for i in deleted { - if remote_deleted.contains(&i.id.to_string()) { - continue; - } - - info!("deleting {} on remote", i.id); - client.delete_history(i).await?; - } - - Ok(()) -} - -pub async fn sync(settings: &Settings, force: bool, db: &impl Database) -> Result<()> { - let client = api_client::Client::new( - &settings.sync_address, - settings.sync_auth_token().await?, - settings.network_connect_timeout, - settings.network_timeout, - )?; - - Settings::save_sync_time().await?; - - let key = load_key(settings)?; // encryption key - - sync_upload(&key, force, &client, db).await?; - - let download = sync_download(&key, force, &client, db).await?; - - debug!("sync downloaded {}", download.0); - - Ok(()) -} diff --git a/crates/atuin-client/src/theme.rs b/crates/atuin-client/src/theme.rs deleted file mode 100644 index a277ac13..00000000 --- a/crates/atuin-client/src/theme.rs +++ /dev/null @@ -1,831 +0,0 @@ -use config::{Config, File as ConfigFile, FileFormat}; -use log; -use palette::named; -use serde::{Deserialize, Serialize}; -use serde_json; -use std::collections::HashMap; -use std::error; -use std::io::{Error, ErrorKind}; -use std::path::PathBuf; -use std::sync::LazyLock; -use strum_macros; - -static DEFAULT_MAX_DEPTH: u8 = 10; - -// Collection of settable "meanings" that can have colors set. -// NOTE: You can add a new meaning here without breaking backwards compatibility but please: -// - update the atuin/docs repository, which has a list of available meanings -// - add a fallback in the MEANING_FALLBACKS below, so that themes which do not have it -// get a sensible fallback (see Title as an example) -#[derive( - Serialize, Deserialize, Copy, Clone, Hash, Debug, Eq, PartialEq, strum_macros::Display, -)] -#[strum(serialize_all = "camel_case")] -pub enum Meaning { - AlertInfo, - AlertWarn, - AlertError, - Annotation, - Base, - Guidance, - Important, - Title, - Muted, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ThemeConfig { - // Definition of the theme - pub theme: ThemeDefinitionConfigBlock, - - // Colors - pub colors: HashMap, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ThemeDefinitionConfigBlock { - /// Name of theme ("default" for base) - pub name: String, - - /// Whether any theme should be treated as a parent _if available_ - pub parent: Option, -} - -use crossterm::style::{Attribute, Attributes, Color, ContentStyle}; - -// For now, a theme is loaded as a mapping of meanings to colors, but it may be desirable to -// expand that in the future to general styles, so we populate a Meaning->ContentStyle hashmap. -pub struct Theme { - pub name: String, - pub parent: Option, - pub styles: HashMap, -} - -// Themes have a number of convenience functions for the most commonly used meanings. -// The general purpose `as_style` routine gives back a style, but for ease-of-use and to keep -// theme-related boilerplate minimal, the convenience functions give a color. -impl Theme { - // This is the base "default" color, for general text - pub fn get_base(&self) -> ContentStyle { - self.styles[&Meaning::Base] - } - - pub fn get_info(&self) -> ContentStyle { - self.get_alert(log::Level::Info) - } - - pub fn get_warning(&self) -> ContentStyle { - self.get_alert(log::Level::Warn) - } - - pub fn get_error(&self) -> ContentStyle { - self.get_alert(log::Level::Error) - } - - // The alert meanings may be chosen by the Level enum, rather than the methods above - // or the full Meaning enum, to simplify programmatic selection of a log-level. - pub fn get_alert(&self, severity: log::Level) -> ContentStyle { - self.styles[ALERT_TYPES.get(&severity).unwrap()] - } - - pub fn new( - name: String, - parent: Option, - styles: HashMap, - ) -> Theme { - Theme { - name, - parent, - styles, - } - } - - pub fn closest_meaning<'a>(&self, meaning: &'a Meaning) -> &'a Meaning { - if self.styles.contains_key(meaning) { - meaning - } else if MEANING_FALLBACKS.contains_key(meaning) { - self.closest_meaning(&MEANING_FALLBACKS[meaning]) - } else { - &Meaning::Base - } - } - - // General access - if you have a meaning, this will give you a (crossterm) style - pub fn as_style(&self, meaning: Meaning) -> ContentStyle { - self.styles[self.closest_meaning(&meaning)] - } - - // Turns a map of meanings to colornames into a theme - // If theme-debug is on, then we will print any colornames that we cannot load, - // but we do not have this on in general, as it could print unfiltered text to the terminal - // from a theme TOML file. However, it will always return a theme, falling back to - // defaults on error, so that a TOML file does not break loading - pub fn from_foreground_colors( - name: String, - parent: Option<&Theme>, - foreground_colors: HashMap, - debug: bool, - ) -> Theme { - let styles: HashMap = foreground_colors - .iter() - .map(|(name, color)| { - ( - *name, - StyleFactory::from_fg_string(color).unwrap_or_else(|err| { - if debug { - log::warn!("Tried to load string as a color unsuccessfully: ({name}={color}) {err}"); - } - ContentStyle::default() - }), - ) - }) - .collect(); - Theme::from_map(name, parent, &styles) - } - - // Boil down a meaning-color hashmap into a theme, by taking the defaults - // for any unknown colors - fn from_map( - name: String, - parent: Option<&Theme>, - overrides: &HashMap, - ) -> Theme { - let styles = match parent { - Some(theme) => Box::new(theme.styles.clone()), - None => Box::new(DEFAULT_THEME.styles.clone()), - } - .iter() - .map(|(name, color)| match overrides.get(name) { - Some(value) => (*name, *value), - None => (*name, *color), - }) - .collect(); - Theme::new(name, parent.map(|p| p.name.clone()), styles) - } -} - -// Use palette to get a color from a string name, if possible -fn from_string(name: &str) -> Result { - if name.is_empty() { - return Err("Empty string".into()); - } - let first_char = name.chars().next().unwrap(); - match first_char { - '#' => { - let hexcode = &name[1..]; - let vec: Vec = hexcode - .chars() - .collect::>() - .chunks(2) - .map(|pair| u8::from_str_radix(pair.iter().collect::().as_str(), 16)) - .filter_map(|n| n.ok()) - .collect(); - if vec.len() != 3 { - return Err("Could not parse 3 hex values from string".into()); - } - Ok(Color::Rgb { - r: vec[0], - g: vec[1], - b: vec[2], - }) - } - '@' => { - // For full flexibility, we need to use serde_json, given - // crossterm's approach. - serde_json::from_str::(format!("\"{}\"", &name[1..]).as_str()) - .map_err(|_| format!("Could not convert color name {name} to Crossterm color")) - } - _ => { - let srgb = named::from_str(name).ok_or("No such color in palette")?; - Ok(Color::Rgb { - r: srgb.red, - g: srgb.green, - b: srgb.blue, - }) - } - } -} - -pub struct StyleFactory {} - -impl StyleFactory { - fn from_fg_string(name: &str) -> Result { - match from_string(name) { - Ok(color) => Ok(Self::from_fg_color(color)), - Err(err) => Err(err), - } - } - - // For succinctness, if we are confident that the name will be known, - // this routine is available to keep the code readable - fn known_fg_string(name: &str) -> ContentStyle { - Self::from_fg_string(name).unwrap() - } - - fn from_fg_color(color: Color) -> ContentStyle { - ContentStyle { - foreground_color: Some(color), - ..ContentStyle::default() - } - } - - fn from_fg_color_and_attributes(color: Color, attributes: Attributes) -> ContentStyle { - ContentStyle { - foreground_color: Some(color), - attributes, - ..ContentStyle::default() - } - } -} - -// Built-in themes. Rather than having extra files added before any theming -// is available, this gives a couple of basic options, demonstrating the use -// of themes: autumn and marine -static ALERT_TYPES: LazyLock> = LazyLock::new(|| { - HashMap::from([ - (log::Level::Info, Meaning::AlertInfo), - (log::Level::Warn, Meaning::AlertWarn), - (log::Level::Error, Meaning::AlertError), - ]) -}); - -static MEANING_FALLBACKS: LazyLock> = LazyLock::new(|| { - HashMap::from([ - (Meaning::Guidance, Meaning::AlertInfo), - (Meaning::Annotation, Meaning::AlertInfo), - (Meaning::Title, Meaning::Important), - ]) -}); - -static DEFAULT_THEME: LazyLock = LazyLock::new(|| { - Theme::new( - "default".to_string(), - None, - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::from_fg_color(Color::DarkRed), - ), - ( - Meaning::AlertWarn, - StyleFactory::from_fg_color(Color::DarkYellow), - ), - ( - Meaning::AlertInfo, - StyleFactory::from_fg_color(Color::DarkGreen), - ), - ( - Meaning::Annotation, - StyleFactory::from_fg_color(Color::DarkGrey), - ), - ( - Meaning::Guidance, - StyleFactory::from_fg_color(Color::DarkBlue), - ), - ( - Meaning::Important, - StyleFactory::from_fg_color_and_attributes( - Color::White, - Attributes::from(Attribute::Bold), - ), - ), - (Meaning::Muted, StyleFactory::from_fg_color(Color::Grey)), - (Meaning::Base, ContentStyle::default()), - ]), - ) -}); - -static BUILTIN_THEMES: LazyLock> = LazyLock::new(|| { - HashMap::from([ - ("default", HashMap::new()), - ( - "(none)", - HashMap::from([ - (Meaning::AlertError, ContentStyle::default()), - (Meaning::AlertWarn, ContentStyle::default()), - (Meaning::AlertInfo, ContentStyle::default()), - (Meaning::Annotation, ContentStyle::default()), - (Meaning::Guidance, ContentStyle::default()), - (Meaning::Important, ContentStyle::default()), - (Meaning::Muted, ContentStyle::default()), - (Meaning::Base, ContentStyle::default()), - ]), - ), - ( - "autumn", - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::known_fg_string("saddlebrown"), - ), - ( - Meaning::AlertWarn, - StyleFactory::known_fg_string("darkorange"), - ), - (Meaning::AlertInfo, StyleFactory::known_fg_string("gold")), - ( - Meaning::Annotation, - StyleFactory::from_fg_color(Color::DarkGrey), - ), - (Meaning::Guidance, StyleFactory::known_fg_string("brown")), - ]), - ), - ( - "marine", - HashMap::from([ - ( - Meaning::AlertError, - StyleFactory::known_fg_string("yellowgreen"), - ), - (Meaning::AlertWarn, StyleFactory::known_fg_string("cyan")), - ( - Meaning::AlertInfo, - StyleFactory::known_fg_string("turquoise"), - ), - ( - Meaning::Annotation, - StyleFactory::known_fg_string("steelblue"), - ), - ( - Meaning::Base, - StyleFactory::known_fg_string("lightsteelblue"), - ), - (Meaning::Guidance, StyleFactory::known_fg_string("teal")), - ]), - ), - ]) - .iter() - .map(|(name, theme)| (*name, Theme::from_map(name.to_string(), None, theme))) - .collect() -}); - -// To avoid themes being repeatedly loaded, we store them in a theme manager -pub struct ThemeManager { - loaded_themes: HashMap, - debug: bool, - override_theme_dir: Option, -} - -// Theme-loading logic -impl ThemeManager { - pub fn new(debug: Option, theme_dir: Option) -> Self { - Self { - loaded_themes: HashMap::new(), - debug: debug.unwrap_or(false), - override_theme_dir: match theme_dir { - Some(theme_dir) => Some(theme_dir), - None => std::env::var("ATUIN_THEME_DIR").ok(), - }, - } - } - - // Try to load a theme from a `{name}.toml` file in the theme directory. If an override is set - // for the theme dir (via ATUIN_THEME_DIR env) we should load the theme from there - pub fn load_theme_from_file( - &mut self, - name: &str, - max_depth: u8, - ) -> Result<&Theme, Box> { - let mut theme_file = if let Some(p) = &self.override_theme_dir { - if p.is_empty() { - return Err(Box::new(Error::new( - ErrorKind::NotFound, - "Empty theme directory override and could not find theme elsewhere", - ))); - } - PathBuf::from(p) - } else { - let config_dir = atuin_common::utils::config_dir(); - let mut theme_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - let mut theme_file = PathBuf::new(); - theme_file.push(config_dir); - theme_file - }; - theme_file.push("themes"); - theme_file - }; - - let theme_toml = format!["{name}.toml"]; - theme_file.push(theme_toml); - - let mut config_builder = Config::builder(); - - config_builder = config_builder.add_source(ConfigFile::new( - theme_file.to_str().unwrap(), - FileFormat::Toml, - )); - - let config = config_builder.build()?; - self.load_theme_from_config(name, config, max_depth) - } - - pub fn load_theme_from_config( - &mut self, - name: &str, - config: Config, - max_depth: u8, - ) -> Result<&Theme, Box> { - let debug = self.debug; - let theme_config: ThemeConfig = match config.try_deserialize() { - Ok(tc) => tc, - Err(e) => { - return Err(Box::new(Error::new( - ErrorKind::InvalidInput, - format!( - "Failed to deserialize theme: {}", - if debug { - e.to_string() - } else { - "set theme debug on for more info".to_string() - } - ), - ))); - } - }; - let colors: HashMap = theme_config.colors; - let parent: Option<&Theme> = match theme_config.theme.parent { - Some(parent_name) => { - if max_depth == 0 { - return Err(Box::new(Error::new( - ErrorKind::InvalidInput, - "Parent requested but we hit the recursion limit", - ))); - } - Some(self.load_theme(parent_name.as_str(), Some(max_depth - 1))) - } - None => Some(self.load_theme("default", Some(max_depth - 1))), - }; - - if debug && name != theme_config.theme.name { - log::warn!( - "Your theme config name is not the name of your loaded theme {} != {}", - name, - theme_config.theme.name - ); - } - - let theme = Theme::from_foreground_colors(theme_config.theme.name, parent, colors, debug); - let name = name.to_string(); - self.loaded_themes.insert(name.clone(), theme); - let theme = self.loaded_themes.get(&name).unwrap(); - Ok(theme) - } - - // Check if the requested theme is loaded and, if not, then attempt to get it - // from the builtins or, if not there, from file - pub fn load_theme(&mut self, name: &str, max_depth: Option) -> &Theme { - if self.loaded_themes.contains_key(name) { - return self.loaded_themes.get(name).unwrap(); - } - let built_ins = &BUILTIN_THEMES; - match built_ins.get(name) { - Some(theme) => theme, - None => match self.load_theme_from_file(name, max_depth.unwrap_or(DEFAULT_MAX_DEPTH)) { - Ok(theme) => theme, - Err(err) => { - log::warn!("Could not load theme {name}: {err}"); - built_ins.get("(none)").unwrap() - } - }, - } - } -} - -#[cfg(test)] -mod theme_tests { - use super::*; - - #[test] - fn test_can_load_builtin_theme() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - let theme = manager.load_theme("autumn", None); - assert_eq!( - theme.as_style(Meaning::Guidance).foreground_color, - from_string("brown").ok() - ); - } - - #[test] - fn test_can_create_theme() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - let mytheme = Theme::new( - "mytheme".to_string(), - None, - HashMap::from([( - Meaning::AlertError, - StyleFactory::known_fg_string("yellowgreen"), - )]), - ); - manager.loaded_themes.insert("mytheme".to_string(), mytheme); - let theme = manager.load_theme("mytheme", None); - assert_eq!( - theme.as_style(Meaning::AlertError).foreground_color, - from_string("yellowgreen").ok() - ); - } - - #[test] - fn test_can_fallback_when_meaning_missing() { - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - - // We use title as an example of a meaning that is not defined - // even in the base theme. - assert!(!DEFAULT_THEME.styles.contains_key(&Meaning::Title)); - - let config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"title_theme\" - - [colors] - Guidance = \"white\" - AlertInfo = \"zomp\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let theme = manager - .load_theme_from_config("config_theme", config, 1) - .unwrap(); - - // Correctly picks overridden color. - assert_eq!( - theme.as_style(Meaning::Guidance).foreground_color, - from_string("white").ok() - ); - - // Does not fall back to any color. - assert_eq!(theme.as_style(Meaning::AlertInfo).foreground_color, None); - - // Even for the base. - assert_eq!(theme.as_style(Meaning::Base).foreground_color, None); - - // Falls back to red as meaning missing from theme, so picks base default. - assert_eq!( - theme.as_style(Meaning::AlertError).foreground_color, - Some(Color::DarkRed) - ); - - // Falls back to Important as Title not available. - assert_eq!( - theme.as_style(Meaning::Title).foreground_color, - theme.as_style(Meaning::Important).foreground_color, - ); - - let title_config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"title_theme\" - - [colors] - Title = \"white\" - AlertInfo = \"zomp\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let title_theme = manager - .load_theme_from_config("title_theme", title_config, 1) - .unwrap(); - - assert_eq!( - title_theme.as_style(Meaning::Title).foreground_color, - Some(Color::White) - ); - } - - #[test] - fn test_no_fallbacks_are_circular() { - let mytheme = Theme::new("mytheme".to_string(), None, HashMap::from([])); - MEANING_FALLBACKS - .iter() - .for_each(|pair| assert_eq!(mytheme.closest_meaning(pair.0), &Meaning::Base)) - } - - #[test] - fn test_can_get_colors_via_convenience_functions() { - let mut manager = ThemeManager::new(Some(true), Some("".to_string())); - let theme = manager.load_theme("default", None); - assert_eq!(theme.get_error().foreground_color.unwrap(), Color::DarkRed); - assert_eq!( - theme.get_warning().foreground_color.unwrap(), - Color::DarkYellow - ); - assert_eq!(theme.get_info().foreground_color.unwrap(), Color::DarkGreen); - assert_eq!(theme.get_base().foreground_color, None); - assert_eq!( - theme.get_alert(log::Level::Error).foreground_color.unwrap(), - Color::DarkRed - ) - } - - #[test] - fn test_can_use_parent_theme_for_fallbacks() { - testing_logger::setup(); - - let mut manager = ThemeManager::new(Some(false), Some("".to_string())); - - // First, we introduce a base theme - let solarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"solarized\" - - [colors] - Guidance = \"white\" - AlertInfo = \"pink\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let solarized_theme = manager - .load_theme_from_config("solarized", solarized, 1) - .unwrap(); - - assert_eq!( - solarized_theme - .as_style(Meaning::AlertInfo) - .foreground_color, - from_string("pink").ok() - ); - - // Then we introduce a derived theme - let unsolarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"unsolarized\" - parent = \"solarized\" - - [colors] - AlertInfo = \"red\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let unsolarized_theme = manager - .load_theme_from_config("unsolarized", unsolarized, 1) - .unwrap(); - - // It will take its own values - assert_eq!( - unsolarized_theme - .as_style(Meaning::AlertInfo) - .foreground_color, - from_string("red").ok() - ); - - // ...or fall back to the parent - assert_eq!( - unsolarized_theme - .as_style(Meaning::Guidance) - .foreground_color, - from_string("white").ok() - ); - - testing_logger::validate(|captured_logs| assert_eq!(captured_logs.len(), 0)); - - // If the parent is not found, we end up with the no theme colors or styling - // as this is considered a (soft) error state. - let nunsolarized = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"nunsolarized\" - parent = \"nonsolarized\" - - [colors] - AlertInfo = \"red\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - let nunsolarized_theme = manager - .load_theme_from_config("nunsolarized", nunsolarized, 1) - .unwrap(); - - assert_eq!( - nunsolarized_theme - .as_style(Meaning::Guidance) - .foreground_color, - None - ); - - testing_logger::validate(|captured_logs| { - assert_eq!(captured_logs.len(), 1); - assert_eq!( - captured_logs[0].body, - "Could not load theme nonsolarized: Empty theme directory override and could not find theme elsewhere" - ); - assert_eq!(captured_logs[0].level, log::Level::Warn) - }); - } - - #[test] - fn test_can_debug_theme() { - testing_logger::setup(); - [true, false].iter().for_each(|debug| { - let mut manager = ThemeManager::new(Some(*debug), Some("".to_string())); - let config = Config::builder() - .add_source(ConfigFile::from_str( - " - [theme] - name = \"mytheme\" - - [colors] - Guidance = \"white\" - AlertInfo = \"xinetic\" - ", - FileFormat::Toml, - )) - .build() - .unwrap(); - manager - .load_theme_from_config("config_theme", config, 1) - .unwrap(); - testing_logger::validate(|captured_logs| { - if *debug { - assert_eq!(captured_logs.len(), 2); - assert_eq!( - captured_logs[0].body, - "Your theme config name is not the name of your loaded theme config_theme != mytheme" - ); - assert_eq!(captured_logs[0].level, log::Level::Warn); - assert_eq!( - captured_logs[1].body, - "Tried to load string as a color unsuccessfully: (AlertInfo=xinetic) No such color in palette" - ); - assert_eq!(captured_logs[1].level, log::Level::Warn) - } else { - assert_eq!(captured_logs.len(), 0) - } - }) - }) - } - - #[test] - fn test_can_parse_color_strings_correctly() { - assert_eq!( - from_string("brown").unwrap(), - Color::Rgb { - r: 165, - g: 42, - b: 42 - } - ); - - assert_eq!(from_string(""), Err("Empty string".into())); - - ["manatee", "caput mortuum", "123456"] - .iter() - .for_each(|inp| { - assert_eq!(from_string(inp), Err("No such color in palette".into())); - }); - - assert_eq!( - from_string("#ff1122").unwrap(), - Color::Rgb { - r: 255, - g: 17, - b: 34 - } - ); - ["#1122", "#ffaa112", "#brown"].iter().for_each(|inp| { - assert_eq!( - from_string(inp), - Err("Could not parse 3 hex values from string".into()) - ); - }); - - assert_eq!(from_string("@dark_grey").unwrap(), Color::DarkGrey); - assert_eq!( - from_string("@rgb_(255,255,255)").unwrap(), - Color::Rgb { - r: 255, - g: 255, - b: 255 - } - ); - assert_eq!(from_string("@ansi_(255)").unwrap(), Color::AnsiValue(255)); - ["@", "@DarkGray", "@Dark 4ay", "@ansi(256)"] - .iter() - .for_each(|inp| { - assert_eq!( - from_string(inp), - Err(format!( - "Could not convert color name {inp} to Crossterm color" - )) - ); - }); - } -} diff --git a/crates/atuin-client/src/utils.rs b/crates/atuin-client/src/utils.rs deleted file mode 100644 index 35d7db26..00000000 --- a/crates/atuin-client/src/utils.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub(crate) fn get_hostname() -> String { - std::env::var("ATUIN_HOST_NAME") - .unwrap_or_else(|_| whoami::hostname().unwrap_or_else(|_| "unknown-host".to_string())) -} - -pub(crate) fn get_username() -> String { - std::env::var("ATUIN_HOST_USER") - .unwrap_or_else(|_| whoami::username().unwrap_or_else(|_| "unknown-user".to_string())) -} - -/// Returns a pair of the hostname and username, separated by a colon. -pub(crate) fn get_host_user() -> String { - format!("{}:{}", get_hostname(), get_username()) -} diff --git a/crates/atuin-client/tests/data/xonsh-history.sqlite b/crates/atuin-client/tests/data/xonsh-history.sqlite deleted file mode 100644 index 744fcf86..00000000 Binary files a/crates/atuin-client/tests/data/xonsh-history.sqlite and /dev/null differ diff --git a/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json b/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json deleted file mode 100644 index 339a09f1..00000000 --- a/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json +++ /dev/null @@ -1,12 +0,0 @@ -{"locs": [ 69, 3371, 3451, 3978], - "index": {"offsets":{"__total__":0,"cmds":[{"__total__":10,"cwd":18,"inp":78,"rtn":96,"ts":[106,125,105]},{"__total__":149,"cwd":157,"inp":217,"rtn":234,"ts":[244,263,243]},9],"env":{"ATUIN_SESSION":314,"BASH_COMPLETIONS":370,"COLORTERM":433,"DBUS_SESSION_BUS_ADDRESS":474,"DESKTOP_SESSION":529,"DISPLAY":550,"GDMSESSION":570,"GIO_LAUNCHED_DESKTOP_FILE":609,"GIO_LAUNCHED_DESKTOP_FILE_PID":704,"GJS_DEBUG_OUTPUT":734,"GJS_DEBUG_TOPICS":764,"GNOME_DESKTOP_SESSION_ID":811,"GNOME_SETUP_DISPLAY":856,"GNOME_SHELL_SESSION_MODE":890,"GTK_MODULES":915,"HOME":942,"IM_CONFIG_PHASE":976,"INVOCATION_ID":998,"JOURNAL_STREAM":1052,"LANG":1071,"LOGNAME":1097,"MANAGERPID":1118,"MOZ_ENABLE_WAYLAND":1148,"PATH":1161,"PWD":1736,"PYENV_DIR":1802,"PYENV_HOOK_PATH":1874,"PYENV_ROOT":2048,"PYENV_SHELL":2086,"PYENV_VERSION":2111,"QT_ACCESSIBILITY":2141,"QT_IM_MODULE":2162,"SESSION_MANAGER":2189,"SHELL":2279,"SHLVL":2303,"SSH_AGENT_LAUNCHER":2330,"SSH_AUTH_SOCK":2364,"SSL_CERT_DIR":2415,"SSL_CERT_FILE":2458,"SYSTEMD_EXEC_PID":2525,"TERM":2541,"TERM_PROGRAM":2575,"TERM_PROGRAM_VERSION":2610,"THREAD_SUBPROCS":2657,"USER":2670,"USERNAME":2689,"WAYLAND_DISPLAY":2715,"WEZTERM_CONFIG_DIR":2750,"WEZTERM_CONFIG_FILE":2806,"WEZTERM_EXECUTABLE":2874,"WEZTERM_EXECUTABLE_DIR":2927,"WEZTERM_PANE":2957,"WEZTERM_UNIX_SOCKET":2986,"XAUTHORITY":3047,"XDG_CONFIG_DIRS":3116,"XDG_CURRENT_DESKTOP":3176,"XDG_DATA_DIRS":3209,"XDG_MENU_PREFIX":3316,"XDG_RUNTIME_DIR":3345,"XDG_SESSION_CLASS":3387,"XDG_SESSION_DESKTOP":3418,"XDG_SESSION_TYPE":3448,"XMODIFIERS":3473,"XONSHRC":3496,"XONSHRC_DIR":3594,"XONSH_CAPTURE_ALWAYS":3674,"XONSH_CONFIG_DIR":3698,"XONSH_DATA_DIR":3747,"XONSH_INTERACTIVE":3805,"XONSH_LOGIN":3825,"XONSH_VERSION":3847,"__total__":296},"locked":3869,"sessionid":3889,"ts":[3936,3956,3935]},"sizes":{"__total__":3978,"cmds":[{"__total__":137,"cwd":51,"inp":9,"rtn":1,"ts":[17,18,40]},{"__total__":136,"cwd":51,"inp":8,"rtn":1,"ts":[17,18,40]},278],"env":{"ATUIN_SESSION":34,"BASH_COMPLETIONS":48,"COLORTERM":11,"DBUS_SESSION_BUS_ADDRESS":34,"DESKTOP_SESSION":8,"DISPLAY":4,"GDMSESSION":8,"GIO_LAUNCHED_DESKTOP_FILE":60,"GIO_LAUNCHED_DESKTOP_FILE_PID":8,"GJS_DEBUG_OUTPUT":8,"GJS_DEBUG_TOPICS":17,"GNOME_DESKTOP_SESSION_ID":20,"GNOME_SETUP_DISPLAY":4,"GNOME_SHELL_SESSION_MODE":8,"GTK_MODULES":17,"HOME":13,"IM_CONFIG_PHASE":3,"INVOCATION_ID":34,"JOURNAL_STREAM":9,"LANG":13,"LOGNAME":5,"MANAGERPID":6,"MOZ_ENABLE_WAYLAND":3,"PATH":566,"PWD":51,"PYENV_DIR":51,"PYENV_HOOK_PATH":158,"PYENV_ROOT":21,"PYENV_SHELL":6,"PYENV_VERSION":8,"QT_ACCESSIBILITY":3,"QT_IM_MODULE":6,"SESSION_MANAGER":79,"SHELL":13,"SHLVL":3,"SSH_AGENT_LAUNCHER":15,"SSH_AUTH_SOCK":33,"SSL_CERT_DIR":24,"SSL_CERT_FILE":45,"SYSTEMD_EXEC_PID":6,"TERM":16,"TERM_PROGRAM":9,"TERM_PROGRAM_VERSION":26,"THREAD_SUBPROCS":3,"USER":5,"USERNAME":5,"WAYLAND_DISPLAY":11,"WEZTERM_CONFIG_DIR":31,"WEZTERM_CONFIG_FILE":44,"WEZTERM_EXECUTABLE":25,"WEZTERM_EXECUTABLE_DIR":12,"WEZTERM_PANE":4,"WEZTERM_UNIX_SOCKET":45,"XAUTHORITY":48,"XDG_CONFIG_DIRS":35,"XDG_CURRENT_DESKTOP":14,"XDG_DATA_DIRS":86,"XDG_MENU_PREFIX":8,"XDG_RUNTIME_DIR":19,"XDG_SESSION_CLASS":6,"XDG_SESSION_DESKTOP":8,"XDG_SESSION_TYPE":9,"XMODIFIERS":10,"XONSHRC":81,"XONSHRC_DIR":54,"XONSH_CAPTURE_ALWAYS":2,"XONSH_CONFIG_DIR":29,"XONSH_DATA_DIR":35,"XONSH_INTERACTIVE":3,"XONSH_LOGIN":3,"XONSH_VERSION":8,"__total__":3561},"locked":5,"sessionid":38,"ts":[18,18,41]}}, - "data": {"cmds": [{"cwd": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "inp": "false\n", "rtn": 1, "ts": [1707241291.142516, 1707241291.1527853] -} -, {"cwd": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "inp": "exit\n", "rtn": 0, "ts": [1707241292.271584, 1707241292.2758434] -} -] -, "env": {"ATUIN_SESSION": "018d7f82ad167dc4888ca0bf294d2bfd", "BASH_COMPLETIONS": "\/usr\/share\/bash-completion\/bash_completion", "COLORTERM": "truecolor", "DBUS_SESSION_BUS_ADDRESS": "unix:path=\/run\/user\/1000\/bus", "DESKTOP_SESSION": "ubuntu", "DISPLAY": ":0", "GDMSESSION": "ubuntu", "GIO_LAUNCHED_DESKTOP_FILE": "\/usr\/share\/applications\/org.wezfurlong.wezterm.desktop", "GIO_LAUNCHED_DESKTOP_FILE_PID": "196859", "GJS_DEBUG_OUTPUT": "stderr", "GJS_DEBUG_TOPICS": "JS ERROR;JS LOG", "GNOME_DESKTOP_SESSION_ID": "this-is-deprecated", "GNOME_SETUP_DISPLAY": ":1", "GNOME_SHELL_SESSION_MODE": "ubuntu", "GTK_MODULES": "gail:atk-bridge", "HOME": "\/home\/user", "IM_CONFIG_PHASE": "1", "INVOCATION_ID": "4f121e7ad56c41a6b84aa3cbe1ad61fa", "JOURNAL_STREAM": "8:37187", "LANG": "en_US.UTF-8", "LOGNAME": "user", "MANAGERPID": "2118", "MOZ_ENABLE_WAYLAND": "1", "PATH": "\/home\/user\/.pyenv\/versions\/3.12.0\/bin:\/home\/user\/.pyenv\/libexec:\/home\/user\/.pyenv\/plugins\/python-build\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-update\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-doctor\/bin:\/home\/user\/.cargo\/bin:\/home\/user\/.pyenv\/shims:\/home\/user\/.pyenv\/bin:\/home\/user\/bin:\/home\/user\/bin:\/usr\/local\/sbin:\/usr\/local\/bin:\/usr\/sbin:\/usr\/bin:\/sbin:\/bin:\/usr\/games:\/usr\/local\/games:\/snap\/bin:\/snap\/bin:\/home\/user\/.local\/share\/JetBrains\/Toolbox\/scripts", "PWD": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "PYENV_DIR": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "PYENV_HOOK_PATH": "\/home\/user\/.pyenv\/pyenv.d:\/usr\/local\/etc\/pyenv.d:\/etc\/pyenv.d:\/usr\/lib\/pyenv\/hooks:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/etc\/pyenv.d", "PYENV_ROOT": "\/home\/user\/.pyenv", "PYENV_SHELL": "bash", "PYENV_VERSION": "3.12.0", "QT_ACCESSIBILITY": "1", "QT_IM_MODULE": "ibus", "SESSION_MANAGER": "local\/box:@\/tmp\/.ICE-unix\/2452,unix\/box:\/tmp\/.ICE-unix\/2452", "SHELL": "\/bin\/bash", "SHLVL": "1", "SSH_AGENT_LAUNCHER": "gnome-keyring", "SSH_AUTH_SOCK": "\/run\/user\/1000\/keyring\/ssh", "SSL_CERT_DIR": "\/usr\/lib\/ssl\/certs", "SSL_CERT_FILE": "\/usr\/lib\/ssl\/certs\/ca-certificates.crt", "SYSTEMD_EXEC_PID": "2470", "TERM": "xterm-256color", "TERM_PROGRAM": "WezTerm", "TERM_PROGRAM_VERSION": "20240127-113634-bbcac864", "THREAD_SUBPROCS": "1", "USER": "user", "USERNAME": "user", "WAYLAND_DISPLAY": "wayland-0", "WEZTERM_CONFIG_DIR": "\/home\/user\/.config\/wezterm", "WEZTERM_CONFIG_FILE": "\/home\/user\/.config\/wezterm\/wezterm.lua", "WEZTERM_EXECUTABLE": "\/usr\/bin\/wezterm-gui", "WEZTERM_EXECUTABLE_DIR": "\/usr\/bin", "WEZTERM_PANE": "41", "WEZTERM_UNIX_SOCKET": "\/run\/user\/1000\/wezterm\/gui-sock-196859", "XAUTHORITY": "\/run\/user\/1000\/.mutter-Xwaylandauth.T986H2", "XDG_CONFIG_DIRS": "\/etc\/xdg\/xdg-ubuntu:\/etc\/xdg", "XDG_CURRENT_DESKTOP": "ubuntu:GNOME", "XDG_DATA_DIRS": "\/usr\/share\/ubuntu:\/usr\/local\/share\/:\/usr\/share\/:\/var\/lib\/snapd\/desktop", "XDG_MENU_PREFIX": "gnome-", "XDG_RUNTIME_DIR": "\/run\/user\/1000", "XDG_SESSION_CLASS": "user", "XDG_SESSION_DESKTOP": "ubuntu", "XDG_SESSION_TYPE": "wayland", "XMODIFIERS": "@im=ibus", "XONSHRC": "\/etc\/xonsh\/xonshrc:\/home\/user\/.config\/xonsh\/rc.xsh:\/home\/user\/.xonshrc", "XONSHRC_DIR": "\/etc\/xonsh\/rc.d:\/home\/user\/.config\/xonsh\/rc.d", "XONSH_CAPTURE_ALWAYS": "", "XONSH_CONFIG_DIR": "\/home\/user\/.config\/xonsh", "XONSH_DATA_DIR": "\/home\/user\/.local\/share\/xonsh", "XONSH_INTERACTIVE": "1", "XONSH_LOGIN": "1", "XONSH_VERSION": "0.14.2"} -, "locked": false, "sessionid": "82eafbf5-9f43-489a-80d2-61c7dc6ef542", "ts": [1707241286.9361255, 1707241292.3081477] -} - -} diff --git a/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json b/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json deleted file mode 100644 index 72694f04..00000000 --- a/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json +++ /dev/null @@ -1,12 +0,0 @@ -{"locs": [ 69, 3372, 3452, 3936], - "index": {"offsets":{"__total__":0,"cmds":[{"__total__":10,"cwd":18,"inp":64,"rtn":94,"ts":[104,124,103]},{"__total__":148,"cwd":156,"inp":202,"rtn":220,"ts":[230,250,229]},9],"env":{"ATUIN_SESSION":300,"BASH_COMPLETIONS":356,"COLORTERM":419,"DBUS_SESSION_BUS_ADDRESS":460,"DESKTOP_SESSION":515,"DISPLAY":536,"GDMSESSION":556,"GIO_LAUNCHED_DESKTOP_FILE":595,"GIO_LAUNCHED_DESKTOP_FILE_PID":690,"GJS_DEBUG_OUTPUT":720,"GJS_DEBUG_TOPICS":750,"GNOME_DESKTOP_SESSION_ID":797,"GNOME_SETUP_DISPLAY":842,"GNOME_SHELL_SESSION_MODE":876,"GTK_MODULES":901,"HOME":928,"IM_CONFIG_PHASE":962,"INVOCATION_ID":984,"JOURNAL_STREAM":1038,"LANG":1057,"LOGNAME":1083,"MANAGERPID":1104,"MOZ_ENABLE_WAYLAND":1134,"PATH":1147,"PWD":1722,"PYENV_DIR":1774,"PYENV_HOOK_PATH":1832,"PYENV_ROOT":2006,"PYENV_SHELL":2044,"PYENV_VERSION":2069,"QT_ACCESSIBILITY":2099,"QT_IM_MODULE":2120,"SESSION_MANAGER":2147,"SHELL":2237,"SHLVL":2261,"SSH_AGENT_LAUNCHER":2288,"SSH_AUTH_SOCK":2322,"SSL_CERT_DIR":2373,"SSL_CERT_FILE":2416,"SYSTEMD_EXEC_PID":2483,"TERM":2499,"TERM_PROGRAM":2533,"TERM_PROGRAM_VERSION":2568,"THREAD_SUBPROCS":2615,"USER":2628,"USERNAME":2647,"WAYLAND_DISPLAY":2673,"WEZTERM_CONFIG_DIR":2708,"WEZTERM_CONFIG_FILE":2764,"WEZTERM_EXECUTABLE":2832,"WEZTERM_EXECUTABLE_DIR":2885,"WEZTERM_PANE":2915,"WEZTERM_UNIX_SOCKET":2944,"XAUTHORITY":3005,"XDG_CONFIG_DIRS":3074,"XDG_CURRENT_DESKTOP":3134,"XDG_DATA_DIRS":3167,"XDG_MENU_PREFIX":3274,"XDG_RUNTIME_DIR":3303,"XDG_SESSION_CLASS":3345,"XDG_SESSION_DESKTOP":3376,"XDG_SESSION_TYPE":3406,"XMODIFIERS":3431,"XONSHRC":3454,"XONSHRC_DIR":3552,"XONSH_CAPTURE_ALWAYS":3632,"XONSH_CONFIG_DIR":3656,"XONSH_DATA_DIR":3705,"XONSH_INTERACTIVE":3763,"XONSH_LOGIN":3783,"XONSH_VERSION":3805,"__total__":282},"locked":3827,"sessionid":3847,"ts":[3894,3914,3893]},"sizes":{"__total__":3936,"cmds":[{"__total__":136,"cwd":37,"inp":21,"rtn":1,"ts":[18,18,41]},{"__total__":123,"cwd":37,"inp":9,"rtn":1,"ts":[18,17,40]},264],"env":{"ATUIN_SESSION":34,"BASH_COMPLETIONS":48,"COLORTERM":11,"DBUS_SESSION_BUS_ADDRESS":34,"DESKTOP_SESSION":8,"DISPLAY":4,"GDMSESSION":8,"GIO_LAUNCHED_DESKTOP_FILE":60,"GIO_LAUNCHED_DESKTOP_FILE_PID":8,"GJS_DEBUG_OUTPUT":8,"GJS_DEBUG_TOPICS":17,"GNOME_DESKTOP_SESSION_ID":20,"GNOME_SETUP_DISPLAY":4,"GNOME_SHELL_SESSION_MODE":8,"GTK_MODULES":17,"HOME":13,"IM_CONFIG_PHASE":3,"INVOCATION_ID":34,"JOURNAL_STREAM":9,"LANG":13,"LOGNAME":5,"MANAGERPID":6,"MOZ_ENABLE_WAYLAND":3,"PATH":566,"PWD":37,"PYENV_DIR":37,"PYENV_HOOK_PATH":158,"PYENV_ROOT":21,"PYENV_SHELL":6,"PYENV_VERSION":8,"QT_ACCESSIBILITY":3,"QT_IM_MODULE":6,"SESSION_MANAGER":79,"SHELL":13,"SHLVL":3,"SSH_AGENT_LAUNCHER":15,"SSH_AUTH_SOCK":33,"SSL_CERT_DIR":24,"SSL_CERT_FILE":45,"SYSTEMD_EXEC_PID":6,"TERM":16,"TERM_PROGRAM":9,"TERM_PROGRAM_VERSION":26,"THREAD_SUBPROCS":3,"USER":5,"USERNAME":5,"WAYLAND_DISPLAY":11,"WEZTERM_CONFIG_DIR":31,"WEZTERM_CONFIG_FILE":44,"WEZTERM_EXECUTABLE":25,"WEZTERM_EXECUTABLE_DIR":12,"WEZTERM_PANE":4,"WEZTERM_UNIX_SOCKET":45,"XAUTHORITY":48,"XDG_CONFIG_DIRS":35,"XDG_CURRENT_DESKTOP":14,"XDG_DATA_DIRS":86,"XDG_MENU_PREFIX":8,"XDG_RUNTIME_DIR":19,"XDG_SESSION_CLASS":6,"XDG_SESSION_DESKTOP":8,"XDG_SESSION_TYPE":9,"XMODIFIERS":10,"XONSHRC":81,"XONSHRC_DIR":54,"XONSH_CAPTURE_ALWAYS":2,"XONSH_CONFIG_DIR":29,"XONSH_DATA_DIR":35,"XONSH_INTERACTIVE":3,"XONSH_LOGIN":3,"XONSH_VERSION":8,"__total__":3533},"locked":5,"sessionid":38,"ts":[18,18,41]}}, - "data": {"cmds": [{"cwd": "\/home\/user\/Documents\/code\/atuin", "inp": "echo hello world!\n", "rtn": 0, "ts": [1707193079.4782722, 1707193079.4829233] -} -, {"cwd": "\/home\/user\/Documents\/code\/atuin", "inp": "ls -l\n", "rtn": 0, "ts": [1707193081.7063284, 1707193081.727617] -} -] -, "env": {"ATUIN_SESSION": "018d7ca2e953742e9826012f30115040", "BASH_COMPLETIONS": "\/usr\/share\/bash-completion\/bash_completion", "COLORTERM": "truecolor", "DBUS_SESSION_BUS_ADDRESS": "unix:path=\/run\/user\/1000\/bus", "DESKTOP_SESSION": "ubuntu", "DISPLAY": ":0", "GDMSESSION": "ubuntu", "GIO_LAUNCHED_DESKTOP_FILE": "\/usr\/share\/applications\/org.wezfurlong.wezterm.desktop", "GIO_LAUNCHED_DESKTOP_FILE_PID": "196859", "GJS_DEBUG_OUTPUT": "stderr", "GJS_DEBUG_TOPICS": "JS ERROR;JS LOG", "GNOME_DESKTOP_SESSION_ID": "this-is-deprecated", "GNOME_SETUP_DISPLAY": ":1", "GNOME_SHELL_SESSION_MODE": "ubuntu", "GTK_MODULES": "gail:atk-bridge", "HOME": "\/home\/user", "IM_CONFIG_PHASE": "1", "INVOCATION_ID": "4f121e7ad56c41a6b84aa3cbe1ad61fa", "JOURNAL_STREAM": "8:37187", "LANG": "en_US.UTF-8", "LOGNAME": "user", "MANAGERPID": "2118", "MOZ_ENABLE_WAYLAND": "1", "PATH": "\/home\/user\/.pyenv\/versions\/3.12.0\/bin:\/home\/user\/.pyenv\/libexec:\/home\/user\/.pyenv\/plugins\/python-build\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-update\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-doctor\/bin:\/home\/user\/.cargo\/bin:\/home\/user\/.pyenv\/shims:\/home\/user\/.pyenv\/bin:\/home\/user\/bin:\/home\/user\/bin:\/usr\/local\/sbin:\/usr\/local\/bin:\/usr\/sbin:\/usr\/bin:\/sbin:\/bin:\/usr\/games:\/usr\/local\/games:\/snap\/bin:\/snap\/bin:\/home\/user\/.local\/share\/JetBrains\/Toolbox\/scripts", "PWD": "\/home\/user\/Documents\/code\/atuin", "PYENV_DIR": "\/home\/user\/Documents\/code\/atuin", "PYENV_HOOK_PATH": "\/home\/user\/.pyenv\/pyenv.d:\/usr\/local\/etc\/pyenv.d:\/etc\/pyenv.d:\/usr\/lib\/pyenv\/hooks:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/etc\/pyenv.d", "PYENV_ROOT": "\/home\/user\/.pyenv", "PYENV_SHELL": "bash", "PYENV_VERSION": "3.12.0", "QT_ACCESSIBILITY": "1", "QT_IM_MODULE": "ibus", "SESSION_MANAGER": "local\/box:@\/tmp\/.ICE-unix\/2452,unix\/box:\/tmp\/.ICE-unix\/2452", "SHELL": "\/bin\/bash", "SHLVL": "1", "SSH_AGENT_LAUNCHER": "gnome-keyring", "SSH_AUTH_SOCK": "\/run\/user\/1000\/keyring\/ssh", "SSL_CERT_DIR": "\/usr\/lib\/ssl\/certs", "SSL_CERT_FILE": "\/usr\/lib\/ssl\/certs\/ca-certificates.crt", "SYSTEMD_EXEC_PID": "2470", "TERM": "xterm-256color", "TERM_PROGRAM": "WezTerm", "TERM_PROGRAM_VERSION": "20240127-113634-bbcac864", "THREAD_SUBPROCS": "1", "USER": "user", "USERNAME": "user", "WAYLAND_DISPLAY": "wayland-0", "WEZTERM_CONFIG_DIR": "\/home\/user\/.config\/wezterm", "WEZTERM_CONFIG_FILE": "\/home\/user\/.config\/wezterm\/wezterm.lua", "WEZTERM_EXECUTABLE": "\/usr\/bin\/wezterm-gui", "WEZTERM_EXECUTABLE_DIR": "\/usr\/bin", "WEZTERM_PANE": "38", "WEZTERM_UNIX_SOCKET": "\/run\/user\/1000\/wezterm\/gui-sock-196859", "XAUTHORITY": "\/run\/user\/1000\/.mutter-Xwaylandauth.T986H2", "XDG_CONFIG_DIRS": "\/etc\/xdg\/xdg-ubuntu:\/etc\/xdg", "XDG_CURRENT_DESKTOP": "ubuntu:GNOME", "XDG_DATA_DIRS": "\/usr\/share\/ubuntu:\/usr\/local\/share\/:\/usr\/share\/:\/var\/lib\/snapd\/desktop", "XDG_MENU_PREFIX": "gnome-", "XDG_RUNTIME_DIR": "\/run\/user\/1000", "XDG_SESSION_CLASS": "user", "XDG_SESSION_DESKTOP": "ubuntu", "XDG_SESSION_TYPE": "wayland", "XMODIFIERS": "@im=ibus", "XONSHRC": "\/etc\/xonsh\/xonshrc:\/home\/user\/.config\/xonsh\/rc.xsh:\/home\/user\/.xonshrc", "XONSHRC_DIR": "\/etc\/xonsh\/rc.d:\/home\/user\/.config\/xonsh\/rc.d", "XONSH_CAPTURE_ALWAYS": "", "XONSH_CONFIG_DIR": "\/home\/user\/.config\/xonsh", "XONSH_DATA_DIR": "\/home\/user\/.local\/share\/xonsh", "XONSH_INTERACTIVE": "1", "XONSH_LOGIN": "1", "XONSH_VERSION": "0.14.2"} -, "locked": false, "sessionid": "de16af90-9148-4461-8df3-5b5659c6420d", "ts": [1707193067.8615997, 1707193089.2513068] -} - -} diff --git a/crates/atuin-common/Cargo.toml b/crates/atuin-common/Cargo.toml deleted file mode 100644 index 811b0bdb..00000000 --- a/crates/atuin-common/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -name = "atuin-common" -edition = "2024" -description = "common library for atuin" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -time = { workspace = true } -serde = { workspace = true } -uuid = { workspace = true } -typed-builder = { workspace = true } -eyre = { workspace = true } -sqlx = { workspace = true } -semver = { workspace = true } -thiserror = { workspace = true } -directories = { workspace = true } -sysinfo = "0.30.7" -base64 = { workspace = true } -getrandom = "0.2" -rustls = { workspace = true } - -[dev-dependencies] -pretty_assertions = { workspace = true } diff --git a/crates/atuin-common/src/api.rs b/crates/atuin-common/src/api.rs deleted file mode 100644 index 1a9f348c..00000000 --- a/crates/atuin-common/src/api.rs +++ /dev/null @@ -1,144 +0,0 @@ -use semver::Version; -use serde::{Deserialize, Serialize}; -use std::borrow::Cow; -use std::sync::LazyLock; -use time::OffsetDateTime; - -// the usage of X- has been deprecated for quite along time, it turns out -pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version"; -pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); - -pub static ATUIN_VERSION: LazyLock = - LazyLock::new(|| Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver")); - -#[derive(Debug, Serialize, Deserialize)] -pub struct UserResponse { - pub username: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct RegisterRequest { - pub email: String, - pub username: String, - pub password: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct RegisterResponse { - pub session: String, - /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. - /// Old servers that don't return this field will deserialize as None. - #[serde(default)] - pub auth: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DeleteUserResponse {} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChangePasswordRequest { - pub current_password: String, - pub new_password: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ChangePasswordResponse {} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LoginRequest { - pub username: String, - pub password: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct LoginResponse { - pub session: String, - /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. - /// Old servers that don't return this field will deserialize as None. - #[serde(default)] - pub auth: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct AddHistoryRequest { - pub id: String, - #[serde(with = "time::serde::rfc3339")] - pub timestamp: OffsetDateTime, - pub data: String, - pub hostname: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct CountResponse { - pub count: i64, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct SyncHistoryRequest { - #[serde(with = "time::serde::rfc3339")] - pub sync_ts: OffsetDateTime, - #[serde(with = "time::serde::rfc3339")] - pub history_ts: OffsetDateTime, - pub host: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct SyncHistoryResponse { - pub history: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ErrorResponse<'a> { - pub reason: Cow<'a, str>, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct IndexResponse { - pub homage: String, - pub version: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct StatusResponse { - pub count: i64, - pub username: String, - pub deleted: Vec, - - // These could/should also go on the index of the server - // However, we do not request the server index as a part of normal sync - // I'd rather slightly increase the size of this response, than add an extra HTTP request - pub page_size: i64, // max page size supported by the server - pub version: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct DeleteHistoryRequest { - pub client_id: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct MessageResponse { - pub message: String, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct MeResponse { - pub username: String, -} - -// Hub CLI authentication types - -/// Response from POST /auth/cli/code - generates a code for CLI auth -#[derive(Debug, Serialize, Deserialize)] -pub struct CliCodeResponse { - pub code: String, -} - -/// Response from GET /auth/cli/verify?code= - polls for authorization -#[derive(Debug, Serialize, Deserialize)] -pub struct CliVerifyResponse { - /// Session token, present only when authorization is complete - pub token: Option, - pub success: Option, - pub error: Option, -} diff --git a/crates/atuin-common/src/calendar.rs b/crates/atuin-common/src/calendar.rs deleted file mode 100644 index d3b1d921..00000000 --- a/crates/atuin-common/src/calendar.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Calendar data -use serde::{Serialize, Deserialize}; - -pub enum TimePeriod { - YEAR, - MONTH, - DAY, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TimePeriodInfo { - pub count: u64, - - // TODO: Use this for merkle tree magic - pub hash: String, -} diff --git a/crates/atuin-common/src/lib.rs b/crates/atuin-common/src/lib.rs deleted file mode 100644 index 91164a82..00000000 --- a/crates/atuin-common/src/lib.rs +++ /dev/null @@ -1,60 +0,0 @@ -#![deny(unsafe_code)] - -/// Defines a new UUID type wrapper -macro_rules! new_uuid { - ($name:ident) => { - #[derive( - Debug, - Copy, - Clone, - PartialEq, - Eq, - Hash, - PartialOrd, - Ord, - serde::Serialize, - serde::Deserialize, - )] - #[serde(transparent)] - pub struct $name(pub Uuid); - - impl sqlx::Type for $name - where - Uuid: sqlx::Type, - { - fn type_info() -> ::TypeInfo { - Uuid::type_info() - } - } - - impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name - where - Uuid: sqlx::Decode<'r, DB>, - { - fn decode( - value: DB::ValueRef<'r>, - ) -> std::result::Result { - Uuid::decode(value).map(Self) - } - } - - impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name - where - Uuid: sqlx::Encode<'q, DB>, - { - fn encode_by_ref( - &self, - buf: &mut DB::ArgumentBuffer<'q>, - ) -> Result> - { - self.0.encode_by_ref(buf) - } - } - }; -} - -pub mod api; -pub mod record; -pub mod shell; -pub mod tls; -pub mod utils; diff --git a/crates/atuin-common/src/record.rs b/crates/atuin-common/src/record.rs deleted file mode 100644 index e6ce2647..00000000 --- a/crates/atuin-common/src/record.rs +++ /dev/null @@ -1,426 +0,0 @@ -use std::collections::HashMap; - -use eyre::Result; -use serde::{Deserialize, Serialize}; -use typed_builder::TypedBuilder; -use uuid::Uuid; - -#[derive(Clone, Debug, PartialEq)] -pub struct DecryptedData(pub Vec); - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct EncryptedData { - pub data: String, - pub content_encryption_key: String, -} - -#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)] -pub struct Diff { - pub host: HostId, - pub tag: String, - pub local: Option, - pub remote: Option, -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub struct Host { - pub id: HostId, - pub name: String, -} - -impl Host { - pub fn new(id: HostId) -> Self { - Host { - id, - name: String::new(), - } - } -} - -new_uuid!(RecordId); -new_uuid!(HostId); - -pub type RecordIdx = u64; - -/// A single record stored inside of our local database -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] -pub struct Record { - /// a unique ID - #[builder(default = RecordId(crate::utils::uuid_v7()))] - pub id: RecordId, - - /// The integer record ID. This is only unique per (host, tag). - pub idx: RecordIdx, - - /// The unique ID of the host. - // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store - // as strings. I would rather avoid normalization, so store as UUID binary instead of - // encoding to a string and wasting much more storage. - pub host: Host, - - /// The creation time in nanoseconds since unix epoch - #[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)] - pub timestamp: u64, - - /// The version the data in the entry conforms to - // However we want to track versions for this tag, eg v2 - pub version: String, - - /// The type of data we are storing here. Eg, "history" - pub tag: String, - - /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. - pub data: Data, -} - -/// Extra data from the record that should be encoded in the data -#[derive(Debug, Copy, Clone)] -pub struct AdditionalData<'a> { - pub id: &'a RecordId, - pub idx: &'a u64, - pub version: &'a str, - pub tag: &'a str, - pub host: &'a HostId, -} - -impl Record { - pub fn append(&self, data: Vec) -> Record { - Record::builder() - .host(self.host.clone()) - .version(self.version.clone()) - .idx(self.idx + 1) - .tag(self.tag.clone()) - .data(DecryptedData(data)) - .build() - } -} - -/// An index representing the current state of the record stores -/// This can be both remote, or local, and compared in either direction -#[derive(Debug, Serialize, Deserialize)] -pub struct RecordStatus { - // A map of host -> tag -> max(idx) - pub hosts: HashMap>, -} - -impl Default for RecordStatus { - fn default() -> Self { - Self::new() - } -} - -impl Extend<(HostId, String, RecordIdx)> for RecordStatus { - fn extend>(&mut self, iter: T) { - for (host, tag, tail_idx) in iter { - self.set_raw(host, tag, tail_idx); - } - } -} - -impl RecordStatus { - pub fn new() -> RecordStatus { - RecordStatus { - hosts: HashMap::new(), - } - } - - /// Insert a new tail record into the store - pub fn set(&mut self, tail: Record) { - self.set_raw(tail.host.id, tail.tag, tail.idx) - } - - pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) { - self.hosts.entry(host).or_default().insert(tag, tail_id); - } - - pub fn get(&self, host: HostId, tag: String) -> Option { - self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() - } - - /// Diff this index with another, likely remote index. - /// The two diffs can then be reconciled, and the optimal change set calculated - /// Returns a tuple, with (host, tag, Option(OTHER)) - /// OTHER is set to the value of the idx on the other machine. If it is greater than our index, - /// then we need to do some downloading. If it is smaller, then we need to do some uploading - /// Note that we cannot upload if we are not the owner of the record store - hosts can only - /// write to their own store. - pub fn diff(&self, other: &Self) -> Vec { - let mut ret = Vec::new(); - - // First, we check if other has everything that self has - for (host, tag_map) in self.hosts.iter() { - for (tag, idx) in tag_map.iter() { - match other.get(*host, tag.clone()) { - // The other store is all up to date! No diff. - Some(t) if t.eq(idx) => continue, - - // The other store does exist, and it is either ahead or behind us. A diff regardless - Some(t) => ret.push(Diff { - host: *host, - tag: tag.clone(), - local: Some(*idx), - remote: Some(t), - }), - - // The other store does not exist :O - None => ret.push(Diff { - host: *host, - tag: tag.clone(), - local: Some(*idx), - remote: None, - }), - }; - } - } - - // At this point, there is a single case we have not yet considered. - // If the other store knows of a tag that we are not yet aware of, then the diff will be missed - - // account for that! - for (host, tag_map) in other.hosts.iter() { - for (tag, idx) in tag_map.iter() { - match self.get(*host, tag.clone()) { - // If we have this host/tag combo, the comparison and diff will have already happened above - Some(_) => continue, - - None => ret.push(Diff { - host: *host, - tag: tag.clone(), - remote: Some(*idx), - local: None, - }), - }; - } - } - - // Stability is a nice property to have - ret.sort(); - ret - } -} - -pub trait Encryption { - fn re_encrypt( - data: EncryptedData, - ad: AdditionalData, - old_key: &[u8; 32], - new_key: &[u8; 32], - ) -> Result { - let data = Self::decrypt(data, ad, old_key)?; - Ok(Self::encrypt(data, ad, new_key)) - } - fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData; - fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result; -} - -impl Record { - pub fn encrypt(self, key: &[u8; 32]) -> Record { - let ad = AdditionalData { - id: &self.id, - version: &self.version, - tag: &self.tag, - host: &self.host.id, - idx: &self.idx, - }; - Record { - data: E::encrypt(self.data, ad, key), - id: self.id, - host: self.host, - idx: self.idx, - timestamp: self.timestamp, - version: self.version, - tag: self.tag, - } - } -} - -impl Record { - pub fn decrypt(self, key: &[u8; 32]) -> Result> { - let ad = AdditionalData { - id: &self.id, - version: &self.version, - tag: &self.tag, - host: &self.host.id, - idx: &self.idx, - }; - Ok(Record { - data: E::decrypt(self.data, ad, key)?, - id: self.id, - host: self.host, - idx: self.idx, - timestamp: self.timestamp, - version: self.version, - tag: self.tag, - }) - } - - pub fn re_encrypt( - self, - old_key: &[u8; 32], - new_key: &[u8; 32], - ) -> Result> { - let ad = AdditionalData { - id: &self.id, - version: &self.version, - tag: &self.tag, - host: &self.host.id, - idx: &self.idx, - }; - Ok(Record { - data: E::re_encrypt(self.data, ad, old_key, new_key)?, - id: self.id, - host: self.host, - idx: self.idx, - timestamp: self.timestamp, - version: self.version, - tag: self.tag, - }) - } -} - -#[cfg(test)] -mod tests { - use crate::record::{Host, HostId}; - - use super::{DecryptedData, Diff, Record, RecordStatus}; - use pretty_assertions::assert_eq; - - fn test_record() -> Record { - Record::builder() - .host(Host::new(HostId(crate::utils::uuid_v7()))) - .version("v1".into()) - .tag(crate::utils::uuid_v7().simple().to_string()) - .data(DecryptedData(vec![0, 1, 2, 3])) - .idx(0) - .build() - } - - #[test] - fn record_index() { - let mut index = RecordStatus::new(); - let record = test_record(); - - index.set(record.clone()); - - let tail = index.get(record.host.id, record.tag); - - assert_eq!( - record.idx, - tail.expect("tail not in store"), - "tail in store did not match" - ); - } - - #[test] - fn record_index_overwrite() { - let mut index = RecordStatus::new(); - let record = test_record(); - let child = record.append(vec![1, 2, 3]); - - index.set(record.clone()); - index.set(child.clone()); - - let tail = index.get(record.host.id, record.tag); - - assert_eq!( - child.idx, - tail.expect("tail not in store"), - "tail in store did not match" - ); - } - - #[test] - fn record_index_no_diff() { - // Here, they both have the same version and should have no diff - - let mut index1 = RecordStatus::new(); - let mut index2 = RecordStatus::new(); - - let record1 = test_record(); - - index1.set(record1.clone()); - index2.set(record1); - - let diff = index1.diff(&index2); - - assert_eq!(0, diff.len(), "expected empty diff"); - } - - #[test] - fn record_index_single_diff() { - // Here, they both have the same stores, but one is ahead by a single record - - let mut index1 = RecordStatus::new(); - let mut index2 = RecordStatus::new(); - - let record1 = test_record(); - let record2 = record1.append(vec![1, 2, 3]); - - index1.set(record1); - index2.set(record2.clone()); - - let diff = index1.diff(&index2); - - assert_eq!(1, diff.len(), "expected single diff"); - assert_eq!( - diff[0], - Diff { - host: record2.host.id, - tag: record2.tag, - remote: Some(1), - local: Some(0) - } - ); - } - - #[test] - fn record_index_multi_diff() { - // A much more complex case, with a bunch more checks - let mut index1 = RecordStatus::new(); - let mut index2 = RecordStatus::new(); - - let store1record1 = test_record(); - let store1record2 = store1record1.append(vec![1, 2, 3]); - - let store2record1 = test_record(); - let store2record2 = store2record1.append(vec![1, 2, 3]); - - let store3record1 = test_record(); - - let store4record1 = test_record(); - - // index1 only knows about the first two entries of the first two stores - index1.set(store1record1); - index1.set(store2record1); - - // index2 is fully up to date with the first two stores, and knows of a third - index2.set(store1record2); - index2.set(store2record2); - index2.set(store3record1); - - // index1 knows of a 4th store - index1.set(store4record1); - - let diff1 = index1.diff(&index2); - let diff2 = index2.diff(&index1); - - // both diffs the same length - assert_eq!(4, diff1.len()); - assert_eq!(4, diff2.len()); - - dbg!(&diff1, &diff2); - - // both diffs should be ALMOST the same. They will agree on which hosts and tags - // require updating, but the "other" value will not be the same. - let smol_diff_1: Vec<(HostId, String)> = - diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); - let smol_diff_2: Vec<(HostId, String)> = - diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); - - assert_eq!(smol_diff_1, smol_diff_2); - - // diffing with yourself = no diff - assert_eq!(index1.diff(&index1).len(), 0); - assert_eq!(index2.diff(&index2).len(), 0); - } -} diff --git a/crates/atuin-common/src/shell.rs b/crates/atuin-common/src/shell.rs deleted file mode 100644 index 7f9a7b8f..00000000 --- a/crates/atuin-common/src/shell.rs +++ /dev/null @@ -1,183 +0,0 @@ -use std::{ffi::OsStr, path::Path, process::Command}; - -use serde::Serialize; -use sysinfo::{Process, System, get_current_pid}; -use thiserror::Error; - -#[derive(PartialEq)] -pub enum Shell { - Sh, - Bash, - Fish, - Zsh, - Xonsh, - Nu, - Powershell, - - Unknown, -} - -impl std::fmt::Display for Shell { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let shell = match self { - Shell::Bash => "bash", - Shell::Fish => "fish", - Shell::Zsh => "zsh", - Shell::Nu => "nu", - Shell::Xonsh => "xonsh", - Shell::Sh => "sh", - Shell::Powershell => "powershell", - - Shell::Unknown => "unknown", - }; - - write!(f, "{shell}") - } -} - -#[derive(Debug, Error, Serialize)] -pub enum ShellError { - #[error("shell not supported")] - NotSupported, - - #[error("failed to execute shell command: {0}")] - ExecError(String), -} - -impl Shell { - pub fn current() -> Shell { - let sys = System::new_all(); - - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - let parent = sys - .process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist"); - - let shell = parent.name().trim().to_lowercase(); - let shell = shell.strip_prefix('-').unwrap_or(&shell); - - Shell::from_string(shell.to_string()) - } - - pub fn from_env() -> Shell { - std::env::var("ATUIN_SHELL").map_or(Shell::Unknown, |shell| { - Shell::from_string(shell.trim().to_lowercase()) - }) - } - - pub fn config_file(&self) -> Option { - let mut path = if let Some(base) = directories::BaseDirs::new() { - base.home_dir().to_owned() - } else { - return None; - }; - - // TODO: handle all shells - match self { - Shell::Bash => path.push(".bashrc"), - Shell::Zsh => path.push(".zshrc"), - Shell::Fish => path.push(".config/fish/config.fish"), - - _ => return None, - }; - - Some(path) - } - - /// Best-effort attempt to determine the default shell - /// This implementation will be different across different platforms - /// Caller should ensure to handle Shell::Unknown correctly - pub fn default_shell() -> Result { - let sys = System::name().unwrap_or("".to_string()).to_lowercase(); - - // TODO: Support Linux - // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues - let path = if sys.contains("darwin") { - // This works in my testing so far - Shell::Sh.run_interactive([ - "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", - ])? - } else if cfg!(windows) { - return Ok(Shell::Powershell); - } else { - Shell::Sh.run_interactive(["getent passwd $LOGNAME | cut -d: -f7"])? - }; - - let path = Path::new(path.trim()); - let shell = path.file_name(); - - if shell.is_none() { - return Err(ShellError::NotSupported); - } - - Ok(Shell::from_string( - shell.unwrap().to_string_lossy().to_string(), - )) - } - - pub fn from_string(name: String) -> Shell { - match name.as_str() { - "bash" => Shell::Bash, - "fish" => Shell::Fish, - "zsh" => Shell::Zsh, - "xonsh" => Shell::Xonsh, - "nu" => Shell::Nu, - "sh" => Shell::Sh, - "powershell" => Shell::Powershell, - - _ => Shell::Unknown, - } - } - - /// Returns true if the shell is posix-like - /// Note that while fish is not posix compliant, it behaves well enough for our current - /// featureset that this does not matter. - pub fn is_posixish(&self) -> bool { - matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) - } - - pub fn run_interactive(&self, args: I) -> Result - where - I: IntoIterator, - S: AsRef, - { - let shell = self.to_string(); - let output = if self == &Self::Powershell { - Command::new(shell) - .args(args) - .output() - .map_err(|e| ShellError::ExecError(e.to_string()))? - } else { - Command::new(shell) - .arg("-ic") - .args(args) - .output() - .map_err(|e| ShellError::ExecError(e.to_string()))? - }; - - Ok(String::from_utf8(output.stdout).unwrap()) - } -} - -pub fn shell_name(parent: Option<&Process>) -> String { - let sys = System::new_all(); - - let parent = if let Some(parent) = parent { - parent - } else { - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - sys.process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist") - }; - - let shell = parent.name().trim().to_lowercase(); - let shell = shell.strip_prefix('-').unwrap_or(&shell); - - shell.to_string() -} diff --git a/crates/atuin-common/src/tls.rs b/crates/atuin-common/src/tls.rs deleted file mode 100644 index e8c840e0..00000000 --- a/crates/atuin-common/src/tls.rs +++ /dev/null @@ -1,15 +0,0 @@ -use std::sync::Once; - -static INIT: Once = Once::new(); - -/// Ensure the rustls crypto provider (ring) is installed. -/// -/// Must be called before creating any reqwest clients. Safe to call -/// multiple times — only the first call installs the provider. -pub fn ensure_crypto_provider() { - INIT.call_once(|| { - rustls::crypto::ring::default_provider() - .install_default() - .expect("Failed to install rustls crypto provider"); - }); -} diff --git a/crates/atuin-common/src/utils.rs b/crates/atuin-common/src/utils.rs deleted file mode 100644 index d7382fb2..00000000 --- a/crates/atuin-common/src/utils.rs +++ /dev/null @@ -1,383 +0,0 @@ -use std::borrow::Cow; -use std::env; -use std::path::{Path, PathBuf}; - -use eyre::{Result, eyre}; - -use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; -use getrandom::getrandom; -use uuid::Uuid; - -/// Generate N random bytes, using a cryptographically secure source -pub fn crypto_random_bytes() -> [u8; N] { - // rand say they are in principle safe for crypto purposes, but that it is perhaps a better - // idea to use getrandom for things such as passwords. - let mut ret = [0u8; N]; - - getrandom(&mut ret).expect("Failed to generate random bytes!"); - - ret -} - -/// Generate N random bytes using a cryptographically secure source, return encoded as a string -pub fn crypto_random_string() -> String { - let bytes = crypto_random_bytes::(); - - // We only use this to create a random string, and won't be reversing it to find the original - // data - no padding is OK there. It may be in URLs. - BASE64_URL_SAFE_NO_PAD.encode(bytes) -} - -pub fn uuid_v7() -> Uuid { - Uuid::now_v7() -} - -pub fn uuid_v4() -> String { - Uuid::new_v4().as_simple().to_string() -} - -pub fn has_git_dir(path: &str) -> bool { - let mut gitdir = PathBuf::from(path); - gitdir.push(".git"); - - gitdir.exists() -} - -// in a git worktree, .git is a file containing "gitdir: " pointing -// to the main repo's .git/worktrees/ directory. follow the pointer -// back to the main repo root so all worktrees share a workspace. -fn resolve_git_worktree(path: &Path) -> Option { - let git_path = path.join(".git"); - - if !git_path.is_file() { - return None; - } - - let contents = std::fs::read_to_string(&git_path).ok()?; - let gitdir_str = contents.strip_prefix("gitdir: ")?.trim(); - - let gitdir = PathBuf::from(gitdir_str); - let gitdir = if gitdir.is_absolute() { - gitdir - } else { - path.join(gitdir_str) - }; - - // walk up from e.g. /repo/.git/worktrees/feature to find /repo - let mut candidate = gitdir.as_path(); - while let Some(parent) = candidate.parent() { - if parent.join(".git").is_dir() { - return Some(parent.to_path_buf()); - } - candidate = parent; - } - - None -} - -// detect if any parent dir has a git repo in it -// I really don't want to bring in libgit for something simple like this -// If we start to do anything more advanced, then perhaps -pub fn in_git_repo(path: &str) -> Option { - let mut gitdir = PathBuf::from(path); - - while gitdir.parent().is_some() && !has_git_dir(gitdir.to_str().unwrap()) { - gitdir.pop(); - } - - // No parent? then we hit root, finding no git - if gitdir.parent().is_some() { - // if .git is a file (worktree), resolve to the main repo root - if let Some(main_repo) = resolve_git_worktree(&gitdir) { - return Some(main_repo); - } - return Some(gitdir); - } - - None -} - -// TODO: more reliable, more tested -// I don't want to use ProjectDirs, it puts config in awkward places on -// mac. Data too. Seems to be more intended for GUI apps. - -pub fn home_dir() -> PathBuf { - directories::BaseDirs::new() - .map(|d| d.home_dir().to_path_buf()) - .expect("could not determine home directory") -} - -pub fn config_dir() -> PathBuf { - let config_dir = - std::env::var("XDG_CONFIG_HOME").map_or_else(|_| home_dir().join(".config"), PathBuf::from); - config_dir.join("atuin") -} - -pub fn data_dir() -> PathBuf { - let data_dir = std::env::var("XDG_DATA_HOME") - .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); - - data_dir.join("atuin") -} - -pub fn runtime_dir() -> PathBuf { - std::env::var("XDG_RUNTIME_DIR").map_or_else(|_| data_dir(), PathBuf::from) -} - -pub fn logs_dir() -> PathBuf { - home_dir().join(".atuin").join("logs") -} - -pub fn dotfiles_cache_dir() -> PathBuf { - // In most cases, this will be ~/.local/share/atuin/dotfiles/cache - let data_dir = std::env::var("XDG_DATA_HOME") - .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); - - data_dir.join("atuin").join("dotfiles").join("cache") -} - -pub fn get_current_dir() -> String { - // Prefer PWD environment variable over cwd if available to better support symbolic links - match env::var("PWD") { - Ok(v) => v, - Err(_) => match env::current_dir() { - Ok(dir) => dir.display().to_string(), - Err(_) => String::from(""), - }, - } -} - -pub fn broken_symlink>(path: P) -> bool { - let path = path.into(); - path.is_symlink() && !path.exists() -} - -/// Extension trait for anything that can behave like a string to make it easy to escape control -/// characters. -/// -/// Intended to help prevent control characters being printed and interpreted by the terminal when -/// printing history as well as to ensure the commands that appear in the interactive search -/// reflect the actual command run rather than just the printable characters. -pub trait Escapable: AsRef { - fn escape_control(&self) -> Cow<'_, str> { - if !self.as_ref().contains(|c: char| c.is_ascii_control()) { - self.as_ref().into() - } else { - let mut remaining = self.as_ref(); - // Not a perfect way to reserve space but should reduce the allocations - let mut buf = String::with_capacity(remaining.len()); - while let Some(i) = remaining.find(|c: char| c.is_ascii_control()) { - // safe to index with `..i`, `i` and `i+1..` as part[i] is a single byte ascii char - buf.push_str(&remaining[..i]); - buf.push('^'); - buf.push(match remaining.as_bytes()[i] { - 0x7F => '?', - code => char::from_u32(u32::from(code) + 64).unwrap(), - }); - remaining = &remaining[i + 1..]; - } - buf.push_str(remaining); - buf.into() - } - } -} - -pub fn unquote(s: &str) -> Result { - if s.chars().count() < 2 { - return Err(eyre!("not enough chars")); - } - - let quote = s.chars().next().unwrap(); - - // not quoted, do nothing - if quote != '"' && quote != '\'' && quote != '`' { - return Ok(s.to_string()); - } - - if s.chars().last().unwrap() != quote { - return Err(eyre!("unexpected eof, quotes do not match")); - } - - // removes quote characters - // the sanity checks performed above ensure that the quotes will be ASCII and this will not - // panic - let s = &s[1..s.len() - 1]; - - Ok(s.to_string()) -} - -impl> Escapable for T {} - -#[expect(unsafe_code)] -#[cfg(test)] -mod tests { - use pretty_assertions::assert_ne; - - use super::*; - - use std::collections::HashSet; - - #[cfg(not(windows))] - #[test] - fn test_dirs() { - // these tests need to be run sequentially to prevent race condition - test_config_dir_xdg(); - test_config_dir(); - test_data_dir_xdg(); - test_data_dir(); - } - - #[cfg(not(windows))] - fn test_config_dir_xdg() { - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("HOME") }; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var("XDG_CONFIG_HOME", "/home/user/custom_config") }; - assert_eq!( - config_dir(), - PathBuf::from("/home/user/custom_config/atuin") - ); - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("XDG_CONFIG_HOME") }; - } - - #[cfg(not(windows))] - fn test_config_dir() { - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var("HOME", "/home/user") }; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("XDG_CONFIG_HOME") }; - - assert_eq!(config_dir(), PathBuf::from("/home/user/.config/atuin")); - - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("HOME") }; - } - - #[cfg(not(windows))] - fn test_data_dir_xdg() { - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("HOME") }; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var("XDG_DATA_HOME", "/home/user/custom_data") }; - assert_eq!(data_dir(), PathBuf::from("/home/user/custom_data/atuin")); - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("XDG_DATA_HOME") }; - } - - #[cfg(not(windows))] - fn test_data_dir() { - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::set_var("HOME", "/home/user") }; - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("XDG_DATA_HOME") }; - assert_eq!(data_dir(), PathBuf::from("/home/user/.local/share/atuin")); - // TODO: Audit that the environment access only happens in single-threaded code. - unsafe { env::remove_var("HOME") }; - } - - #[test] - fn uuid_is_unique() { - let how_many: usize = 1000000; - - // for peace of mind - let mut uuids: HashSet = HashSet::with_capacity(how_many); - - // there will be many in the same millisecond - for _ in 0..how_many { - let uuid = uuid_v7(); - uuids.insert(uuid); - } - - assert_eq!(uuids.len(), how_many); - } - - #[test] - fn escape_control_characters() { - use super::Escapable; - // CSI colour sequence - assert_eq!("\x1b[31mfoo".escape_control(), "^[[31mfoo"); - - // Tabs count as control chars - assert_eq!("foo\tbar".escape_control(), "foo^Ibar"); - - // space is in control char range but should be excluded - assert_eq!("two words".escape_control(), "two words"); - - // unicode multi-byte characters - let s = "🐢\x1b[32m🦀"; - assert_eq!(s.escape_control(), s.replace("\x1b", "^[")); - } - - #[test] - fn escape_no_control_characters() { - use super::Escapable as _; - assert!(matches!( - "no control characters".escape_control(), - Cow::Borrowed(_) - )); - assert!(matches!( - "with \x1b[31mcontrol\x1b[0m characters".escape_control(), - Cow::Owned(_) - )); - } - - #[cfg(not(windows))] - #[test] - fn in_git_repo_regular() { - // regular git repo should resolve to the directory containing .git - let tmp = std::env::temp_dir().join("atuin-test-regular-git"); - let _ = std::fs::remove_dir_all(&tmp); - let subdir = tmp.join("src").join("deep"); - std::fs::create_dir_all(&subdir).unwrap(); - std::fs::create_dir_all(tmp.join(".git")).unwrap(); - - let result = in_git_repo(subdir.to_str().unwrap()); - assert_eq!(result, Some(tmp.clone())); - - std::fs::remove_dir_all(&tmp).unwrap(); - } - - #[cfg(not(windows))] - #[test] - fn in_git_repo_worktree_resolves_to_main_repo() { - // worktree .git is a file pointing back to the main repo — - // in_git_repo should follow it so all worktrees share a workspace - let tmp = std::env::temp_dir().join("atuin-test-worktree-git"); - let _ = std::fs::remove_dir_all(&tmp); - - // main repo at tmp/main with a real .git directory - let main_repo = tmp.join("main"); - let worktree_git_dir = main_repo.join(".git").join("worktrees").join("feature"); - std::fs::create_dir_all(&worktree_git_dir).unwrap(); - - // worktree at tmp/worktree with a .git file - let worktree = tmp.join("worktree"); - let worktree_subdir = worktree.join("src"); - std::fs::create_dir_all(&worktree_subdir).unwrap(); - std::fs::write( - worktree.join(".git"), - format!("gitdir: {}", worktree_git_dir.to_str().unwrap()), - ) - .unwrap(); - - // should resolve to the main repo root, not the worktree root - let result = in_git_repo(worktree_subdir.to_str().unwrap()); - assert_eq!(result, Some(main_repo.clone())); - - std::fs::remove_dir_all(&tmp).unwrap(); - } - - #[test] - fn dumb_random_test() { - // Obviously not a test of randomness, but make sure we haven't made some - // catastrophic error - - assert_ne!(crypto_random_string::<1>(), crypto_random_string::<1>()); - assert_ne!(crypto_random_string::<2>(), crypto_random_string::<2>()); - assert_ne!(crypto_random_string::<4>(), crypto_random_string::<4>()); - assert_ne!(crypto_random_string::<8>(), crypto_random_string::<8>()); - assert_ne!(crypto_random_string::<16>(), crypto_random_string::<16>()); - assert_ne!(crypto_random_string::<32>(), crypto_random_string::<32>()); - } -} diff --git a/crates/atuin-daemon/Cargo.toml b/crates/atuin-daemon/Cargo.toml deleted file mode 100644 index e767d3c9..00000000 --- a/crates/atuin-daemon/Cargo.toml +++ /dev/null @@ -1,52 +0,0 @@ -[package] -name = "atuin-daemon" -edition = "2024" -version = { workspace = true } -description = "The daemon crate for Atuin" - -authors.workspace = true -rust-version.workspace = true -license.workspace = true -homepage.workspace = true -repository.workspace = true -readme.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -atuin-client = { path = "../atuin-client", version = "18.16.1" } -atuin-common = { path = "../atuin-common", version = "18.16.1" } -atuin-history = { path = "../atuin-history", version = "18.16.1" } - -time = { workspace = true } -uuid = { workspace = true } -tokio = { workspace = true } -tower = { workspace = true } -eyre = { workspace = true } -tracing = { workspace = true } -tracing-subscriber = { workspace = true } - -dashmap = "6.1.0" -lasso = { version = "0.7", features = ["multi-threaded"] } -tonic-types = "0.14" -tonic = "0.14" -tonic-prost = "0.14" -prost = "0.14" -prost-types = "0.14" -tokio-stream = { version = "0.1.14", features = ["net"] } -hyper-util = "0.1" - -rand.workspace = true -atuin-nucleo = { workspace = true } - - -[target.'cfg(target_os = "linux")'.dependencies] -listenfd = "1.0.1" - -[dev-dependencies] -tempfile = { workspace = true } - -[build-dependencies] -protox = "0.9" -tonic-build = "0.14" -tonic-prost-build = "0.14" diff --git a/crates/atuin-daemon/build.rs b/crates/atuin-daemon/build.rs deleted file mode 100644 index 7808a07b..00000000 --- a/crates/atuin-daemon/build.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{env, fs, path::PathBuf}; - -use protox::prost::Message; - -fn main() -> std::io::Result<()> { - let proto_paths = [ - "proto/history.proto", - "proto/search.proto", - "proto/control.proto", - "proto/semantic.proto", - ]; - let proto_include_dirs = ["proto"]; - - let file_descriptors = protox::compile(proto_paths, proto_include_dirs).unwrap(); - - let file_descriptor_path = PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR not set")) - .join("file_descriptor_set.bin"); - fs::write(&file_descriptor_path, file_descriptors.encode_to_vec()).unwrap(); - - tonic_prost_build::configure() - .build_server(true) - .file_descriptor_set_path(&file_descriptor_path) - .skip_protoc_run() - .compile_protos(&proto_paths, &proto_include_dirs) -} diff --git a/crates/atuin-daemon/proto/control.proto b/crates/atuin-daemon/proto/control.proto deleted file mode 100644 index 06347902..00000000 --- a/crates/atuin-daemon/proto/control.proto +++ /dev/null @@ -1,62 +0,0 @@ -syntax = "proto3"; -package control; - -// The Control service allows external processes (CLI commands, etc.) -// to inject events into the running daemon. -service Control { - // Send an event to the daemon's event bus - rpc SendEvent(SendEventRequest) returns (SendEventResponse); -} - -message SendEventRequest { - oneof event { - // History was pruned - search index needs full rebuild - HistoryPrunedEvent history_pruned = 1; - - // Specific history items were deleted - HistoryDeletedEvent history_deleted = 2; - - // Request immediate sync - ForceSyncEvent force_sync = 3; - - // Settings have changed, reload if needed - SettingsReloadedEvent settings_reloaded = 4; - - // Request graceful shutdown - ShutdownEvent shutdown = 5; - - // History was rebuilt - search index needs full rebuild - HistoryRebuiltEvent history_rebuilt = 6; - } -} - -message SendEventResponse { - // Empty on success; errors come through gRPC status -} - -// Individual event message types - -message HistoryPrunedEvent { - // No fields needed - just signals that pruning happened -} - -message HistoryRebuiltEvent { - // No fields needed - just signals that rebuilding happened -} - -message HistoryDeletedEvent { - // IDs of deleted history items (UUIDs as strings) - repeated string ids = 1; -} - -message ForceSyncEvent { - // No fields needed - just triggers sync -} - -message SettingsReloadedEvent { - // No fields needed - components should re-read settings -} - -message ShutdownEvent { - // No fields needed - triggers graceful shutdown -} diff --git a/crates/atuin-daemon/proto/history.proto b/crates/atuin-daemon/proto/history.proto deleted file mode 100644 index 59c12471..00000000 --- a/crates/atuin-daemon/proto/history.proto +++ /dev/null @@ -1,81 +0,0 @@ -syntax = "proto3"; -package history; - -message StartHistoryRequest { - // If people are still using my software in ~530 years, they can figure out a u128 migration - uint64 timestamp = 1; // nanosecond unix epoch - string command = 2; - string cwd = 3; - string session = 4; - string hostname = 5; - string author = 6; - string intent = 7; -} - -message EndHistoryRequest { - string id = 1; - int64 exit = 2; - uint64 duration = 3; -} - -message StartHistoryReply { - string id = 1; - string version = 2; - uint32 protocol = 3; -} - -message EndHistoryReply { - string id = 1; - uint64 idx = 2; - string version = 3; - uint32 protocol = 4; -} - -message StatusRequest {} - -message StatusReply { - bool healthy = 1; - string version = 2; - uint32 pid = 3; - uint32 protocol = 4; -} - -message ShutdownRequest {} - -message ShutdownReply { - bool accepted = 1; -} - -message TailHistoryRequest {} - -enum HistoryEventKind { - HISTORY_EVENT_KIND_UNSPECIFIED = 0; - HISTORY_EVENT_KIND_STARTED = 1; - HISTORY_EVENT_KIND_ENDED = 2; -} - -message HistoryEntry { - uint64 timestamp = 1; // nanosecond unix epoch - string id = 2; - string command = 3; - string cwd = 4; - string session = 5; - string hostname = 6; - string author = 7; - string intent = 8; - int64 exit = 9; - int64 duration = 10; -} - -message TailHistoryReply { - HistoryEventKind kind = 1; - HistoryEntry history = 2; -} - -service History { - rpc StartHistory(StartHistoryRequest) returns (StartHistoryReply); - rpc EndHistory(EndHistoryRequest) returns (EndHistoryReply); - rpc TailHistory(TailHistoryRequest) returns (stream TailHistoryReply); - rpc Status(StatusRequest) returns (StatusReply); - rpc Shutdown(ShutdownRequest) returns (ShutdownReply); -} diff --git a/crates/atuin-daemon/proto/search.proto b/crates/atuin-daemon/proto/search.proto deleted file mode 100644 index 6b84acbd..00000000 --- a/crates/atuin-daemon/proto/search.proto +++ /dev/null @@ -1,35 +0,0 @@ -syntax = "proto3"; -package search; - -enum FilterMode { - GLOBAL = 0; - HOST = 1; - SESSION = 2; - DIRECTORY = 3; - WORKSPACE = 4; - SESSION_PRELOAD = 5; -} - -message SearchContext { - string session_id = 1; - string cwd = 2; - string hostname = 3; - string host_id = 4; - optional string git_root = 5; -} - -message SearchRequest { - string query = 1; - uint64 query_id = 2; // Incrementing ID to match responses to queries - FilterMode filter_mode = 3; - SearchContext context = 4; -} - -message SearchResponse { - uint64 query_id = 1; // Echo back the query ID - repeated bytes ids = 2; -} - -service Search { - rpc Search(stream SearchRequest) returns (stream SearchResponse); -} diff --git a/crates/atuin-daemon/proto/semantic.proto b/crates/atuin-daemon/proto/semantic.proto deleted file mode 100644 index 07e550c8..00000000 --- a/crates/atuin-daemon/proto/semantic.proto +++ /dev/null @@ -1,47 +0,0 @@ -syntax = "proto3"; -package semantic; - -service Semantic { - rpc RecordCommands(stream CommandCapture) returns (RecordCommandsReply); - rpc CommandOutput(CommandOutputRequest) returns (CommandOutputReply); -} - -message CommandCapture { - string prompt = 1; - string command = 2; - string output = 3; - optional int32 exit_code = 4; - optional string history_id = 5; - optional string session_id = 6; - bool output_truncated = 7; - uint64 output_observed_bytes = 8; -} - -message RecordCommandsReply { - uint64 accepted = 1; -} - -message CommandOutputRequest { - string history_id = 1; - repeated OutputRange ranges = 2; -} - -message OutputRange { - int64 start = 1; - int64 end = 2; -} - -message OutputLine { - uint64 line_number = 1; - string content = 2; -} - -message CommandOutputReply { - bool found = 1; - string output = 2; - uint64 total_bytes = 3; - uint64 total_lines = 4; - repeated OutputLine lines = 5; - bool output_truncated = 6; - uint64 output_observed_bytes = 7; -} diff --git a/crates/atuin-daemon/src/client.rs b/crates/atuin-daemon/src/client.rs deleted file mode 100644 index c18e0e46..00000000 --- a/crates/atuin-daemon/src/client.rs +++ /dev/null @@ -1,518 +0,0 @@ -use atuin_client::database::Context; -use atuin_client::settings::{FilterMode, Settings}; -use eyre::{Context as EyreContext, Result}; -#[cfg(windows)] -use tokio::net::TcpStream; -use tonic::Code; -use tonic::transport::{Channel, Endpoint, Uri}; -use tower::service_fn; - -use hyper_util::rt::TokioIo; - -#[cfg(unix)] -use tokio::net::UnixStream; - -use atuin_client::history::History; -use tracing::{Level, instrument, span}; - -use crate::control::HistoryRebuiltEvent; -use crate::control::{ - ForceSyncEvent, HistoryDeletedEvent, HistoryPrunedEvent, SendEventRequest, - SettingsReloadedEvent, ShutdownEvent, control_client::ControlClient as ControlServiceClient, -}; -use crate::events::DaemonEvent; -use crate::history::{ - EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, - StatusReply, StatusRequest, TailHistoryReply, TailHistoryRequest, - history_client::HistoryClient as HistoryServiceClient, -}; -use crate::search::{ - FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, - search_client::SearchClient as SearchServiceClient, -}; -use crate::semantic::{ - CommandCapture, CommandOutputReply, CommandOutputRequest, OutputRange, RecordCommandsReply, - semantic_client::SemanticClient as SemanticServiceClient, -}; - -pub struct HistoryClient { - client: HistoryServiceClient, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub enum DaemonClientErrorKind { - Connect, - Unavailable, - Unimplemented, - Other, -} - -#[must_use] -pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { - for cause in error.chain() { - if cause.downcast_ref::().is_some() { - return DaemonClientErrorKind::Connect; - } - - if let Some(status) = cause.downcast_ref::() { - return match status.code() { - Code::Unavailable => DaemonClientErrorKind::Unavailable, - Code::Unimplemented => DaemonClientErrorKind::Unimplemented, - _ => DaemonClientErrorKind::Other, - }; - } - } - - DaemonClientErrorKind::Other -} - -// Wrap the grpc client -impl HistoryClient { - #[cfg(unix)] - pub async fn new(path: String) -> Result { - use eyre::Context; - - let log_path = path.clone(); - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let path = path.clone(); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at {}. Is it running?", - &log_path - ) - })?; - - let client = HistoryServiceClient::new(channel); - - Ok(HistoryClient { client }) - } - - #[cfg(not(unix))] - pub async fn new(port: u64) -> Result { - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let url = format!("127.0.0.1:{port}"); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" - ) - })?; - - let client = HistoryServiceClient::new(channel); - - Ok(HistoryClient { client }) - } - - pub async fn start_history(&mut self, h: History) -> Result { - let req = StartHistoryRequest { - command: h.command, - cwd: h.cwd, - hostname: h.hostname, - session: h.session, - timestamp: h.timestamp.unix_timestamp_nanos() as u64, - author: h.author, - intent: h.intent.unwrap_or_default(), - }; - - Ok(self.client.start_history(req).await?.into_inner()) - } - - pub async fn end_history( - &mut self, - id: String, - duration: u64, - exit: i64, - ) -> Result { - let req = EndHistoryRequest { id, duration, exit }; - - Ok(self.client.end_history(req).await?.into_inner()) - } - - pub async fn status(&mut self) -> Result { - Ok(self.client.status(StatusRequest {}).await?.into_inner()) - } - - pub async fn tail_history(&mut self) -> Result> { - Ok(self - .client - .tail_history(TailHistoryRequest {}) - .await? - .into_inner()) - } - - pub async fn shutdown(&mut self) -> Result { - let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); - Ok(resp.accepted) - } -} - -pub struct SearchClient { - client: SearchServiceClient, -} - -impl SearchClient { - #[cfg(unix)] - pub async fn new(path: String) -> Result { - let log_path = path.clone(); - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let path = path.clone(); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at {}. Is it running?", - &log_path - ) - })?; - - let client = SearchServiceClient::new(channel); - - Ok(SearchClient { client }) - } - - #[cfg(not(unix))] - pub async fn new(port: u64) -> Result { - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let url = format!("127.0.0.1:{port}"); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" - ) - })?; - - let client = SearchServiceClient::new(channel); - - Ok(SearchClient { client }) - } - - #[instrument(skip_all, level = Level::TRACE, name = "daemon_client_search", fields(query = %query, query_id = query_id))] - pub async fn search( - &mut self, - query: String, - query_id: u64, - filter_mode: FilterMode, - context: Option, - ) -> Result> { - let request = SearchRequest { - query, - query_id, - filter_mode: RpcFilterMode::from(filter_mode).into(), - context: context.map(RpcSearchContext::from), - }; - let request_stream = tokio_stream::once(request); - let response = span!(Level::TRACE, "daemon_client_search.request") - .in_scope(async || self.client.search(request_stream).await) - .await?; - - Ok(response.into_inner()) - } -} - -impl From for RpcFilterMode { - fn from(filter_mode: FilterMode) -> Self { - match filter_mode { - FilterMode::Global => RpcFilterMode::Global, - FilterMode::Host => RpcFilterMode::Host, - FilterMode::Session => RpcFilterMode::Session, - FilterMode::Directory => RpcFilterMode::Directory, - FilterMode::Workspace => RpcFilterMode::Workspace, - FilterMode::SessionPreload => RpcFilterMode::SessionPreload, - } - } -} - -impl From for RpcSearchContext { - fn from(context: Context) -> Self { - RpcSearchContext { - session_id: context.session, - cwd: context.cwd, - hostname: context.hostname, - host_id: context.host_id, - git_root: context - .git_root - .map(|path| path.to_string_lossy().to_string()), - } - } -} - -pub struct SemanticClient { - client: SemanticServiceClient, -} - -impl SemanticClient { - #[cfg(unix)] - pub async fn new(path: String) -> Result { - let log_path = path.clone(); - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let path = path.clone(); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at {}. Is it running?", - &log_path - ) - })?; - - let client = SemanticServiceClient::new(channel); - - Ok(SemanticClient { client }) - } - - #[cfg(not(unix))] - pub async fn new(port: u64) -> Result { - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let url = format!("127.0.0.1:{port}"); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" - ) - })?; - - let client = SemanticServiceClient::new(channel); - - Ok(SemanticClient { client }) - } - - #[cfg(unix)] - pub async fn from_settings(settings: &Settings) -> Result { - Self::new(settings.daemon.socket_path.clone()).await - } - - #[cfg(not(unix))] - pub async fn from_settings(settings: &Settings) -> Result { - Self::new(settings.daemon.tcp_port).await - } - - pub async fn record_commands( - &mut self, - captures: Vec, - ) -> Result { - let stream = tokio_stream::iter(captures); - Ok(self.client.record_commands(stream).await?.into_inner()) - } - - pub async fn command_output( - &mut self, - history_id: String, - ranges: Vec<(i64, i64)>, - ) -> Result { - let request = CommandOutputRequest { - history_id, - ranges: ranges - .into_iter() - .map(|(start, end)| OutputRange { start, end }) - .collect(), - }; - - Ok(self.client.command_output(request).await?.into_inner()) - } -} - -// ============================================================================ -// Control Client -// ============================================================================ - -/// Client for the Control gRPC service. -/// -/// Used to inject events into a running daemon from external processes. -pub struct ControlClient { - client: ControlServiceClient, -} - -impl ControlClient { - /// Connect to the daemon's control service. - #[cfg(unix)] - pub async fn new(path: String) -> Result { - let log_path = path.clone(); - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let path = path.clone(); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at {}. Is it running?", - &log_path - ) - })?; - - let client = ControlServiceClient::new(channel); - - Ok(ControlClient { client }) - } - - /// Connect to the daemon's control service. - #[cfg(not(unix))] - pub async fn new(port: u64) -> Result { - let channel = Endpoint::try_from("http://atuin_local_daemon:0")? - .connect_with_connector(service_fn(move |_: Uri| { - let url = format!("127.0.0.1:{port}"); - - async move { - Ok::<_, std::io::Error>(TokioIo::new(TcpStream::connect(url.clone()).await?)) - } - })) - .await - .wrap_err_with(|| { - format!( - "failed to connect to local atuin daemon at 127.0.0.1:{port}. Is it running?" - ) - })?; - - let client = ControlServiceClient::new(channel); - - Ok(ControlClient { client }) - } - - /// Connect using settings. - #[cfg(unix)] - pub async fn from_settings(settings: &Settings) -> Result { - Self::new(settings.daemon.socket_path.clone()).await - } - - /// Connect using settings. - #[cfg(not(unix))] - pub async fn from_settings(settings: &Settings) -> Result { - Self::new(settings.daemon.tcp_port).await - } - - /// Send an event to the daemon. - pub async fn send_event(&mut self, event: DaemonEvent) -> Result<()> { - let proto_event = daemon_event_to_proto(event); - let request = SendEventRequest { - event: Some(proto_event), - }; - self.client.send_event(request).await?; - Ok(()) - } -} - -/// Convert a daemon event to its proto representation. -fn daemon_event_to_proto(event: DaemonEvent) -> crate::control::send_event_request::Event { - use crate::control::send_event_request::Event; - - match event { - DaemonEvent::HistoryPruned => Event::HistoryPruned(HistoryPrunedEvent {}), - DaemonEvent::HistoryRebuilt => Event::HistoryRebuilt(HistoryRebuiltEvent {}), - DaemonEvent::HistoryDeleted { ids } => Event::HistoryDeleted(HistoryDeletedEvent { - ids: ids.into_iter().map(|id| id.0).collect(), - }), - DaemonEvent::ForceSync => Event::ForceSync(ForceSyncEvent {}), - DaemonEvent::SettingsReloaded => Event::SettingsReloaded(SettingsReloadedEvent {}), - DaemonEvent::ShutdownRequested => Event::Shutdown(ShutdownEvent {}), - // These events are internal and not sent via the control service - DaemonEvent::HistoryStarted(_) - | DaemonEvent::HistoryEnded(_) - | DaemonEvent::RecordsAdded(_) - | DaemonEvent::SyncCompleted { .. } - | DaemonEvent::SyncFailed { .. } => { - // Use shutdown as a fallback, though this shouldn't happen - tracing::warn!("attempted to send internal event via control service"); - Event::Shutdown(ShutdownEvent {}) - } - } -} - -// ============================================================================ -// Convenience Functions -// ============================================================================ - -/// Emit an event to the daemon. -/// -/// This is a fire-and-forget helper for sending events to the daemon from -/// external processes like CLI commands. If the daemon isn't running, this -/// will silently succeed (returns Ok). -/// -/// # Example -/// -/// ```ignore -/// // After pruning history -/// emit_event(DaemonEvent::HistoryPruned).await?; -/// -/// // After deleting specific history items -/// emit_event(DaemonEvent::HistoryDeleted { ids: vec![...] }).await?; -/// -/// // Request immediate sync -/// emit_event(DaemonEvent::ForceSync).await?; -/// ``` -pub async fn emit_event(event: DaemonEvent) -> Result<()> { - emit_event_with_settings(event, None).await -} - -/// Emit an event to the daemon with explicit settings. -/// -/// If settings are not provided, they will be loaded from the default location. -/// If the daemon isn't running, this will silently succeed. -pub async fn emit_event_with_settings( - event: DaemonEvent, - settings: Option<&Settings>, -) -> Result<()> { - // Load settings if not provided - let owned_settings; - let settings = match settings { - Some(s) => s, - None => { - owned_settings = Settings::new()?; - &owned_settings - } - }; - - // Try to connect - if daemon isn't running, that's fine - let mut client = match ControlClient::from_settings(settings).await { - Ok(c) => c, - Err(e) => { - tracing::debug!(?e, "daemon not running, skipping event emission"); - return Ok(()); - } - }; - - // Send the event - if let Err(e) = client.send_event(event).await { - tracing::debug!(?e, "failed to send event to daemon"); - // Don't fail - this is fire-and-forget - } - - Ok(()) -} diff --git a/crates/atuin-daemon/src/components/history.rs b/crates/atuin-daemon/src/components/history.rs deleted file mode 100644 index c82c8f94..00000000 --- a/crates/atuin-daemon/src/components/history.rs +++ /dev/null @@ -1,327 +0,0 @@ -//! History component. -//! -//! Handles command history lifecycle (start/end) and provides the History gRPC service. - -use std::{pin::Pin, sync::Arc}; - -use atuin_client::{ - database::Database, - history::{History, HistoryId, store::HistoryStore}, - settings::Settings, -}; -use dashmap::DashMap; -use eyre::Result; -use time::OffsetDateTime; -use tokio_stream::Stream; -use tonic::{Request, Response, Status}; -use tracing::{Level, instrument}; - -use crate::{ - daemon::{Component, DaemonHandle}, - events::DaemonEvent, - history::{ - EndHistoryReply, EndHistoryRequest, HistoryEntry, HistoryEventKind, ShutdownReply, - ShutdownRequest, StartHistoryReply, StartHistoryRequest, StatusReply, StatusRequest, - TailHistoryReply, TailHistoryRequest, - history_server::{History as HistorySvc, HistoryServer}, - }, -}; - -const DAEMON_PROTOCOL_VERSION: u32 = 1; - -/// History component - manages command history lifecycle. -/// -/// This component: -/// - Tracks currently running commands (stored in memory) -/// - Saves completed commands to the database and record store -/// - Emits history events for other components (e.g., search indexing) -/// - Provides the History gRPC service -pub struct HistoryComponent { - inner: Arc, -} - -struct HistoryComponentInner { - /// Commands currently running (not yet completed). - running: DashMap, - - /// Handle to the daemon (set during start). - handle: tokio::sync::RwLock>, - - /// History store for pushing records (set during start). - history_store: tokio::sync::RwLock>, -} - -impl HistoryComponent { - /// Create a new history component. - pub fn new() -> Self { - Self { - inner: Arc::new(HistoryComponentInner { - running: DashMap::new(), - handle: tokio::sync::RwLock::new(None), - history_store: tokio::sync::RwLock::new(None), - }), - } - } - - /// Get the gRPC service for this component. - /// - /// This returns a tonic service that can be added to a gRPC server. - pub fn grpc_service(&self) -> HistoryServer { - HistoryServer::new(HistoryGrpcService { - inner: self.inner.clone(), - }) - } -} - -impl Default for HistoryComponent { - fn default() -> Self { - Self::new() - } -} - -#[tonic::async_trait] -impl Component for HistoryComponent { - fn name(&self) -> &'static str { - "history" - } - - async fn start(&mut self, handle: DaemonHandle) -> Result<()> { - // Create the history store - let host_id = Settings::host_id().await?; - let history_store = - HistoryStore::new(handle.store().clone(), host_id, *handle.encryption_key()); - - *self.inner.history_store.write().await = Some(history_store); - *self.inner.handle.write().await = Some(handle); - - tracing::info!("history component started"); - Ok(()) - } - - async fn handle_event(&mut self, _event: &DaemonEvent) -> Result<()> { - // History component produces events but doesn't need to react to them - Ok(()) - } - - async fn stop(&mut self) -> Result<()> { - tracing::info!("history component stopped"); - Ok(()) - } -} - -/// The gRPC service implementation. -/// -/// This is a thin wrapper that delegates to the component's shared state. -pub struct HistoryGrpcService { - inner: Arc, -} - -fn history_to_tail_reply(kind: HistoryEventKind, history: History) -> TailHistoryReply { - TailHistoryReply { - kind: kind as i32, - history: Some(HistoryEntry { - timestamp: history.timestamp.unix_timestamp_nanos() as u64, - id: history.id.0, - command: history.command, - cwd: history.cwd, - session: history.session, - hostname: history.hostname, - author: history.author, - intent: history.intent.unwrap_or_default(), - exit: history.exit, - duration: history.duration, - }), - } -} - -#[tonic::async_trait] -impl HistorySvc for HistoryGrpcService { - type TailHistoryStream = Pin> + Send>>; - - #[instrument(skip_all, level = Level::INFO)] - async fn start_history( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - - let timestamp = - OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| { - Status::invalid_argument( - "failed to parse timestamp as unix time (expected nanos since epoch)", - ) - })?; - - let h: History = History::daemon() - .timestamp(timestamp) - .command(req.command) - .cwd(req.cwd) - .session(req.session) - .hostname(req.hostname) - .author(req.author) - .intent(req.intent) - .build() - .into(); - - // Emit the event - if let Some(handle) = self.inner.handle.read().await.as_ref() { - handle.emit(DaemonEvent::HistoryStarted(h.clone())); - } - - let id = h.id.clone(); - tracing::info!(id = id.to_string(), "start history"); - self.inner.running.insert(id.clone(), h); - - let reply = StartHistoryReply { - id: id.to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), - protocol: DAEMON_PROTOCOL_VERSION, - }; - - Ok(Response::new(reply)) - } - - #[instrument(skip_all, level = Level::INFO)] - async fn end_history( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - let id = HistoryId(req.id); - - if let Some((_, mut history)) = self.inner.running.remove(&id) { - history.exit = req.exit; - history.duration = match req.duration { - 0 => i64::try_from( - (OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(), - ) - .expect("failed to convert calculated duration to i64"), - value => i64::try_from(value).expect("failed to get i64 duration"), - }; - - // Get the handle and store to save the history - let handle_guard = self.inner.handle.read().await; - let handle = handle_guard - .as_ref() - .ok_or_else(|| Status::internal("component not initialized"))?; - - let store_guard = self.inner.history_store.read().await; - let history_store = store_guard - .as_ref() - .ok_or_else(|| Status::internal("component not initialized"))?; - - // Save to database - handle - .history_db() - .save(&history) - .await - .map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?; - - tracing::info!( - id = id.0.to_string(), - duration = history.duration, - "end history" - ); - - // Push to record store - let (record_id, idx) = history_store - .push(history.clone()) - .await - .map_err(|e| Status::internal(format!("failed to push record to store: {e:?}")))?; - - // Emit the event - handle.emit(DaemonEvent::HistoryEnded(history)); - - let reply = EndHistoryReply { - id: record_id.0.to_string(), - idx, - version: env!("CARGO_PKG_VERSION").to_string(), - protocol: DAEMON_PROTOCOL_VERSION, - }; - - return Ok(Response::new(reply)); - } - - Err(Status::not_found(format!( - "could not find history with id: {id}" - ))) - } - - #[instrument(skip_all, level = Level::INFO)] - async fn tail_history( - &self, - _request: Request, - ) -> Result, Status> { - let handle_guard = self.inner.handle.read().await; - let handle = handle_guard - .as_ref() - .cloned() - .ok_or_else(|| Status::internal("component not initialized"))?; - - let mut rx = handle.subscribe(); - let (tx, out_rx) = tokio::sync::mpsc::channel::>(128); - - tokio::spawn(async move { - loop { - let event = match rx.recv().await { - Ok(event) => event, - Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { - let _ = tx - .send(Err(Status::resource_exhausted(format!( - "tail stream lagged behind and dropped {skipped} events" - )))) - .await; - break; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => break, - }; - - let reply = match event { - DaemonEvent::HistoryStarted(history) => { - Some(history_to_tail_reply(HistoryEventKind::Started, history)) - } - DaemonEvent::HistoryEnded(history) => { - Some(history_to_tail_reply(HistoryEventKind::Ended, history)) - } - _ => None, - }; - - if let Some(reply) = reply - && tx.send(Ok(reply)).await.is_err() - { - break; - } - } - }); - - let stream = tokio_stream::wrappers::ReceiverStream::new(out_rx); - Ok(Response::new(Box::pin(stream))) - } - - #[instrument(skip_all, level = Level::INFO)] - async fn status( - &self, - _request: Request, - ) -> Result, Status> { - let reply = StatusReply { - healthy: true, - version: env!("CARGO_PKG_VERSION").to_string(), - pid: std::process::id(), - protocol: DAEMON_PROTOCOL_VERSION, - }; - - Ok(Response::new(reply)) - } - - #[instrument(skip_all, level = Level::INFO)] - async fn shutdown( - &self, - _request: Request, - ) -> Result, Status> { - // Use the daemon handle to request shutdown - if let Some(handle) = self.inner.handle.read().await.as_ref() { - handle.shutdown(); - } - Ok(Response::new(ShutdownReply { accepted: true })) - } -} diff --git a/crates/atuin-daemon/src/components/mod.rs b/crates/atuin-daemon/src/components/mod.rs deleted file mode 100644 index 447e31df..00000000 --- a/crates/atuin-daemon/src/components/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Daemon components. -//! -//! Components are the building blocks of the daemon. Each component handles -//! a specific domain and can: -//! -//! - Expose gRPC services -//! - React to events -//! - Spawn background tasks -//! -//! Available components: -//! -//! - [`history::HistoryComponent`]: Command history lifecycle management -//! - [`search::SearchComponent`]: Fuzzy search over history -//! - [`semantic::SemanticComponent`]: In-memory semantic command captures -//! - [`sync::SyncComponent`]: Cloud sync - -pub mod history; -pub mod search; -pub mod semantic; -pub mod sync; - -pub use history::HistoryComponent; -pub use search::SearchComponent; -pub use semantic::SemanticComponent; -pub use sync::SyncComponent; diff --git a/crates/atuin-daemon/src/components/search.rs b/crates/atuin-daemon/src/components/search.rs deleted file mode 100644 index 9fc87fae..00000000 --- a/crates/atuin-daemon/src/components/search.rs +++ /dev/null @@ -1,413 +0,0 @@ -//! Search component. -//! -//! Provides fuzzy search over command history using the Nucleo search library -//! with frecency-based ranking and dynamic filtering. - -use std::{pin::Pin, sync::Arc}; - -use atuin_client::database::Database; -use eyre::Result; -use tokio::sync::RwLock; -use tokio_stream::Stream; -use tonic::{Request, Response, Status, Streaming}; -use tracing::{Level, debug, info, instrument, span, trace}; -use uuid::Uuid; - -use crate::{ - daemon::{Component, DaemonHandle}, - events::DaemonEvent, - search::{ - FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse, - search_server::{Search as SearchSvc, SearchServer}, - }, -}; - -const PAGE_SIZE: usize = 5000; -const RESULTS_LIMIT: u32 = 200; -/// How often to rebuild the frecency map (in seconds). -const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60; - -/// Search component - provides fuzzy search over command history. -/// -/// This component: -/// - Maintains a deduplicated search index with frecency ranking -/// - Loads history from the database on startup -/// - Updates the index when history events occur -/// - Provides the Search gRPC service -pub struct SearchComponent { - index: Arc>, - handle: tokio::sync::RwLock>, - loader_handle: Option>, - frecency_handle: Option>, -} - -impl SearchComponent { - /// Create a new search component. - pub fn new() -> Self { - Self { - index: Arc::new(RwLock::new(SearchIndex::new())), - handle: tokio::sync::RwLock::new(None), - loader_handle: None, - frecency_handle: None, - } - } - - /// Get the gRPC service for this component. - pub fn grpc_service(&self) -> SearchServer { - SearchServer::new(SearchGrpcService { - index: self.index.clone(), - }) - } - - /// Rebuild the entire search index from the database. - async fn rebuild_index(&self) -> Result<()> { - let handle_guard = self.handle.read().await; - let handle = handle_guard - .as_ref() - .ok_or_else(|| eyre::eyre!("component not initialized"))?; - - info!("Rebuilding search index from database"); - - // Create a new index - let new_index = SearchIndex::new(); - - // Load all history into the new index - let db = handle.history_db().clone(); - let mut pager = db.all_paged(PAGE_SIZE, false, true); - loop { - match pager.next().await { - Ok(Some(histories)) => { - info!( - "Loading {} history entries into search index", - histories.len() - ); - new_index.add_histories(&histories); - } - Ok(None) => break, - Err(e) => { - tracing::error!("Failed to load history during rebuild: {}", e); - break; - } - } - } - - info!( - "Search index rebuild complete; {} unique commands", - new_index.command_count() - ); - - // Replace the old index with the new one - *self.index.write().await = new_index; - Ok(()) - } -} - -impl Default for SearchComponent { - fn default() -> Self { - Self::new() - } -} - -#[tonic::async_trait] -impl Component for SearchComponent { - fn name(&self) -> &'static str { - "search" - } - - async fn start(&mut self, handle: DaemonHandle) -> Result<()> { - *self.handle.write().await = Some(handle.clone()); - - // Spawn background task to load history into index - let index = self.index.clone(); - let db = handle.history_db().clone(); - let handle_for_loader = handle.clone(); - - self.loader_handle = Some(tokio::spawn(async move { - info!( - "Loading history into search index; page size = {}", - PAGE_SIZE - ); - let mut pager = db.all_paged(PAGE_SIZE, false, true); - loop { - match pager.next().await { - Ok(Some(histories)) => { - info!( - "Loading {} history entries into search index", - histories.len() - ); - index.read().await.add_histories(&histories); - } - Ok(None) => { - info!( - "Initial history load complete; {} unique commands indexed", - index.read().await.command_count() - ); - // Build initial frecency map with current settings - let settings = handle_for_loader.settings().await; - index.read().await.rebuild_frecency(&settings.search).await; - info!("Initial frecency map built"); - break; - } - Err(e) => { - tracing::error!("Failed to load history: {}", e); - break; - } - } - } - })); - - // Spawn background task to periodically refresh frecency - let index_for_frecency = self.index.clone(); - let handle_for_frecency = handle.clone(); - self.frecency_handle = Some(tokio::spawn(async move { - let mut interval = tokio::time::interval(std::time::Duration::from_secs( - FRECENCY_REFRESH_INTERVAL_SECS, - )); - loop { - interval.tick().await; - trace!("Refreshing frecency map"); - let settings = handle_for_frecency.settings().await; - index_for_frecency - .read() - .await - .rebuild_frecency(&settings.search) - .await; - } - })); - - tracing::info!("search component started"); - Ok(()) - } - - async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { - match event { - DaemonEvent::RecordsAdded(records) => { - debug!( - count = records.len(), - "Processing added records for search index" - ); - - let handle_guard = self.handle.read().await; - if let Some(handle) = handle_guard.as_ref() { - let histories: Vec<_> = handle - .history_db() - .query_history( - format!( - "select * from history where id in ({})", - records - .iter() - .map(|record| record.0.to_string()) - .collect::>() - .join(",") - ) - .as_str(), - ) - .await - .unwrap_or_default(); - - span!(Level::TRACE, "inject_records", count = histories.len()) - .in_scope(async || { - self.index.read().await.add_histories(&histories); - }) - .await; - } - } - DaemonEvent::HistoryStarted(history) => { - debug!(id = %history.id, command = %history.command, "History started (no index action)"); - } - DaemonEvent::HistoryEnded(history) => { - span!(Level::TRACE, "inject_history_ended") - .in_scope(async || { - self.index.read().await.add_history(history); - }) - .await; - } - DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => { - info!("History store pruned or rebuilt, rebuilding search index"); - if let Err(e) = self.rebuild_index().await { - tracing::error!("Failed to rebuild search index: {}", e); - } - } - DaemonEvent::HistoryDeleted { ids } => { - info!( - count = ids.len(), - "History deleted, rebuilding search index" - ); - // For now, just rebuild the entire index. A more efficient implementation - // would remove specific items from the index. - if let Err(e) = self.rebuild_index().await { - tracing::error!("Failed to rebuild search index: {}", e); - } - } - DaemonEvent::SettingsReloaded => { - info!("Settings reloaded, rebuilding frecency map with new multipliers"); - let handle_guard = self.handle.read().await; - if let Some(handle) = handle_guard.as_ref() { - let settings = handle.settings().await; - self.index - .read() - .await - .rebuild_frecency(&settings.search) - .await; - } - } - // Events we don't care about - DaemonEvent::SyncCompleted { .. } - | DaemonEvent::SyncFailed { .. } - | DaemonEvent::ForceSync - | DaemonEvent::ShutdownRequested => {} - } - Ok(()) - } - - async fn stop(&mut self) -> Result<()> { - if let Some(handle) = self.loader_handle.take() { - handle.abort(); - } - if let Some(handle) = self.frecency_handle.take() { - handle.abort(); - } - tracing::info!("search component stopped"); - Ok(()) - } -} - -/// The gRPC service implementation. -pub struct SearchGrpcService { - index: Arc>, -} - -#[tonic::async_trait] -impl SearchSvc for SearchGrpcService { - type SearchStream = Pin> + Send>>; - - #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")] - async fn search( - &self, - request: Request>, - ) -> Result, Status> { - let mut in_stream = request.into_inner(); - let index = self.index.clone(); - - // Create output channel - let (tx, rx) = tokio::sync::mpsc::channel::>(128); - - // Spawn task to handle incoming requests and send responses - tokio::spawn(async move { - while let Some(req) = in_stream.message().await.transpose() { - match req { - Ok(search_req) => { - let query = search_req.query; - let query_id = search_req.query_id; - let filter_mode: FilterMode = search_req - .filter_mode - .try_into() - .unwrap_or(FilterMode::Global); - let proto_context = search_req.context; - - debug!( - "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}", - query, - query_id, - filter_mode.as_str_name(), - proto_context - ); - - // Convert proto FilterMode + context to IndexFilterMode - let index_filter = convert_filter_mode(filter_mode, &proto_context); - - // Build QueryContext from proto context - let query_context = proto_context - .map(|ctx| QueryContext { - cwd: Some(with_trailing_slash(&ctx.cwd)), - git_root: ctx.git_root.map(|s| with_trailing_slash(&s)), - hostname: Some(ctx.hostname), - session_id: Some(ctx.session_id), - }) - .unwrap_or_default(); - - // Perform the search - let history_ids = - span!(Level::TRACE, "daemon_search_query", %query, query_id) - .in_scope(|| async { - let index = index.read().await; - index - .search(&query, index_filter, &query_context, RESULTS_LIMIT) - .await - }) - .await; - - // Convert history IDs to bytes - let ids: Vec> = history_ids - .iter() - .filter_map(|id| { - Uuid::parse_str(id) - .ok() - .map(|uuid| uuid.as_bytes().to_vec()) - }) - .collect(); - - if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() { - break; // Client disconnected - } - } - Err(e) => { - let _ = tx.send(Err(e)).await; - break; - } - } - } - }); - - // Convert receiver to stream - let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx); - Ok(Response::new(Box::pin(out_stream))) - } -} - -/// Convert proto FilterMode and context to IndexFilterMode. -fn convert_filter_mode( - mode: FilterMode, - context: &Option, -) -> IndexFilterMode { - match (mode, context) { - (FilterMode::Global, _) => IndexFilterMode::Global, - (FilterMode::Directory, Some(ctx)) => { - IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) - } - (FilterMode::Workspace, Some(ctx)) => { - if let Some(ref git_root) = ctx.git_root { - IndexFilterMode::Workspace(with_trailing_slash(git_root)) - } else { - // Fall back to directory if no git root - IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) - } - } - (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()), - (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()), - (FilterMode::SessionPreload, Some(ctx)) => { - // SessionPreload is similar to Session - filter by session - IndexFilterMode::Session(ctx.session_id.clone()) - } - // If no context provided, fall back to global - _ => IndexFilterMode::Global, - } -} - -#[cfg(windows)] -pub fn with_trailing_slash(s: &str) -> String { - if s.ends_with('\\') { - s.to_string() - } else { - format!("{}\\", s) - } -} - -#[cfg(not(windows))] -pub fn with_trailing_slash(s: &str) -> String { - if s.ends_with('/') { - s.to_string() - } else { - format!("{}/", s) - } -} diff --git a/crates/atuin-daemon/src/components/semantic.rs b/crates/atuin-daemon/src/components/semantic.rs deleted file mode 100644 index dff38fd3..00000000 --- a/crates/atuin-daemon/src/components/semantic.rs +++ /dev/null @@ -1,900 +0,0 @@ -//! Semantic command capture component. -//! -//! This is a prototype in-memory store for completed command captures emitted -//! by atuin-pty-proxy. It keeps recent captures per Atuin session and indexes -//! them by history ID for AI tool lookup. - -use std::collections::{HashMap, VecDeque}; -use std::fmt::{Display, Formatter}; -use std::sync::Arc; - -use atuin_client::history::{History, HistoryId}; -use eyre::Result; -use tokio::sync::Mutex; -use tonic::{Request, Response, Status, Streaming}; -use tracing::{Level, instrument}; - -use crate::{ - daemon::{Component, DaemonHandle}, - events::DaemonEvent, - semantic::{ - CommandCapture, CommandOutputReply, CommandOutputRequest, OutputLine, RecordCommandsReply, - semantic_server::{Semantic as SemanticSvc, SemanticServer}, - }, -}; - -const MAX_SESSIONS: usize = 20; -const MAX_COMMANDS_PER_SESSION: usize = 128; -const MAX_BYTES_PER_SESSION: usize = 32 * 1024 * 1024; -const MAX_PENDING_HISTORIES: usize = 128; - -/// Stores completed command captures and associates them with history events. -pub struct SemanticComponent { - inner: Arc, -} - -struct SemanticComponentInner { - state: Mutex, -} - -#[derive(Default)] -struct SemanticState { - sessions: HashMap, - session_lru: VecDeque, - history_index: HashMap, - pending_histories: VecDeque, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct SessionId(String); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct CaptureId(u64); - -#[derive(Debug, Clone, PartialEq, Eq)] -struct CaptureRef { - session_id: SessionId, - capture_id: CaptureId, -} - -#[derive(Default)] -struct SessionCaptures { - next_id: u64, - records: VecDeque, - output_bytes: usize, -} - -struct StoredCapture { - id: CaptureId, - history_id: HistoryId, - output_bytes: usize, - record: SemanticCommandRecord, -} - -struct EvictedCapture { - history_id: HistoryId, - capture_id: CaptureId, -} - -#[derive(Debug, Clone)] -struct SemanticCommandRecord { - capture: CommandCapture, - history: Option, -} - -impl SemanticComponent { - pub fn new() -> Self { - Self { - inner: Arc::new(SemanticComponentInner { - state: Mutex::new(SemanticState::default()), - }), - } - } - - pub fn grpc_service(&self) -> SemanticServer { - SemanticServer::new(SemanticGrpcService { - inner: self.inner.clone(), - }) - } -} - -impl Default for SemanticComponent { - fn default() -> Self { - Self::new() - } -} - -#[tonic::async_trait] -impl Component for SemanticComponent { - fn name(&self) -> &'static str { - "semantic" - } - - async fn start(&mut self, _handle: DaemonHandle) -> Result<()> { - tracing::info!("semantic component started"); - Ok(()) - } - - async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { - if let DaemonEvent::HistoryEnded(history) = event { - self.inner.record_history(history.clone()).await; - } - - Ok(()) - } - - async fn stop(&mut self) -> Result<()> { - let state = self.inner.state.lock().await; - tracing::info!( - sessions = state.sessions.len(), - records = state.record_count(), - indexed_histories = state.history_index.len(), - pending_histories = state.pending_histories.len(), - "semantic component stopped" - ); - Ok(()) - } -} - -impl SemanticComponentInner { - async fn record_capture(&self, capture: CommandCapture) -> bool { - let mut state = self.state.lock().await; - state.record_capture(capture) - } - - async fn record_history(&self, history: History) { - let mut state = self.state.lock().await; - state.record_history(history); - } - - async fn command_output(&self, request: &CommandOutputRequest) -> CommandOutputReply { - let mut state = self.state.lock().await; - state.command_output(request) - } -} - -impl SemanticState { - fn record_capture(&mut self, mut capture: CommandCapture) -> bool { - let Some(history_id) = history_id_from_str(capture.history_id.as_deref()) else { - tracing::debug!( - command_bytes = capture.command.len(), - prompt_bytes = capture.prompt.len(), - output_bytes = capture.output.len(), - output_truncated = capture.output_truncated, - "dropping semantic command capture without history id" - ); - return false; - }; - - let history = take_pending_history(&mut self.pending_histories, &history_id); - let Some(session_id) = capture - .session_id - .as_deref() - .and_then(|session_id| SessionId::try_from(session_id).ok()) - .or_else(|| { - history - .as_ref() - .and_then(|history| SessionId::try_from(history.session.as_str()).ok()) - }) - else { - tracing::debug!( - history_id = %history_id, - command_bytes = capture.command.len(), - prompt_bytes = capture.prompt.len(), - output_bytes = capture.output.len(), - output_truncated = capture.output_truncated, - "dropping semantic command capture without session id" - ); - return false; - }; - - capture.history_id = Some(history_id.to_string()); - capture.session_id = Some(session_id.to_string()); - if capture.output_observed_bytes == 0 { - capture.output_observed_bytes = capture.output.len() as u64; - } - - let record = SemanticCommandRecord { capture, history }; - log_record(&record, "recorded semantic command capture"); - self.push_record(session_id, history_id, record); - true - } - - fn record_history(&mut self, history: History) { - let history_id = history.id.clone(); - - if let Some(capture_ref) = self.history_index.get(&history_id).cloned() { - if let Some(stored) = self.stored_capture_mut(&capture_ref) { - stored.record.history = Some(history); - log_record( - &stored.record, - "associated semantic command capture with history", - ); - return; - } - - self.history_index.remove(&history_id); - } - - tracing::debug!( - id = %history.id, - command_bytes = history.command.len(), - "history ended before semantic capture arrived" - ); - push_pending_history(&mut self.pending_histories, history); - } - - fn command_output(&mut self, request: &CommandOutputRequest) -> CommandOutputReply { - let Some(history_id) = history_id_from_str(Some(&request.history_id)) else { - return command_output_not_found(); - }; - let Some(capture_ref) = self.history_index.get(&history_id).cloned() else { - return command_output_not_found(); - }; - - let Some(reply) = self.command_output_for_ref(&capture_ref, &request.ranges) else { - self.history_index.remove(&history_id); - return command_output_not_found(); - }; - - self.touch_session(&capture_ref.session_id); - reply - } - - fn command_output_for_ref( - &self, - capture_ref: &CaptureRef, - ranges: &[crate::semantic::OutputRange], - ) -> Option { - let stored = self - .sessions - .get(&capture_ref.session_id)? - .stored_capture(capture_ref.capture_id)?; - let output = &stored.record.capture.output; - let output_observed_bytes = stored - .record - .capture - .output_observed_bytes - .max(output.len() as u64); - - Some(CommandOutputReply { - found: true, - output: String::new(), - total_bytes: output.len() as u64, - total_lines: output.lines().count() as u64, - lines: select_output_ranges(output, ranges), - output_truncated: stored.record.capture.output_truncated, - output_observed_bytes, - }) - } - - fn push_record( - &mut self, - session_id: SessionId, - history_id: HistoryId, - record: SemanticCommandRecord, - ) { - self.touch_session(&session_id); - - let (capture_id, evicted) = { - let session = self.sessions.entry(session_id.clone()).or_default(); - session.push(history_id.clone(), record) - }; - - let capture_ref = CaptureRef { - session_id: session_id.clone(), - capture_id, - }; - self.history_index.insert(history_id, capture_ref); - - for evicted in evicted { - self.remove_history_index_if_matches( - &session_id, - &evicted.history_id, - evicted.capture_id, - ); - } - - self.expire_lru_sessions(); - } - - fn touch_session(&mut self, session_id: &SessionId) { - if let Some(index) = self.session_lru.iter().position(|id| id == session_id) { - self.session_lru.remove(index); - } - self.session_lru.push_back(session_id.clone()); - } - - fn expire_lru_sessions(&mut self) { - while self.session_lru.len() > MAX_SESSIONS { - let Some(session_id) = self.session_lru.pop_front() else { - break; - }; - let Some(session) = self.sessions.remove(&session_id) else { - continue; - }; - - for stored in session.records { - self.remove_history_index_if_matches(&session_id, &stored.history_id, stored.id); - } - } - } - - fn remove_history_index_if_matches( - &mut self, - session_id: &SessionId, - history_id: &HistoryId, - capture_id: CaptureId, - ) { - if self - .history_index - .get(history_id) - .is_some_and(|capture_ref| { - &capture_ref.session_id == session_id && capture_ref.capture_id == capture_id - }) - { - self.history_index.remove(history_id); - } - } - - fn stored_capture_mut(&mut self, capture_ref: &CaptureRef) -> Option<&mut StoredCapture> { - self.sessions - .get_mut(&capture_ref.session_id)? - .stored_capture_mut(capture_ref.capture_id) - } - - fn record_count(&self) -> usize { - self.sessions - .values() - .map(|session| session.records.len()) - .sum() - } -} - -impl SessionCaptures { - fn push( - &mut self, - history_id: HistoryId, - record: SemanticCommandRecord, - ) -> (CaptureId, Vec) { - self.push_with_limits( - history_id, - record, - MAX_COMMANDS_PER_SESSION, - MAX_BYTES_PER_SESSION, - ) - } - - fn push_with_limits( - &mut self, - history_id: HistoryId, - record: SemanticCommandRecord, - max_commands: usize, - max_output_bytes: usize, - ) -> (CaptureId, Vec) { - let capture_id = CaptureId(self.next_id); - self.next_id = self.next_id.saturating_add(1); - let output_bytes = record.capture.output.len(); - self.output_bytes = self.output_bytes.saturating_add(output_bytes); - self.records.push_back(StoredCapture { - id: capture_id, - history_id, - output_bytes, - record, - }); - - ( - capture_id, - self.evict_to_limits(max_commands, max_output_bytes), - ) - } - - fn evict_to_limits( - &mut self, - max_commands: usize, - max_output_bytes: usize, - ) -> Vec { - let mut evicted = Vec::new(); - while self.records.len() > max_commands || self.output_bytes > max_output_bytes { - let Some(record) = self.records.pop_front() else { - break; - }; - self.output_bytes = self.output_bytes.saturating_sub(record.output_bytes); - evicted.push(EvictedCapture { - history_id: record.history_id, - capture_id: record.id, - }); - } - evicted - } - - fn stored_capture(&self, capture_id: CaptureId) -> Option<&StoredCapture> { - self.records.iter().find(|record| record.id == capture_id) - } - - fn stored_capture_mut(&mut self, capture_id: CaptureId) -> Option<&mut StoredCapture> { - self.records - .iter_mut() - .find(|record| record.id == capture_id) - } -} - -impl TryFrom<&str> for SessionId { - type Error = (); - - fn try_from(value: &str) -> std::result::Result { - let value = value.trim(); - if value.is_empty() { - return Err(()); - } - - Ok(Self(value.to_string())) - } -} - -impl TryFrom for SessionId { - type Error = (); - - fn try_from(value: String) -> std::result::Result { - Self::try_from(value.as_str()) - } -} - -impl AsRef for SessionId { - fn as_ref(&self) -> &str { - &self.0 - } -} - -impl Display for SessionId { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str(&self.0) - } -} - -pub struct SemanticGrpcService { - inner: Arc, -} - -#[tonic::async_trait] -impl SemanticSvc for SemanticGrpcService { - #[instrument(skip_all, level = Level::INFO)] - async fn record_commands( - &self, - request: Request>, - ) -> Result, Status> { - let mut stream = request.into_inner(); - let mut accepted = 0_u64; - - while let Some(capture) = stream.message().await? { - if self.inner.record_capture(capture).await { - accepted += 1; - } - } - - Ok(Response::new(RecordCommandsReply { accepted })) - } - - #[instrument(skip_all, level = Level::INFO)] - async fn command_output( - &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); - if request.history_id.trim().is_empty() { - return Err(Status::invalid_argument("history_id is required")); - } - - Ok(Response::new(self.inner.command_output(&request).await)) - } -} - -fn history_id_from_str(value: Option<&str>) -> Option { - let value = value?.trim(); - (!value.is_empty()).then(|| HistoryId(value.to_string())) -} - -fn take_pending_history( - histories: &mut VecDeque, - history_id: &HistoryId, -) -> Option { - let index = histories - .iter() - .position(|history| &history.id == history_id)?; - histories.remove(index) -} - -fn push_pending_history(histories: &mut VecDeque, history: History) { - if let Some(index) = histories - .iter() - .position(|pending| pending.id == history.id) - { - histories.remove(index); - } - - histories.push_back(history); - trim_front(histories, MAX_PENDING_HISTORIES); -} - -fn trim_front(records: &mut VecDeque, max_len: usize) { - while records.len() > max_len { - records.pop_front(); - } -} - -fn command_output_not_found() -> CommandOutputReply { - CommandOutputReply { - found: false, - output: String::new(), - total_bytes: 0, - total_lines: 0, - lines: Vec::new(), - output_truncated: false, - output_observed_bytes: 0, - } -} - -fn select_output_ranges(output: &str, ranges: &[crate::semantic::OutputRange]) -> Vec { - let lines: Vec<&str> = output.lines().collect(); - if lines.is_empty() { - return Vec::new(); - } - - let ranges = if ranges.is_empty() { - vec![crate::semantic::OutputRange { start: 0, end: 999 }] - } else { - ranges.to_vec() - }; - - let mut ranges = ranges - .into_iter() - .filter_map(|range| normalize_line_range(range.start, range.end, lines.len())) - .collect::>(); - ranges.sort_unstable_by_key(|(start, _)| *start); - - let mut merged: Vec<(usize, usize)> = Vec::new(); - for (start, end) in ranges { - match merged.last_mut() { - Some((_, merged_end)) if start <= merged_end.saturating_add(1) => { - *merged_end = (*merged_end).max(end); - } - _ => merged.push((start, end)), - } - } - - merged - .into_iter() - .flat_map(|(start, end)| { - lines[start..=end] - .iter() - .enumerate() - .map(move |(offset, line)| OutputLine { - line_number: (start + offset + 1) as u64, - content: (*line).to_string(), - }) - }) - .collect() -} - -fn normalize_line_range(start: i64, end: i64, line_count: usize) -> Option<(usize, usize)> { - let line_count = i64::try_from(line_count).ok()?; - let start = if start < 0 { line_count + start } else { start }; - let end = if end < 0 { line_count + end } else { end }; - - if end < 0 || start >= line_count { - return None; - } - - let start = start.max(0); - let end = end.min(line_count - 1); - - (start <= end).then_some((start as usize, end as usize)) -} - -fn log_record(record: &SemanticCommandRecord, message: &'static str) { - let history_id = record.capture.history_id.as_deref().unwrap_or(""); - let associated_history_id = record - .history - .as_ref() - .map(|history| history.id.to_string()); - let exit = record.history.as_ref().map(|history| history.exit); - let duration = record.history.as_ref().map(|history| history.duration); - let author = record - .history - .as_ref() - .map(|history| history.author.as_str()); - let session_id = record.capture.session_id.as_deref(); - - tracing::debug!( - history_id = %history_id, - associated_history_id = ?associated_history_id, - session_id = ?session_id, - command_bytes = record.capture.command.len(), - prompt_bytes = record.capture.prompt.len(), - output_bytes = record.capture.output.len(), - output_truncated = record.capture.output_truncated, - output_observed_bytes = record.capture.output_observed_bytes, - capture_exit_code = ?record.capture.exit_code, - history_exit = ?exit, - duration = ?duration, - author = ?author, - "{message}" - ); -} - -#[cfg(test)] -mod tests { - use super::*; - use time::OffsetDateTime; - - fn history(id: &str, session: &str, command: &str) -> History { - History { - id: HistoryId(id.to_string()), - timestamp: OffsetDateTime::UNIX_EPOCH, - duration: 0, - exit: 0, - command: command.to_string(), - cwd: String::new(), - session: session.to_string(), - hostname: String::new(), - author: String::new(), - intent: None, - deleted_at: None, - } - } - - fn capture(history_id: Option<&str>, session_id: Option<&str>, output: &str) -> CommandCapture { - CommandCapture { - prompt: String::new(), - command: String::new(), - output: output.to_string(), - exit_code: None, - history_id: history_id.map(str::to_string), - session_id: session_id.map(str::to_string), - output_truncated: false, - output_observed_bytes: output.len() as u64, - } - } - - fn command_output(state: &mut SemanticState, history_id: &str) -> CommandOutputReply { - state.command_output(&CommandOutputRequest { - history_id: history_id.to_string(), - ranges: Vec::new(), - }) - } - - fn output_line(line_number: u64, content: &str) -> OutputLine { - OutputLine { - line_number, - content: content.to_string(), - } - } - - #[test] - fn drops_capture_without_history_id() { - let mut state = SemanticState::default(); - - assert!(!state.record_capture(capture(None, Some("session-1"), "output"))); - assert!(!command_output(&mut state, "id-1").found); - assert_eq!(state.record_count(), 0); - } - - #[test] - fn stores_capture_by_session_and_history_id() { - let mut state = SemanticState::default(); - - assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); - - let reply = command_output(&mut state, "id-1"); - assert!(reply.found); - assert_eq!(reply.total_bytes, 6); - assert_eq!(reply.output_observed_bytes, 6); - assert_eq!(reply.lines, vec![output_line(1, "output")]); - } - - #[test] - fn uses_pending_history_session_when_capture_session_is_missing() { - let mut state = SemanticState::default(); - - state.record_history(history("id-1", "session-from-history", "cargo test")); - assert!(state.record_capture(capture(Some("id-1"), None, "output"))); - - assert!( - state - .sessions - .contains_key(&SessionId("session-from-history".to_string())) - ); - assert!(command_output(&mut state, "id-1").found); - } - - #[test] - fn associates_history_by_id_after_capture_arrives() { - let mut state = SemanticState::default(); - - assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); - state.record_history(history("id-1", "session-1", "different command")); - - let capture_ref = state - .history_index - .get(&HistoryId("id-1".to_string())) - .unwrap(); - let stored = state - .sessions - .get(&capture_ref.session_id) - .unwrap() - .stored_capture(capture_ref.capture_id) - .unwrap(); - assert!(stored.record.history.is_some()); - } - - #[test] - fn evicts_oldest_command_when_session_ring_is_full() { - let mut state = SemanticState::default(); - - for index in 0..=MAX_COMMANDS_PER_SESSION { - assert!(state.record_capture(capture( - Some(&format!("id-{index}")), - Some("session-1"), - "output", - ))); - } - - assert!(!command_output(&mut state, "id-0").found); - assert!(command_output(&mut state, &format!("id-{MAX_COMMANDS_PER_SESSION}")).found); - assert_eq!(state.record_count(), MAX_COMMANDS_PER_SESSION); - } - - #[test] - fn evicts_oldest_session_after_lru_limit() { - let mut state = SemanticState::default(); - - for index in 0..MAX_SESSIONS { - assert!(state.record_capture(capture( - Some(&format!("id-{index}")), - Some(&format!("session-{index}")), - "output", - ))); - } - assert!(command_output(&mut state, "id-0").found); - - assert!(state.record_capture(capture(Some("new-id"), Some("new-session"), "output",))); - - assert!(command_output(&mut state, "id-0").found); - assert!(!command_output(&mut state, "id-1").found); - assert!(command_output(&mut state, "new-id").found); - assert_eq!(state.sessions.len(), MAX_SESSIONS); - } - - #[test] - fn evicts_by_session_byte_limit() { - let mut session = SessionCaptures::default(); - let first_output = "x".repeat(10); - let second_output = "y"; - let (_, evicted_first) = session.push_with_limits( - HistoryId("first".to_string()), - SemanticCommandRecord { - capture: capture(Some("first"), Some("session-1"), &first_output), - history: None, - }, - MAX_COMMANDS_PER_SESSION, - 10, - ); - assert!(evicted_first.is_empty()); - - let (_, evicted_second) = session.push_with_limits( - HistoryId("second".to_string()), - SemanticCommandRecord { - capture: capture(Some("second"), Some("session-1"), second_output), - history: None, - }, - MAX_COMMANDS_PER_SESSION, - 10, - ); - - assert_eq!(evicted_second.len(), 1); - assert_eq!(evicted_second[0].history_id, HistoryId("first".to_string())); - assert_eq!(session.records.len(), 1); - assert_eq!(session.output_bytes, 1); - } - - #[test] - fn command_output_reports_truncation_metadata() { - let mut state = SemanticState::default(); - let mut capture = capture(Some("id-1"), Some("session-1"), "partial"); - capture.output_truncated = true; - capture.output_observed_bytes = 1024; - - assert!(state.record_capture(capture)); - - let reply = command_output(&mut state, "id-1"); - assert!(reply.output_truncated); - assert_eq!(reply.total_bytes, 7); - assert_eq!(reply.output_observed_bytes, 1024); - } - - #[test] - fn output_ranges_are_line_based_inclusive_and_support_negative_offsets() { - let output = "zero\none\ntwo\nthree\nfour"; - let ranges = vec![ - crate::semantic::OutputRange { start: 1, end: 2 }, - crate::semantic::OutputRange { start: -2, end: -1 }, - ]; - - assert_eq!( - select_output_ranges(output, &ranges), - vec![ - output_line(2, "one"), - output_line(3, "two"), - output_line(4, "three"), - output_line(5, "four"), - ] - ); - } - - #[test] - fn output_ranges_merge_overlaps_and_adjacent_ranges() { - let output = (0..100) - .map(|n| format!("line {n}")) - .collect::>() - .join("\n"); - let ranges = vec![ - crate::semantic::OutputRange { start: 0, end: 100 }, - crate::semantic::OutputRange { - start: -100, - end: -1, - }, - ]; - - let selected = select_output_ranges(&output, &ranges); - - assert_eq!(selected.len(), 100); - assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); - assert_eq!(selected.last(), Some(&output_line(100, "line 99"))); - } - - #[test] - fn output_ranges_can_leave_gaps_for_client_formatting() { - let output = "zero\none\ntwo\nthree\nfour"; - let ranges = vec![ - crate::semantic::OutputRange { start: 0, end: 1 }, - crate::semantic::OutputRange { start: 4, end: 4 }, - ]; - - assert_eq!( - select_output_ranges(output, &ranges), - vec![ - output_line(1, "zero"), - output_line(2, "one"), - output_line(5, "four"), - ] - ); - } - - #[test] - fn empty_output_ranges_default_to_first_thousand_lines() { - let output = (0..1001) - .map(|n| format!("line {n}")) - .collect::>() - .join("\n"); - - let selected = select_output_ranges(&output, &[]); - - assert_eq!(selected.len(), 1000); - assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); - assert_eq!(selected.last(), Some(&output_line(1000, "line 999"))); - } - - #[test] - fn output_ranges_skip_ranges_fully_outside_output() { - let output = "zero\none\ntwo"; - let ranges = vec![ - crate::semantic::OutputRange { start: 10, end: 20 }, - crate::semantic::OutputRange { - start: -20, - end: -10, - }, - ]; - - assert_eq!(select_output_ranges(output, &ranges), Vec::new()); - } -} diff --git a/crates/atuin-daemon/src/components/sync.rs b/crates/atuin-daemon/src/components/sync.rs deleted file mode 100644 index 6e486250..00000000 --- a/crates/atuin-daemon/src/components/sync.rs +++ /dev/null @@ -1,279 +0,0 @@ -//! Sync component. -//! -//! Handles periodic synchronization with the Atuin cloud server. - -use std::time::Duration; - -use eyre::Result; -use rand::Rng; -use tokio::sync::mpsc; -use tokio::time::{self, MissedTickBehavior}; - -use atuin_client::{history::store::HistoryStore, record::sync, settings::Settings}; - -use crate::{ - daemon::{Component, DaemonHandle}, - events::DaemonEvent, -}; - -/// Commands that can be sent to the sync task. -enum SyncCommand { - /// Trigger an immediate sync. - ForceSync, - /// Stop the sync loop. - Stop, -} - -/// Sync state - tracks whether we're in normal operation or retrying after failure. -#[derive(Clone, Copy, PartialEq, Eq)] -enum SyncState { - /// Normal operation. Periodic syncs only run if auto_sync is enabled. - Idle, - /// Retrying after a sync failure. Retries continue regardless of auto_sync - /// until the sync succeeds. - Retrying, -} - -/// Sync component - handles periodic cloud synchronization. -/// -/// This component: -/// - Runs a background sync loop on a configurable interval -/// - Implements exponential backoff on sync failures -/// - Responds to ForceSync events for immediate sync -/// - Emits SyncCompleted/SyncFailed events -pub struct SyncComponent { - task_handle: Option>, - command_tx: Option>, -} - -impl SyncComponent { - /// Create a new sync component. - pub fn new() -> Self { - Self { - task_handle: None, - command_tx: None, - } - } -} - -impl Default for SyncComponent { - fn default() -> Self { - Self::new() - } -} - -#[tonic::async_trait] -impl Component for SyncComponent { - fn name(&self) -> &'static str { - "sync" - } - - async fn start(&mut self, handle: DaemonHandle) -> Result<()> { - let (cmd_tx, cmd_rx) = mpsc::channel(16); - self.command_tx = Some(cmd_tx); - - // Spawn the sync loop with its own copy of the handle - self.task_handle = Some(tokio::spawn(sync_loop(handle, cmd_rx))); - - tracing::info!("sync component started"); - Ok(()) - } - - async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { - if let DaemonEvent::ForceSync = event { - tracing::info!("force sync requested"); - if let Some(tx) = &self.command_tx { - let _ = tx.send(SyncCommand::ForceSync).await; - } - } - Ok(()) - } - - async fn stop(&mut self) -> Result<()> { - if let Some(tx) = &self.command_tx { - let _ = tx.send(SyncCommand::Stop).await; - } - if let Some(handle) = self.task_handle.take() { - // Give the task a moment to shut down gracefully - let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; - } - tracing::info!("sync component stopped"); - Ok(()) - } -} - -/// The main sync loop. -/// -/// This runs in a spawned task and handles periodic sync as well as -/// force sync requests. -async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver) { - tracing::info!("sync loop starting"); - - // Clone settings since we need them across await points - let settings = handle.settings().await.clone(); - let host_id = match Settings::host_id().await { - Ok(id) => id, - Err(e) => { - tracing::error!("failed to get host id, sync disabled: {e}"); - return; - } - }; - - // Create the stores we need - let encryption_key = *handle.encryption_key(); - let history_store = HistoryStore::new(handle.store().clone(), host_id, encryption_key); - - // Don't backoff by more than 30 mins (with a random jitter of up to 1 min) - let max_interval: f64 = 60.0 * 30.0 + rand::thread_rng().gen_range(0.0..60.0); - - let mut ticker = time::interval(time::Duration::from_secs(settings.daemon.sync_frequency)); - - // IMPORTANT: without this, if we miss ticks because a sync takes ages or is otherwise delayed, - // we may end up running a lot of syncs in a hot loop. - ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - - let mut sync_state = SyncState::Idle; - - loop { - tokio::select! { - _ = ticker.tick() => { - let settings = handle.settings().await; - - // Skip periodic ticks if auto_sync is disabled AND we're not retrying - // a previous failure. Retries must continue regardless of auto_sync. - if !settings.auto_sync && sync_state == SyncState::Idle { - tracing::debug!("auto_sync disabled, skipping periodic sync tick"); - continue; - } - - sync_state = do_sync_tick( - &handle, - &history_store, - &mut ticker, - max_interval, - &settings, - ).await; - } - cmd = cmd_rx.recv() => { - match cmd { - Some(SyncCommand::ForceSync) => { - tracing::info!("executing force sync"); - let settings = handle.settings().await; - sync_state = do_sync_tick( - &handle, - &history_store, - &mut ticker, - max_interval, - &settings, - ).await; - } - Some(SyncCommand::Stop) | None => { - tracing::info!("sync loop stopping"); - break; - } - } - } - } - } -} - -/// Execute a single sync tick. -/// -/// Returns the new sync state: `Idle` on success, `Retrying` on failure. -async fn do_sync_tick( - handle: &DaemonHandle, - history_store: &HistoryStore, - ticker: &mut time::Interval, - max_interval: f64, - settings: &Settings, -) -> SyncState { - tracing::info!("sync tick"); - - // Check if logged in - let logged_in = match settings.logged_in().await { - Ok(v) => v, - Err(e) => { - tracing::warn!("failed to check login status, skipping sync tick: {e}"); - return SyncState::Idle; - } - }; - - if !logged_in { - tracing::debug!("not logged in, skipping sync tick"); - return SyncState::Idle; - } - - // Perform the sync - let res = sync::sync(settings, handle.store(), handle.encryption_key()).await; - - match res { - Err(e) => { - tracing::error!("sync tick failed with {e}"); - - // Emit failure event - handle.emit(DaemonEvent::SyncFailed { - error: e.to_string(), - }); - - // Exponential backoff - let mut rng = rand::thread_rng(); - let mut new_interval = ticker.period().as_secs_f64() * rng.gen_range(2.0..2.2); - - if new_interval > max_interval { - new_interval = max_interval; - } - - *ticker = time::interval_at( - tokio::time::Instant::now() + Duration::from_secs(new_interval as u64), - time::Duration::from_secs(new_interval as u64), - ); - ticker.reset_after(time::Duration::from_secs(new_interval as u64)); - ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - - tracing::error!("backing off, next sync tick in {new_interval}"); - - SyncState::Retrying - } - Ok((uploaded_count, downloaded_records)) => { - tracing::info!( - uploaded = uploaded_count, - downloaded = downloaded_records.len(), - "sync complete" - ); - - // Build history from downloaded records - if let Err(e) = history_store - .incremental_build(handle.history_db(), &downloaded_records) - .await - { - tracing::error!("failed to build history from downloaded records: {e}"); - } - - // Emit the records added event (for search indexing) - handle.emit(DaemonEvent::RecordsAdded(downloaded_records.clone())); - - // Emit sync completed event - handle.emit(DaemonEvent::SyncCompleted { - uploaded: uploaded_count as usize, - downloaded: downloaded_records.len(), - }); - - // Reset backoff on success - if ticker.period().as_secs() != settings.daemon.sync_frequency { - *ticker = time::interval_at( - tokio::time::Instant::now() - + Duration::from_secs(settings.daemon.sync_frequency), - time::Duration::from_secs(settings.daemon.sync_frequency), - ); - ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - } - - // Store sync time - if let Err(e) = Settings::save_sync_time().await { - tracing::error!("failed to save sync time: {e}"); - } - - SyncState::Idle - } - } -} diff --git a/crates/atuin-daemon/src/control/mod.rs b/crates/atuin-daemon/src/control/mod.rs deleted file mode 100644 index afb29c57..00000000 --- a/crates/atuin-daemon/src/control/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -//! Control module for external event injection. -//! -//! This module provides the gRPC service that allows external processes -//! (like CLI commands) to inject events into the daemon's event bus. - -mod service; - -// Include the generated proto code -tonic::include_proto!("control"); - -// Re-export the service -pub use service::ControlService; diff --git a/crates/atuin-daemon/src/control/service.rs b/crates/atuin-daemon/src/control/service.rs deleted file mode 100644 index 2e7403ce..00000000 --- a/crates/atuin-daemon/src/control/service.rs +++ /dev/null @@ -1,71 +0,0 @@ -//! Control service implementation. -//! -//! This gRPC service allows external processes (like CLI commands) to inject -//! events into the daemon's event bus. - -use atuin_client::history::HistoryId; -use tonic::{Request, Response, Status}; -use tracing::{Level, info, instrument}; - -use super::{ - SendEventRequest, SendEventResponse, - control_server::{Control, ControlServer}, - send_event_request::Event, -}; -use crate::{daemon::DaemonHandle, events::DaemonEvent}; - -/// The Control gRPC service. -/// -/// This service is used by external processes to inject events into the daemon. -/// It's not a component - it's part of the daemon's core infrastructure. -pub struct ControlService { - handle: DaemonHandle, -} - -impl ControlService { - /// Create a new control service with the given daemon handle. - pub fn new(handle: DaemonHandle) -> Self { - Self { handle } - } - - /// Get a tonic server for this service. - pub fn into_server(self) -> ControlServer { - ControlServer::new(self) - } -} - -#[tonic::async_trait] -impl Control for ControlService { - #[instrument(skip_all, level = Level::INFO, name = "control_send_event")] - async fn send_event( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - - let event = req - .event - .ok_or_else(|| Status::invalid_argument("event is required"))?; - - let daemon_event = proto_event_to_daemon_event(event)?; - - info!(?daemon_event, "received control event"); - self.handle.emit(daemon_event); - - Ok(Response::new(SendEventResponse {})) - } -} - -/// Convert a proto event to a daemon event. -fn proto_event_to_daemon_event(event: Event) -> Result { - match event { - Event::HistoryPruned(_) => Ok(DaemonEvent::HistoryPruned), - Event::HistoryRebuilt(_) => Ok(DaemonEvent::HistoryRebuilt), - Event::HistoryDeleted(e) => Ok(DaemonEvent::HistoryDeleted { - ids: e.ids.into_iter().map(HistoryId).collect(), - }), - Event::ForceSync(_) => Ok(DaemonEvent::ForceSync), - Event::SettingsReloaded(_) => Ok(DaemonEvent::SettingsReloaded), - Event::Shutdown(_) => Ok(DaemonEvent::ShutdownRequested), - } -} diff --git a/crates/atuin-daemon/src/daemon.rs b/crates/atuin-daemon/src/daemon.rs deleted file mode 100644 index 625ca205..00000000 --- a/crates/atuin-daemon/src/daemon.rs +++ /dev/null @@ -1,458 +0,0 @@ -//! Core daemon infrastructure. -//! -//! This module provides the foundational types for building the atuin daemon: -//! -//! - [`DaemonState`]: Shared state owned by the daemon -//! - [`DaemonHandle`]: A lightweight, cloneable handle for accessing daemon state -//! - [`Component`]: A trait for implementing daemon components -//! - [`Daemon`]: The main daemon orchestrator -//! - [`DaemonBuilder`]: Builder for constructing and configuring the daemon - -use std::sync::Arc; - -use atuin_client::{ - database::Sqlite as HistoryDatabase, encryption, record::sqlite_store::SqliteStore, - settings::Settings, -}; -use eyre::{Context, Result}; -use tokio::sync::{RwLock, broadcast}; - -use crate::events::DaemonEvent; - -// ============================================================================ -// DaemonState -// ============================================================================ - -/// Shared state owned by the daemon. -/// -/// This contains all the resources that components and services need access to. -/// The state is wrapped in an `Arc` and accessed via [`DaemonHandle`]. -pub struct DaemonState { - // Event bus - event_tx: broadcast::Sender, - - // Configuration (mutable - can be reloaded) - settings: RwLock, - - // Encryption key (immutable - derived at startup) - encryption_key: [u8; 32], - - // Database handles - history_db: HistoryDatabase, - store: SqliteStore, -} - -// ============================================================================ -// DaemonHandle -// ============================================================================ - -/// A lightweight handle to the daemon's shared state. -/// -/// This is the primary way for components, gRPC services, and spawned tasks to -/// interact with the daemon. It provides access to: -/// -/// - Event emission and subscription -/// - Configuration (settings, encryption key) -/// - Database handles -/// -/// The handle is cheaply cloneable (wraps an `Arc`) and can be freely passed -/// around to any code that needs daemon access. -/// -/// # Example -/// -/// ```ignore -/// // Emit an event -/// handle.emit(DaemonEvent::HistoryPruned); -/// -/// // Access settings -/// let settings = handle.settings().await; -/// let sync_freq = settings.daemon.sync_frequency; -/// -/// // Access database -/// let history = handle.history_db().load(id).await?; -/// ``` -#[derive(Clone)] -pub struct DaemonHandle { - state: Arc, -} - -impl DaemonHandle { - // ---- Events ---- - - /// Emit an event to the daemon's event bus. - /// - /// This is fire-and-forget - if no receivers are listening (which shouldn't - /// happen in normal operation), the event is dropped silently. - pub fn emit(&self, event: DaemonEvent) { - if let Err(e) = self.state.event_tx.send(event) { - tracing::warn!("failed to emit event (no receivers?): {e}"); - } - } - - /// Subscribe to the event bus. - /// - /// Returns a receiver that will receive all events emitted after this call. - /// Useful for components that need to listen for events outside of the - /// normal `handle_event` callback flow. - pub fn subscribe(&self) -> broadcast::Receiver { - self.state.event_tx.subscribe() - } - - /// Request graceful shutdown of the daemon. - pub fn shutdown(&self) { - self.emit(DaemonEvent::ShutdownRequested); - } - - // ---- Configuration ---- - - /// Get the current settings. - /// - /// This acquires a read lock on the settings. For most use cases, clone - /// the settings if you need to hold onto them. - pub async fn settings(&self) -> tokio::sync::RwLockReadGuard<'_, Settings> { - self.state.settings.read().await - } - - /// Reload settings from disk and emit a SettingsReloaded event. - /// - /// Components listening for `SettingsReloaded` can then re-read settings - /// via `handle.settings()` to pick up the changes. - pub async fn reload_settings(&self) -> Result<()> { - let new_settings = Settings::new()?; - self.apply_settings(new_settings).await; - Ok(()) - } - - /// Apply already-loaded settings and emit a SettingsReloaded event. - /// - /// Use this when settings have already been loaded (e.g., from a file watcher) - /// to avoid parsing the config file twice. - pub async fn apply_settings(&self, settings: Settings) { - *self.state.settings.write().await = settings; - self.emit(DaemonEvent::SettingsReloaded); - tracing::info!("settings applied"); - } - - /// Get the encryption key. - pub fn encryption_key(&self) -> &[u8; 32] { - &self.state.encryption_key - } - - // ---- Database ---- - - /// Get a reference to the history database. - pub fn history_db(&self) -> &HistoryDatabase { - &self.state.history_db - } - - /// Get a reference to the record store. - pub fn store(&self) -> &SqliteStore { - &self.state.store - } -} - -impl std::fmt::Debug for DaemonHandle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DaemonHandle").finish_non_exhaustive() - } -} - -// ============================================================================ -// Component Trait -// ============================================================================ - -/// A daemon component that handles a specific domain. -/// -/// Components are the building blocks of the daemon. Each component: -/// -/// - Has a unique name for logging and debugging -/// - Can optionally expose gRPC services -/// - Receives a [`DaemonHandle`] on startup for accessing daemon resources -/// - Handles events from the event bus -/// - Performs cleanup on shutdown -/// -/// # Lifecycle -/// -/// 1. **Construction**: Component is created (usually via `new()`) -/// 2. **Start**: `start()` is called with a [`DaemonHandle`] -/// 3. **Running**: `handle_event()` is called for each event on the bus -/// 4. **Shutdown**: `stop()` is called for cleanup -/// -/// # Example -/// -/// ```ignore -/// pub struct MyComponent { -/// handle: Option, -/// } -/// -/// #[async_trait] -/// impl Component for MyComponent { -/// fn name(&self) -> &'static str { "my-component" } -/// -/// async fn start(&mut self, handle: DaemonHandle) -> Result<()> { -/// self.handle = Some(handle); -/// Ok(()) -/// } -/// -/// async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { -/// match event { -/// DaemonEvent::SomeEvent => { -/// // Handle the event -/// if let Some(handle) = &self.handle { -/// handle.emit(DaemonEvent::ResponseEvent); -/// } -/// } -/// _ => {} -/// } -/// Ok(()) -/// } -/// -/// async fn stop(&mut self) -> Result<()> { -/// Ok(()) -/// } -/// } -/// ``` -#[tonic::async_trait] -pub trait Component: Send + Sync { - /// Human-readable name for logging and debugging. - fn name(&self) -> &'static str; - - /// Called once at startup. - /// - /// Store the handle if you need to emit events or access daemon resources - /// later. The handle is cheaply cloneable, so feel free to clone it for - /// spawned tasks. - async fn start(&mut self, handle: DaemonHandle) -> Result<()>; - - /// Handle an incoming event. - /// - /// Called for every event on the bus. To emit new events in response, - /// use the handle stored during `start()`. Events emitted here will be - /// processed in subsequent event loop iterations. - async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()>; - - /// Called on graceful shutdown. - /// - /// Use this to clean up resources, abort spawned tasks, etc. - async fn stop(&mut self) -> Result<()>; -} - -// ============================================================================ -// Daemon -// ============================================================================ - -/// The main daemon orchestrator. -/// -/// The daemon manages components, runs the event loop, and coordinates startup -/// and shutdown. It is constructed via [`DaemonBuilder`]. -/// -/// # Event Loop -/// -/// The daemon runs a simple event loop: -/// -/// 1. Wait for an event on the bus -/// 2. Dispatch the event to all components (in registration order) -/// 3. Components may emit new events in response -/// 4. Repeat until `ShutdownRequested` is received -/// -/// Events emitted during handling are queued and processed in subsequent -/// iterations, ensuring the loop eventually drains. -pub struct Daemon { - components: Vec>, - handle: DaemonHandle, -} - -impl Daemon { - /// Create a new daemon builder. - pub fn builder(settings: Settings) -> DaemonBuilder { - DaemonBuilder::new(settings) - } - - /// Get a clone of the daemon handle. - /// - /// The handle can be used to emit events, access settings, etc. - pub fn handle(&self) -> DaemonHandle { - self.handle.clone() - } - - /// Start all components. - /// - /// This must be called before `run_event_loop()`. It initializes all - /// registered components with the daemon handle. - pub async fn start_components(&mut self) -> Result<()> { - for component in &mut self.components { - tracing::info!(component = component.name(), "starting component"); - component - .start(self.handle.clone()) - .await - .with_context(|| format!("failed to start component: {}", component.name()))?; - } - Ok(()) - } - - /// Run the daemon event loop. - /// - /// This processes events until a ShutdownRequested event is received. - /// Components must be started first via `start_components()`. - pub async fn run_event_loop(&mut self) -> Result<()> { - let mut event_rx = self.handle.subscribe(); - loop { - match event_rx.recv().await { - Ok(DaemonEvent::ShutdownRequested) => { - tracing::info!("shutdown requested, stopping daemon"); - break; - } - Ok(event) => { - tracing::debug!(?event, "processing event"); - self.dispatch_event(&event).await; - } - Err(broadcast::error::RecvError::Lagged(n)) => { - tracing::warn!( - skipped = n, - "event receiver lagged, some events were dropped" - ); - } - Err(broadcast::error::RecvError::Closed) => { - tracing::info!("event bus closed, stopping daemon"); - break; - } - } - } - Ok(()) - } - - /// Stop all components. - /// - /// This performs graceful shutdown of all components. - pub async fn stop_components(&mut self) { - for component in &mut self.components { - tracing::info!(component = component.name(), "stopping component"); - if let Err(e) = component.stop().await { - tracing::error!( - component = component.name(), - error = ?e, - "error stopping component" - ); - } - } - tracing::info!("all components stopped"); - } - - /// Run the daemon. - /// - /// This is a convenience method that starts components, runs the event loop, - /// and handles shutdown. It does not return until the daemon is shut down. - pub async fn run(mut self) -> Result<()> { - self.start_components().await?; - self.run_event_loop().await?; - self.stop_components().await; - tracing::info!("daemon stopped"); - Ok(()) - } - - async fn dispatch_event(&mut self, event: &DaemonEvent) { - for component in &mut self.components { - if let Err(e) = component.handle_event(event).await { - tracing::error!( - component = component.name(), - error = ?e, - "error handling event" - ); - } - } - } -} - -// ============================================================================ -// DaemonBuilder -// ============================================================================ - -/// Builder for constructing a [`Daemon`]. -/// -/// # Example -/// -/// ```ignore -/// let daemon = Daemon::builder(settings) -/// .store(store) -/// .history_db(history_db) -/// .component(HistoryComponent::new()) -/// .component(SearchComponent::new()) -/// .component(SyncComponent::new()) -/// .build() -/// .await?; -/// -/// daemon.run().await?; -/// ``` -pub struct DaemonBuilder { - settings: Settings, - store: Option, - history_db: Option, - components: Vec>, -} - -impl DaemonBuilder { - /// Create a new daemon builder with the given settings. - pub fn new(settings: Settings) -> Self { - Self { - settings, - store: None, - history_db: None, - components: Vec::new(), - } - } - - /// Set the record store. - pub fn store(mut self, store: SqliteStore) -> Self { - self.store = Some(store); - self - } - - /// Set the history database. - pub fn history_db(mut self, db: HistoryDatabase) -> Self { - self.history_db = Some(db); - self - } - - /// Register a component. - /// - /// Components are started in registration order and stopped in reverse order. - pub fn component(mut self, component: impl Component + 'static) -> Self { - self.components.push(Box::new(component)); - self - } - - /// Build the daemon. - /// - /// This loads the encryption key and creates the daemon state. - pub async fn build(self) -> Result { - let store = self.store.ok_or_else(|| eyre::eyre!("store is required"))?; - let history_db = self - .history_db - .ok_or_else(|| eyre::eyre!("history_db is required"))?; - - // Load encryption key - let encryption_key: [u8; 32] = encryption::load_key(&self.settings) - .context("could not load encryption key")? - .into(); - - // Create the event bus - let (event_tx, _) = broadcast::channel(64); - - // Create the shared state - let state = Arc::new(DaemonState { - event_tx, - settings: RwLock::new(self.settings), - encryption_key, - history_db, - store, - }); - - // Create the handle (just a reference to the state) - let handle = DaemonHandle { state }; - - Ok(Daemon { - components: self.components, - handle, - }) - } -} diff --git a/crates/atuin-daemon/src/events.rs b/crates/atuin-daemon/src/events.rs deleted file mode 100644 index 4e6c6ff3..00000000 --- a/crates/atuin-daemon/src/events.rs +++ /dev/null @@ -1,74 +0,0 @@ -//! Daemon events. -//! -//! Events are the primary communication mechanism within the daemon. -//! Components emit events to notify others of state changes, and handle -//! events to react to changes elsewhere in the system. -//! -//! External processes (like CLI commands) can also inject events via the -//! Control gRPC service. - -use atuin_client::history::{History, HistoryId}; -use atuin_common::record::RecordId; - -/// Events that flow through the daemon's event bus. -/// -/// Events are broadcast to all components. Each component decides which -/// events it cares about in its `handle_event` implementation. -#[derive(Debug, Clone)] -pub enum DaemonEvent { - // ---- History lifecycle ---- - /// A command has started running. - HistoryStarted(History), - - /// A command has finished running. - HistoryEnded(History), - - // ---- Sync ---- - /// Records were synced from the server. - /// - /// The search component uses this to update its index with new history. - RecordsAdded(Vec), - - /// Sync completed successfully. - SyncCompleted { - /// Number of records uploaded. - uploaded: usize, - /// Number of records downloaded. - downloaded: usize, - }, - - /// Sync failed. - SyncFailed { - /// Error message describing what went wrong. - error: String, - }, - - /// Request an immediate sync (external trigger). - ForceSync, - - // ---- External commands ---- - /// History was pruned - search index needs a full rebuild. - /// - /// Emitted when the user runs `atuin history prune` or similar. - HistoryPruned, - - /// History was rebuilt - search index needs a full rebuild. - /// - /// Emitted when the user runs `atuin store rebuild history` or similar. - HistoryRebuilt, - - /// Specific history items were deleted. - /// - /// The search component should remove these from its index. - HistoryDeleted { - /// IDs of the deleted history entries. - ids: Vec, - }, - - /// Settings have changed, components should reload if needed. - SettingsReloaded, - - // ---- Lifecycle ---- - /// Request graceful shutdown of the daemon. - ShutdownRequested, -} diff --git a/crates/atuin-daemon/src/history/mod.rs b/crates/atuin-daemon/src/history/mod.rs deleted file mode 100644 index b71853df..00000000 --- a/crates/atuin-daemon/src/history/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! History module for the daemon gRPC history service. -//! -//! This module contains the proto-generated types for the history gRPC service. - -// Include the generated proto code -tonic::include_proto!("history"); diff --git a/crates/atuin-daemon/src/lib.rs b/crates/atuin-daemon/src/lib.rs deleted file mode 100644 index 27d3932b..00000000 --- a/crates/atuin-daemon/src/lib.rs +++ /dev/null @@ -1,136 +0,0 @@ -use atuin_client::database::Sqlite as HistoryDatabase; -use atuin_client::record::sqlite_store::SqliteStore; -use atuin_client::settings::{Settings, watcher::global_settings_watcher}; -use eyre::Result; - -pub mod client; -pub mod components; -pub mod control; -pub mod daemon; -pub mod events; -pub mod history; -pub mod search; -pub mod semantic; -pub mod server; - -// Re-export core daemon types for convenience -pub use daemon::{Component, Daemon, DaemonBuilder, DaemonHandle}; -pub use events::DaemonEvent; - -// Re-export components -pub use components::{HistoryComponent, SearchComponent, SemanticComponent, SyncComponent}; - -// Re-export client helpers -pub use client::{ControlClient, SemanticClient, emit_event, emit_event_with_settings}; - -/// Boot the daemon using the new component-based architecture. -/// -/// This creates a daemon with the standard components (history, search, sync), -/// starts the gRPC server with their services, and runs the event loop. -pub async fn boot( - settings: Settings, - store: SqliteStore, - history_db: HistoryDatabase, -) -> Result<()> { - // Create the components - let history_component = HistoryComponent::new(); - let search_component = SearchComponent::new(); - let semantic_component = SemanticComponent::new(); - let sync_component = SyncComponent::new(); - - // Get the gRPC services before moving components into the daemon - // (The services share state with the components via Arc) - let history_service = history_component.grpc_service(); - let search_service = search_component.grpc_service(); - let semantic_service = semantic_component.grpc_service(); - - // Build the daemon - let mut daemon = Daemon::builder(settings.clone()) - .store(store) - .history_db(history_db) - .component(history_component) - .component(search_component) - .component(semantic_component) - .component(sync_component) - .build() - .await?; - - // Get a handle for the control service and gRPC server shutdown - let handle = daemon.handle(); - - // Create the control service - let control_service = control::ControlService::new(handle.clone()); - - // Start all components first (so gRPC services can work) - daemon.start_components().await?; - - // Spawn config file watcher to reload settings on changes - if let Ok(watcher) = global_settings_watcher() { - let mut settings_rx = watcher.subscribe(); - let watcher_handle = handle.clone(); - tokio::spawn(async move { - tracing::info!("config file watcher started"); - while settings_rx.changed().await.is_ok() { - // Use the already-loaded settings from the watcher - // (avoids parsing the config file twice) - let new_settings = (*settings_rx.borrow()).clone(); - watcher_handle.apply_settings((*new_settings).clone()).await; - } - tracing::debug!("config file watcher stopped"); - }); - } else { - tracing::warn!( - "failed to start config file watcher; settings changes will require daemon restart" - ); - } - - // Spawn signal handler to emit ShutdownRequested on Ctrl+C/SIGTERM - let signal_handle = handle.clone(); - tokio::spawn(async move { - shutdown_signal().await; - tracing::info!("received shutdown signal"); - signal_handle.shutdown(); - }); - - // Start the gRPC server in the background - server::run_grpc_server( - settings, - history_service, - search_service, - semantic_service, - control_service.into_server(), - handle, - ) - .await?; - - // Run the daemon event loop - daemon.run_event_loop().await?; - - // Stop all components on shutdown - daemon.stop_components().await; - - tracing::info!("daemon shut down complete"); - Ok(()) -} - -/// Wait for a shutdown signal (Ctrl+C or SIGTERM). -#[cfg(unix)] -async fn shutdown_signal() { - let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) - .expect("failed to register sigterm handler"); - let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) - .expect("failed to register sigint handler"); - - tokio::select! { - _ = term.recv() => {}, - _ = int.recv() => {}, - } -} - -/// Wait for a shutdown signal (Ctrl+C). -#[cfg(not(unix))] -async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("failed to listen for ctrl+c"); -} diff --git a/crates/atuin-daemon/src/search/index.rs b/crates/atuin-daemon/src/search/index.rs deleted file mode 100644 index bb155979..00000000 --- a/crates/atuin-daemon/src/search/index.rs +++ /dev/null @@ -1,683 +0,0 @@ -//! Search index with frecency-based ranking. -//! -//! This module provides a deduplicated search index where each unique command -//! is stored once, with metadata about all its invocations. This enables: -//! -//! - Efficient fuzzy matching (fewer items to match) -//! - Frecency-based ranking (frequency + recency) -//! - Dynamic filtering by directory, host, session, etc. - -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; - -use atuin_client::history::{History, is_known_agent}; -use atuin_client::settings::Search; -use atuin_nucleo::{Injector, Nucleo, pattern}; -use dashmap::DashMap; -use lasso::{Spur, ThreadedRodeo}; -use time::OffsetDateTime; -use tokio::sync::RwLock; -use tracing::{Level, instrument}; -use uuid::Uuid; - -use crate::components::search::with_trailing_slash; - -/// Parse a UUID string into a 16-byte array. -/// Returns None if the string is not a valid UUID. -fn parse_uuid_bytes(s: &str) -> Option<[u8; 16]> { - Uuid::parse_str(s).ok().map(|u| *u.as_bytes()) -} - -/// Format a 16-byte array as a UUID string. -fn format_uuid_bytes(bytes: &[u8; 16]) -> String { - Uuid::from_bytes(*bytes).to_string() -} - -/// Pre-computed frecency data for O(1) lookup. -#[derive(Debug, Clone, Default)] -pub struct FrecencyData { - /// Total number of times this command was used. - pub count: u32, - /// Most recent usage timestamp (unix seconds). - pub last_used: i64, -} - -impl FrecencyData { - /// Record a new usage of this command. - pub fn record_use(&mut self, timestamp: i64) { - self.count += 1; - if timestamp > self.last_used { - self.last_used = timestamp; - } - } - - /// Compute frecency score based on count and recency. - /// - /// Uses a decay function where more recent commands score higher. - /// The formula balances frequency (how often) with recency (how recent). - /// - /// Multipliers allow tuning the relative weights: - /// - `recency_mul`: Multiplier for recency score (default: 1.0) - /// - `frequency_mul`: Multiplier for frequency score (default: 1.0) - /// - /// A multiplier of 0.0 disables that component, 1.0 is unchanged, 2.0 doubles weight. - /// Values like 0.5 reduce weight by half, 1.5 increases by 50%, etc. - #[instrument(level = tracing::Level::TRACE, name = "index_frecency_compute")] - pub fn compute(&self, now: i64, recency_mul: f64, frequency_mul: f64) -> u32 { - if self.count == 0 { - return 0; - } - - // Time-based decay: score decreases as time passes - let age_seconds = (now - self.last_used).max(0) as u64; - let age_hours = age_seconds / 3600; - - // Decay factor: recent commands get higher scores - // - Last hour: multiplier ~1.0 - // - Last day: multiplier ~0.5 - // - Last week: multiplier ~0.1 - // - Older: multiplier approaches 0 - let recency_score: f64 = match age_hours { - 0 => 100.0, - 1..=6 => 90.0, - 7..=24 => 70.0, - 25..=72 => 50.0, - 73..=168 => 30.0, - 169..=720 => 15.0, - _ => 5.0, - }; - - // Frequency boost: more uses = higher score (with diminishing returns) - let frequency_score = ((self.count as f64).ln() * 20.0).min(100.0); - - // Apply multipliers and combine scores, then round to u32 - ((recency_score * recency_mul) + (frequency_score * frequency_mul)).round() as u32 - } -} - -/// Data for a unique command. -pub struct CommandData { - /// History ID of the most recent invocation (16-byte UUID). - most_recent_id: [u8; 16], - /// Timestamp of the most recent invocation. - most_recent_timestamp: i64, - /// Pre-computed global frecency. - pub global_frecency: FrecencyData, - - // Pre-computed indexes for O(1) filter lookups - // Using HashSet instead of DashSet since CommandData lives inside DashMap (already synchronized) - /// All directories where this command has been run (interned keys). - directories: HashSet, - /// All hostnames where this command has been run (interned keys). - hosts: HashSet, - /// All sessions where this command has been run (as 16-byte UUIDs). - sessions: HashSet<[u8; 16]>, -} - -impl CommandData { - /// Create a new CommandData from a history entry. - /// Returns None if the history entry has invalid UUIDs. - pub fn new(history: &History, interner: &ThreadedRodeo) -> Option { - let history_id = parse_uuid_bytes(&history.id.0)?; - let session = parse_uuid_bytes(&history.session)?; - let timestamp = history.timestamp.unix_timestamp(); - - let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); - let host_key = interner.get_or_intern(&history.hostname); - - let mut directories = HashSet::new(); - directories.insert(dir_key); - - let mut hosts = HashSet::new(); - hosts.insert(host_key); - - let mut sessions = HashSet::new(); - sessions.insert(session); - - let mut global_frecency = FrecencyData::default(); - global_frecency.record_use(timestamp); - - Some(Self { - most_recent_id: history_id, - most_recent_timestamp: timestamp, - global_frecency, - directories, - hosts, - sessions, - }) - } - - /// Add an invocation from a history entry. - /// Returns false if the history entry has invalid UUIDs. - pub fn add_invocation(&mut self, history: &History, interner: &ThreadedRodeo) -> bool { - let Some(history_id) = parse_uuid_bytes(&history.id.0) else { - return false; - }; - let Some(session) = parse_uuid_bytes(&history.session) else { - return false; - }; - - let timestamp = history.timestamp.unix_timestamp(); - - // Update global frecency - self.global_frecency.record_use(timestamp); - - // Update pre-computed indexes for O(1) filter lookups - let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); - self.directories.insert(dir_key); - self.hosts.insert(interner.get_or_intern(&history.hostname)); - self.sessions.insert(session); - - // Update most recent if this invocation is newer - if timestamp > self.most_recent_timestamp { - self.most_recent_id = history_id; - self.most_recent_timestamp = timestamp; - } - - true - } - - /// Get the most recent history ID for this command. - pub fn most_recent_id(&self) -> String { - format_uuid_bytes(&self.most_recent_id) - } - - /// Check if any invocation matches a directory filter (exact match). - /// O(1) lookup using pre-computed index. - pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool { - interner - .get(dir) - .is_some_and(|spur| self.directories.contains(&spur)) - } - - /// Check if any invocation matches a directory prefix (workspace/git root). - /// O(n) where n = number of unique directories for this command. - pub fn has_invocation_in_workspace(&self, prefix: &str, interner: &ThreadedRodeo) -> bool { - self.directories - .iter() - .any(|&spur| interner.resolve(&spur).starts_with(prefix)) - } - - /// Check if any invocation matches a hostname. - /// O(1) lookup using pre-computed index. - pub fn has_invocation_on_host(&self, hostname: &str, interner: &ThreadedRodeo) -> bool { - interner - .get(hostname) - .is_some_and(|spur| self.hosts.contains(&spur)) - } - - /// Check if any invocation matches a session. - /// O(1) lookup using pre-computed index. - pub fn has_invocation_in_session(&self, session: &str) -> bool { - parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes)) - } -} - -/// Filter mode for search queries. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum IndexFilterMode { - /// No filtering - search all commands. - Global, - /// Filter to commands run in a specific directory. - Directory(String), - /// Filter to commands run in a workspace (directory prefix). - Workspace(String), - /// Filter to commands run on a specific host. - Host(String), - /// Filter to commands run in a specific session. - Session(String), -} - -/// Context for search queries. -#[derive(Debug, Clone, Default)] -pub struct QueryContext { - pub cwd: Option, - pub git_root: Option, - pub hostname: Option, - pub session_id: Option, -} - -/// Shareable frecency map: command -> frecency score. -/// Wrapped in Arc for zero-copy sharing with scorer callbacks. -type FrecencyMap = Arc, u32>>; - -/// A deduplicated search index with frecency-based ranking. -/// -/// Commands are stored by their text, with metadata about all invocations. -/// Nucleo handles fuzzy matching, while frecency is computed via scorer callback. -/// -/// Global frecency is precomputed by a background task and used for scoring. -/// If frecency data is not available, search still works but without frecency ranking; -/// although this should never happen due to precomputing the frecency map. -pub struct SearchIndex { - /// Map from command text to command data. - /// Using DashMap for concurrent read/write access, wrapped in Arc for sharing with scorer. - /// Keys are Arc to enable zero-copy sharing with frecency_map. - commands: Arc, CommandData>>, - /// Nucleo fuzzy matcher - items are command strings. - nucleo: RwLock>, - /// Injector for adding new commands to Nucleo. - injector: Injector, - /// Precomputed global frecency map. Updated by background task. - frecency_map: RwLock>, - /// String interner for deduplicating cwd, hostname, and directory paths. - interner: Arc, -} - -impl SearchIndex { - /// Create a new empty search index. - pub fn new() -> Self { - let nucleo_config = atuin_nucleo::Config::DEFAULT; - // Single column for command text - let nucleo = Nucleo::::new(nucleo_config, Arc::new(|| {}), None, 1); - let injector = nucleo.injector(); - - Self { - commands: Arc::new(DashMap::new()), - nucleo: RwLock::new(nucleo), - injector, - frecency_map: RwLock::new(None), - interner: Arc::new(ThreadedRodeo::new()), - } - } - - /// Add a history entry to the index. - /// - /// If the command already exists, updates its invocation data. - /// If it's a new command, adds it to both the map and Nucleo. - pub fn add_history(&self, history: &History) { - if is_known_agent(&history.author) { - return; - } - - let command = history.command.as_str(); - - // DashMap with Arc keys can be looked up with &str via Borrow trait - if let Some(mut entry) = self.commands.get_mut(command) { - // Existing command - just update invocations - entry.add_invocation(history, &self.interner); - } else { - // New command - create Arc once and share it - let Some(data) = CommandData::new(history, &self.interner) else { - return; // Invalid UUIDs, skip this entry - }; - let command_arc: Arc = command.into(); - self.commands.insert(Arc::clone(&command_arc), data); - // Nucleo still needs String (unavoidable copy for fuzzy matching) - self.injector.push(command_arc.to_string(), |cmd, cols| { - cols[0] = cmd.clone().into(); - }); - } - // Note: frecency_map is rebuilt by background task, not invalidated here - } - - /// Add multiple history entries to the index. - pub fn add_histories(&self, histories: &[History]) { - for history in histories { - self.add_history(history); - } - } - - /// Get the number of unique commands in the index. - pub fn command_count(&self) -> usize { - self.commands.len() - } - - /// Get the number of items in Nucleo (should match command_count). - pub async fn nucleo_item_count(&self) -> u32 { - self.nucleo.read().await.snapshot().item_count() - } - - /// Search for commands matching a query. - /// - /// Returns a list of history IDs (most recent invocation per command). - /// Uses precomputed global frecency for scoring if available. - #[instrument(skip_all, level = tracing::Level::TRACE, name = "index_search", fields(query = %query))] - pub async fn search( - &self, - query: &str, - filter_mode: IndexFilterMode, - _context: &QueryContext, - limit: u32, - ) -> Vec { - let mut nucleo = self.nucleo.write().await; - - // Get precomputed frecency map (may be None if not yet computed) - let frecency_map = self.frecency_map.read().await.clone(); - - // Build filter based on mode - let filter = self.build_filter(&filter_mode); - nucleo.set_filter(filter); - - // Build scorer from precomputed frecency (or None if not available) - let scorer = Self::build_scorer(frecency_map); - nucleo.set_scorer(scorer); - - // Update pattern - nucleo.pattern.reparse( - 0, - query, - pattern::CaseMatching::Smart, - pattern::Normalization::Smart, - false, - ); - - tracing::span!(Level::TRACE, "index_search_tick").in_scope(|| { - // Tick until complete - while nucleo.tick(10).running {} - }); - - // Collect results - let snapshot = nucleo.snapshot(); - let matched_count = snapshot.matched_item_count().min(limit); - - tracing::span!(Level::TRACE, "index_search_results").in_scope(|| { - snapshot - .matched_items(..matched_count) - .filter_map(|item| { - let cmd = item.data; - // DashMap, _>::get accepts &str via Borrow trait - self.commands - .get(cmd.as_str()) - .map(|data| data.most_recent_id()) - }) - .collect() - }) - } - - /// Rebuild the global frecency map. - /// - /// This should be called by a background task periodically. - /// The map is used for scoring search results. - /// - /// Uses multipliers from search settings: - /// - `recency_score_multiplier`: Weight for recency component - /// - `frequency_score_multiplier`: Weight for frequency component - /// - `frecency_score_multiplier`: Overall multiplier for final score - #[instrument(skip_all, level = tracing::Level::DEBUG, name = "rebuild_frecency")] - pub async fn rebuild_frecency(&self, search_settings: &Search) { - let now = OffsetDateTime::now_utc().unix_timestamp(); - let mut frecency_map: HashMap, u32> = HashMap::new(); - - // Clamp multipliers to non-negative values to prevent broken frecency ranking - // (negative values would produce unexpected results when cast to u32) - let recency_mul = search_settings.recency_score_multiplier.max(0.0); - let frequency_mul = search_settings.frequency_score_multiplier.max(0.0); - let frecency_mul = search_settings.frecency_score_multiplier.max(0.0); - - for entry in self.commands.iter() { - let frecency = entry - .global_frecency - .compute(now, recency_mul, frequency_mul); - // Apply overall frecency multiplier and round to u32 - let frecency = (frecency as f64 * frecency_mul).round() as u32; - // Arc::clone is cheap - just increments reference count - frecency_map.insert(Arc::clone(entry.key()), frecency); - } - - *self.frecency_map.write().await = Some(Arc::new(frecency_map)); - } - - /// Build filter predicate for the given mode. - fn build_filter(&self, mode: &IndexFilterMode) -> Option> { - // For Global mode, no filter needed - if matches!(mode, IndexFilterMode::Global) { - return None; - } - - // Pre-compute which commands pass the filter - // Use HashSet for the short-lived filter (simpler than Arc lookup) - let passing_commands: Arc> = { - let mut set = HashSet::new(); - for entry in self.commands.iter() { - let passes = match mode { - IndexFilterMode::Global => unreachable!(), - IndexFilterMode::Directory(dir) => { - entry.has_invocation_in_dir(dir, &self.interner) - } - IndexFilterMode::Workspace(prefix) => { - entry.has_invocation_in_workspace(prefix, &self.interner) - } - IndexFilterMode::Host(hostname) => { - entry.has_invocation_on_host(hostname, &self.interner) - } - IndexFilterMode::Session(session) => entry.has_invocation_in_session(session), - }; - if passes { - // Convert Arc to String for filter lookup - set.insert(entry.key().to_string()); - } - } - Arc::new(set) - }; - - Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) - } - - /// Build scorer from precomputed frecency map. - /// - /// Returns None if frecency map is not available (search still works, just without frecency ranking). - fn build_scorer(frecency_map: Option) -> Option> { - let map = frecency_map?; - Some(Arc::new(move |cmd: &String, fuzzy_score: u32| { - // HashMap, _>::get accepts &str via Borrow trait - let frecency = map.get(cmd.as_str()).copied().unwrap_or(0); - fuzzy_score + frecency - })) - } -} - -impl Default for SearchIndex { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use time::macros::datetime; - - fn make_history(command: &str, cwd: &str, timestamp: OffsetDateTime) -> History { - History::import() - .timestamp(timestamp) - .command(command) - .cwd(cwd) - .build() - .into() - } - - #[test] - fn frecency_data_compute() { - let now = 1000000i64; - - // Recent command (with default multipliers of 1.0) - let recent = FrecencyData { - count: 5, - last_used: now - 60, // 1 minute ago - }; - assert!(recent.compute(now, 1.0, 1.0) > 100); // High score - - // Old command - let old = FrecencyData { - count: 5, - last_used: now - 86400 * 30, // 30 days ago - }; - assert!(old.compute(now, 1.0, 1.0) < recent.compute(now, 1.0, 1.0)); - - // Frequently used old command - let frequent_old = FrecencyData { - count: 100, - last_used: now - 86400 * 7, // 1 week ago - }; - // Should still have decent score due to frequency - assert!(frequent_old.compute(now, 1.0, 1.0) > 50); - } - - #[test] - fn frecency_data_compute_with_multipliers() { - let now = 1000000i64; - - let data = FrecencyData { - count: 5, - last_used: now - 60, // 1 minute ago (recency_score = 100) - }; - - // Default multipliers (1.0, 1.0) - let default_score = data.compute(now, 1.0, 1.0); - - // Double recency weight - let double_recency = data.compute(now, 2.0, 1.0); - assert!(double_recency > default_score); - - // Double frequency weight - let double_frequency = data.compute(now, 1.0, 2.0); - assert!(double_frequency > default_score); - - // Zero out recency (only frequency counts) - let no_recency = data.compute(now, 0.0, 1.0); - assert!(no_recency < default_score); - - // Zero out frequency (only recency counts) - let no_frequency = data.compute(now, 1.0, 0.0); - assert!(no_frequency < default_score); - - // Zero both (should be zero) - let no_score = data.compute(now, 0.0, 0.0); - assert_eq!(no_score, 0); - - // Fractional multipliers - let half_recency = data.compute(now, 0.5, 1.0); - assert!(half_recency < default_score); - assert!(half_recency > no_recency); - - // 1.5x multiplier - let boost_recency = data.compute(now, 1.5, 1.0); - assert!(boost_recency > default_score); - assert!(boost_recency < double_recency); - } - - #[test] - fn command_data_add_invocation() { - let interner = ThreadedRodeo::new(); - - let (dir1, dir2) = if cfg!(windows) { - ("C:\\Users\\User\\project", "C:\\Users\\User\\other") - } else { - ("/home/user/project", "/home/user/other") - }; - - let history1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); - let history2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); - - let mut data = CommandData::new(&history1, &interner).unwrap(); - assert_eq!(data.global_frecency.count, 1); - let id1 = data.most_recent_id(); - - data.add_invocation(&history2, &interner); - assert_eq!(data.global_frecency.count, 2); - - // Most recent ID should update to history2 (newer timestamp) - let id2 = data.most_recent_id(); - assert_ne!(id1, id2); - } - - #[test] - fn command_data_filters() { - let interner = ThreadedRodeo::new(); - - let (dir1, dir2) = if cfg!(windows) { - ("C:\\Users\\User\\project", "C:\\Users\\User\\other") - } else { - ("/home/user/project", "/home/user/other") - }; - - let h1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); - let h2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); - - let mut data = CommandData::new(&h1, &interner).unwrap(); - data.add_invocation(&h2, &interner); - - let (check1, check2, check3) = if cfg!(windows) { - ( - with_trailing_slash("C:\\Users\\User\\project"), - with_trailing_slash("C:\\Users\\User\\other"), - with_trailing_slash("C:\\Users\\User\\missing"), - ) - } else { - ( - with_trailing_slash("/home/user/project"), - with_trailing_slash("/home/user/other"), - with_trailing_slash("/home/user/missing"), - ) - }; - - assert!(data.has_invocation_in_dir(&check1, &interner)); - assert!(data.has_invocation_in_dir(&check2, &interner)); - assert!(!data.has_invocation_in_dir(&check3, &interner)); - - let (check1, check2, check3) = if cfg!(windows) { - ( - with_trailing_slash("C:\\Users\\User"), - with_trailing_slash("C:\\Users"), - with_trailing_slash("C:\\Users\\User\\var"), - ) - } else { - ( - with_trailing_slash("/home/user"), - with_trailing_slash("/home"), - with_trailing_slash("/var"), - ) - }; - - assert!(data.has_invocation_in_workspace(&check1, &interner)); - assert!(data.has_invocation_in_workspace(&check2, &interner)); - assert!(!data.has_invocation_in_workspace(&check3, &interner)); - } - - #[tokio::test] - async fn search_index_add_and_search() { - let index = SearchIndex::new(); - - let h1 = make_history( - "git status", - "/home/user/project", - datetime!(2024-01-01 10:00 UTC), - ); - let h2 = make_history( - "git commit -m 'test'", - "/home/user/project", - datetime!(2024-01-01 10:05 UTC), - ); - let h3 = make_history( - "ls -la", - "/home/user/other", - datetime!(2024-01-01 10:10 UTC), - ); - - index.add_history(&h1); - index.add_history(&h2); - index.add_history(&h3); - - assert_eq!(index.command_count(), 3); - - // Search for "git" - should match 2 commands - let results = index - .search("git", IndexFilterMode::Global, &QueryContext::default(), 10) - .await; - assert_eq!(results.len(), 2); - - // Search with directory filter - let results = index - .search( - "", - IndexFilterMode::Directory(with_trailing_slash("/home/user/project")), - &QueryContext::default(), - 10, - ) - .await; - assert_eq!(results.len(), 2); // git status and git commit - } -} diff --git a/crates/atuin-daemon/src/search/mod.rs b/crates/atuin-daemon/src/search/mod.rs deleted file mode 100644 index 4d261956..00000000 --- a/crates/atuin-daemon/src/search/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! Search module for the daemon gRPC search service. -//! -//! This module provides fuzzy search over command history using Nucleo. - -mod index; - -// Include the generated proto code -tonic::include_proto!("search"); - -// Re-export the service and index -pub use index::{IndexFilterMode, QueryContext, SearchIndex}; diff --git a/crates/atuin-daemon/src/semantic/mod.rs b/crates/atuin-daemon/src/semantic/mod.rs deleted file mode 100644 index c3511676..00000000 --- a/crates/atuin-daemon/src/semantic/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! Semantic command capture gRPC service types. - -tonic::include_proto!("semantic"); diff --git a/crates/atuin-daemon/src/server.rs b/crates/atuin-daemon/src/server.rs deleted file mode 100644 index b823cff2..00000000 --- a/crates/atuin-daemon/src/server.rs +++ /dev/null @@ -1,170 +0,0 @@ -use eyre::Result; - -use crate::components::history::HistoryGrpcService; -use crate::components::search::SearchGrpcService; -use crate::components::semantic::SemanticGrpcService; -use crate::control::{ControlService, control_server::ControlServer}; -use crate::daemon::DaemonHandle; -use crate::history::history_server::HistoryServer; -use crate::search::search_server::SearchServer; -use crate::semantic::semantic_server::SemanticServer; - -use atuin_client::settings::Settings; - -/// Run the gRPC server with the given services. -/// -/// This starts the gRPC server in the background and returns immediately. -/// The server will shut down when a ShutdownRequested event is received. -#[cfg(unix)] -pub async fn run_grpc_server( - settings: Settings, - history_service: HistoryServer, - search_service: SearchServer, - semantic_service: SemanticServer, - control_service: ControlServer, - handle: DaemonHandle, -) -> Result<()> { - use tokio::net::UnixListener; - use tokio_stream::wrappers::UnixListenerStream; - - let socket_path = settings.daemon.socket_path.clone(); - - let (uds, cleanup) = if cfg!(target_os = "linux") && settings.daemon.systemd_socket { - #[cfg(target_os = "linux")] - { - use eyre::{OptionExt, WrapErr}; - use std::os::unix::net::SocketAddr; - use std::path::PathBuf; - tracing::info!("getting systemd socket"); - let listener = listenfd::ListenFd::from_env() - .take_unix_listener(0)? - .ok_or_eyre("missing systemd socket")?; - listener.set_nonblocking(true)?; - let actual_path: Result = listener - .local_addr() - .context("getting systemd socket's path") - .and_then(|addr: SocketAddr| { - addr.as_pathname() - .ok_or_eyre("systemd socket missing path") - .map(|path: &std::path::Path| path.to_owned()) - }); - match actual_path { - Ok(actual_path) => { - tracing::info!("listening on systemd socket: {actual_path:?}"); - if actual_path != std::path::Path::new(&socket_path) { - tracing::warn!( - "systemd socket is not at configured client path: {socket_path:?}" - ); - } - } - Err(err) => { - tracing::warn!( - "could not detect systemd socket path, ensure that it's at the configured path: {socket_path:?}, error: {err:?}" - ); - } - } - (UnixListener::from_std(listener)?, false) - } - #[cfg(not(target_os = "linux"))] - unreachable!() - } else { - tracing::info!("listening on unix socket {socket_path:?}"); - (UnixListener::bind(socket_path.clone())?, true) - }; - - let uds_stream = UnixListenerStream::new(uds); - - // Create shutdown signal from daemon handle - let shutdown_signal = async move { - let mut rx = handle.subscribe(); - loop { - use crate::DaemonEvent; - - match rx.recv().await { - Ok(DaemonEvent::ShutdownRequested) => break, - Ok(_) => continue, - Err(_) => break, // Channel closed - } - } - if cleanup { - eprintln!("Removing socket..."); - if let Err(e) = std::fs::remove_file(&socket_path) - && e.kind() != std::io::ErrorKind::NotFound - { - eprintln!("failed to remove socket: {e}"); - } - } - eprintln!("Shutting down gRPC server..."); - }; - - // Spawn the server in the background - tokio::spawn(async move { - use tonic::transport::Server; - - if let Err(e) = Server::builder() - .add_service(history_service) - .add_service(search_service) - .add_service(semantic_service) - .add_service(control_service) - .serve_with_incoming_shutdown(uds_stream, shutdown_signal) - .await - { - tracing::error!("gRPC server error: {e}"); - } - }); - - Ok(()) -} - -/// Run the gRPC server with the given services (Windows/TCP version). -#[cfg(not(unix))] -pub async fn run_grpc_server( - settings: Settings, - history_service: HistoryServer, - search_service: SearchServer, - semantic_service: SemanticServer, - control_service: ControlServer, - handle: DaemonHandle, -) -> Result<()> { - use tokio::net::TcpListener; - use tokio_stream::wrappers::TcpListenerStream; - use tonic::transport::Server; - - let port = settings.daemon.tcp_port; - let url = format!("127.0.0.1:{port}"); - let tcp = TcpListener::bind(&url).await?; - let tcp_stream = TcpListenerStream::new(tcp); - - tracing::info!("listening on tcp port {:?}", port); - - // Create shutdown signal from daemon handle - let shutdown_signal = async move { - use crate::DaemonEvent; - - let mut rx = handle.subscribe(); - loop { - match rx.recv().await { - Ok(DaemonEvent::ShutdownRequested) => break, - Ok(_) => continue, - Err(_) => break, // Channel closed - } - } - eprintln!("Shutting down gRPC server..."); - }; - - // Spawn the server in the background - tokio::spawn(async move { - if let Err(e) = Server::builder() - .add_service(history_service) - .add_service(search_service) - .add_service(semantic_service) - .add_service(control_service) - .serve_with_incoming_shutdown(tcp_stream, shutdown_signal) - .await - { - tracing::error!("gRPC server error: {e}"); - } - }); - - Ok(()) -} diff --git a/crates/atuin-daemon/tests/lifecycle.rs b/crates/atuin-daemon/tests/lifecycle.rs deleted file mode 100644 index 4a91e5cb..00000000 --- a/crates/atuin-daemon/tests/lifecycle.rs +++ /dev/null @@ -1,222 +0,0 @@ -//! Integration tests for the daemon server lifecycle. -//! -//! Each test spins up a real gRPC server on a temporary unix socket, -//! connects a client, and exercises the daemon RPCs. - -#[cfg(unix)] -mod unix { - use std::time::Duration; - - use atuin_client::database::Sqlite; - use atuin_client::record::sqlite_store::SqliteStore; - use atuin_client::settings::{Settings, init_meta_config_for_testing}; - use atuin_daemon::client::HistoryClient; - use atuin_daemon::components::HistoryComponent; - use atuin_daemon::{Daemon, DaemonHandle}; - use tempfile::TempDir; - use tokio::net::UnixListener; - use tokio_stream::wrappers::UnixListenerStream; - use tonic::transport::Server; - - /// Spins up a daemon server on a temp socket and returns a connected client, - /// the daemon handle (for shutdown), and the temp dir (must be held to keep paths alive). - async fn start_test_daemon() -> (HistoryClient, DaemonHandle, TempDir) { - let tmp = tempfile::tempdir().unwrap(); - - let db_path = tmp.path().join("history.db"); - let record_path = tmp.path().join("records.db"); - let key_path = tmp.path().join("key"); - let socket_path = tmp.path().join("test.sock"); - let meta_path = tmp.path().join("meta.db"); - - // Initialize the meta store config for testing (required for Settings::host_id()) - init_meta_config_for_testing(meta_path.to_str().unwrap(), 5.0); - - // Build settings with test paths - let settings: Settings = Settings::builder() - .expect("could not build settings builder") - .set_override("db_path", db_path.to_str().unwrap()) - .expect("failed to set db_path") - .set_override("record_store_path", record_path.to_str().unwrap()) - .expect("failed to set record_store_path") - .set_override("key_path", key_path.to_str().unwrap()) - .expect("failed to set key_path") - .set_override("daemon.socket_path", socket_path.to_str().unwrap()) - .expect("failed to set socket_path") - .set_override("meta.db_path", meta_path.to_str().unwrap()) - .expect("failed to set meta.db_path") - .build() - .expect("could not build settings") - .try_deserialize() - .expect("could not deserialize settings"); - - // Create databases - let history_db = Sqlite::new(&db_path, 5.0).await.unwrap(); - let store = SqliteStore::new(&record_path, 5.0).await.unwrap(); - - // Create the history component and get its gRPC service - let history_component = HistoryComponent::new(); - let history_service = history_component.grpc_service(); - - // Build and start the daemon - let mut daemon = Daemon::builder(settings) - .store(store) - .history_db(history_db) - .component(history_component) - .build() - .await - .unwrap(); - - let handle = daemon.handle(); - - // Start components (this initializes the history component with the handle) - daemon.start_components().await.unwrap(); - - // Start the gRPC server - let uds = UnixListener::bind(&socket_path).unwrap(); - let stream = UnixListenerStream::new(uds); - - let server_handle = handle.clone(); - tokio::spawn(async move { - let mut rx = server_handle.subscribe(); - Server::builder() - .add_service(history_service) - .serve_with_incoming_shutdown(stream, async move { - loop { - match rx.recv().await { - Ok(atuin_daemon::DaemonEvent::ShutdownRequested) => break, - Ok(_) => continue, - Err(_) => break, - } - } - }) - .await - .unwrap(); - }); - - // Spawn the daemon event loop in the background - tokio::spawn(async move { - daemon.run_event_loop().await.unwrap(); - }); - - // Give the server a moment to bind. - tokio::time::sleep(Duration::from_millis(50)).await; - - let client = HistoryClient::new(socket_path.to_string_lossy().to_string()) - .await - .unwrap(); - - (client, handle, tmp) - } - - #[tokio::test] - async fn test_status() { - let (mut client, _handle, _tmp) = start_test_daemon().await; - - let status = client.status().await.unwrap(); - assert!(status.healthy); - assert_eq!(status.version, env!("CARGO_PKG_VERSION")); - assert_eq!(status.protocol, 1); - assert!(status.pid > 0); - } - - #[tokio::test] - async fn test_start_end_history() { - use atuin_client::history::History; - - let (mut client, _handle, _tmp) = start_test_daemon().await; - - let history = History::daemon() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo hello".to_string()) - .cwd("/tmp".to_string()) - .session("test-session".to_string()) - .hostname("test-host".to_string()) - .build() - .into(); - - let start_reply = client.start_history(history).await.unwrap(); - assert!(!start_reply.id.is_empty()); - - let end_reply = client - .end_history(start_reply.id, 1_000_000, 0) - .await - .unwrap(); - assert!(!end_reply.id.is_empty()); - } - - #[tokio::test] - async fn test_tail_history_streams_started_and_ended_events() { - use atuin_client::history::History; - use atuin_daemon::history::HistoryEventKind; - - let (mut client, _handle, _tmp) = start_test_daemon().await; - let mut stream = client.tail_history().await.unwrap(); - - let history = History::daemon() - .timestamp(time::OffsetDateTime::now_utc()) - .command("git status".to_string()) - .cwd("/tmp/repo".to_string()) - .session("tail-session".to_string()) - .hostname("test-host:ellie".to_string()) - .author("claude".to_string()) - .intent("inspect repository state".to_string()) - .build() - .into(); - - let start_reply = client.start_history(history).await.unwrap(); - - let started = stream.message().await.unwrap().unwrap(); - assert_eq!( - HistoryEventKind::try_from(started.kind).unwrap(), - HistoryEventKind::Started - ); - let started_history = started.history.unwrap(); - assert_eq!(started_history.id, start_reply.id); - assert_eq!(started_history.command, "git status"); - assert_eq!(started_history.cwd, "/tmp/repo"); - assert_eq!(started_history.hostname, "test-host:ellie"); - assert_eq!(started_history.author, "claude"); - assert_eq!(started_history.intent, "inspect repository state"); - - client - .end_history(start_reply.id.clone(), 1_000_000, 0) - .await - .unwrap(); - - let ended = stream.message().await.unwrap().unwrap(); - assert_eq!( - HistoryEventKind::try_from(ended.kind).unwrap(), - HistoryEventKind::Ended - ); - let ended_history = ended.history.unwrap(); - assert_eq!(ended_history.id, start_reply.id); - assert_eq!(ended_history.exit, 0); - assert_eq!(ended_history.duration, 1_000_000); - } - - #[tokio::test] - async fn test_end_unknown_history_fails() { - let (mut client, _handle, _tmp) = start_test_daemon().await; - - let result = client - .end_history("nonexistent-id".to_string(), 1000, 0) - .await; - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_shutdown() { - let (mut client, _handle, _tmp) = start_test_daemon().await; - - let accepted = client.shutdown().await.unwrap(); - assert!(accepted); - - // Give server time to shut down. - tokio::time::sleep(Duration::from_millis(100)).await; - - // Subsequent calls should fail since the server is gone. - let result = client.status().await; - assert!(result.is_err()); - } -} diff --git a/crates/atuin-history/Cargo.toml b/crates/atuin-history/Cargo.toml deleted file mode 100644 index 50831a0b..00000000 --- a/crates/atuin-history/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "atuin-history" -description = "The history crate for Atuin" -edition = "2024" -version = { workspace = true } - -authors.workspace = true -rust-version.workspace = true -license.workspace = true -homepage.workspace = true -repository.workspace = true -readme.workspace = true - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -atuin-client = { path = "../atuin-client", version = "18.16.1" } - -time = { workspace = true } -serde = { workspace = true } -crossterm = { workspace = true, features = ["use-dev-tty"] } -unicode-segmentation = "1.11.0" - -[dev-dependencies] -divan = "0.1.14" -rand = { workspace = true } - -[[bench]] -name = "smart_sort" -harness = false diff --git a/crates/atuin-history/benches/smart_sort.rs b/crates/atuin-history/benches/smart_sort.rs deleted file mode 100644 index a78064de..00000000 --- a/crates/atuin-history/benches/smart_sort.rs +++ /dev/null @@ -1,35 +0,0 @@ -use atuin_client::history::History; -use atuin_history::sort::sort; - -use rand::Rng; - -fn main() { - // Run registered benchmarks. - divan::main(); -} - -// Smart sort usually runs on 200 entries, test on a few sizes -#[divan::bench(args=[100, 200, 400, 800, 1600, 10000])] -fn smart_sort(lines: usize) { - // benchmark a few different sizes of "history" - // first we need to generate some history. This will use a whole bunch of memory, sorry - let mut rng = rand::thread_rng(); - let now = time::OffsetDateTime::now_utc().unix_timestamp(); - - let possible_commands = ["echo", "ls", "cd", "grep", "atuin", "curl"]; - let mut commands = Vec::::with_capacity(lines); - - for _ in 0..lines { - let command = possible_commands[rng.gen_range(0..possible_commands.len())]; - - let command = History::import() - .command(command) - .timestamp(time::OffsetDateTime::from_unix_timestamp(rng.gen_range(0..now)).unwrap()) - .build() - .into(); - - commands.push(command); - } - - let _ = sort("curl", commands); -} diff --git a/crates/atuin-history/src/lib.rs b/crates/atuin-history/src/lib.rs deleted file mode 100644 index e7b33916..00000000 --- a/crates/atuin-history/src/lib.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod sort; -pub mod stats; diff --git a/crates/atuin-history/src/sort.rs b/crates/atuin-history/src/sort.rs deleted file mode 100644 index 022865a2..00000000 --- a/crates/atuin-history/src/sort.rs +++ /dev/null @@ -1,46 +0,0 @@ -use atuin_client::history::History; - -type ScoredHistory = (f64, History); - -// Fuzzy search already comes sorted by minspan -// This sorting should be applicable to all search modes, and solve the more "obvious" issues -// first. -// Later on, we can pass in context and do some boosts there too. -pub fn sort(query: &str, input: Vec) -> Vec { - // This can totally be extended. We need to be _careful_ that it's not slow. - // We also need to balance sorting db-side with sorting here. SQLite can do a lot, - // but some things are just much easier/more doable in Rust. - - let mut scored = input - .into_iter() - .map(|h| { - // If history is _prefixed_ with the query, score it more highly - let score = if h.command.starts_with(query) { - 2.0 - } else if h.command.contains(query) { - 1.75 - } else { - 1.0 - }; - - // calculate how long ago the history was, in seconds - let now = time::OffsetDateTime::now_utc().unix_timestamp(); - let time = h.timestamp.unix_timestamp(); - let diff = std::cmp::max(1, now - time); // no /0 please - - // prefer newer history, but not hugely so as to offset the other scoring - // the numbers will get super small over time, but I don't want time to overpower other - // scoring - #[expect(clippy::cast_precision_loss)] - let time_score = 1.0 + (1.0 / diff as f64); - let score = score * time_score; - - (score, h) - }) - .collect::>(); - - scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap().reverse()); - - // Remove the scores and return the history - scored.into_iter().map(|(_, h)| h).collect::>() -} diff --git a/crates/atuin-history/src/stats.rs b/crates/atuin-history/src/stats.rs deleted file mode 100644 index fedb1487..00000000 --- a/crates/atuin-history/src/stats.rs +++ /dev/null @@ -1,548 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use crossterm::style::{Color, ResetColor, SetAttribute, SetForegroundColor}; -use serde::{Deserialize, Serialize}; -use unicode_segmentation::UnicodeSegmentation; - -use atuin_client::{history::History, settings::Settings, theme::Meaning, theme::Theme}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Stats { - pub total_commands: usize, - pub unique_commands: usize, - pub top: Vec<(Vec, usize)>, -} - -fn first_non_whitespace(s: &str) -> Option { - s.char_indices() - // find the first non whitespace char - .find(|(_, c)| !c.is_ascii_whitespace()) - // return the index of that char - .map(|(i, _)| i) -} - -fn first_whitespace(s: &str) -> usize { - s.char_indices() - // find the first whitespace char - .find(|(_, c)| c.is_ascii_whitespace()) - // return the index of that char, (or the max length of the string) - .map_or(s.len(), |(i, _)| i) -} - -fn interesting_command<'a>(settings: &Settings, mut command: &'a str) -> &'a str { - // Sort by length so that we match the longest prefix first - let mut common_prefix = settings.stats.common_prefix.clone(); - common_prefix.sort_by_key(|b| std::cmp::Reverse(b.len())); - - // Trim off the common prefix, if it exists - for p in &common_prefix { - if command.starts_with(p) { - let i = p.len(); - let prefix = &command[..i]; - command = command[i..].trim_start(); - if command.is_empty() { - // no commands following, just use the prefix - return prefix; - } - break; - } - } - - // Sort the common_subcommands by length so that we match the longest subcommand first - let mut common_subcommands = settings.stats.common_subcommands.clone(); - common_subcommands.sort_by_key(|b| std::cmp::Reverse(b.len())); - - // Check for a common subcommand - for p in &common_subcommands { - if command.starts_with(p) { - // if the subcommand is the same length as the command, then we just use the subcommand - if p.len() == command.len() { - return command; - } - // otherwise we need to use the subcommand + the next word - let non_whitespace = first_non_whitespace(&command[p.len()..]).unwrap_or(0); - let j = - p.len() + non_whitespace + first_whitespace(&command[p.len() + non_whitespace..]); - return &command[..j]; - } - } - // Return the first word if there is no subcommand - &command[..first_whitespace(command)] -} - -fn split_at_pipe(command: &str) -> Vec<&str> { - let mut result = vec![]; - let mut quoted = false; - let mut start = 0; - let mut graphemes = UnicodeSegmentation::grapheme_indices(command, true); - - while let Some((i, c)) = graphemes.next() { - let current = i; - match c { - "\"" if command[start..current] != *"\"" => { - quoted = !quoted; - } - "'" if command[start..current] != *"'" => { - quoted = !quoted; - } - "\\" if graphemes.next().is_some() => {} - "|" if !quoted => { - if current > start && command[start..].starts_with('|') { - start += 1; - } - result.push(&command[start..current]); - start = current; - } - _ => {} - } - } - if command[start..].starts_with('|') { - start += 1; - } - result.push(&command[start..]); - result -} - -fn strip_leading_env_vars(command: &str) -> &str { - // fast path: no equals sign, no environment variable - if !command.contains('=') { - return command; - } - - let mut in_token = false; - let mut token_start_pos = 0; - let mut in_single_quotes = false; - let mut in_double_quotes = false; - let mut escape_next = false; - let mut has_equals_outside_quotes = false; - - for (i, g) in UnicodeSegmentation::grapheme_indices(command, true) { - if escape_next { - escape_next = false; - continue; - } - - if !in_token { - token_start_pos = i; - } - - match g { - "\\" => { - escape_next = true; - in_token = true; - } - "'" if !in_double_quotes => { - in_single_quotes = !in_single_quotes; - in_token = true; - } - "\"" if !in_single_quotes => { - in_double_quotes = !in_double_quotes; - in_token = true; - } - "=" if !in_single_quotes && !in_double_quotes => { - has_equals_outside_quotes = true; - in_token = true; - } - " " | "\t" if !in_single_quotes && !in_double_quotes => { - if in_token { - if !has_equals_outside_quotes { - // if we're not in an env var, we can break early - break; - } - in_token = false; - has_equals_outside_quotes = false; - } - } - _ => { - in_token = true; - } - } - } - - command[token_start_pos..].trim() -} - -pub fn pretty_print(stats: Stats, ngram_size: usize, theme: &Theme) { - let max = stats.top.iter().map(|x| x.1).max().unwrap(); - let num_pad = max.ilog10() as usize + 1; - - // Find the length of the longest command name for each column - let column_widths = stats - .top - .iter() - .map(|(commands, _)| commands.iter().map(|c| c.len()).collect::>()) - .fold(vec![0; ngram_size], |acc, item| { - acc.iter() - .zip(item.iter()) - .map(|(a, i)| *std::cmp::max(a, i)) - .collect() - }); - - for (command, count) in stats.top { - let gray = SetForegroundColor(match theme.as_style(Meaning::Muted).foreground_color { - Some(color) => color, - None => Color::Grey, - }); - let bold = SetAttribute(crossterm::style::Attribute::Bold); - - let in_ten = 10 * count / max; - - print!("["); - print!( - "{}", - SetForegroundColor(match theme.get_error().foreground_color { - Some(color) => color, - None => Color::Red, - }) - ); - - for i in 0..in_ten { - if i == 2 { - print!( - "{}", - SetForegroundColor(match theme.get_warning().foreground_color { - Some(color) => color, - None => Color::Yellow, - }) - ); - } - - if i == 5 { - print!( - "{}", - SetForegroundColor(match theme.get_info().foreground_color { - Some(color) => color, - None => Color::Green, - }) - ); - } - - print!("▮"); - } - - for _ in in_ten..10 { - print!(" "); - } - - let formatted_command = command - .iter() - .zip(column_widths.iter()) - .map(|(cmd, width)| format!("{cmd:width$}")) - .collect::>() - .join(" | "); - - println!( - "{ResetColor}] {gray}{count:num_pad$}{ResetColor} {bold}{formatted_command}{ResetColor}" - ); - } - println!("Total commands: {}", stats.total_commands); - println!("Unique commands: {}", stats.unique_commands); -} - -pub fn compute( - settings: &Settings, - history: &[History], - count: usize, - ngram_size: usize, -) -> Option { - let mut commands = HashSet::<&str>::with_capacity(history.len()); - let mut total_unignored = 0; - let mut prefixes = HashMap::, usize>::with_capacity(history.len()); - - for i in history { - // just in case it somehow has a leading tab or space or something (legacy atuin didn't ignore space prefixes) - let command = strip_leading_env_vars(i.command.trim()); - let prefix = interesting_command(settings, command); - - if settings.stats.ignored_commands.iter().any(|c| c == prefix) { - continue; - } - - total_unignored += 1; - commands.insert(command); - - split_at_pipe(command) - .iter() - .map(|l| { - let command = l.trim(); - commands.insert(command); - command - }) - .collect::>() - .windows(ngram_size) - .for_each(|w| { - *prefixes - .entry(w.iter().map(|c| interesting_command(settings, c)).collect()) - .or_default() += 1; - }); - } - - let unique = commands.len(); - let mut top = prefixes.into_iter().collect::>(); - - top.sort_unstable_by_key(|x| std::cmp::Reverse(x.1)); - top.truncate(count); - - if top.is_empty() { - return None; - } - - Some(Stats { - unique_commands: unique, - total_commands: total_unignored, - top: top - .into_iter() - .map(|t| (t.0.into_iter().map(|s| s.to_string()).collect(), t.1)) - .collect(), - }) -} - -#[cfg(test)] -mod tests { - use atuin_client::history::History; - use atuin_client::settings::Settings; - use time::OffsetDateTime; - - use super::compute; - use super::{interesting_command, split_at_pipe, strip_leading_env_vars}; - - #[test] - fn ignored_env_vars() { - let settings = Settings::utc(); - - let history: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("FOO='BAR=🚀' echo foo") - .cwd("/") - .build() - .into(); - - let stats = compute(&settings, &[history], 10, 1).expect("failed to compute stats"); - assert_eq!(stats.top.first().unwrap().0, vec!["echo"]); - } - - #[test] - fn ignored_commands() { - let mut settings = Settings::utc(); - settings.stats.ignored_commands.push("cd".to_string()); - - let history = [ - History::import() - .timestamp(OffsetDateTime::now_utc()) - .command("cd foo") - .build() - .into(), - History::import() - .timestamp(OffsetDateTime::now_utc()) - .command("cargo build stuff") - .build() - .into(), - ]; - - let stats = compute(&settings, &history, 10, 1).expect("failed to compute stats"); - assert_eq!(stats.total_commands, 1); - assert_eq!(stats.unique_commands, 1); - } - - #[test] - fn interesting_commands() { - let settings = Settings::utc(); - - assert_eq!(interesting_command(&settings, "cargo"), "cargo"); - assert_eq!( - interesting_command(&settings, "cargo build foo bar"), - "cargo build" - ); - assert_eq!( - interesting_command(&settings, "sudo cargo build foo bar"), - "cargo build" - ); - assert_eq!(interesting_command(&settings, "sudo"), "sudo"); - } - - // Test with spaces in the common_prefix - #[test] - fn interesting_commands_spaces() { - let mut settings = Settings::utc(); - settings.stats.common_prefix.push("sudo test".to_string()); - - assert_eq!(interesting_command(&settings, "sudo test"), "sudo test"); - assert_eq!(interesting_command(&settings, "sudo test "), "sudo test"); - assert_eq!(interesting_command(&settings, "sudo test foo bar"), "foo"); - assert_eq!( - interesting_command(&settings, "sudo test foo bar"), - "foo" - ); - - // Works with a common_subcommand as well - assert_eq!( - interesting_command(&settings, "sudo test cargo build foo bar"), - "cargo build" - ); - - // We still match on just the sudo prefix - assert_eq!(interesting_command(&settings, "sudo"), "sudo"); - assert_eq!(interesting_command(&settings, "sudo foo"), "foo"); - } - - // Test with spaces in the common_subcommand - #[test] - fn interesting_commands_spaces_subcommand() { - let mut settings = Settings::utc(); - settings - .stats - .common_subcommands - .push("cargo build".to_string()); - - assert_eq!(interesting_command(&settings, "cargo build"), "cargo build"); - assert_eq!( - interesting_command(&settings, "cargo build "), - "cargo build" - ); - assert_eq!( - interesting_command(&settings, "cargo build foo bar"), - "cargo build foo" - ); - - // Works with a common_prefix as well - assert_eq!( - interesting_command(&settings, "sudo cargo build foo bar"), - "cargo build foo" - ); - - // We still match on just cargo as a subcommand - assert_eq!(interesting_command(&settings, "cargo"), "cargo"); - assert_eq!(interesting_command(&settings, "cargo foo"), "cargo foo"); - } - - // Test with spaces in the common_prefix and common_subcommand - #[test] - fn interesting_commands_spaces_both() { - let mut settings = Settings::utc(); - settings.stats.common_prefix.push("sudo test".to_string()); - settings - .stats - .common_subcommands - .push("cargo build".to_string()); - - assert_eq!( - interesting_command(&settings, "sudo test cargo build"), - "cargo build" - ); - assert_eq!( - interesting_command(&settings, "sudo test cargo build"), - "cargo build" - ); - assert_eq!( - interesting_command(&settings, "sudo test cargo build "), - "cargo build" - ); - assert_eq!( - interesting_command(&settings, "sudo test cargo build foo bar"), - "cargo build foo" - ); - } - - #[test] - fn split_simple() { - assert_eq!(split_at_pipe("fd | rg"), ["fd ", " rg"]); - } - - #[test] - fn split_multi() { - assert_eq!( - split_at_pipe("kubectl | jq | rg"), - ["kubectl ", " jq ", " rg"] - ); - } - - #[test] - fn split_simple_quoted() { - assert_eq!( - split_at_pipe("foo | bar 'baz {} | quux' | xyzzy"), - ["foo ", " bar 'baz {} | quux' ", " xyzzy"] - ); - } - - #[test] - fn split_multi_quoted() { - assert_eq!( - split_at_pipe("foo | bar 'baz \"{}\" | quux' | xyzzy"), - ["foo ", " bar 'baz \"{}\" | quux' ", " xyzzy"] - ); - } - - #[test] - fn escaped_pipes() { - assert_eq!( - split_at_pipe("foo | bar baz \\| quux"), - ["foo ", " bar baz \\| quux"] - ); - } - - #[test] - fn emoji() { - assert_eq!( - split_at_pipe("git commit -m \"🚀\""), - ["git commit -m \"🚀\""] - ); - } - - #[test] - fn starts_with_pipe() { - assert_eq!( - split_at_pipe("| sed 's/[0-9a-f]//g'"), - ["", " sed 's/[0-9a-f]//g'"] - ); - } - - #[test] - fn starts_with_spaces_and_pipe() { - assert_eq!( - split_at_pipe(" | sed 's/[0-9a-f]//g'"), - [" ", " sed 's/[0-9a-f]//g'"] - ); - } - - #[test] - fn strip_leading_env_vars_simple() { - assert_eq!( - strip_leading_env_vars("FOO=bar BAZ=quux echo foo"), - "echo foo" - ); - } - - #[test] - fn strip_leading_env_vars_quoted_single() { - assert_eq!(strip_leading_env_vars("FOO='BAR=baz' echo foo"), "echo foo"); - } - - #[test] - fn strip_leading_env_vars_quoted_double() { - assert_eq!( - strip_leading_env_vars("FOO=\"BAR=baz\" echo foo"), - "echo foo" - ); - } - - #[test] - fn strip_leading_env_vars_quoted_single_and_double() { - assert_eq!( - strip_leading_env_vars("FOO='BAR=\"baz\"' echo foo \"BAR=quux\""), - "echo foo \"BAR=quux\"" - ); - } - - #[test] - fn strip_leading_env_vars_emojis() { - assert_eq!( - strip_leading_env_vars("FOO='BAR=🚀' echo foo \"BAR=quux\" foo"), - "echo foo \"BAR=quux\" foo" - ); - } - - #[test] - fn strip_leading_env_vars_name_same_as_command() { - assert_eq!(strip_leading_env_vars("FOO='bar' bar baz"), "bar baz"); - } -} diff --git a/crates/atuin-pty-proxy/Cargo.toml b/crates/atuin-pty-proxy/Cargo.toml deleted file mode 100644 index baacf776..00000000 --- a/crates/atuin-pty-proxy/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "atuin-pty-proxy" -edition = "2024" -description = "a PTY proxy for atuin" - -version = { workspace = true } -authors = { workspace = true } -rust-version = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[dependencies] -clap = { workspace = true } - -[target.'cfg(unix)'.dependencies] -crossterm = { workspace = true } -eyre = { workspace = true } -portable-pty = "0.9" -signal-hook = "0.3" -vt100 = { workspace = true } diff --git a/crates/atuin-pty-proxy/src/capture.rs b/crates/atuin-pty-proxy/src/capture.rs deleted file mode 100644 index 6426035b..00000000 --- a/crates/atuin-pty-proxy/src/capture.rs +++ /dev/null @@ -1,467 +0,0 @@ -use std::sync::Arc; -use std::sync::atomic::{AtomicU16, Ordering}; - -use crate::osc133::{Event, Params, Parser, Zone}; - -const HISTORY_ID_PARAM: &str = "history_id"; -const SESSION_ID_PARAM: &str = "session_id"; -const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommandCapture { - pub prompt: String, - pub command: String, - pub output: String, - pub exit_code: Option, - pub history_id: Option, - pub session_id: Option, - pub output_truncated: bool, - pub output_observed_bytes: u64, -} - -pub type CommandCaptureSink = Box; - -#[derive(Default)] -struct CaptureBuffers { - prompt: Vec, - command: Vec, - output: Vec, - output_observed_bytes: u64, - output_truncated: bool, - exit_code: Option, - history_id: Option, - session_id: Option, -} - -pub(crate) struct CommandCaptureTracker { - parser: Parser, - zone: Zone, - buffers: CaptureBuffers, - cols: Arc, -} - -impl CommandCaptureTracker { - pub(crate) fn new(cols: Arc) -> Self { - Self { - parser: Parser::new(), - zone: Zone::Unknown, - buffers: CaptureBuffers::default(), - cols, - } - } - - pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { - let mut events = Vec::new(); - self.parser - .push_located(data, |located| events.push(located)); - - let mut start = 0; - for located in events { - let marker_start = located.start_offset.min(data.len()).max(start); - let offset = located.offset.min(data.len()); - self.append(&data[start..marker_start]); - self.handle_event(located.event, &located.params, &mut on_capture); - self.zone = located.zone; - start = offset; - } - - let append_end = self - .parser - .incomplete_osc_sequence_start() - .map_or(data.len(), |sequence_start| { - sequence_start.min(data.len()).max(start) - }); - if start < append_end { - self.append(&data[start..append_end]); - } - } - - fn append(&mut self, data: &[u8]) { - match self.zone { - Zone::Prompt => self.buffers.prompt.extend_from_slice(data), - Zone::Input => self.buffers.command.extend_from_slice(data), - Zone::Output => self.append_output(data), - Zone::Unknown => {} - } - } - - fn append_output(&mut self, data: &[u8]) { - self.buffers.output_observed_bytes = self - .buffers - .output_observed_bytes - .saturating_add(data.len() as u64); - - if self.buffers.output_truncated { - return; - } - - let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); - let retained = data.len().min(remaining); - self.buffers.output_truncated = retained < data.len(); - - if retained > 0 { - self.buffers.output.extend_from_slice(&data[..retained]); - } - } - - fn handle_event( - &mut self, - event: Event, - params: &Params, - on_capture: &mut impl FnMut(CommandCapture), - ) { - match event { - Event::PromptStart => { - if self.zone != Zone::Prompt { - self.buffers = CaptureBuffers::default(); - } - } - Event::CommandStart | Event::CommandExecuted => {} - Event::CommandFinished { exit_code } => { - let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { - return; - }; - - if exit_code.is_some() || self.buffers.exit_code.is_none() { - self.buffers.exit_code = exit_code; - } - self.buffers.history_id = Some(history_id); - self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); - - if let Some(capture) = self.finish_capture() { - on_capture(capture); - } - } - } - } - - fn finish_capture(&mut self) -> Option { - let buffers = std::mem::take(&mut self.buffers); - let cols = self.cols.load(Ordering::Relaxed).max(1); - let prompt = render_plain_text(&buffers.prompt, cols); - let command = render_plain_text(&buffers.command, cols) - .trim_matches(|c| c == '\r' || c == '\n') - .to_string(); - let output = render_plain_text(&buffers.output, cols); - let output_truncated = buffers.output_truncated; - let output_observed_bytes = buffers.output_observed_bytes; - let exit_code = buffers.exit_code; - let history_id = buffers.history_id; - let session_id = buffers.session_id; - - if command.is_empty() && output.is_empty() { - return None; - } - - Some(CommandCapture { - prompt, - command, - output, - exit_code, - history_id, - session_id, - output_truncated, - output_observed_bytes, - }) - } -} - -const CLEAN_TEXT_MAX_ROWS: usize = 10_000; - -fn render_plain_text(bytes: &[u8], cols: u16) -> String { - if bytes.is_empty() { - return String::new(); - } - - let cols = cols.max(1); - let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); - parser.process(bytes); - normalize_screen_contents(&parser.screen().contents()) -} - -fn normalize_screen_contents(contents: &str) -> String { - let mut lines = contents.lines().map(str::trim_end).collect::>(); - while lines.last().is_some_and(|line| line.is_empty()) { - lines.pop(); - } - lines.join("\n") -} - -fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { - let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; - let wrapped_rows = bytes.len() / cols as usize; - newline_rows - .saturating_add(wrapped_rows) - .saturating_add(1) - .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 -} - -#[cfg(test)] -mod tests { - use super::*; - - fn tracker(cols: u16) -> CommandCaptureTracker { - CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) - } - - fn assert_no_terminal_controls(text: &str) { - assert!( - !text - .chars() - .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), - "text still contains terminal controls: {text:?}" - ); - } - - #[test] - fn command_text_collapses_terminal_echo_edits() { - assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); - assert_eq!( - render_plain_text( - b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", - 80 - ), - "echo hi" - ); - assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); - } - - #[test] - fn text_cleaning_strips_ansi_and_terminal_controls() { - let text = render_plain_text( - b"\x1b[32mhi\x1b[0m\r\n% \r \r", - 80, - ); - - assert_eq!(text, "hi"); - assert_no_terminal_controls(&text); - } - - #[test] - fn text_cleaning_preserves_valid_utf8_after_backspace() { - let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); - - assert_eq!(text, "🦀 crab"); - assert_no_terminal_controls(&text); - } - - #[test] - fn command_text_replays_backspaces() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - let input = - b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; - tracker.push(input, |capture| captures.push(capture)); - - assert_eq!(captures.len(), 1); - assert_eq!(captures[0].command, "echo hi"); - assert_eq!(captures[0].output, "hi"); - assert_no_terminal_controls(&captures[0].command); - assert_no_terminal_controls(&captures[0].output); - } - - #[test] - fn captures_complete_command() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", - |capture| captures.push(capture), - ); - - assert_eq!( - captures, - vec![CommandCapture { - prompt: "$".to_string(), - command: "echo hi".to_string(), - output: "hi".to_string(), - exit_code: Some(0), - history_id: Some("hist".to_string()), - session_id: Some("sess".to_string()), - output_truncated: false, - output_observed_bytes: 4, - }] - ); - } - - #[test] - fn strips_ansi_and_split_markers() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); - tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); - tracker.push( - b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", - |capture| { - captures.push(capture); - }, - ); - - assert_eq!( - captures, - vec![CommandCapture { - prompt: "%".to_string(), - command: "ls".to_string(), - output: "file".to_string(), - exit_code: Some(1), - history_id: Some("hist".to_string()), - session_id: Some("sess".to_string()), - output_truncated: false, - output_observed_bytes: 15, - }] - ); - } - - #[test] - fn duplicate_prompt_start_does_not_reset_prompt_capture() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", - |capture| captures.push(capture), - ); - - assert_eq!(captures.len(), 1); - assert_eq!(captures[0].prompt, "$ continued"); - assert_eq!(captures[0].command, "echo hi"); - assert_eq!(captures[0].output, "hi"); - } - - #[test] - fn bare_finish_without_metadata_is_ignored() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { - captures.push(capture); - }); - - tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); - - assert!(captures.is_empty()); - } - - #[test] - fn bare_finish_before_metadata_in_same_push_ignored() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", - |capture| captures.push(capture), - ); - - assert_eq!(captures.len(), 1); - assert_eq!(captures[0].output, "line one"); - assert_eq!(captures[0].exit_code, Some(0)); - assert_eq!(captures[0].history_id.as_deref(), Some("018f")); - assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); - } - - #[test] - fn metadata_arriving_after_bare_finish_across_pushes() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { - captures.push(capture); - }); - tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { - captures.push(capture) - }); - - assert!(captures.is_empty()); - - tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); - - assert_eq!(captures.len(), 1); - assert_eq!(captures[0].output, "line one"); - assert_eq!(captures[0].exit_code, Some(0)); - assert_eq!(captures[0].history_id.as_deref(), Some("018f")); - assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); - } - - #[test] - fn split_finish_marker_is_not_counted_as_output() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", - |capture| { - captures.push(capture); - }, - ); - assert!(captures.is_empty()); - - tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); - - assert_eq!(captures.len(), 1); - assert_eq!(captures[0].output, "line one"); - assert_eq!(captures[0].output_observed_bytes, 10); - } - - #[test] - fn captures_output_with_history_metadata_from_d_marker() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", - |capture| captures.push(capture), - ); - - assert_eq!( - captures, - vec![CommandCapture { - prompt: String::new(), - command: String::new(), - output: "line one".to_string(), - exit_code: Some(0), - history_id: Some("018f".to_string()), - session_id: Some("abcd".to_string()), - output_truncated: false, - output_observed_bytes: 10, - }] - ); - } - - #[test] - fn output_capture_is_capped_and_reports_observed_bytes() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - let mut input = b"\x1b]133;C\x07".to_vec(); - input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); - input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); - - tracker.push(&input, |capture| captures.push(capture)); - - assert_eq!(captures.len(), 1); - assert!(captures[0].output_truncated); - assert_eq!( - captures[0].output_observed_bytes, - (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 - ); - } - - #[test] - fn resets_buffers_between_c_d_only_captures() { - let mut tracker = tracker(80); - let mut captures = Vec::new(); - - tracker.push( - b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", - |capture| captures.push(capture), - ); - - assert_eq!(captures.len(), 2); - assert_eq!(captures[0].output, "first"); - assert_eq!(captures[0].history_id.as_deref(), Some("one")); - assert_eq!(captures[1].output, "second"); - assert_eq!(captures[1].history_id.as_deref(), Some("two")); - } -} diff --git a/crates/atuin-pty-proxy/src/debug.rs b/crates/atuin-pty-proxy/src/debug.rs deleted file mode 100644 index 806bde90..00000000 --- a/crates/atuin-pty-proxy/src/debug.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::osc133::{Event, Parser}; - -pub(crate) const RESET: &[u8] = b"\x1b[0m"; - -pub(crate) struct Osc133DebugHighlighter { - parser: Parser, -} - -impl Osc133DebugHighlighter { - pub(crate) fn new() -> Self { - Self { - parser: Parser::new(), - } - } - - pub(crate) fn render(&mut self, data: &[u8]) -> Vec { - let mut events = Vec::new(); - self.parser - .push_located(data, |located| events.push(located)); - - if events.is_empty() { - return data.to_vec(); - } - - let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); - let mut start = 0; - - for located in events { - let offset = located.offset.min(data.len()); - if offset > start { - rendered.extend_from_slice(&data[start..offset]); - } - - rendered.extend_from_slice(event_label(&located.event)); - rendered.extend_from_slice(RESET); - start = offset; - } - - rendered.extend_from_slice(&data[start..]); - rendered - } -} - -fn event_label(event: &Event) -> &'static [u8] { - match event { - Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", - Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", - Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", - Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", - Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", - Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", - } -} diff --git a/crates/atuin-pty-proxy/src/lib.rs b/crates/atuin-pty-proxy/src/lib.rs deleted file mode 100644 index d1571079..00000000 --- a/crates/atuin-pty-proxy/src/lib.rs +++ /dev/null @@ -1,48 +0,0 @@ -#[cfg(unix)] -mod capture; -#[cfg(unix)] -mod debug; -#[cfg(unix)] -mod osc133; -#[cfg(unix)] -mod pty_proxy; -#[cfg(unix)] -mod runtime; -#[cfg(unix)] -mod screen; - -#[cfg(unix)] -pub use capture::{CommandCapture, CommandCaptureSink}; -#[cfg(unix)] -pub use pty_proxy::PtyProxy; - -#[cfg(not(unix))] -#[expect(dead_code)] -mod unsupported { - use clap::{Args, Subcommand}; - - #[derive(Args, Debug)] - pub struct PtyProxy { - /// Highlight OSC 133 prompt, input, output, and exit-code regions - #[arg(long)] - debug_osc133: bool, - - #[command(subcommand)] - cmd: Option, - } - - #[derive(Subcommand, Debug)] - enum Cmd { - /// Print shell code to initialize atuin pty-proxy on shell startup - Init(Init), - } - - #[derive(Args, Debug)] - struct Init { - /// Shell to generate init for. If omitted, attempt auto-detection - shell: Option, - } -} - -#[cfg(not(unix))] -pub use unsupported::PtyProxy; diff --git a/crates/atuin-pty-proxy/src/osc133.rs b/crates/atuin-pty-proxy/src/osc133.rs deleted file mode 100644 index 5b70f0aa..00000000 --- a/crates/atuin-pty-proxy/src/osc133.rs +++ /dev/null @@ -1,900 +0,0 @@ -//! Streaming parser for OSC 133 (FinalTerm semantic prompt) escape sequences. -//! -//! OSC 133 marks four regions of a shell interaction: -//! -//! | Marker | Meaning | -//! |--------|--------------------------------------| -//! | A | Prompt is about to be printed | -//! | B | Prompt ended — command input begins | -//! | C | Command submitted — output begins | -//! | D[;n] | Command finished with exit code *n* | -//! -//! The wire format is `ESC ] 133 ; [; ] ST` where ST is BEL -//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). -//! -//! # Design goals -//! -//! * **Transparent** — the parser observes the byte stream without modifying it; -//! the caller remains responsible for forwarding bytes to their destination. -//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot -//! grow memory without limit. -//! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available -//! and returns immediately. -//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata -//! can ride alongside standard OSC 133 markers. - -/// Events emitted when an OSC 133 marker is detected. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Event { - /// `ESC ] 133 ; A ST` — the shell is about to display its prompt. - PromptStart, - /// `ESC ] 133 ; B ST` — the prompt has ended; the user may type a command. - CommandStart, - /// `ESC ] 133 ; C ST` — the command has been submitted for execution. - CommandExecuted, - /// `ESC ] 133 ; D [; ] ST` — command output is complete. - CommandFinished { - /// The exit code reported after the `;`, if present and valid. - exit_code: Option, - }, -} - -/// Parameters attached to an OSC 133 marker. -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct Params { - items: Vec, -} - -impl Params { - /// Iterate over all marker parameters in order. - #[cfg(test)] - #[inline] - pub fn iter(&self) -> impl Iterator { - self.items.iter() - } - - /// Return the value for the first `key=value` parameter with this key. - #[inline] - pub fn get(&self, key: &str) -> Option<&str> { - self.items.iter().find_map(|item| match item { - Param::KeyValue { - key: item_key, - value, - } if item_key == key => Some(value.as_str()), - Param::Value(_) | Param::KeyValue { .. } => None, - }) - } -} - -/// A single OSC 133 marker parameter. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Param { - /// A positional parameter without an equals sign. - Value(String), - /// A `key=value` parameter. - KeyValue { key: String, value: String }, -} - -/// An OSC 133 event with its position in the most recent input chunk. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LocatedEvent { - /// The OSC 133 event that was parsed. - pub event: Event, - /// Offset where this marker starts in the current chunk. - /// - /// If a marker started in an earlier [`Parser::push_located`] call, this is - /// `0` in the chunk that completed the marker. - pub start_offset: usize, - /// Offset immediately after this marker's terminator in the current chunk. - /// - /// If a marker spans multiple [`Parser::push_located`] calls, this is still - /// the offset in the chunk that completed the marker. - pub offset: usize, - /// The semantic zone after applying this event. - pub zone: Zone, - /// Metadata parameters attached to this marker. - pub params: Params, -} - -/// The current semantic zone as determined by the most recent OSC 133 marker. -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] -#[expect(dead_code)] -pub enum Zone { - /// No marker seen yet, or after a `D` marker (between commands). - #[default] - Unknown, - /// Between `A` and `B` — the shell is rendering its prompt. - Prompt, - /// Between `B` and `C` — the user is editing a command line. - Input, - /// Between `C` and `D` — command output is being produced. - Output, -} - -// --------------------------------------------------------------------------- -// Internal constants -// --------------------------------------------------------------------------- - -const ESC: u8 = 0x1B; -const BEL: u8 = 0x07; -const C1_ST: u8 = 0x9C; -const BACKSLASH: u8 = b'\\'; -const RIGHT_BRACKET: u8 = b']'; - -/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough -/// for Atuin metadata such as history/session IDs while still bounding malformed -/// OSC sequences. -const PARAM_BUF_CAP: usize = 512; - -// --------------------------------------------------------------------------- -// State machine -// --------------------------------------------------------------------------- - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum State { - /// Normal pass-through. - Ground, - /// Saw ESC (0x1B). - Esc, - /// Inside an OSC sequence (`ESC ]`), accumulating parameter bytes. - OscParam, - /// Inside an OSC sequence, saw ESC — next byte decides if this is `ESC \` - /// (string terminator) or something else. - OscEsc, -} - -/// A streaming, zero-allocation parser for OSC 133 escape sequences. -/// -/// Feed arbitrary byte slices into [`Parser::push`]. The parser detects -/// OSC 133 markers and reports [`Event`]s through a caller-supplied callback -/// without modifying the data. It can sit transparently between a PTY reader -/// and stdout. -pub struct Parser { - state: State, - zone: Zone, - sequence_start: Option, - param_buf: [u8; PARAM_BUF_CAP], - param_len: usize, -} - -impl Default for Parser { - fn default() -> Self { - Self::new() - } -} - -impl Parser { - /// Create a new parser in the initial (ground / unknown-zone) state. - #[inline] - pub fn new() -> Self { - Self { - state: State::Ground, - zone: Zone::Unknown, - sequence_start: None, - param_buf: [0u8; PARAM_BUF_CAP], - param_len: 0, - } - } - - /// The current semantic zone based on markers seen so far. - #[inline] - #[expect(dead_code)] - pub fn zone(&self) -> Zone { - self.zone - } - - /// Start offset of an incomplete OSC sequence in the most recent chunk. - #[inline] - pub(crate) fn incomplete_osc_sequence_start(&self) -> Option { - matches!(self.state, State::OscParam | State::OscEsc) - .then(|| self.sequence_start.unwrap_or(0)) - } - - /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker - /// found. - /// - /// All bytes in `data` should still be forwarded to the terminal by the - /// caller — this method only *observes* the stream. - #[cfg(test)] - #[inline] - pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { - self.push_located(data, |located| on_event(located.event)); - } - - /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker - /// found with its byte offset in this chunk. - /// - /// The offset points to the first byte after the marker terminator, making - /// it suitable for callers that need to split the original chunk at marker - /// boundaries. - #[inline] - pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { - self.sequence_start = (self.state != State::Ground).then_some(0); - - for (offset, &byte) in data.iter().enumerate() { - match self.state { - State::Ground => { - if byte == ESC { - self.state = State::Esc; - self.sequence_start = Some(offset); - } - } - State::Esc => { - if byte == RIGHT_BRACKET { - self.state = State::OscParam; - self.param_len = 0; - } else { - self.state = State::Ground; - self.sequence_start = None; - } - } - State::OscParam => { - if byte == BEL || byte == C1_ST { - self.dispatch(offset + 1, &mut on_event); - self.state = State::Ground; - self.sequence_start = None; - } else if byte == ESC { - self.state = State::OscEsc; - } else if self.param_len < PARAM_BUF_CAP { - self.param_buf[self.param_len] = byte; - self.param_len += 1; - } - // If param_len == PARAM_BUF_CAP we silently stop - // accumulating — dispatch will ignore non-133 sequences. - } - State::OscEsc => { - if byte == BACKSLASH { - self.dispatch(offset + 1, &mut on_event); - } - // Whether we got a valid ST or not, return to ground. - // (A new ESC ] would restart accumulation via the Ground - // -> Esc -> OscParam path on the *next* byte.) - self.state = State::Ground; - self.sequence_start = None; - } - } - } - } - - /// Inspect the accumulated parameter buffer. If it holds an OSC 133 - /// payload, emit the corresponding [`Event`] and update the zone. - #[inline] - fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { - let payload = &self.param_buf[..self.param_len]; - - if payload.len() < 5 || &payload[..4] != b"133;" { - return; - } - - if payload.len() > 5 && payload[5] != b';' { - return; - } - - let metadata = payload.get(6..).unwrap_or_default(); - let cmd = payload[4]; - let (event, params) = match cmd { - b'A' => { - self.zone = Zone::Prompt; - (Event::PromptStart, parse_params(metadata)) - } - b'B' => { - self.zone = Zone::Input; - (Event::CommandStart, parse_params(metadata)) - } - b'C' => { - self.zone = Zone::Output; - (Event::CommandExecuted, parse_params(metadata)) - } - b'D' => { - let (exit_code, params) = parse_command_finished_params(metadata); - self.zone = Zone::Unknown; - (Event::CommandFinished { exit_code }, params) - } - _ => return, - }; - - on_event(LocatedEvent { - event, - start_offset: self.sequence_start.unwrap_or(0), - offset, - zone: self.zone, - params, - }); - } -} - -fn parse_command_finished_params(metadata: &[u8]) -> (Option, Params) { - if metadata.is_empty() { - return (None, Params::default()); - } - - let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { - return parse_exit_code(metadata).map_or_else( - || (None, parse_params(metadata)), - |exit_code| (Some(exit_code), Params::default()), - ); - }; - - let (first, rest) = metadata.split_at(separator); - let rest = &rest[1..]; - - parse_exit_code(first).map_or_else( - || (None, parse_params(metadata)), - |exit_code| (Some(exit_code), parse_params(rest)), - ) -} - -fn parse_exit_code(code: &[u8]) -> Option { - if code.is_empty() { - return None; - } - - std::str::from_utf8(code) - .ok() - .and_then(|code| code.parse::().ok()) -} - -fn parse_params(metadata: &[u8]) -> Params { - let items = metadata - .split(|byte| *byte == b';') - .filter(|part| !part.is_empty()) - .map(parse_param) - .collect(); - - Params { items } -} - -fn parse_param(param: &[u8]) -> Param { - let param = String::from_utf8_lossy(param); - - if let Some((key, value)) = param.split_once('=') { - return Param::KeyValue { - key: key.to_string(), - value: value.to_string(), - }; - } - - Param::Value(param.into_owned()) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - /// Collect all events from a single `push` call. - fn parse_events(data: &[u8]) -> Vec { - let mut parser = Parser::new(); - let mut events = Vec::new(); - parser.push(data, |e| events.push(e)); - events - } - - // -- Basic event detection ------------------------------------------------ - - #[test] - fn detect_prompt_start_bel() { - let data = b"\x1b]133;A\x07"; - assert_eq!(parse_events(data), vec![Event::PromptStart]); - } - - #[test] - fn detect_prompt_start_st() { - let data = b"\x1b]133;A\x1b\\"; - assert_eq!(parse_events(data), vec![Event::PromptStart]); - } - - #[test] - fn detect_command_start_bel() { - let data = b"\x1b]133;B\x07"; - assert_eq!(parse_events(data), vec![Event::CommandStart]); - } - - #[test] - fn detect_command_start_st() { - let data = b"\x1b]133;B\x1b\\"; - assert_eq!(parse_events(data), vec![Event::CommandStart]); - } - - #[test] - fn detect_command_executed_bel() { - let data = b"\x1b]133;C\x07"; - assert_eq!(parse_events(data), vec![Event::CommandExecuted]); - } - - #[test] - fn detect_command_executed_st() { - let data = b"\x1b]133;C\x1b\\"; - assert_eq!(parse_events(data), vec![Event::CommandExecuted]); - } - - #[test] - fn detect_command_finished_no_exit_code() { - let data = b"\x1b]133;D\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { exit_code: None }] - ); - } - - #[test] - fn detect_command_finished_exit_zero() { - let data = b"\x1b]133;D;0\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { exit_code: Some(0) }] - ); - } - - #[test] - fn detect_command_finished_exit_nonzero() { - let data = b"\x1b]133;D;127\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { - exit_code: Some(127) - }] - ); - } - - #[test] - fn detect_command_finished_negative_exit_code() { - let data = b"\x1b]133;D;-1\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { - exit_code: Some(-1) - }] - ); - } - - #[test] - fn detect_command_finished_exit_code_st() { - let data = b"\x1b]133;D;42\x1b\\"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { - exit_code: Some(42) - }] - ); - } - - #[test] - fn invalid_exit_code_yields_none() { - let data = b"\x1b]133;D;abc\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { exit_code: None }] - ); - } - - // -- Zone tracking -------------------------------------------------------- - - #[test] - fn zone_starts_unknown() { - let parser = Parser::new(); - assert_eq!(parser.zone(), Zone::Unknown); - } - - #[test] - fn full_zone_cycle() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b]133;A\x07", |e| events.push(e)); - assert_eq!(parser.zone(), Zone::Prompt); - - parser.push(b"\x1b]133;B\x07", |e| events.push(e)); - assert_eq!(parser.zone(), Zone::Input); - - parser.push(b"\x1b]133;C\x07", |e| events.push(e)); - assert_eq!(parser.zone(), Zone::Output); - - parser.push(b"\x1b]133;D;0\x07", |e| events.push(e)); - assert_eq!(parser.zone(), Zone::Unknown); - - assert_eq!( - events, - vec![ - Event::PromptStart, - Event::CommandStart, - Event::CommandExecuted, - Event::CommandFinished { exit_code: Some(0) }, - ] - ); - } - - // -- Multiple events in one push ------------------------------------------ - - #[test] - fn multiple_events_single_push() { - let data = b"\x1b]133;A\x07$ \x1b]133;B\x07ls\n\x1b]133;C\x07file.txt\n\x1b]133;D;0\x07"; - let events = parse_events(data); - assert_eq!( - events, - vec![ - Event::PromptStart, - Event::CommandStart, - Event::CommandExecuted, - Event::CommandFinished { exit_code: Some(0) }, - ] - ); - } - - // -- Split across push boundaries ----------------------------------------- - - #[test] - fn split_esc_and_bracket() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b", |e| events.push(e)); - assert!(events.is_empty()); - - parser.push(b"]133;A\x07", |e| events.push(e)); - assert_eq!(events, vec![Event::PromptStart]); - } - - #[test] - fn split_mid_param() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b]13", |e| events.push(e)); - assert!(events.is_empty()); - - parser.push(b"3;D;42\x07", |e| events.push(e)); - assert_eq!( - events, - vec![Event::CommandFinished { - exit_code: Some(42) - }] - ); - } - - #[test] - fn split_before_terminator() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b]133;B", |e| events.push(e)); - assert!(events.is_empty()); - - parser.push(b"\x07", |e| events.push(e)); - assert_eq!(events, vec![Event::CommandStart]); - } - - #[test] - fn split_esc_backslash_terminator() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b]133;C\x1b", |e| events.push(e)); - assert!(events.is_empty()); - - parser.push(b"\\", |e| events.push(e)); - assert_eq!(events, vec![Event::CommandExecuted]); - } - - // -- Interleaved normal text ---------------------------------------------- - - #[test] - fn normal_text_before_and_after() { - let data = b"hello world\x1b]133;A\x07prompt text\x1b]133;B\x07command"; - let events = parse_events(data); - assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); - } - - // -- Non-133 OSC sequences (should be ignored) ---------------------------- - - #[test] - fn non_133_osc_ignored() { - let data = b"\x1b]0;window title\x07\x1b]133;A\x07"; - let events = parse_events(data); - assert_eq!(events, vec![Event::PromptStart]); - } - - #[test] - fn osc_7_ignored() { - let data = b"\x1b]7;file:///home/user\x07"; - assert!(parse_events(data).is_empty()); - } - - // -- Unknown command letter ----------------------------------------------- - - #[test] - fn unknown_command_ignored() { - let data = b"\x1b]133;Z\x07"; - assert!(parse_events(data).is_empty()); - } - - #[test] - fn marker_with_unexpected_trailing_bytes_ignored() { - let data = b"\x1b]133;ABC\x07"; - assert!(parse_events(data).is_empty()); - } - - // -- Malformed sequences -------------------------------------------------- - - #[test] - fn esc_followed_by_non_bracket() { - let data = b"\x1b[31m\x1b]133;A\x07"; - let events = parse_events(data); - assert_eq!(events, vec![Event::PromptStart]); - } - - #[test] - fn lone_esc_at_end_of_chunk() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push(b"\x1b", |e| events.push(e)); - assert!(events.is_empty()); - - // Feed non-bracket to abort the escape, then a real sequence. - parser.push(b"x\x1b]133;A\x07", |e| events.push(e)); - assert_eq!(events, vec![Event::PromptStart]); - } - - #[test] - fn truncated_133_prefix() { - // "13" followed by terminator — not "133;" so no event. - let data = b"\x1b]13\x07"; - assert!(parse_events(data).is_empty()); - } - - #[test] - fn empty_osc() { - let data = b"\x1b]\x07"; - assert!(parse_events(data).is_empty()); - } - - // -- Buffer overflow (very long non-133 OSC) ------------------------------ - - #[test] - fn very_long_osc_does_not_panic() { - let mut data = Vec::new(); - data.extend_from_slice(b"\x1b]"); - data.extend(std::iter::repeat_n(b'x', 1000)); - data.push(BEL); - // Should not panic and should produce no event. - assert!(parse_events(&data).is_empty()); - } - - // -- Empty input ---------------------------------------------------------- - - #[test] - fn empty_input() { - assert!(parse_events(b"").is_empty()); - } - - #[test] - fn only_normal_text() { - let data = b"just some regular terminal output\r\n"; - assert!(parse_events(data).is_empty()); - } - - // -- Repeated prompts (empty command) ------------------------------------ - - #[test] - fn repeated_prompt_cycle() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - // User hits enter on an empty prompt twice. - let data = b"\x1b]133;A\x07$ \x1b]133;B\x07\x1b]133;D\x07\x1b]133;A\x07$ \x1b]133;B\x07"; - parser.push(data, |e| events.push(e)); - - assert_eq!( - events, - vec![ - Event::PromptStart, - Event::CommandStart, - Event::CommandFinished { exit_code: None }, - Event::PromptStart, - Event::CommandStart, - ] - ); - assert_eq!(parser.zone(), Zone::Input); - } - - // -- Byte-at-a-time feeding ----------------------------------------------- - - #[test] - fn byte_at_a_time() { - let data = b"\x1b]133;D;99\x07"; - let mut parser = Parser::new(); - let mut events = Vec::new(); - - for &byte in data { - parser.push(&[byte], |e| events.push(e)); - } - - assert_eq!( - events, - vec![Event::CommandFinished { - exit_code: Some(99) - }] - ); - } - - // -- Mixed terminators ---------------------------------------------------- - - #[test] - fn mixed_bel_and_st_terminators() { - let data = b"\x1b]133;A\x07\x1b]133;B\x1b\\\x1b]133;C\x07\x1b]133;D;1\x1b\\"; - let events = parse_events(data); - assert_eq!( - events, - vec![ - Event::PromptStart, - Event::CommandStart, - Event::CommandExecuted, - Event::CommandFinished { exit_code: Some(1) }, - ] - ); - } - - #[test] - fn detects_c1_st_terminator() { - let data = b"\x1b]133;A\x9c"; - assert_eq!(parse_events(data), vec![Event::PromptStart]); - } - - // -- Located event offsets ------------------------------------------------ - - #[test] - fn located_event_reports_offset_after_marker() { - let data = b"before\x1b]133;A\x07prompt"; - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push_located(data, |e| events.push(e)); - - assert_eq!( - events, - vec![LocatedEvent { - event: Event::PromptStart, - start_offset: b"before".len(), - offset: b"before\x1b]133;A\x07".len(), - zone: Zone::Prompt, - params: Params::default(), - }] - ); - } - - #[test] - fn located_event_offset_is_relative_to_completing_chunk() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push_located(b"\x1b]133;", |e| events.push(e)); - parser.push_located(b"D;42\x07after", |e| events.push(e)); - - assert_eq!( - events, - vec![LocatedEvent { - event: Event::CommandFinished { - exit_code: Some(42) - }, - start_offset: 0, - offset: b"D;42\x07".len(), - zone: Zone::Unknown, - params: Params::default(), - }] - ); - } - - #[test] - fn located_event_preserves_metadata_params() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push_located( - b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", - |event| events.push(event), - ); - - assert_eq!(events.len(), 1); - let event = &events[0]; - assert_eq!( - event.event, - Event::CommandFinished { - exit_code: Some(127) - } - ); - assert_eq!(event.params.get("history_id"), Some("018f")); - assert_eq!(event.params.get("session_id"), Some("abcd")); - assert!( - event - .params - .iter() - .any(|param| param == &Param::Value("flag".to_string())) - ); - } - - #[test] - fn command_finished_metadata_without_exit_code_is_preserved() { - let mut parser = Parser::new(); - let mut events = Vec::new(); - - parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { - events.push(event); - }); - - assert_eq!(events.len(), 1); - let event = &events[0]; - assert_eq!(event.event, Event::CommandFinished { exit_code: None }); - assert_eq!(event.params.get("history_id"), Some("018f")); - assert_eq!(event.params.get("session_id"), Some("abcd")); - } - - // -- Default trait -------------------------------------------------------- - - #[test] - fn parser_default() { - let parser = Parser::default(); - assert_eq!(parser.zone(), Zone::Unknown); - } - - #[test] - fn zone_default() { - assert_eq!(Zone::default(), Zone::Unknown); - } - - // -- D with empty exit code field ----------------------------------------- - - #[test] - fn d_with_semicolon_but_empty_code() { - // "133;D;" — semicolon present but no digits. - let data = b"\x1b]133;D;\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { exit_code: None }] - ); - } - - // -- Consecutive OSC sequences without gap -------------------------------- - - #[test] - fn back_to_back_osc_no_gap() { - let data = b"\x1b]133;A\x07\x1b]133;B\x07"; - let events = parse_events(data); - assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); - } - - // -- CSI sequences interleaved (should not confuse parser) ---------------- - - #[test] - fn csi_sequences_ignored() { - // CSI (ESC [) color codes mixed with OSC 133. - let data = b"\x1b[32m\x1b]133;A\x07\x1b[0m$ \x1b]133;B\x07"; - let events = parse_events(data); - assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); - } - - // -- Large exit codes ----------------------------------------------------- - - #[test] - fn large_exit_code() { - let data = b"\x1b]133;D;2147483647\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { - exit_code: Some(i32::MAX) - }] - ); - } - - #[test] - fn overflow_exit_code_yields_none() { - let data = b"\x1b]133;D;9999999999999\x07"; - assert_eq!( - parse_events(data), - vec![Event::CommandFinished { exit_code: None }] - ); - } -} diff --git a/crates/atuin-pty-proxy/src/pty_proxy.rs b/crates/atuin-pty-proxy/src/pty_proxy.rs deleted file mode 100644 index 19ccd274..00000000 --- a/crates/atuin-pty-proxy/src/pty_proxy.rs +++ /dev/null @@ -1,231 +0,0 @@ -use clap::{Args, Subcommand, ValueEnum}; - -use crate::{CommandCaptureSink, runtime}; - -#[derive(Args, Debug)] -pub struct PtyProxy { - /// Highlight OSC 133 prompt, input, output, and exit-code regions - #[arg(long)] - debug_osc133: bool, - - #[command(subcommand)] - cmd: Option, -} - -#[derive(Subcommand, Debug)] -pub enum Cmd { - /// Print shell code to initialize atuin pty-proxy on shell startup - Init(Init), -} - -#[derive(Args, Debug)] -pub struct Init { - /// Shell to generate init for. If omitted, attempt auto-detection - #[arg(value_enum)] - shell: Option, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] -#[value(rename_all = "lower")] -#[expect(clippy::enum_variant_names, clippy::doc_markdown)] -enum Shell { - /// Zsh setup - Zsh, - /// Bash setup - Bash, - /// Fish setup - Fish, - /// Nu setup - Nu, -} - -pub(crate) struct RuntimeOptions { - pub(crate) debug_osc133: bool, - pub(crate) command_capture_sink: Option, -} - -impl RuntimeOptions { - fn new(debug_osc133: bool, command_capture_sink: Option) -> Self { - Self { - debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), - command_capture_sink, - } - } -} - -impl PtyProxy { - pub fn run(self, command_capture_sink: Option) { - match self.cmd { - Some(Cmd::Init(init)) => { - if let Err(err) = init.run() { - eprintln!("atuin pty-proxy: {err}"); - std::process::exit(1); - } - } - None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), - } - } -} - -impl Init { - fn run(self) -> Result<(), String> { - let shell = detect_shell(self.shell)?; - let script = render_init(shell); - print!("{script}"); - Ok(()) - } -} - -fn detect_shell(cli_shell: Option) -> Result { - if let Some(shell) = cli_shell { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("ATUIN_SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - if let Ok(shell) = std::env::var("SHELL") - && let Some(shell) = shell_from_name(&shell) - { - return Ok(shell); - } - - Err( - "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" - .to_string(), - ) -} - -fn shell_from_name(name: &str) -> Option { - let shell = name - .trim() - .rsplit('/') - .next() - .unwrap_or(name) - .trim_start_matches('-') - .to_ascii_lowercase(); - - match shell.as_str() { - "bash" => Some(Shell::Bash), - "zsh" => Some(Shell::Zsh), - "fish" => Some(Shell::Fish), - "nu" => Some(Shell::Nu), - _ => None, - } -} - -fn env_flag(name: &str) -> bool { - std::env::var(name).is_ok_and(|value| { - matches!( - value.trim().to_ascii_lowercase().as_str(), - "1" | "true" | "yes" | "on" - ) - }) -} - -fn render_init(shell: Shell) -> &'static str { - match shell { - Shell::Bash | Shell::Zsh => { - r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then - _atuin_pty_proxy_tmux_current="${TMUX:-}" - _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" - - if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then - export ATUIN_PTY_PROXY_ACTIVE=1 - export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - fi - - unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous -fi -"# - } - Shell::Fish => { - r#"if status is-interactive; and test -t 0; and test -t 1 - set -l _atuin_pty_proxy_tmux_current "" - if set -q TMUX - set _atuin_pty_proxy_tmux_current "$TMUX" - end - - set -l _atuin_pty_proxy_tmux_previous "" - if set -q ATUIN_PTY_PROXY_TMUX - set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" - end - - if not set -q ATUIN_PTY_PROXY_ACTIVE - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" - set -gx ATUIN_PTY_PROXY_ACTIVE 1 - set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" - exec atuin pty-proxy - end -end -"# - } - // Nushell cannot dynamically source the output of `atuin init nu`, - // so we only output the pty-proxy preamble here. Users must also set up - // `atuin init nu` separately. - Shell::Nu => { - r#"if (is-terminal --stdin) and (is-terminal --stdout) { - let tmux_current = ($env.TMUX? | default "") - let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") - - if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { - $env.ATUIN_PTY_PROXY_ACTIVE = "1" - $env.ATUIN_PTY_PROXY_TMUX = $tmux_current - exec atuin pty-proxy - } -} -"# - } - } -} - -#[cfg(test)] -mod tests { - use super::{Shell, render_init, shell_from_name}; - - #[test] - fn shell_from_name_handles_paths() { - assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); - assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); - assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); - assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); - } - - #[test] - fn posix_init_uses_exec_and_tmux_guard() { - let script = render_init(Shell::Bash); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(!script.contains("eval \"$(atuin init bash)\"")); - } - - #[test] - fn posix_init_has_no_double_braces() { - let script = render_init(Shell::Bash); - assert!(!script.contains("${{"), "double braces in bash init script"); - } - - #[test] - fn fish_init_uses_source() { - let script = render_init(Shell::Fish); - assert!(script.contains("exec atuin pty-proxy")); - assert!(!script.contains("atuin init fish | source")); - } - - #[test] - fn nu_init_uses_exec_and_tty_guard() { - let script = render_init(Shell::Nu); - assert!(script.contains("exec atuin pty-proxy")); - assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); - assert!(script.contains("is-terminal --stdin")); - assert!(script.contains("is-terminal --stdout")); - assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); - } -} diff --git a/crates/atuin-pty-proxy/src/runtime.rs b/crates/atuin-pty-proxy/src/runtime.rs deleted file mode 100644 index 2b34fbb7..00000000 --- a/crates/atuin-pty-proxy/src/runtime.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::io::{Read, Write}; -use std::sync::Arc; -use std::sync::atomic::{AtomicU16, Ordering}; -use std::sync::mpsc; - -use crossterm::terminal; -use portable_pty::{CommandBuilder, PtySize, native_pty_system}; - -use crate::capture::CommandCaptureTracker; -use crate::debug::{Osc133DebugHighlighter, RESET}; -use crate::pty_proxy::RuntimeOptions; -use crate::screen::{self, Msg}; - -pub(crate) fn main(options: RuntimeOptions) { - if let Err(e) = run(options) { - let _ = terminal::disable_raw_mode(); - eprintln!("atuin pty-proxy: {e:#}"); - std::process::exit(1); - } -} - -fn run(options: RuntimeOptions) -> eyre::Result<()> { - let (cols, rows) = terminal::size()?; - - let pty_system = native_pty_system(); - let pair = pty_system - .openpty(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - let sock_path = screen::socket_path(); - let _ = std::fs::remove_file(&sock_path); - - let mut cmd = CommandBuilder::new_default_prog(); - cmd.cwd(std::env::current_dir()?); - cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); - cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); - - let mut child = pair - .slave - .spawn_command(cmd) - .map_err(|e| eyre::eyre!("{e:#}"))?; - - drop(pair.slave); - - let mut pty_reader = pair - .master - .try_clone_reader() - .map_err(|e| eyre::eyre!("{e:#}"))?; - let mut pty_writer = pair - .master - .take_writer() - .map_err(|e| eyre::eyre!("{e:#}"))?; - - let (msg_tx, msg_rx) = mpsc::sync_channel::(64); - let current_cols = Arc::new(AtomicU16::new(cols.max(1))); - - screen::spawn_parser_thread(rows, cols, msg_rx); - screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); - spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; - - terminal::enable_raw_mode()?; - - let stdout_thread = std::thread::spawn(move || { - let mut stdout = std::io::stdout(); - let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); - let mut capture_tracker = options - .command_capture_sink - .as_ref() - .map(|_| CommandCaptureTracker::new(current_cols)); - let mut buf = [0u8; 8192]; - - loop { - match pty_reader.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - if let (Some(tracker), Some(sink)) = ( - capture_tracker.as_mut(), - options.command_capture_sink.as_ref(), - ) { - tracker.push(&buf[..n], sink); - } - - if let Some(highlighter) = highlighter.as_mut() { - let rendered = highlighter.render(&buf[..n]); - let _ = msg_tx.try_send(Msg::Data(rendered.clone())); - - if stdout.write_all(&rendered).is_err() { - break; - } - } else { - let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); - - if stdout.write_all(&buf[..n]).is_err() { - break; - } - } - let _ = stdout.flush(); - } - } - } - - if highlighter.is_some() { - let _ = stdout.write_all(RESET); - let _ = stdout.flush(); - } - }); - - std::thread::spawn(move || { - let mut stdin = std::io::stdin(); - let mut buf = [0u8; 8192]; - loop { - match stdin.read(&mut buf) { - Ok(0) | Err(_) => break, - Ok(n) => { - if pty_writer.write_all(&buf[..n]).is_err() { - break; - } - } - } - } - }); - - let status = child.wait()?; - let _ = stdout_thread.join(); - - let _ = terminal::disable_raw_mode(); - let _ = std::fs::remove_file(&sock_path); - - std::process::exit(process_exit_code(status.exit_code())); -} - -fn spawn_resize_handler( - master: Box, - resize_tx: mpsc::SyncSender, - current_cols: Arc, -) -> eyre::Result<()> { - use signal_hook::consts::SIGWINCH; - use signal_hook::iterator::Signals; - - let mut signals = Signals::new([SIGWINCH])?; - - std::thread::spawn(move || { - for _ in signals.forever() { - if let Ok((cols, rows)) = terminal::size() { - current_cols.store(cols.max(1), Ordering::Relaxed); - let _ = master.resize(PtySize { - rows, - cols, - pixel_width: 0, - pixel_height: 0, - }); - let _ = resize_tx.try_send(Msg::Resize { rows, cols }); - } - } - }); - - Ok(()) -} - -fn process_exit_code(code: u32) -> i32 { - i32::try_from(code).unwrap_or(1) -} - -#[cfg(test)] -mod tests { - use super::process_exit_code; - - #[test] - fn process_exit_code_preserves_valid_values() { - assert_eq!(process_exit_code(0), 0); - assert_eq!(process_exit_code(127), 127); - assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); - } - - #[test] - fn process_exit_code_defaults_when_out_of_range() { - assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); - } -} diff --git a/crates/atuin-pty-proxy/src/screen.rs b/crates/atuin-pty-proxy/src/screen.rs deleted file mode 100644 index 5b892e21..00000000 --- a/crates/atuin-pty-proxy/src/screen.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::io::Write; -use std::os::unix::net::UnixListener; -use std::path::PathBuf; -use std::sync::mpsc::{self, Receiver, SyncSender}; - -pub(crate) enum Msg { - Data(Vec), - Resize { rows: u16, cols: u16 }, - ScreenRequest(mpsc::Sender>), -} - -pub(crate) fn socket_path() -> PathBuf { - let dir = std::env::temp_dir(); - dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) -} - -pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver) { - std::thread::spawn(move || { - let mut parser = vt100::Parser::new(rows, cols, 0); - - loop { - let first = match msg_rx.recv() { - Ok(msg) => msg, - Err(_) => break, - }; - - handle_parser_msg(&mut parser, first); - - while let Ok(msg) = msg_rx.try_recv() { - handle_parser_msg(&mut parser, msg); - } - } - }); -} - -pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender) { - std::thread::spawn(move || { - let listener = match UnixListener::bind(&sock_path) { - Ok(l) => l, - Err(e) => { - eprintln!("atuin pty-proxy: failed to bind socket: {e}"); - return; - } - }; - - for stream in listener.incoming() { - let mut stream = match stream { - Ok(s) => s, - Err(_) => break, - }; - - let (reply_tx, reply_rx) = mpsc::channel(); - if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { - break; - } - if let Ok(data) = reply_rx.recv() { - let _ = stream.write_all(&data); - let _ = stream.flush(); - } - } - }); -} - -/// Wire format written to the Unix socket: -/// -/// ```text -/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] -/// [row_0_len: u32 BE][row_0_bytes...] -/// [row_1_len: u32 BE][row_1_bytes...] -/// ... -/// ``` -/// -/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain -/// pre-built ANSI escape sequences. The client can write them directly to -/// stdout without needing its own vt100 parser. -fn encode_screen(parser: &vt100::Parser) -> Vec { - let screen = parser.screen(); - let (rows, cols) = screen.size(); - let (cursor_row, cursor_col) = screen.cursor_position(); - - let mut buf: Vec = Vec::with_capacity(256 + (rows as usize * cols as usize)); - buf.extend_from_slice(&rows.to_be_bytes()); - buf.extend_from_slice(&cols.to_be_bytes()); - buf.extend_from_slice(&cursor_row.to_be_bytes()); - buf.extend_from_slice(&cursor_col.to_be_bytes()); - - for row_bytes in screen.rows_formatted(0, cols) { - let len = row_bytes.len() as u32; - buf.extend_from_slice(&len.to_be_bytes()); - buf.extend_from_slice(&row_bytes); - } - - buf -} - -fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { - match msg { - Msg::Data(data) => parser.process(&data), - Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), - Msg::ScreenRequest(reply_tx) => { - let _ = reply_tx.send(encode_screen(parser)); - } - } -} diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml deleted file mode 100644 index 52ccbf97..00000000 --- a/crates/atuin-server-database/Cargo.toml +++ /dev/null @@ -1,21 +0,0 @@ -[package] -name = "atuin-server-database" -edition = "2024" -description = "server database library for atuin" - -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[dependencies] -atuin-common = { path = "../atuin-common", version = "18.16.1" } - -async-trait = { workspace = true } -eyre = { workspace = true } -serde = { workspace = true } -sqlx = { workspace = true } -time = { workspace = true } -tracing = { workspace = true } -url = "2.5.2" diff --git a/crates/atuin-server-database/src/calendar.rs b/crates/atuin-server-database/src/calendar.rs deleted file mode 100644 index 2229667b..00000000 --- a/crates/atuin-server-database/src/calendar.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Calendar data - -use serde::{Deserialize, Serialize}; -use time::Month; - -pub enum TimePeriod { - Year, - Month { year: i32 }, - Day { year: i32, month: Month }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TimePeriodInfo { - pub count: u64, - - // TODO: Use this for merkle tree magic - pub hash: String, -} diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs deleted file mode 100644 index 5437fc15..00000000 --- a/crates/atuin-server-database/src/lib.rs +++ /dev/null @@ -1,268 +0,0 @@ -#![forbid(unsafe_code)] - -pub mod calendar; -pub mod models; - -use std::{ - collections::HashMap, - fmt::{Debug, Display}, - ops::Range, -}; - -use self::{ - calendar::{TimePeriod, TimePeriodInfo}, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use serde::{Deserialize, Serialize}; -use time::{Date, Duration, Month, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; -use tracing::instrument; - -#[derive(Debug)] -pub enum DbError { - NotFound, - Other(eyre::Report), -} - -impl Display for DbError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl From for DbError { - fn from(error: time::error::ComponentRange) -> Self { - DbError::Other(error.into()) - } -} - -impl From for DbError { - fn from(error: time::error::Error) -> Self { - DbError::Other(error.into()) - } -} - -impl From for DbError { - fn from(error: sqlx::Error) -> Self { - match error { - sqlx::Error::RowNotFound => DbError::NotFound, - error => DbError::Other(error.into()), - } - } -} - -impl std::error::Error for DbError {} - -pub type DbResult = Result; - -#[derive(Debug, PartialEq)] -pub enum DbType { - Postgres, - Sqlite, - Unknown, -} - -#[derive(Clone, Deserialize, Serialize)] -pub struct DbSettings { - pub db_uri: String, - /// Optional URI for read replicas. If set, read-only queries will use this connection. - pub read_db_uri: Option, -} - -impl DbSettings { - pub fn db_type(&self) -> DbType { - if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") { - DbType::Postgres - } else if self.db_uri.starts_with("sqlite://") { - DbType::Sqlite - } else { - DbType::Unknown - } - } -} - -fn redact_db_uri(uri: &str) -> String { - url::Url::parse(uri) - .map(|mut url| { - let _ = url.set_password(Some("****")); - url.to_string() - }) - .unwrap_or_else(|_| uri.to_string()) -} - -// Do our best to redact passwords so they're not logged in the event of an error. -impl Debug for DbSettings { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.db_type() == DbType::Postgres { - let redacted_uri = redact_db_uri(&self.db_uri); - let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); - f.debug_struct("DbSettings") - .field("db_uri", &redacted_uri) - .field("read_db_uri", &redacted_read_uri) - .finish() - } else { - f.debug_struct("DbSettings") - .field("db_uri", &self.db_uri) - .field("read_db_uri", &self.read_db_uri) - .finish() - } - } -} - -#[async_trait] -pub trait Database: Sized + Clone + Send + Sync + 'static { - async fn new(settings: &DbSettings) -> DbResult; - - async fn get_session(&self, token: &str) -> DbResult; - async fn get_session_user(&self, token: &str) -> DbResult; - async fn add_session(&self, session: &NewSession) -> DbResult<()>; - - async fn get_user(&self, username: &str) -> DbResult; - async fn get_user_session(&self, u: &User) -> DbResult; - async fn add_user(&self, user: &NewUser) -> DbResult; - - async fn update_user_password(&self, u: &User) -> DbResult<()>; - - async fn count_history(&self, user: &User) -> DbResult; - async fn count_history_cached(&self, user: &User) -> DbResult; - - async fn delete_user(&self, u: &User) -> DbResult<()>; - async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; - async fn deleted_history(&self, user: &User) -> DbResult>; - async fn delete_store(&self, user: &User) -> DbResult<()>; - - async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>; - async fn next_records( - &self, - user: &User, - host: HostId, - tag: String, - start: Option, - count: u64, - ) -> DbResult>>; - - // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) - async fn status(&self, user: &User) -> DbResult; - - async fn count_history_range(&self, user: &User, range: Range) - -> DbResult; - - async fn list_history( - &self, - user: &User, - created_after: OffsetDateTime, - since: OffsetDateTime, - host: &str, - page_size: i64, - ) -> DbResult>; - - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; - - async fn oldest_history(&self, user: &User) -> DbResult; - - #[instrument(skip_all)] - async fn calendar( - &self, - user: &User, - period: TimePeriod, - tz: UtcOffset, - ) -> DbResult> { - let mut ret = HashMap::new(); - let iter: Box)>> + Send> = match period { - TimePeriod::Year => { - // First we need to work out how far back to calculate. Get the - // oldest history item - let oldest = self - .oldest_history(user) - .await? - .timestamp - .to_offset(tz) - .year(); - let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); - - // All the years we need to get data for - // The upper bound is exclusive, so include current +1 - let years = oldest..current_year + 1; - - Box::new(years.map(|year| { - let start = Date::from_calendar_date(year, time::Month::January, 1)?; - let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; - - Ok((year as u64, start..end)) - })) - } - - TimePeriod::Month { year } => { - let months = - std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); - - Box::new(months.map(move |month| { - let start = Date::from_calendar_date(year, month, 1)?; - let days = start.month().length(year); - let end = start + Duration::days(days as i64); - - Ok((month as u64, start..end)) - })) - } - - TimePeriod::Day { year, month } => { - let days = 1..month.length(year); - Box::new(days.map(move |day| { - let start = Date::from_calendar_date(year, month, day)?; - let end = start - .next_day() - .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; - - Ok((day as u64, start..end)) - })) - } - }; - - for x in iter { - let (index, range) = x?; - - let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); - let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); - - let count = self.count_history_range(user, start..end).await?; - - ret.insert( - index, - TimePeriodInfo { - count: count as u64, - hash: "".to_string(), - }, - ); - } - - Ok(ret) - } -} - -pub fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { - let x = x.to_offset(UtcOffset::UTC); - PrimitiveDateTime::new(x.date(), x.time()) -} - -#[cfg(test)] -mod tests { - use time::macros::datetime; - - use crate::into_utc; - - #[test] - fn utc() { - let dt = datetime!(2023-09-26 15:11:02 +05:30); - assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - - let dt = datetime!(2023-09-26 15:11:02 -07:00); - assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - - let dt = datetime!(2023-09-26 15:11:02 +00:00); - assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); - assert_eq!(into_utc(dt).assume_utc(), dt); - } -} diff --git a/crates/atuin-server-database/src/models.rs b/crates/atuin-server-database/src/models.rs deleted file mode 100644 index b71a9bc9..00000000 --- a/crates/atuin-server-database/src/models.rs +++ /dev/null @@ -1,52 +0,0 @@ -use time::OffsetDateTime; - -pub struct History { - pub id: i64, - pub client_id: String, // a client generated ID - pub user_id: i64, - pub hostname: String, - pub timestamp: OffsetDateTime, - - /// All the data we have about this command, encrypted. - /// - /// Currently this is an encrypted msgpack object, but this may change in the future. - pub data: String, - - pub created_at: OffsetDateTime, -} - -pub struct NewHistory { - pub client_id: String, - pub user_id: i64, - pub hostname: String, - pub timestamp: OffsetDateTime, - - /// All the data we have about this command, encrypted. - /// - /// Currently this is an encrypted msgpack object, but this may change in the future. - pub data: String, -} - -pub struct User { - pub id: i64, - pub username: String, - pub email: String, - pub password: String, -} - -pub struct Session { - pub id: i64, - pub user_id: i64, - pub token: String, -} - -pub struct NewUser { - pub username: String, - pub email: String, - pub password: String, -} - -pub struct NewSession { - pub user_id: i64, - pub token: String, -} diff --git a/crates/atuin-server-postgres/Cargo.toml b/crates/atuin-server-postgres/Cargo.toml deleted file mode 100644 index ea19899e..00000000 --- a/crates/atuin-server-postgres/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "atuin-server-postgres" -edition = "2024" -description = "server postgres database library for atuin" - -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[dependencies] -atuin-common = { path = "../atuin-common", version = "18.16.1" } -atuin-server-database = { path = "../atuin-server-database", version = "18.16.1" } - -eyre = { workspace = true } -tracing = { workspace = true } -time = { workspace = true } -serde = { workspace = true } -sqlx = { workspace = true } -async-trait = { workspace = true } -uuid = { workspace = true } -metrics = "0.24" -futures-util = "0.3" -rand.workspace = true \ No newline at end of file diff --git a/crates/atuin-server-postgres/build.rs b/crates/atuin-server-postgres/build.rs deleted file mode 100644 index d5068697..00000000 --- a/crates/atuin-server-postgres/build.rs +++ /dev/null @@ -1,5 +0,0 @@ -// generated by `sqlx migrate build-script` -fn main() { - // trigger recompilation when a new migration is added - println!("cargo:rerun-if-changed=migrations"); -} diff --git a/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql b/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql deleted file mode 100644 index 2c2d17b0..00000000 --- a/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql +++ /dev/null @@ -1,11 +0,0 @@ -create table history ( - id bigserial primary key, - client_id text not null unique, -- the client-generated ID - user_id bigserial not null, -- allow multiple users - hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) - timestamp timestamp not null, -- one of the few non-encrypted metadatas - - data varchar(8192) not null, -- store the actual history data, encrypted. I don't wanna know! - - created_at timestamp not null default current_timestamp -); diff --git a/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql b/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql deleted file mode 100644 index a25dcced..00000000 --- a/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql +++ /dev/null @@ -1,10 +0,0 @@ -create table users ( - id bigserial primary key, -- also store our own ID - username varchar(32) not null unique, -- being able to contact users is useful - email varchar(128) not null unique, -- being able to contact users is useful - password varchar(128) not null unique -); - --- the prior index is case sensitive :( -CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); -CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql b/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql deleted file mode 100644 index c2fb6559..00000000 --- a/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql +++ /dev/null @@ -1,6 +0,0 @@ --- Add migration script here -create table sessions ( - id bigserial primary key, - user_id bigserial, - token varchar(128) unique not null -); diff --git a/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql b/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql deleted file mode 100644 index dd1afa88..00000000 --- a/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql +++ /dev/null @@ -1,51 +0,0 @@ --- Prior to this, the count endpoint was super naive and just ran COUNT(1). --- This is slow asf. Now that we have an amount of actual traffic, --- stop doing that! --- This basically maintains a count, so we can read ONE row, instead of ALL the --- rows. Much better. --- Future optimisation could use some sort of cache so we don't even need to hit --- postgres at all. - -create table total_history_count_user( - id bigserial primary key, - user_id bigserial, - total integer -- try and avoid using keywords - hence total, not count -); - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - - elsif (TG_OP='DELETE') then - update total_history_count_user set total = total - 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value - -create trigger tg_user_history_count - after insert or delete on history - for each row - execute procedure user_history_count(); diff --git a/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql b/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql deleted file mode 100644 index 6198f300..00000000 --- a/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql +++ /dev/null @@ -1,35 +0,0 @@ --- the old version of this function used NEW in the delete part when it should --- use OLD - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - - elsif (TG_OP='DELETE') then - update total_history_count_user set total = total - 1 where user_id = old.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - old.user_id, - (select count(1) from history where user_id = old.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value diff --git a/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql b/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql deleted file mode 100644 index 0ac43433..00000000 --- a/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql +++ /dev/null @@ -1,3 +0,0 @@ --- Make it 4x larger. Most commands are less than this, but as it's base64 --- SOME are more than 8192. Should be enough for now. -ALTER TABLE history ALTER COLUMN data TYPE varchar(32768); diff --git a/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql b/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql deleted file mode 100644 index a9138194..00000000 --- a/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql +++ /dev/null @@ -1 +0,0 @@ -alter table users add column created_at timestamp not null default now(); diff --git a/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql b/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql deleted file mode 100644 index 57e16ec7..00000000 --- a/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql +++ /dev/null @@ -1,14 +0,0 @@ -create type event_type as enum ('create', 'delete'); - -create table events ( - id bigserial primary key, - client_id text not null unique, -- the client-generated ID - user_id bigserial not null, -- allow multiple users - hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) - timestamp timestamp not null, -- one of the few non-encrypted metadatas - - event_type event_type, - data text not null, -- store the actual history data, encrypted. I don't wanna know! - - created_at timestamp not null default current_timestamp -); diff --git a/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql b/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql deleted file mode 100644 index b1c23016..00000000 --- a/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -alter table history alter column data type text; diff --git a/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql b/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql deleted file mode 100644 index fe3cae17..00000000 --- a/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -drop table events; diff --git a/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql b/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql deleted file mode 100644 index 9a9e6263..00000000 --- a/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql +++ /dev/null @@ -1,5 +0,0 @@ --- Add migration script here -alter table history add column if not exists deleted_at timestamp; - --- queries will all be selecting the ids of history for a user, that has been deleted -create index if not exists history_deleted_index on history(client_id, user_id, deleted_at); diff --git a/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql b/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql deleted file mode 100644 index 3d0bba52..00000000 --- a/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql +++ /dev/null @@ -1,30 +0,0 @@ --- We do not need to run the trigger on deletes, as the only time we are deleting history is when the user --- has already been deleted --- This actually slows down deleting all the history a good bit! - -create or replace function user_history_count() -returns trigger as -$func$ -begin - if (TG_OP='INSERT') then - update total_history_count_user set total = total + 1 where user_id = new.user_id; - - if not found then - insert into total_history_count_user(user_id, total) - values ( - new.user_id, - (select count(1) from history where user_id = new.user_id) - ); - end if; - end if; - - return NEW; -- this is actually ignored for an after trigger, but oh well -end; -$func$ -language plpgsql volatile -- pldfplplpflh -cost 100; -- default value - -create or replace trigger tg_user_history_count - after insert on history - for each row - execute procedure user_history_count(); diff --git a/crates/atuin-server-postgres/migrations/20230623070418_records.sql b/crates/atuin-server-postgres/migrations/20230623070418_records.sql deleted file mode 100644 index 22437595..00000000 --- a/crates/atuin-server-postgres/migrations/20230623070418_records.sql +++ /dev/null @@ -1,15 +0,0 @@ --- Add migration script here -create table records ( - id uuid primary key, -- remember to use uuidv7 for happy indices <3 - client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key - host uuid not null, -- a unique identifier for the host - parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list - timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision - version text not null, - tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host - data text not null, -- store the actual history data, encrypted. I don't wanna know! - cek text not null, - - user_id bigint not null, -- allow multiple users - created_at timestamp not null default current_timestamp -); diff --git a/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql b/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql deleted file mode 100644 index ffb57966..00000000 --- a/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql +++ /dev/null @@ -1,15 +0,0 @@ --- Add migration script here -create table store ( - id uuid primary key, -- remember to use uuidv7 for happy indices <3 - client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically - host uuid not null, -- a unique identifier for the host - idx bigint not null, -- the index of the record in this store, identified by (host, tag) - timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision - version text not null, - tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host - data text not null, -- store the actual history data, encrypted. I don't wanna know! - cek text not null, - - user_id bigint not null, -- allow multiple users - created_at timestamp not null default current_timestamp -); diff --git a/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql b/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql deleted file mode 100644 index 56d67145..00000000 --- a/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql +++ /dev/null @@ -1,2 +0,0 @@ --- Add migration script here -create unique index record_uniq ON store(user_id, host, tag, idx); diff --git a/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql b/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql deleted file mode 100644 index ad2af5a1..00000000 --- a/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql +++ /dev/null @@ -1,4 +0,0 @@ --- Add migration script here -alter table history alter column user_id drop default; -alter table sessions alter column user_id drop default; -alter table total_history_count_user alter column user_id drop default; diff --git a/crates/atuin-server-postgres/migrations/20240614104159_idx-cache.sql b/crates/atuin-server-postgres/migrations/20240614104159_idx-cache.sql deleted file mode 100644 index 76425ed7..00000000 --- a/crates/atuin-server-postgres/migrations/20240614104159_idx-cache.sql +++ /dev/null @@ -1,8 +0,0 @@ -create table store_idx_cache( - id bigserial primary key, - user_id bigint, - - host uuid, - tag text, - idx bigint -); diff --git a/crates/atuin-server-postgres/migrations/20240621110731_user-verified.sql b/crates/atuin-server-postgres/migrations/20240621110731_user-verified.sql deleted file mode 100644 index 6eba02ec..00000000 --- a/crates/atuin-server-postgres/migrations/20240621110731_user-verified.sql +++ /dev/null @@ -1,8 +0,0 @@ -alter table users add verified_at timestamp with time zone default null; - -create table user_verification_token( - id bigserial primary key, - user_id bigint unique references users(id), - token text, - valid_until timestamp with time zone -); diff --git a/crates/atuin-server-postgres/migrations/20240702094825_idx_cache_index.sql b/crates/atuin-server-postgres/migrations/20240702094825_idx_cache_index.sql deleted file mode 100644 index d1a7b194..00000000 --- a/crates/atuin-server-postgres/migrations/20240702094825_idx_cache_index.sql +++ /dev/null @@ -1 +0,0 @@ -create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag); diff --git a/crates/atuin-server-postgres/migrations/20260127000000_remove-email-verification.sql b/crates/atuin-server-postgres/migrations/20260127000000_remove-email-verification.sql deleted file mode 100644 index 15309920..00000000 --- a/crates/atuin-server-postgres/migrations/20260127000000_remove-email-verification.sql +++ /dev/null @@ -1,2 +0,0 @@ -drop table if exists user_verification_token; -alter table users drop column if exists verified_at; diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs deleted file mode 100644 index 2e69c7f2..00000000 --- a/crates/atuin-server-postgres/src/lib.rs +++ /dev/null @@ -1,581 +0,0 @@ -use std::collections::HashMap; -use std::ops::Range; - -use rand::Rng; - -use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; -use atuin_server_database::{Database, DbError, DbResult, DbSettings, into_utc}; -use futures_util::TryStreamExt; -use sqlx::Row; -use sqlx::postgres::PgPoolOptions; - -use time::OffsetDateTime; -use tracing::instrument; -use uuid::Uuid; -use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; - -mod wrappers; - -const MIN_PG_VERSION: u32 = 14; - -#[derive(Clone)] -pub struct Postgres { - pool: sqlx::Pool, - /// Optional read replica pool for read-only queries - read_pool: Option>, -} - -impl Postgres { - /// Returns the appropriate pool for read operations. - /// Uses read_pool if available, otherwise falls back to the primary pool. - fn read_pool(&self) -> &sqlx::Pool { - self.read_pool.as_ref().unwrap_or(&self.pool) - } -} - -#[async_trait] -impl Database for Postgres { - async fn new(settings: &DbSettings) -> DbResult { - let pool = PgPoolOptions::new() - .max_connections(100) - .connect(settings.db_uri.as_str()) - .await?; - - // Call server_version_num to get the DB server's major version number - // The call returns None for servers older than 8.x. - let pg_major_version: u32 = - pool.acquire() - .await? - .server_version_num() - .ok_or(DbError::Other(eyre::Report::msg( - "could not get PostgreSQL version", - )))? - / 10000; - - if pg_major_version < MIN_PG_VERSION { - return Err(DbError::Other(eyre::Report::msg(format!( - "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}" - )))); - } - - sqlx::migrate!("./migrations") - .run(&pool) - .await - .map_err(|error| DbError::Other(error.into()))?; - - // Create read replica pool if configured - let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { - tracing::info!("Connecting to read replica database"); - let read_pool = PgPoolOptions::new() - .max_connections(100) - .connect(read_db_uri.as_str()) - .await?; - - // Verify the read replica is also a supported PostgreSQL version - let read_pg_major_version: u32 = read_pool - .acquire() - .await? - .server_version_num() - .ok_or(DbError::Other(eyre::Report::msg( - "could not get PostgreSQL version from read replica", - )))? - / 10000; - - if read_pg_major_version < MIN_PG_VERSION { - return Err(DbError::Other(eyre::Report::msg(format!( - "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" - )))); - } - - Some(read_pool) - } else { - None - }; - - Ok(Self { pool, read_pool }) - } - - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> DbResult { - sqlx::query_as("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> DbResult { - sqlx::query_as("select id, username, email, password from users where username = $1") - .bind(username) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> DbResult { - sqlx::query_as( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> DbResult { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(self.read_pool()) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, user: &User) -> DbResult { - let res: (i32,) = sqlx::query_as( - "select total from total_history_count_user - where user_id = $1", - ) - .bind(user.id) - .fetch_one(self.read_pool()) - .await?; - - Ok(res.0 as i64) - } - - async fn delete_store(&self, user: &User) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - sqlx::query( - "delete from store - where user_id = $1", - ) - .bind(user.id) - .execute(&mut *tx) - .await?; - - sqlx::query( - "delete from store_idx_cache - where user_id = $1", - ) - .bind(user.id) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - - Ok(()) - } - - async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(OffsetDateTime::now_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> DbResult> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(self.read_pool()) - .await?; - - let res = res - .iter() - .map(|row| row.get::("client_id")) - .collect(); - - Ok(res) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - range: Range, - ) -> DbResult { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(into_utc(range.start)) - .bind(into_utc(range.end)) - .fetch_one(self.read_pool()) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: OffsetDateTime, - since: OffsetDateTime, - host: &str, - page_size: i64, - ) -> DbResult> { - let res = sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(into_utc(created_after)) - .bind(into_utc(since)) - .bind(page_size) - .fetch(self.read_pool()) - .map_ok(|DbHistory(h)| h) - .try_collect() - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> DbResult<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from store where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from total_history_count_user where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn update_user_password(&self, user: &User) -> DbResult<()> { - sqlx::query( - "update users - set password = $1 - where id = $2", - ) - .bind(&user.password) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> DbResult { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> DbResult<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> DbResult { - sqlx::query_as("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> DbResult { - sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(self.read_pool()) - .await - .map_err(Into::into) - .map(|DbHistory(h)| h) - } - - #[instrument(skip_all)] - async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max - // idx without having to make further database queries. Doing the query on this small - // amount of data should be much, much faster. - // - // Worst case, say we get this wrong. We end up caching data that isn't actually the max - // idx, so clients upload again. The cache logic can be verified with a sql query anyway :) - - let mut heads = HashMap::<(HostId, &str), u64>::new(); - - for i in records { - let id = atuin_common::utils::uuid_v7(); - - let result = sqlx::query( - "insert into store - (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - on conflict do nothing - ", - ) - .bind(id) - .bind(i.id) - .bind(i.host.id) - .bind(i.idx as i64) - .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time - .bind(&i.version) - .bind(&i.tag) - .bind(&i.data.data) - .bind(&i.data.content_encryption_key) - .bind(user.id) - .execute(&mut *tx) - .await?; - - // Only update heads if we actually inserted the record - if result.rows_affected() > 0 { - heads - .entry((i.host.id, &i.tag)) - .and_modify(|e| { - if i.idx > *e { - *e = i.idx - } - }) - .or_insert(i.idx); - } - } - - // we've built the map of heads for this push, so commit it to the database - for ((host, tag), idx) in heads { - sqlx::query( - "insert into store_idx_cache - (user_id, host, tag, idx) - values ($1, $2, $3, $4) - on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4) - ", - ) - .bind(user.id) - .bind(host) - .bind(tag) - .bind(idx as i64) - .execute(&mut *tx) - .await - ?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn next_records( - &self, - user: &User, - host: HostId, - tag: String, - start: Option, - count: u64, - ) -> DbResult>> { - tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); - let start = start.unwrap_or(0); - - let records: Result, DbError> = sqlx::query_as( - "select client_id, host, idx, timestamp, version, tag, data, cek from store - where user_id = $1 - and tag = $2 - and host = $3 - and idx >= $4 - order by idx asc - limit $5", - ) - .bind(user.id) - .bind(tag.clone()) - .bind(host) - .bind(start as i64) - .bind(count as i64) - .fetch_all(self.read_pool()) - .await - .map_err(Into::into); - - let ret = match records { - Ok(records) => { - let records: Vec> = records - .into_iter() - .map(|f| { - let record: Record = f.into(); - record - }) - .collect(); - - records - } - Err(DbError::NotFound) => { - tracing::debug!("no records found in store: {:?}/{}", host, tag); - return Ok(vec![]); - } - Err(e) => return Err(e), - }; - - Ok(ret) - } - - async fn status(&self, user: &User) -> DbResult { - const STATUS_SQL: &str = - "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - - // If IDX_CACHE_ROLLOUT is set, then we - // 1. Read the value of the var, use it as a % chance of using the cache - // 2. If we use the cache, just read from the cache table - // 3. If we don't use the cache, read from the store table - // IDX_CACHE_ROLLOUT should be between 0 and 100. - - let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string()); - let idx_cache_rollout = idx_cache_rollout.parse::().unwrap_or(0.0); - let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0); - - let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache { - tracing::debug!("using idx cache for user {}", user.id); - sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") - .bind(user.id) - .fetch_all(self.read_pool()) - .await? - } else { - tracing::debug!("using aggregate query for user {}", user.id); - sqlx::query_as(STATUS_SQL) - .bind(user.id) - .fetch_all(self.read_pool()) - .await? - }; - - res.sort(); - - let mut status = RecordStatus::new(); - - for i in res.iter() { - status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64); - } - - Ok(status) - } -} diff --git a/crates/atuin-server-postgres/src/wrappers.rs b/crates/atuin-server-postgres/src/wrappers.rs deleted file mode 100644 index cde4134c..00000000 --- a/crates/atuin-server-postgres/src/wrappers.rs +++ /dev/null @@ -1,77 +0,0 @@ -use ::sqlx::{FromRow, Result}; -use atuin_common::record::{EncryptedData, Host, Record}; -use atuin_server_database::models::{History, Session, User}; -use sqlx::{Row, postgres::PgRow}; -use time::PrimitiveDateTime; - -pub struct DbUser(pub User); -pub struct DbSession(pub Session); -pub struct DbHistory(pub History); -pub struct DbRecord(pub Record); - -impl<'a> FromRow<'a, PgRow> for DbUser { - fn from_row(row: &'a PgRow) -> Result { - Ok(Self(User { - id: row.try_get("id")?, - username: row.try_get("username")?, - email: row.try_get("email")?, - password: row.try_get("password")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { - fn from_row(row: &'a PgRow) -> ::sqlx::Result { - Ok(Self(Session { - id: row.try_get("id")?, - user_id: row.try_get("user_id")?, - token: row.try_get("token")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { - fn from_row(row: &'a PgRow) -> ::sqlx::Result { - Ok(Self(History { - id: row.try_get("id")?, - client_id: row.try_get("client_id")?, - user_id: row.try_get("user_id")?, - hostname: row.try_get("hostname")?, - timestamp: row - .try_get::("timestamp")? - .assume_utc(), - data: row.try_get("data")?, - created_at: row - .try_get::("created_at")? - .assume_utc(), - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { - fn from_row(row: &'a PgRow) -> ::sqlx::Result { - let timestamp: i64 = row.try_get("timestamp")?; - let idx: i64 = row.try_get("idx")?; - - let data = EncryptedData { - data: row.try_get("data")?, - content_encryption_key: row.try_get("cek")?, - }; - - Ok(Self(Record { - id: row.try_get("client_id")?, - host: Host::new(row.try_get("host")?), - idx: idx as u64, - timestamp: timestamp as u64, - version: row.try_get("version")?, - tag: row.try_get("tag")?, - data, - })) - } -} - -impl From for Record { - fn from(other: DbRecord) -> Record { - Record { ..other.0 } - } -} diff --git a/crates/atuin-server-sqlite/Cargo.toml b/crates/atuin-server-sqlite/Cargo.toml deleted file mode 100644 index 579a5e7e..00000000 --- a/crates/atuin-server-sqlite/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "atuin-server-sqlite" -edition = "2024" -description = "server sqlite database library for atuin" - -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[dependencies] -atuin-common = { path = "../atuin-common", version = "18.16.1" } -atuin-server-database = { path = "../atuin-server-database", version = "18.16.1" } - -eyre = { workspace = true } -tracing = { workspace = true } -time = { workspace = true } -serde = { workspace = true } -sqlx = { workspace = true, features = ["sqlite", "regexp"] } -async-trait = { workspace = true } -uuid = { workspace = true } -metrics = "0.24" -futures-util = "0.3" diff --git a/crates/atuin-server-sqlite/build.rs b/crates/atuin-server-sqlite/build.rs deleted file mode 100644 index d5068697..00000000 --- a/crates/atuin-server-sqlite/build.rs +++ /dev/null @@ -1,5 +0,0 @@ -// generated by `sqlx migrate build-script` -fn main() { - // trigger recompilation when a new migration is added - println!("cargo:rerun-if-changed=migrations"); -} diff --git a/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql b/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql deleted file mode 100644 index ca19ed62..00000000 --- a/crates/atuin-server-sqlite/migrations/20231203124112_create-store.sql +++ /dev/null @@ -1,17 +0,0 @@ -create table store ( - id text primary key, -- remember to use uuidv7 for happy indices <3 - client_id text not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically - host text not null, -- a unique identifier for the host - idx bigint not null, -- the index of the record in this store, identified by (host, tag) - timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision - version text not null, - tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host - data text not null, -- store the actual history data, encrypted. I don't wanna know! - cek text not null, - - user_id bigint not null, -- allow multiple users - created_at timestamp not null default current_timestamp -); - -create unique index record_uniq ON store(user_id, host, tag, idx); - diff --git a/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql b/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql deleted file mode 100644 index 7bd653ba..00000000 --- a/crates/atuin-server-sqlite/migrations/20240108124830_create-history.sql +++ /dev/null @@ -1,15 +0,0 @@ -create table history ( - id integer primary key autoincrement, - client_id text not null unique, -- the client-generated ID - user_id bigserial not null, -- allow multiple users - hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) - timestamp timestamp not null, -- one of the few non-encrypted metadatas - - data text not null, -- store the actual history data, encrypted. I don't wanna know! - - created_at timestamp not null default current_timestamp, - deleted_at timestamp -); - -create unique index history_deleted_index on history(client_id, user_id, deleted_at); - diff --git a/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql b/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql deleted file mode 100644 index 3120c35d..00000000 --- a/crates/atuin-server-sqlite/migrations/20240108124831_create-sessions.sql +++ /dev/null @@ -1,6 +0,0 @@ -create table sessions ( - id integer primary key autoincrement, - user_id integer, - token text unique not null -); - diff --git a/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql b/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql deleted file mode 100644 index 852c159d..00000000 --- a/crates/atuin-server-sqlite/migrations/20240621110730_create-users.sql +++ /dev/null @@ -1,12 +0,0 @@ -create table users ( - id integer primary key autoincrement, -- also store our own ID - username text not null unique, -- being able to contact users is useful - email text not null unique, -- being able to contact users is useful - password text not null unique, - created_at timestamp not null default (datetime('now','localtime')), - verified_at timestamp with time zone default null -); - --- the prior index is case sensitive :( -CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); -CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql b/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql deleted file mode 100644 index 36eb14de..00000000 --- a/crates/atuin-server-sqlite/migrations/20240621110731_create-user-verification-token.sql +++ /dev/null @@ -1,6 +0,0 @@ -create table user_verification_token( - id integer primary key autoincrement, - user_id bigint unique references users(id), - token text, - valid_until timestamp with time zone -); diff --git a/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql b/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql deleted file mode 100644 index cd54cb18..00000000 --- a/crates/atuin-server-sqlite/migrations/20240702094825_create-store-idx-cache.sql +++ /dev/null @@ -1,10 +0,0 @@ -create table store_idx_cache( - id integer primary key autoincrement, - user_id bigint, - - host uuid, - tag text, - idx bigint -); - -create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag); diff --git a/crates/atuin-server-sqlite/migrations/20260127000000_remove-email-verification.sql b/crates/atuin-server-sqlite/migrations/20260127000000_remove-email-verification.sql deleted file mode 100644 index 0bde89d7..00000000 --- a/crates/atuin-server-sqlite/migrations/20260127000000_remove-email-verification.sql +++ /dev/null @@ -1,2 +0,0 @@ -drop table if exists user_verification_token; -alter table users drop column verified_at; diff --git a/crates/atuin-server-sqlite/src/lib.rs b/crates/atuin-server-sqlite/src/lib.rs deleted file mode 100644 index 56ed9b6c..00000000 --- a/crates/atuin-server-sqlite/src/lib.rs +++ /dev/null @@ -1,430 +0,0 @@ -use std::str::FromStr; - -use async_trait::async_trait; -use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; -use atuin_server_database::{ - Database, DbError, DbResult, DbSettings, into_utc, - models::{History, NewHistory, NewSession, NewUser, Session, User}, -}; -use futures_util::TryStreamExt; -use sqlx::{ - Row, - sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, - types::Uuid, -}; -use tracing::instrument; -use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; - -mod wrappers; - -#[derive(Clone)] -pub struct Sqlite { - pool: sqlx::Pool, -} - -#[async_trait] -impl Database for Sqlite { - async fn new(settings: &DbSettings) -> DbResult { - let opts = SqliteConnectOptions::from_str(&settings.db_uri)? - .journal_mode(SqliteJournalMode::Wal) - .create_if_missing(true); - - let pool = SqlitePoolOptions::new().connect_with(opts).await?; - - sqlx::migrate!("./migrations") - .run(&pool) - .await - .map_err(|error| DbError::Other(error.into()))?; - - Ok(Self { pool }) - } - - #[instrument(skip_all)] - async fn get_session(&self, token: &str) -> DbResult { - sqlx::query_as("select id, user_id, token from sessions where token = $1") - .bind(token) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn get_session_user(&self, token: &str) -> DbResult { - sqlx::query_as( - "select users.id, users.username, users.email, users.password from users - inner join sessions - on users.id = sessions.user_id - and sessions.token = $1", - ) - .bind(token) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn add_session(&self, session: &NewSession) -> DbResult<()> { - let token: &str = &session.token; - - sqlx::query( - "insert into sessions - (user_id, token) - values($1, $2)", - ) - .bind(session.user_id) - .bind(token) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn get_user(&self, username: &str) -> DbResult { - sqlx::query_as("select id, username, email, password from users where username = $1") - .bind(username) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbUser(user)| user) - } - - #[instrument(skip_all)] - async fn get_user_session(&self, u: &User) -> DbResult { - sqlx::query_as("select id, user_id, token from sessions where user_id = $1") - .bind(u.id) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbSession(session)| session) - } - - #[instrument(skip_all)] - async fn add_user(&self, user: &NewUser) -> DbResult { - let email: &str = &user.email; - let username: &str = &user.username; - let password: &str = &user.password; - - let res: (i64,) = sqlx::query_as( - "insert into users - (username, email, password) - values($1, $2, $3) - returning id", - ) - .bind(username) - .bind(email) - .bind(password) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn update_user_password(&self, user: &User) -> DbResult<()> { - sqlx::query( - "update users - set password = $1 - where id = $2", - ) - .bind(&user.password) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn count_history(&self, user: &User) -> DbResult { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn count_history_cached(&self, _user: &User) -> DbResult { - Err(DbError::NotFound) - } - - #[instrument(skip_all)] - async fn delete_user(&self, u: &User) -> DbResult<()> { - sqlx::query("delete from sessions where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from users where id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - sqlx::query("delete from history where user_id = $1") - .bind(u.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { - sqlx::query( - "update history - set deleted_at = $3 - where user_id = $1 - and client_id = $2 - and deleted_at is null", // don't just keep setting it - ) - .bind(user.id) - .bind(id) - .bind(time::OffsetDateTime::now_utc()) - .fetch_all(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn deleted_history(&self, user: &User) -> DbResult> { - // The cache is new, and the user might not yet have a cache value. - // They will have one as soon as they post up some new history, but handle that - // edge case. - - let res = sqlx::query( - "select client_id from history - where user_id = $1 - and deleted_at is not null", - ) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let res = res.iter().map(|row| row.get("client_id")).collect(); - - Ok(res) - } - - async fn delete_store(&self, user: &User) -> DbResult<()> { - sqlx::query( - "delete from store - where user_id = $1", - ) - .bind(user.id) - .execute(&self.pool) - .await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - for i in records { - let id = atuin_common::utils::uuid_v7(); - - sqlx::query( - "insert into store - (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - on conflict do nothing - ", - ) - .bind(id) - .bind(i.id) - .bind(i.host.id) - .bind(i.idx as i64) - .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time - .bind(&i.version) - .bind(&i.tag) - .bind(&i.data.data) - .bind(&i.data.content_encryption_key) - .bind(user.id) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn next_records( - &self, - user: &User, - host: HostId, - tag: String, - start: Option, - count: u64, - ) -> DbResult>> { - tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); - let start = start.unwrap_or(0); - - let records: Result, DbError> = sqlx::query_as( - "select client_id, host, idx, timestamp, version, tag, data, cek from store - where user_id = $1 - and tag = $2 - and host = $3 - and idx >= $4 - order by idx asc - limit $5", - ) - .bind(user.id) - .bind(tag.clone()) - .bind(host) - .bind(start as i64) - .bind(count as i64) - .fetch_all(&self.pool) - .await - .map_err(Into::into); - - let ret = match records { - Ok(records) => { - let records: Vec> = records - .into_iter() - .map(|f| { - let record: Record = f.into(); - record - }) - .collect(); - - records - } - Err(DbError::NotFound) => { - tracing::debug!("no records found in store: {:?}/{}", host, tag); - return Ok(vec![]); - } - Err(e) => return Err(e), - }; - - Ok(ret) - } - - async fn status(&self, user: &User) -> DbResult { - const STATUS_SQL: &str = - "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; - - let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) - .bind(user.id) - .fetch_all(&self.pool) - .await?; - - let mut status = RecordStatus::new(); - - for i in res { - status.set_raw(HostId(i.0), i.1, i.2 as u64); - } - - Ok(status) - } - - #[instrument(skip_all)] - async fn count_history_range( - &self, - user: &User, - range: std::ops::Range, - ) -> DbResult { - let res: (i64,) = sqlx::query_as( - "select count(1) from history - where user_id = $1 - and timestamp >= $2::date - and timestamp < $3::date", - ) - .bind(user.id) - .bind(into_utc(range.start)) - .bind(into_utc(range.end)) - .fetch_one(&self.pool) - .await?; - - Ok(res.0) - } - - #[instrument(skip_all)] - async fn list_history( - &self, - user: &User, - created_after: time::OffsetDateTime, - since: time::OffsetDateTime, - host: &str, - page_size: i64, - ) -> DbResult> { - let res = sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - and hostname != $2 - and created_at >= $3 - and timestamp >= $4 - order by timestamp asc - limit $5", - ) - .bind(user.id) - .bind(host) - .bind(into_utc(created_after)) - .bind(into_utc(since)) - .bind(page_size) - .fetch(&self.pool) - .map_ok(|DbHistory(h)| h) - .try_collect() - .await?; - - Ok(res) - } - - #[instrument(skip_all)] - async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { - let mut tx = self.pool.begin().await?; - - for i in history { - let client_id: &str = &i.client_id; - let hostname: &str = &i.hostname; - let data: &str = &i.data; - - sqlx::query( - "insert into history - (client_id, user_id, hostname, timestamp, data) - values ($1, $2, $3, $4, $5) - on conflict do nothing - ", - ) - .bind(client_id) - .bind(i.user_id) - .bind(hostname) - .bind(i.timestamp) - .bind(data) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } - - #[instrument(skip_all)] - async fn oldest_history(&self, user: &User) -> DbResult { - sqlx::query_as( - "select id, client_id, user_id, hostname, timestamp, data, created_at from history - where user_id = $1 - order by timestamp asc - limit 1", - ) - .bind(user.id) - .fetch_one(&self.pool) - .await - .map_err(Into::into) - .map(|DbHistory(h)| h) - } -} diff --git a/crates/atuin-server-sqlite/src/wrappers.rs b/crates/atuin-server-sqlite/src/wrappers.rs deleted file mode 100644 index 2f1230c2..00000000 --- a/crates/atuin-server-sqlite/src/wrappers.rs +++ /dev/null @@ -1,72 +0,0 @@ -use ::sqlx::{FromRow, Result}; -use atuin_common::record::{EncryptedData, Host, Record}; -use atuin_server_database::models::{History, Session, User}; -use sqlx::{Row, sqlite::SqliteRow}; - -pub struct DbUser(pub User); -pub struct DbSession(pub Session); -pub struct DbHistory(pub History); -pub struct DbRecord(pub Record); - -impl<'a> FromRow<'a, SqliteRow> for DbUser { - fn from_row(row: &'a SqliteRow) -> Result { - Ok(Self(User { - id: row.try_get("id")?, - username: row.try_get("username")?, - email: row.try_get("email")?, - password: row.try_get("password")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { - Ok(Self(Session { - id: row.try_get("id")?, - user_id: row.try_get("user_id")?, - token: row.try_get("token")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { - Ok(Self(History { - id: row.try_get("id")?, - client_id: row.try_get("client_id")?, - user_id: row.try_get("user_id")?, - hostname: row.try_get("hostname")?, - timestamp: row.try_get("timestamp")?, - data: row.try_get("data")?, - created_at: row.try_get("created_at")?, - })) - } -} - -impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { - fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { - let idx: i64 = row.try_get("idx")?; - let timestamp: i64 = row.try_get("timestamp")?; - - let data = EncryptedData { - data: row.try_get("data")?, - content_encryption_key: row.try_get("cek")?, - }; - - Ok(Self(Record { - id: row.try_get("client_id")?, - host: Host::new(row.try_get("host")?), - idx: idx as u64, - timestamp: timestamp as u64, - version: row.try_get("version")?, - tag: row.try_get("tag")?, - data, - })) - } -} - -impl From for Record { - fn from(other: DbRecord) -> Record { - Record { ..other.0 } - } -} diff --git a/crates/atuin-server/CHANGELOG.md b/crates/atuin-server/CHANGELOG.md deleted file mode 120000 index 699cc9e7..00000000 --- a/crates/atuin-server/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -../../CHANGELOG.md \ No newline at end of file diff --git a/crates/atuin-server/Cargo.toml b/crates/atuin-server/Cargo.toml deleted file mode 100644 index b7779899..00000000 --- a/crates/atuin-server/Cargo.toml +++ /dev/null @@ -1,45 +0,0 @@ -[package] -name = "atuin-server" -edition = "2024" -description = "server library for atuin" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[lib] -name = "atuin_server" -path = "src/lib.rs" - -[[bin]] -name = "atuin-server" -path = "src/bin/main.rs" - -[dependencies] -atuin-common = { workspace = true } -atuin-server-database = { workspace = true } -atuin-server-postgres = { workspace = true } -atuin-server-sqlite = { workspace = true } - -tracing = { workspace = true } -time = { workspace = true } -eyre = { workspace = true } -config = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -rand = { workspace = true } -tokio = { workspace = true } -axum = "0.8" -fs-err = { workspace = true } -tower = { workspace = true } -tower-http = { version = "0.6", features = ["trace"] } -reqwest = { workspace = true } -argon2 = "0.5" -semver = { workspace = true } -metrics-exporter-prometheus = { version = "0.18", default-features = false } -metrics = "0.24" -clap = { workspace = true } -tracing-subscriber = { workspace = true } diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml deleted file mode 100644 index 9ff95890..00000000 --- a/crates/atuin-server/server.toml +++ /dev/null @@ -1,38 +0,0 @@ -## host to bind, can also be passed via CLI args -# host = "127.0.0.1" - -## port to bind, can also be passed via CLI args -# port = 8888 - -## whether to allow anyone to register an account -# open_registration = false - -## URI for postgres (using development creds here) -# db_uri="postgres://username:password@localhost/atuin" -# db_uri="sqlite:///config/atuin-server.db" - -## Optional: URI for read replica database -## If set, read-only queries will be routed to this database -# read_db_uri="postgres://username:password@localhost-replica/atuin" - -## Maximum size for one history entry -# max_history_length = 8192 - -## Maximum size for one record entry -## 1024 * 1024 * 1024 -# max_record_size = 1073741824 - -## Webhook to be called when user registers on the servers -# register_webhook_username = "" - -## Default page size for requests -# page_size = 1100 - -# [metrics] -# enable = false -# host = 127.0.0.1 -# port = 9001 - -## Enable legacy sync v1 routes (history-based sync) -## Set to false to disable and use only the newer record-based sync -# sync_v1_enabled = true diff --git a/crates/atuin-server/src/bin/main.rs b/crates/atuin-server/src/bin/main.rs deleted file mode 100644 index 960bdf6e..00000000 --- a/crates/atuin-server/src/bin/main.rs +++ /dev/null @@ -1,73 +0,0 @@ -#![forbid(unsafe_code)] - -use std::net::SocketAddr; - -use atuin_server::{Settings, example_config, launch, launch_metrics_server}; -use atuin_server_database::DbType; -use atuin_server_postgres::Postgres; -use atuin_server_sqlite::Sqlite; - -use clap::Parser; -use eyre::{Context, Result, eyre}; -use tracing_subscriber::{EnvFilter, fmt, prelude::*}; - -#[derive(Parser, Debug)] -#[clap( - name = "atuin-server", - about = "Atuin sync server", - version, - infer_subcommands = true -)] -enum Cmd { - /// Start the server - Start { - /// The host address to bind - #[clap(long)] - host: Option, - - /// The port to bind - #[clap(long, short)] - port: Option, - }, - - /// Print server example configuration - DefaultConfig, -} - -#[tokio::main] -async fn main() -> Result<()> { - let cmd = Cmd::parse(); - - tracing_subscriber::registry() - .with(fmt::layer()) - .with(EnvFilter::from_default_env()) - .init(); - - tracing::trace!(command = ?cmd, "server command"); - - match cmd { - Cmd::Start { host, port } => { - let settings = Settings::new().wrap_err("could not load server settings")?; - let host = host.as_ref().unwrap_or(&settings.host).clone(); - let port = port.unwrap_or(settings.port); - let addr = SocketAddr::new(host.parse()?, port); - - if settings.metrics.enable { - tokio::spawn(launch_metrics_server( - settings.metrics.host.clone(), - settings.metrics.port, - )); - } - - match settings.db_settings.db_type() { - DbType::Postgres => launch::(settings, addr).await, - DbType::Sqlite => launch::(settings, addr).await, - DbType::Unknown => Err(eyre!("db_uri must start with postgres:// or sqlite://")), - } - } - Cmd::DefaultConfig => { - println!("{}", example_config()); - Ok(()) - } - } -} diff --git a/crates/atuin-server/src/handlers/health.rs b/crates/atuin-server/src/handlers/health.rs deleted file mode 100644 index aebd1e8f..00000000 --- a/crates/atuin-server/src/handlers/health.rs +++ /dev/null @@ -1,15 +0,0 @@ -use axum::{Json, http, response::IntoResponse}; - -use serde::Serialize; - -#[derive(Serialize)] -pub struct HealthResponse { - pub status: &'static str, -} - -pub async fn health_check() -> impl IntoResponse { - ( - http::StatusCode::OK, - Json(HealthResponse { status: "healthy" }), - ) -} diff --git a/crates/atuin-server/src/handlers/history.rs b/crates/atuin-server/src/handlers/history.rs deleted file mode 100644 index bdafcc60..00000000 --- a/crates/atuin-server/src/handlers/history.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::{collections::HashMap, convert::TryFrom}; - -use axum::{ - Json, - extract::{Path, Query, State}, - http::{HeaderMap, StatusCode}, -}; -use metrics::counter; -use time::{Month, UtcOffset}; -use tracing::{debug, error, instrument}; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::{ - router::{AppState, UserAuth}, - utils::client_version_min, -}; -use atuin_server_database::{ - Database, - calendar::{TimePeriod, TimePeriodInfo}, - models::NewHistory, -}; - -use atuin_common::api::*; - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn count( - UserAuth(user): UserAuth, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - match db.count_history_cached(&user).await { - // By default read out the cached value - Ok(count) => Ok(Json(CountResponse { count })), - - // If that fails, fallback on a full COUNT. Cache is built on a POST - // only - Err(_) => match db.count_history(&user).await { - Ok(count) => Ok(Json(CountResponse { count })), - Err(_) => Err(ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)), - }, - } -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn list( - req: Query, - UserAuth(user): UserAuth, - headers: HeaderMap, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - let agent = headers - .get("user-agent") - .map_or("", |v| v.to_str().unwrap_or("")); - - let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); - - let page_size = if variable_page_size { - state.settings.page_size - } else { - 100 - }; - - if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { - error!("client asked for history from < epoch 0"); - counter!("atuin_history_epoch_before_zero").increment(1); - - return Err( - ErrorResponse::reply("asked for history from before epoch 0") - .with_status(StatusCode::BAD_REQUEST), - ); - } - - let history = db - .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) - .await; - - if let Err(e) = history { - error!("failed to load history: {}", e); - return Err(ErrorResponse::reply("failed to load history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - - let history: Vec = history - .unwrap() - .iter() - .map(|i| i.data.to_string()) - .collect(); - - debug!( - "loaded {} items of history for user {}", - history.len(), - user.id - ); - - counter!("atuin_history_returned").increment(history.len() as u64); - - Ok(Json(SyncHistoryResponse { history })) -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn delete( - UserAuth(user): UserAuth, - state: State>, - Json(req): Json, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - // user_id is the ID of the history, as set by the user (the server has its own ID) - let deleted = db.delete_history(&user, req.client_id).await; - - if let Err(e) = deleted { - error!("failed to delete history: {}", e); - return Err(ErrorResponse::reply("failed to delete history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - - Ok(Json(MessageResponse { - message: String::from("deleted OK"), - })) -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn add( - UserAuth(user): UserAuth, - state: State>, - Json(req): Json>, -) -> Result<(), ErrorResponseStatus<'static>> { - let State(AppState { database, settings }) = state; - - debug!("request to add {} history items", req.len()); - counter!("atuin_history_uploaded").increment(req.len() as u64); - - let mut history: Vec = req - .into_iter() - .map(|h| NewHistory { - client_id: h.id, - user_id: user.id, - hostname: h.hostname, - timestamp: h.timestamp, - data: h.data, - }) - .collect(); - - history.retain(|h| { - // keep if within limit, or limit is 0 (unlimited) - let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; - - // Don't return an error here. We want to insert as much of the - // history list as we can, so log the error and continue going. - if !keep { - counter!("atuin_history_too_long").increment(1); - - tracing::warn!( - "history too long, got length {}, max {}", - h.data.len(), - settings.max_history_length - ); - } - - keep - }); - - if let Err(e) = database.add_history(&history).await { - error!("failed to add history: {}", e); - - return Err(ErrorResponse::reply("failed to add history") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - }; - - Ok(()) -} - -#[derive(serde::Deserialize, Debug)] -pub struct CalendarQuery { - #[serde(default = "serde_calendar::zero")] - year: i32, - #[serde(default = "serde_calendar::one")] - month: u8, - - #[serde(default = "serde_calendar::utc")] - tz: UtcOffset, -} - -mod serde_calendar { - use time::UtcOffset; - - pub fn zero() -> i32 { - 0 - } - - pub fn one() -> u8 { - 1 - } - - pub fn utc() -> UtcOffset { - UtcOffset::UTC - } -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn calendar( - Path(focus): Path, - Query(params): Query, - UserAuth(user): UserAuth, - state: State>, -) -> Result>, ErrorResponseStatus<'static>> { - let focus = focus.as_str(); - - let year = params.year; - let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { - error: ErrorResponse { - reason: e.to_string().into(), - }, - status: StatusCode::BAD_REQUEST, - })?; - - let period = match focus { - "year" => TimePeriod::Year, - "month" => TimePeriod::Month { year }, - "day" => TimePeriod::Day { year, month }, - _ => { - return Err(ErrorResponse::reply("invalid focus: use year/month/day") - .with_status(StatusCode::BAD_REQUEST)); - } - }; - - let db = &state.0.database; - let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { - ErrorResponse::reply("failed to query calendar") - .with_status(StatusCode::INTERNAL_SERVER_ERROR) - })?; - - Ok(Json(focus)) -} diff --git a/crates/atuin-server/src/handlers/mod.rs b/crates/atuin-server/src/handlers/mod.rs deleted file mode 100644 index 2176ac5e..00000000 --- a/crates/atuin-server/src/handlers/mod.rs +++ /dev/null @@ -1,60 +0,0 @@ -use atuin_common::api::{ErrorResponse, IndexResponse}; -use atuin_server_database::Database; -use axum::{Json, extract::State, http, response::IntoResponse}; - -use crate::router::AppState; - -pub mod health; -pub mod history; -pub mod record; -pub mod status; -pub mod user; -pub mod v0; - -const VERSION: &str = env!("CARGO_PKG_VERSION"); - -pub async fn index(state: State>) -> Json { - let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; - - let version = state - .settings - .fake_version - .clone() - .unwrap_or(VERSION.to_string()); - - Json(IndexResponse { - homage: homage.to_string(), - version, - }) -} - -impl IntoResponse for ErrorResponseStatus<'_> { - fn into_response(self) -> axum::response::Response { - (self.status, Json(self.error)).into_response() - } -} - -pub struct ErrorResponseStatus<'a> { - pub error: ErrorResponse<'a>, - pub status: http::StatusCode, -} - -pub trait RespExt<'a> { - fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a>; - fn reply(reason: &'a str) -> Self; -} - -impl<'a> RespExt<'a> for ErrorResponse<'a> { - fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { - ErrorResponseStatus { - error: self, - status, - } - } - - fn reply(reason: &'a str) -> ErrorResponse<'a> { - Self { - reason: reason.into(), - } - } -} diff --git a/crates/atuin-server/src/handlers/record.rs b/crates/atuin-server/src/handlers/record.rs deleted file mode 100644 index 410c54bd..00000000 --- a/crates/atuin-server/src/handlers/record.rs +++ /dev/null @@ -1,42 +0,0 @@ -use axum::{Json, http::StatusCode, response::IntoResponse}; -use serde_json::json; -use tracing::instrument; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::router::UserAuth; - -use atuin_common::record::{EncryptedData, Record}; - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn post(UserAuth(user): UserAuth) -> Result<(), ErrorResponseStatus<'static>> { - // anyone who has actually used the old record store (a very small number) will see this error - // upon trying to sync. - // 1. The status endpoint will say that the server has nothing - // 2. The client will try to upload local records - // 3. Sync will fail with this error - - // If the client has no local records, they will see the empty index and do nothing. For the - // vast majority of users, this is the case. - return Err( - ErrorResponse::reply("record store deprecated; please upgrade") - .with_status(StatusCode::BAD_REQUEST), - ); -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn index(UserAuth(user): UserAuth) -> axum::response::Response { - let ret = json!({ - "hosts": {} - }); - - ret.to_string().into_response() -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn next( - UserAuth(user): UserAuth, -) -> Result>>, ErrorResponseStatus<'static>> { - let records = Vec::new(); - - Ok(Json(records)) -} diff --git a/crates/atuin-server/src/handlers/status.rs b/crates/atuin-server/src/handlers/status.rs deleted file mode 100644 index 9c152d51..00000000 --- a/crates/atuin-server/src/handlers/status.rs +++ /dev/null @@ -1,45 +0,0 @@ -use axum::{Json, extract::State, http::StatusCode}; -use tracing::instrument; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::router::{AppState, UserAuth}; -use atuin_server_database::Database; - -use atuin_common::api::*; - -const VERSION: &str = env!("CARGO_PKG_VERSION"); - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn status( - UserAuth(user): UserAuth, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); - - let count = match db.count_history_cached(&user).await { - // By default read out the cached value - Ok(count) => count, - - // If that fails, fallback on a full COUNT. Cache is built on a POST - // only - Err(_) => match db.count_history(&user).await { - Ok(count) => count, - Err(_) => { - return Err(ErrorResponse::reply("failed to query history count") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }, - }; - - tracing::debug!(user = user.username, "requested sync status"); - - Ok(Json(StatusResponse { - count, - deleted, - username: user.username, - version: VERSION.to_string(), - page_size: state.settings.page_size, - })) -} diff --git a/crates/atuin-server/src/handlers/user.rs b/crates/atuin-server/src/handlers/user.rs deleted file mode 100644 index dda7a381..00000000 --- a/crates/atuin-server/src/handlers/user.rs +++ /dev/null @@ -1,269 +0,0 @@ -use std::borrow::Borrow; -use std::collections::HashMap; -use std::time::Duration; - -use argon2::{ - Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version, - password_hash::SaltString, -}; -use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, -}; -use metrics::counter; - -use rand::rngs::OsRng; -use tracing::{debug, error, info, instrument}; - -use atuin_common::tls::ensure_crypto_provider; - -use super::{ErrorResponse, ErrorResponseStatus, RespExt}; -use crate::router::{AppState, UserAuth}; -use atuin_server_database::{ - Database, DbError, - models::{NewSession, NewUser}, -}; - -use reqwest::header::CONTENT_TYPE; - -use atuin_common::{api::*, utils::crypto_random_string}; - -pub fn verify_str(hash: &str, password: &str) -> bool { - let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); - let Ok(hash) = PasswordHash::new(hash) else { - return false; - }; - arg2.verify_password(password.as_bytes(), &hash).is_ok() -} - -// Try to send a Discord webhook once - if it fails, we don't retry. "At most once", and best effort. -// Don't return the status because if this fails, we don't really care. -async fn send_register_hook(url: &str, username: String, registered: String) { - ensure_crypto_provider(); - let hook = HashMap::from([ - ("username", username), - ("content", format!("{registered} has just signed up!")), - ]); - - let client = reqwest::Client::new(); - - let resp = client - .post(url) - .timeout(Duration::new(5, 0)) - .header(CONTENT_TYPE, "application/json") - .json(&hook) - .send() - .await; - - match resp { - Ok(_) => info!("register webhook sent ok!"), - Err(e) => error!("failed to send register webhook: {}", e), - } -} - -#[instrument(skip_all, fields(user.username = username.as_str()))] -pub async fn get( - Path(username): Path, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - let user = match db.get_user(username.as_ref()).await { - Ok(user) => user, - Err(DbError::NotFound) => { - debug!("user not found: {}", username); - return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); - } - Err(DbError::Other(err)) => { - error!("database error: {}", err); - return Err(ErrorResponse::reply("database error") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }; - - Ok(Json(UserResponse { - username: user.username, - })) -} - -#[instrument(skip_all)] -pub async fn register( - state: State>, - Json(register): Json, -) -> Result, ErrorResponseStatus<'static>> { - if !state.settings.open_registration { - return Err( - ErrorResponse::reply("this server is not open for registrations") - .with_status(StatusCode::BAD_REQUEST), - ); - } - - for c in register.username.chars() { - match c { - 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => {} - _ => { - return Err(ErrorResponse::reply( - "Only alphanumeric and hyphens (-) are allowed in usernames", - ) - .with_status(StatusCode::BAD_REQUEST)); - } - } - } - - let hashed = hash_secret(®ister.password); - - let new_user = NewUser { - email: register.email.clone(), - username: register.username.clone(), - password: hashed, - }; - - let db = &state.0.database; - let user_id = match db.add_user(&new_user).await { - Ok(id) => id, - Err(e) => { - error!("failed to add user: {}", e); - return Err( - ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) - ); - } - }; - - // 24 bytes encoded as base64 - let token = crypto_random_string::<24>(); - - let new_session = NewSession { - user_id, - token: (&token).into(), - }; - - if let Some(url) = &state.settings.register_webhook_url { - // Could probs be run on another thread, but it's ok atm - send_register_hook( - url, - state.settings.register_webhook_username.clone(), - register.username, - ) - .await; - } - - counter!("atuin_users_registered").increment(1); - - match db.add_session(&new_session).await { - Ok(_) => Ok(Json(RegisterResponse { - session: token, - auth: Some("cli".into()), - })), - Err(e) => { - error!("failed to add session: {}", e); - Err(ErrorResponse::reply("failed to register user") - .with_status(StatusCode::BAD_REQUEST)) - } - } -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn delete( - UserAuth(user): UserAuth, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - debug!("request to delete user {}", user.id); - - let db = &state.0.database; - if let Err(e) = db.delete_user(&user).await { - error!("failed to delete user: {}", e); - - return Err(ErrorResponse::reply("failed to delete user") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - }; - - counter!("atuin_users_deleted").increment(1); - - Ok(Json(DeleteUserResponse {})) -} - -#[instrument(skip_all, fields(user.id = user.id, change_password))] -pub async fn change_password( - UserAuth(mut user): UserAuth, - state: State>, - Json(change_password): Json, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - - let verified = verify_str( - user.password.as_str(), - change_password.current_password.borrow(), - ); - if !verified { - return Err( - ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) - ); - } - - let hashed = hash_secret(&change_password.new_password); - user.password = hashed; - - if let Err(e) = db.update_user_password(&user).await { - error!("failed to change user password: {}", e); - - return Err(ErrorResponse::reply("failed to change user password") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - }; - Ok(Json(ChangePasswordResponse {})) -} - -#[instrument(skip_all, fields(user.username = login.username.as_str()))] -pub async fn login( - state: State>, - login: Json, -) -> Result, ErrorResponseStatus<'static>> { - let db = &state.0.database; - let user = match db.get_user(login.username.borrow()).await { - Ok(u) => u, - Err(DbError::NotFound) => { - return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); - } - Err(DbError::Other(e)) => { - error!("failed to get user {}: {}", login.username.clone(), e); - - return Err(ErrorResponse::reply("database error") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }; - - let session = match db.get_user_session(&user).await { - Ok(u) => u, - Err(DbError::NotFound) => { - debug!("user session not found for user id={}", user.id); - return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); - } - Err(DbError::Other(err)) => { - error!("database error for user {}: {}", login.username, err); - return Err(ErrorResponse::reply("database error") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }; - - let verified = verify_str(user.password.as_str(), login.password.borrow()); - - if !verified { - debug!(user = user.username, "login failed"); - return Err( - ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) - ); - } - - debug!(user = user.username, "login success"); - - Ok(Json(LoginResponse { - session: session.token, - auth: Some("cli".into()), - })) -} - -fn hash_secret(password: &str) -> String { - let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); - let salt = SaltString::generate(&mut OsRng); - let hash = arg2.hash_password(password.as_bytes(), &salt).unwrap(); - hash.to_string() -} diff --git a/crates/atuin-server/src/handlers/v0/me.rs b/crates/atuin-server/src/handlers/v0/me.rs deleted file mode 100644 index 7960b479..00000000 --- a/crates/atuin-server/src/handlers/v0/me.rs +++ /dev/null @@ -1,16 +0,0 @@ -use axum::Json; -use tracing::instrument; - -use crate::handlers::ErrorResponseStatus; -use crate::router::UserAuth; - -use atuin_common::api::*; - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn get( - UserAuth(user): UserAuth, -) -> Result, ErrorResponseStatus<'static>> { - Ok(Json(MeResponse { - username: user.username, - })) -} diff --git a/crates/atuin-server/src/handlers/v0/mod.rs b/crates/atuin-server/src/handlers/v0/mod.rs deleted file mode 100644 index d6f880f2..00000000 --- a/crates/atuin-server/src/handlers/v0/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub(crate) mod me; -pub(crate) mod record; -pub(crate) mod store; diff --git a/crates/atuin-server/src/handlers/v0/record.rs b/crates/atuin-server/src/handlers/v0/record.rs deleted file mode 100644 index 5c57910b..00000000 --- a/crates/atuin-server/src/handlers/v0/record.rs +++ /dev/null @@ -1,114 +0,0 @@ -use axum::{Json, extract::Query, extract::State, http::StatusCode}; -use metrics::counter; -use serde::Deserialize; -use tracing::{error, instrument}; - -use crate::{ - handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, - router::{AppState, UserAuth}, -}; -use atuin_server_database::Database; - -use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn post( - UserAuth(user): UserAuth, - state: State>, - Json(records): Json>>, -) -> Result<(), ErrorResponseStatus<'static>> { - let State(AppState { database, settings }) = state; - - tracing::debug!( - count = records.len(), - user = user.username, - "request to add records" - ); - - counter!("atuin_record_uploaded").increment(records.len() as u64); - - let keep = records - .iter() - .all(|r| r.data.data.len() <= settings.max_record_size || settings.max_record_size == 0); - - if !keep { - counter!("atuin_record_too_large").increment(1); - - return Err( - ErrorResponse::reply("could not add records; record too large") - .with_status(StatusCode::BAD_REQUEST), - ); - } - - if let Err(e) = database.add_records(&user, &records).await { - error!("failed to add record: {}", e); - - return Err(ErrorResponse::reply("failed to add record") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - }; - - Ok(()) -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn index( - UserAuth(user): UserAuth, - state: State>, -) -> Result, ErrorResponseStatus<'static>> { - let State(AppState { - database, - settings: _, - }) = state; - - let record_index = match database.status(&user).await { - Ok(index) => index, - Err(e) => { - error!("failed to get record index: {}", e); - - return Err(ErrorResponse::reply("failed to calculate record index") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }; - - tracing::debug!(user = user.username, "record index request"); - - Ok(Json(record_index)) -} - -#[derive(Deserialize)] -pub struct NextParams { - host: HostId, - tag: String, - start: Option, - count: u64, -} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn next( - params: Query, - UserAuth(user): UserAuth, - state: State>, -) -> Result>>, ErrorResponseStatus<'static>> { - let State(AppState { - database, - settings: _, - }) = state; - let params = params.0; - - let records = match database - .next_records(&user, params.host, params.tag, params.start, params.count) - .await - { - Ok(records) => records, - Err(e) => { - error!("failed to get record index: {}", e); - - return Err(ErrorResponse::reply("failed to calculate record index") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - }; - - counter!("atuin_record_downloaded").increment(records.len() as u64); - - Ok(Json(records)) -} diff --git a/crates/atuin-server/src/handlers/v0/store.rs b/crates/atuin-server/src/handlers/v0/store.rs deleted file mode 100644 index 6ca455d7..00000000 --- a/crates/atuin-server/src/handlers/v0/store.rs +++ /dev/null @@ -1,37 +0,0 @@ -use axum::{extract::Query, extract::State, http::StatusCode}; -use metrics::counter; -use serde::Deserialize; -use tracing::{error, instrument}; - -use crate::{ - handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, - router::{AppState, UserAuth}, -}; -use atuin_server_database::Database; - -#[derive(Deserialize)] -pub struct DeleteParams {} - -#[instrument(skip_all, fields(user.id = user.id))] -pub async fn delete( - _params: Query, - UserAuth(user): UserAuth, - state: State>, -) -> Result<(), ErrorResponseStatus<'static>> { - let State(AppState { - database, - settings: _, - }) = state; - - if let Err(e) = database.delete_store(&user).await { - counter!("atuin_store_delete_failed").increment(1); - error!("failed to delete store {e:?}"); - - return Err(ErrorResponse::reply("failed to delete store") - .with_status(StatusCode::INTERNAL_SERVER_ERROR)); - } - - counter!("atuin_store_deleted").increment(1); - - Ok(()) -} diff --git a/crates/atuin-server/src/lib.rs b/crates/atuin-server/src/lib.rs deleted file mode 100644 index 02e50e1e..00000000 --- a/crates/atuin-server/src/lib.rs +++ /dev/null @@ -1,89 +0,0 @@ -#![forbid(unsafe_code)] - -use std::future::Future; -use std::net::SocketAddr; - -use atuin_server_database::Database; -use axum::{Router, serve}; -use eyre::{Context, Result}; - -mod handlers; -mod metrics; -mod router; -mod utils; - -pub use settings::Settings; -pub use settings::example_config; - -pub mod settings; - -use tokio::net::TcpListener; -use tokio::signal; - -#[cfg(target_family = "unix")] -async fn shutdown_signal() { - let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to register signal handler"); - let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) - .expect("failed to register signal handler"); - - tokio::select! { - _ = term.recv() => {}, - _ = interrupt.recv() => {}, - }; - eprintln!("Shutting down gracefully..."); -} - -pub async fn launch(settings: Settings, addr: SocketAddr) -> Result<()> { - launch_with_tcp_listener::( - settings, - TcpListener::bind(addr) - .await - .context("could not connect to socket")?, - shutdown_signal(), - ) - .await -} - -pub async fn launch_with_tcp_listener( - settings: Settings, - listener: TcpListener, - shutdown: impl Future + Send + 'static, -) -> Result<()> { - let r = make_router::(settings).await?; - - serve(listener, r.into_make_service()) - .with_graceful_shutdown(shutdown) - .await?; - - Ok(()) -} - -// The separate listener means it's much easier to ensure metrics are not accidentally exposed to -// the public. -pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { - let listener = TcpListener::bind((host, port)) - .await - .context("failed to bind metrics tcp")?; - - let recorder_handle = metrics::setup_metrics_recorder(); - - let router = Router::new().route( - "/metrics", - axum::routing::get(move || std::future::ready(recorder_handle.render())), - ); - - serve(listener, router.into_make_service()) - .with_graceful_shutdown(shutdown_signal()) - .await?; - - Ok(()) -} - -async fn make_router(settings: Settings) -> Result { - let db = Db::new(&settings.db_settings) - .await - .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; - let r = router::router(db, settings); - Ok(r) -} diff --git a/crates/atuin-server/src/metrics.rs b/crates/atuin-server/src/metrics.rs deleted file mode 100644 index ebd0dd2d..00000000 --- a/crates/atuin-server/src/metrics.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::time::Instant; - -use axum::{ - extract::{MatchedPath, Request}, - middleware::Next, - response::IntoResponse, -}; -use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; - -pub fn setup_metrics_recorder() -> PrometheusHandle { - const EXPONENTIAL_SECONDS: &[f64] = &[ - 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, - ]; - - PrometheusBuilder::new() - .set_buckets_for_metric( - Matcher::Full("http_requests_duration_seconds".to_string()), - EXPONENTIAL_SECONDS, - ) - .unwrap() - .install_recorder() - .unwrap() -} - -/// Middleware to record some common HTTP metrics -/// Generic over B to allow for arbitrary body types (eg Vec, Streams, a deserialized thing, etc) -/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57 -pub async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { - let start = Instant::now(); - - let path = match req.extensions().get::() { - Some(matched_path) => matched_path.as_str().to_owned(), - _ => req.uri().path().to_owned(), - }; - - let method = req.method().clone(); - - // Run the rest of the request handling first, so we can measure it and get response - // codes. - let response = next.run(req).await; - - let latency = start.elapsed().as_secs_f64(); - let status = response.status().as_u16().to_string(); - - let labels = [ - ("method", method.to_string()), - ("path", path), - ("status", status), - ]; - - metrics::counter!("http_requests_total", &labels).increment(1); - metrics::histogram!("http_requests_duration_seconds", &labels).record(latency); - - response -} diff --git a/crates/atuin-server/src/router.rs b/crates/atuin-server/src/router.rs deleted file mode 100644 index 2d679759..00000000 --- a/crates/atuin-server/src/router.rs +++ /dev/null @@ -1,155 +0,0 @@ -use atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}; -use axum::{ - Router, - extract::{FromRequestParts, Request}, - http::{self, request::Parts}, - middleware::Next, - response::{IntoResponse, Response}, - routing::{delete, get, patch, post}, -}; -use eyre::Result; -use tower::ServiceBuilder; -use tower_http::trace::TraceLayer; - -use super::handlers; -use crate::{ - handlers::{ErrorResponseStatus, RespExt}, - metrics, - settings::Settings, -}; -use atuin_server_database::{Database, DbError, models::User}; - -pub struct UserAuth(pub User); - -impl FromRequestParts> for UserAuth -where - DB: Database, -{ - type Rejection = ErrorResponseStatus<'static>; - - async fn from_request_parts( - req: &mut Parts, - state: &AppState, - ) -> Result { - let auth_header = req - .headers - .get(http::header::AUTHORIZATION) - .ok_or_else(|| { - ErrorResponse::reply("missing authorization header") - .with_status(http::StatusCode::BAD_REQUEST) - })?; - let auth_header = auth_header.to_str().map_err(|_| { - ErrorResponse::reply("invalid authorization header encoding") - .with_status(http::StatusCode::BAD_REQUEST) - })?; - let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { - ErrorResponse::reply("invalid authorization header encoding") - .with_status(http::StatusCode::BAD_REQUEST) - })?; - - if typ != "Token" { - return Err( - ErrorResponse::reply("invalid authorization header encoding") - .with_status(http::StatusCode::BAD_REQUEST), - ); - } - - let user = state - .database - .get_session_user(token) - .await - .map_err(|e| match e { - DbError::NotFound => ErrorResponse::reply("session not found") - .with_status(http::StatusCode::FORBIDDEN), - DbError::Other(e) => { - tracing::error!(error = ?e, "could not query user session"); - ErrorResponse::reply("could not query user session") - .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) - } - })?; - - Ok(UserAuth(user)) - } -} - -async fn teapot() -> impl IntoResponse { - // This used to return 418: 🫖 - // Much as it was fun, it wasn't as useful or informative as it should be - (http::StatusCode::NOT_FOUND, "404 not found") -} - -async fn clacks_overhead(request: Request, next: Next) -> Response { - let mut response = next.run(request).await; - - let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; - let gnu_terry_header = "X-Clacks-Overhead"; - - response - .headers_mut() - .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); - response -} - -/// Ensure that we only try and sync with clients on the same major version -async fn semver(request: Request, next: Next) -> Response { - let mut response = next.run(request).await; - response - .headers_mut() - .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); - - response -} - -#[derive(Clone)] -pub struct AppState { - pub database: DB, - pub settings: Settings, -} - -pub fn router(database: DB, settings: Settings) -> Router { - let mut routes = Router::new() - .route("/", get(handlers::index)) - .route("/healthz", get(handlers::health::health_check)); - - // Sync v1 routes - can be disabled in favor of record-based sync - if settings.sync_v1_enabled { - routes = routes - .route("/sync/count", get(handlers::history::count)) - .route("/sync/history", get(handlers::history::list)) - .route("/sync/calendar/{focus}", get(handlers::history::calendar)) - .route("/sync/status", get(handlers::status::status)) - .route("/history", post(handlers::history::add)) - .route("/history", delete(handlers::history::delete)); - } - - let routes = routes - .route("/user/{username}", get(handlers::user::get)) - .route("/account", delete(handlers::user::delete)) - .route("/account/password", patch(handlers::user::change_password)) - .route("/register", post(handlers::user::register)) - .route("/login", post(handlers::user::login)) - .route("/record", post(handlers::record::post)) - .route("/record", get(handlers::record::index)) - .route("/record/next", get(handlers::record::next)) - .route("/api/v0/me", get(handlers::v0::me::get)) - .route("/api/v0/record", post(handlers::v0::record::post)) - .route("/api/v0/record", get(handlers::v0::record::index)) - .route("/api/v0/record/next", get(handlers::v0::record::next)) - .route("/api/v0/store", delete(handlers::v0::store::delete)); - - let path = settings.path.as_str(); - if path.is_empty() { - routes - } else { - Router::new().nest(path, routes) - } - .fallback(teapot) - .with_state(AppState { database, settings }) - .layer( - ServiceBuilder::new() - .layer(axum::middleware::from_fn(clacks_overhead)) - .layer(TraceLayer::new_for_http()) - .layer(axum::middleware::from_fn(metrics::track_metrics)) - .layer(axum::middleware::from_fn(semver)), - ) -} diff --git a/crates/atuin-server/src/settings.rs b/crates/atuin-server/src/settings.rs deleted file mode 100644 index 3a612be9..00000000 --- a/crates/atuin-server/src/settings.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{io::prelude::*, path::PathBuf}; - -use atuin_server_database::DbSettings; -use config::{Config, Environment, File as ConfigFile, FileFormat}; -use eyre::{Result, eyre}; -use fs_err::{File, create_dir_all}; -use serde::{Deserialize, Serialize}; - -static EXAMPLE_CONFIG: &str = include_str!("../server.toml"); - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Metrics { - #[serde(alias = "enabled")] - pub enable: bool, - pub host: String, - pub port: u16, -} - -impl Default for Metrics { - fn default() -> Self { - Self { - enable: false, - host: String::from("127.0.0.1"), - port: 9001, - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct Settings { - pub host: String, - pub port: u16, - pub path: String, - pub open_registration: bool, - pub max_history_length: usize, - pub max_record_size: usize, - pub page_size: i64, - pub register_webhook_url: Option, - pub register_webhook_username: String, - pub metrics: Metrics, - - /// Enable legacy sync v1 routes (history-based sync) - /// Set to false to use only the newer record-based sync - pub sync_v1_enabled: bool, - - /// Advertise a version that is not what we are _actually_ running - /// Many clients compare their version with api.atuin.sh, and if they differ, notify the user - /// that an update is available. - /// Now that we take beta releases, we should be able to advertise a different version to avoid - /// notifying users when the server runs something that is not a stable release. - pub fake_version: Option, - - #[serde(flatten)] - pub db_settings: DbSettings, -} - -impl Settings { - pub fn new() -> Result { - let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { - PathBuf::from(p) - } else { - let mut config_file = PathBuf::new(); - let config_dir = atuin_common::utils::config_dir(); - config_file.push(config_dir); - config_file - }; - - config_file.push("server.toml"); - - // create the config file if it does not exist - let mut config_builder = Config::builder() - .set_default("host", "127.0.0.1")? - .set_default("port", 8888)? - .set_default("open_registration", false)? - .set_default("max_history_length", 8192)? - .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky - .set_default("path", "")? - .set_default("register_webhook_username", "")? - .set_default("page_size", 1100)? - .set_default("metrics.enable", false)? - .set_default("metrics.host", "127.0.0.1")? - .set_default("metrics.port", 9001)? - .set_default("sync_v1_enabled", true)? - .add_source( - Environment::with_prefix("atuin") - .prefix_separator("_") - .separator("__"), - ); - - config_builder = if config_file.exists() { - config_builder.add_source(ConfigFile::new( - config_file.to_str().unwrap(), - FileFormat::Toml, - )) - } else { - create_dir_all(config_file.parent().unwrap())?; - let mut file = File::create(config_file)?; - file.write_all(EXAMPLE_CONFIG.as_bytes())?; - - config_builder - }; - - let config = config_builder.build()?; - - config - .try_deserialize() - .map_err(|e| eyre!("failed to deserialize: {}", e)) - } -} - -pub fn example_config() -> &'static str { - EXAMPLE_CONFIG -} diff --git a/crates/atuin-server/src/utils.rs b/crates/atuin-server/src/utils.rs deleted file mode 100644 index 12e9ac1b..00000000 --- a/crates/atuin-server/src/utils.rs +++ /dev/null @@ -1,15 +0,0 @@ -use eyre::Result; -use semver::{Version, VersionReq}; - -pub fn client_version_min(user_agent: &str, req: &str) -> Result { - if user_agent.is_empty() { - return Ok(false); - } - - let version = user_agent.replace("atuin/", ""); - - let req = VersionReq::parse(req)?; - let version = Version::parse(version.as_str())?; - - Ok(req.matches(&version)) -} diff --git a/crates/atuin/CHANGELOG.md b/crates/atuin/CHANGELOG.md deleted file mode 120000 index 699cc9e7..00000000 --- a/crates/atuin/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -../../CHANGELOG.md \ No newline at end of file diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml deleted file mode 100644 index 6cdc57fa..00000000 --- a/crates/atuin/Cargo.toml +++ /dev/null @@ -1,87 +0,0 @@ -[package] -name = "atuin" -edition = "2024" -description = "atuin - magical shell history" -readme = "./README.md" - -rust-version = { workspace = true } -version = { workspace = true } -authors = { workspace = true } -license = { workspace = true } -homepage = { workspace = true } -repository = { workspace = true } - -[features] -default = [ - "client", "sync", "clipboard", "daemon", "pty-proxy" -] -client = ["atuin-client"] -sync = ["atuin-client/sync"] -daemon = ["atuin-client/daemon", "atuin-daemon"] -pty-proxy = ["dep:atuin-pty-proxy"] -hex = ["pty-proxy"] -clipboard = ["arboard"] - -[dependencies] -atuin-client = { path = "../atuin-client", version = "18.16.1", optional = true, default-features = false } -atuin-common = { workspace = true } -atuin-history = { workspace = true } -atuin-daemon = { path = "../atuin-daemon", version = "18.16.1", optional = true, default-features = false } -atuin-pty-proxy = { path = "../atuin-pty-proxy", version = "18.16.1", optional = true, default-features = false } - -log = { workspace = true } -time = { workspace = true } -eyre = { workspace = true } -indicatif = "0.18.0" -serde = { workspace = true } -serde_json = { workspace = true } -crossterm = { workspace = true, features = ["use-dev-tty"] } -unicode-width = "0.2" -itertools = { workspace = true } -tokio = { workspace = true } -async-trait = { workspace = true } -interim = { workspace = true } -clap = { workspace = true } -clap_complete = "4.5.8" -clap_complete_nushell = "4.5.4" -fs-err = { workspace = true } -fs4 = "0.13.1" -rpassword = "7.0" -semver = { workspace = true } -rustix = { workspace = true } -runtime-format = "0.1.3" -futures-util = "0.3" -fuzzy-matcher = "0.3.7" -colored = "2.0.4" -open = "5" -ratatui = { workspace = true } -tracing = "0.1" -tracing-subscriber = { workspace = true } -tracing-appender = "0.2" -uuid = { workspace = true } -sysinfo = "0.30.7" -regex = "1.10.5" -norm = { version = "0.1.1", features = ["fzf-v2"] } -atuin-nucleo-matcher = { workspace = true } -tempfile = { workspace = true } -shlex = "1.3.0" - -# settings editor with comment and relative ordering preservation -toml_edit = { workspace = true } - -[target.'cfg(target_os = "linux")'.dependencies] -arboard = { version = "3.4", optional = true, default-features = false, features = [ - "wayland-data-control", -] } - -[target.'cfg(unix)'.dependencies] -daemonize = "0.5.0" - -[dev-dependencies] -tracing-tree = "0.4" - -# Integration tests in tests/ spin up a test server to verify sync functionality. -# TODO: Consider moving these tests to atuin-server crate instead (client would become a dev dep there) -atuin-server = { workspace = true } -atuin-server-database = { workspace = true } -atuin-server-postgres = { workspace = true } diff --git a/crates/atuin/LICENSE b/crates/atuin/LICENSE deleted file mode 100644 index 7dfc9b58..00000000 --- a/crates/atuin/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 Ellie Huxtable - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/crates/atuin/README.md b/crates/atuin/README.md deleted file mode 120000 index fe840054..00000000 --- a/crates/atuin/README.md +++ /dev/null @@ -1 +0,0 @@ -../../README.md \ No newline at end of file diff --git a/crates/atuin/build.rs b/crates/atuin/build.rs deleted file mode 100644 index 75d53ee0..00000000 --- a/crates/atuin/build.rs +++ /dev/null @@ -1,11 +0,0 @@ -use std::process::Command; -fn main() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output(); - - let sha = match output { - Ok(sha) => String::from_utf8(sha.stdout).unwrap(), - Err(_) => String::from("NO_GIT"), - }; - - println!("cargo:rustc-env=GIT_HASH={sha}"); -} diff --git a/crates/atuin/src/command/CONTRIBUTORS b/crates/atuin/src/command/CONTRIBUTORS deleted file mode 120000 index 1ca4115a..00000000 --- a/crates/atuin/src/command/CONTRIBUTORS +++ /dev/null @@ -1 +0,0 @@ -../../../../CONTRIBUTORS \ No newline at end of file diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs deleted file mode 100644 index 6c1bac29..00000000 --- a/crates/atuin/src/command/client.rs +++ /dev/null @@ -1,364 +0,0 @@ -use std::fs::{self, OpenOptions}; -use std::path::{Path, PathBuf}; - -use clap::Subcommand; -use eyre::{Result, WrapErr}; - -use atuin_client::{ - database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, -}; -use tracing_appender::rolling::{RollingFileAppender, Rotation}; -use tracing_subscriber::{ - Layer, filter::EnvFilter, filter::LevelFilter, fmt, fmt::format::FmtSpan, prelude::*, -}; - -fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { - let cutoff = std::time::SystemTime::now() - - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); - - let Ok(entries) = fs::read_dir(log_dir) else { - return; - }; - - for entry in entries.flatten() { - let path = entry.path(); - let Some(name) = path.file_name().and_then(|n| n.to_str()) else { - continue; - }; - - // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" - if !name.starts_with(prefix) || name == prefix { - continue; - } - - if let Ok(metadata) = entry.metadata() - && let Ok(modified) = metadata.modified() - && modified < cutoff - { - let _ = fs::remove_file(&path); - } - } -} - -#[cfg(feature = "sync")] -mod sync; - -#[cfg(feature = "sync")] -mod account; - -#[cfg(feature = "daemon")] -mod daemon; - -mod config; -mod default_config; -mod doctor; -mod history; -mod import; -mod info; -mod init; -mod search; -mod setup; -mod stats; -mod store; -mod wrapped; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Setup Atuin features - #[command()] - Setup, - - /// Manipulate shell history - #[command(subcommand)] - History(history::Cmd), - - /// Import shell history from file - #[command(subcommand)] - Import(import::Cmd), - - /// Calculate statistics for your history - Stats(stats::Cmd), - - /// Interactive history search - Search(search::Cmd), - - #[cfg(feature = "sync")] - #[command(flatten)] - Sync(sync::Cmd), - - /// Manage your sync account - #[cfg(feature = "sync")] - Account(account::Cmd), - - /// Manage the atuin data store - #[command(subcommand)] - Store(store::Cmd), - - /// Print Atuin's shell init script - #[command()] - Init(init::Cmd), - - /// Information about dotfiles locations and ENV vars - #[command()] - Info, - - /// Run the doctor to check for common issues - #[command()] - Doctor, - - #[command()] - Wrapped { year: Option }, - - /// *Experimental* Manage the background daemon - #[cfg(feature = "daemon")] - #[command()] - Daemon(daemon::Cmd), - - /// Print the default atuin configuration (config.toml) - #[command()] - DefaultConfig, - - #[command(subcommand)] - Config(config::Cmd), -} - -impl Cmd { - pub fn run(self) -> Result<()> { - // Daemonize before creating the async runtime – fork() inside a live - // tokio runtime corrupts its internal state. - #[cfg(all(unix, feature = "daemon"))] - if let Self::Daemon(ref cmd) = self - && cmd.should_daemonize() - { - daemon::daemonize_current_process()?; - } - - let mut runtime = tokio::runtime::Builder::new_current_thread(); - - let runtime = runtime.enable_all().build().unwrap(); - - // For non-history commands, we want to initialize logging and the theme manager before - // doing anything else. History commands are performance-sensitive and run before and after - // every shell command, so we want to skip any unnecessary initialization for them. - let settings = Settings::new().wrap_err("could not load client settings")?; - let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); - let res = runtime.block_on(self.run_inner(settings, theme_manager)); - - runtime.shutdown_timeout(std::time::Duration::from_millis(50)); - - res - } - - #[expect(clippy::too_many_lines, clippy::future_not_send)] - async fn run_inner( - self, - mut settings: Settings, - mut theme_manager: theme::ThemeManager, - ) -> Result<()> { - // ATUIN_LOG env var overrides config file level settings - let env_log_set = std::env::var("ATUIN_LOG").is_ok(); - - // Base filter from env var (or empty if not set) - let base_filter = - EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); - - let is_interactive_search = matches!(&self, Self::Search(cmd) if cmd.is_interactive()); - // Use file-based logging for interactive search (TUI mode) - let use_search_logging = is_interactive_search && settings.logs.search_enabled(); - - // Use file-based logging for daemon - #[cfg(feature = "daemon")] - let use_daemon_logging = matches!(&self, Self::Daemon(_)) && settings.logs.daemon_enabled(); - - #[cfg(not(feature = "daemon"))] - let use_daemon_logging = false; - - // Check if daemon should also log to console - #[cfg(feature = "daemon")] - let daemon_show_logs = matches!(&self, Self::Daemon(cmd) if cmd.show_logs()); - - #[cfg(not(feature = "daemon"))] - let daemon_show_logs = false; - - // Set up span timing JSON logs if ATUIN_SPAN is set - let span_path = std::env::var("ATUIN_SPAN").ok().map(|p| { - if p.is_empty() { - "atuin-spans.json".to_string() - } else { - p - } - }); - - // Helper to create span timing layer - macro_rules! make_span_layer { - ($path:expr) => {{ - let span_file = OpenOptions::new() - .create(true) - .truncate(true) - .write(true) - .open($path)?; - Some( - fmt::layer() - .json() - .with_writer(span_file) - .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) - .with_filter(LevelFilter::TRACE), - ) - }}; - } - - // Build the subscriber with all configured layers - if use_search_logging { - let search_filename = settings.logs.search.file.clone(); - let log_dir = PathBuf::from(&settings.logs.dir); - fs::create_dir_all(&log_dir)?; - - // Clean up old log files - cleanup_old_logs(&log_dir, &search_filename, settings.logs.search_retention()); - - let file_appender = - RollingFileAppender::new(Rotation::DAILY, &log_dir, &search_filename); - - // Use config level unless ATUIN_LOG is set - let filter = if env_log_set { - base_filter - } else { - EnvFilter::default() - .add_directive(settings.logs.search_level().as_directive().parse()?) - .add_directive("sqlx_sqlite::regexp=off".parse()?) - }; - - let base = tracing_subscriber::registry().with( - fmt::layer() - .with_writer(file_appender) - .with_ansi(false) - .with_filter(filter), - ); - - match &span_path { - Some(sp) => { - base.with(make_span_layer!(sp)).init(); - } - None => { - base.init(); - } - } - } else if use_daemon_logging { - let daemon_filename = settings.logs.daemon.file.clone(); - let log_dir = PathBuf::from(&settings.logs.dir); - fs::create_dir_all(&log_dir)?; - - // Clean up old log files - cleanup_old_logs(&log_dir, &daemon_filename, settings.logs.daemon_retention()); - - let file_appender = - RollingFileAppender::new(Rotation::DAILY, &log_dir, &daemon_filename); - - // Use config level unless ATUIN_LOG is set - let file_filter = if env_log_set { - base_filter - } else { - EnvFilter::default() - .add_directive(settings.logs.daemon_level().as_directive().parse()?) - .add_directive("sqlx_sqlite::regexp=off".parse()?) - }; - - let file_layer = fmt::layer() - .with_writer(file_appender) - .with_ansi(false) - .with_filter(file_filter); - - // Optionally add console layer for --show-logs - if daemon_show_logs { - let console_filter = EnvFilter::from_env("ATUIN_LOG") - .add_directive("sqlx_sqlite::regexp=off".parse()?); - - let console_layer = fmt::layer().with_filter(console_filter); - - let base = tracing_subscriber::registry() - .with(file_layer) - .with(console_layer); - - match &span_path { - Some(sp) => { - base.with(make_span_layer!(sp)).init(); - } - None => { - base.init(); - } - } - } else { - let base = tracing_subscriber::registry().with(file_layer); - - match &span_path { - Some(sp) => { - base.with(make_span_layer!(sp)).init(); - } - None => { - base.init(); - } - } - } - } - - tracing::trace!(command = ?self, "client command"); - - // Skip initializing any databases for history - // This is a pretty hot path, as it runs before and after every single command the user - // runs - match self { - Self::History(history) => return history.run(&settings).await, - Self::Init(init) => { - init.run(&settings); - return Ok(()); - } - Self::Doctor => return doctor::run(&settings).await, - Self::Config(config) => return config.run(&settings).await, - _ => {} - } - - let db_path = PathBuf::from(settings.db_path.as_str()); - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - - let db = Sqlite::new(db_path, settings.local_timeout).await?; - let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - - let theme_name = settings.theme.name.clone(); - let theme = theme_manager.load_theme(theme_name.as_str(), settings.theme.max_depth); - - match self { - Self::Setup => setup::run(&settings).await, - Self::Import(import) => import.run(&db).await, - Self::Stats(stats) => stats.run(&db, &settings, theme).await, - Self::Search(search) => search.run(db, &mut settings, sqlite_store, theme).await, - - #[cfg(feature = "sync")] - Self::Sync(sync) => sync.run(settings, &db, sqlite_store).await, - - #[cfg(feature = "sync")] - Self::Account(account) => account.run(settings, sqlite_store).await, - - Self::Store(store) => store.run(&settings, &db, sqlite_store).await, - - Self::Info => { - info::run(&settings); - Ok(()) - } - - Self::DefaultConfig => { - default_config::run(); - Ok(()) - } - - Self::Wrapped { year } => wrapped::run(year, &db, &settings, theme).await, - - #[cfg(feature = "daemon")] - Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, - - Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { - unreachable!() - } - } - } -} diff --git a/crates/atuin/src/command/client/account.rs b/crates/atuin/src/command/client/account.rs deleted file mode 100644 index fc1c9343..00000000 --- a/crates/atuin/src/command/client/account.rs +++ /dev/null @@ -1,47 +0,0 @@ -use clap::{Args, Subcommand}; -use eyre::Result; - -use atuin_client::record::sqlite_store::SqliteStore; -use atuin_client::settings::Settings; - -pub mod change_password; -pub mod delete; -pub mod login; -pub mod logout; -pub mod register; - -#[derive(Args, Debug)] -pub struct Cmd { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand, Debug)] -pub enum Commands { - /// Login to the configured server - Login(login::Cmd), - - /// Register a new account - Register(register::Cmd), - - /// Log out - Logout, - - /// Delete your account, and all synced data - Delete(delete::Cmd), - - /// Change your password - ChangePassword(change_password::Cmd), -} - -impl Cmd { - pub async fn run(self, settings: Settings, store: SqliteStore) -> Result<()> { - match self.command { - Commands::Login(l) => l.run(&settings, &store).await, - Commands::Register(r) => r.run(&settings).await, - Commands::Logout => logout::run().await, - Commands::Delete(d) => d.run(&settings).await, - Commands::ChangePassword(c) => c.run(&settings).await, - } - } -} diff --git a/crates/atuin/src/command/client/account/change_password.rs b/crates/atuin/src/command/client/account/change_password.rs deleted file mode 100644 index 234d4dc0..00000000 --- a/crates/atuin/src/command/client/account/change_password.rs +++ /dev/null @@ -1,67 +0,0 @@ -use clap::Parser; -use eyre::{Result, bail}; - -use atuin_client::{ - auth::{self, MutateResponse}, - settings::Settings, -}; -use rpassword::prompt_password; - -#[derive(Parser, Debug)] -pub struct Cmd { - #[clap(long, short)] - pub current_password: Option, - - #[clap(long, short)] - pub new_password: Option, - - /// The two-factor authentication code for your account, if any - #[clap(long, short)] - pub totp_code: Option, -} - -impl Cmd { - pub async fn run(&self, settings: &Settings) -> Result<()> { - if !settings.logged_in().await? { - bail!("You are not logged in"); - } - - let client = auth::auth_client(settings).await; - - let current_password = self.current_password.clone().unwrap_or_else(|| { - prompt_password("Please enter the current password: ") - .expect("Failed to read from input") - }); - - if current_password.is_empty() { - bail!("please provide the current password"); - } - - let new_password = self.new_password.clone().unwrap_or_else(|| { - prompt_password("Please enter the new password: ").expect("Failed to read from input") - }); - - if new_password.is_empty() { - bail!("please provide a new password"); - } - - let mut totp_code = self.totp_code.clone(); - - loop { - let response = client - .change_password(¤t_password, &new_password, totp_code.as_deref()) - .await?; - - match response { - MutateResponse::Success => break, - MutateResponse::TwoFactorRequired => { - totp_code = Some(super::login::or_user_input(None, "two-factor code")); - } - } - } - - println!("Account password successfully changed!"); - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/account/delete.rs b/crates/atuin/src/command/client/account/delete.rs deleted file mode 100644 index a5e7f0dd..00000000 --- a/crates/atuin/src/command/client/account/delete.rs +++ /dev/null @@ -1,57 +0,0 @@ -use atuin_client::{ - auth::{self, MutateResponse}, - settings::Settings, -}; -use clap::Parser; -use eyre::{Result, bail}; - -use super::login::{or_user_input, read_user_password}; - -#[derive(Parser, Debug)] -pub struct Cmd { - #[clap(long, short)] - pub password: Option, - - /// The two-factor authentication code for your account, if any - #[clap(long, short)] - pub totp_code: Option, -} - -impl Cmd { - pub async fn run(&self, settings: &Settings) -> Result<()> { - if !settings.logged_in().await? { - bail!("You are not logged in"); - } - - let client = auth::auth_client(settings).await; - - let password = self.password.clone().unwrap_or_else(read_user_password); - - if password.is_empty() { - bail!("please provide your password"); - } - - let mut totp_code = self.totp_code.clone(); - - loop { - let response = client - .delete_account(&password, totp_code.as_deref()) - .await?; - - match response { - MutateResponse::Success => break, - MutateResponse::TwoFactorRequired => { - totp_code = Some(or_user_input(None, "two-factor code")); - } - } - } - - // Clean up sessions from meta store - let meta = Settings::meta_store().await?; - meta.delete_session().await?; - - println!("Your account is deleted"); - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/account/login.rs b/crates/atuin/src/command/client/account/login.rs deleted file mode 100644 index e320e80b..00000000 --- a/crates/atuin/src/command/client/account/login.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{io, path::PathBuf}; - -use clap::Parser; -use eyre::{Context, Result, bail}; -use tokio::{fs::File, io::AsyncWriteExt}; - -use atuin_client::{ - auth::{self, AuthResponse}, - encryption::{decode_key, load_key}, - record::sqlite_store::SqliteStore, - record::store::Store, - record::sync::{self, SyncError}, - settings::{Settings, SyncAuth}, -}; -use rpassword::prompt_password; - -#[derive(Parser, Debug)] -pub struct Cmd { - #[clap(long, short)] - pub username: Option, - - #[clap(long, short)] - pub password: Option, - - /// The encryption key for your account - #[clap(long, short)] - pub key: Option, - - /// The two-factor authentication code for your account, if any - #[clap(long, short)] - pub totp_code: Option, - - #[clap(long, hide = true)] - pub from_registration: bool, -} - -fn get_input() -> Result { - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - Ok(input.trim_end_matches(&['\r', '\n'][..]).to_string()) -} - -impl Cmd { - pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { - match settings.resolve_sync_auth().await { - SyncAuth::Legacy { .. } => { - println!("You are logged in to your sync server."); - println!("Run 'atuin logout' to log out."); - return Ok(()); - } - SyncAuth::NotLoggedIn { .. } => {} - } - - self.run_legacy_login(settings, store).await?; - - verify_key_against_remote(settings).await - } - - /// Legacy login: always prompt for username/password interactively - /// (or accept them via flags). - async fn run_legacy_login(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { - let username = or_user_input(self.username.clone(), "username"); - let password = self.password.clone().unwrap_or_else(read_user_password); - - self.prompt_and_store_key(settings, store).await?; - - let client = auth::auth_client(settings).await; - let response = client.login(&username, &password, None).await?; - - match response { - AuthResponse::Success { session, .. } => { - Settings::meta_store().await?.save_session(&session).await?; - } - AuthResponse::TwoFactorRequired => { - // Legacy server doesn't support 2FA, so this shouldn't happen. - bail!("unexpected two-factor requirement from legacy server"); - } - } - - println!("Logged in!"); - Ok(()) - } - - async fn prompt_and_store_key(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { - let key_path = settings.key_path.as_str(); - let key_path = PathBuf::from(key_path); - - println!("IMPORTANT"); - println!( - "If you are already logged in on another machine, you must ensure that the key you use here is the same as the key you used there." - ); - println!("You can find your key by running 'atuin key' on the other machine."); - println!("Do not share this key with anyone."); - println!("\nRead more here: https://docs.atuin.sh/guide/sync/#login \n"); - - let key = or_user_input( - self.key.clone(), - "encryption key [blank to use existing key file]", - ); - - if key.is_empty() { - if key_path.exists() { - let bytes = fs_err::read_to_string(&key_path).context(format!( - "Existing key file at '{}' could not be read", - key_path.to_string_lossy() - ))?; - if decode_key(bytes).is_err() { - bail!(format!( - "The key in existing key file at '{}' is invalid", - key_path.to_string_lossy() - )); - } - } else { - panic!( - "No key provided and no existing key file found. Please use 'atuin key' on your other machine, or recover your key from a backup" - ) - } - } else if !key_path.exists() { - if decode_key(key.clone()).is_err() { - bail!("The specified key is invalid"); - } - - let mut file = File::create(&key_path).await?; - file.write_all(key.as_bytes()).await?; - } else { - // we now know that the user has logged in specifying a key, AND that the key path - // exists - - // 1. check if the saved key and the provided key match. if so, nothing to do. - // 2. if not, re-encrypt the local history and overwrite the key - let current_key: [u8; 32] = load_key(settings)?.into(); - - let encoded = key.clone(); // gonna want to save it in a bit - let new_key: [u8; 32] = decode_key(key) - .context("Could not decode provided key; is not valid base64-encoded key")? - .into(); - - if new_key != current_key { - println!("\nRe-encrypting local store with new key"); - - store.re_encrypt(¤t_key, &new_key).await?; - - println!("Writing new key"); - let mut file = File::create(&key_path).await?; - file.write_all(encoded.as_bytes()).await?; - } - } - - Ok(()) - } -} - -async fn verify_key_against_remote(settings: &Settings) -> Result<()> { - let key: [u8; 32] = load_key(settings) - .context("could not load encryption key for verification")? - .into(); - - let client = sync::build_client(settings).await?; - let remote_index = match client.record_status().await { - Ok(idx) => idx, - Err(e) => { - tracing::warn!("could not fetch remote status to verify key: {e}"); - return Ok(()); - } - }; - - match sync::check_encryption_key(&client, &remote_index, &key).await { - Ok(()) => Ok(()), - Err(SyncError::WrongKey) => { - // Roll back the saved session so the user is not left in a - // half-authenticated state with a key that can't read the data. - if let Ok(meta) = Settings::meta_store().await { - let _ = meta.delete_session().await; - } - crate::print_error::print_error( - "Wrong encryption key", - "The encryption key on this machine does not match the data on the server. \ - You have been logged out.\n\n\ - To fix this, find your existing key by running `atuin key` on a machine that \ - already syncs successfully, then run `atuin login` again here with that key.", - ); - std::process::exit(1); - } - Err(e) => { - // Non-key error (e.g. transient network issue). Don't fail the - // login — the user is authenticated and can sync later when the - // network recovers. - tracing::warn!("could not verify encryption key against remote: {e}"); - Ok(()) - } - } -} - -pub(super) fn or_user_input(value: Option, name: &'static str) -> String { - value.unwrap_or_else(|| read_user_input(name)) -} - -pub(super) fn read_user_password() -> String { - let password = prompt_password("Please enter password: "); - password.expect("Failed to read from input") -} - -fn read_user_input(name: &'static str) -> String { - eprint!("Please enter {name}: "); - get_input().expect("Failed to read from input") -} diff --git a/crates/atuin/src/command/client/account/logout.rs b/crates/atuin/src/command/client/account/logout.rs deleted file mode 100644 index b958e65a..00000000 --- a/crates/atuin/src/command/client/account/logout.rs +++ /dev/null @@ -1,5 +0,0 @@ -use eyre::Result; - -pub async fn run() -> Result<()> { - atuin_client::logout::logout().await -} diff --git a/crates/atuin/src/command/client/account/register.rs b/crates/atuin/src/command/client/account/register.rs deleted file mode 100644 index bd836e7b..00000000 --- a/crates/atuin/src/command/client/account/register.rs +++ /dev/null @@ -1,67 +0,0 @@ -use clap::Parser; -use eyre::{Result, bail}; - -use super::login::or_user_input; -use atuin_client::settings::{Settings, SyncAuth}; - -#[derive(Parser, Debug)] -pub struct Cmd { - #[clap(long, short)] - pub username: Option, - - #[clap(long, short)] - pub password: Option, - - #[clap(long, short)] - pub email: Option, -} - -impl Cmd { - pub async fn run(&self, settings: &Settings) -> Result<()> { - match settings.resolve_sync_auth().await { - SyncAuth::Legacy { .. } => { - println!("You are already logged in."); - println!("Run 'atuin logout' to log out."); - return Ok(()); - } - - SyncAuth::NotLoggedIn { .. } => {} - } - - // Legacy registration flow - println!("Registering for an Atuin Sync account"); - - let username = or_user_input(self.username.clone(), "username"); - let email = or_user_input(self.email.clone(), "email"); - let password = self - .password - .clone() - .unwrap_or_else(super::login::read_user_password); - - if password.is_empty() { - bail!("please provide a password"); - } - - let session = atuin_client::api_client::register( - settings.sync_address.as_str(), - &username, - &email, - &password, - ) - .await?; - - let meta = Settings::meta_store().await?; - meta.save_session(&session.session).await?; - - let _key = atuin_client::encryption::load_key(settings)?; - - println!( - "Registration successful! Please make a note of your key (run 'atuin key') and keep it safe." - ); - println!( - "You will need it to log in on other devices, and we cannot help recover it if you lose it." - ); - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/config.rs b/crates/atuin/src/command/client/config.rs deleted file mode 100644 index 5ec5f7f3..00000000 --- a/crates/atuin/src/command/client/config.rs +++ /dev/null @@ -1,352 +0,0 @@ -use atuin_client::settings::Settings; -use clap::{Args, Subcommand, ValueEnum}; -use eyre::Result; -use toml_edit::{Document, DocumentMut, Item, Table, TableLike, Value}; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Get a configuration value from your config.toml file - /// or after defaults and overrides are applied - #[command()] - Get(GetCmd), - - /// Set a configuration value in your config.toml file - #[command()] - Set(SetCmd), - - /// Print all configuration values from your config.toml file - /// in TOML format - /// - /// If a key is provided, only print the value of that key and all its children - #[command()] - Print(PrintCmd), -} - -impl Cmd { - pub async fn run(self, settings: &Settings) -> Result<()> { - match self { - Self::Get(get) => get.run(settings).await, - Self::Set(set) => set.run(settings).await, - Self::Print(print) => print.run(settings).await, - } - } -} - -/// Get a configuration value from your config.toml file, -/// or optionally the effective value after defaults and overrides are applied. -#[derive(Args, Debug)] -pub struct GetCmd { - /// The configuration key to get - pub key: String, - - /// Print the value after defaults and overrides are applied - #[arg(long, short)] - pub resolved: bool, - - /// Print both the config file value and the resolved value - #[arg(long, short)] - pub verbose: bool, -} - -impl GetCmd { - pub async fn run(&self, _settings: &Settings) -> Result<()> { - let key = self.key.trim(); - if key.is_empty() || key.contains(char::is_whitespace) { - eyre::bail!("Config key must be non-empty and must not contain whitespace"); - } - - if self.verbose { - println!("Config file:"); - self.print_current_value(key, " ").await?; - println!("\nResolved:"); - Self::print_effective_value(key, " "); - return Ok(()); - } - - if self.resolved { - Self::print_effective_value(key, ""); - } else { - self.print_current_value(key, "").await?; - } - - Ok(()) - } - - async fn print_current_value(&self, key: &str, prefix: &str) -> Result<()> { - let config_file = Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let doc = config_str.parse::>()?; - - let current = get_deep_key(&doc, key); - - match current { - Some(item) if item.is_table() || item.is_inline_table() => { - let table = item - .as_table_like() - .expect("is_table()/is_inline_table() but no table"); - println!("{prefix}[{key}]"); - dump_table(table, prefix, &mut vec![key.to_string()])?; - } - Some(item) => { - let val = item.to_string(); - let val = val.trim().trim_matches('"'); - println!("{prefix}{val}"); - } - None => { - println!("{prefix}(not set in config file)"); - } - } - - Ok(()) - } - - fn print_effective_value(key: &str, prefix: &str) { - match Settings::get_config_value(key) { - Ok(value) => { - for line in value.lines() { - println!("{prefix}{line}"); - } - } - Err(_) => { - println!("{prefix}(unknown key)"); - } - } - } -} - -#[derive(Args, Debug)] -pub struct SetCmd { - /// The configuration key to set - pub key: String, - - /// The value to set - pub value: String, - - /// Store value as an explicit type - #[arg(long = "type", short, value_enum, default_value_t = ValueType::Auto, value_name = "TYPE")] - pub the_type: ValueType, -} - -#[derive(ValueEnum, Debug, Clone, PartialEq, Eq)] -pub enum ValueType { - /// Automatically determine the type of the value - Auto, - /// Store value as a string - String, - /// Store value as a boolean - Boolean, - /// Store value as an integer - Integer, - /// Store the value as a float - Float, -} - -impl SetCmd { - pub async fn run(self, _settings: &Settings) -> Result<()> { - let key = self.key.trim(); - if key.is_empty() || key.contains(char::is_whitespace) { - eyre::bail!("Config key must be non-empty and must not contain whitespace"); - } - - let config_file = Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let mut doc: DocumentMut = config_str.parse()?; - - // When using auto type detection, try to match the existing value's type - // so we don't accidentally change e.g. "300" (string) to 300 (integer) - let existing_type = detect_existing_type(&doc, key); - let value = self.parse_value(existing_type.as_ref())?; - set_deep_key(&mut doc, key, value)?; - - tokio::fs::write(&config_file, doc.to_string()).await?; - - Ok(()) - } - - fn parse_value(&self, existing_type: Option<&ValueType>) -> Result { - let raw = &self.value; - - // Explicit --type takes priority, then existing value type, then auto-detect - let effective_type = if self.the_type != ValueType::Auto { - &self.the_type - } else if let Some(existing) = existing_type { - existing - } else { - &ValueType::Auto - }; - - match effective_type { - ValueType::String => Ok(Value::from(raw.as_str())), - ValueType::Boolean => { - let b: bool = raw - .parse() - .map_err(|_| eyre::eyre!("invalid boolean value: {raw}"))?; - Ok(Value::from(b)) - } - ValueType::Integer => { - let i: i64 = raw - .parse() - .map_err(|_| eyre::eyre!("invalid integer value: {raw}"))?; - Ok(Value::from(i)) - } - ValueType::Float => { - let f: f64 = raw - .parse() - .map_err(|_| eyre::eyre!("invalid float value: {raw}"))?; - Ok(Value::from(f)) - } - ValueType::Auto => { - if raw == "true" || raw == "false" { - return Ok(Value::from(raw == "true")); - } - if let Ok(i) = raw.parse::() { - return Ok(Value::from(i)); - } - if let Ok(f) = raw.parse::() { - return Ok(Value::from(f)); - } - Ok(Value::from(raw.as_str())) - } - } - } -} - -#[derive(Args, Debug)] -pub struct PrintCmd { - /// Print the value of a specific key and all its children - pub key: Option, -} - -impl PrintCmd { - pub async fn run(&self, _settings: &Settings) -> Result<()> { - let config_file = Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let doc = config_str.parse::>()?; - - if let Some(key) = &self.key { - let current = get_deep_key(&doc, key); - - if let Some(current) = current { - if current.is_table() || current.is_inline_table() { - println!("[{key}]"); - dump_table( - current - .as_table_like() - .expect("is_table()/is_inline_table() but no table"), - "", - &mut vec![key.clone()], - )?; - } else { - println!("{}", current.to_string().trim().trim_matches('"')); - } - } else { - println!("key not found"); - } - } else { - dump_table(doc.as_table(), "", &mut Vec::new())?; - } - - Ok(()) - } -} - -fn dump_table(table: &dyn TableLike, prefix: &str, stack: &mut Vec) -> Result<()> { - for (key, value) in table.iter() { - if value.is_table() || value.is_inline_table() { - stack.push(key.to_string()); - - let table = value - .as_table_like() - .expect("is_table()/is_inline_table() but no table"); - - println!("\n{}[{}]", prefix, stack.join(".")); - - dump_table(table, prefix, stack)?; - - stack.pop(); - } else { - println!("{prefix}{key} = {value}"); - } - } - - Ok(()) -} - -fn get_deep_key<'doc>(doc: &'doc Document, key: &str) -> Option<&'doc Item> { - let parts = key.split('.'); - let mut current: Option<&Item> = Some(doc.as_item()); - - for part in parts { - current = current - .and_then(|item| item.as_table_like()) - .and_then(|table| table.get(part)); - } - - current -} - -/// Detect the TOML type of an existing key in the document, so `set` with auto -/// type detection preserves the original type rather than guessing from the value string. -fn detect_existing_type(doc: &DocumentMut, key: &str) -> Option { - let parts: Vec<&str> = key.split('.').collect(); - let mut current: &dyn TableLike = doc.as_table(); - - for &part in &parts[..parts.len().saturating_sub(1)] { - current = current.get(part)?.as_table_like()?; - } - - let last = parts.last()?; - let v = current.get(last)?.as_value()?; - - if v.is_str() { - Some(ValueType::String) - } else if v.is_bool() { - Some(ValueType::Boolean) - } else if v.is_integer() { - Some(ValueType::Integer) - } else if v.is_float() { - Some(ValueType::Float) - } else { - None - } -} - -fn set_deep_key(doc: &mut DocumentMut, key: &str, value: Value) -> Result<()> { - let parts: Vec<&str> = key.split('.').collect(); - - if parts.is_empty() { - eyre::bail!("empty config key"); - } - - let mut current: &mut dyn TableLike = doc.as_table_mut(); - - // Navigate/create intermediate tables - for &part in &parts[..parts.len() - 1] { - if !current.contains_key(part) { - current.insert(part, Item::Table(Table::new())); - } - current = current - .get_mut(part) - .expect("just inserted or already exists") - .as_table_like_mut() - .ok_or_else(|| eyre::eyre!("'{}' exists but is not a table", part))?; - } - - let last = *parts.last().unwrap(); - - // Don't silently overwrite a table with a scalar value - if let Some(existing) = current.get(last) - && (existing.is_table() || existing.is_inline_table()) - { - eyre::bail!( - "'{}' is a table; use a dotted key like '{}.key' to set a value within it", - key, - key - ); - } - - current.insert(last, Item::Value(value)); - - Ok(()) -} diff --git a/crates/atuin/src/command/client/daemon.rs b/crates/atuin/src/command/client/daemon.rs deleted file mode 100644 index c3dcf9d0..00000000 --- a/crates/atuin/src/command/client/daemon.rs +++ /dev/null @@ -1,784 +0,0 @@ -use std::fs::{self, File, OpenOptions}; -use std::io::{ErrorKind, Write}; -#[cfg(unix)] -use std::os::unix::net::UnixStream as StdUnixStream; -use std::path::{Path, PathBuf}; -use std::process::{Command, Stdio}; -use std::time::{Duration, Instant}; - -use atuin_client::{ - database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, -}; -use atuin_daemon::DaemonEvent; -use atuin_daemon::client::{ControlClient, DaemonClientErrorKind, HistoryClient, classify_error}; -use clap::Subcommand; -#[cfg(unix)] -use daemonize::Daemonize; -use eyre::{Result, WrapErr, bail, eyre}; -use fs4::fs_std::FileExt; -use tokio::time::sleep; - -#[derive(clap::Args, Debug)] -pub struct Cmd { - /// Internal flag for daemonization - #[arg(long, hide = true)] - daemonize: bool, - - /// Also write daemon logs to the console (useful for debugging) - #[arg(long)] - show_logs: bool, - - #[command(subcommand)] - subcmd: Option, -} - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum SubCmd { - /// Start the daemon server - Start { - #[arg(long, hide = true)] - daemonize: bool, - - /// Also write daemon logs to the console (useful for debugging) - #[arg(long)] - show_logs: bool, - - /// Force start: kill existing daemon process and reset the socket - #[arg(long)] - force: bool, - }, - - /// Show the daemon's current status - Status, - - /// Stop the daemon gracefully - Stop, - - /// Restart the daemon (stop, then start in background) - Restart, -} - -impl Cmd { - /// Returns `true` when the process should daemonize before creating the - /// async runtime or opening any database connections. - #[cfg(unix)] - pub fn should_daemonize(&self) -> bool { - match &self.subcmd { - Some(SubCmd::Start { daemonize, .. }) => *daemonize, - None => self.daemonize, - _ => false, - } - } - - /// Returns `true` when logs should also be written to the console. - pub fn show_logs(&self) -> bool { - match &self.subcmd { - Some(SubCmd::Start { show_logs, .. }) => *show_logs, - None => self.show_logs, - _ => false, - } - } - - pub async fn run( - self, - settings: Settings, - store: SqliteStore, - history_db: Sqlite, - ) -> Result<()> { - match self.subcmd { - None => { - eprintln!("Warning: `atuin daemon` is deprecated, use `atuin daemon start`"); - run(settings, store, history_db, false).await - } - Some(SubCmd::Start { force, .. }) => run(settings, store, history_db, force).await, - Some(SubCmd::Status) => status_cmd(&settings).await, - Some(SubCmd::Stop) => stop_cmd(&settings).await, - Some(SubCmd::Restart) => restart_cmd(&settings).await, - } - } -} - -const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); -const DAEMON_PROTOCOL_VERSION: u32 = 1; -const STARTUP_POLL: Duration = Duration::from_millis(40); -const LOCK_POLL: Duration = Duration::from_millis(20); -const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; - -struct PidfileGuard { - file: File, -} - -impl PidfileGuard { - fn acquire(path: &Path) -> Result { - let mut file = open_lock_file(path)?; - - if !file.try_lock_exclusive()? { - bail!( - "daemon already running (pidfile lock busy at {})", - path.display() - ); - } - - file.set_len(0) - .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; - writeln!(file, "{}", std::process::id()) - .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) - .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; - - Ok(Self { file }) - } -} - -impl Drop for PidfileGuard { - fn drop(&mut self) { - let _ = self.file.unlock(); - } -} - -enum Probe { - Ready(HistoryClient), - NeedsRestart(String), - Unreachable(eyre::Report), -} - -fn daemon_matches_expected(version: &str, protocol: u32) -> bool { - version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION -} - -fn daemon_mismatch_message(version: &str, protocol: u32) -> String { - if protocol == DAEMON_PROTOCOL_VERSION { - format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") - } else { - format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") - } -} - -fn is_legacy_daemon_error(err: &eyre::Report) -> bool { - matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) -} - -fn should_retry_after_error(err: &eyre::Report) -> bool { - matches!( - classify_error(err), - DaemonClientErrorKind::Connect - | DaemonClientErrorKind::Unavailable - | DaemonClientErrorKind::Unimplemented - ) -} - -fn daemon_startup_lock_path(pidfile_path: &Path) -> PathBuf { - let mut os = pidfile_path.as_os_str().to_os_string(); - os.push(".startup.lock"); - PathBuf::from(os) -} - -fn open_lock_file(path: &Path) -> Result { - if let Some(parent) = path.parent() { - fs::create_dir_all(parent) - .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; - } - - OpenOptions::new() - .read(true) - .write(true) - .create(true) - .truncate(false) - .open(path) - .wrap_err_with(|| format!("could not open lock file {}", path.display())) -} - -async fn wait_for_lock(path: &Path, timeout: Duration) -> Result { - let file = open_lock_file(path)?; - let start = Instant::now(); - - loop { - match file.try_lock_exclusive() { - Ok(true) => return Ok(file), - Ok(false) => { - if start.elapsed() >= timeout { - bail!("timed out waiting for lock at {}", path.display()); - } - - sleep(LOCK_POLL).await; - } - Err(err) => { - return Err(eyre!("could not lock {}: {err}", path.display())); - } - } - } -} - -async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { - let file = wait_for_lock(path, timeout).await?; - file.unlock() - .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; - Ok(()) -} - -async fn connect_client(settings: &Settings) -> Result { - HistoryClient::new( - #[cfg(not(unix))] - settings.daemon.tcp_port, - #[cfg(unix)] - settings.daemon.socket_path.clone(), - ) - .await -} - -async fn probe(settings: &Settings) -> Probe { - let mut client = match connect_client(settings).await { - Ok(client) => client, - Err(err) => return Probe::Unreachable(err), - }; - - match client.status().await { - Ok(status) => { - if daemon_matches_expected(&status.version, status.protocol) { - Probe::Ready(client) - } else { - Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) - } - } - Err(err) => Probe::Unreachable(err), - } -} - -async fn request_shutdown(settings: &Settings) { - if let Ok(mut client) = connect_client(settings).await { - let _ = client.shutdown().await; - } -} - -fn spawn_daemon_process() -> Result<()> { - let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; - - let mut cmd = Command::new(exe); - cmd.arg("daemon") - .arg("start") - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()); - - #[cfg(unix)] - cmd.arg("--daemonize"); - - cmd.spawn().wrap_err("failed to spawn daemon process")?; - - Ok(()) -} - -fn startup_timeout(settings: &Settings) -> Duration { - Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) -} - -#[cfg(unix)] -fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { - if settings.daemon.systemd_socket { - return Ok(()); - } - - let socket_path = Path::new(&settings.daemon.socket_path); - if !socket_path.exists() { - return Ok(()); - } - - match StdUnixStream::connect(socket_path) { - Ok(stream) => { - drop(stream); - Ok(()) - } - Err(err) if err.kind() == ErrorKind::ConnectionRefused => { - fs::remove_file(socket_path).wrap_err_with(|| { - format!( - "failed to remove stale daemon socket {}", - socket_path.display() - ) - })?; - Ok(()) - } - Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), - Err(_) => Ok(()), - } -} - -async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result { - let start = Instant::now(); - let mut last_error = eyre!("daemon did not become ready"); - - loop { - match probe(settings).await { - Probe::Ready(client) => return Ok(client), - Probe::NeedsRestart(reason) => { - last_error = eyre!(reason); - } - Probe::Unreachable(err) => { - if is_legacy_daemon_error(&err) { - return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); - } - last_error = err; - } - } - - if start.elapsed() >= timeout { - return Err(last_error.wrap_err(format!( - "timed out waiting for daemon startup after {}ms", - timeout.as_millis() - ))); - } - - sleep(STARTUP_POLL).await; - } -} - -#[expect(clippy::unnecessary_wraps)] -fn ensure_autostart_supported(settings: &Settings) -> Result<()> { - #[cfg(unix)] - if settings.daemon.systemd_socket { - bail!( - "daemon autostart is incompatible with `daemon.systemd_socket = true`; use systemd to manage the daemon" - ); - } - #[cfg(not(unix))] - let _ = settings; - - Ok(()) -} - -/// Ensure the daemon is running, starting it if necessary. -/// -/// If the daemon is already running and up-to-date, this is a no-op. -/// If it is not running or needs a restart, this will spawn a new daemon -/// process and wait for it to become ready. -/// -/// Returns an error if the daemon could not be started. -pub async fn ensure_daemon_running(settings: &Settings) -> Result<()> { - ensure_autostart_supported(settings)?; - - let timeout = startup_timeout(settings); - let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); - let startup_lock_path = daemon_startup_lock_path(&pidfile_path); - let startup_lock = wait_for_lock(&startup_lock_path, timeout).await?; - - match probe(settings).await { - Probe::Ready(_) => { - drop(startup_lock); - return Ok(()); - } - Probe::NeedsRestart(_) => { - request_shutdown(settings).await; - } - Probe::Unreachable(err) => { - if is_legacy_daemon_error(&err) { - return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); - } - } - } - - // This prevents rapid-fire hook invocations from racing daemon restart. - wait_for_pidfile_available(&pidfile_path, timeout).await?; - - #[cfg(unix)] - remove_stale_socket_if_present(settings)?; - - spawn_daemon_process()?; - let _ = wait_until_ready(settings, timeout).await?; - - drop(startup_lock); - Ok(()) -} - -async fn restart_daemon(settings: &Settings) -> Result { - ensure_daemon_running(settings).await?; - connect_client(settings).await -} - -fn ensure_reply_compatible(settings: &Settings, version: &str, protocol: u32) -> Result<()> { - if daemon_matches_expected(version, protocol) { - return Ok(()); - } - - let message = daemon_mismatch_message(version, protocol); - if settings.daemon.autostart { - bail!("{message}"); - } - - bail!("{message}. Enable `daemon.autostart = true` or restart the daemon manually"); -} - -pub async fn start_history(settings: &Settings, history: History) -> Result { - match async { - connect_client(settings) - .await? - .start_history(history.clone()) - .await - } - .await - { - Ok(resp) => { - if daemon_matches_expected(&resp.version, resp.protocol) { - return Ok(resp.id); - } - - if !settings.daemon.autostart { - return Err(eyre!( - "{}. Enable `daemon.autostart = true` or restart the daemon manually", - daemon_mismatch_message(&resp.version, resp.protocol) - )); - } - } - Err(err) if !settings.daemon.autostart => return Err(err), - Err(err) if !should_retry_after_error(&err) => return Err(err), - Err(_) => {} - } - - let resp = restart_daemon(settings) - .await? - .start_history(history) - .await?; - ensure_reply_compatible(settings, &resp.version, resp.protocol)?; - Ok(resp.id) -} - -pub async fn end_history(settings: &Settings, id: String, duration: u64, exit: i64) -> Result<()> { - match async { - connect_client(settings) - .await? - .end_history(id.clone(), duration, exit) - .await - } - .await - { - Ok(resp) => { - if daemon_matches_expected(&resp.version, resp.protocol) { - return Ok(()); - } - - if !settings.daemon.autostart { - return Err(eyre!( - "{}. Enable `daemon.autostart = true` or restart the daemon manually", - daemon_mismatch_message(&resp.version, resp.protocol) - )); - } - - // End succeeded on the running daemon, so avoid replaying it. - // We only restart to make subsequent hook calls target the expected version. - let _ = restart_daemon(settings).await; - return Ok(()); - } - Err(err) if !settings.daemon.autostart => return Err(err), - Err(err) if !should_retry_after_error(&err) => return Err(err), - Err(_) => {} - } - - let resp = restart_daemon(settings) - .await? - .end_history(id, duration, exit) - .await?; - ensure_reply_compatible(settings, &resp.version, resp.protocol)?; - Ok(()) -} - -/// Emit a daemon event, auto-starting the daemon if it is not running. -/// -/// If the daemon is not reachable and `daemon.autostart` is enabled, this -/// will start the daemon and retry the event. If the daemon cannot be -/// started or the retry fails, a warning is printed to stderr. -pub async fn emit_event(settings: &Settings, event: DaemonEvent) { - // Try to connect and send - match ControlClient::from_settings(settings).await { - Ok(mut client) => { - if let Err(e) = client.send_event(event).await { - tracing::debug!(?e, "failed to send event to daemon"); - } - return; - } - Err(e) if !settings.daemon.autostart || !should_retry_after_error(&e) => { - tracing::debug!(?e, "daemon not available, skipping event emission"); - return; - } - Err(_) => {} - } - - // Auto-start the daemon and retry - if let Err(e) = ensure_daemon_running(settings).await { - eprintln!("Could not start daemon: {e}"); - return; - } - - match ControlClient::from_settings(settings).await { - Ok(mut client) => { - if let Err(e) = client.send_event(event).await { - eprintln!("Daemon started but failed to send event: {e}"); - } - } - Err(e) => { - eprintln!("Daemon started but failed to connect: {e}"); - } - } -} - -pub async fn tail_client(settings: &Settings) -> Result { - match probe(settings).await { - Probe::Ready(client) => return Ok(client), - Probe::NeedsRestart(reason) if !settings.daemon.autostart => { - bail!("{reason}. Enable `daemon.autostart = true` or restart the daemon manually"); - } - Probe::Unreachable(err) if is_legacy_daemon_error(&err) => { - return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); - } - Probe::Unreachable(err) if !settings.daemon.autostart => return Err(err), - Probe::Unreachable(err) if !should_retry_after_error(&err) => return Err(err), - Probe::NeedsRestart(_) | Probe::Unreachable(_) => {} - } - - restart_daemon(settings).await -} - -async fn status_cmd(settings: &Settings) -> Result<()> { - match probe(settings).await { - Probe::Ready(mut client) => { - let status = client.status().await?; - println!("Daemon running"); - println!(" PID: {}", status.pid); - println!(" Version: {}", status.version); - println!(" Protocol: {}", status.protocol); - println!(" Healthy: {}", status.healthy); - #[cfg(unix)] - println!(" Socket: {}", settings.daemon.socket_path); - #[cfg(not(unix))] - println!(" Port: {}", settings.daemon.tcp_port); - } - Probe::NeedsRestart(reason) => { - println!("Daemon running (needs restart)"); - println!(" Reason: {reason}"); - } - Probe::Unreachable(_) => { - println!("Daemon is not running"); - } - } - - Ok(()) -} - -async fn stop_cmd(settings: &Settings) -> Result<()> { - let Ok(mut client) = connect_client(settings).await else { - println!("Daemon is not running"); - return Ok(()); - }; - - match client.shutdown().await { - Ok(true) => { - println!("Shutdown requested"); - - let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); - let timeout = Duration::from_secs(5); - match wait_for_pidfile_available(&pidfile_path, timeout).await { - Ok(()) => println!("Daemon stopped"), - Err(_) => println!("Daemon may still be shutting down"), - } - - Ok(()) - } - Ok(false) => bail!("Daemon rejected shutdown request"), - Err(err) => Err(err.wrap_err("Failed to send shutdown request")), - } -} - -async fn restart_cmd(settings: &Settings) -> Result<()> { - // Stop if running - match probe(settings).await { - Probe::Ready(_) | Probe::NeedsRestart(_) => { - request_shutdown(settings).await; - println!("Stopping daemon..."); - - let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); - let timeout = Duration::from_secs(5); - wait_for_pidfile_available(&pidfile_path, timeout) - .await - .wrap_err("Timed out waiting for old daemon to stop")?; - } - Probe::Unreachable(_) => { - println!("No daemon running"); - } - } - - #[cfg(unix)] - remove_stale_socket_if_present(settings)?; - - spawn_daemon_process()?; - println!("Starting daemon..."); - - let timeout = startup_timeout(settings); - let status = wait_until_ready(settings, timeout).await?.status().await?; - - println!("Daemon restarted"); - println!(" PID: {}", status.pid); - println!(" Version: {}", status.version); - - Ok(()) -} - -/// Daemonize the current process. Must be called before creating the tokio -/// runtime or opening database connections, since `fork()` inside an async -/// runtime corrupts its internal state. -#[cfg(unix)] -pub fn daemonize_current_process() -> Result<()> { - let cwd = - std::env::current_dir().wrap_err("could not determine current directory for daemon")?; - - Daemonize::new() - .working_directory(cwd) - .start() - .wrap_err("failed to daemonize process")?; - - Ok(()) -} - -async fn run( - settings: Settings, - store: SqliteStore, - history_db: Sqlite, - force: bool, -) -> Result<()> { - if force { - force_cleanup(&settings); - } - - let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); - let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; - - atuin_daemon::boot(settings, store, history_db).await?; - - Ok(()) -} - -/// Force cleanup: kill existing daemon process and remove socket. -fn force_cleanup(settings: &Settings) { - let pidfile_path = Path::new(&settings.daemon.pidfile_path); - - // Read and kill the existing process if pidfile exists - if pidfile_path.exists() { - if let Ok(contents) = fs::read_to_string(pidfile_path) - && let Some(pid_str) = contents.lines().next() - && let Ok(pid) = pid_str.parse::() - { - kill_process(pid); - // Give it a moment to release resources - std::thread::sleep(Duration::from_millis(100)); - } - - // Remove the pidfile - if let Err(e) = fs::remove_file(pidfile_path) - && e.kind() != ErrorKind::NotFound - { - tracing::warn!("failed to remove pidfile: {e}"); - } - } - - // Remove the socket file - #[cfg(unix)] - { - let socket_path = Path::new(&settings.daemon.socket_path); - if socket_path.exists() - && let Err(e) = fs::remove_file(socket_path) - && e.kind() != ErrorKind::NotFound - { - tracing::warn!("failed to remove socket: {e}"); - } - } -} - -/// Kill a process by PID. -#[cfg(unix)] -fn kill_process(pid: u32) { - // Use kill command to send SIGTERM for graceful shutdown - let _ = Command::new("kill") - .args(["-TERM", &pid.to_string()]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status(); -} - -/// Kill a process by PID. -#[cfg(not(unix))] -fn kill_process(pid: u32) { - // On Windows, use taskkill - let _ = Command::new("taskkill") - .args(["/PID", &pid.to_string(), "/F"]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status(); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_version_matches() { - assert!(daemon_matches_expected( - DAEMON_VERSION, - DAEMON_PROTOCOL_VERSION - )); - } - - #[test] - fn test_version_mismatch() { - assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); - assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); - assert!(!daemon_matches_expected("0.0.0", 999)); - } - - #[test] - fn test_mismatch_message_version() { - let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); - assert!(msg.contains("out of date"), "got: {msg}"); - assert!(msg.contains("0.0.0")); - assert!(msg.contains(DAEMON_VERSION)); - } - - #[test] - fn test_mismatch_message_protocol() { - let msg = daemon_mismatch_message(DAEMON_VERSION, 999); - assert!(msg.contains("protocol mismatch"), "got: {msg}"); - } - - #[test] - fn test_startup_lock_path() { - let pidfile = Path::new("/tmp/atuin-daemon.pid"); - let lock = daemon_startup_lock_path(pidfile); - assert_eq!(lock, PathBuf::from("/tmp/atuin-daemon.pid.startup.lock")); - } - - #[test] - fn test_pidfile_guard_acquire_and_drop() { - let tmp = tempfile::tempdir().unwrap(); - let pidfile = tmp.path().join("daemon.pid"); - - { - let _guard = PidfileGuard::acquire(&pidfile).unwrap(); - // Guard holds an exclusive lock — on Windows other handles cannot - // read the file, so we verify contents after the guard is dropped. - } - - let contents = std::fs::read_to_string(&pidfile).unwrap(); - let lines: Vec<&str> = contents.lines().collect(); - assert_eq!(lines.len(), 2); - assert_eq!(lines[0], std::process::id().to_string()); - assert_eq!(lines[1], DAEMON_VERSION); - - // After guard is dropped, lock should be released — acquiring again must succeed. - let _guard2 = PidfileGuard::acquire(&pidfile).unwrap(); - } - - #[test] - fn test_pidfile_guard_prevents_double_acquire() { - let tmp = tempfile::tempdir().unwrap(); - let pidfile = tmp.path().join("daemon.pid"); - - let _guard = PidfileGuard::acquire(&pidfile).unwrap(); - let result = PidfileGuard::acquire(&pidfile); - assert!(result.is_err()); - } -} diff --git a/crates/atuin/src/command/client/default_config.rs b/crates/atuin/src/command/client/default_config.rs deleted file mode 100644 index f51e45c2..00000000 --- a/crates/atuin/src/command/client/default_config.rs +++ /dev/null @@ -1,5 +0,0 @@ -use atuin_client::settings::Settings; - -pub fn run() { - println!("{}", Settings::example_config()); -} diff --git a/crates/atuin/src/command/client/doctor.rs b/crates/atuin/src/command/client/doctor.rs deleted file mode 100644 index 1bf003db..00000000 --- a/crates/atuin/src/command/client/doctor.rs +++ /dev/null @@ -1,412 +0,0 @@ -use std::process::Command; -use std::{env, str::FromStr}; - -use atuin_client::database::Sqlite; -use atuin_client::settings::Settings; -use atuin_common::shell::{Shell, shell_name}; -use atuin_common::utils; -use colored::Colorize; -use eyre::Result; -use serde::Serialize; - -use sysinfo::{Disks, System, get_current_pid}; - -#[derive(Debug, Serialize)] -struct ShellInfo { - pub name: String, - - // best-effort, not supported on all OSes - pub default: String, - - // Detect some shell plugins that the user has installed. - // I'm just going to start with preexec/blesh - pub plugins: Vec, - - // The preexec framework used in the current session, if Atuin is loaded. - pub preexec: Option, -} - -impl ShellInfo { - // HACK ALERT! - // Many of the shell vars we need to detect are not exported :( - // So, we're going to run a interactive session and directly check the - // variable. There's a chance this won't work, so it should not be fatal. - // - // Every shell we support handles `shell -ic 'command'` - fn shellvar_exists(shell: &str, var: &str) -> bool { - let cmd = Command::new(shell) - .args([ - "-ic", - format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), - ]) - .output() - .map_or(String::new(), |v| { - let out = v.stdout; - String::from_utf8(out).unwrap_or_default() - }); - - cmd.contains("ATUIN_DOCTOR_ENV_FOUND") - } - - fn detect_preexec_framework(shell: &str) -> Option { - if env::var("ATUIN_SESSION").ok().is_none() { - None - } else if shell.starts_with("bash") || shell == "sh" { - env::var("ATUIN_PREEXEC_BACKEND") - .ok() - .filter(|value| !value.is_empty()) - .and_then(|atuin_preexec_backend| { - atuin_preexec_backend.rfind(':').and_then(|pos_colon| { - u32::from_str(&atuin_preexec_backend[..pos_colon]) - .ok() - .is_some_and(|preexec_shlvl| { - env::var("SHLVL") - .ok() - .and_then(|shlvl| u32::from_str(&shlvl).ok()) - .is_some_and(|shlvl| shlvl == preexec_shlvl) - }) - .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) - }) - }) - } else { - Some("built-in".to_string()) - } - } - - fn validate_plugin_blesh( - _shell: &str, - shell_process: &sysinfo::Process, - ble_session_id: &str, - ) -> Option { - ble_session_id - .split('/') - .nth(1) - .and_then(|field| u32::from_str(field).ok()) - .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) - .map(|_| "blesh".to_string()) - } - - pub fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec { - // consider a different detection approach if there are plugins - // that don't set shell vars - - enum PluginShellType { - Any, - Bash, - - // Note: these are currently unused - #[expect(dead_code)] - Zsh, - #[expect(dead_code)] - Fish, - #[expect(dead_code)] - Nushell, - #[expect(dead_code)] - Xonsh, - } - - enum PluginProbeType { - EnvironmentVariable(&'static str), - InteractiveShellVariable(&'static str), - } - - type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option; - - let plugin_list: [( - &str, - PluginShellType, - PluginProbeType, - Option, - ); 3] = [ - ( - "atuin", - PluginShellType::Any, - PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), - None, - ), - ( - "blesh", - PluginShellType::Bash, - PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), - Some(Self::validate_plugin_blesh), - ), - ( - "bash-preexec", - PluginShellType::Bash, - PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), - None, - ), - ]; - - plugin_list - .into_iter() - .filter(|(_, shell_type, _, _)| match shell_type { - PluginShellType::Any => true, - PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", - PluginShellType::Zsh => shell.starts_with("zsh"), - PluginShellType::Fish => shell.starts_with("fish"), - PluginShellType::Nushell => shell.starts_with("nu"), - PluginShellType::Xonsh => shell.starts_with("xonsh"), - }) - .filter_map(|(plugin, _, probe_type, validator)| -> Option { - match probe_type { - PluginProbeType::EnvironmentVariable(env) => { - env::var(env).ok().filter(|value| !value.is_empty()) - } - PluginProbeType::InteractiveShellVariable(shellvar) => { - ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) - } - } - .and_then(|value| { - validator.map_or_else( - || Some(plugin.to_string()), - |validator| validator(shell, shell_process, &value), - ) - }) - }) - .collect() - } - - pub fn new() -> Self { - // TODO: rework to use atuin_common::Shell - - let sys = System::new_all(); - - let process = sys - .process(get_current_pid().expect("Failed to get current PID")) - .expect("Process with current pid does not exist"); - - let parent = sys - .process(process.parent().expect("Atuin running with no parent!")) - .expect("Process with parent pid does not exist"); - - let name = shell_name(Some(parent)); - - let plugins = ShellInfo::plugins(name.as_str(), parent); - - let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); - - let preexec = Self::detect_preexec_framework(name.as_str()); - - Self { - name, - default, - plugins, - preexec, - } - } -} - -#[derive(Debug, Serialize)] -struct DiskInfo { - pub name: String, - pub filesystem: String, -} - -#[derive(Debug, Serialize)] -struct SystemInfo { - pub os: String, - - pub arch: String, - - pub version: String, - pub disks: Vec, -} - -impl SystemInfo { - pub fn new() -> Self { - let disks = Disks::new_with_refreshed_list(); - let disks = disks - .list() - .iter() - .map(|d| DiskInfo { - name: d.name().to_os_string().into_string().unwrap(), - filesystem: d.file_system().to_os_string().into_string().unwrap(), - }) - .collect(); - - Self { - os: System::name().unwrap_or_else(|| "unknown".to_string()), - arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), - version: System::os_version().unwrap_or_else(|| "unknown".to_string()), - disks, - } - } -} - -#[derive(Debug, Serialize)] -struct SyncInfo { - pub auth_state: String, - pub auto_sync: bool, - - pub last_sync: String, -} - -impl SyncInfo { - pub async fn new(settings: &Settings) -> Self { - // Build auth state description from raw token state without calling - // resolve_sync_auth(), which has side effects (token migration cleanup) - // that a diagnostic command should not trigger. - let meta = Settings::meta_store().await.ok(); - let has_cli_token = match &meta { - Some(m) => m.session_token().await.ok().flatten().is_some(), - None => false, - }; - - let auth_state = if has_cli_token { - "Self-hosted (authenticated)".into() - } else { - "Not authenticated".into() - }; - - Self { - auth_state, - auto_sync: settings.auto_sync, - last_sync: Settings::last_sync() - .await - .map_or_else(|_| "no last sync".to_string(), |v| v.to_string()), - } - } -} - -#[derive(Debug)] -struct SettingPaths { - db: String, - record_store: String, - key: String, -} - -impl SettingPaths { - pub fn new(settings: &Settings) -> Self { - Self { - db: settings.db_path.clone(), - record_store: settings.record_store_path.clone(), - key: settings.key_path.clone(), - } - } - - pub fn verify(&self) { - let paths = vec![ - ("ATUIN_DB_PATH", &self.db), - ("ATUIN_RECORD_STORE", &self.record_store), - ("ATUIN_KEY", &self.key), - ]; - - for (path_env_var, path) in paths { - if utils::broken_symlink(path) { - eprintln!( - "{path} (${path_env_var}) is a broken symlink. This may cause issues with Atuin." - ); - } - } - } -} - -#[derive(Debug, Serialize)] -struct AtuinInfo { - pub version: String, - pub commit: String, - - /// Whether the main Atuin sync server is in use - /// I'm just calling it Atuin Cloud for lack of a better name atm - pub sync: Option, - - pub sqlite_version: String, - - #[serde(skip)] // probably unnecessary to expose this - pub setting_paths: SettingPaths, -} - -impl AtuinInfo { - pub async fn new(settings: &Settings) -> Self { - let logged_in = settings.logged_in().await.unwrap_or(false); - - let sync = if logged_in { - Some(SyncInfo::new(settings).await) - } else { - None - }; - - let sqlite_version = match Sqlite::new("sqlite::memory:", 0.1).await { - Ok(db) => db - .sqlite_version() - .await - .unwrap_or_else(|_| "unknown".to_string()), - Err(_) => "error".to_string(), - }; - - Self { - version: crate::VERSION.to_string(), - commit: crate::SHA.to_string(), - sync, - sqlite_version, - setting_paths: SettingPaths::new(settings), - } - } -} - -#[derive(Debug, Serialize)] -struct DoctorDump { - pub atuin: AtuinInfo, - pub shell: ShellInfo, - pub system: SystemInfo, -} - -impl DoctorDump { - pub async fn new(settings: &Settings) -> Self { - Self { - atuin: AtuinInfo::new(settings).await, - shell: ShellInfo::new(), - system: SystemInfo::new(), - } - } -} - -fn checks(info: &DoctorDump) { - println!(); // spacing - // - let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); - let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh (>= 0.4) be installed. An older ble.sh may not be detected. so ignore this if you have ble.sh >= 0.4 set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); - let blesh_integration_error = "[Shell] Atuin and ble.sh seem to be loaded in the session, but the integration does not seem to be working. Please check the setup in .bashrc.".bold().red(); - - // ZFS: https://github.com/atuinsh/atuin/issues/952 - if info.system.disks.iter().any(|d| d.filesystem == "zfs") { - println!("{zfs_error}"); - } - - info.atuin.setting_paths.verify(); - - // Shell - if info.shell.name == "bash" { - if !info - .shell - .plugins - .iter() - .any(|p| p == "blesh" || p == "bash-preexec") - { - println!("{bash_plugin_error}"); - } - - if info.shell.plugins.iter().any(|plugin| plugin == "atuin") - && info.shell.plugins.iter().any(|plugin| plugin == "blesh") - && info.shell.preexec.as_ref().is_some_and(|val| val == "none") - { - println!("{blesh_integration_error}"); - } - } -} - -pub async fn run(settings: &Settings) -> Result<()> { - println!("{}", "Atuin Doctor".bold()); - println!("Checking for diagnostics"); - let dump = DoctorDump::new(settings).await; - - checks(&dump); - - let dump = serde_json::to_string_pretty(&dump)?; - - println!("\nPlease include the output below with any bug reports or issues\n"); - println!("{dump}"); - - Ok(()) -} diff --git a/crates/atuin/src/command/client/history.rs b/crates/atuin/src/command/client/history.rs deleted file mode 100644 index abf39cc2..00000000 --- a/crates/atuin/src/command/client/history.rs +++ /dev/null @@ -1,1337 +0,0 @@ -use std::{ - fmt::{self, Display}, - io::{self, IsTerminal, Write}, - path::PathBuf, - time::Duration, -}; - -use atuin_common::utils::{self, Escapable as _}; -use clap::Subcommand; -use eyre::{Context, Result, bail}; -use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt}; - -#[cfg(feature = "daemon")] -use super::daemon as daemon_cmd; -#[cfg(feature = "daemon")] -use colored::Colorize; -#[cfg(feature = "daemon")] -use serde::Serialize; - -#[cfg(feature = "daemon")] -use atuin_daemon::history::{HistoryEventKind, TailHistoryReply}; - -use atuin_client::{ - database::{Database, Sqlite, current_context}, - encryption, - history::{History, store::HistoryStore}, - record::sqlite_store::SqliteStore, - settings::{ - FilterMode::{Directory, Global, Session}, - Settings, Timezone, - }, -}; - -#[cfg(feature = "sync")] -use atuin_client::{record, sync}; - -use log::{debug, warn}; -use time::{OffsetDateTime, macros::format_description}; - -#[cfg(feature = "daemon")] -use super::daemon; -use super::search::format_duration_into; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Begins a new command in the history - Start { - /// Collects the command from the `ATUIN_COMMAND_LINE` environment variable, - /// which does not need escaping and is more compatible between OS and shells - #[arg(long = "command-from-env", hide = true)] - cmd_env: bool, - - /// Author of this command, eg `ellie`, `claude`, or `copilot` - #[arg(long)] - author: Option, - - /// Optional intent/rationale for running this command - #[arg(long)] - intent: Option, - - command: Vec, - }, - - /// Finishes a new command in the history (adds time, exit code) - End { - id: String, - #[arg(long, short)] - exit: i64, - #[arg(long, short)] - duration: Option, - }, - - /// Stream history events from the daemon as they are received - Tail, - - /// List all items in history - List { - #[arg(long, short)] - cwd: bool, - - #[arg(long, short)] - session: bool, - - #[arg(long)] - human: bool, - - /// Show only the text of the command - #[arg(long)] - cmd_only: bool, - - /// Terminate the output with a null, for better multiline support - #[arg(long)] - print0: bool, - - #[arg(long, short, default_value = "true")] - // accept no value - #[arg(num_args(0..=1), default_missing_value("true"))] - // accept a value - #[arg(action = clap::ArgAction::Set)] - reverse: bool, - - /// Display the command time in another timezone other than the configured default. - /// - /// This option takes one of the following kinds of values: - /// - the special value "local" (or "l") which refers to the system time zone - /// - an offset from UTC (e.g. "+9", "-2:30") - #[arg(long, visible_alias = "tz")] - timezone: Option, - - /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {exit}, {time}, {session}, and {uuid} - /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" - #[arg(long, short)] - format: Option, - }, - - /// Get the last command ran - Last { - #[arg(long)] - human: bool, - - /// Show only the text of the command - #[arg(long)] - cmd_only: bool, - - /// Display the command time in another timezone other than the configured default. - /// - /// This option takes one of the following kinds of values: - /// - the special value "local" (or "l") which refers to the system time zone - /// - an offset from UTC (e.g. "+9", "-2:30") - #[arg(long, visible_alias = "tz")] - timezone: Option, - - /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {time}, {session}, {uuid} and {relativetime}. - /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" - #[arg(long, short)] - format: Option, - }, - - InitStore, - - /// Delete history entries matching the configured exclusion filters - Prune { - /// List matching history lines without performing the actual deletion. - #[arg(short = 'n', long)] - dry_run: bool, - }, - - /// Delete duplicate history entries (that have the same command, cwd and hostname) - Dedup { - /// List matching history lines without performing the actual deletion. - #[arg(short = 'n', long)] - dry_run: bool, - - /// Only delete results added before this date - #[arg(long, short)] - before: String, - - /// How many recent duplicates to keep - #[arg(long)] - dupkeep: u32, - }, -} - -#[derive(Clone, Copy, Debug)] -pub enum ListMode { - Human, - CmdOnly, - Regular, -} - -impl ListMode { - pub const fn from_flags(human: bool, cmd_only: bool) -> Self { - if human { - ListMode::Human - } else if cmd_only { - ListMode::CmdOnly - } else { - ListMode::Regular - } - } -} - -#[expect(clippy::cast_sign_loss)] -pub fn print_list( - h: &[History], - list_mode: ListMode, - format: Option<&str>, - print0: bool, - reverse: bool, - tz: Timezone, -) { - let w = std::io::stdout(); - let mut w = w.lock(); - - let fmt_str = match list_mode { - ListMode::Human => format - .unwrap_or("{time} · {duration}\t{command}") - .replace("\\t", "\t"), - ListMode::Regular => format - .unwrap_or("{time}\t{command}\t{duration}") - .replace("\\t", "\t"), - // not used - ListMode::CmdOnly => String::new(), - }; - - let parsed_fmt = match list_mode { - ListMode::Human | ListMode::Regular => parse_fmt(&fmt_str), - ListMode::CmdOnly => std::iter::once(ParseSegment::Key("command")).collect(), - }; - - let iterator = if reverse { - Box::new(h.iter().rev()) as Box> - } else { - Box::new(h.iter()) as Box> - }; - - let entry_terminator = if print0 { "\0" } else { "\n" }; - let flush_each_line = print0; - - for history in iterator { - let fh = FmtHistory { - history, - cmd_format: CmdFormat::for_output(&w), - tz: &tz, - }; - let args = parsed_fmt.with_args(&fh); - - // Check for formatting errors before attempting to write - if let Err(err) = args.status() { - eprintln!("ERROR: history output failed with: {err}"); - std::process::exit(1); - } - - let write_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - write!(w, "{args}{entry_terminator}") - })); - - match write_result { - Ok(Ok(())) => { - // Write succeeded - } - Ok(Err(err)) => { - if err.kind() != io::ErrorKind::BrokenPipe { - eprintln!("ERROR: Failed to write history output: {err}"); - std::process::exit(1); - } - } - Err(_) => { - eprintln!("ERROR: Format string caused a formatting error."); - eprintln!( - "This may be due to an unsupported format string containing special characters." - ); - eprintln!( - "Please check your format string syntax and ensure literal braces are properly escaped." - ); - std::process::exit(1); - } - } - if flush_each_line { - check_for_write_errors(w.flush()); - } - } - - if !flush_each_line { - check_for_write_errors(w.flush()); - } -} - -fn check_for_write_errors(write: Result<(), io::Error>) { - if let Err(err) = write { - // Ignore broken pipe (issue #626) - if err.kind() != io::ErrorKind::BrokenPipe { - eprintln!("ERROR: History output failed with the following error: {err}"); - std::process::exit(1); - } - } -} - -/// Type wrapper around `History` with formatting settings. -#[derive(Clone, Copy, Debug)] -struct FmtHistory<'a> { - history: &'a History, - cmd_format: CmdFormat, - tz: &'a Timezone, -} - -#[derive(Clone, Copy, Debug)] -enum CmdFormat { - Literal, - Escaped, -} -impl CmdFormat { - fn for_output(out: &O) -> Self { - if out.is_terminal() { - Self::Escaped - } else { - Self::Literal - } - } -} - -static TIME_FMT: &[time::format_description::FormatItem<'static>] = - format_description!("[year]-[month]-[day] [hour repr:24]:[minute]:[second]"); - -/// defines how to format the history -impl FormatKey for FmtHistory<'_> { - #[expect(clippy::cast_sign_loss)] - fn fmt(&self, key: &str, f: &mut fmt::Formatter<'_>) -> Result<(), FormatKeyError> { - match key { - "command" => match self.cmd_format { - CmdFormat::Literal => f.write_str(self.history.command.trim()), - CmdFormat::Escaped => f.write_str(&self.history.command.trim().escape_control()), - }?, - "directory" => f.write_str(self.history.cwd.trim())?, - "exit" => f.write_str(&self.history.exit.to_string())?, - "duration" => { - let dur = Duration::from_nanos(std::cmp::max(self.history.duration, 0) as u64); - format_duration_into(dur, f)?; - } - "time" => { - self.history - .timestamp - .to_offset(self.tz.0) - .format(TIME_FMT) - .map_err(|_| fmt::Error)? - .fmt(f)?; - } - "relativetime" => { - let since = OffsetDateTime::now_utc() - self.history.timestamp; - let d = Duration::try_from(since).unwrap_or_default(); - format_duration_into(d, f)?; - } - "host" => f.write_str( - self.history - .hostname - .split_once(':') - .map_or(&self.history.hostname, |(host, _)| host), - )?, - "author" => f.write_str(&self.history.author)?, - "intent" => f.write_str(self.history.intent.as_deref().unwrap_or_default())?, - "user" => f.write_str( - self.history - .hostname - .split_once(':') - .map_or("", |(_, user)| user), - )?, - "session" => f.write_str(&self.history.session)?, - "uuid" => f.write_str(&self.history.id.0)?, - _ => return Err(FormatKeyError::UnknownKey), - } - Ok(()) - } -} - -fn parse_fmt(format: &str) -> ParsedFmt<'_> { - match ParsedFmt::new(format) { - Ok(fmt) => fmt, - Err(err) => { - eprintln!("ERROR: History formatting failed with the following error: {err}"); - - if format.contains('"') && (format.contains(":{") || format.contains(",{")) { - eprintln!("It looks like you're trying to create JSON output."); - eprintln!("For JSON, you need to escape literal braces by doubling them:"); - eprintln!("Example: '{{\"command\":\"{{command}}\",\"time\":\"{{time}}\"}}'"); - } else { - eprintln!( - "If your formatting string contains literal curly braces, you need to escape them by doubling:" - ); - eprintln!("Use {{{{ for literal {{ and }}}} for literal }}"); - } - std::process::exit(1) - } - } -} - -fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { - if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { - author.clone_into(&mut history.author); - } - - if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { - history.intent = Some(intent.to_owned()); - } else if intent.is_some() { - history.intent = None; - } -} - -fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { - if !settings.strip_trailing_whitespace { - return command; - } - - let trimmed = command.trim_end_matches([' ', '\t']); - if trimmed.len() == command.len() { - return command; - } - - let trailing_backslashes = trimmed - .as_bytes() - .iter() - .rev() - .take_while(|&&byte| byte == b'\\') - .count(); - - if trailing_backslashes % 2 == 1 { - command - } else { - trimmed - } -} - -async fn handle_start( - db: &impl Database, - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, -) -> Result> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(None); - } - - let id = h.id.0.clone(); - - // Silently ignore database errors to avoid breaking the shell - // This is important when disk is full or database is locked - if let Err(e) = db.save(&h).await { - debug!("failed to save history: {e}"); - } - - Ok(Some(id)) -} - -#[cfg(feature = "daemon")] -async fn handle_daemon_start( - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, -) -> Result> { - // It's better for atuin to silently fail here and attempt to - // store whatever is ran, than to throw an error to the terminal - let cwd = utils::get_current_dir(); - let command = normalize_command_for_storage(command, settings); - - let mut h: History = History::capture() - .timestamp(OffsetDateTime::now_utc()) - .command(command) - .cwd(cwd) - .build() - .into(); - apply_start_metadata(&mut h, author, intent); - - if !h.should_save(settings) { - return Ok(None); - } - - // Attempt to start history via daemon, but silently ignore errors - // to avoid breaking the shell when the daemon is unavailable or disk is full - let resp = match daemon::start_history(settings, h.clone()).await { - Ok(id) => id, - Err(e) => { - debug!("failed to start history via daemon: {e}"); - h.id.0.clone() - } - }; - - Ok(Some(resp)) -} - -#[expect(unused_variables)] -async fn handle_end( - db: &impl Database, - store: SqliteStore, - history_store: HistoryStore, - settings: &Settings, - id: &str, - exit: i64, - duration: Option, -) -> Result<()> { - if id.trim() == "" { - return Ok(()); - } - - let Some(mut h) = db.load(id).await? else { - warn!("history entry is missing"); - return Ok(()); - }; - - if h.duration > 0 { - debug!("cannot end history - already has duration"); - - // returning OK as this can occur if someone Ctrl-c a prompt - return Ok(()); - } - - if !settings.store_failed && exit > 0 { - debug!("history has non-zero exit code, and store_failed is false"); - - // the history has already been inserted half complete. remove it - db.delete(h).await?; - - return Ok(()); - } - - h.exit = exit; - h.duration = match duration { - Some(value) => i64::try_from(value).context("command took over 292 years")?, - None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) - .context("command took over 292 years")?, - }; - - db.update(&h).await?; - history_store.push(h).await?; - - if settings.should_sync().await? { - let (_, downloaded) = - record::sync::sync(settings, &store, &history_store.encryption_key).await?; - Settings::save_sync_time().await?; - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - } else { - debug!("sync disabled! not syncing"); - } - - Ok(()) -} - -#[cfg(feature = "daemon")] -async fn handle_daemon_end( - settings: &Settings, - id: &str, - exit: i64, - duration: Option, -) -> Result<()> { - daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; - - Ok(()) -} - -pub(super) async fn start_history_entry( - settings: &Settings, - command: &str, - author: Option<&str>, - intent: Option<&str>, -) -> Result> { - #[cfg(feature = "daemon")] - if settings.daemon.enabled { - return handle_daemon_start(settings, command, author, intent).await; - } - - let db_path = PathBuf::from(settings.db_path.as_str()); - let db = Sqlite::new(db_path, settings.local_timeout).await?; - handle_start(&db, settings, command, author, intent).await -} - -pub(super) async fn end_history_entry( - settings: &Settings, - id: &str, - exit: i64, - duration: Option, -) -> Result<()> { - #[cfg(feature = "daemon")] - if settings.daemon.enabled { - return handle_daemon_end(settings, id, exit, duration).await; - } - - let db_path = PathBuf::from(settings.db_path.as_str()); - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - - let db = Sqlite::new(db_path, settings.local_timeout).await?; - let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - handle_end(&db, store, history_store, settings, id, exit, duration).await -} - -#[cfg(feature = "daemon")] -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum TailKind { - Started, - Ended, -} - -#[cfg(feature = "daemon")] -#[derive(Clone, Debug, Eq, PartialEq)] -struct TailEvent { - kind: TailKind, - history: History, -} - -#[cfg(feature = "daemon")] -#[derive(Serialize)] -struct TailJsonEvent<'a> { - event: &'static str, - history: TailJsonHistory<'a>, -} - -#[cfg(feature = "daemon")] -#[derive(Serialize)] -struct TailJsonHistory<'a> { - id: &'a str, - timestamp: String, - timestamp_unix_ns: u64, - command: &'a str, - cwd: &'a str, - session: &'a str, - hostname: &'a str, - host: &'a str, - user: &'a str, - author: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - intent: Option<&'a str>, - #[serde(skip_serializing_if = "Option::is_none")] - exit: Option, - #[serde(skip_serializing_if = "Option::is_none")] - duration_ns: Option, - #[serde(skip_serializing_if = "Option::is_none")] - duration: Option, - #[serde(skip_serializing_if = "Option::is_none")] - success: Option, - #[serde(skip_serializing_if = "Option::is_none")] - finished_at: Option, -} - -#[cfg(feature = "daemon")] -impl TailEvent { - fn from_proto(reply: TailHistoryReply) -> Result { - let history = reply - .history - .ok_or_else(|| eyre::eyre!("daemon sent a history tail event without history"))?; - let timestamp = OffsetDateTime::from_unix_timestamp_nanos(i128::from(history.timestamp)) - .context("invalid daemon history timestamp")?; - let kind = match HistoryEventKind::try_from(reply.kind) - .unwrap_or(HistoryEventKind::Unspecified) - { - HistoryEventKind::Started => TailKind::Started, - HistoryEventKind::Ended => TailKind::Ended, - HistoryEventKind::Unspecified => bail!("daemon sent an unspecified history tail event"), - }; - - Ok(Self { - kind, - history: History { - id: history.id.into(), - timestamp, - duration: history.duration, - exit: history.exit, - command: history.command, - cwd: history.cwd, - session: history.session, - hostname: history.hostname, - author: history.author, - intent: normalize_optional_field(&history.intent), - deleted_at: None, - }, - }) - } - - fn render(&self, tty: bool, tz: Timezone) -> Result { - if tty { - Ok(self.render_pretty(tz)) - } else { - let mut json = self.render_json(tz)?; - json.push('\n'); - Ok(json) - } - } - - fn render_json(&self, tz: Timezone) -> Result { - let payload = TailJsonEvent { - event: self.kind.as_str(), - history: TailJsonHistory { - id: &self.history.id.0, - timestamp: format_history_time(self.history.timestamp, tz)?, - timestamp_unix_ns: u64::try_from(self.history.timestamp.unix_timestamp_nanos()) - .context("history timestamp predates unix epoch")?, - command: &self.history.command, - cwd: &self.history.cwd, - session: &self.history.session, - hostname: &self.history.hostname, - host: self.host(), - user: self.user(), - author: &self.history.author, - intent: self.history.intent.as_deref(), - exit: self.exit_value(), - duration_ns: self.duration_value(), - duration: self.duration_value().map(format_duration_ns), - success: self.success_value(), - finished_at: self - .finished_at() - .map(|time| format_history_time(time, tz)) - .transpose()?, - }, - }; - - Ok(serde_json::to_string(&payload)?) - } - - fn render_pretty(&self, tz: Timezone) -> String { - let mut out = String::new(); - let border = match self.kind { - TailKind::Started => "-".repeat(72).bright_blue().to_string(), - TailKind::Ended if self.history.exit == 0 => "-".repeat(72).bright_green().to_string(), - TailKind::Ended => "-".repeat(72).bright_red().to_string(), - }; - - out.push_str(&border); - out.push('\n'); - - let command = self.history.command.trim(); - let escaped_command = command.escape_control(); - let mut command_lines = escaped_command.lines(); - let header = format!( - "{} {}", - self.kind.badge(self.history.exit), - command_lines.next().unwrap_or_default().bold() - ); - out.push_str(&header); - out.push('\n'); - - for line in command_lines { - out.push_str(" "); - out.push_str(line); - out.push('\n'); - } - - push_pretty_field( - &mut out, - "start", - &format_history_time(self.history.timestamp, tz) - .unwrap_or_else(|_| "invalid".to_owned()), - ); - push_pretty_field(&mut out, "history", &self.history.id.0); - push_pretty_field(&mut out, "session", &self.history.session); - push_pretty_field(&mut out, "exit", &self.exit_display()); - push_pretty_field(&mut out, "duration", &self.duration_display()); - - out.push('\n'); - - push_pretty_field(&mut out, "cwd", &self.history.cwd); - push_pretty_field(&mut out, "hostname", &self.history.hostname); - push_pretty_field(&mut out, "host", self.host()); - push_pretty_field(&mut out, "user", self.user()); - push_pretty_field(&mut out, "author", &self.history.author); - - if let Some(intent) = self.history.intent.as_deref() { - push_pretty_field(&mut out, "intent", intent); - } - - if let Some(finished) = self.finished_at() { - let finished = - format_history_time(finished, tz).unwrap_or_else(|_| "invalid".to_owned()); - push_pretty_field(&mut out, "finished", &finished); - } - - out.push_str(&border); - out.push_str("\n\n"); - out - } - - fn host(&self) -> &str { - self.history - .hostname - .split_once(':') - .map_or(self.history.hostname.as_str(), |(host, _)| host) - } - - fn user(&self) -> &str { - self.history - .hostname - .split_once(':') - .map_or("", |(_, user)| user) - } - - fn exit_value(&self) -> Option { - matches!(self.kind, TailKind::Ended).then_some(self.history.exit) - } - - fn duration_value(&self) -> Option { - matches!(self.kind, TailKind::Ended).then_some(self.history.duration) - } - - fn success_value(&self) -> Option { - matches!(self.kind, TailKind::Ended).then_some(self.history.exit == 0) - } - - fn finished_at(&self) -> Option { - self.duration_value() - .filter(|duration| *duration >= 0) - .map(time::Duration::nanoseconds) - .and_then(|duration| self.history.timestamp.checked_add(duration)) - } - - fn exit_display(&self) -> String { - match self.exit_value() { - Some(0) => "0 (success)".bright_green().to_string(), - Some(code) => format!("{code} (failure)").bright_red().to_string(), - None => "pending".bright_yellow().to_string(), - } - } - - fn duration_display(&self) -> String { - match self.duration_value() { - Some(duration) if duration >= 0 => format_duration_ns(duration), - Some(_) => "unknown".bright_yellow().to_string(), - None => "running".bright_yellow().to_string(), - } - } -} - -#[cfg(feature = "daemon")] -impl TailKind { - const fn as_str(self) -> &'static str { - match self { - Self::Started => "started", - Self::Ended => "ended", - } - } - - fn badge(self, exit: i64) -> colored::ColoredString { - match self { - Self::Started => "STARTED".bold().bright_blue(), - Self::Ended if exit == 0 => "ENDED".bold().bright_green(), - Self::Ended => "ENDED".bold().bright_red(), - } - } -} - -#[cfg(feature = "daemon")] -fn format_history_time(timestamp: OffsetDateTime, tz: Timezone) -> Result { - Ok(timestamp.to_offset(tz.0).format(TIME_FMT)?) -} - -#[cfg(feature = "daemon")] -fn format_duration_ns(duration_ns: i64) -> String { - struct F(Duration); - impl Display for F { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_duration_into(self.0, f) - } - } - - F(Duration::from_nanos(duration_ns.max(0).cast_unsigned())).to_string() -} - -#[cfg(feature = "daemon")] -fn push_pretty_field(out: &mut String, label: &str, value: &str) { - out.push_str(" "); - let label = format!("{label}:"); - out.push_str(&label.bright_cyan().bold().to_string()); - if label.len() < 10 { - out.push_str(&" ".repeat(10 - label.len())); - } - - let mut lines = value.lines(); - if let Some(first) = lines.next() { - out.push_str(first); - } - out.push('\n'); - - for line in lines { - out.push_str(" "); - out.push_str(line); - out.push('\n'); - } -} - -#[cfg(feature = "daemon")] -fn normalize_optional_field(value: &str) -> Option { - let trimmed = value.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_owned()) - } -} - -impl Cmd { - #[cfg(feature = "daemon")] - async fn handle_tail(settings: &Settings) -> Result<()> { - let tty = std::io::stdout().is_terminal(); - let mut client = daemon::tail_client(settings).await?; - let mut stream = client.tail_history().await?; - let stdout = std::io::stdout(); - - while let Some(reply) = stream.message().await? { - let event = TailEvent::from_proto(reply)?; - let rendered = event.render(tty, settings.timezone)?; - let mut out = stdout.lock(); - - match out.write_all(rendered.as_bytes()) { - Ok(()) => out.flush()?, - Err(err) if err.kind() == io::ErrorKind::BrokenPipe => break, - Err(err) => return Err(err.into()), - } - } - - Ok(()) - } - - #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] - #[expect(clippy::too_many_arguments)] - #[expect(clippy::fn_params_excessive_bools)] - async fn handle_list( - db: &impl Database, - settings: &Settings, - context: atuin_client::database::Context, - session: bool, - cwd: bool, - mode: ListMode, - format: Option, - include_deleted: bool, - print0: bool, - reverse: bool, - tz: Timezone, - ) -> Result<()> { - let filters = match (session, cwd) { - (true, true) => [Session, Directory], - (true, false) => [Session, Global], - (false, true) => [Global, Directory], - (false, false) => [ - settings.default_filter_mode(context.git_root.is_some()), - Global, - ], - }; - - let history = db - .list(&filters, &context, None, false, include_deleted) - .await?; - - print_list( - &history, - mode, - match format { - None => Some(settings.history_format.as_str()), - _ => format.as_deref(), - }, - print0, - reverse, - tz, - ); - - Ok(()) - } - - async fn handle_prune( - db: &impl Database, - settings: &Settings, - store: SqliteStore, - context: atuin_client::database::Context, - dry_run: bool, - ) -> Result<()> { - // Grab all executed commands and filter them using History::should_save. - // We could iterate or paginate here if memory usage becomes an issue. - let matches: Vec = db - .list(&[Global], &context, None, false, false) - .await? - .into_iter() - .filter(|h| !h.should_save(settings)) - .collect(); - - match matches.len() { - 0 => { - println!("No entries to prune."); - return Ok(()); - } - 1 => println!("Found 1 entry to prune."), - n => println!("Found {n} entries to prune."), - } - - if dry_run { - print_list( - &matches, - ListMode::Human, - Some(settings.history_format.as_str()), - false, - false, - settings.timezone, - ); - } else { - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - for entry in matches { - eprintln!("deleting {}", entry.id); - let (id, _) = history_store.delete(entry.id.clone()).await?; - history_store.incremental_build(db, &[id]).await?; - } - - #[cfg(feature = "daemon")] - daemon_cmd::emit_event(settings, atuin_daemon::DaemonEvent::HistoryPruned).await; - } - Ok(()) - } - - async fn handle_dedup( - db: &impl Database, - settings: &Settings, - store: SqliteStore, - before: i64, - dupkeep: u32, - dry_run: bool, - ) -> Result<()> { - if dupkeep == 0 { - eprintln!( - "\"--dupkeep 0\" would keep 0 copies of duplicate commands and thus delete all of them! Use \"atuin search --delete ...\" if you really want that." - ); - std::process::exit(1); - } - - let matches: Vec = db.get_dups(before, dupkeep).await?; - - match matches.len() { - 0 => { - println!("No duplicates to delete."); - return Ok(()); - } - 1 => println!("Found 1 duplicate to delete."), - n => println!("Found {n} duplicates to delete."), - } - - if dry_run { - print_list( - &matches, - ListMode::Human, - Some(settings.history_format.as_str()), - false, - false, - settings.timezone, - ); - } else { - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - #[cfg(feature = "daemon")] - let ids = matches.iter().map(|h| h.id.clone()).collect::>(); - - for entry in matches { - eprintln!("deleting {}", entry.id); - let (id, _) = history_store.delete(entry.id).await?; - history_store.incremental_build(db, &[id]).await?; - } - - #[cfg(feature = "daemon")] - daemon_cmd::emit_event(settings, atuin_daemon::DaemonEvent::HistoryDeleted { ids }) - .await; - } - Ok(()) - } - - #[expect(clippy::too_many_lines)] - pub async fn run(self, settings: &Settings) -> Result<()> { - match self { - Self::Start { - cmd_env, - author, - intent, - command, - } => { - let command = if cmd_env { - std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() - } else { - command.join(" ") - }; - - if let Some(id) = - start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) - .await? - { - println!("{id}"); - } - - Ok(()) - } - Self::End { id, exit, duration } => { - end_history_entry(settings, &id, exit, duration).await - } - Self::Tail => { - #[cfg(feature = "daemon")] - { - return Self::handle_tail(settings).await; - } - - #[cfg(not(feature = "daemon"))] - bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); - } - cmd => { - let context = current_context().await?; - - let db_path = PathBuf::from(settings.db_path.as_str()); - let record_store_path = PathBuf::from(settings.record_store_path.as_str()); - - let db = Sqlite::new(db_path, settings.local_timeout).await?; - let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; - - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - match cmd { - Self::List { - session, - cwd, - human, - cmd_only, - print0, - reverse, - timezone, - format, - } => { - let mode = ListMode::from_flags(human, cmd_only); - let tz = timezone.unwrap_or(settings.timezone); - Self::handle_list( - &db, settings, context, session, cwd, mode, format, false, print0, - reverse, tz, - ) - .await - } - - Self::Last { - human, - cmd_only, - timezone, - format, - } => { - let last = db.last().await?; - let last = last.as_slice(); - let tz = timezone.unwrap_or(settings.timezone); - print_list( - last, - ListMode::from_flags(human, cmd_only), - match format { - None => Some(settings.history_format.as_str()), - _ => format.as_deref(), - }, - false, - true, - tz, - ); - - Ok(()) - } - - Self::InitStore => history_store.init_store(&db).await, - - Self::Prune { dry_run } => { - Self::handle_prune(&db, settings, store, context, dry_run).await - } - - Self::Dedup { - dry_run, - before, - dupkeep, - } => { - let before = i64::try_from( - interim::parse_date_string( - before.as_str(), - OffsetDateTime::now_utc(), - interim::Dialect::Uk, - )? - .unix_timestamp_nanos(), - )?; - Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await - } - - Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), - } - } - } - } -} - -#[cfg(test)] -mod tests { - #[cfg(feature = "daemon")] - use time::macros::datetime; - - use super::*; - - #[test] - fn normalize_command_strips_trailing_spaces_and_tabs() { - let settings = Settings::utc(); - - assert!(settings.strip_trailing_whitespace); - assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); - } - - #[test] - fn normalize_command_preserves_escaped_trailing_space() { - let settings = Settings::utc(); - - assert_eq!( - normalize_command_for_storage("printf foo\\ ", &settings), - "printf foo\\ " - ); - assert_eq!( - normalize_command_for_storage("printf foo\\\\ ", &settings), - "printf foo\\\\" - ); - } - - #[tokio::test] - async fn handle_start_saves_trimmed_command() { - let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); - let settings = Settings::utc(); - - handle_start(&db, &settings, "ls \t", None, None) - .await - .unwrap(); - - let history = db - .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) - .await - .unwrap() - .pop() - .unwrap(); - assert_eq!(history.command, "ls"); - } - - #[tokio::test] - async fn handle_start_can_keep_trailing_whitespace() { - let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); - let settings = Settings { - strip_trailing_whitespace: false, - ..Settings::utc() - }; - - handle_start(&db, &settings, "ls \t", None, None) - .await - .unwrap(); - - let history = db - .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) - .await - .unwrap() - .pop() - .unwrap(); - assert_eq!(history.command, "ls \t"); - } - - #[test] - fn test_format_string_no_panic() { - // Don't panic but provide helpful output (issue #2776) - let malformed_json = r#"{"command":"{command}","key":"value"}"#; - - let result = std::panic::catch_unwind(|| parse_fmt(malformed_json)); - - assert!(result.is_ok()); - } - - #[test] - fn test_valid_formats_still_work() { - assert!(std::panic::catch_unwind(|| parse_fmt("{command}")).is_ok()); - assert!(std::panic::catch_unwind(|| parse_fmt("{time} - {command}")).is_ok()); - } - - #[cfg(feature = "daemon")] - fn sample_tail_event(kind: TailKind) -> TailEvent { - TailEvent { - kind, - history: History { - id: "history-id".to_owned().into(), - timestamp: datetime!(2026-04-09 17:18:19 UTC), - duration: 12_345_678, - exit: 0, - command: "git status".to_owned(), - cwd: "/tmp/repo".to_owned(), - session: "session-id".to_owned(), - hostname: "host:ellie".to_owned(), - author: "claude".to_owned(), - intent: Some("inspect repository state".to_owned()), - deleted_at: None, - }, - } - } - - #[cfg(feature = "daemon")] - #[test] - fn test_tail_json_output_contains_history_fields() { - let json = sample_tail_event(TailKind::Ended) - .render(false, Timezone(time::UtcOffset::UTC)) - .unwrap(); - let value: serde_json::Value = serde_json::from_str(&json).unwrap(); - - assert_eq!(value["event"], "ended"); - assert_eq!(value["history"]["id"], "history-id"); - assert_eq!(value["history"]["duration_ns"], 12_345_678); - assert_eq!(value["history"]["success"], true); - assert!(value.get("record").is_none()); - } - - #[cfg(feature = "daemon")] - #[test] - fn test_tail_pretty_output_shows_pending_fields_for_started_events() { - let rendered = sample_tail_event(TailKind::Started) - .render(true, Timezone(time::UtcOffset::UTC)) - .unwrap(); - let plain = regex::Regex::new(r"\x1b\[[0-9;]*m") - .unwrap() - .replace_all(&rendered, ""); - - assert!(plain.contains("STARTED git status")); - assert!(plain.contains("exit:")); - assert!(plain.contains("pending")); - assert!(plain.contains("duration:")); - assert!(plain.contains("running")); - } -} diff --git a/crates/atuin/src/command/client/import.rs b/crates/atuin/src/command/client/import.rs deleted file mode 100644 index 21ac76b4..00000000 --- a/crates/atuin/src/command/client/import.rs +++ /dev/null @@ -1,186 +0,0 @@ -use std::env; - -use async_trait::async_trait; -use clap::Parser; -use eyre::Result; -use indicatif::ProgressBar; - -use atuin_client::{ - database::Database, - history::History, - import::{ - Importer, Loader, bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, - powershell::PowerShell, replxx::Replxx, resh::Resh, xonsh::Xonsh, - xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, - }, -}; - -#[derive(Parser, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Import history for the current shell - Auto, - - /// Import history from the zsh history file - Zsh, - /// Import history from the zsh history file - ZshHistDb, - /// Import history from the bash history file - Bash, - /// Import history from the replxx history file - Replxx, - /// Import history from the resh history file - Resh, - /// Import history from the fish history file - Fish, - /// Import history from the nu history file - Nu, - /// Import history from the nu history file - NuHistDb, - /// Import history from xonsh json files - Xonsh, - /// Import history from xonsh sqlite db - XonshSqlite, - /// Import history from the powershell history file - Powershell, -} - -const BATCH_SIZE: usize = 100; - -impl Cmd { - #[expect(clippy::cognitive_complexity)] - pub async fn run(&self, db: &DB) -> Result<()> { - println!(" Atuin "); - println!("======================"); - println!(" \u{1f30d} "); - println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); - println!(" \u{1f422} "); - println!("======================"); - println!("Importing history..."); - - match self { - Self::Auto => { - if cfg!(windows) { - return if env::var("PSModulePath").is_ok() { - println!("Detected PowerShell"); - import::(db).await - } else { - println!("Could not detect the current shell."); - println!("Please run atuin import ."); - println!("To view a list of shells, run atuin import."); - Ok(()) - }; - } - - // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is - let xonsh_histfile = - env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); - let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); - - if xonsh_histfile.to_lowercase().ends_with(".json") { - println!("Detected Xonsh"); - import::(db).await - } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { - println!("Detected Xonsh (SQLite backend)"); - import::(db).await - } else if shell.ends_with("/zsh") { - if ZshHistDb::histpath().is_ok() { - println!( - "Detected Zsh-HistDb, using :{}", - ZshHistDb::histpath().unwrap().to_str().unwrap() - ); - import::(db).await - } else { - println!("Detected ZSH"); - import::(db).await - } - } else if shell.ends_with("/fish") { - println!("Detected Fish"); - import::(db).await - } else if shell.ends_with("/bash") { - println!("Detected Bash"); - import::(db).await - } else if shell.ends_with("/nu") { - if NuHistDb::histpath().is_ok() { - println!( - "Detected Nu-HistDb, using :{}", - NuHistDb::histpath().unwrap().to_str().unwrap() - ); - import::(db).await - } else { - println!("Detected Nushell"); - import::(db).await - } - } else if shell.ends_with("/pwsh") { - println!("Detected PowerShell"); - import::(db).await - } else { - println!("cannot import {shell} history"); - Ok(()) - } - } - - Self::Zsh => import::(db).await, - Self::ZshHistDb => import::(db).await, - Self::Bash => import::(db).await, - Self::Replxx => import::(db).await, - Self::Resh => import::(db).await, - Self::Fish => import::(db).await, - Self::Nu => import::(db).await, - Self::NuHistDb => import::(db).await, - Self::Xonsh => import::(db).await, - Self::XonshSqlite => import::(db).await, - Self::Powershell => import::(db).await, - } - } -} - -pub struct HistoryImporter<'db, DB: Database> { - pb: ProgressBar, - buf: Vec, - db: &'db DB, -} - -impl<'db, DB: Database> HistoryImporter<'db, DB> { - fn new(db: &'db DB, len: usize) -> Self { - Self { - pb: ProgressBar::new(len as u64), - buf: Vec::with_capacity(BATCH_SIZE), - db, - } - } - - async fn flush(self) -> Result<()> { - if !self.buf.is_empty() { - self.db.save_bulk(&self.buf).await?; - } - self.pb.finish(); - Ok(()) - } -} - -#[async_trait] -impl Loader for HistoryImporter<'_, DB> { - async fn push(&mut self, hist: History) -> Result<()> { - self.pb.inc(1); - self.buf.push(hist); - if self.buf.len() == self.buf.capacity() { - self.db.save_bulk(&self.buf).await?; - self.buf.clear(); - } - Ok(()) - } -} - -async fn import(db: &DB) -> Result<()> { - println!("Importing history from {}", I::NAME); - - let mut importer = I::new().await?; - let len = importer.entries().await.unwrap(); - let mut loader = HistoryImporter::new(db, len); - importer.load(&mut loader).await?; - loader.flush().await?; - - println!("Import complete!"); - Ok(()) -} diff --git a/crates/atuin/src/command/client/info.rs b/crates/atuin/src/command/client/info.rs deleted file mode 100644 index a69f9b2f..00000000 --- a/crates/atuin/src/command/client/info.rs +++ /dev/null @@ -1,31 +0,0 @@ -use atuin_client::settings::Settings; - -use crate::{SHA, VERSION}; - -pub fn run(settings: &Settings) { - let config = atuin_common::utils::config_dir(); - let mut config_file = config.clone(); - config_file.push("config.toml"); - let mut sever_config = config; - sever_config.push("server.toml"); - - let config_paths = format!( - "Config files:\nclient config: {:?}\nserver config: {:?}\nclient db path: {:?}\nkey path: {:?}\nmeta db path: {:?}", - config_file.to_string_lossy(), - sever_config.to_string_lossy(), - settings.db_path, - settings.key_path, - settings.meta.db_path - ); - - let env_vars = format!( - "Env Vars:\nATUIN_CONFIG_DIR = {:?}", - std::env::var("ATUIN_CONFIG_DIR").unwrap_or_else(|_| "None".into()) - ); - - let general_info = format!("Version info:\nversion: {VERSION}\ncommit: {SHA}"); - - let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}"); - - println!("{print_out}"); -} diff --git a/crates/atuin/src/command/client/init.rs b/crates/atuin/src/command/client/init.rs deleted file mode 100644 index 39cd1247..00000000 --- a/crates/atuin/src/command/client/init.rs +++ /dev/null @@ -1,127 +0,0 @@ -use atuin_client::settings::{Settings, Tmux}; -use clap::{Parser, ValueEnum}; - -mod bash; -mod fish; -mod powershell; -mod xonsh; -mod zsh; - -#[derive(Parser, Debug)] -pub struct Cmd { - shell: Shell, - - /// Disable the binding of CTRL-R to atuin - #[clap(long)] - disable_ctrl_r: bool, - - /// Disable the binding of the Up Arrow key to atuin - #[clap(long)] - disable_up_arrow: bool, - - /// Disable the binding of ? to Atuin AI - #[clap(long)] - disable_ai: bool, -} - -#[derive(Clone, Copy, ValueEnum, Debug)] -#[value(rename_all = "lower")] -#[expect(clippy::enum_variant_names, clippy::doc_markdown)] -pub enum Shell { - /// Zsh setup - Zsh, - /// Bash setup - Bash, - /// Fish setup - Fish, - /// Nu setup - Nu, - /// Xonsh setup - Xonsh, - /// PowerShell setup - PowerShell, -} - -impl Cmd { - fn init_nu(&self, _tmux: &Tmux) { - let full = include_str!("../../shell/atuin.nu"); - - // TODO: tmux popup for Nu - println!("{full}"); - - if std::env::var("ATUIN_NOBIND").is_err() { - const BIND_CTRL_R: &str = r"$env.config = ( - $env.config | upsert keybindings ( - $env.config.keybindings - | append { - name: atuin - modifier: control - keycode: char_r - mode: [emacs, vi_normal, vi_insert] - event: { send: executehostcommand cmd: (_atuin_search_cmd) } - } - ) -)"; - const BIND_UP_ARROW: &str = r" -$env.config = ( - $env.config | upsert keybindings ( - $env.config.keybindings - | append { - name: atuin - modifier: none - keycode: up - mode: [emacs, vi_normal, vi_insert] - event: { - until: [ - {send: menuup} - {send: executehostcommand cmd: (_atuin_search_cmd '--shell-up-key-binding') } - ] - } - } - ) -) -"; - if !self.disable_ctrl_r { - println!("{BIND_CTRL_R}"); - } - if !self.disable_up_arrow { - println!("{BIND_UP_ARROW}"); - } - } - } - - fn static_init(&self, settings: &Settings) { - let tmux = &settings.tmux; - - match self.shell { - Shell::Zsh => { - zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); - } - Shell::Bash => { - bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); - } - Shell::Fish => { - fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); - } - Shell::Nu => { - self.init_nu(tmux); - } - Shell::Xonsh => { - xonsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); - } - Shell::PowerShell => { - powershell::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); - } - } - } - - pub fn run(self, settings: &Settings) { - if !settings.paths_ok() { - eprintln!( - "Atuin settings paths are broken. Disabling atuin shell hooks. Run `atuin doctor` to diagnose." - ); - } - - self.static_init(settings); - } -} diff --git a/crates/atuin/src/command/client/init/bash.rs b/crates/atuin/src/command/client/init/bash.rs deleted file mode 100644 index 2280dc3d..00000000 --- a/crates/atuin/src/command/client/init/bash.rs +++ /dev/null @@ -1,25 +0,0 @@ -use atuin_client::settings::Tmux; - -fn print_tmux_config(tmux: &Tmux) { - if tmux.enabled { - println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); - println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); - } else { - println!("export ATUIN_TMUX_POPUP=false"); - } -} - -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { - let base = include_str!("../../../shell/atuin.bash"); - - let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { - (false, false) - } else { - (!disable_ctrl_r, !disable_up_arrow) - }; - - print_tmux_config(tmux); - println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); - println!("__atuin_bind_up_arrow={bind_up_arrow}"); - println!("{base}"); -} diff --git a/crates/atuin/src/command/client/init/fish.rs b/crates/atuin/src/command/client/init/fish.rs deleted file mode 100644 index 07c6a5ba..00000000 --- a/crates/atuin/src/command/client/init/fish.rs +++ /dev/null @@ -1,86 +0,0 @@ -use atuin_client::settings::Tmux; - -fn print_tmux_config(tmux: &Tmux) { - if tmux.enabled { - println!("set -gx ATUIN_TMUX_POPUP_WIDTH '{}'", tmux.width); - println!("set -gx ATUIN_TMUX_POPUP_HEIGHT '{}'", tmux.height); - } else { - println!("set -gx ATUIN_TMUX_POPUP false"); - } -} - -fn print_bindings( - indent: &str, - disable_up_arrow: bool, - disable_ctrl_r: bool, - bind_ctrl_r: &str, - bind_up_arrow: &str, - bind_ctrl_r_ins: &str, - bind_up_arrow_ins: &str, -) { - if !disable_ctrl_r { - println!("{indent}{bind_ctrl_r}"); - } - if !disable_up_arrow { - println!("{indent}{bind_up_arrow}"); - } - - println!("{indent}if bind -M insert >/dev/null 2>&1"); - if !disable_ctrl_r { - println!("{indent}{indent}{bind_ctrl_r_ins}"); - } - if !disable_up_arrow { - println!("{indent}{indent}{bind_up_arrow_ins}"); - } - println!("{indent}end"); -} - -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { - let indent = " ".repeat(4); - - let base = include_str!("../../../shell/atuin.fish"); - - print_tmux_config(tmux); - println!("{base}"); - - if std::env::var("ATUIN_NOBIND").is_err() { - println!("if string match -q '4.*' $version"); - - // In fish 4.0 and above the option bind -k doesn't exist anymore, - // instead we can use key names and modifiers directly. - print_bindings( - &indent, - disable_up_arrow, - disable_ctrl_r, - "bind ctrl-r _atuin_search", - "bind up _atuin_bind_up", - "bind -M insert ctrl-r _atuin_search", - "bind -M insert up _atuin_bind_up", - ); - - println!("else"); - - // We keep these for compatibility with fish 3.x - print_bindings( - &indent, - disable_up_arrow, - disable_ctrl_r, - r"bind \cr _atuin_search", - &[ - r"bind -k up _atuin_bind_up", - r"bind \eOA _atuin_bind_up", - r"bind \e\[A _atuin_bind_up", - ] - .join("; "), - r"bind -M insert \cr _atuin_search", - &[ - r"bind -M insert -k up _atuin_bind_up", - r"bind -M insert \eOA _atuin_bind_up", - r"bind -M insert \e\[A _atuin_bind_up", - ] - .join("; "), - ); - - println!("end"); - } -} diff --git a/crates/atuin/src/command/client/init/powershell.rs b/crates/atuin/src/command/client/init/powershell.rs deleted file mode 100644 index f92f1cbe..00000000 --- a/crates/atuin/src/command/client/init/powershell.rs +++ /dev/null @@ -1,23 +0,0 @@ -use atuin_client::settings::Tmux; - -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { - let base = include_str!("../../../shell/atuin.ps1"); - - let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { - (false, false) - } else { - (!disable_ctrl_r, !disable_up_arrow) - }; - - // TODO: tmux popup for Powershell - println!("{base}"); - println!( - "Enable-AtuinSearchKeys -CtrlR {} -UpArrow {}", - ps_bool(bind_ctrl_r), - ps_bool(bind_up_arrow) - ); -} - -fn ps_bool(value: bool) -> &'static str { - if value { "$true" } else { "$false" } -} diff --git a/crates/atuin/src/command/client/init/xonsh.rs b/crates/atuin/src/command/client/init/xonsh.rs deleted file mode 100644 index 9fb5730d..00000000 --- a/crates/atuin/src/command/client/init/xonsh.rs +++ /dev/null @@ -1,22 +0,0 @@ -use atuin_client::settings::Tmux; - -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { - let base = include_str!("../../../shell/atuin.xsh"); - - let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { - (false, false) - } else { - (!disable_ctrl_r, !disable_up_arrow) - }; - - // TODO: tmux popup for xonsh - println!( - "_ATUIN_BIND_CTRL_R={}", - if bind_ctrl_r { "True" } else { "False" } - ); - println!( - "_ATUIN_BIND_UP_ARROW={}", - if bind_up_arrow { "True" } else { "False" } - ); - println!("{base}"); -} diff --git a/crates/atuin/src/command/client/init/zsh.rs b/crates/atuin/src/command/client/init/zsh.rs deleted file mode 100644 index 3f325167..00000000 --- a/crates/atuin/src/command/client/init/zsh.rs +++ /dev/null @@ -1,38 +0,0 @@ -use atuin_client::settings::Tmux; - -fn print_tmux_config(tmux: &Tmux) { - if tmux.enabled { - println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); - println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); - } else { - println!("export ATUIN_TMUX_POPUP=false"); - } -} - -pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { - let base = include_str!("../../../shell/atuin.zsh"); - - print_tmux_config(tmux); - println!("{base}"); - - if std::env::var("ATUIN_NOBIND").is_err() { - const BIND_CTRL_R: &str = r"bindkey -M emacs '^r' atuin-search -bindkey -M viins '^r' atuin-search-viins -bindkey -M vicmd '/' atuin-search"; - - const BIND_UP_ARROW: &str = r"bindkey -M emacs '^[[A' atuin-up-search -bindkey -M vicmd '^[[A' atuin-up-search-vicmd -bindkey -M viins '^[[A' atuin-up-search-viins -bindkey -M emacs '^[OA' atuin-up-search -bindkey -M vicmd '^[OA' atuin-up-search-vicmd -bindkey -M viins '^[OA' atuin-up-search-viins -bindkey -M vicmd 'k' atuin-up-search-vicmd"; - - if !disable_ctrl_r { - println!("{BIND_CTRL_R}"); - } - if !disable_up_arrow { - println!("{BIND_UP_ARROW}"); - } - } -} diff --git a/crates/atuin/src/command/client/search.rs b/crates/atuin/src/command/client/search.rs deleted file mode 100644 index a9dc9a68..00000000 --- a/crates/atuin/src/command/client/search.rs +++ /dev/null @@ -1,375 +0,0 @@ -use std::fs::File; -use std::io::{IsTerminal as _, Write, stderr, stdout}; - -use atuin_common::utils::{self, Escapable as _}; -use clap::Parser; -use eyre::Result; - -use atuin_client::{ - database::Database, - database::{OptFilters, current_context}, - encryption, - history::{History, store::HistoryStore}, - record::sqlite_store::SqliteStore, - settings::{FilterMode, KeymapMode, SearchMode, Settings, Timezone}, - theme::Theme, -}; - -use super::history::ListMode; - -mod cursor; -mod duration; -mod engines; -mod history_list; -mod inspector; -mod interactive; -pub mod keybindings; - -pub use duration::format_duration_into; - -#[expect(clippy::struct_excessive_bools, clippy::struct_field_names)] -#[derive(Parser, Debug)] -pub struct Cmd { - /// Filter search result by directory - #[arg(long, short)] - cwd: Option, - - /// Exclude directory from results - #[arg(long = "exclude-cwd")] - exclude_cwd: Option, - - /// Filter search result by exit code - #[arg(long, short)] - exit: Option, - - /// Exclude results with this exit code - #[arg(long = "exclude-exit")] - exclude_exit: Option, - - /// Only include results added before this date - #[arg(long, short)] - before: Option, - - /// Only include results after this date - #[arg(long)] - after: Option, - - /// How many entries to return at most - #[arg(long)] - limit: Option, - - /// Offset from the start of the results - #[arg(long)] - offset: Option, - - /// Open interactive search UI - #[arg(long, short)] - interactive: bool, - - /// Allow overriding filter mode over config - #[arg(long = "filter-mode")] - filter_mode: Option, - - /// Allow overriding search mode over config - #[arg(long = "search-mode")] - search_mode: Option, - - /// Marker argument used to inform atuin that it was invoked from a shell up-key binding (hidden from help to avoid confusion) - #[arg(long = "shell-up-key-binding", hide = true)] - shell_up_key_binding: bool, - - /// Notify the keymap at the shell's side - #[arg(long = "keymap-mode", default_value = "auto")] - keymap_mode: KeymapMode, - - /// Use human-readable formatting for time - #[arg(long)] - human: bool, - - #[arg(allow_hyphen_values = true)] - query: Option>, - - /// Show only the text of the command - #[arg(long)] - cmd_only: bool, - - /// Terminate the output with a null, for better multiline handling - #[arg(long)] - print0: bool, - - /// Delete anything matching this query. Will not print out the match - #[arg(long)] - delete: bool, - - /// Delete EVERYTHING! - #[arg(long)] - delete_it_all: bool, - - /// Reverse the order of results, oldest first - #[arg(long, short)] - reverse: bool, - - /// Display the command time in another timezone other than the configured default. - /// - /// This option takes one of the following kinds of values: - /// - the special value "local" (or "l") which refers to the system time zone - /// - an offset from UTC (e.g. "+9", "-2:30") - #[arg(long, visible_alias = "tz")] - #[arg(allow_hyphen_values = true)] - // Clippy warns about `Option>`, but we suppress it because we need - // this distinction for proper argument handling. - #[expect(clippy::option_option)] - timezone: Option>, - - /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {time}, {exit} and - /// {relativetime}. - /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" - #[arg(long, short)] - format: Option, - - /// Set the maximum number of lines Atuin's interface should take up. - #[arg(long = "inline-height")] - inline_height: Option, - - /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. - /// Can be specified multiple times. - #[arg(long)] - author: Option>, - - /// Include duplicate commands in the output (non-interactive only) - #[arg(long)] - include_duplicates: bool, - - /// File name to write the result to (hidden from help as this is meant to be used from a script) - #[arg(long = "result-file", hide = true)] - result_file: Option, -} - -impl Cmd { - /// Returns true if this search command will run in interactive (TUI) mode - pub fn is_interactive(&self) -> bool { - self.interactive - } - - // clippy: please write this instead - // clippy: now it has too many lines - // me: I'll do it later OKAY - #[expect(clippy::too_many_lines)] - pub async fn run( - self, - db: impl Database, - settings: &mut Settings, - store: SqliteStore, - theme: &Theme, - ) -> Result<()> { - let query = self.query.unwrap_or_else(|| { - std::env::var("ATUIN_QUERY").map_or_else( - |_| vec![], - |query| { - query - .split(' ') - .map(std::string::ToString::to_string) - .collect() - }, - ) - }); - - if (self.delete_it_all || self.delete) && self.limit.is_some() { - // Because of how deletion is implemented, it will always delete all matches - // and disregard the limit option. It is also not clear what deletion with a - // limit would even mean. Deleting the LIMIT most recent entries that match - // the search query would make sense, but that wouldn't match what's displayed - // when running the equivalent search, but deleting those entries that are - // displayed with the search would leave any duplicates of those lines which may - // or may not have been intended to be deleted. - eprintln!("\"--limit\" is not compatible with deletion."); - return Ok(()); - } - - if self.delete && query.is_empty() { - eprintln!( - "Please specify a query to match the items you wish to delete. If you wish to delete all history, pass --delete-it-all" - ); - return Ok(()); - } - - if self.delete_it_all && !query.is_empty() { - eprintln!( - "--delete-it-all will delete ALL of your history! It does not require a query." - ); - return Ok(()); - } - - if let Some(search_mode) = self.search_mode { - settings.search_mode = search_mode; - } - if let Some(filter_mode) = self.filter_mode { - settings.filter_mode = Some(filter_mode); - } - if let Some(inline_height) = self.inline_height { - settings.inline_height = inline_height; - } - - settings.shell_up_key_binding = self.shell_up_key_binding; - - // `keymap_mode` specified in config.toml overrides the `--keymap-mode` - // option specified in the keybindings. - settings.keymap_mode = match settings.keymap_mode { - KeymapMode::Auto => self.keymap_mode, - value => value, - }; - settings.keymap_mode_shell = self.keymap_mode; - - let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); - - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - if self.interactive { - let item = interactive::history(&query, settings, db, &history_store, theme).await?; - - if let Some(result_file) = self.result_file { - let mut file = File::create(result_file)?; - write!(file, "{item}")?; - } else if !stdout().is_terminal() { - // stdout is not a terminal - likely command substitution like VAR=$(atuin search -i) - // Write to stdout so it gets captured. This requires some care on Windows, as the current - // console code page or `[Console]::OutputEncoding` on PowerShell may be different from UTF-8. - println!("{item}"); - } else if stderr().is_terminal() { - eprintln!("{}", item.escape_control()); - } else { - eprintln!("{item}"); - } - } else { - let opt_filter = OptFilters { - exit: self.exit, - exclude_exit: self.exclude_exit, - cwd: self.cwd, - exclude_cwd: self.exclude_cwd, - before: self.before, - after: self.after, - limit: self.limit, - offset: self.offset, - reverse: self.reverse, - include_duplicates: self.include_duplicates, - authors: self.author.clone().unwrap_or_default(), - }; - - let mut entries = - run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; - - if entries.is_empty() { - std::process::exit(1) - } - - // if we aren't deleting, print it all - if self.delete || self.delete_it_all { - // delete it - // it only took me _years_ to add this - // sorry - while !entries.is_empty() { - for entry in &entries { - eprintln!("deleting {}", entry.id); - } - - let ids = history_store.delete_entries(entries).await?; - history_store.incremental_build(&db, &ids).await?; - - entries = - run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; - } - } else { - let format = match self.format { - None => Some(settings.history_format.as_str()), - _ => self.format.as_deref(), - }; - let tz = match self.timezone { - Some(Some(tz)) => tz, // User provided a value - Some(None) | None => settings.timezone, // No value was provided - }; - - super::history::print_list( - &entries, - ListMode::from_flags(self.human, self.cmd_only), - format, - self.print0, - true, - tz, - ); - } - } - Ok(()) - } -} - -// This is supposed to more-or-less mirror the command line version, so ofc -// it is going to have a lot of args -#[expect(clippy::too_many_arguments, clippy::cast_possible_truncation)] -async fn run_non_interactive( - settings: &Settings, - filter_options: OptFilters, - query: &[String], - db: &impl Database, -) -> Result> { - let dir = if filter_options.cwd.as_deref() == Some(".") { - Some(utils::get_current_dir()) - } else { - filter_options.cwd - }; - - let context = current_context().await?; - - let opt_filter = OptFilters { - cwd: dir.clone(), - ..filter_options - }; - - let filter_mode = settings.default_filter_mode(context.git_root.is_some()); - - let results = db - .search( - settings.search_mode, - filter_mode, - &context, - query.join(" ").as_str(), - opt_filter, - ) - .await?; - - Ok(results) -} - -#[cfg(test)] -mod tests { - use super::Cmd; - use clap::Parser; - - #[test] - fn search_for_triple_dash() { - // Issue #3028: searching for `---` should not be treated as a CLI flag - let cmd = Cmd::try_parse_from(["search", "---"]); - assert!(cmd.is_ok(), "Failed to parse '---' as a query: {cmd:?}"); - let cmd = cmd.unwrap(); - assert_eq!(cmd.query, Some(vec!["---".to_string()])); - } - - #[test] - fn search_for_double_dash_value() { - // Searching for strings starting with -- should also work - let cmd = Cmd::try_parse_from(["search", "--", "--foo"]); - assert!(cmd.is_ok()); - let cmd = cmd.unwrap(); - assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); - } - - #[test] - fn search_author_cli_flag() { - let cmd = - Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); - assert_eq!( - cmd.author, - Some(vec!["codex".to_string(), "ellie".to_string()]) - ); - } -} diff --git a/crates/atuin/src/command/client/search/cursor.rs b/crates/atuin/src/command/client/search/cursor.rs deleted file mode 100644 index c1cdfee4..00000000 --- a/crates/atuin/src/command/client/search/cursor.rs +++ /dev/null @@ -1,405 +0,0 @@ -use atuin_client::settings::WordJumpMode; - -pub struct Cursor { - source: String, - index: usize, -} - -impl From for Cursor { - fn from(source: String) -> Self { - Self { source, index: 0 } - } -} - -pub struct WordJumper<'a> { - word_chars: &'a str, - word_jump_mode: WordJumpMode, -} - -impl WordJumper<'_> { - fn is_word_boundary(&self, c: char, next_c: char) -> bool { - (c.is_whitespace() && !next_c.is_whitespace()) - || (!c.is_whitespace() && next_c.is_whitespace()) - || (self.word_chars.contains(c) && !self.word_chars.contains(next_c)) - || (!self.word_chars.contains(c) && self.word_chars.contains(next_c)) - } - - fn emacs_get_next_word_pos(&self, source: &str, index: usize) -> usize { - let index = (index + 1..source.len().saturating_sub(1)) - .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) - .unwrap_or(source.len()); - (index + 1..source.len().saturating_sub(1)) - .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) - .unwrap_or(source.len()) - } - - fn emacs_get_prev_word_pos(&self, source: &str, index: usize) -> usize { - let index = (1..index) - .rev() - .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) - .unwrap_or(0); - (1..index) - .rev() - .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) - .map_or(0, |i| i + 1) - } - - fn subl_get_next_word_pos(&self, source: &str, index: usize) -> usize { - let index = (index..source.len().saturating_sub(1)).find(|&i| { - self.is_word_boundary( - source.chars().nth(i).unwrap(), - source.chars().nth(i + 1).unwrap(), - ) - }); - if index.is_none() { - return source.len(); - } - (index.unwrap() + 1..source.len()) - .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()) - .unwrap_or(source.len()) - } - - fn subl_get_prev_word_pos(&self, source: &str, index: usize) -> usize { - let index = (1..index) - .rev() - .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()); - if index.is_none() { - return 0; - } - (1..index.unwrap()) - .rev() - .find(|&i| { - self.is_word_boundary( - source.chars().nth(i - 1).unwrap(), - source.chars().nth(i).unwrap(), - ) - }) - .unwrap_or(0) - } - - fn get_next_word_pos(&self, source: &str, index: usize) -> usize { - match self.word_jump_mode { - WordJumpMode::Emacs => self.emacs_get_next_word_pos(source, index), - WordJumpMode::Subl => self.subl_get_next_word_pos(source, index), - } - } - - fn get_prev_word_pos(&self, source: &str, index: usize) -> usize { - match self.word_jump_mode { - WordJumpMode::Emacs => self.emacs_get_prev_word_pos(source, index), - WordJumpMode::Subl => self.subl_get_prev_word_pos(source, index), - } - } -} - -impl Cursor { - pub fn as_str(&self) -> &str { - self.source.as_str() - } - - pub fn into_inner(self) -> String { - self.source - } - - /// Returns the string before the cursor - pub fn substring(&self) -> &str { - &self.source[..self.index] - } - - /// Returns the currently selected [`char`] - pub fn char(&self) -> Option { - self.source[self.index..].chars().next() - } - - pub fn right(&mut self) { - if self.index < self.source.len() { - loop { - self.index += 1; - if self.source.is_char_boundary(self.index) { - break; - } - } - } - } - - pub fn left(&mut self) -> bool { - if self.index > 0 { - loop { - self.index -= 1; - if self.source.is_char_boundary(self.index) { - break true; - } - } - } else { - false - } - } - - pub fn next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { - let word_jumper = WordJumper { - word_chars, - word_jump_mode, - }; - self.index = word_jumper.get_next_word_pos(&self.source, self.index); - } - - pub fn prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { - let word_jumper = WordJumper { - word_chars, - word_jump_mode, - }; - self.index = word_jumper.get_prev_word_pos(&self.source, self.index); - } - - /// Move cursor to the end of the current/next word (vim `e` motion). - /// - /// If cursor is in the middle of a word, moves to the end of that word. - /// If cursor is at the end of a word (or on whitespace), moves to the - /// end of the next word. - pub fn word_end(&mut self, word_chars: &str) { - let len = self.source.len(); - if self.index >= len { - return; - } - - let chars: Vec = self.source.chars().collect(); - let mut char_idx = self.source[..self.index].chars().count(); - - if char_idx >= chars.len() { - return; - } - - let current = chars[char_idx]; - - // Check if we're at a word boundary (end of current word or on whitespace) - let at_word_boundary = current.is_whitespace() || char_idx + 1 >= chars.len() || { - let next = chars[char_idx + 1]; - next.is_whitespace() || (word_chars.contains(current) != word_chars.contains(next)) - }; - - // If at word boundary, advance past it and skip whitespace to find next word - if at_word_boundary { - char_idx += 1; - while char_idx < chars.len() && chars[char_idx].is_whitespace() { - char_idx += 1; - } - } - - // If we've gone past end, go to end of string - if char_idx >= chars.len() { - self.index = len; - return; - } - - // Find end of word: advance until next char is whitespace or different word type - let in_word_chars = word_chars.contains(chars[char_idx]); - while char_idx < chars.len() { - let next_idx = char_idx + 1; - if next_idx >= chars.len() { - // At last char, move past it - char_idx = next_idx; - break; - } - let next_c = chars[next_idx]; - if next_c.is_whitespace() || (word_chars.contains(next_c) != in_word_chars) { - // Next char is start of new word/whitespace, so current char is end - char_idx = next_idx; - break; - } - char_idx += 1; - } - - // Convert char index back to byte index - self.index = chars.iter().take(char_idx).map(|c| c.len_utf8()).sum(); - } - - pub fn insert(&mut self, c: char) { - self.source.insert(self.index, c); - self.index += c.len_utf8(); - } - - pub fn remove(&mut self) -> Option { - if self.index < self.source.len() { - Some(self.source.remove(self.index)) - } else { - None - } - } - - pub fn remove_next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { - let word_jumper = WordJumper { - word_chars, - word_jump_mode, - }; - let next_index = word_jumper.get_next_word_pos(&self.source, self.index); - self.source.replace_range(self.index..next_index, ""); - } - - pub fn remove_prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { - let word_jumper = WordJumper { - word_chars, - word_jump_mode, - }; - let next_index = word_jumper.get_prev_word_pos(&self.source, self.index); - self.source.replace_range(next_index..self.index, ""); - self.index = next_index; - } - - pub fn back(&mut self) -> Option { - if self.left() { self.remove() } else { None } - } - - pub fn clear(&mut self) { - self.source.clear(); - self.index = 0; - } - - pub fn clear_to_start(&mut self) { - self.source.replace_range(..self.index, ""); - self.index = 0; - } - - pub fn clear_to_end(&mut self) { - self.source.replace_range(self.index.., ""); - self.index = self.source.len(); - } - - pub fn end(&mut self) { - self.index = self.source.len(); - } - - pub fn start(&mut self) { - self.index = 0; - } - - pub fn position(&self) -> usize { - self.index - } -} - -#[cfg(test)] -mod cursor_tests { - use super::Cursor; - use super::*; - - static EMACS_WORD_JUMPER: WordJumper = WordJumper { - word_chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", - word_jump_mode: WordJumpMode::Emacs, - }; - - static SUBL_WORD_JUMPER: WordJumper = WordJumper { - word_chars: "./\\()\"'-:,.;<>~!@#$%^&*|+=[]{}`~?", - word_jump_mode: WordJumpMode::Subl, - }; - - #[test] - fn right() { - // ö is 2 bytes - let mut c = Cursor::from(String::from("öaöböcödöeöfö")); - let indices = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 20, 20, 20]; - for i in indices { - assert_eq!(c.index, i); - c.right(); - } - } - - #[test] - fn left() { - // ö is 2 bytes - let mut c = Cursor::from(String::from("öaöböcödöeöfö")); - c.end(); - let indices = [20, 18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 3, 2, 0, 0, 0, 0]; - for i in indices { - assert_eq!(c.index, i); - c.left(); - } - } - - #[test] - fn test_emacs_get_next_word_pos() { - let s = String::from(" aaa ((()))bbb ((())) "); - let indices = [(0, 6), (3, 6), (7, 18), (19, 30)]; - for (i_src, i_dest) in indices { - assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); - } - assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos("", 0), 0); - } - - #[test] - fn test_emacs_get_prev_word_pos() { - let s = String::from(" aaa ((()))bbb ((())) "); - let indices = [(30, 15), (29, 15), (15, 3), (3, 0)]; - for (i_src, i_dest) in indices { - assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); - } - assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos("", 0), 0); - } - - #[test] - fn test_subl_get_next_word_pos() { - let s = String::from(" aaa ((()))bbb ((())) "); - let indices = [(0, 3), (1, 3), (3, 9), (9, 15), (15, 21), (21, 30)]; - for (i_src, i_dest) in indices { - assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); - } - assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos("", 0), 0); - } - - #[test] - fn test_subl_get_prev_word_pos() { - let s = String::from(" aaa ((()))bbb ((())) "); - let indices = [(30, 21), (21, 15), (15, 9), (9, 3), (3, 0)]; - for (i_src, i_dest) in indices { - assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); - } - assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos("", 0), 0); - } - - #[test] - fn pop() { - let mut s = String::from("öaöböcödöeöfö"); - let mut c = Cursor::from(s.clone()); - c.end(); - while !s.is_empty() { - let c1 = s.pop(); - let c2 = c.back(); - assert_eq!(c1, c2); - assert_eq!(s.as_str(), c.substring()); - } - let c1 = s.pop(); - let c2 = c.back(); - assert_eq!(c1, c2); - } - - #[test] - fn back() { - let mut c = Cursor::from(String::from("öaöböcödöeöfö")); - // move to ^ - for _ in 0..4 { - c.right(); - } - assert_eq!(c.substring(), "öaöb"); - assert_eq!(c.back(), Some('b')); - assert_eq!(c.back(), Some('ö')); - assert_eq!(c.back(), Some('a')); - assert_eq!(c.back(), Some('ö')); - assert_eq!(c.back(), None); - assert_eq!(c.as_str(), "öcödöeöfö"); - } - - #[test] - fn insert() { - let mut c = Cursor::from(String::from("öaöböcödöeöfö")); - // move to ^ - for _ in 0..4 { - c.right(); - } - assert_eq!(c.substring(), "öaöb"); - c.insert('ö'); - c.insert('g'); - c.insert('ö'); - c.insert('h'); - assert_eq!(c.substring(), "öaöbögöh"); - assert_eq!(c.as_str(), "öaöbögöhöcödöeöfö"); - } -} diff --git a/crates/atuin/src/command/client/search/duration.rs b/crates/atuin/src/command/client/search/duration.rs deleted file mode 100644 index 54856c87..00000000 --- a/crates/atuin/src/command/client/search/duration.rs +++ /dev/null @@ -1,65 +0,0 @@ -use core::fmt; -use std::{ops::ControlFlow, time::Duration}; - -#[expect(clippy::module_name_repetitions)] -pub fn format_duration_into(dur: Duration, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fn item(unit: &'static str, value: u64) -> ControlFlow<(&'static str, u64)> { - if value > 0 { - ControlFlow::Break((unit, value)) - } else { - ControlFlow::Continue(()) - } - } - - // impl taken and modified from - // https://github.com/tailhook/humantime/blob/master/src/duration.rs#L295-L331 - // Copyright (c) 2016 The humantime Developers - fn fmt(f: Duration) -> ControlFlow<(&'static str, u64), ()> { - let secs = f.as_secs(); - let nanos = f.subsec_nanos(); - - let years = secs / 31_557_600; // 365.25d - let year_days = secs % 31_557_600; - let months = year_days / 2_630_016; // 30.44d - let month_days = year_days % 2_630_016; - let days = month_days / 86400; - let day_secs = month_days % 86400; - let hours = day_secs / 3600; - let minutes = day_secs % 3600 / 60; - let seconds = day_secs % 60; - - let millis = nanos / 1_000_000; - let micros = nanos / 1_000; - - // a difference from our impl than the original is that - // we only care about the most-significant segment of the duration. - // If the item call returns `Break`, then the `?` will early-return. - // This allows for a very consise impl - item("y", years)?; - item("mo", months)?; - item("d", days)?; - item("h", hours)?; - item("m", minutes)?; - item("s", seconds)?; - item("ms", u64::from(millis))?; - item("us", u64::from(micros))?; - item("ns", u64::from(nanos))?; - ControlFlow::Continue(()) - } - - match fmt(dur) { - ControlFlow::Break((unit, value)) => write!(f, "{value}{unit}"), - ControlFlow::Continue(()) => write!(f, "0s"), - } -} - -#[expect(clippy::module_name_repetitions)] -pub fn format_duration(f: Duration) -> String { - struct F(Duration); - impl fmt::Display for F { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_duration_into(self.0, f) - } - } - F(f).to_string() -} diff --git a/crates/atuin/src/command/client/search/engines.rs b/crates/atuin/src/command/client/search/engines.rs deleted file mode 100644 index 886f0171..00000000 --- a/crates/atuin/src/command/client/search/engines.rs +++ /dev/null @@ -1,95 +0,0 @@ -use async_trait::async_trait; -use atuin_client::{ - database::{Context, Database, OptFilters}, - history::{AUTHOR_FILTER_ALL_USER, History, HistoryId}, - settings::{FilterMode, SearchMode, Settings}, -}; -use eyre::Result; - -use super::cursor::Cursor; - -#[cfg(feature = "daemon")] -pub mod daemon; -pub mod db; -pub mod skim; - -#[expect(unused)] // settings is only used if daemon feature is enabled -pub fn engine(search_mode: SearchMode, settings: &Settings) -> Box { - match search_mode { - SearchMode::Skim => Box::new(skim::Search::new()) as Box<_>, - #[cfg(feature = "daemon")] - SearchMode::DaemonFuzzy => Box::new(daemon::Search::new(settings)) as Box<_>, - #[cfg(not(feature = "daemon"))] - SearchMode::DaemonFuzzy => { - // Fall back to fuzzy mode if daemon feature is not enabled - Box::new(db::Search(SearchMode::Fuzzy)) as Box<_> - } - mode => Box::new(db::Search(mode)) as Box<_>, - } -} - -pub struct SearchState { - pub input: Cursor, - pub filter_mode: FilterMode, - pub context: Context, - pub custom_context: Option, -} - -impl SearchState { - pub(crate) fn rotate_filter_mode(&mut self, settings: &Settings, offset: isize) { - let mut i = settings - .search - .filters - .iter() - .position(|&m| m == self.filter_mode) - .unwrap_or_default(); - for _ in 0..settings.search.filters.len() { - i = (i.wrapping_add_signed(offset)) % settings.search.filters.len(); - let mode = settings.search.filters[i]; - if self.filter_mode_available(mode, settings) { - self.filter_mode = mode; - break; - } - } - } - - fn filter_mode_available(&self, mode: FilterMode, settings: &Settings) -> bool { - match mode { - FilterMode::Global | FilterMode::SessionPreload => self.custom_context.is_none(), - FilterMode::Workspace => settings.workspaces && self.context.git_root.is_some(), - _ => true, - } - } -} - -#[async_trait] -pub trait SearchEngine: Send + Sync + 'static { - async fn full_query( - &mut self, - state: &SearchState, - db: &mut dyn Database, - ) -> Result>; - - async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result> { - if state.input.as_str().is_empty() { - Ok(db - .search( - SearchMode::FullText, - state.filter_mode, - &state.context, - "", - OptFilters { - limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], - ..Default::default() - }, - ) - .await? - .into_iter() - .collect::>()) - } else { - self.full_query(state, db).await - } - } - fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec; -} diff --git a/crates/atuin/src/command/client/search/engines/daemon.rs b/crates/atuin/src/command/client/search/engines/daemon.rs deleted file mode 100644 index 8b15c180..00000000 --- a/crates/atuin/src/command/client/search/engines/daemon.rs +++ /dev/null @@ -1,249 +0,0 @@ -use async_trait::async_trait; -use atuin_client::{ - database::{Database, OptFilters}, - history::{AUTHOR_FILTER_ALL_USER, History}, - settings::{SearchMode, Settings}, -}; -use atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; -use atuin_nucleo_matcher::{ - Config, Matcher, Utf32Str, - pattern::{CaseMatching, Normalization, Pattern}, -}; -use eyre::Result; -use tracing::{Level, debug, instrument, span}; -use uuid::Uuid; - -use super::{SearchEngine, SearchState}; -use crate::command::client::daemon; - -pub struct Search { - client: Option, - query_id: u64, - settings: Settings, - #[cfg(unix)] - socket_path: String, - #[cfg(not(unix))] - tcp_port: u64, -} - -impl Search { - pub fn new(settings: &Settings) -> Self { - Search { - client: None, - query_id: 0, - settings: settings.clone(), - #[cfg(unix)] - socket_path: settings.daemon.socket_path.clone(), - #[cfg(not(unix))] - tcp_port: settings.daemon.tcp_port, - } - } - - #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")] - async fn get_client(&mut self) -> Result<&mut SearchClient> { - if self.client.is_none() { - self.connect().await?; - } - Ok(self.client.as_mut().unwrap()) - } - - async fn connect(&mut self) -> Result<()> { - #[cfg(unix)] - let client = SearchClient::new(self.socket_path.clone()).await?; - - #[cfg(not(unix))] - let client = SearchClient::new(self.tcp_port).await?; - - self.client = Some(client); - Ok(()) - } - - fn should_retry(err: &eyre::Report) -> bool { - matches!( - classify_error(err), - DaemonClientErrorKind::Connect - | DaemonClientErrorKind::Unavailable - | DaemonClientErrorKind::Unimplemented - ) - } - - fn next_query_id(&mut self) -> u64 { - self.query_id += 1; - self.query_id - } - - /// Check if query contains regex pattern (r/.../) - /// Nucleo doesn't support regex, so we fall back to database search - fn contains_regex_pattern(query: &str) -> bool { - query.starts_with("r/") || query.contains(" r/") - } - - #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")] - async fn fallback_to_db_search( - &self, - state: &SearchState, - db: &dyn Database, - ) -> Result> { - let results = db - .search( - SearchMode::FullText, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], - ..Default::default() - }, - ) - .await - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) - } - - #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] - async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result> { - let placeholders: Vec = ids.iter().map(|id| format!("'{id}'")).collect(); - let sql_query = format!( - "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", - placeholders.join(",") - ); - Ok(db.query_history(&sql_query).await?) - } -} - -#[async_trait] -impl SearchEngine for Search { - #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))] - async fn full_query( - &mut self, - state: &SearchState, - db: &mut dyn Database, - ) -> Result> { - let query = state.input.as_str().to_string(); - - // Fall back to database for regex queries (Nucleo doesn't support regex) - if Self::contains_regex_pattern(&query) { - debug!(query = %query, "[daemon-client] regex detected, falling back to db"); - return self.fallback_to_db_search(state, db).await; - } - - let query_id = self.next_query_id(); - - let span = - span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id); - - // Try to connect and search; if it fails with a retriable error, - // auto-start the daemon and retry once. - let first_attempt = async { - let client = self.get_client().await?; - client - .search( - query.clone(), - query_id, - state.filter_mode, - Some(state.context.clone()), - ) - .await - } - .await; - - let mut stream = match first_attempt { - Ok(stream) => stream, - Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => { - debug!("daemon not available, attempting auto-start"); - self.client = None; - - daemon::ensure_daemon_running(&self.settings).await?; - - let client = self.get_client().await?; - client - .search( - query.clone(), - query_id, - state.filter_mode, - Some(state.context.clone()), - ) - .await? - } - Err(err) => return Err(err), - }; - - let mut ids = Vec::with_capacity(200); - span!(Level::TRACE, "daemon_search.resp") - .in_scope(async || { - while let Ok(Some(response)) = stream.message().await { - let span2 = span!( - Level::TRACE, - "daemon_search.resp.item", - query_id = response.query_id - ); - let _span2 = span2.enter(); - // Only process if the query_id matches (prevents stale responses) - if response.query_id == query_id { - let uuids = response - .ids - .iter() - .map(|id| { - let bytes: [u8; 16] = - id.as_slice().try_into().expect("id should be 16 bytes"); - Uuid::from_bytes(bytes).as_simple().to_string() - }) - .collect::>(); - ids.extend(uuids); - } - drop(_span2); - drop(span2); - } - }) - .await; - drop(span); - - if ids.is_empty() { - debug!(query = %query, results = 0, "[daemon-client] empty results"); - return Ok(Vec::new()); - } - - // // Hydrate from local database - let results = self.hydrate_from_db(db, &ids).await?; - - // // Reorder results to match the order from the daemon (which is ranked by relevance) - let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { - let mut ordered_results = Vec::with_capacity(results.len()); - for id in &ids { - if let Some(history) = results.iter().find(|h| h.id.0 == *id) { - ordered_results.push(history.clone()); - } - } - ordered_results - }); - - debug!( - query = %query, - results = results.len(), - "[daemon-client]" - ); - - Ok(ordered_results) - } - - #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")] - fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { - // Use fulltext highlighting for regex queries - if Self::contains_regex_pattern(search_input) { - return super::db::get_highlight_indices_fulltext(command, search_input); - } - - let mut matcher = Matcher::new(Config::DEFAULT); - let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart); - - let mut indices: Vec = Vec::new(); - let mut haystack_buf = Vec::new(); - - let haystack = Utf32Str::new(command, &mut haystack_buf); - pattern.indices(haystack, &mut matcher, &mut indices); - - // Convert u32 indices to usize - indices.into_iter().map(|i| i as usize).collect() - } -} diff --git a/crates/atuin/src/command/client/search/engines/db.rs b/crates/atuin/src/command/client/search/engines/db.rs deleted file mode 100644 index b15aabd8..00000000 --- a/crates/atuin/src/command/client/search/engines/db.rs +++ /dev/null @@ -1,110 +0,0 @@ -use super::{SearchEngine, SearchState}; -use async_trait::async_trait; -use atuin_client::{ - database::Database, - database::OptFilters, - database::{QueryToken, QueryTokenizer}, - history::{AUTHOR_FILTER_ALL_USER, History}, - settings::SearchMode, -}; -use eyre::Result; -use norm::Metric; -use norm::fzf::{FzfParser, FzfV2}; -use std::ops::Range; -use tracing::{Level, instrument}; - -pub struct Search(pub SearchMode); - -#[async_trait] -impl SearchEngine for Search { - #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))] - async fn full_query( - &mut self, - state: &SearchState, - db: &mut dyn Database, - ) -> Result> { - let results = db - .search( - self.0, - state.filter_mode, - &state.context, - state.input.as_str(), - OptFilters { - limit: Some(200), - authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], - ..Default::default() - }, - ) - .await - // ignore errors as it may be caused by incomplete regex - .map_or(Vec::new(), |r| r.into_iter().collect()); - Ok(results) - } - - #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] - fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { - if self.0 == SearchMode::Prefix { - return vec![]; - } else if self.0 == SearchMode::FullText { - return get_highlight_indices_fulltext(command, search_input); - } - let mut fzf = FzfV2::new(); - let mut parser = FzfParser::new(); - let query = parser.parse(search_input); - let mut ranges: Vec> = Vec::new(); - let _ = fzf.distance_and_ranges(query, command, &mut ranges); - - // convert ranges to all indices - ranges.into_iter().flatten().collect() - } -} - -#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")] -pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec { - let mut ranges = vec![]; - let lower_command = command.to_ascii_lowercase(); - - for token in QueryTokenizer::new(search_input) { - let matchee = if token.has_uppercase() { - command - } else { - &lower_command - }; - - if token.is_inverse() { - continue; - } - - match token { - QueryToken::Or => {} - QueryToken::Regex(r) => { - if let Ok(re) = regex::Regex::new(r) { - for m in re.find_iter(command) { - ranges.push(m.range()); - } - } - } - QueryToken::MatchStart(term, _) => { - if matchee.starts_with(term) { - ranges.push(0..term.len()); - } - } - QueryToken::MatchEnd(term, _) => { - if matchee.ends_with(term) { - let l = matchee.len(); - ranges.push((l - term.len())..l); - } - } - QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => { - for (idx, m) in matchee.match_indices(term) { - ranges.push(idx..(idx + m.len())); - } - } - } - } - - let mut ret: Vec<_> = ranges.into_iter().flatten().collect(); - ret.sort_unstable(); - ret.dedup(); - ret -} diff --git a/crates/atuin/src/command/client/search/engines/skim.rs b/crates/atuin/src/command/client/search/engines/skim.rs deleted file mode 100644 index fe05fd09..00000000 --- a/crates/atuin/src/command/client/search/engines/skim.rs +++ /dev/null @@ -1,229 +0,0 @@ -use std::path::Path; - -use async_trait::async_trait; -use atuin_client::{ - database::Database, - history::{History, is_known_agent}, - settings::FilterMode, -}; -use eyre::Result; -use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; -use itertools::Itertools; -use time::OffsetDateTime; -use tokio::task::yield_now; -use tracing::{Level, instrument, warn}; -use uuid; - -use super::{SearchEngine, SearchState}; - -pub struct Search { - all_history: Vec<(History, i32)>, - engine: SkimMatcherV2, -} - -impl Search { - pub fn new() -> Self { - Search { - all_history: vec![], - engine: SkimMatcherV2::default(), - } - } -} - -#[async_trait] -impl SearchEngine for Search { - #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))] - async fn full_query( - &mut self, - state: &SearchState, - db: &mut dyn Database, - ) -> Result> { - if self.all_history.is_empty() { - self.all_history = load_all_history(db).await; - } - - Ok(fuzzy_search(&self.engine, state, &self.all_history).await) - } - - #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")] - fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { - let (_, indices) = self - .engine - .fuzzy_indices(command, search_input) - .unwrap_or_default(); - indices - } -} - -#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] -async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { - db.all_with_count().await.unwrap() -} - -#[expect(clippy::too_many_lines)] -#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))] -async fn fuzzy_search( - engine: &SkimMatcherV2, - state: &SearchState, - all_history: &[(History, i32)], -) -> Vec { - let mut set = Vec::with_capacity(200); - let mut ranks = Vec::with_capacity(200); - let query = state.input.as_str(); - let now = OffsetDateTime::now_utc(); - - for (i, (history, count)) in all_history.iter().enumerate() { - if i % 256 == 0 { - yield_now().await; - } - if is_known_agent(&history.author) { - continue; - } - let context = &state.context; - let git_root = context - .git_root - .as_ref() - .and_then(|git_root| git_root.to_str()) - .unwrap_or(&context.cwd); - match state.filter_mode { - FilterMode::Global => {} - // we aggregate host by ',' separating them - FilterMode::Host - if history - .hostname - .split(',') - .contains(&context.hostname.as_str()) => {} - // we aggregate session by concattenating them. - // sessions are 32 byte simple uuid formats - FilterMode::Session - if history - .session - .as_bytes() - .chunks(32) - .contains(&context.session.as_bytes()) => {} - // SessionPreload: include current session + global history from before session start - FilterMode::SessionPreload => { - let is_current_session = { - history - .session - .as_bytes() - .chunks(32) - .any(|chunk| chunk == context.session.as_bytes()) - }; - - if !is_current_session { - let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else { - warn!("failed to parse session id '{}'", context.session); - continue; - }; - let Some(timestamp) = uuid.get_timestamp() else { - warn!( - "failed to get timestamp from uuid '{}'", - uuid.as_hyphenated() - ); - continue; - }; - let (seconds, nanos) = timestamp.to_unix(); - let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos( - i128::from(seconds) * 1_000_000_000 + i128::from(nanos), - ) else { - warn!( - "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}" - ); - continue; - }; - - if history.timestamp >= session_start { - continue; - } - } - } - // we aggregate directory by ':' separating them - FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} - FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} - _ => continue, - } - #[expect(clippy::cast_lossless, clippy::cast_precision_loss)] - if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { - let begin = indices.first().copied().unwrap_or_default(); - - let mut duration = (now - history.timestamp).as_seconds_f64().log2(); - if !duration.is_finite() || duration <= 1.0 { - duration = 1.0; - } - // these + X.0 just make the log result a bit smoother. - // log is very spiky towards 1-4, but I want a gradual decay. - // eg: - // log2(4) = 2, log2(5) = 2.3 (16% increase) - // log2(8) = 3, log2(9) = 3.16 (5% increase) - // log2(16) = 4, log2(17) = 4.08 (2% increase) - let count = (*count as f64 + 8.0).log2(); - let begin = (begin as f64 + 16.0).log2(); - let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); - let path = (path as f64 + 8.0).log2(); - - // reduce longer durations, raise higher counts, raise matches close to the start - let score = (-score as f64) * count / path / duration / begin; - - 'insert: { - // algorithm: - // 1. find either the position that this command ranks - // 2. find the same command positioned better than our rank. - for i in 0..set.len() { - // do we out score the current position? - if ranks[i] > score { - ranks.insert(i, score); - set.insert(i, history.clone()); - let mut j = i + 1; - while j < set.len() { - // remove duplicates that have a worse score - if set[j].command == history.command { - ranks.remove(j); - set.remove(j); - - // break this while loop because there won't be any other - // duplicates. - break; - } - j += 1; - } - - // keep it limited - if ranks.len() > 200 { - ranks.pop(); - set.pop(); - } - - break 'insert; - } - // don't continue if this command has a better score already - if set[i].command == history.command { - break 'insert; - } - } - - if set.len() < 200 { - ranks.push(score); - set.push(history.clone()); - } - } - } - } - - set -} - -fn path_dist(a: &Path, b: &Path) -> usize { - let mut a: Vec<_> = a.components().collect(); - let b: Vec<_> = b.components().collect(); - - let mut dist = 0; - - // pop a until there's a common ancestor - while !b.starts_with(&a) { - dist += 1; - a.pop(); - } - - b.len() - a.len() + dist -} diff --git a/crates/atuin/src/command/client/search/history_list.rs b/crates/atuin/src/command/client/search/history_list.rs deleted file mode 100644 index 7af324b4..00000000 --- a/crates/atuin/src/command/client/search/history_list.rs +++ /dev/null @@ -1,429 +0,0 @@ -use std::time::Duration; - -use super::duration::format_duration; -use super::engines::SearchEngine; -use atuin_client::{ - history::History, - settings::{UiColumn, UiColumnType}, - theme::{Meaning, Theme}, -}; -use atuin_common::utils::Escapable as _; -use itertools::Itertools; -use ratatui::{ - backend::FromCrossterm, - buffer::Buffer, - crossterm::style, - layout::Rect, - style::{Modifier, Style}, - widgets::{Block, StatefulWidget, Widget}, -}; -use time::OffsetDateTime; - -pub struct HistoryHighlighter<'a> { - pub engine: &'a dyn SearchEngine, - pub search_input: &'a str, -} - -impl HistoryHighlighter<'_> { - pub fn get_highlight_indices(&self, command: &str) -> Vec { - self.engine - .get_highlight_indices(command, self.search_input) - } -} - -pub struct HistoryList<'a> { - history: &'a [History], - block: Option>, - inverted: bool, - /// Apply an alternative highlighting to the selected row - alternate_highlight: bool, - now: &'a dyn Fn() -> OffsetDateTime, - indicator: &'a str, - theme: &'a Theme, - history_highlighter: HistoryHighlighter<'a>, - show_numeric_shortcuts: bool, - /// Columns to display (in order, after the indicator) - columns: &'a [UiColumn], -} - -#[derive(Default)] -pub struct ListState { - offset: usize, - selected: usize, - max_entries: usize, -} - -impl ListState { - pub fn selected(&self) -> usize { - self.selected - } - - pub fn max_entries(&self) -> usize { - self.max_entries - } - - pub fn offset(&self) -> usize { - self.offset - } - - pub fn select(&mut self, index: usize) { - self.selected = index; - } -} - -impl StatefulWidget for HistoryList<'_> { - type State = ListState; - - fn render(mut self, area: Rect, buf: &mut Buffer, state: &mut Self::State) { - let list_area = self.block.take().map_or(area, |b| { - let inner_area = b.inner(area); - b.render(area, buf); - inner_area - }); - - if list_area.width < 1 || list_area.height < 1 || self.history.is_empty() { - return; - } - let list_height = list_area.height as usize; - - let (start, end) = self.get_items_bounds(state.selected, state.offset, list_height); - state.offset = start; - state.max_entries = end - start; - - let mut s = DrawState { - buf, - list_area, - x: 0, - y: 0, - state, - inverted: self.inverted, - alternate_highlight: self.alternate_highlight, - now: &self.now, - indicator: self.indicator, - theme: self.theme, - history_highlighter: self.history_highlighter, - show_numeric_shortcuts: self.show_numeric_shortcuts, - columns: self.columns, - }; - - for item in self.history.iter().skip(state.offset).take(end - start) { - s.render_row(item); - - // reset line - s.y += 1; - s.x = 0; - } - } -} - -impl<'a> HistoryList<'a> { - #[expect(clippy::too_many_arguments)] - pub fn new( - history: &'a [History], - inverted: bool, - alternate_highlight: bool, - now: &'a dyn Fn() -> OffsetDateTime, - indicator: &'a str, - theme: &'a Theme, - history_highlighter: HistoryHighlighter<'a>, - show_numeric_shortcuts: bool, - columns: &'a [UiColumn], - ) -> Self { - Self { - history, - block: None, - inverted, - alternate_highlight, - now, - indicator, - theme, - history_highlighter, - show_numeric_shortcuts, - columns, - } - } - - pub fn block(mut self, block: Block<'a>) -> Self { - self.block = Some(block); - self - } - - fn get_items_bounds(&self, selected: usize, offset: usize, height: usize) -> (usize, usize) { - let offset = offset.min(self.history.len().saturating_sub(1)); - - let max_scroll_space = height.min(10).min(self.history.len() - selected); - if offset + height < selected + max_scroll_space { - let end = selected + max_scroll_space; - (end - height, end) - } else if selected < offset { - (selected, selected + height) - } else { - (offset, offset + height) - } - } -} - -struct DrawState<'a> { - buf: &'a mut Buffer, - list_area: Rect, - x: u16, - y: u16, - state: &'a ListState, - inverted: bool, - alternate_highlight: bool, - now: &'a dyn Fn() -> OffsetDateTime, - indicator: &'a str, - theme: &'a Theme, - history_highlighter: HistoryHighlighter<'a>, - show_numeric_shortcuts: bool, - columns: &'a [UiColumn], -} - -// these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. -// Yes, this is a hack, but it makes me feel happy -static SLICES: &str = " > 1 2 3 4 5 6 7 8 9 "; - -impl DrawState<'_> { - /// Render a complete row for a history item based on configured columns. - fn render_row(&mut self, h: &History) { - // Always render the indicator first (width 3) - self.index(); - - // Calculate the width for the expanding column - // Fixed columns use their configured width + 1 (trailing space) - let indicator_width: u16 = 3; - let fixed_width: u16 = self - .columns - .iter() - .filter(|c| !c.expand) - .map(|c| c.width + 1) - .sum(); - let expand_width = self - .list_area - .width - .saturating_sub(indicator_width + fixed_width); - - let style = self.theme.as_style(Meaning::Base); - // Render each configured column - for (idx, column) in self.columns.iter().enumerate() { - if idx != 0 { - self.draw(" ", Style::from_crossterm(style)); - } - let width = if column.expand { - expand_width - } else { - column.width - }; - match column.column_type { - UiColumnType::Duration => self.duration(h, width), - UiColumnType::Time => self.time(h, width), - UiColumnType::Datetime => self.datetime(h, width), - UiColumnType::Directory => self.directory(h, width), - UiColumnType::Host => self.host(h, width), - UiColumnType::User => self.user(h, width), - UiColumnType::Exit => self.exit_code(h, width), - UiColumnType::Command => self.command(h), - } - } - } - - fn index(&mut self) { - if !self.show_numeric_shortcuts { - let i = self.y as usize + self.state.offset; - let is_selected = i == self.state.selected(); - let prompt: &str = if is_selected { self.indicator } else { " " }; - self.draw(prompt, Style::default()); - return; - } - - // these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. - // Yes, this is a hack, but it makes me feel happy - - let i = self.y as usize + self.state.offset; - let i = i.checked_sub(self.state.selected); - let i = i.unwrap_or(10).min(10) * 2; - let prompt: &str = if i == 0 { - self.indicator - } else { - &SLICES[i..i + 3] - }; - self.draw(prompt, Style::default()); - } - - fn duration(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(if h.success() { - Meaning::AlertInfo - } else { - Meaning::AlertError - }); - let duration = Duration::from_nanos(u64::try_from(h.duration).unwrap_or(0)); - let formatted = format_duration(duration); - let w = width as usize; - // Right-align duration within its column width, plus trailing space - let display = format!("{formatted:>w$}"); - self.draw(&display, Style::from_crossterm(style)); - } - - fn time(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(Meaning::Guidance); - - // Account for the chance that h.timestamp is "in the future" - // This would mean that "since" is negative, and the unwrap here - // would fail. - // If the timestamp would otherwise be in the future, display - // the time since as 0. - let since = (self.now)() - h.timestamp; - let time = format_duration(since.try_into().unwrap_or_default()); - - // Format as "Xs ago" right-aligned within column width - let w = width as usize; - let time_str = format!("{time} ago"); - - let display = format!("{time_str:>w$}"); - self.draw(&display, Style::from_crossterm(style)); - } - - fn command(&mut self, h: &History) { - let mut style = self.theme.as_style(Meaning::Base); - let mut row_highlighted = false; - if !self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) - { - row_highlighted = true; - // if not applying alternative highlighting to the whole row, color the command - style = self.theme.as_style(Meaning::AlertError); - style.attributes.set(style::Attribute::Bold); - } - - let highlight_indices = self.history_highlighter.get_highlight_indices( - h.command - .escape_control() - .split_ascii_whitespace() - .join(" ") - .as_str(), - ); - - let mut pos = 0; - for section in h.command.escape_control().split_ascii_whitespace() { - if pos != 0 { - self.draw(" ", Style::from_crossterm(style)); - } - for ch in section.chars() { - if self.x > self.list_area.width { - // Avoid attempting to draw a command section beyond the width - // of the list - return; - } - let mut style = style; - if highlight_indices.contains(&pos) { - if row_highlighted { - // if the row is highlighted bold is not enough as the whole row is bold - // change the color too - style = self.theme.as_style(Meaning::AlertWarn); - } - style.attributes.set(style::Attribute::Bold); - } - let s = ch.to_string(); - self.draw(&s, Style::from_crossterm(style)); - pos += s.len(); - } - pos += 1; - } - } - - /// Render the absolute datetime column (e.g., "2025-01-22 14:35") - fn datetime(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(Meaning::Annotation); - // Format: YYYY-MM-DD HH:MM - let formatted = h - .timestamp - .format( - &time::format_description::parse("[year]-[month]-[day] [hour]:[minute]") - .expect("valid format"), - ) - .unwrap_or_else(|_| "????-??-?? ??:??".to_string()); - let w = width as usize; - let display = format!("{formatted:w$}"); - self.draw(&display, Style::from_crossterm(style)); - } - - /// Render the directory column (working directory, truncated) - fn directory(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(Meaning::Annotation); - let w = width as usize; - let cwd = &h.cwd; - let char_count = cwd.chars().count(); - // Truncate from the left with "..." if too long, plus trailing space - // Use character count for comparison and skip for UTF-8 safety - let display = if char_count > w && w >= 4 { - let truncated: String = cwd.chars().skip(char_count - (w - 3)).collect(); - format!("...{truncated}") - } else { - format!("{cwd:w$}") - }; - self.draw(&display, Style::from_crossterm(style)); - } - - /// Render the host column (just the hostname) - fn host(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(Meaning::Annotation); - let w = width as usize; - // Database stores hostname as "hostname:username" - let host = h.hostname.split(':').next().unwrap_or(&h.hostname); - let char_count = host.chars().count(); - // Use character count for comparison and take for UTF-8 safety - let display = if char_count > w && w >= 4 { - let truncated: String = host.chars().take(w.saturating_sub(4)).collect(); - format!("{truncated}...") - } else { - format!("{host:w$}") - }; - self.draw(&display, Style::from_crossterm(style)); - } - - /// Render the user column - fn user(&mut self, h: &History, width: u16) { - let style = self.theme.as_style(Meaning::Annotation); - let w = width as usize; - // Database stores hostname as "hostname:username" - let user = h.hostname.split(':').nth(1).unwrap_or(""); - let char_count = user.chars().count(); - // Use character count for comparison and take for UTF-8 safety - let display = if char_count > w && w >= 4 { - let truncated: String = user.chars().take(w.saturating_sub(4)).collect(); - format!("{truncated}...") - } else { - format!("{user:w$}") - }; - self.draw(&display, Style::from_crossterm(style)); - } - - /// Render the exit code column - fn exit_code(&mut self, h: &History, width: u16) { - let style = if h.success() { - self.theme.as_style(Meaning::AlertInfo) - } else { - self.theme.as_style(Meaning::AlertError) - }; - let w = width as usize; - let display = format!("{:>w$}", h.exit); - self.draw(&display, Style::from_crossterm(style)); - } - - fn draw(&mut self, s: &str, mut style: Style) { - let cx = self.list_area.left() + self.x; - - let cy = if self.inverted { - self.list_area.top() + self.y - } else { - self.list_area.bottom() - self.y - 1 - }; - - if self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) - { - style = style.add_modifier(Modifier::REVERSED); - } - - let w = (self.list_area.width - self.x) as usize; - self.x += self.buf.set_stringn(cx, cy, s, w, style).0 - cx; - } -} diff --git a/crates/atuin/src/command/client/search/inspector.rs b/crates/atuin/src/command/client/search/inspector.rs deleted file mode 100644 index e2cdabe5..00000000 --- a/crates/atuin/src/command/client/search/inspector.rs +++ /dev/null @@ -1,421 +0,0 @@ -use std::time::Duration; -use time::macros::format_description; - -use atuin_client::{ - history::{History, HistoryStats}, - settings::{Settings, Timezone}, -}; -use ratatui::{ - Frame, - backend::FromCrossterm, - layout::Rect, - prelude::{Constraint, Direction, Layout}, - style::Style, - text::{Span, Text}, - widgets::{Bar, BarChart, BarGroup, Block, Borders, Padding, Paragraph, Row, Table}, -}; - -use super::duration::format_duration; - -use super::super::theme::{Meaning, Theme}; -use super::interactive::{Compactness, to_compactness}; - -#[expect(clippy::cast_sign_loss)] -fn u64_or_zero(num: i64) -> u64 { - if num < 0 { 0 } else { num as u64 } -} - -pub fn draw_commands( - f: &mut Frame<'_>, - parent: Rect, - history: &History, - stats: &HistoryStats, - compact: bool, - theme: &Theme, -) { - let commands = Layout::default() - .direction(if compact { - Direction::Vertical - } else { - Direction::Horizontal - }) - .constraints(if compact { - [ - Constraint::Length(1), - Constraint::Length(1), - Constraint::Min(0), - ] - } else { - [ - Constraint::Ratio(1, 4), - Constraint::Ratio(1, 2), - Constraint::Ratio(1, 4), - ] - }) - .split(parent); - - let command = Paragraph::new(Text::from(Span::styled( - history.command.clone(), - Style::from_crossterm(theme.as_style(Meaning::Important)), - ))) - .block(if compact { - Block::new() - .borders(Borders::NONE) - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - } else { - Block::new() - .borders(Borders::ALL) - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - .title("Command") - .padding(Padding::horizontal(1)) - }); - - let previous = Paragraph::new( - stats - .previous - .clone() - .map_or_else(|| "[No previous command]".to_string(), |prev| prev.command), - ) - .block(if compact { - Block::new() - .borders(Borders::NONE) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - } else { - Block::new() - .borders(Borders::ALL) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - .title("Previous command") - .padding(Padding::horizontal(1)) - }); - - // Add [] around blank text, as when this is shown in a list - // compacted, it makes it more obviously control text. - let next = Paragraph::new( - stats - .next - .clone() - .map_or_else(|| "[No next command]".to_string(), |next| next.command), - ) - .block(if compact { - Block::new() - .borders(Borders::NONE) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - } else { - Block::new() - .borders(Borders::ALL) - .title("Next command") - .padding(Padding::horizontal(1)) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - }); - - f.render_widget(previous, commands[0]); - f.render_widget(command, commands[1]); - f.render_widget(next, commands[2]); -} - -pub fn draw_stats_table( - f: &mut Frame<'_>, - parent: Rect, - history: &History, - tz: Timezone, - stats: &HistoryStats, - theme: &Theme, -) { - let duration = Duration::from_nanos(u64_or_zero(history.duration)); - let avg_duration = Duration::from_nanos(stats.average_duration); - let (host, user) = history.hostname.split_once(':').unwrap_or(("", "")); - - let rows = [ - Row::new(vec!["Host".to_string(), host.to_string()]), - Row::new(vec!["User".to_string(), user.to_string()]), - Row::new(vec![ - "Time".to_string(), - history.timestamp.to_offset(tz.0).to_string(), - ]), - Row::new(vec!["Duration".to_string(), format_duration(duration)]), - Row::new(vec![ - "Avg duration".to_string(), - format_duration(avg_duration), - ]), - Row::new(vec!["Exit".to_string(), history.exit.to_string()]), - Row::new(vec!["Directory".to_string(), history.cwd.clone()]), - Row::new(vec!["Session".to_string(), history.session.clone()]), - Row::new(vec!["Total runs".to_string(), stats.total.to_string()]), - ]; - - let widths = [Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]; - - let table = Table::new(rows, widths).column_spacing(1).block( - Block::default() - .title("Command stats") - .borders(Borders::ALL) - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - .padding(Padding::vertical(1)), - ); - - f.render_widget(table, parent); -} - -fn num_to_day(num: &str) -> String { - match num { - "0" => "Sunday".to_string(), - "1" => "Monday".to_string(), - "2" => "Tuesday".to_string(), - "3" => "Wednesday".to_string(), - "4" => "Thursday".to_string(), - "5" => "Friday".to_string(), - "6" => "Saturday".to_string(), - _ => "Invalid day".to_string(), - } -} - -fn sort_duration_over_time(durations: &[(String, i64)]) -> Vec<(String, i64)> { - let format = format_description!("[day]-[month]-[year]"); - let output = format_description!("[month]/[year repr:last_two]"); - - let mut durations: Vec<(time::Date, i64)> = durations - .iter() - .map(|d| { - ( - time::Date::parse(d.0.as_str(), &format).expect("invalid date string from sqlite"), - d.1, - ) - }) - .collect(); - - durations.sort_by_key(|a| a.0); - - durations - .iter() - .map(|(date, duration)| { - ( - date.format(output).expect("failed to format sqlite date"), - *duration, - ) - }) - .collect() -} - -fn draw_stats_charts(f: &mut Frame<'_>, parent: Rect, stats: &HistoryStats, theme: &Theme) { - let exits: Vec = stats - .exits - .iter() - .map(|(exit, count)| { - Bar::default() - .label(exit.to_string()) - .value(u64_or_zero(*count)) - }) - .collect(); - - let exits = BarChart::default() - .block( - Block::default() - .title("Exit code distribution") - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - .borders(Borders::ALL), - ) - .bar_width(3) - .bar_gap(1) - .bar_style(Style::default()) - .value_style(Style::default()) - .label_style(Style::default()) - .data(BarGroup::default().bars(&exits)); - - let day_of_week: Vec = stats - .day_of_week - .iter() - .map(|(day, count)| { - Bar::default() - .label(num_to_day(day.as_str())) - .value(u64_or_zero(*count)) - }) - .collect(); - - let day_of_week = BarChart::default() - .block( - Block::default() - .title("Runs per day") - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - .borders(Borders::ALL), - ) - .bar_width(3) - .bar_gap(1) - .bar_style(Style::default()) - .value_style(Style::default()) - .label_style(Style::default()) - .data(BarGroup::default().bars(&day_of_week)); - - let duration_over_time = sort_duration_over_time(&stats.duration_over_time); - let duration_over_time: Vec = duration_over_time - .iter() - .map(|(date, duration)| { - let d = Duration::from_nanos(u64_or_zero(*duration)); - Bar::default() - .label(date.clone()) - .value(u64_or_zero(*duration)) - .text_value(format_duration(d)) - }) - .collect(); - - let duration_over_time = BarChart::default() - .block( - Block::default() - .title("Duration over time") - .style(Style::from_crossterm(theme.as_style(Meaning::Base))) - .borders(Borders::ALL), - ) - .bar_width(5) - .bar_gap(1) - .bar_style(Style::default()) - .value_style(Style::default()) - .label_style(Style::default()) - .data(BarGroup::default().bars(&duration_over_time)); - - let layout = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Ratio(1, 3), - Constraint::Ratio(1, 3), - Constraint::Ratio(1, 3), - ]) - .split(parent); - - f.render_widget(exits, layout[0]); - f.render_widget(day_of_week, layout[1]); - f.render_widget(duration_over_time, layout[2]); -} - -pub fn draw( - f: &mut Frame<'_>, - chunk: Rect, - history: &History, - stats: &HistoryStats, - settings: &Settings, - theme: &Theme, - tz: Timezone, -) { - let compactness = to_compactness(f, settings); - - match compactness { - Compactness::Ultracompact => draw_ultracompact(f, chunk, history, stats, theme), - _ => draw_full(f, chunk, history, stats, theme, tz), - } -} - -pub fn draw_ultracompact( - f: &mut Frame<'_>, - chunk: Rect, - history: &History, - stats: &HistoryStats, - theme: &Theme, -) { - draw_commands(f, chunk, history, stats, true, theme); -} - -pub fn draw_full( - f: &mut Frame<'_>, - chunk: Rect, - history: &History, - stats: &HistoryStats, - theme: &Theme, - tz: Timezone, -) { - let vert_layout = Layout::default() - .direction(Direction::Vertical) - .constraints([Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]) - .split(chunk); - - let stats_layout = Layout::default() - .direction(Direction::Horizontal) - .constraints([Constraint::Ratio(1, 3), Constraint::Ratio(2, 3)]) - .split(vert_layout[1]); - - draw_commands(f, vert_layout[0], history, stats, false, theme); - draw_stats_table(f, stats_layout[0], history, tz, stats, theme); - draw_stats_charts(f, stats_layout[1], stats, theme); -} - -#[cfg(test)] -mod tests { - use super::draw_ultracompact; - use atuin_client::{ - history::{History, HistoryId, HistoryStats}, - theme::ThemeManager, - }; - use ratatui::{backend::TestBackend, prelude::*}; - use time::OffsetDateTime; - - fn mock_history_stats() -> (History, HistoryStats) { - let history = History { - id: HistoryId::from("test1".to_string()), - timestamp: OffsetDateTime::now_utc(), - duration: 3, - exit: 0, - command: "/bin/cmd".to_string(), - cwd: "/toot".to_string(), - session: "sesh1".to_string(), - hostname: "hostn".to_string(), - author: "hostn".to_string(), - intent: None, - deleted_at: None, - }; - let next = History { - id: HistoryId::from("test2".to_string()), - timestamp: OffsetDateTime::now_utc(), - duration: 2, - exit: 0, - command: "/bin/cmd -os".to_string(), - cwd: "/toot".to_string(), - session: "sesh1".to_string(), - hostname: "hostn".to_string(), - author: "hostn".to_string(), - intent: None, - deleted_at: None, - }; - let prev = History { - id: HistoryId::from("test3".to_string()), - timestamp: OffsetDateTime::now_utc(), - duration: 1, - exit: 0, - command: "/bin/cmd -a".to_string(), - cwd: "/toot".to_string(), - session: "sesh1".to_string(), - hostname: "hostn".to_string(), - author: "hostn".to_string(), - intent: None, - deleted_at: None, - }; - let stats = HistoryStats { - next: Some(next.clone()), - previous: Some(prev.clone()), - total: 2, - average_duration: 3, - exits: Vec::new(), - day_of_week: Vec::new(), - duration_over_time: Vec::new(), - }; - (history, stats) - } - - #[test] - fn test_output_looks_correct_for_ultracompact() { - let backend = TestBackend::new(22, 5); - let mut terminal = Terminal::new(backend).expect("Could not create terminal"); - let chunk = Rect::new(0, 0, 22, 5); - let (history, stats) = mock_history_stats(); - let prev = stats.previous.clone().unwrap(); - let next = stats.next.clone().unwrap(); - - let mut manager = ThemeManager::new(Some(true), Some("".to_string())); - let theme = manager.load_theme("(none)", None); - let _ = terminal.draw(|f| draw_ultracompact(f, chunk, &history, &stats, &theme)); - let mut lines = [" "; 5].map(|l| Line::from(l)); - for (n, entry) in [prev, history, next].iter().enumerate() { - let mut l = lines[n].to_string(); - l.replace_range(0..entry.command.len(), &entry.command); - lines[n] = Line::from(l); - } - - terminal.backend().assert_buffer_lines(lines); - } -} diff --git a/crates/atuin/src/command/client/search/interactive.rs b/crates/atuin/src/command/client/search/interactive.rs deleted file mode 100644 index 4efba803..00000000 --- a/crates/atuin/src/command/client/search/interactive.rs +++ /dev/null @@ -1,3099 +0,0 @@ -use std::{ - io::{IsTerminal, Write, stdout}, - time::Duration, -}; - -#[cfg(unix)] -use std::io::Read as _; - -use atuin_common::{shell::Shell, utils::Escapable as _}; -use eyre::Result; -use time::OffsetDateTime; -use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; - -use super::{ - cursor::Cursor, - engines::{SearchEngine, SearchState}, - history_list::{HistoryList, ListState}, -}; -use atuin_client::{ - database::{Context, Database, current_context}, - history::{History, HistoryId, HistoryStats, store::HistoryStore}, - settings::{ - CursorStyle, ExitMode, FilterMode, KeymapMode, PreviewStrategy, SearchMode, Settings, - UiColumn, - }, -}; - -use crate::command::client::search::history_list::HistoryHighlighter; -use crate::command::client::search::keybindings::KeymapSet; -use crate::command::client::theme::{Meaning, Theme}; -use crate::{VERSION, command::client::search::engines}; - -use ratatui::{ - Frame, Terminal, TerminalOptions, Viewport, - backend::{CrosstermBackend, FromCrossterm}, - crossterm::{ - cursor::SetCursorStyle, - event::{self, Event, KeyEvent, MouseEvent}, - execute, queue, terminal, - }, - layout::{Alignment, Constraint, Direction, Layout}, - prelude::*, - style::{Modifier, Style}, - text::{Line, Span, Text}, - widgets::{Block, BorderType, Borders, Clear, Padding, Paragraph, Tabs}, -}; - -#[cfg(not(target_os = "windows"))] -use ratatui::crossterm::event::{ - KeyboardEnhancementFlags, PopKeyboardEnhancementFlags, PushKeyboardEnhancementFlags, -}; - -#[cfg(windows)] -use windows_sys::Win32::System::Console::{GetConsoleOutputCP, SetConsoleOutputCP}; - -const TAB_TITLES: [&str; 2] = ["Search", "Inspect"]; - -pub enum InputAction { - Accept(usize), - AcceptInspecting, - Copy(usize), - Delete(usize), - DeleteAllMatching(usize), - ReturnOriginal, - ReturnQuery, - Continue, - Redraw, - SwitchContext(Option), -} - -#[derive(Clone)] -pub struct InspectingState { - current: Option, - next: Option, - previous: Option, -} - -impl InspectingState { - pub fn move_to_previous(&mut self) { - let previous = self.previous.clone(); - self.reset(); - self.current = previous; - } - - pub fn move_to_next(&mut self) { - let next = self.next.clone(); - self.reset(); - self.current = next; - } - - pub fn reset(&mut self) { - self.current = None; - self.next = None; - self.previous = None; - } -} - -pub fn to_compactness(f: &Frame, settings: &Settings) -> Compactness { - if match settings.style { - atuin_client::settings::Style::Auto => f.area().height < 14, - atuin_client::settings::Style::Compact => true, - atuin_client::settings::Style::Full => false, - } { - if settings.auto_hide_height != 0 && f.area().height <= settings.auto_hide_height { - Compactness::Ultracompact - } else { - Compactness::Compact - } - } else { - Compactness::Full - } -} - -#[expect(clippy::struct_field_names)] -#[expect(clippy::struct_excessive_bools)] -pub struct State { - history_count: i64, - results_state: ListState, - switched_search_mode: bool, - search_mode: SearchMode, - results_len: usize, - accept: bool, - keymap_mode: KeymapMode, - prefix: bool, - current_cursor: Option, - tab_index: usize, - pending_vim_key: Option, - original_input_empty: bool, - - pub inspecting_state: InspectingState, - - keymaps: KeymapSet, - search: SearchState, - engine: Box, - now: Box OffsetDateTime + Send>, -} - -#[derive(Clone, Copy)] -pub enum Compactness { - Ultracompact, - Compact, - Full, -} - -#[derive(Clone, Copy)] -struct StyleState { - compactness: Compactness, - invert: bool, - inner_width: usize, -} - -impl State { - async fn query_results( - &mut self, - db: &mut dyn Database, - smart_sort: bool, - ) -> Result> { - let results = self.engine.query(&self.search, db).await?; - - self.inspecting_state = InspectingState { - current: None, - next: None, - previous: None, - }; - self.results_state.select(0); - self.results_len = results.len(); - - if smart_sort { - Ok(atuin_history::sort::sort( - self.search.input.as_str(), - results, - )) - } else { - Ok(results) - } - } - - fn handle_input(&mut self, settings: &Settings, input: &Event) -> InputAction { - match input { - Event::Key(k) => self.handle_key_input(settings, k), - Event::Mouse(m) => self.handle_mouse_input(*m, settings.invert), - Event::Paste(d) => self.handle_paste_input(d), - _ => InputAction::Continue, - } - } - - fn handle_mouse_input(&mut self, input: MouseEvent, inverted: bool) -> InputAction { - match (input.kind, inverted) { - (event::MouseEventKind::ScrollDown, false) - | (event::MouseEventKind::ScrollUp, true) => { - self.scroll_down(1); - } - (event::MouseEventKind::ScrollDown, true) - | (event::MouseEventKind::ScrollUp, false) => { - self.scroll_up(1); - } - _ => {} - } - InputAction::Continue - } - - fn handle_paste_input(&mut self, input: &str) -> InputAction { - for i in input.chars() { - self.search.input.insert(i); - } - InputAction::Continue - } - - fn cast_cursor_style(style: CursorStyle) -> SetCursorStyle { - match style { - CursorStyle::DefaultUserShape => SetCursorStyle::DefaultUserShape, - CursorStyle::BlinkingBlock => SetCursorStyle::BlinkingBlock, - CursorStyle::SteadyBlock => SetCursorStyle::SteadyBlock, - CursorStyle::BlinkingUnderScore => SetCursorStyle::BlinkingUnderScore, - CursorStyle::SteadyUnderScore => SetCursorStyle::SteadyUnderScore, - CursorStyle::BlinkingBar => SetCursorStyle::BlinkingBar, - CursorStyle::SteadyBar => SetCursorStyle::SteadyBar, - } - } - - fn set_keymap_cursor(&mut self, settings: &Settings, keymap_name: &str) { - let cursor_style = if keymap_name == "__clear__" { - None - } else { - settings.keymap_cursor.get(keymap_name).copied() - } - .or_else(|| self.current_cursor.map(|_| CursorStyle::DefaultUserShape)); - - if cursor_style != self.current_cursor - && let Some(style) = cursor_style - { - self.current_cursor = cursor_style; - let _ = execute!(stdout(), Self::cast_cursor_style(style)); - } - } - - pub fn initialize_keymap_cursor(&mut self, settings: &Settings) { - match self.keymap_mode { - KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), - KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), - KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), - KeymapMode::Auto => {} - } - } - - pub fn finalize_keymap_cursor(&mut self, settings: &Settings) { - match settings.keymap_mode_shell { - KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), - KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), - KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), - KeymapMode::Auto => self.set_keymap_cursor(settings, "__clear__"), - } - } - - fn handle_key_exit(settings: &Settings) -> InputAction { - match settings.exit_mode { - ExitMode::ReturnOriginal => InputAction::ReturnOriginal, - ExitMode::ReturnQuery => InputAction::ReturnQuery, - } - } - - /// Select the keymap for the current mode (ignoring prefix). - fn mode_keymap(&self) -> &super::keybindings::Keymap { - if self.tab_index == 1 { - &self.keymaps.inspector - } else { - match self.keymap_mode { - KeymapMode::Emacs | KeymapMode::Auto => &self.keymaps.emacs, - KeymapMode::VimNormal => &self.keymaps.vim_normal, - KeymapMode::VimInsert => &self.keymaps.vim_insert, - } - } - } - - /// Whether the current mode supports character insertion on unmatched keys. - fn is_insert_mode(&self) -> bool { - matches!( - self.keymap_mode, - KeymapMode::Emacs | KeymapMode::Auto | KeymapMode::VimInsert - ) - } - - fn handle_key_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { - use super::keybindings::Action; - use super::keybindings::EvalContext; - use super::keybindings::key::{KeyCodeValue, KeyInput, SingleKey}; - - // Skip release events - if input.kind == event::KeyEventKind::Release { - return InputAction::Continue; - } - - // Reset switched_search_mode at start of each key event - self.switched_search_mode = false; - - // Build evaluation context from current state - let ctx = EvalContext { - cursor_position: self.search.input.position(), - input_width: UnicodeWidthStr::width(self.search.input.as_str()), - input_byte_len: self.search.input.as_str().len(), - selected_index: self.results_state.selected(), - results_len: self.results_len, - original_input_empty: self.original_input_empty, - has_context: self.search.custom_context.is_some(), - }; - - // Convert KeyEvent to SingleKey - let Some(single) = SingleKey::from_event(input) else { - return InputAction::Continue; - }; - - // --- Phase 1: Resolve (take pending key first, then immutable borrows) --- - - // Take pending key before any immutable borrows of self - let pending = self.pending_vim_key.take(); - - // If in prefix mode, try prefix keymap first (single keys only) - let prefix_action = if self.prefix { - let ki = KeyInput::Single(single.clone()); - self.keymaps.prefix.resolve(&ki, &ctx) - } else { - None - }; - - // The if-let/else-if chain here is clearer than map_or_else with nested closures. - #[expect(clippy::option_if_let_else)] - let (action, new_pending) = if prefix_action.is_some() { - (prefix_action, None) - } else { - // Use mode keymap (handles both single and multi-key sequences) - let keymap = self.mode_keymap(); - - if let Some(pending_char) = pending { - // We have a pending key from a previous press (e.g., first 'g' of 'gg') - let pending_single = SingleKey { - code: KeyCodeValue::Char(pending_char), - ctrl: false, - alt: false, - shift: false, - super_key: false, - }; - let seq = KeyInput::Sequence(vec![pending_single, single.clone()]); - let action = keymap - .resolve(&seq, &ctx) - .or_else(|| keymap.resolve(&KeyInput::Single(single.clone()), &ctx)); - (action, None) - } else if keymap.has_sequence_starting_with(&single) - && matches!(single.code, KeyCodeValue::Char(_)) - && !single.ctrl - && !single.alt - { - // This key starts a multi-key sequence; wait for next key - let KeyCodeValue::Char(c) = single.code else { - unreachable!() - }; - (Some(Action::Noop), Some(c)) - } else { - ( - keymap.resolve(&KeyInput::Single(single.clone()), &ctx), - None, - ) - } - }; - - // --- Phase 2: Apply mutations --- - self.pending_vim_key = new_pending; - - // Reset prefix (before execute, so EnterPrefixMode can re-set it) - self.prefix = false; - - if let Some(action) = action { - self.execute_action(&action, settings) - } else { - // No action matched. In insert-capable modes, insert the character. - if self.is_insert_mode() && !single.ctrl && !single.alt { - match single.code { - KeyCodeValue::Char(c) => { - self.search.input.insert(c); - } - KeyCodeValue::Space => { - self.search.input.insert(' '); - } - _ => {} - } - } - InputAction::Continue - } - } - - fn scroll_down(&mut self, scroll_len: usize) { - let i = self.results_state.selected().saturating_sub(scroll_len); - self.inspecting_state.reset(); - self.results_state.select(i); - } - - fn scroll_up(&mut self, scroll_len: usize) { - let i = self.results_state.selected() + scroll_len; - self.results_state - .select(i.min(self.results_len.saturating_sub(1))); - self.inspecting_state.reset(); - } - - /// Execute a resolved action, performing all side effects and returning the - /// appropriate `InputAction` for the event loop. - /// - /// This is the "do it" half of the resolve+execute pipeline. The resolver - /// decides *what* to do (which `Action`), and this function carries it out. - /// - /// Invert handling: scroll actions (`SelectNext`, `ScrollPageDown`, etc.) account - /// for `settings.invert` so that keybindings are always in "visual" terms — - /// users never need to think about invert in their keybinding config. - #[expect(clippy::too_many_lines)] - pub(crate) fn execute_action( - &mut self, - action: &super::keybindings::Action, - settings: &Settings, - ) -> InputAction { - use crate::command::client::search::keybindings::Action; - - match action { - // -- Cursor movement -- - Action::CursorLeft => { - self.search.input.left(); - InputAction::Continue - } - Action::CursorRight => { - self.search.input.right(); - InputAction::Continue - } - Action::CursorWordLeft => { - self.search - .input - .prev_word(&settings.word_chars, settings.word_jump_mode); - InputAction::Continue - } - Action::CursorWordRight => { - self.search - .input - .next_word(&settings.word_chars, settings.word_jump_mode); - InputAction::Continue - } - Action::CursorWordEnd => { - self.search.input.word_end(&settings.word_chars); - InputAction::Continue - } - Action::CursorStart => { - self.search.input.start(); - InputAction::Continue - } - Action::CursorEnd => { - self.search.input.end(); - InputAction::Continue - } - - // -- Editing -- - Action::DeleteCharBefore => { - self.search.input.back(); - InputAction::Continue - } - Action::DeleteCharAfter => { - self.search.input.remove(); - InputAction::Continue - } - Action::DeleteWordBefore => { - self.search - .input - .remove_prev_word(&settings.word_chars, settings.word_jump_mode); - InputAction::Continue - } - Action::DeleteWordAfter => { - self.search - .input - .remove_next_word(&settings.word_chars, settings.word_jump_mode); - InputAction::Continue - } - Action::DeleteToWordBoundary => { - // ctrl-w: remove trailing whitespace, then delete to word boundary - while matches!(self.search.input.back(), Some(c) if c.is_whitespace()) {} - while self.search.input.left() { - if self.search.input.char().unwrap().is_whitespace() { - self.search.input.right(); - break; - } - self.search.input.remove(); - } - InputAction::Continue - } - Action::ClearLine => { - self.search.input.clear(); - InputAction::Continue - } - Action::ClearToStart => { - self.search.input.clear_to_start(); - InputAction::Continue - } - Action::ClearToEnd => { - self.search.input.clear_to_end(); - InputAction::Continue - } - - // -- List navigation (invert-aware) -- - Action::SelectNext => { - if settings.invert { - self.scroll_up(1); - } else { - self.scroll_down(1); - } - InputAction::Continue - } - Action::SelectPrevious => { - if settings.invert { - self.scroll_down(1); - } else { - self.scroll_up(1); - } - InputAction::Continue - } - // -- Page/half-page scroll (invert-aware) -- - Action::ScrollHalfPageUp => { - let scroll_len = self - .results_state - .max_entries() - .saturating_sub(settings.scroll_context_lines) - / 2; - if settings.invert { - self.scroll_down(scroll_len); - } else { - self.scroll_up(scroll_len); - } - InputAction::Continue - } - Action::ScrollHalfPageDown => { - let scroll_len = self - .results_state - .max_entries() - .saturating_sub(settings.scroll_context_lines) - / 2; - if settings.invert { - self.scroll_up(scroll_len); - } else { - self.scroll_down(scroll_len); - } - InputAction::Continue - } - Action::ScrollPageUp => { - let scroll_len = self - .results_state - .max_entries() - .saturating_sub(settings.scroll_context_lines); - if settings.invert { - self.scroll_down(scroll_len); - } else { - self.scroll_up(scroll_len); - } - InputAction::Continue - } - Action::ScrollPageDown => { - let scroll_len = self - .results_state - .max_entries() - .saturating_sub(settings.scroll_context_lines); - if settings.invert { - self.scroll_up(scroll_len); - } else { - self.scroll_down(scroll_len); - } - InputAction::Continue - } - - // -- Absolute jumps (invert-aware) -- - Action::ScrollToTop => { - // Visual top of history - if settings.invert { - self.results_state.select(0); - } else { - let last_idx = self.results_len.saturating_sub(1); - self.results_state.select(last_idx); - } - self.inspecting_state.reset(); - InputAction::Continue - } - Action::ScrollToBottom => { - // Visual bottom of history - if settings.invert { - let last_idx = self.results_len.saturating_sub(1); - self.results_state.select(last_idx); - } else { - self.results_state.select(0); - } - self.inspecting_state.reset(); - InputAction::Continue - } - Action::ScrollToScreenTop => { - // H — jump to top of visible screen - let top = self.results_state.offset(); - let visible = self.results_state.max_entries().min(self.results_len); - let bottom = top + visible.saturating_sub(1); - self.results_state - .select(bottom.min(self.results_len.saturating_sub(1))); - self.inspecting_state.reset(); - InputAction::Continue - } - Action::ScrollToScreenMiddle => { - // M — jump to middle of visible screen - let top = self.results_state.offset(); - let visible = self.results_state.max_entries().min(self.results_len); - let middle = top + visible / 2; - self.results_state - .select(middle.min(self.results_len.saturating_sub(1))); - self.inspecting_state.reset(); - InputAction::Continue - } - Action::ScrollToScreenBottom => { - // L — jump to bottom of visible screen - let top_visible = self.results_state.offset(); - self.results_state.select(top_visible); - self.inspecting_state.reset(); - InputAction::Continue - } - - // -- Commands -- - Action::Accept => { - if self.tab_index == 1 { - return InputAction::AcceptInspecting; - } - self.accept = true; - InputAction::Accept(self.results_state.selected()) - } - Action::AcceptNth(n) => { - self.accept = true; - InputAction::Accept(self.results_state.selected() + *n as usize) - } - Action::ReturnSelection => { - if self.tab_index == 1 { - return InputAction::AcceptInspecting; - } - InputAction::Accept(self.results_state.selected()) - } - Action::ReturnSelectionNth(n) => { - InputAction::Accept(self.results_state.selected() + *n as usize) - } - Action::Copy => InputAction::Copy(self.results_state.selected()), - Action::Delete => InputAction::Delete(self.results_state.selected()), - Action::DeleteAll => InputAction::DeleteAllMatching(self.results_state.selected()), - Action::ReturnOriginal => InputAction::ReturnOriginal, - Action::ReturnQuery => InputAction::ReturnQuery, - Action::Exit => Self::handle_key_exit(settings), - Action::Redraw => InputAction::Redraw, - Action::CycleFilterMode => { - self.search.rotate_filter_mode(settings, 1); - InputAction::Continue - } - Action::CycleSearchMode => { - self.switched_search_mode = true; - self.search_mode = self.search_mode.next(settings); - self.engine = engines::engine(self.search_mode, settings); - InputAction::Continue - } - Action::SwitchContext => { - InputAction::SwitchContext(Some(self.results_state.selected())) - } - Action::ClearContext => InputAction::SwitchContext(None), - Action::ToggleTab => { - self.tab_index = (self.tab_index + 1) % TAB_TITLES.len(); - InputAction::Continue - } - - // -- Mode changes -- - Action::VimEnterNormal => { - self.set_keymap_cursor(settings, "vim_normal"); - self.keymap_mode = KeymapMode::VimNormal; - InputAction::Continue - } - Action::VimEnterInsert => { - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::VimEnterInsertAfter => { - self.search.input.right(); - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::VimEnterInsertAtStart => { - self.search.input.start(); - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::VimEnterInsertAtEnd => { - self.search.input.end(); - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::VimSearchInsert => { - self.search.input.clear(); - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::VimChangeToEnd => { - self.search.input.clear_to_end(); - self.set_keymap_cursor(settings, "vim_insert"); - self.keymap_mode = KeymapMode::VimInsert; - InputAction::Continue - } - Action::EnterPrefixMode => { - self.prefix = true; - InputAction::Continue - } - - // -- Inspector -- - Action::InspectPrevious => { - self.inspecting_state.move_to_previous(); - InputAction::Redraw - } - Action::InspectNext => { - self.inspecting_state.move_to_next(); - InputAction::Redraw - } - - // -- Special -- - Action::Noop => InputAction::Continue, - } - } - - #[expect(clippy::cast_possible_truncation)] - #[expect(clippy::bool_to_int_with_if)] - fn calc_preview_height( - settings: &Settings, - results: &[History], - selected: usize, - tab_index: usize, - compactness: Compactness, - border_size: u16, - preview_width: u16, - ) -> u16 { - if settings.show_preview - && settings.preview.strategy == PreviewStrategy::Auto - && tab_index == 0 - && !results.is_empty() - { - let length_current_cmd = results[selected].command.len() as u16; - // calculate the number of newlines in the command - let num_newlines = results[selected] - .command - .chars() - .filter(|&c| c == '\n') - .count() as u16; - if num_newlines > 0 { - std::cmp::min( - settings.max_preview_height, - results[selected] - .command - .split('\n') - .map(|line| { - (line.len() as u16 + preview_width - 1 - border_size) - / (preview_width - border_size) - }) - .sum(), - ) + border_size * 2 - } - // The '- 19' takes the characters before the command (duration and time) into account - else if length_current_cmd > preview_width - 19 { - std::cmp::min( - settings.max_preview_height, - (length_current_cmd + preview_width - 1 - border_size) - / (preview_width - border_size), - ) + border_size * 2 - } else { - 1 - } - } else if settings.show_preview - && settings.preview.strategy == PreviewStrategy::Static - && tab_index == 0 - { - let longest_command = results - .iter() - .max_by(|h1, h2| h1.command.len().cmp(&h2.command.len())); - longest_command.map_or(0, |v| { - std::cmp::min( - settings.max_preview_height, - v.command - .split('\n') - .map(|line| { - (line.len() as u16 + preview_width - 1 - border_size) - / (preview_width - border_size) - }) - .sum(), - ) - }) + border_size * 2 - } else if settings.show_preview && settings.preview.strategy == PreviewStrategy::Fixed { - settings.max_preview_height + border_size * 2 - } else if !matches!(compactness, Compactness::Full) || tab_index == 1 { - 0 - } else { - 1 - } - } - - #[expect(clippy::bool_to_int_with_if)] - #[expect(clippy::too_many_lines)] - #[expect(clippy::too_many_arguments)] - fn draw( - &mut self, - f: &mut Frame, - results: &[History], - stats: Option, - inspecting: Option<&History>, - settings: &Settings, - theme: &Theme, - popup_mode: bool, - ) { - let area = f.area(); - if popup_mode { - f.render_widget(Clear, area); - } - self.draw_inner(f, area, results, stats, inspecting, settings, theme); - } - - #[expect(clippy::too_many_arguments)] - #[expect(clippy::too_many_lines)] - #[expect(clippy::bool_to_int_with_if)] - fn draw_inner( - &mut self, - f: &mut Frame, - area: Rect, - results: &[History], - stats: Option, - inspecting: Option<&History>, - settings: &Settings, - theme: &Theme, - ) { - let compactness = to_compactness(f, settings); - let invert = settings.invert; - let border_size = match compactness { - Compactness::Full => 1, - _ => 0, - }; - let preview_width = area.width.saturating_sub(2); - let preview_height = Self::calc_preview_height( - settings, - results, - self.results_state.selected(), - self.tab_index, - compactness, - border_size, - preview_width, - ); - let show_help = - settings.show_help && (matches!(compactness, Compactness::Full) || area.height > 1); - // This is an OR, as it seems more likely for someone to wish to override - // tabs unexpectedly being missed, than unexpectedly present. - let show_tabs = settings.show_tabs && !matches!(compactness, Compactness::Ultracompact); - let chunks = Layout::default() - .direction(Direction::Vertical) - .margin(0) - .horizontal_margin(1) - .constraints::<&[Constraint]>( - if invert { - [ - Constraint::Length(1 + border_size), // input - Constraint::Min(1), // results list - Constraint::Length(preview_height), // preview - Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs - Constraint::Length(if show_help { 1 } else { 0 }), // header (sic) - ] - } else { - match compactness { - Compactness::Ultracompact => [ - Constraint::Length(if show_help { 1 } else { 0 }), // header - Constraint::Length(0), // tabs - Constraint::Min(1), // results list - Constraint::Length(0), - Constraint::Length(0), - ], - _ => [ - Constraint::Length(if show_help { 1 } else { 0 }), // header - Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs - Constraint::Min(1), // results list - Constraint::Length(1 + border_size), // input - Constraint::Length(preview_height), // preview - ], - } - } - .as_ref(), - ) - .split(area); - - let input_chunk = if invert { chunks[0] } else { chunks[3] }; - let results_list_chunk = if invert { chunks[1] } else { chunks[2] }; - let preview_chunk = if invert { chunks[2] } else { chunks[4] }; - let tabs_chunk = if invert { chunks[3] } else { chunks[1] }; - let header_chunk = if invert { chunks[4] } else { chunks[0] }; - - // TODO: this should be split so that we have one interactive search container that is - // EITHER a search box or an inspector. But I'm not doing that now, way too much atm. - // also allocate less 🙈 - let titles: Vec<_> = TAB_TITLES.iter().copied().map(Line::from).collect(); - - if show_tabs { - let tabs = Tabs::new(titles) - .block(Block::default().borders(Borders::NONE)) - .select(self.tab_index) - .style(Style::default()) - .highlight_style(Style::from_crossterm(theme.as_style(Meaning::Important))); - - f.render_widget(tabs, tabs_chunk); - } - - let style = StyleState { - compactness, - invert, - inner_width: input_chunk.width.into(), - }; - - let header_chunks = Layout::default() - .direction(Direction::Horizontal) - .constraints::<&[Constraint]>( - [ - Constraint::Ratio(1, 5), - Constraint::Ratio(3, 5), - Constraint::Ratio(1, 5), - ] - .as_ref(), - ) - .split(header_chunk); - - let title = Self::build_title(theme); - f.render_widget(title, header_chunks[0]); - - let help = self.build_help(settings, theme); - f.render_widget(help, header_chunks[1]); - - let stats_tab = self.build_stats(theme); - f.render_widget(stats_tab, header_chunks[2]); - - let indicator: String = match compactness { - Compactness::Ultracompact => { - if self.switched_search_mode { - format!("S{}>", self.search_mode.as_str().chars().next().unwrap()) - } else if self.search.custom_context.is_some() { - format!( - "C{}>", - self.search.filter_mode.as_str().chars().next().unwrap() - ) - } else { - format!( - "{}> ", - self.search.filter_mode.as_str().chars().next().unwrap() - ) - } - } - _ => " > ".to_string(), - }; - - match self.tab_index { - 0 => { - let history_highlighter = HistoryHighlighter { - engine: self.engine.as_ref(), - search_input: self.search.input.as_str(), - }; - let results_list = Self::build_results_list( - style, - results, - self.keymap_mode, - &self.now, - indicator.as_str(), - theme, - history_highlighter, - settings.show_numeric_shortcuts, - &settings.ui.columns, - ); - f.render_stateful_widget(results_list, results_list_chunk, &mut self.results_state); - } - - 1 => { - if results.is_empty() { - let message = Paragraph::new("Nothing to inspect") - .block( - Block::new() - .title(Line::from(" Info ".to_string())) - .title_alignment(Alignment::Center) - .borders(Borders::ALL) - .padding(Padding::vertical(2)), - ) - .alignment(Alignment::Center); - f.render_widget(message, results_list_chunk); - } else { - let inspecting = match inspecting { - Some(inspecting) => inspecting, - None => &results[self.results_state.selected()], - }; - super::inspector::draw( - f, - results_list_chunk, - inspecting, - &stats.expect("Drawing inspector, but no stats"), - settings, - theme, - settings.timezone, - ); - } - - // HACK: I'm following up with abstracting this into the UI container, with a - // sub-widget for search + for inspector - let feedback = Paragraph::new( - "The inspector is new - please give feedback (good, or bad) at https://forum.atuin.sh", - ); - f.render_widget(feedback, input_chunk); - - return; - } - - _ => { - panic!("invalid tab index"); - } - } - - if !matches!(compactness, Compactness::Ultracompact) { - let preview_width = match compactness { - Compactness::Full => preview_width - 2, - _ => preview_width, - }; - let preview = self.build_preview( - results, - compactness, - preview_width, - preview_chunk.width.into(), - theme, - ); - #[expect(clippy::cast_possible_truncation)] - let prefix_width = settings - .ui - .columns - .iter() - .take_while(|col| !col.expand) - .map(|col| col.width + 1) - .sum::() - + " > ".len() as u16; - #[expect(clippy::cast_possible_truncation)] - let min_prefix_width = "[ SRCH: FULLTXT ] ".len() as u16; - self.draw_preview( - f, - style, - input_chunk, - compactness, - preview_chunk, - preview, - std::cmp::max(prefix_width, min_prefix_width), - ); - } - } - - #[expect(clippy::cast_possible_truncation, clippy::too_many_arguments)] - fn draw_preview( - &self, - f: &mut Frame, - style: StyleState, - input_chunk: Rect, - compactness: Compactness, - preview_chunk: Rect, - preview: Paragraph, - prefix_width: u16, - ) { - let input = self.build_input(style, prefix_width); - f.render_widget(input, input_chunk); - - f.render_widget(preview, preview_chunk); - - let extra_width = UnicodeWidthStr::width(self.search.input.substring()); - - let cursor_offset = match compactness { - Compactness::Full => 1, - _ => 0, - }; - f.set_cursor_position(( - // Put cursor past the end of the input text - input_chunk.x + extra_width as u16 + prefix_width + cursor_offset, - input_chunk.y + cursor_offset, - )); - } - - fn build_title(theme: &Theme) -> Paragraph<'_> { - let title = { - let style: Style = Style::from_crossterm(theme.as_style(Meaning::Base)); - Paragraph::new(Text::from(Span::styled( - format!("Atuin v{VERSION}"), - style.add_modifier(Modifier::BOLD), - ))) - }; - title.alignment(Alignment::Left) - } - - #[expect(clippy::unused_self)] - fn build_help(&self, settings: &Settings, theme: &Theme) -> Paragraph<'_> { - match self.tab_index { - // search - 0 => Paragraph::new(Text::from(Line::from(vec![ - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": exit"), - Span::raw(", "), - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": edit"), - Span::raw(", "), - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(if settings.enter_accept { - ": run" - } else { - ": edit" - }), - Span::raw(", "), - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": inspect"), - ]))), - - 1 => Paragraph::new(Text::from(Line::from(vec![ - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": exit"), - Span::raw(", "), - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": search"), - Span::raw(", "), - Span::styled("", Style::default().add_modifier(Modifier::BOLD)), - Span::raw(": delete"), - ]))), - - _ => unreachable!("invalid tab index"), - } - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - .alignment(Alignment::Center) - } - - fn build_stats(&self, theme: &Theme) -> Paragraph<'_> { - Paragraph::new(Text::from(Span::raw(format!( - "history count: {}", - self.history_count, - )))) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) - .alignment(Alignment::Right) - } - - #[expect(clippy::too_many_arguments)] - fn build_results_list<'a>( - style: StyleState, - results: &'a [History], - keymap_mode: KeymapMode, - now: &'a dyn Fn() -> OffsetDateTime, - indicator: &'a str, - theme: &'a Theme, - history_highlighter: HistoryHighlighter<'a>, - show_numeric_shortcuts: bool, - columns: &'a [UiColumn], - ) -> HistoryList<'a> { - let results_list = HistoryList::new( - results, - style.invert, - keymap_mode == KeymapMode::VimNormal, - now, - indicator, - theme, - history_highlighter, - show_numeric_shortcuts, - columns, - ); - - match style.compactness { - Compactness::Full => { - if style.invert { - results_list.block( - Block::default() - .borders(Borders::LEFT | Borders::RIGHT) - .border_type(BorderType::Rounded) - .title(format!("{:─>width$}", "", width = style.inner_width - 2)), - ) - } else { - results_list.block( - Block::default() - .borders(Borders::TOP | Borders::LEFT | Borders::RIGHT) - .border_type(BorderType::Rounded), - ) - } - } - _ => results_list, - } - } - - fn build_input(&self, style: StyleState, prefix_width: u16) -> Paragraph<'_> { - let (pref, mode) = if self.switched_search_mode { - (" SRCH:", self.search_mode.as_str()) - } else if self.search.custom_context.is_some() { - (" CTX:", self.search.filter_mode.as_str()) - } else { - ("", self.search.filter_mode.as_str()) - }; - // 3: surrounding "[" "] " - let mode_width = usize::from(prefix_width) - pref.len() - 3; - // sanity check to ensure we don't exceed the layout limits - debug_assert!(mode_width >= mode.len(), "mode name '{mode}' is too long!"); - let input = format!("[{pref}{mode:^mode_width$}] {}", self.search.input.as_str()); - let input = Paragraph::new(input); - match style.compactness { - Compactness::Full => { - if style.invert { - input.block( - Block::default() - .borders(Borders::LEFT | Borders::RIGHT | Borders::TOP) - .border_type(BorderType::Rounded), - ) - } else { - input.block( - Block::default() - .borders(Borders::LEFT | Borders::RIGHT) - .border_type(BorderType::Rounded) - .title(format!("{:─>width$}", "", width = style.inner_width - 2)), - ) - } - } - _ => input, - } - } - - fn build_preview( - &self, - results: &[History], - compactness: Compactness, - preview_width: u16, - chunk_width: usize, - theme: &Theme, - ) -> Paragraph<'_> { - let selected = self.results_state.selected(); - let command = if results.is_empty() { - String::new() - } else { - let s = &results[selected].command; - let mut lines = Vec::new(); - for line in s.split('\n') { - let line = line.escape_control(); - let mut width = 0; - let mut start = 0; - for (idx, ch) in line.char_indices() { - let w = ch.width().unwrap_or(0); // None for control chars which should not happen - if width + w > preview_width.into() { - lines.push(line[start..idx].to_owned()); - start = idx; - width = w; - } else { - width += w; - } - } - if width != 0 { - lines.push(line[start..].to_owned()); - } - } - lines.join("\n") - }; - - match compactness { - Compactness::Full => Paragraph::new(command).block( - Block::default() - .borders(Borders::BOTTOM | Borders::LEFT | Borders::RIGHT) - .border_type(BorderType::Rounded) - .title(format!("{:─>width$}", "", width = chunk_width - 2)), - ), - _ => Paragraph::new(command) - .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))), - } - } -} - -/// The writer used for terminal output - either stdout or /dev/tty -enum TerminalWriter { - Stdout(std::io::Stdout), - #[cfg(unix)] - Tty(std::fs::File), - #[cfg(windows)] - ConOut(std::io::LineWriter, u32), -} - -impl TerminalWriter { - #[cfg(windows)] - const CP_UTF8: u32 = 65001; - - fn new() -> std::io::Result { - let stdout = stdout(); - if stdout.is_terminal() { - return Ok(TerminalWriter::Stdout(stdout)); - } - - // If stdout is not a terminal (e.g., captured by command substitution), - // fall back to /dev/tty so the TUI can still render. - // This allows usage like: VAR=$(atuin search -i) - #[cfg(unix)] - { - Ok(TerminalWriter::Tty( - std::fs::File::options() - .read(true) - .write(true) - .open("/dev/tty")?, - )) - } - - // On Windows, use CONOUT$ which is the equivalent of /dev/tty, but this - // requires setting the current console output code page to UTF-8 for the - // TUI to render properly. We'll set it back to its previous value upon exit. - #[cfg(windows)] - { - let file = std::fs::File::options() - .read(true) - .write(true) - .open("CONOUT$")?; - - let initial_console_output_cp = unsafe { GetConsoleOutputCP() }; - if initial_console_output_cp != Self::CP_UTF8 { - unsafe { - SetConsoleOutputCP(Self::CP_UTF8); - } - } - - Ok(TerminalWriter::ConOut( - std::io::LineWriter::new(file), - initial_console_output_cp, - )) - } - - #[cfg(not(any(unix, windows)))] - Err(std::io::Error::new( - std::io::ErrorKind::Unsupported, - "Interactive mode requires a terminal", - )) - } -} - -impl Write for TerminalWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - match self { - TerminalWriter::Stdout(stdout) => stdout.write(buf), - #[cfg(unix)] - TerminalWriter::Tty(file) => file.write(buf), - #[cfg(windows)] - TerminalWriter::ConOut(writer, _) => writer.write(buf), - } - } - - fn flush(&mut self) -> std::io::Result<()> { - match self { - TerminalWriter::Stdout(stdout) => stdout.flush(), - #[cfg(unix)] - TerminalWriter::Tty(file) => file.flush(), - #[cfg(windows)] - TerminalWriter::ConOut(writer, _) => writer.flush(), - } - } -} - -impl Drop for TerminalWriter { - fn drop(&mut self) { - #[cfg(windows)] - if let TerminalWriter::ConOut(_, initial_console_output_cp) = self - && *initial_console_output_cp != Self::CP_UTF8 - { - unsafe { - SetConsoleOutputCP(*initial_console_output_cp); - } - } - } -} - -/// Screen state captured from atuin pty-proxy's screen server. -#[cfg(unix)] -struct SavedScreen { - #[expect(dead_code)] - rows: u16, - #[expect(dead_code)] - cols: u16, - cursor_row: u16, - cursor_col: u16, - /// Pre-formatted ANSI bytes for each screen row, ready to write to stdout. - rows_data: Vec>, -} - -/// Connect to atuin pty-proxy's Unix socket and fetch the current screen state. -/// -/// The wire format is: -/// ```text -/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] -/// [row_0_len: u32 BE][row_0_bytes...] -/// [row_1_len: u32 BE][row_1_bytes...] -/// ... -/// ``` -#[cfg(unix)] -fn fetch_screen_state(socket_path: &str) -> Option { - use std::os::unix::net::UnixStream; - - let mut stream = UnixStream::connect(socket_path).ok()?; - stream.set_read_timeout(Some(Duration::from_secs(2))).ok()?; - - let mut data = Vec::new(); - stream.read_to_end(&mut data).ok()?; - - if data.len() < 8 { - return None; - } - - let rows = u16::from_be_bytes([data[0], data[1]]); - let cols = u16::from_be_bytes([data[2], data[3]]); - let cursor_row = u16::from_be_bytes([data[4], data[5]]); - let cursor_col = u16::from_be_bytes([data[6], data[7]]); - - // Parse length-prefixed rows - let mut rows_data = Vec::with_capacity(rows as usize); - let mut offset = 8; - while offset + 4 <= data.len() { - let row_len = u32::from_be_bytes([ - data[offset], - data[offset + 1], - data[offset + 2], - data[offset + 3], - ]) as usize; - offset += 4; - if offset + row_len > data.len() { - break; - } - rows_data.push(data[offset..offset + row_len].to_vec()); - offset += row_len; - } - - Some(SavedScreen { - rows, - cols, - cursor_row, - cursor_col, - rows_data, - }) -} - -/// Restore the screen area that was covered by the popup. -/// -/// Writes the pre-formatted per-row ANSI bytes received from atuin pty-proxy -/// directly to stdout, which correctly handles wide characters, colors, and -/// all text attributes without needing a client-side vt100 parser. -#[cfg(unix)] -fn restore_popup_area(saved: &SavedScreen, popup_rect: Rect, scroll_offset: u16) { - use ratatui::crossterm::cursor::MoveTo; - - let mut stdout = stdout(); - - for dy in 0..popup_rect.height { - let target_row = popup_rect.y + dy; - let source_row = (target_row + scroll_offset) as usize; - - // Clear only the popup region. The server-side rows_formatted() skips - // default cells (spaces with default attributes) using cursor jumps, so - // any popup content at those positions would remain if not cleared - // beforehand. We write `popup_rect.width` spaces instead of - // ClearType::CurrentLine so that only the popup area is cleared, not - // the entire terminal line. - let _ = execute!( - stdout, - MoveTo(popup_rect.x, target_row), - ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset), - ); - let _ = write!(stdout, "{:width$}", "", width = popup_rect.width as usize); - let _ = execute!(stdout, MoveTo(popup_rect.x, target_row)); - - if let Some(row_bytes) = saved.rows_data.get(source_row) { - let _ = stdout.write_all(row_bytes); - } - } - - let _ = execute!( - stdout, - MoveTo( - saved.cursor_col, - saved.cursor_row.saturating_sub(scroll_offset) - ) - ); - let _ = stdout.flush(); -} - -struct Stdout { - writer: TerminalWriter, - inline_mode: bool, - no_mouse: bool, -} - -impl Stdout { - pub fn new(inline_mode: bool, no_mouse: bool) -> std::io::Result { - terminal::enable_raw_mode()?; - - let mut writer = TerminalWriter::new()?; - - if !inline_mode { - execute!(writer, terminal::EnterAlternateScreen)?; - } - - if !no_mouse { - execute!(writer, event::EnableMouseCapture)?; - } - - execute!(writer, event::EnableBracketedPaste)?; - - #[cfg(not(target_os = "windows"))] - execute!( - writer, - PushKeyboardEnhancementFlags( - KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES - | KeyboardEnhancementFlags::REPORT_ALL_KEYS_AS_ESCAPE_CODES - | KeyboardEnhancementFlags::REPORT_ALTERNATE_KEYS - ), - )?; - - Ok(Self { - writer, - inline_mode, - no_mouse, - }) - } -} - -impl Drop for Stdout { - fn drop(&mut self) { - #[cfg(not(target_os = "windows"))] - if let Err(e) = execute!(self.writer, PopKeyboardEnhancementFlags) { - tracing::error!(?e, "Failed to pop keyboard enhancement flags"); - } - - if !self.inline_mode - && let Err(e) = execute!(self.writer, terminal::LeaveAlternateScreen) - { - tracing::error!(?e, "Failed to leave alt screen mode"); - } - - if !self.no_mouse - && let Err(e) = execute!(self.writer, event::DisableMouseCapture) - { - tracing::error!(?e, "Failed to disable mouse capture"); - } - - if let Err(e) = execute!(self.writer, event::DisableBracketedPaste) { - tracing::error!(?e, "Failed to disable bracketed paste"); - } - - if let Err(e) = terminal::disable_raw_mode() { - tracing::error!(?e, "Failed to disable raw mode"); - } - } -} - -impl Write for Stdout { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.writer.write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.writer.flush() - } -} - -// this is a big blob of horrible! clean it up! -/// Compute the popup position and any scroll offset needed to make room. -/// -/// Given the cursor row, terminal dimensions, and desired popup height, -/// returns `(popup_rect, scroll_offset)` where `scroll_offset` is the number -/// of lines the caller should scroll the terminal up before rendering. -/// -/// This function performs no I/O — it is a pure computation. -#[cfg(unix)] -fn compute_popup_placement( - cursor_row: u16, - term_rows: u16, - term_cols: u16, - inline_height: u16, -) -> (Rect, u16) { - let popup_w = term_cols; - let popup_h = inline_height.min(term_rows); - let space_below = term_rows.saturating_sub(cursor_row); - - let (popup_y, scroll) = if popup_h <= space_below { - // Fits below cursor - (cursor_row, 0u16) - } else if cursor_row >= term_rows / 2 { - // Bottom half — render above cursor (overlay on existing text) - (cursor_row.saturating_sub(popup_h), 0u16) - } else { - // Top half, not enough space — scroll terminal to make room - let scroll = popup_h.saturating_sub(space_below); - let popup_y = cursor_row.saturating_sub(scroll); - (popup_y, scroll) - }; - - (Rect::new(0, popup_y, popup_w, popup_h), scroll) -} - -// for now, it works. But it'd be great if it were more easily readable, and -// modular. I'd like to add some more stats and stuff at some point -#[expect( - clippy::cast_possible_truncation, - clippy::too_many_lines, - clippy::cognitive_complexity -)] -pub async fn history( - query: &[String], - settings: &Settings, - mut db: impl Database, - history_store: &HistoryStore, - theme: &Theme, -) -> Result { - let inline_height = if settings.shell_up_key_binding { - settings - .inline_height_shell_up_key_binding - .unwrap_or(settings.inline_height) - } else { - settings.inline_height - }; - - // Use fullscreen mode if the inline height doesn't fit in the terminal, - // this will preserve the scroll position upon exit. - // Also force fullscreen when stdout isn't a terminal (e.g., command substitution - // like VAR=$(atuin search -i)). In that case, we need to use /dev/tty for the TUI and force - // fullscreen mode (inline mode won't work as it requires cursor position queries - // that don't work when stdout is captured). - let inline_height = if !stdout().is_terminal() { - 0 - } else if let Ok(size) = terminal::size() - && inline_height >= size.1 - { - 0 - } else { - inline_height - }; - - // Popup mode: if running under atuin pty-proxy and inline mode is requested, - // fetch the screen state and render as a centered overlay. - #[cfg(unix)] - let (saved_screen, popup_rect, popup_scroll_offset) = { - let socket_path = std::env::var("ATUIN_PTY_PROXY_SOCKET") - .or_else(|_| std::env::var("ATUIN_HEX_SOCKET")) - .ok(); - if let Some(ref path) = socket_path - && inline_height > 0 - { - let saved = fetch_screen_state(path); - if let Some(ref s) = saved { - let (term_cols, term_rows) = terminal::size().unwrap_or((s.cols, s.rows)); - let (popup_rect, scroll) = - compute_popup_placement(s.cursor_row, term_rows, term_cols, inline_height); - - // Scroll terminal content up to make room if needed - if scroll > 0 { - use ratatui::crossterm::cursor::MoveTo; - let mut stdout = stdout(); - let _ = execute!(stdout, MoveTo(0, term_rows - 1)); - for _ in 0..scroll { - let _ = writeln!(stdout); - } - let _ = stdout.flush(); - } - - (saved, popup_rect, scroll) - } else { - (None, Rect::default(), 0u16) - } - } else { - (None, Rect::default(), 0u16) - } - }; - - #[cfg(not(unix))] - let (saved_screen, popup_rect, _popup_scroll_offset): (Option<()>, Rect, u16) = - (None, Rect::default(), 0); - - let popup_mode = saved_screen.is_some(); - - let stdout = Stdout::new(inline_height > 0, settings.no_mouse)?; - - // In popup mode, clear the popup region on the physical terminal before - // ratatui takes over. Ratatui's diff-based rendering compares against an - // initially-empty buffer, so cells that remain "empty" (spaces with default - // style) won't be written — leaving underlying terminal text visible. - // By pre-clearing with spaces, those cells are already correct on screen. - if popup_mode { - use ratatui::crossterm::cursor::MoveTo; - let mut raw_stdout = std::io::stdout(); - // Queue all commands without flushing so the terminal receives them - // as a single write — no intermediate cursor positions are visible. - let _ = queue!( - raw_stdout, - ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset) - ); - for row in popup_rect.y..popup_rect.y.saturating_add(popup_rect.height) { - let _ = queue!(raw_stdout, MoveTo(popup_rect.x, row)); - let _ = write!( - raw_stdout, - "{:width$}", - "", - width = popup_rect.width as usize - ); - } - let _ = raw_stdout.flush(); - } - - let backend = CrosstermBackend::new(stdout); - let mut terminal = Terminal::with_options( - backend, - TerminalOptions { - viewport: if popup_mode { - Viewport::Fixed(popup_rect) - } else if inline_height > 0 { - Viewport::Inline(inline_height) - } else { - Viewport::Fullscreen - }, - }, - )?; - - let original_query = query.join(" "); - - // Check if this is a command chaining scenario - let is_command_chaining = if settings.command_chaining { - let trimmed = original_query.trim_end(); - trimmed.ends_with("&&") || trimmed.ends_with('|') - } else { - false - }; - - // For command chaining, start with empty input to allow searching for new commands - let search_input = if is_command_chaining { - String::new() - } else { - original_query.clone() - }; - - let mut input = Cursor::from(search_input); - // Put the cursor at the end of the query by default - input.end(); - - let initial_context = current_context().await?; - - let history_count = db.history_count(false).await?; - let search_mode = if settings.shell_up_key_binding { - settings - .search_mode_shell_up_key_binding - .unwrap_or(settings.search_mode) - } else { - settings.search_mode - }; - let default_filter_mode = settings - .filter_mode_shell_up_key_binding - .filter(|_| settings.shell_up_key_binding) - .unwrap_or_else(|| settings.default_filter_mode(initial_context.git_root.is_some())); - let mut app = State { - history_count, - results_state: ListState::default(), - switched_search_mode: false, - search_mode, - tab_index: 0, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::from_settings(settings), - search: SearchState { - input, - filter_mode: default_filter_mode, - context: initial_context.clone(), - custom_context: None, - }, - engine: engines::engine(search_mode, settings), - results_len: 0, - accept: false, - keymap_mode: match settings.keymap_mode { - KeymapMode::Auto => KeymapMode::Emacs, - value => value, - }, - current_cursor: None, - now: if settings.prefers_reduced_motion { - let now = OffsetDateTime::now_utc(); - Box::new(move || now) - } else { - Box::new(OffsetDateTime::now_utc) - }, - prefix: false, - pending_vim_key: None, - original_input_empty: original_query.is_empty(), - }; - - app.initialize_keymap_cursor(settings); - - let mut results = app.query_results(&mut db, settings.smart_sort).await?; - - if inline_height > 0 && !popup_mode { - terminal.clear()?; - } - - let mut stats: Option = None; - let mut inspecting: Option = None; - let accept; - let result = 'render: loop { - terminal.draw(|f| { - app.draw( - f, - &results, - stats.clone(), - inspecting.as_ref(), - settings, - theme, - popup_mode, - ); - })?; - - let initial_input = app.search.input.as_str().to_owned(); - let initial_filter_mode = app.search.filter_mode; - let initial_search_mode = app.search_mode; - let initial_custom_context = app.search.custom_context.clone(); - - let event_ready = tokio::task::spawn_blocking(|| event::poll(Duration::from_millis(250))); - - tokio::select! { - event_ready = event_ready => { - if event_ready?? { - loop { - match app.handle_input(settings, &event::read()?) { - InputAction::Continue => {}, - InputAction::Delete(index) => { - if results.is_empty() { - break; - } - app.results_len -= 1; - let selected = app.results_state.selected(); - if selected == app.results_len { - app.inspecting_state.reset(); - app.results_state.select(selected - 1); - } - - let entry = results.remove(index); - - let ids = history_store.delete_entries([entry]).await?; - history_store.incremental_build(&db, &ids).await?; - - app.tab_index = 0; - }, - InputAction::DeleteAllMatching(index) => { - if results.is_empty() { - break; - } - - let command = results[index].command.clone(); - - // Remove matching entries from the visible results - results.retain(|e| e.command != command); - - // Query the DB for ALL entries with this command and delete them - let all_matching = db.query_history( - &format!( - "select * from history where command = '{}' and deleted_at is null", - command.replace('\'', "''") - ) - ).await?; - - let ids = history_store.delete_entries(all_matching).await?; - history_store.incremental_build(&db, &ids).await?; - - app.results_len = results.len(); - app.results_state = ListState::default(); - app.inspecting_state.reset(); - app.tab_index = 0; - }, - InputAction::SwitchContext(index) => { - if let Some(index) = index && let Some(entry) = results.get(index) { - app.search.custom_context = Some(entry.id.clone()); - app.search.context = Context::from_history(entry); - app.search.filter_mode = FilterMode::Session; - app.search.input = Cursor::from(String::new()); - app.results_state = ListState::default(); - } else { - app.search.custom_context = None; - app.search.context = initial_context.clone(); - app.search.filter_mode = default_filter_mode; - } - }, - InputAction::Redraw => { - if !popup_mode { - terminal.clear()?; - } - terminal.draw(|f| { - app.draw(f, &results, stats.clone(), inspecting.as_ref(), settings, theme, popup_mode); - })?; - }, - r => { - accept = app.accept; - break 'render r; - }, - } - if !event::poll(Duration::ZERO)? { - break; - } - } - } - } - } - - if initial_input != app.search.input.as_str() - || initial_filter_mode != app.search.filter_mode - || initial_search_mode != app.search_mode - || initial_custom_context != app.search.custom_context - { - results = app.query_results(&mut db, settings.smart_sort).await?; - } - - // In custom context mode, when no filter is applied, highlight the entry which was used - // to enter the context when changing modes. This helps to find your way around. - if app.search.custom_context.is_some() - && app.search.input.as_str().is_empty() - && (initial_custom_context != app.search.custom_context - || initial_filter_mode != app.search.filter_mode) - && let Some(history_id) = app.search.custom_context.clone() - && let Some(pos) = results.iter().position(|entry| entry.id == history_id) - { - app.results_state.select(pos); - } - - let inspecting_id = app.inspecting_state.clone().current; - // If inspecting ID is not the current inspecting History, update it. - match inspecting_id { - Some(inspecting_id) => { - if inspecting.is_none() || inspecting_id != inspecting.clone().unwrap().id { - inspecting = db.load(inspecting_id.0.as_str()).await?; - } - } - _ => { - inspecting = None; - } - } - - stats = if app.tab_index == 0 { - None - } else if !results.is_empty() { - // If we have stats, then we can indicate next available IDs. This avoids passing - // around a database object, or a full stats object. - let selected = match inspecting.clone() { - Some(insp) => insp, - None => results[app.results_state.selected()].clone(), - }; - let stats = db.stats(&selected).await?; - app.inspecting_state.current = Some(selected.id); - app.inspecting_state.previous = match stats.previous.clone() { - Some(p) => Some(p.id), - _ => None, - }; - app.inspecting_state.next = match stats.next.clone() { - Some(p) => Some(p.id), - _ => None, - }; - Some(stats) - } else { - None - }; - }; - - app.finalize_keymap_cursor(settings); - - if popup_mode { - // In popup mode, restore the screen area that was covered by the popup. - // This must happen before Stdout is dropped (which disables raw mode). - #[cfg(unix)] - if let Some(ref saved) = saved_screen { - restore_popup_area(saved, popup_rect, popup_scroll_offset); - } - } else if inline_height > 0 { - terminal.clear()?; - } - - let accept = accept - && matches!( - Shell::from_env(), - Shell::Zsh | Shell::Fish | Shell::Bash | Shell::Xonsh | Shell::Nu | Shell::Powershell - ); - - let accept_prefix = "__atuin_accept__:"; - - match result { - InputAction::AcceptInspecting => { - match inspecting { - Some(result) => { - let mut command = result.command; - - if accept { - command = String::from(accept_prefix) + &command; - } - - // index is in bounds so we return that entry - Ok(command) - } - None => Ok(String::new()), - } - } - InputAction::Accept(index) if index < results.len() => { - let mut command = results.swap_remove(index).command; - - if is_command_chaining { - command = format!("{} {}", original_query.trim_end(), command); - } else if accept { - command = String::from(accept_prefix) + &command; - } - - // index is in bounds so we return that entry - Ok(command) - } - InputAction::ReturnOriginal => Ok(String::new()), - InputAction::Copy(index) => { - let cmd = results.swap_remove(index).command; - set_clipboard(cmd); - Ok(String::new()) - } - InputAction::ReturnQuery | InputAction::Accept(_) => { - // Either: - // * index == RETURN_QUERY, in which case we should return the input - // * out of bounds -> usually implies no selected entry so we return the input - Ok(app.search.input.into_inner()) - } - InputAction::Continue - | InputAction::Redraw - | InputAction::Delete(_) - | InputAction::DeleteAllMatching(_) - | InputAction::SwitchContext(_) => { - unreachable!("should have been handled!") - } - } -} - -// cli-clipboard only works on Windows, Mac, and Linux. - -#[cfg(all( - feature = "clipboard", - any(target_os = "windows", target_os = "macos", target_os = "linux") -))] -fn set_clipboard(s: String) { - let mut ctx = arboard::Clipboard::new().unwrap(); - ctx.set_text(s).unwrap(); - // Use the clipboard context to make sure it is saved - ctx.get_text().unwrap(); -} - -#[cfg(not(all( - feature = "clipboard", - any(target_os = "windows", target_os = "macos", target_os = "linux") -)))] -fn set_clipboard(_s: String) {} - -#[cfg(test)] -mod tests { - use atuin_client::database::Context; - use atuin_client::history::History; - use atuin_client::settings::{ - FilterMode, KeymapMode, Preview, PreviewStrategy, SearchMode, Settings, - }; - use time::OffsetDateTime; - - use crate::command::client::search::engines::{self, SearchState}; - use crate::command::client::search::history_list::ListState; - - use super::{Compactness, InspectingState, KeymapSet, State}; - - #[test] - #[expect(clippy::too_many_lines)] - fn calc_preview_height_test() { - let settings_preview_auto = Settings { - preview: Preview { - strategy: PreviewStrategy::Auto, - }, - show_preview: true, - ..Settings::utc() - }; - - let settings_preview_auto_h2 = Settings { - preview: Preview { - strategy: PreviewStrategy::Auto, - }, - show_preview: true, - max_preview_height: 2, - ..Settings::utc() - }; - - let settings_preview_h4 = Settings { - preview: Preview { - strategy: PreviewStrategy::Static, - }, - show_preview: true, - max_preview_height: 4, - ..Settings::utc() - }; - - let settings_preview_fixed = Settings { - preview: Preview { - strategy: PreviewStrategy::Fixed, - }, - show_preview: true, - max_preview_height: 15, - ..Settings::utc() - }; - - let cmd_60: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("for i in $(seq -w 10); do echo \"item number $i - abcd\"; done") - .cwd("/") - .build() - .into(); - - let cmd_124: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("echo 'Aurea prima sata est aetas, quae vindice nullo, sponte sua, sine lege fidem rectumque colebat. Poena metusque aberant'") - .cwd("/") - .build() - .into(); - - let cmd_200: History = History::capture() - .timestamp(time::OffsetDateTime::now_utc()) - .command("CREATE USER atuin WITH ENCRYPTED PASSWORD 'supersecretpassword'; CREATE DATABASE atuin WITH OWNER = atuin; \\c atuin; REVOKE ALL PRIVILEGES ON SCHEMA public FROM PUBLIC; echo 'All done. 200 characters'") - .cwd("/") - .build() - .into(); - - let results: Vec = vec![cmd_60, cmd_124, cmd_200]; - - // the selected command does not require a preview - let no_preview = State::calc_preview_height( - &settings_preview_auto, - &results, - 0_usize, - 0_usize, - Compactness::Full, - 1, - 80, - ); - // the selected command requires 2 lines - let preview_h2 = State::calc_preview_height( - &settings_preview_auto, - &results, - 1_usize, - 0_usize, - Compactness::Full, - 1, - 80, - ); - // the selected command requires 3 lines - let preview_h3 = State::calc_preview_height( - &settings_preview_auto, - &results, - 2_usize, - 0_usize, - Compactness::Full, - 1, - 80, - ); - // the selected command requires a preview of 1 line (happens when the command is between preview_width-19 and preview_width) - let preview_one_line = State::calc_preview_height( - &settings_preview_auto, - &results, - 0_usize, - 0_usize, - Compactness::Full, - 1, - 66, - ); - // the selected command requires 3 lines, but we have a max preview height limit of 2 - let preview_limit_at_2 = State::calc_preview_height( - &settings_preview_auto_h2, - &results, - 2_usize, - 0_usize, - Compactness::Full, - 1, - 80, - ); - // the longest command requires 3 lines - let preview_static_h3 = State::calc_preview_height( - &settings_preview_h4, - &results, - 1_usize, - 0_usize, - Compactness::Full, - 1, - 80, - ); - // the longest command requires 10 lines, but we have a max preview height limit of 4 - let preview_static_limit_at_4 = State::calc_preview_height( - &settings_preview_h4, - &results, - 1_usize, - 0_usize, - Compactness::Full, - 1, - 20, - ); - // the longest command requires 10 lines, but we have a max preview height of 15 and a fixed preview strategy - let settings_preview_fixed = State::calc_preview_height( - &settings_preview_fixed, - &results, - 1_usize, - 0_usize, - Compactness::Full, - 1, - 20, - ); - - assert_eq!(no_preview, 1); - // 1 * 2 is the space for the border - let border_space = 2; - assert_eq!(preview_h2, 2 + border_space); - assert_eq!(preview_h3, 3 + border_space); - assert_eq!(preview_one_line, 1 + border_space); - assert_eq!(preview_limit_at_2, 2 + border_space); - assert_eq!(preview_static_h3, 3 + border_space); - assert_eq!(preview_static_limit_at_4, 4 + border_space); - assert_eq!(settings_preview_fixed, 15 + border_space); - } - - // Test when there's no results, scrolling up or down doesn't underflow - #[test] - fn state_scroll_up_underflow() { - let settings = Settings::utc(); - let mut state = State { - history_count: 0, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 0, - accept: false, - keymap_mode: KeymapMode::Auto, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Directory, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - state.scroll_up(1); - state.scroll_down(1); - } - - #[test] - fn test_accept_keybindings() { - use atuin_client::settings::Keys; - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let mut settings = Settings::utc(); - settings.keys = Keys { - scroll_exits: true, - exit_past_line_start: false, - accept_past_line_end: true, - accept_past_line_start: false, - accept_with_backspace: false, - prefix: "a".to_string(), - }; - - let mut state = State { - history_count: 1, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 1, - accept: false, - keymap_mode: KeymapMode::Emacs, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &tab_event); - assert!( - matches!(result, super::InputAction::Accept(_)), - "Tab should always accept" - ); - - // Test left arrow with accept_past_line_start disabled (should continue) - let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &left_event); - assert!( - matches!(result, super::InputAction::Continue), - "Left arrow should continue when disabled" - ); - - // Test left arrow with accept_past_line_start enabled (should accept at start of line) - settings.keys.accept_past_line_start = true; - state.keymaps = KeymapSet::defaults(&settings); - let result = state.handle_key_input(&settings, &left_event); - assert!( - matches!(result, super::InputAction::Accept(_)), - "Left arrow should accept at start of line when enabled" - ); - settings.keys.accept_past_line_start = false; - state.keymaps = KeymapSet::defaults(&settings); - - let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &backspace_event); - assert!( - matches!(result, super::InputAction::Continue), - "Backspace should continue when disabled" - ); - - settings.keys.accept_with_backspace = true; - state.keymaps = KeymapSet::defaults(&settings); - let result = state.handle_key_input(&settings, &backspace_event); - assert!( - matches!(result, super::InputAction::Accept(_)), - "Backspace should accept at start of line when enabled" - ); - - state.search.input.insert('t'); - state.search.input.insert('e'); - state.search.input.insert('s'); - state.search.input.insert('t'); - state.search.input.end(); - - let right_event = KeyEvent::new(KeyCode::Right, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &right_event); - assert!( - matches!(result, super::InputAction::Accept(_)), - "Right arrow should accept at end of line when enabled" - ); - - settings.keys.accept_past_line_start = true; - state.keymaps = KeymapSet::defaults(&settings); - let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &left_event); - assert!( - matches!(result, super::InputAction::Continue), - "Left arrow should continue and end of line, even when enabled" - ); - settings.keys.accept_past_line_start = false; - state.keymaps = KeymapSet::defaults(&settings); - - settings.keys.accept_with_backspace = true; - state.keymaps = KeymapSet::defaults(&settings); - let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &backspace_event); - assert!( - matches!(result, super::InputAction::Continue), - "Backspace should continue at end of line, even when enabled" - ); - settings.keys.accept_with_backspace = false; - state.keymaps = KeymapSet::defaults(&settings); - } - - #[test] - fn test_vim_gg_multikey_sequence() { - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let settings = Settings::utc(); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::VimNormal, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - // Start in the middle of the list - state.results_state.select(50); - - // First 'g' should set pending state - let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &g_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, Some('g')); - assert_eq!(state.results_state.selected(), 50); // Position unchanged - - // Second 'g' should jump to end (visual top in non-inverted mode) - let result = state.handle_key_input(&settings, &g_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, None); - assert_eq!(state.results_state.selected(), 99); // Jumped to last index (visual top) - } - - #[test] - fn test_vim_g_key_clears_on_other_input() { - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let settings = Settings::utc(); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::VimNormal, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - state.results_state.select(50); - - // Press 'g' to set pending state - let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); - state.handle_key_input(&settings, &g_event); - assert_eq!(state.pending_vim_key, Some('g')); - - // Press 'j' - should clear pending state - let j_event = KeyEvent::new(KeyCode::Char('j'), KeyModifiers::NONE); - state.handle_key_input(&settings, &j_event); - assert_eq!(state.pending_vim_key, None); - } - - #[test] - fn test_vim_big_g_jump_to_bottom() { - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let settings = Settings::utc(); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::VimNormal, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - state.results_state.select(50); - - // 'G' should jump to visual bottom (index 0 in non-inverted mode) - let big_g_event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &big_g_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.results_state.selected(), 0); - } - - #[test] - fn test_vim_ctrl_u_d_half_page_scroll() { - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let settings = Settings::utc(); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::VimNormal, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - state.results_state.select(50); - - // Ctrl+d should return Continue and clear pending key - // (scroll amount depends on max_entries which is 0 in tests) - state.pending_vim_key = Some('g'); - let ctrl_d_event = KeyEvent::new(KeyCode::Char('d'), KeyModifiers::CONTROL); - let result = state.handle_key_input(&settings, &ctrl_d_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, None); - - // Ctrl+u should return Continue and clear pending key - state.pending_vim_key = Some('g'); - let ctrl_u_event = KeyEvent::new(KeyCode::Char('u'), KeyModifiers::CONTROL); - let result = state.handle_key_input(&settings, &ctrl_u_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, None); - } - - #[test] - fn test_vim_ctrl_f_b_full_page_scroll() { - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - let settings = Settings::utc(); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::VimNormal, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - state.results_state.select(50); - - // Ctrl+f should return Continue and clear pending key - // (scroll amount depends on max_entries which is 0 in tests) - state.pending_vim_key = Some('g'); - let ctrl_f_event = KeyEvent::new(KeyCode::Char('f'), KeyModifiers::CONTROL); - let result = state.handle_key_input(&settings, &ctrl_f_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, None); - - // Ctrl+b should return Continue and clear pending key - state.pending_vim_key = Some('g'); - let ctrl_b_event = KeyEvent::new(KeyCode::Char('b'), KeyModifiers::CONTROL); - let result = state.handle_key_input(&settings, &ctrl_b_event); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.pending_vim_key, None); - } - - // ----------------------------------------------------------------------- - // Executor tests (execute_action) - // ----------------------------------------------------------------------- - - /// Helper to build a State for executor tests. - fn make_executor_state(results_len: usize, selected: usize) -> State { - let settings = Settings::utc(); - let mut state = State { - history_count: results_len as i64, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len, - accept: false, - keymap_mode: KeymapMode::Emacs, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::defaults(&settings), - search: SearchState { - input: String::new().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - state.results_state.select(selected); - state - } - - #[test] - fn execute_select_next_no_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let settings = Settings::utc(); - let result = state.execute_action(&Action::SelectNext, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Non-inverted: SelectNext = scroll_down = selected - 1 - assert_eq!(state.results_state.selected(), 49); - } - - #[test] - fn execute_select_next_with_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let mut settings = Settings::utc(); - settings.invert = true; - let result = state.execute_action(&Action::SelectNext, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Inverted: SelectNext = scroll_up = selected + 1 - assert_eq!(state.results_state.selected(), 51); - } - - #[test] - fn execute_select_previous_no_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let settings = Settings::utc(); - let result = state.execute_action(&Action::SelectPrevious, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Non-inverted: SelectPrevious = scroll_up = selected + 1 - assert_eq!(state.results_state.selected(), 51); - } - - #[test] - fn execute_vim_enter_normal() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - let result = state.execute_action(&Action::VimEnterNormal, &settings); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.keymap_mode, KeymapMode::VimNormal); - } - - #[test] - fn execute_vim_enter_insert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - state.keymap_mode = KeymapMode::VimNormal; - let settings = Settings::utc(); - let result = state.execute_action(&Action::VimEnterInsert, &settings); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.keymap_mode, KeymapMode::VimInsert); - } - - #[test] - fn execute_accept_sets_accept_flag() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 5); - let mut settings = Settings::utc(); - settings.enter_accept = true; - let result = state.execute_action(&Action::Accept, &settings); - assert!(matches!(result, super::InputAction::Accept(5))); - assert!(state.accept); - } - - #[test] - fn execute_return_selection_does_not_set_accept() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 5); - let settings = Settings::utc(); - let result = state.execute_action(&Action::ReturnSelection, &settings); - assert!(matches!(result, super::InputAction::Accept(5))); - assert!(!state.accept); - } - - #[test] - fn execute_accept_nth() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 5); - let settings = Settings::utc(); - let result = state.execute_action(&Action::AcceptNth(3), &settings); - assert!(matches!(result, super::InputAction::Accept(8))); - } - - #[test] - fn execute_scroll_to_top_no_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let settings = Settings::utc(); - let result = state.execute_action(&Action::ScrollToTop, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Non-inverted: visual top = highest index - assert_eq!(state.results_state.selected(), 99); - } - - #[test] - fn execute_scroll_to_top_with_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let mut settings = Settings::utc(); - settings.invert = true; - let result = state.execute_action(&Action::ScrollToTop, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Inverted: visual top = index 0 - assert_eq!(state.results_state.selected(), 0); - } - - #[test] - fn execute_scroll_to_bottom_no_invert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let settings = Settings::utc(); - let result = state.execute_action(&Action::ScrollToBottom, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Non-inverted: visual bottom = index 0 - assert_eq!(state.results_state.selected(), 0); - } - - #[test] - fn execute_toggle_tab() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - assert_eq!(state.tab_index, 0); - state.execute_action(&Action::ToggleTab, &settings); - assert_eq!(state.tab_index, 1); - state.execute_action(&Action::ToggleTab, &settings); - assert_eq!(state.tab_index, 0); - } - - #[test] - fn execute_enter_prefix_mode() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - assert!(!state.prefix); - state.execute_action(&Action::EnterPrefixMode, &settings); - assert!(state.prefix); - } - - #[test] - fn execute_exit_returns_based_on_exit_mode() { - use crate::command::client::search::keybindings::Action; - use atuin_client::settings::ExitMode; - - let mut state = make_executor_state(100, 0); - let mut settings = Settings::utc(); - - settings.exit_mode = ExitMode::ReturnOriginal; - let result = state.execute_action(&Action::Exit, &settings); - assert!(matches!(result, super::InputAction::ReturnOriginal)); - - settings.exit_mode = ExitMode::ReturnQuery; - let result = state.execute_action(&Action::Exit, &settings); - assert!(matches!(result, super::InputAction::ReturnQuery)); - } - - #[test] - fn execute_return_original() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - let result = state.execute_action(&Action::ReturnOriginal, &settings); - assert!(matches!(result, super::InputAction::ReturnOriginal)); - } - - #[test] - fn execute_copy() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 7); - let settings = Settings::utc(); - let result = state.execute_action(&Action::Copy, &settings); - assert!(matches!(result, super::InputAction::Copy(7))); - } - - #[test] - fn execute_delete() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 7); - let settings = Settings::utc(); - let result = state.execute_action(&Action::Delete, &settings); - assert!(matches!(result, super::InputAction::Delete(7))); - } - - #[test] - fn execute_switch_context() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 7); - let settings = Settings::utc(); - let result = state.execute_action(&Action::SwitchContext, &settings); - assert!(matches!(result, super::InputAction::SwitchContext(Some(7)))); - } - - #[test] - fn execute_clear_context() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 7); - let settings = Settings::utc(); - let result = state.execute_action(&Action::ClearContext, &settings); - assert!(matches!(result, super::InputAction::SwitchContext(None))); - } - - #[test] - fn execute_noop() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 50); - let settings = Settings::utc(); - let result = state.execute_action(&Action::Noop, &settings); - assert!(matches!(result, super::InputAction::Continue)); - assert_eq!(state.results_state.selected(), 50); - } - - #[test] - fn execute_accept_in_inspector_tab() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 5); - state.tab_index = 1; - let settings = Settings::utc(); - let result = state.execute_action(&Action::Accept, &settings); - assert!(matches!(result, super::InputAction::AcceptInspecting)); - } - - #[test] - fn execute_cycle_search_mode() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - let original_mode = state.search_mode; - let result = state.execute_action(&Action::CycleSearchMode, &settings); - assert!(matches!(result, super::InputAction::Continue)); - assert!(state.switched_search_mode); - assert_ne!(state.search_mode, original_mode); - } - - #[test] - fn execute_vim_search_insert() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - state.search.input.insert('h'); - state.search.input.insert('i'); - state.keymap_mode = KeymapMode::VimNormal; - let settings = Settings::utc(); - let result = state.execute_action(&Action::VimSearchInsert, &settings); - assert!(matches!(result, super::InputAction::Continue)); - // Should clear input and switch to insert mode - assert_eq!(state.search.input.as_str(), ""); - assert_eq!(state.keymap_mode, KeymapMode::VimInsert); - } - - #[test] - fn execute_cursor_movement() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - - // Insert some text - state.search.input.insert('h'); - state.search.input.insert('e'); - state.search.input.insert('l'); - state.search.input.insert('l'); - state.search.input.insert('o'); - // cursor is at end (position 5) - - // CursorLeft - state.execute_action(&Action::CursorLeft, &settings); - assert_eq!(state.search.input.position(), 4); - - // CursorStart - state.execute_action(&Action::CursorStart, &settings); - assert_eq!(state.search.input.position(), 0); - - // CursorEnd - state.execute_action(&Action::CursorEnd, &settings); - assert_eq!(state.search.input.position(), 5); - - // CursorRight at end does nothing - state.execute_action(&Action::CursorRight, &settings); - assert_eq!(state.search.input.position(), 5); - } - - #[test] - fn execute_editing() { - use crate::command::client::search::keybindings::Action; - - let mut state = make_executor_state(100, 0); - let settings = Settings::utc(); - - // Insert "hello" - state.search.input.insert('h'); - state.search.input.insert('e'); - state.search.input.insert('l'); - state.search.input.insert('l'); - state.search.input.insert('o'); - - // DeleteCharBefore (backspace) - state.execute_action(&Action::DeleteCharBefore, &settings); - assert_eq!(state.search.input.as_str(), "hell"); - - // ClearLine - state.execute_action(&Action::ClearLine, &settings); - assert_eq!(state.search.input.as_str(), ""); - } - - #[test] - fn keymap_config_return_query() { - use atuin_client::settings::KeyBindingConfig; - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - use std::collections::HashMap; - - let mut settings = Settings::utc(); - // Configure tab to return-query - settings.keymap.emacs = HashMap::from([( - "tab".to_string(), - KeyBindingConfig::Simple("return-query".to_string()), - )]); - - let mut state = State { - history_count: 100, - results_state: ListState::default(), - switched_search_mode: false, - search_mode: SearchMode::Fuzzy, - results_len: 100, - accept: false, - keymap_mode: KeymapMode::Emacs, - prefix: false, - current_cursor: None, - tab_index: 0, - pending_vim_key: None, - original_input_empty: false, - inspecting_state: InspectingState { - current: None, - next: None, - previous: None, - }, - keymaps: KeymapSet::from_settings(&settings), - search: SearchState { - input: "test query".to_string().into(), - filter_mode: FilterMode::Global, - context: Context { - session: String::new(), - cwd: String::new(), - hostname: String::new(), - host_id: String::new(), - git_root: None, - }, - custom_context: None, - }, - engine: engines::engine(SearchMode::Fuzzy, &settings), - now: Box::new(OffsetDateTime::now_utc), - }; - - let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); - let result = state.handle_key_input(&settings, &tab_event); - assert!( - matches!(result, super::InputAction::ReturnQuery), - "Tab configured as return-query should return InputAction::ReturnQuery" - ); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/actions.rs b/crates/atuin/src/command/client/search/keybindings/actions.rs deleted file mode 100644 index ff2ef7de..00000000 --- a/crates/atuin/src/command/client/search/keybindings/actions.rs +++ /dev/null @@ -1,322 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -/// All possible actions that can be triggered by a keybinding. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Action { - // Cursor movement - CursorLeft, - CursorRight, - CursorWordLeft, - CursorWordRight, - CursorWordEnd, - CursorStart, - CursorEnd, - - // Editing - DeleteCharBefore, - DeleteCharAfter, - DeleteWordBefore, - DeleteWordAfter, - DeleteToWordBoundary, - ClearLine, - ClearToStart, - ClearToEnd, - - // List navigation - SelectNext, - SelectPrevious, - ScrollHalfPageUp, - ScrollHalfPageDown, - ScrollPageUp, - ScrollPageDown, - ScrollToTop, - ScrollToBottom, - ScrollToScreenTop, - ScrollToScreenMiddle, - ScrollToScreenBottom, - - // Commands — accept selection and execute immediately - Accept, - AcceptNth(u8), - // Commands — return selection to command line without executing - ReturnSelection, - ReturnSelectionNth(u8), - // Commands — other - Copy, - Delete, - DeleteAll, - ReturnOriginal, - ReturnQuery, - Exit, - Redraw, - CycleFilterMode, - CycleSearchMode, - SwitchContext, - ClearContext, - ToggleTab, - - // Mode changes - VimEnterNormal, - VimEnterInsert, - VimEnterInsertAfter, - VimEnterInsertAtStart, - VimEnterInsertAtEnd, - VimSearchInsert, - VimChangeToEnd, - EnterPrefixMode, - - // Inspector - InspectPrevious, - InspectNext, - - // Special - Noop, -} - -impl Action { - /// Convert from a kebab-case string. - pub fn from_str(s: &str) -> Result { - // Handle accept-N and return-selection-N patterns - if let Some(rest) = s.strip_prefix("accept-") - && let Ok(n) = rest.parse::() - && (1..=9).contains(&n) - { - return Ok(Action::AcceptNth(n)); - } - if let Some(rest) = s.strip_prefix("return-selection-") - && let Ok(n) = rest.parse::() - && (1..=9).contains(&n) - { - return Ok(Action::ReturnSelectionNth(n)); - } - - match s { - "cursor-left" => Ok(Action::CursorLeft), - "cursor-right" => Ok(Action::CursorRight), - "cursor-word-left" => Ok(Action::CursorWordLeft), - "cursor-word-right" => Ok(Action::CursorWordRight), - "cursor-word-end" => Ok(Action::CursorWordEnd), - "cursor-start" => Ok(Action::CursorStart), - "cursor-end" => Ok(Action::CursorEnd), - - "delete-char-before" => Ok(Action::DeleteCharBefore), - "delete-char-after" => Ok(Action::DeleteCharAfter), - "delete-word-before" => Ok(Action::DeleteWordBefore), - "delete-word-after" => Ok(Action::DeleteWordAfter), - "delete-to-word-boundary" => Ok(Action::DeleteToWordBoundary), - "clear-line" => Ok(Action::ClearLine), - "clear-to-start" => Ok(Action::ClearToStart), - "clear-to-end" => Ok(Action::ClearToEnd), - - "select-next" => Ok(Action::SelectNext), - "select-previous" => Ok(Action::SelectPrevious), - "scroll-half-page-up" => Ok(Action::ScrollHalfPageUp), - "scroll-half-page-down" => Ok(Action::ScrollHalfPageDown), - "scroll-page-up" => Ok(Action::ScrollPageUp), - "scroll-page-down" => Ok(Action::ScrollPageDown), - "scroll-to-top" => Ok(Action::ScrollToTop), - "scroll-to-bottom" => Ok(Action::ScrollToBottom), - "scroll-to-screen-top" => Ok(Action::ScrollToScreenTop), - "scroll-to-screen-middle" => Ok(Action::ScrollToScreenMiddle), - "scroll-to-screen-bottom" => Ok(Action::ScrollToScreenBottom), - - "accept" => Ok(Action::Accept), - "return-selection" => Ok(Action::ReturnSelection), - "copy" => Ok(Action::Copy), - "delete" => Ok(Action::Delete), - "delete-all" => Ok(Action::DeleteAll), - "return-original" => Ok(Action::ReturnOriginal), - "return-query" => Ok(Action::ReturnQuery), - "exit" => Ok(Action::Exit), - "redraw" => Ok(Action::Redraw), - "cycle-filter-mode" => Ok(Action::CycleFilterMode), - "cycle-search-mode" => Ok(Action::CycleSearchMode), - "switch-context" => Ok(Action::SwitchContext), - "clear-context" => Ok(Action::ClearContext), - "toggle-tab" => Ok(Action::ToggleTab), - - "vim-enter-normal" => Ok(Action::VimEnterNormal), - "vim-enter-insert" => Ok(Action::VimEnterInsert), - "vim-enter-insert-after" => Ok(Action::VimEnterInsertAfter), - "vim-enter-insert-at-start" => Ok(Action::VimEnterInsertAtStart), - "vim-enter-insert-at-end" => Ok(Action::VimEnterInsertAtEnd), - "vim-search-insert" => Ok(Action::VimSearchInsert), - "vim-change-to-end" => Ok(Action::VimChangeToEnd), - "enter-prefix-mode" => Ok(Action::EnterPrefixMode), - - "inspect-previous" => Ok(Action::InspectPrevious), - "inspect-next" => Ok(Action::InspectNext), - - "noop" => Ok(Action::Noop), - - _ => Err(format!("unknown action: {s}")), - } - } - - /// Convert to a kebab-case string. - pub fn as_str(&self) -> String { - match self { - Action::CursorLeft => "cursor-left".to_string(), - Action::CursorRight => "cursor-right".to_string(), - Action::CursorWordLeft => "cursor-word-left".to_string(), - Action::CursorWordRight => "cursor-word-right".to_string(), - Action::CursorWordEnd => "cursor-word-end".to_string(), - Action::CursorStart => "cursor-start".to_string(), - Action::CursorEnd => "cursor-end".to_string(), - - Action::DeleteCharBefore => "delete-char-before".to_string(), - Action::DeleteCharAfter => "delete-char-after".to_string(), - Action::DeleteWordBefore => "delete-word-before".to_string(), - Action::DeleteWordAfter => "delete-word-after".to_string(), - Action::DeleteToWordBoundary => "delete-to-word-boundary".to_string(), - Action::ClearLine => "clear-line".to_string(), - Action::ClearToStart => "clear-to-start".to_string(), - Action::ClearToEnd => "clear-to-end".to_string(), - - Action::SelectNext => "select-next".to_string(), - Action::SelectPrevious => "select-previous".to_string(), - Action::ScrollHalfPageUp => "scroll-half-page-up".to_string(), - Action::ScrollHalfPageDown => "scroll-half-page-down".to_string(), - Action::ScrollPageUp => "scroll-page-up".to_string(), - Action::ScrollPageDown => "scroll-page-down".to_string(), - Action::ScrollToTop => "scroll-to-top".to_string(), - Action::ScrollToBottom => "scroll-to-bottom".to_string(), - Action::ScrollToScreenTop => "scroll-to-screen-top".to_string(), - Action::ScrollToScreenMiddle => "scroll-to-screen-middle".to_string(), - Action::ScrollToScreenBottom => "scroll-to-screen-bottom".to_string(), - - Action::Accept => "accept".to_string(), - Action::AcceptNth(n) => format!("accept-{n}"), - Action::ReturnSelection => "return-selection".to_string(), - Action::ReturnSelectionNth(n) => format!("return-selection-{n}"), - Action::Copy => "copy".to_string(), - Action::Delete => "delete".to_string(), - Action::DeleteAll => "delete-all".to_string(), - Action::ReturnOriginal => "return-original".to_string(), - Action::ReturnQuery => "return-query".to_string(), - Action::Exit => "exit".to_string(), - Action::Redraw => "redraw".to_string(), - Action::CycleFilterMode => "cycle-filter-mode".to_string(), - Action::CycleSearchMode => "cycle-search-mode".to_string(), - Action::SwitchContext => "switch-context".to_string(), - Action::ClearContext => "clear-context".to_string(), - Action::ToggleTab => "toggle-tab".to_string(), - - Action::VimEnterNormal => "vim-enter-normal".to_string(), - Action::VimEnterInsert => "vim-enter-insert".to_string(), - Action::VimEnterInsertAfter => "vim-enter-insert-after".to_string(), - Action::VimEnterInsertAtStart => "vim-enter-insert-at-start".to_string(), - Action::VimEnterInsertAtEnd => "vim-enter-insert-at-end".to_string(), - Action::VimSearchInsert => "vim-search-insert".to_string(), - Action::VimChangeToEnd => "vim-change-to-end".to_string(), - Action::EnterPrefixMode => "enter-prefix-mode".to_string(), - - Action::InspectPrevious => "inspect-previous".to_string(), - Action::InspectNext => "inspect-next".to_string(), - - Action::Noop => "noop".to_string(), - } - } -} - -impl fmt::Display for Action { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -impl Serialize for Action { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.as_str()) - } -} - -impl<'de> Deserialize<'de> for Action { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - Action::from_str(&s).map_err(serde::de::Error::custom) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_basic_actions() { - assert_eq!(Action::from_str("cursor-left").unwrap(), Action::CursorLeft); - assert_eq!(Action::from_str("accept").unwrap(), Action::Accept); - assert_eq!(Action::from_str("exit").unwrap(), Action::Exit); - assert_eq!(Action::from_str("noop").unwrap(), Action::Noop); - assert_eq!( - Action::from_str("vim-enter-normal").unwrap(), - Action::VimEnterNormal - ); - } - - #[test] - fn parse_accept_nth() { - assert_eq!(Action::from_str("accept-1").unwrap(), Action::AcceptNth(1)); - assert_eq!(Action::from_str("accept-9").unwrap(), Action::AcceptNth(9)); - } - - #[test] - fn parse_return_selection() { - assert_eq!( - Action::from_str("return-selection").unwrap(), - Action::ReturnSelection - ); - assert_eq!( - Action::from_str("return-selection-1").unwrap(), - Action::ReturnSelectionNth(1) - ); - assert_eq!( - Action::from_str("return-selection-9").unwrap(), - Action::ReturnSelectionNth(9) - ); - } - - #[test] - fn parse_unknown_action() { - assert!(Action::from_str("unknown-action").is_err()); - assert!(Action::from_str("accept-0").is_err()); - assert!(Action::from_str("accept-10").is_err()); - assert!(Action::from_str("return-selection-0").is_err()); - assert!(Action::from_str("return-selection-10").is_err()); - } - - #[test] - fn round_trip() { - let actions = vec![ - Action::CursorLeft, - Action::Accept, - Action::AcceptNth(5), - Action::ReturnSelection, - Action::ReturnSelectionNth(3), - Action::VimSearchInsert, - Action::ScrollToScreenMiddle, - ]; - for action in actions { - let s = action.as_str(); - let parsed = Action::from_str(&s).unwrap(); - assert_eq!(action, parsed); - } - } - - #[test] - fn serde_round_trip() { - let action = Action::CursorLeft; - let json = serde_json::to_string(&action).unwrap(); - assert_eq!(json, "\"cursor-left\""); - let parsed: Action = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed, Action::CursorLeft); - - let action = Action::AcceptNth(3); - let json = serde_json::to_string(&action).unwrap(); - assert_eq!(json, "\"accept-3\""); - let parsed: Action = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed, Action::AcceptNth(3)); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/conditions.rs b/crates/atuin/src/command/client/search/keybindings/conditions.rs deleted file mode 100644 index 055ae905..00000000 --- a/crates/atuin/src/command/client/search/keybindings/conditions.rs +++ /dev/null @@ -1,801 +0,0 @@ -use std::fmt; - -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -/// Atomic (leaf) conditions that can be evaluated against state. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ConditionAtom { - CursorAtStart, - CursorAtEnd, - InputEmpty, - OriginalInputEmpty, - ListAtEnd, - ListAtStart, - NoResults, - HasResults, - HasContext, -} - -/// Boolean expression tree over condition atoms. -/// -/// Supports negation, conjunction, and disjunction with standard precedence: -/// `!` binds tightest, then `&&`, then `||`. -/// -/// Examples of valid expression strings: -/// - `"cursor-at-start"` (bare atom) -/// - `"!no-results"` (negation) -/// - `"cursor-at-start && input-empty"` (conjunction) -/// - `"list-at-start || no-results"` (disjunction) -/// - `"(cursor-at-start && !input-empty) || no-results"` (grouping) -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ConditionExpr { - Atom(ConditionAtom), - Not(Box), - And(Box, Box), - Or(Box, Box), -} - -/// Context needed to evaluate conditions. This is a pure snapshot of state — -/// no references to mutable data. -pub struct EvalContext { - /// Current cursor position (unicode width units). - pub cursor_position: usize, - /// Width of the input string in unicode width units. - pub input_width: usize, - /// Byte length of the input string. - pub input_byte_len: usize, - /// Currently selected index in the results list. - pub selected_index: usize, - /// Total number of results. - pub results_len: usize, - /// Whether the original input (query passed to the TUI) was empty. - pub original_input_empty: bool, - /// Whether we use a search context of a command from the history. - pub has_context: bool, -} - -// --------------------------------------------------------------------------- -// ConditionAtom -// --------------------------------------------------------------------------- - -impl ConditionAtom { - /// Evaluate this atom against the given context. - pub fn evaluate(&self, ctx: &EvalContext) -> bool { - match self { - ConditionAtom::CursorAtStart => ctx.cursor_position == 0, - ConditionAtom::CursorAtEnd => ctx.cursor_position == ctx.input_width, - ConditionAtom::InputEmpty => ctx.input_byte_len == 0, - ConditionAtom::OriginalInputEmpty => ctx.original_input_empty, - ConditionAtom::ListAtEnd => { - ctx.results_len == 0 || ctx.selected_index >= ctx.results_len.saturating_sub(1) - } - ConditionAtom::ListAtStart => ctx.results_len == 0 || ctx.selected_index == 0, - ConditionAtom::NoResults => ctx.results_len == 0, - ConditionAtom::HasResults => ctx.results_len > 0, - ConditionAtom::HasContext => ctx.has_context, - } - } - - /// Parse from a kebab-case string. - pub fn from_str(s: &str) -> Result { - match s { - "cursor-at-start" => Ok(ConditionAtom::CursorAtStart), - "cursor-at-end" => Ok(ConditionAtom::CursorAtEnd), - "input-empty" => Ok(ConditionAtom::InputEmpty), - "original-input-empty" => Ok(ConditionAtom::OriginalInputEmpty), - "list-at-end" => Ok(ConditionAtom::ListAtEnd), - "list-at-start" => Ok(ConditionAtom::ListAtStart), - "no-results" => Ok(ConditionAtom::NoResults), - "has-results" => Ok(ConditionAtom::HasResults), - "has-context" => Ok(ConditionAtom::HasContext), - _ => Err(format!("unknown condition: {s}")), - } - } - - /// Convert to a kebab-case string. - pub fn as_str(&self) -> &'static str { - match self { - ConditionAtom::CursorAtStart => "cursor-at-start", - ConditionAtom::CursorAtEnd => "cursor-at-end", - ConditionAtom::InputEmpty => "input-empty", - ConditionAtom::OriginalInputEmpty => "original-input-empty", - ConditionAtom::ListAtEnd => "list-at-end", - ConditionAtom::ListAtStart => "list-at-start", - ConditionAtom::NoResults => "no-results", - ConditionAtom::HasResults => "has-results", - ConditionAtom::HasContext => "has-context", - } - } -} - -impl fmt::Display for ConditionAtom { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) - } -} - -// --------------------------------------------------------------------------- -// ConditionExpr — evaluation -// --------------------------------------------------------------------------- - -impl ConditionExpr { - /// Evaluate this expression against the given context. - pub fn evaluate(&self, ctx: &EvalContext) -> bool { - match self { - ConditionExpr::Atom(atom) => atom.evaluate(ctx), - ConditionExpr::Not(inner) => !inner.evaluate(ctx), - ConditionExpr::And(lhs, rhs) => lhs.evaluate(ctx) && rhs.evaluate(ctx), - ConditionExpr::Or(lhs, rhs) => lhs.evaluate(ctx) || rhs.evaluate(ctx), - } - } -} - -// --------------------------------------------------------------------------- -// ConditionExpr — ergonomic builders -// --------------------------------------------------------------------------- - -impl From for ConditionExpr { - fn from(atom: ConditionAtom) -> Self { - ConditionExpr::Atom(atom) - } -} - -#[expect(dead_code)] -impl ConditionExpr { - /// Negate this expression: `!self`. - pub fn not(self) -> Self { - ConditionExpr::Not(Box::new(self)) - } - - /// Conjoin with another expression: `self && other`. - pub fn and(self, other: ConditionExpr) -> Self { - ConditionExpr::And(Box::new(self), Box::new(other)) - } - - /// Disjoin with another expression: `self || other`. - pub fn or(self, other: ConditionExpr) -> Self { - ConditionExpr::Or(Box::new(self), Box::new(other)) - } -} - -// --------------------------------------------------------------------------- -// ConditionExpr — parser -// --------------------------------------------------------------------------- - -/// Recursive descent parser for boolean condition expressions. -/// -/// Grammar (standard boolean precedence): -/// ```text -/// expr = or_expr -/// or_expr = and_expr ("||" and_expr)* -/// and_expr = unary ("&&" unary)* -/// unary = "!" unary | primary -/// primary = atom | "(" expr ")" -/// atom = [a-z][a-z0-9-]* -/// ``` -struct ExprParser<'a> { - input: &'a str, - pos: usize, -} - -impl<'a> ExprParser<'a> { - fn new(input: &'a str) -> Self { - Self { input, pos: 0 } - } - - fn skip_whitespace(&mut self) { - while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() { - self.pos += 1; - } - } - - fn starts_with(&mut self, s: &str) -> bool { - self.skip_whitespace(); - self.input[self.pos..].starts_with(s) - } - - fn consume(&mut self, s: &str) -> bool { - self.skip_whitespace(); - if self.input[self.pos..].starts_with(s) { - self.pos += s.len(); - true - } else { - false - } - } - - /// Parse a full expression, expecting to consume all input. - fn parse(mut self) -> Result { - let expr = self.parse_or()?; - self.skip_whitespace(); - if self.pos < self.input.len() { - return Err(format!( - "unexpected input at position {}: {:?}", - self.pos, - &self.input[self.pos..] - )); - } - Ok(expr) - } - - /// `or_expr` = `and_expr` ("||" `and_expr`)* - fn parse_or(&mut self) -> Result { - let mut left = self.parse_and()?; - while self.starts_with("||") { - self.consume("||"); - let right = self.parse_and()?; - left = ConditionExpr::Or(Box::new(left), Box::new(right)); - } - Ok(left) - } - - /// `and_expr` = unary ("&&" unary)* - fn parse_and(&mut self) -> Result { - let mut left = self.parse_unary()?; - while self.starts_with("&&") { - self.consume("&&"); - let right = self.parse_unary()?; - left = ConditionExpr::And(Box::new(left), Box::new(right)); - } - Ok(left) - } - - /// unary = "!" unary | primary - fn parse_unary(&mut self) -> Result { - if self.consume("!") { - let inner = self.parse_unary()?; - Ok(ConditionExpr::Not(Box::new(inner))) - } else { - self.parse_primary() - } - } - - /// primary = "(" expr ")" | atom - fn parse_primary(&mut self) -> Result { - if self.consume("(") { - let expr = self.parse_or()?; - if !self.consume(")") { - return Err(format!("expected ')' at position {}", self.pos)); - } - Ok(expr) - } else { - self.parse_atom() - } - } - - /// atom = [a-z][a-z0-9-]* - fn parse_atom(&mut self) -> Result { - self.skip_whitespace(); - let start = self.pos; - while self.pos < self.input.len() { - let b = self.input.as_bytes()[self.pos]; - if b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-' { - self.pos += 1; - } else { - break; - } - } - if self.pos == start { - return Err(format!("expected condition name at position {}", self.pos)); - } - let name = &self.input[start..self.pos]; - let atom = ConditionAtom::from_str(name)?; - Ok(ConditionExpr::Atom(atom)) - } -} - -impl ConditionExpr { - /// Parse a condition expression from a string. - pub fn parse(s: &str) -> Result { - let parser = ExprParser::new(s); - parser.parse() - } -} - -// --------------------------------------------------------------------------- -// ConditionExpr — Display -// --------------------------------------------------------------------------- - -/// Precedence levels for minimal-parentheses display. -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] -enum Prec { - Or = 0, - And = 1, - Not = 2, - Atom = 3, -} - -impl ConditionExpr { - fn prec(&self) -> Prec { - match self { - ConditionExpr::Or(..) => Prec::Or, - ConditionExpr::And(..) => Prec::And, - ConditionExpr::Not(..) => Prec::Not, - ConditionExpr::Atom(..) => Prec::Atom, - } - } - - fn fmt_with_prec(&self, f: &mut fmt::Formatter<'_>, parent_prec: Prec) -> fmt::Result { - let needs_parens = self.prec() < parent_prec; - if needs_parens { - write!(f, "(")?; - } - match self { - ConditionExpr::Atom(atom) => write!(f, "{atom}")?, - ConditionExpr::Not(inner) => { - write!(f, "!")?; - inner.fmt_with_prec(f, Prec::Not)?; - } - ConditionExpr::And(lhs, rhs) => { - lhs.fmt_with_prec(f, Prec::And)?; - write!(f, " && ")?; - rhs.fmt_with_prec(f, Prec::And)?; - } - ConditionExpr::Or(lhs, rhs) => { - lhs.fmt_with_prec(f, Prec::Or)?; - write!(f, " || ")?; - rhs.fmt_with_prec(f, Prec::Or)?; - } - } - if needs_parens { - write!(f, ")")?; - } - Ok(()) - } -} - -impl fmt::Display for ConditionExpr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.fmt_with_prec(f, Prec::Or) - } -} - -// --------------------------------------------------------------------------- -// Serde -// --------------------------------------------------------------------------- - -impl Serialize for ConditionExpr { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for ConditionExpr { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - ConditionExpr::parse(&s).map_err(serde::de::Error::custom) - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - fn ctx( - cursor: usize, - width: usize, - byte_len: usize, - selected: usize, - len: usize, - ) -> EvalContext { - ctx_with_original(cursor, width, byte_len, selected, len, false) - } - - fn ctx_with_original( - cursor: usize, - width: usize, - byte_len: usize, - selected: usize, - len: usize, - original_input_empty: bool, - ) -> EvalContext { - EvalContext { - cursor_position: cursor, - input_width: width, - input_byte_len: byte_len, - selected_index: selected, - results_len: len, - original_input_empty, - has_context: false, - } - } - - // -- Atom evaluation (carried over from Phase 0) -- - - #[test] - fn atom_cursor_at_start() { - assert!(ConditionAtom::CursorAtStart.evaluate(&ctx(0, 5, 5, 0, 10))); - assert!(!ConditionAtom::CursorAtStart.evaluate(&ctx(3, 5, 5, 0, 10))); - } - - #[test] - fn atom_cursor_at_end() { - assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(5, 5, 5, 0, 10))); - assert!(!ConditionAtom::CursorAtEnd.evaluate(&ctx(3, 5, 5, 0, 10))); - assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(0, 0, 0, 0, 10))); - } - - #[test] - fn atom_input_empty() { - assert!(ConditionAtom::InputEmpty.evaluate(&ctx(0, 0, 0, 0, 10))); - assert!(!ConditionAtom::InputEmpty.evaluate(&ctx(0, 5, 5, 0, 10))); - } - - #[test] - fn atom_original_input_empty() { - // original_input_empty = true - assert!( - ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, true)) - ); - // original_input_empty = false - assert!( - !ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, false)) - ); - // original_input_empty is independent of current input state - assert!( - ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 5, 5, 0, 10, true)) - ); - } - - #[test] - fn atom_list_at_end() { - assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 99, 100))); - assert!(!ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 50, 100))); - assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 0, 0))); - } - - #[test] - fn atom_list_at_start() { - assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 100))); - assert!(!ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 50, 100))); - assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 0))); - } - - #[test] - fn atom_no_results_and_has_results() { - assert!(ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 0))); - assert!(!ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 5))); - assert!(ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 5))); - assert!(!ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 0))); - } - - #[test] - fn atom_has_context() { - let mut context = ctx(0, 0, 0, 0, 0); - assert!(!ConditionAtom::HasContext.evaluate(&context)); - context.has_context = true; - assert!(ConditionAtom::HasContext.evaluate(&context)); - } - - #[test] - fn atom_parse_round_trip() { - let conditions = [ - "cursor-at-start", - "cursor-at-end", - "input-empty", - "original-input-empty", - "list-at-end", - "list-at-start", - "no-results", - "has-results", - ]; - for s in conditions { - let c = ConditionAtom::from_str(s).unwrap(); - assert_eq!(c.as_str(), s); - } - } - - #[test] - fn atom_parse_unknown() { - assert!(ConditionAtom::from_str("unknown-condition").is_err()); - } - - // -- Parser tests -- - - #[test] - fn parse_bare_atom() { - let expr = ConditionExpr::parse("cursor-at-start").unwrap(); - assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); - } - - #[test] - fn parse_negation() { - let expr = ConditionExpr::parse("!no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::Not(Box::new(ConditionExpr::Atom(ConditionAtom::NoResults))) - ); - } - - #[test] - fn parse_double_negation() { - let expr = ConditionExpr::parse("!!no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::Not(Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( - ConditionAtom::NoResults - ))))) - ); - } - - #[test] - fn parse_and() { - let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); - assert_eq!( - expr, - ConditionExpr::And( - Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), - Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), - ) - ); - } - - #[test] - fn parse_or() { - let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::Or( - Box::new(ConditionExpr::Atom(ConditionAtom::ListAtStart)), - Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), - ) - ); - } - - #[test] - fn parse_precedence_and_binds_tighter_than_or() { - // "a || b && c" should parse as "a || (b && c)" - let expr = ConditionExpr::parse("cursor-at-start || input-empty && no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::Or( - Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), - Box::new(ConditionExpr::And( - Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), - Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), - )), - ) - ); - } - - #[test] - fn parse_parens_override_precedence() { - // "(a || b) && c" - let expr = ConditionExpr::parse("(cursor-at-start || input-empty) && no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::And( - Box::new(ConditionExpr::Or( - Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), - Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), - )), - Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), - ) - ); - } - - #[test] - fn parse_complex_nested() { - // "(a && !b) || c" - let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); - assert_eq!( - expr, - ConditionExpr::Or( - Box::new(ConditionExpr::And( - Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), - Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( - ConditionAtom::InputEmpty - )))), - )), - Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), - ) - ); - } - - #[test] - fn parse_whitespace_tolerance() { - let a = ConditionExpr::parse("cursor-at-start||input-empty").unwrap(); - let b = ConditionExpr::parse("cursor-at-start || input-empty").unwrap(); - let c = ConditionExpr::parse(" cursor-at-start || input-empty ").unwrap(); - assert_eq!(a, b); - assert_eq!(b, c); - } - - #[test] - fn parse_error_unknown_atom() { - assert!(ConditionExpr::parse("unknown-thing").is_err()); - } - - #[test] - fn parse_error_trailing_input() { - assert!(ConditionExpr::parse("cursor-at-start blah").is_err()); - } - - #[test] - fn parse_error_unmatched_paren() { - assert!(ConditionExpr::parse("(cursor-at-start").is_err()); - } - - #[test] - fn parse_error_empty() { - assert!(ConditionExpr::parse("").is_err()); - } - - // -- Expression evaluation -- - - #[test] - fn eval_not() { - let expr = ConditionExpr::parse("!no-results").unwrap(); - // Has results → !no-results is true - assert!(expr.evaluate(&ctx(0, 0, 0, 0, 5))); - // No results → !no-results is false - assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 0))); - } - - #[test] - fn eval_and() { - let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); - // Both true - assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); - // First true, second false (non-empty input) - assert!(!expr.evaluate(&ctx(0, 5, 5, 0, 10))); - // First false (cursor not at start) - assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); - } - - #[test] - fn eval_or() { - let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); - // list at bottom (selected=0) - assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); - // no results - assert!(expr.evaluate(&ctx(0, 0, 0, 0, 0))); - // neither - assert!(!expr.evaluate(&ctx(0, 0, 0, 5, 10))); - } - - #[test] - fn eval_complex_nested() { - // (cursor-at-start && !input-empty) || no-results - let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); - - // cursor at start, input not empty → true (left branch) - assert!(expr.evaluate(&ctx(0, 5, 5, 0, 10))); - // no results → true (right branch) - assert!(expr.evaluate(&ctx(3, 5, 5, 0, 0))); - // cursor not at start, has results → false - assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); - // cursor at start, input empty → false (left: && fails; right: has results) - assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 10))); - } - - // -- Display -- - - #[test] - fn display_atom() { - let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); - assert_eq!(expr.to_string(), "cursor-at-start"); - } - - #[test] - fn display_not() { - let expr = ConditionExpr::Atom(ConditionAtom::NoResults).not(); - assert_eq!(expr.to_string(), "!no-results"); - } - - #[test] - fn display_and() { - let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) - .and(ConditionExpr::Atom(ConditionAtom::InputEmpty)); - assert_eq!(expr.to_string(), "cursor-at-start && input-empty"); - } - - #[test] - fn display_or() { - let expr = ConditionExpr::Atom(ConditionAtom::ListAtStart) - .or(ConditionExpr::Atom(ConditionAtom::NoResults)); - assert_eq!(expr.to_string(), "list-at-start || no-results"); - } - - #[test] - fn display_parens_when_needed() { - // (a || b) && c — the Or inside And needs parens - let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) - .or(ConditionExpr::Atom(ConditionAtom::InputEmpty)) - .and(ConditionExpr::Atom(ConditionAtom::NoResults)); - assert_eq!( - expr.to_string(), - "(cursor-at-start || input-empty) && no-results" - ); - } - - #[test] - fn display_no_parens_when_not_needed() { - // a || b && c — no parens needed (and binds tighter) - let inner_and = ConditionExpr::Atom(ConditionAtom::InputEmpty) - .and(ConditionExpr::Atom(ConditionAtom::NoResults)); - let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart).or(inner_and); - assert_eq!( - expr.to_string(), - "cursor-at-start || input-empty && no-results" - ); - } - - // -- Display round-trip -- - - #[test] - fn display_round_trip() { - let cases = [ - "cursor-at-start", - "!no-results", - "cursor-at-start && input-empty", - "list-at-start || no-results", - "(cursor-at-start || input-empty) && no-results", - "(cursor-at-start && !input-empty) || no-results", - ]; - for s in cases { - let expr = ConditionExpr::parse(s).unwrap(); - let displayed = expr.to_string(); - let reparsed = ConditionExpr::parse(&displayed).unwrap(); - assert_eq!(expr, reparsed, "round-trip failed for: {s}"); - } - } - - // -- Serde -- - - #[test] - fn serde_simple_atom() { - let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); - let json = serde_json::to_string(&expr).unwrap(); - assert_eq!(json, "\"cursor-at-start\""); - let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed, expr); - } - - #[test] - fn serde_compound_expression() { - let json = "\"cursor-at-start && !input-empty\""; - let parsed: ConditionExpr = serde_json::from_str(json).unwrap(); - let expected = ConditionExpr::And( - Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), - Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( - ConditionAtom::InputEmpty, - )))), - ); - assert_eq!(parsed, expected); - } - - #[test] - fn serde_round_trip() { - let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); - let json = serde_json::to_string(&expr).unwrap(); - let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); - assert_eq!(expr, parsed); - } - - // -- From -- - - #[test] - fn from_atom_into_expr() { - let expr: ConditionExpr = ConditionAtom::CursorAtStart.into(); - assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); - } - - // -- Builder helpers -- - - #[test] - fn builder_chain() { - let expr = ConditionExpr::from(ConditionAtom::CursorAtStart) - .and(ConditionExpr::from(ConditionAtom::InputEmpty).not()) - .or(ConditionExpr::from(ConditionAtom::NoResults)); - // And binds tighter than Or, so no parens needed around the And - assert_eq!( - expr.to_string(), - "cursor-at-start && !input-empty || no-results" - ); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/defaults.rs b/crates/atuin/src/command/client/search/keybindings/defaults.rs deleted file mode 100644 index a76cd4a9..00000000 --- a/crates/atuin/src/command/client/search/keybindings/defaults.rs +++ /dev/null @@ -1,1286 +0,0 @@ -use std::collections::HashMap; - -use atuin_client::settings::{KeyBindingConfig, Settings}; -use tracing::warn; - -use super::actions::Action; -use super::conditions::{ConditionAtom, ConditionExpr}; -use super::key::KeyInput; -use super::keymap::{KeyBinding, KeyRule, Keymap}; - -/// Helper to bind a scroll key with optional exit behavior. -/// -/// When `scroll_exits` is true AND the key scrolls toward index 0 (the newest -/// entry), we add a conditional rule: at `ListAtStart` → `Exit`, otherwise → -/// the scroll action. -/// -/// Whether a key scrolls toward index 0 depends on the `invert` setting: -/// - Non-inverted: "down" / "j" move toward index 0, "up" / "k" move away -/// - Inverted: "up" / "k" move toward index 0, "down" / "j" move away -/// -/// If `toward_index_zero` is false, or `scroll_exits` is false, we just bind -/// the key to the plain scroll action (no exit). -fn bind_scroll_key( - km: &mut Keymap, - key_str: &str, - action: Action, - toward_index_zero: bool, - scroll_exits: bool, -) { - let k = key(key_str); - if scroll_exits && toward_index_zero { - km.bind_conditional( - k, - vec![ - KeyRule::when(ConditionAtom::ListAtStart, Action::Exit), - KeyRule::always(action), - ], - ); - } else { - km.bind(k, action); - } -} - -/// Helper to parse a key string, panicking on invalid keys (these are all -/// compile-time-known strings). -fn key(s: &str) -> KeyInput { - KeyInput::parse(s).unwrap_or_else(|e| panic!("invalid default key {s:?}: {e}")) -} - -/// All five keymaps bundled together. -#[derive(Debug, Clone)] -pub struct KeymapSet { - pub emacs: Keymap, - pub vim_normal: Keymap, - pub vim_insert: Keymap, - pub inspector: Keymap, - pub prefix: Keymap, -} - -// --------------------------------------------------------------------------- -// Common bindings shared across search-tab keymaps -// --------------------------------------------------------------------------- - -/// Add the bindings that are common to all search-tab keymaps: -/// ctrl-c, ctrl-g, ctrl-o, and tab. -/// -/// Note: `esc`/`ctrl-[` are NOT included here because their behavior differs -/// between emacs (exit), vim-normal (exit), and vim-insert (enter normal mode). -fn add_common_bindings(km: &mut Keymap) { - km.bind(key("ctrl-c"), Action::ReturnOriginal); - km.bind(key("ctrl-g"), Action::ReturnOriginal); - km.bind(key("ctrl-o"), Action::ToggleTab); - - // Tab: always returns selection without executing (unlike Enter which respects enter_accept) - km.bind(key("tab"), Action::ReturnSelection); -} - -/// Returns `Accept` or `ReturnSelection` based on the `enter_accept` setting. -fn accept_action(settings: &Settings) -> Action { - if settings.enter_accept { - Action::Accept - } else { - Action::ReturnSelection - } -} - -// --------------------------------------------------------------------------- -// Emacs keymap (also base for vim-insert) -// --------------------------------------------------------------------------- - -/// Build the default emacs keymap. This encodes the behavior from -/// `handle_key_input` common section + `handle_search_input` shared section. -/// -/// The `settings` parameter is used for: -/// - `keys.prefix` — which ctrl-key enters prefix mode -/// - `keys.scroll_exits`, `invert` — scroll-at-boundary exit behavior -/// - `keys.accept_past_line_end` — right arrow at end of line accepts -/// - `keys.exit_past_line_start` — left arrow at start of line exits -/// - `keys.accept_past_line_start` — left arrow at start accepts (overrides exit) -/// - `keys.accept_with_backspace` — backspace at start of line accepts -/// - `ctrl_n_shortcuts` — whether alt or ctrl is used for numeric shortcuts -// Keymap builder that enumerates every default binding; not worth splitting. -#[expect(clippy::too_many_lines)] -pub fn default_emacs_keymap(settings: &Settings) -> Keymap { - let mut km = Keymap::new(); - add_common_bindings(&mut km); - - let accept = accept_action(settings); - - // esc / ctrl-[ → exit - km.bind(key("esc"), Action::Exit); - km.bind(key("ctrl-["), Action::Exit); - - // Prefix key: ctrl- → enter prefix mode - let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); - km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); - - // --- Accept / navigation edge behaviors (from [keys] settings) --- - - // right: behavior at end of line - if settings.keys.accept_past_line_end { - km.bind_conditional( - key("right"), - vec![ - KeyRule::when(ConditionAtom::CursorAtEnd, Action::ReturnSelection), - KeyRule::always(Action::CursorRight), - ], - ); - } else { - km.bind(key("right"), Action::CursorRight); - } - - // left: behavior at start of line - // accept_past_line_start takes precedence over exit_past_line_start - if settings.keys.accept_past_line_start { - km.bind_conditional( - key("left"), - vec![ - KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), - KeyRule::always(Action::CursorLeft), - ], - ); - } else if settings.keys.exit_past_line_start { - km.bind_conditional( - key("left"), - vec![ - KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), - KeyRule::always(Action::CursorLeft), - ], - ); - } else { - km.bind(key("left"), Action::CursorLeft); - } - - // down/up: scroll with optional exit at boundary. - // Non-inverted: down moves toward index 0 (can exit); up moves away (no exit). - // Inverted: up moves toward index 0 (can exit); down moves away (no exit). - let scroll_exits = settings.keys.scroll_exits; - let invert = settings.invert; - bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); - bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); - - // backspace: behavior at start of line - if settings.keys.accept_with_backspace { - km.bind_conditional( - key("backspace"), - vec![ - KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), - KeyRule::always(Action::DeleteCharBefore), - ], - ); - } else { - km.bind(key("backspace"), Action::DeleteCharBefore); - } - - // --- Accept --- - km.bind(key("enter"), accept.clone()); - km.bind(key("ctrl-m"), accept); - - // --- Copy --- - km.bind(key("ctrl-y"), Action::Copy); - - // --- Numeric shortcuts (alt-1..9 by default, ctrl-1..9 if ctrl_n_shortcuts) --- - // These return the selection without executing, regardless of enter_accept. - let num_mod = if settings.ctrl_n_shortcuts { - "ctrl" - } else { - "alt" - }; - for n in 1..=9u8 { - km.bind( - key(&format!("{num_mod}-{n}")), - Action::ReturnSelectionNth(n), - ); - } - - // --- Cursor movement --- - km.bind(key("ctrl-left"), Action::CursorWordLeft); - km.bind(key("alt-b"), Action::CursorWordLeft); - km.bind(key("ctrl-b"), Action::CursorLeft); - km.bind(key("ctrl-right"), Action::CursorWordRight); - km.bind(key("alt-f"), Action::CursorWordRight); - km.bind(key("ctrl-f"), Action::CursorRight); - km.bind(key("home"), Action::CursorStart); - // ctrl-a → CursorStart only if prefix char is NOT 'a' - // (otherwise ctrl-a is already bound to EnterPrefixMode above) - if prefix_char != 'a' { - km.bind(key("ctrl-a"), Action::CursorStart); - } - km.bind(key("ctrl-e"), Action::CursorEnd); - km.bind(key("end"), Action::CursorEnd); - - // --- Editing --- - km.bind(key("ctrl-backspace"), Action::DeleteWordBefore); - km.bind(key("ctrl-h"), Action::DeleteCharBefore); - km.bind(key("ctrl-?"), Action::DeleteCharBefore); - km.bind(key("ctrl-delete"), Action::DeleteWordAfter); - km.bind(key("delete"), Action::DeleteCharAfter); - // ctrl-d: if input empty → return original, otherwise delete char - km.bind_conditional( - key("ctrl-d"), - vec![ - KeyRule::when(ConditionAtom::InputEmpty, Action::ReturnOriginal), - KeyRule::always(Action::DeleteCharAfter), - ], - ); - km.bind(key("ctrl-w"), Action::DeleteToWordBoundary); - km.bind(key("ctrl-u"), Action::ClearLine); - - // --- Search mode --- - km.bind(key("ctrl-r"), Action::CycleFilterMode); - km.bind(key("ctrl-s"), Action::CycleSearchMode); - - // --- Scroll (no exit) --- - km.bind(key("ctrl-n"), Action::SelectNext); - km.bind(key("ctrl-j"), Action::SelectNext); - km.bind(key("ctrl-p"), Action::SelectPrevious); - km.bind(key("ctrl-k"), Action::SelectPrevious); - - // --- Redraw --- - km.bind(key("ctrl-l"), Action::Redraw); - - // --- Page scroll --- - km.bind(key("pagedown"), Action::ScrollPageDown); - km.bind(key("pageup"), Action::ScrollPageUp); - - km -} - -// --------------------------------------------------------------------------- -// Vim Normal keymap -// --------------------------------------------------------------------------- - -/// Build the default vim-normal keymap. -pub fn default_vim_normal_keymap(settings: &Settings) -> Keymap { - let mut km = Keymap::new(); - add_common_bindings(&mut km); - - // esc / ctrl-[ → exit (vim-normal exits, unlike vim-insert) - km.bind(key("esc"), Action::Exit); - km.bind(key("ctrl-["), Action::Exit); - - // Prefix key - let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); - km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); - - // --- Vim navigation --- - // j/k: scroll with optional exit at boundary. - let scroll_exits = settings.keys.scroll_exits; - let invert = settings.invert; - bind_scroll_key(&mut km, "j", Action::SelectNext, !invert, scroll_exits); - bind_scroll_key(&mut km, "k", Action::SelectPrevious, invert, scroll_exits); - km.bind(key("h"), Action::CursorLeft); - km.bind(key("l"), Action::CursorRight); - - // --- Vim cursor movement --- - km.bind(key("0"), Action::CursorStart); - km.bind(key("$"), Action::CursorEnd); - km.bind(key("w"), Action::CursorWordRight); - km.bind(key("b"), Action::CursorWordLeft); - km.bind(key("e"), Action::CursorWordEnd); - - // --- Vim editing --- - km.bind(key("x"), Action::DeleteCharAfter); - km.bind(key("d d"), Action::ClearLine); - km.bind(key("D"), Action::ClearToEnd); - km.bind(key("C"), Action::VimChangeToEnd); - - // --- Mode switching --- - km.bind(key("?"), Action::VimSearchInsert); - km.bind(key("/"), Action::VimSearchInsert); - km.bind(key("a"), Action::VimEnterInsertAfter); - km.bind(key("A"), Action::VimEnterInsertAtEnd); - km.bind(key("i"), Action::VimEnterInsert); - km.bind(key("I"), Action::VimEnterInsertAtStart); - - // --- Numeric shortcuts (return selection without executing) --- - for n in 1..=9u8 { - km.bind(key(&n.to_string()), Action::ReturnSelectionNth(n)); - } - - // --- Half/full page scroll --- - km.bind(key("ctrl-u"), Action::ScrollHalfPageUp); - km.bind(key("ctrl-d"), Action::ScrollHalfPageDown); - km.bind(key("ctrl-b"), Action::ScrollPageUp); - km.bind(key("ctrl-f"), Action::ScrollPageDown); - - // --- Jump --- - km.bind(key("G"), Action::ScrollToBottom); - km.bind(key("g g"), Action::ScrollToTop); - km.bind(key("H"), Action::ScrollToScreenTop); - km.bind(key("M"), Action::ScrollToScreenMiddle); - km.bind(key("L"), Action::ScrollToScreenBottom); - - // --- Arrow keys (same as emacs for convenience) --- - bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); - bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); - - // --- Page scroll --- - km.bind(key("pagedown"), Action::ScrollPageDown); - km.bind(key("pageup"), Action::ScrollPageUp); - - // --- Accept --- - let accept = accept_action(settings); - km.bind(key("enter"), accept); - - km -} - -// --------------------------------------------------------------------------- -// Vim Insert keymap -// --------------------------------------------------------------------------- - -/// Build the default vim-insert keymap. This clones the emacs keymap and -/// overlays vim-insert-specific bindings (esc → enter normal mode). -pub fn default_vim_insert_keymap(settings: &Settings) -> Keymap { - let mut km = default_emacs_keymap(settings); - - // Override esc and ctrl-[ to enter normal mode instead of exiting - km.bind(key("esc"), Action::VimEnterNormal); - km.bind(key("ctrl-["), Action::VimEnterNormal); - - km -} - -// --------------------------------------------------------------------------- -// Inspector keymap -// --------------------------------------------------------------------------- - -/// Build the default inspector keymap (tab index 1). -/// -/// The inspector shows details about the selected history item and has no -/// text input, so we build a minimal keymap with only inspector-relevant -/// bindings. We respect the user's `keymap_mode` to provide vim-style j/k -/// navigation for vim users. -pub fn default_inspector_keymap(settings: &Settings) -> Keymap { - use atuin_client::settings::KeymapMode; - - let mut km = Keymap::new(); - - // Common bindings (same as search tab) - km.bind(key("ctrl-c"), Action::ReturnOriginal); - km.bind(key("ctrl-g"), Action::ReturnOriginal); - km.bind(key("esc"), Action::Exit); - km.bind(key("ctrl-["), Action::Exit); - km.bind(key("tab"), Action::ReturnSelection); - km.bind(key("ctrl-o"), Action::ToggleTab); - - // Accept behavior respects enter_accept setting - let accept = if settings.enter_accept { - Action::Accept - } else { - Action::ReturnSelection - }; - km.bind(key("enter"), accept); - - // Inspector-specific: delete history entry - km.bind(key("ctrl-d"), Action::Delete); - - // Inspector navigation - km.bind(key("up"), Action::InspectPrevious); - km.bind(key("down"), Action::InspectNext); - km.bind(key("pageup"), Action::InspectPrevious); - km.bind(key("pagedown"), Action::InspectNext); - - // For vim users, add j/k navigation - if matches!( - settings.keymap_mode, - KeymapMode::VimNormal | KeymapMode::VimInsert - ) { - km.bind(key("j"), Action::InspectNext); - km.bind(key("k"), Action::InspectPrevious); - } - - km -} - -// --------------------------------------------------------------------------- -// Prefix keymap -// --------------------------------------------------------------------------- - -/// Build the default prefix keymap (active after ctrl-a prefix). -pub fn default_prefix_keymap() -> Keymap { - let mut km = Keymap::new(); - - km.bind(key("d"), Action::Delete); - km.bind(key("D"), Action::DeleteAll); - km.bind(key("a"), Action::CursorStart); - km.bind_conditional( - key("c"), - vec![ - KeyRule::when(ConditionAtom::HasContext, Action::ClearContext), - KeyRule::always(Action::SwitchContext), - ], - ); - - km -} - -// --------------------------------------------------------------------------- -// KeymapSet construction -// --------------------------------------------------------------------------- - -// --------------------------------------------------------------------------- -// Config → Keymap conversion -// --------------------------------------------------------------------------- - -/// Convert a `KeyBindingConfig` (from TOML) into a `KeyBinding`. -/// Returns `Err` if an action name or condition expression is invalid. -fn parse_binding_config(config: &KeyBindingConfig) -> Result { - match config { - KeyBindingConfig::Simple(action_str) => { - let action = Action::from_str(action_str)?; - Ok(KeyBinding::simple(action)) - } - KeyBindingConfig::Rules(rules) => { - let mut parsed_rules = Vec::with_capacity(rules.len()); - for rule_cfg in rules { - let action = Action::from_str(&rule_cfg.action)?; - let rule = match &rule_cfg.when { - None => KeyRule::always(action), - Some(cond_str) => { - let cond = ConditionExpr::parse(cond_str)?; - KeyRule::when(cond, action) - } - }; - parsed_rules.push(rule); - } - Ok(KeyBinding::conditional(parsed_rules)) - } - } -} - -/// Apply a map of key-string → binding-config overrides to a keymap. -/// Per-key override replaces the entire rule list for that key. -/// Invalid keys or action names are logged and skipped. -fn apply_config_to_keymap(keymap: &mut Keymap, overrides: &HashMap) { - for (key_str, binding_cfg) in overrides { - let key = match KeyInput::parse(key_str) { - Ok(k) => k, - Err(e) => { - warn!("invalid key in keymap config: {key_str:?}: {e}"); - continue; - } - }; - match parse_binding_config(binding_cfg) { - Ok(binding) => { - keymap.bindings.insert(key, binding); - } - Err(e) => { - warn!("invalid binding for {key_str:?} in keymap config: {e}"); - } - } - } -} - -impl KeymapSet { - /// Build the complete set of default keymaps from settings. - pub fn defaults(settings: &Settings) -> Self { - KeymapSet { - emacs: default_emacs_keymap(settings), - vim_normal: default_vim_normal_keymap(settings), - vim_insert: default_vim_insert_keymap(settings), - inspector: default_inspector_keymap(settings), - prefix: default_prefix_keymap(), - } - } - - /// Build keymaps from settings, applying any user `[keymap]` overrides. - /// - /// Precedence rules: - /// - If `[keymap]` has any entries, `[keys]` is **ignored entirely**. - /// Defaults are built with standard `[keys]` values, then `[keymap]` - /// overrides are applied per-key. - /// - If `[keymap]` is empty/absent, `[keys]` customizes the defaults - /// (current behavior for backward compatibility). - pub fn from_settings(settings: &Settings) -> Self { - use atuin_client::settings::Keys; - - if settings.keymap.is_empty() { - // No [keymap] section → use [keys] to customize defaults - Self::defaults(settings) - } else { - // [keymap] present → ignore [keys], use standard defaults as base - let mut base_settings = settings.clone(); - base_settings.keys = Keys::standard_defaults(); - let mut set = Self::defaults(&base_settings); - set.apply_config(settings); - set - } - } - - /// Apply user keymap config overrides to all modes. - fn apply_config(&mut self, settings: &Settings) { - let config = &settings.keymap; - apply_config_to_keymap(&mut self.emacs, &config.emacs); - apply_config_to_keymap(&mut self.vim_normal, &config.vim_normal); - apply_config_to_keymap(&mut self.vim_insert, &config.vim_insert); - apply_config_to_keymap(&mut self.inspector, &config.inspector); - apply_config_to_keymap(&mut self.prefix, &config.prefix); - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use crate::command::client::search::keybindings::conditions::EvalContext; - - fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { - EvalContext { - cursor_position: cursor, - input_width: width, - input_byte_len: width, - selected_index: selected, - results_len: len, - original_input_empty: false, - has_context: false, - } - } - - fn default_settings() -> Settings { - Settings::utc() - } - - // -- Emacs keymap tests -- - - #[test] - fn emacs_ctrl_c_returns_original() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("ctrl-c"), &ctx), - Some(Action::ReturnOriginal) - ); - } - - #[test] - fn emacs_esc_exits() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); - } - - #[test] - fn emacs_tab_returns_selection() { - // enter_accept=false in test defaults → ReturnSelection - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); - } - - #[test] - fn emacs_enter_returns_selection() { - // enter_accept=false in test defaults → ReturnSelection - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("enter"), &ctx), - Some(Action::ReturnSelection) - ); - } - - #[test] - fn emacs_enter_accept_true_uses_accept() { - let mut settings = default_settings(); - settings.enter_accept = true; - let km = default_emacs_keymap(&settings); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); - assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); - } - - #[test] - fn emacs_right_at_end_returns_selection() { - let km = default_emacs_keymap(&default_settings()); - // cursor at end of "hello" (width 5) - let ctx = make_ctx(5, 5, 0, 10); - assert_eq!( - km.resolve(&key("right"), &ctx), - Some(Action::ReturnSelection) - ); - } - - #[test] - fn emacs_right_not_at_end_moves() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(2, 5, 0, 10); - assert_eq!(km.resolve(&key("right"), &ctx), Some(Action::CursorRight)); - } - - #[test] - fn emacs_left_at_start_exits() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 5, 0, 10); - assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::Exit)); - } - - #[test] - fn emacs_left_not_at_start_moves() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(3, 5, 0, 10); - assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::CursorLeft)); - } - - #[test] - fn emacs_down_at_start_exits() { - let km = default_emacs_keymap(&default_settings()); - // selected=0 → ListAtStart → Exit - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::Exit)); - } - - #[test] - fn emacs_down_not_at_start_selects_next() { - let km = default_emacs_keymap(&default_settings()); - // selected=5 → not at start → SelectNext - let ctx = make_ctx(0, 0, 5, 10); - assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::SelectNext)); - } - - #[test] - fn emacs_up_selects_previous() { - let km = default_emacs_keymap(&default_settings()); - // Non-inverted: up never exits (moves away from index 0) - let ctx = make_ctx(0, 0, 5, 10); - assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::SelectPrevious)); - } - - #[test] - fn emacs_ctrl_d_empty_returns_original() { - let km = default_emacs_keymap(&default_settings()); - // input empty (byte_len = 0) - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("ctrl-d"), &ctx), - Some(Action::ReturnOriginal) - ); - } - - #[test] - fn emacs_ctrl_d_nonempty_deletes() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(2, 5, 0, 10); - assert_eq!( - km.resolve(&key("ctrl-d"), &ctx), - Some(Action::DeleteCharAfter) - ); - } - - #[test] - fn emacs_ctrl_n_selects_next_no_exit_condition() { - let km = default_emacs_keymap(&default_settings()); - // at start, but ctrl-n should NOT exit (no exit condition bound) - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("ctrl-n"), &ctx), Some(Action::SelectNext)); - } - - #[test] - fn emacs_prefix_key_enters_prefix() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("ctrl-a"), &ctx), - Some(Action::EnterPrefixMode) - ); - } - - #[test] - fn emacs_home_cursor_start() { - let km = default_emacs_keymap(&default_settings()); - let ctx = make_ctx(5, 10, 0, 10); - assert_eq!(km.resolve(&key("home"), &ctx), Some(Action::CursorStart)); - } - - // -- Vim Normal keymap tests -- - - #[test] - fn vim_normal_j_at_start_exits() { - let km = default_vim_normal_keymap(&default_settings()); - // selected=0 → ListAtStart → Exit (non-inverted: j moves toward index 0) - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::Exit)); - } - - #[test] - fn vim_normal_j_not_at_start_selects_next() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 5, 10); - assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::SelectNext)); - } - - #[test] - fn vim_normal_k_selects_previous() { - let km = default_vim_normal_keymap(&default_settings()); - // Non-inverted: k never exits (moves away from index 0) - let ctx = make_ctx(0, 0, 5, 10); - assert_eq!(km.resolve(&key("k"), &ctx), Some(Action::SelectPrevious)); - } - - #[test] - fn vim_normal_i_enters_insert() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("i"), &ctx), Some(Action::VimEnterInsert)); - } - - #[test] - fn vim_normal_slash_search_insert() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("/"), &ctx), Some(Action::VimSearchInsert)); - } - - #[test] - fn vim_normal_gg_scroll_to_top() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 50, 100); - assert_eq!(km.resolve(&key("g g"), &ctx), Some(Action::ScrollToTop)); - } - - #[test] - fn vim_normal_big_g_scroll_to_bottom() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 50, 100); - assert_eq!(km.resolve(&key("G"), &ctx), Some(Action::ScrollToBottom)); - } - - #[test] - fn vim_normal_numeric_returns_selection() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("3"), &ctx), - Some(Action::ReturnSelectionNth(3)) - ); - } - - #[test] - fn vim_normal_ctrl_u_half_page_up() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 50, 100); - assert_eq!( - km.resolve(&key("ctrl-u"), &ctx), - Some(Action::ScrollHalfPageUp) - ); - } - - #[test] - fn vim_normal_screen_jumps() { - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 50, 100); - assert_eq!(km.resolve(&key("H"), &ctx), Some(Action::ScrollToScreenTop)); - assert_eq!( - km.resolve(&key("M"), &ctx), - Some(Action::ScrollToScreenMiddle) - ); - assert_eq!( - km.resolve(&key("L"), &ctx), - Some(Action::ScrollToScreenBottom) - ); - } - - #[test] - fn vim_normal_enter_returns_selection() { - // enter_accept=false in test defaults → ReturnSelection - let km = default_vim_normal_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("enter"), &ctx), - Some(Action::ReturnSelection) - ); - } - - #[test] - fn vim_normal_enter_accept_true_uses_accept() { - let mut settings = default_settings(); - settings.enter_accept = true; - let km = default_vim_normal_keymap(&settings); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); - } - - // -- Vim Insert keymap tests -- - - #[test] - fn vim_insert_inherits_emacs_enter() { - let km = default_vim_insert_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - // enter_accept=false → ReturnSelection - assert_eq!( - km.resolve(&key("enter"), &ctx), - Some(Action::ReturnSelection) - ); - } - - #[test] - fn vim_insert_esc_enters_normal() { - let km = default_vim_insert_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::VimEnterNormal)); - } - - #[test] - fn vim_insert_ctrl_bracket_enters_normal() { - let km = default_vim_insert_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - km.resolve(&key("ctrl-["), &ctx), - Some(Action::VimEnterNormal) - ); - } - - #[test] - fn vim_insert_inherits_emacs_ctrl_d() { - let km = default_vim_insert_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - // input empty → return original - assert_eq!( - km.resolve(&key("ctrl-d"), &ctx), - Some(Action::ReturnOriginal) - ); - } - - // -- Inspector keymap tests -- - - #[test] - fn inspector_ctrl_d_deletes() { - let km = default_inspector_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("ctrl-d"), &ctx), Some(Action::Delete)); - } - - #[test] - fn inspector_up_inspects_previous() { - let km = default_inspector_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::InspectPrevious)); - } - - #[test] - fn inspector_down_inspects_next() { - let km = default_inspector_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::InspectNext)); - } - - #[test] - fn inspector_esc_exits() { - let km = default_inspector_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); - } - - #[test] - fn inspector_tab_returns_selection() { - // enter_accept=false → ReturnSelection - let km = default_inspector_keymap(&default_settings()); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); - } - - // -- Prefix keymap tests -- - - #[test] - fn prefix_d_deletes() { - let km = default_prefix_keymap(); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("d"), &ctx), Some(Action::Delete)); - } - - #[test] - fn prefix_a_cursor_start() { - let km = default_prefix_keymap(); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("a"), &ctx), Some(Action::CursorStart)); - } - - #[test] - fn prefix_unknown_key_returns_none() { - let km = default_prefix_keymap(); - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(km.resolve(&key("x"), &ctx), None); - } - - // -- KeymapSet tests -- - - #[test] - fn keymap_set_defaults_builds() { - let settings = default_settings(); - let set = KeymapSet::defaults(&settings); - let ctx = make_ctx(0, 0, 0, 10); - - // Sanity check each keymap has bindings - assert!(set.emacs.resolve(&key("ctrl-c"), &ctx).is_some()); - assert!(set.vim_normal.resolve(&key("ctrl-c"), &ctx).is_some()); - assert!(set.vim_insert.resolve(&key("ctrl-c"), &ctx).is_some()); - assert!(set.inspector.resolve(&key("ctrl-c"), &ctx).is_some()); - assert!(set.prefix.resolve(&key("d"), &ctx).is_some()); - } - - // -- Settings-dependent behavior -- - - #[test] - fn custom_prefix_char() { - let mut settings = default_settings(); - settings.keys.prefix = "x".to_string(); - let km = default_emacs_keymap(&settings); - let ctx = make_ctx(0, 0, 0, 10); - - // ctrl-x should be prefix mode - assert_eq!( - km.resolve(&key("ctrl-x"), &ctx), - Some(Action::EnterPrefixMode) - ); - // ctrl-a should now be CursorStart (not prefix) - assert_eq!(km.resolve(&key("ctrl-a"), &ctx), Some(Action::CursorStart)); - } - - #[test] - fn ctrl_n_shortcuts_changes_numeric_modifier() { - let mut settings = default_settings(); - settings.ctrl_n_shortcuts = true; - let km = default_emacs_keymap(&settings); - let ctx = make_ctx(0, 0, 0, 10); - - // ctrl-1 should work - assert_eq!( - km.resolve(&key("ctrl-1"), &ctx), - Some(Action::ReturnSelectionNth(1)) - ); - // alt-1 should NOT be bound - assert_eq!(km.resolve(&key("alt-1"), &ctx), None); - } - - #[test] - fn default_alt_numeric_shortcuts() { - let settings = default_settings(); - let km = default_emacs_keymap(&settings); - let ctx = make_ctx(0, 0, 0, 10); - - // alt-1 should work by default - assert_eq!( - km.resolve(&key("alt-1"), &ctx), - Some(Action::ReturnSelectionNth(1)) - ); - } - - // ----------------------------------------------------------------------- - // Config parsing and merging tests - // ----------------------------------------------------------------------- - - #[test] - fn parse_simple_binding_config() { - use atuin_client::settings::KeyBindingConfig; - let cfg = KeyBindingConfig::Simple("accept".to_string()); - let binding = super::parse_binding_config(&cfg).unwrap(); - assert_eq!(binding.rules.len(), 1); - assert!(binding.rules[0].condition.is_none()); - assert_eq!(binding.rules[0].action, Action::Accept); - } - - #[test] - fn parse_conditional_binding_config() { - use atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; - let cfg = KeyBindingConfig::Rules(vec![ - KeyRuleConfig { - when: Some("cursor-at-start".to_string()), - action: "exit".to_string(), - }, - KeyRuleConfig { - when: None, - action: "cursor-left".to_string(), - }, - ]); - let binding = super::parse_binding_config(&cfg).unwrap(); - assert_eq!(binding.rules.len(), 2); - assert!(binding.rules[0].condition.is_some()); - assert_eq!(binding.rules[0].action, Action::Exit); - assert!(binding.rules[1].condition.is_none()); - assert_eq!(binding.rules[1].action, Action::CursorLeft); - } - - #[test] - fn parse_binding_config_invalid_action() { - use atuin_client::settings::KeyBindingConfig; - let cfg = KeyBindingConfig::Simple("not-a-real-action".to_string()); - assert!(super::parse_binding_config(&cfg).is_err()); - } - - #[test] - fn parse_binding_config_invalid_condition() { - use atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; - let cfg = KeyBindingConfig::Rules(vec![KeyRuleConfig { - when: Some("not-a-real-condition".to_string()), - action: "exit".to_string(), - }]); - assert!(super::parse_binding_config(&cfg).is_err()); - } - - #[test] - fn config_override_replaces_key() { - use atuin_client::settings::KeyBindingConfig; - use std::collections::HashMap; - - let mut settings = default_settings(); - let set = KeymapSet::defaults(&settings); - - // Default: ctrl-c → ReturnOriginal - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - set.emacs.resolve(&key("ctrl-c"), &ctx), - Some(Action::ReturnOriginal) - ); - - // Override ctrl-c → Exit via config - settings.keymap.emacs = HashMap::from([( - "ctrl-c".to_string(), - KeyBindingConfig::Simple("exit".to_string()), - )]); - - let set = KeymapSet::from_settings(&settings); - assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); - } - - #[test] - fn config_override_preserves_unoverridden_keys() { - use atuin_client::settings::KeyBindingConfig; - use std::collections::HashMap; - - let mut settings = default_settings(); - // Override only ctrl-c; enter should keep its default - settings.keymap.emacs = HashMap::from([( - "ctrl-c".to_string(), - KeyBindingConfig::Simple("exit".to_string()), - )]); - - let set = KeymapSet::from_settings(&settings); - let ctx = make_ctx(0, 0, 0, 10); - - // ctrl-c overridden - assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); - // enter still has default (enter_accept=false → ReturnSelection) - assert_eq!( - set.emacs.resolve(&key("enter"), &ctx), - Some(Action::ReturnSelection) - ); - } - - #[test] - fn config_conditional_override() { - use atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; - use std::collections::HashMap; - - let mut settings = default_settings(); - // Override "up" with a custom conditional - settings.keymap.emacs = HashMap::from([( - "up".to_string(), - KeyBindingConfig::Rules(vec![ - KeyRuleConfig { - when: Some("no-results".to_string()), - action: "exit".to_string(), - }, - KeyRuleConfig { - when: None, - action: "select-previous".to_string(), - }, - ]), - )]); - - let set = KeymapSet::from_settings(&settings); - - // With no results → exit - let ctx = make_ctx(0, 0, 0, 0); - assert_eq!(set.emacs.resolve(&key("up"), &ctx), Some(Action::Exit)); - - // With results → select-previous - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!( - set.emacs.resolve(&key("up"), &ctx), - Some(Action::SelectPrevious) - ); - } - - #[test] - fn from_settings_with_empty_config_equals_defaults() { - let settings = default_settings(); - let defaults = KeymapSet::defaults(&settings); - let from_settings = KeymapSet::from_settings(&settings); - - // Verify a sample of keys produce the same results - let ctx = make_ctx(0, 0, 0, 10); - let test_keys = [ - "ctrl-c", "enter", "esc", "tab", "up", "down", "left", "right", - ]; - for k in &test_keys { - assert_eq!( - defaults.emacs.resolve(&key(k), &ctx), - from_settings.emacs.resolve(&key(k), &ctx), - "mismatch for emacs key {k}" - ); - } - } - - // ----------------------------------------------------------------------- - // Phase 5: [keys] vs [keymap] backward compatibility - // ----------------------------------------------------------------------- - - #[test] - fn keymap_overrides_ignore_keys_section() { - use atuin_client::settings::KeyBindingConfig; - - // Set up: [keys] disables scroll_exits, but [keymap] is present - let mut settings = default_settings(); - settings.keys.scroll_exits = false; - - // Without [keymap], scroll_exits=false means no exit condition on down - let set_legacy = KeymapSet::defaults(&settings); - // At list-at-start (selected=0), down should still be SelectNext (no exit) - let ctx_at_boundary = make_ctx(0, 0, 0, 10); - assert_eq!( - set_legacy.emacs.resolve(&key("down"), &ctx_at_boundary), - Some(Action::SelectNext), - "legacy: down at boundary should be SelectNext with scroll_exits=false" - ); - - // With [keymap] present (even just one override), [keys] is ignored - // so the standard defaults (scroll_exits=true) apply - settings.keymap.emacs = HashMap::from([( - "ctrl-c".to_string(), - KeyBindingConfig::Simple("exit".to_string()), - )]); - let set_keymap = KeymapSet::from_settings(&settings); - - // Not at boundary (selected=5): should SelectNext normally - let ctx_not_at_boundary = make_ctx(0, 0, 5, 10); - assert_eq!( - set_keymap.emacs.resolve(&key("down"), &ctx_not_at_boundary), - Some(Action::SelectNext), - "keymap: down not at boundary should SelectNext" - ); - // At list-at-start (selected=0): should Exit (standard scroll_exits=true) - assert_eq!( - set_keymap.emacs.resolve(&key("down"), &ctx_at_boundary), - Some(Action::Exit), - "keymap: down at boundary should Exit (standard defaults restored)" - ); - } - - #[test] - fn keymap_present_resets_to_standard_keys_defaults() { - use atuin_client::settings::KeyBindingConfig; - - let mut settings = default_settings(); - // Disable all [keys] behaviors - settings.keys.exit_past_line_start = false; - settings.keys.accept_past_line_end = false; - - // Without [keymap], left should be plain CursorLeft - let set_legacy = KeymapSet::defaults(&settings); - let ctx_at_start = make_ctx(0, 5, 0, 10); - assert_eq!( - set_legacy.emacs.resolve(&key("left"), &ctx_at_start), - Some(Action::CursorLeft), - "legacy: left should be plain CursorLeft without exit_past_line_start" - ); - - // Add a [keymap] entry (for a different key) - settings.keymap.emacs = HashMap::from([( - "ctrl-c".to_string(), - KeyBindingConfig::Simple("exit".to_string()), - )]); - let set_keymap = KeymapSet::from_settings(&settings); - - // Now left should use standard defaults (exit_past_line_start=true) - // At cursor start → Exit - assert_eq!( - set_keymap.emacs.resolve(&key("left"), &ctx_at_start), - Some(Action::Exit), - "keymap: left at cursor start should exit (standard defaults)" - ); - - // Right at cursor end should return selection (standard defaults: accept_past_line_end=true, enter_accept=false) - let ctx_at_end = make_ctx(5, 5, 0, 10); - assert_eq!( - set_keymap.emacs.resolve(&key("right"), &ctx_at_end), - Some(Action::ReturnSelection), - "keymap: right at cursor end should return selection (standard defaults)" - ); - } - - #[test] - fn keys_has_non_default_values_detection() { - use atuin_client::settings::Keys; - - let standard = Keys::standard_defaults(); - assert!(!standard.has_non_default_values()); - - let mut modified = Keys::standard_defaults(); - modified.scroll_exits = false; - assert!(modified.has_non_default_values()); - - let mut modified = Keys::standard_defaults(); - modified.prefix = "x".to_string(); - assert!(modified.has_non_default_values()); - } - - #[test] - fn original_input_empty_condition_in_config() { - use atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; - use std::collections::HashMap; - - let mut settings = default_settings(); - // Configure esc to: if original-input-empty -> return-query, else return-original - settings.keymap.emacs = HashMap::from([( - "esc".to_string(), - KeyBindingConfig::Rules(vec![ - KeyRuleConfig { - when: Some("original-input-empty".to_string()), - action: "return-query".to_string(), - }, - KeyRuleConfig { - when: None, - action: "return-original".to_string(), - }, - ]), - )]); - - let set = KeymapSet::from_settings(&settings); - - // When original input was empty, should return-query - let ctx_original_empty = EvalContext { - cursor_position: 0, - input_width: 5, - input_byte_len: 5, - selected_index: 0, - results_len: 10, - original_input_empty: true, - has_context: false, - }; - assert_eq!( - set.emacs.resolve(&key("esc"), &ctx_original_empty), - Some(Action::ReturnQuery), - "esc with original_input_empty=true should return-query" - ); - - // When original input was not empty, should return-original - let ctx_original_not_empty = EvalContext { - cursor_position: 0, - input_width: 5, - input_byte_len: 5, - selected_index: 0, - results_len: 10, - original_input_empty: false, - has_context: false, - }; - assert_eq!( - set.emacs.resolve(&key("esc"), &ctx_original_not_empty), - Some(Action::ReturnOriginal), - "esc with original_input_empty=false should return-original" - ); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/key.rs b/crates/atuin/src/command/client/search/keybindings/key.rs deleted file mode 100644 index c2eb31c6..00000000 --- a/crates/atuin/src/command/client/search/keybindings/key.rs +++ /dev/null @@ -1,629 +0,0 @@ -use std::fmt; - -use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers, MediaKeyCode}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -/// A single key press with modifiers (e.g. `ctrl-c`, `alt-f`, `enter`). -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[expect(clippy::struct_excessive_bools)] -pub struct SingleKey { - pub code: KeyCodeValue, - pub ctrl: bool, - pub alt: bool, - pub shift: bool, - pub super_key: bool, -} - -/// The key code portion of a key press. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum KeyCodeValue { - Char(char), - Enter, - Esc, - Tab, - Backspace, - Delete, - Insert, - Up, - Down, - Left, - Right, - Home, - End, - PageUp, - PageDown, - Space, - F(u8), - Media(MediaKeyCode), -} - -/// A key input that may be a single key or a multi-key sequence (e.g. `g g`). -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum KeyInput { - Single(SingleKey), - Sequence(Vec), -} - -impl SingleKey { - /// Convert a crossterm `KeyEvent` into a `SingleKey`. - pub fn from_event(event: &KeyEvent) -> Option { - let ctrl = event.modifiers.contains(KeyModifiers::CONTROL); - let alt = event.modifiers.contains(KeyModifiers::ALT); - let shift = event.modifiers.contains(KeyModifiers::SHIFT); - let super_key = event.modifiers.contains(KeyModifiers::SUPER); - - let code = match event.code { - KeyCode::Char(' ') => KeyCodeValue::Space, - KeyCode::Char(c) => { - // If shift is the only modifier and it's an uppercase letter, - // we store the uppercase char directly and clear the shift flag - // since the case already encodes it. - if shift && !ctrl && !alt && !super_key && c.is_ascii_uppercase() { - return Some(SingleKey { - code: KeyCodeValue::Char(c), - ctrl: false, - alt: false, - shift: false, - super_key: false, - }); - } - KeyCodeValue::Char(c) - } - KeyCode::Enter => KeyCodeValue::Enter, - KeyCode::Esc => KeyCodeValue::Esc, - KeyCode::Tab => KeyCodeValue::Tab, - // BackTab is sent by many terminals for Shift+Tab - KeyCode::BackTab => { - return Some(SingleKey { - code: KeyCodeValue::Tab, - ctrl, - alt, - shift: true, - super_key, - }); - } - KeyCode::Backspace => KeyCodeValue::Backspace, - KeyCode::Delete => KeyCodeValue::Delete, - KeyCode::Insert => KeyCodeValue::Insert, - KeyCode::Up => KeyCodeValue::Up, - KeyCode::Down => KeyCodeValue::Down, - KeyCode::Left => KeyCodeValue::Left, - KeyCode::Right => KeyCodeValue::Right, - KeyCode::Home => KeyCodeValue::Home, - KeyCode::End => KeyCodeValue::End, - KeyCode::PageUp => KeyCodeValue::PageUp, - KeyCode::PageDown => KeyCodeValue::PageDown, - KeyCode::F(n) => KeyCodeValue::F(n), - KeyCode::Media(m) => KeyCodeValue::Media(m), - _ => return None, - }; - - Some(SingleKey { - code, - ctrl, - alt, - shift: if matches!(code, KeyCodeValue::Char(_)) { - false - } else { - shift - }, - super_key, - }) - } - - /// Parse a key string like `"ctrl-c"`, `"alt-f"`, `"enter"`, `"G"`. - pub fn parse(s: &str) -> Result { - let s = s.trim(); - let parts: Vec<&str> = s.split('-').collect(); - - let mut ctrl = false; - let mut alt = false; - let mut shift = false; - let mut super_key = false; - - // All parts except the last are modifiers - for &part in &parts[..parts.len() - 1] { - match part.to_lowercase().as_str() { - "ctrl" => ctrl = true, - "alt" => alt = true, - "shift" => shift = true, - "super" | "cmd" | "win" => super_key = true, - _ => return Err(format!("unknown modifier: {part}")), - } - } - - let key_part = parts[parts.len() - 1]; - let code = match key_part.to_lowercase().as_str() { - "enter" | "return" => KeyCodeValue::Enter, - "esc" | "escape" => KeyCodeValue::Esc, - "tab" => KeyCodeValue::Tab, - "backspace" => KeyCodeValue::Backspace, - "delete" | "del" => KeyCodeValue::Delete, - "insert" | "ins" => KeyCodeValue::Insert, - "up" => KeyCodeValue::Up, - "down" => KeyCodeValue::Down, - "left" => KeyCodeValue::Left, - "right" => KeyCodeValue::Right, - "home" => KeyCodeValue::Home, - "end" => KeyCodeValue::End, - "pageup" => KeyCodeValue::PageUp, - "pagedown" => KeyCodeValue::PageDown, - "space" => KeyCodeValue::Space, - s if s.starts_with('f') && s.len() > 1 => { - // Parse function keys like "f1", "f12" - if let Ok(n) = s[1..].parse::() { - if (1..=24).contains(&n) { - KeyCodeValue::F(n) - } else { - return Err(format!("function key out of range: {key_part}")); - } - } else { - return Err(format!("unknown key: {key_part}")); - } - } - "[" => KeyCodeValue::Char('['), - "]" => KeyCodeValue::Char(']'), - "?" => KeyCodeValue::Char('?'), - "/" => KeyCodeValue::Char('/'), - "$" => KeyCodeValue::Char('$'), - // Media keys (no dashes - the parser splits on dash for modifiers) - "play" => KeyCodeValue::Media(MediaKeyCode::Play), - "pause" => KeyCodeValue::Media(MediaKeyCode::Pause), - "playpause" => KeyCodeValue::Media(MediaKeyCode::PlayPause), - "stop" => KeyCodeValue::Media(MediaKeyCode::Stop), - "fastforward" => KeyCodeValue::Media(MediaKeyCode::FastForward), - "rewind" => KeyCodeValue::Media(MediaKeyCode::Rewind), - "tracknext" => KeyCodeValue::Media(MediaKeyCode::TrackNext), - "trackprevious" => KeyCodeValue::Media(MediaKeyCode::TrackPrevious), - "record" => KeyCodeValue::Media(MediaKeyCode::Record), - "lowervolume" => KeyCodeValue::Media(MediaKeyCode::LowerVolume), - "raisevolume" => KeyCodeValue::Media(MediaKeyCode::RaiseVolume), - "mutevolume" | "mute" => KeyCodeValue::Media(MediaKeyCode::MuteVolume), - _ => { - let chars: Vec = key_part.chars().collect(); - if chars.len() == 1 { - let c = chars[0]; - // An uppercase letter implies shift (unless shift already specified) - if c.is_ascii_uppercase() && !ctrl && !alt && !super_key { - return Ok(SingleKey { - code: KeyCodeValue::Char(c), - ctrl: false, - alt: false, - shift: false, - super_key: false, - }); - } - KeyCodeValue::Char(c) - } else { - return Err(format!("unknown key: {key_part}")); - } - } - }; - - Ok(SingleKey { - code, - ctrl, - alt, - shift, - super_key, - }) - } -} - -impl fmt::Display for SingleKey { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if self.super_key { - write!(f, "super-")?; - } - if self.ctrl { - write!(f, "ctrl-")?; - } - if self.alt { - write!(f, "alt-")?; - } - if self.shift { - write!(f, "shift-")?; - } - match &self.code { - KeyCodeValue::Char(c) => write!(f, "{c}"), - KeyCodeValue::Enter => write!(f, "enter"), - KeyCodeValue::Esc => write!(f, "esc"), - KeyCodeValue::Tab => write!(f, "tab"), - KeyCodeValue::Backspace => write!(f, "backspace"), - KeyCodeValue::Delete => write!(f, "delete"), - KeyCodeValue::Insert => write!(f, "insert"), - KeyCodeValue::Up => write!(f, "up"), - KeyCodeValue::Down => write!(f, "down"), - KeyCodeValue::Left => write!(f, "left"), - KeyCodeValue::Right => write!(f, "right"), - KeyCodeValue::Home => write!(f, "home"), - KeyCodeValue::End => write!(f, "end"), - KeyCodeValue::PageUp => write!(f, "pageup"), - KeyCodeValue::PageDown => write!(f, "pagedown"), - KeyCodeValue::Space => write!(f, "space"), - KeyCodeValue::F(n) => write!(f, "f{n}"), - KeyCodeValue::Media(m) => match m { - MediaKeyCode::Play => write!(f, "play"), - MediaKeyCode::Pause => write!(f, "media-pause"), - MediaKeyCode::PlayPause => write!(f, "playpause"), - MediaKeyCode::Stop => write!(f, "stop"), - MediaKeyCode::FastForward => write!(f, "fastforward"), - MediaKeyCode::Rewind => write!(f, "rewind"), - MediaKeyCode::TrackNext => write!(f, "tracknext"), - MediaKeyCode::TrackPrevious => write!(f, "trackprevious"), - MediaKeyCode::Record => write!(f, "record"), - MediaKeyCode::LowerVolume => write!(f, "lowervolume"), - MediaKeyCode::RaiseVolume => write!(f, "raisevolume"), - MediaKeyCode::MuteVolume => write!(f, "mutevolume"), - MediaKeyCode::Reverse => write!(f, "reverse"), - }, - } - } -} - -impl KeyInput { - /// Parse a key input string. Supports multi-key sequences separated by spaces - /// (e.g. `"g g"`). - pub fn parse(s: &str) -> Result { - let s = s.trim(); - // Check for space-separated multi-key sequences - // But don't split "space" or modifier combos like "ctrl-a" - let parts: Vec<&str> = s.split_whitespace().collect(); - if parts.len() > 1 { - let keys: Result, String> = - parts.iter().map(|p| SingleKey::parse(p)).collect(); - Ok(KeyInput::Sequence(keys?)) - } else { - Ok(KeyInput::Single(SingleKey::parse(s)?)) - } - } -} - -impl fmt::Display for KeyInput { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - KeyInput::Single(k) => write!(f, "{k}"), - KeyInput::Sequence(keys) => { - for (i, k) in keys.iter().enumerate() { - if i > 0 { - write!(f, " ")?; - } - write!(f, "{k}")?; - } - Ok(()) - } - } - } -} - -impl Serialize for KeyInput { - fn serialize(&self, serializer: S) -> Result { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for KeyInput { - fn deserialize>(deserializer: D) -> Result { - let s = String::deserialize(deserializer)?; - KeyInput::parse(&s).map_err(serde::de::Error::custom) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; - - #[test] - fn parse_simple_keys() { - let k = SingleKey::parse("a").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('a')); - assert!(!k.ctrl && !k.alt && !k.shift); - - let k = SingleKey::parse("enter").unwrap(); - assert_eq!(k.code, KeyCodeValue::Enter); - - let k = SingleKey::parse("esc").unwrap(); - assert_eq!(k.code, KeyCodeValue::Esc); - - let k = SingleKey::parse("tab").unwrap(); - assert_eq!(k.code, KeyCodeValue::Tab); - - let k = SingleKey::parse("space").unwrap(); - assert_eq!(k.code, KeyCodeValue::Space); - } - - #[test] - fn parse_modifiers() { - let k = SingleKey::parse("ctrl-c").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('c')); - assert!(k.ctrl); - assert!(!k.alt); - - let k = SingleKey::parse("alt-f").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('f')); - assert!(k.alt); - assert!(!k.ctrl); - - let k = SingleKey::parse("ctrl-alt-x").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('x')); - assert!(k.ctrl && k.alt); - } - - #[test] - fn parse_uppercase_implies_no_shift_flag() { - let k = SingleKey::parse("G").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('G')); - assert!(!k.shift); - assert!(!k.ctrl); - } - - #[test] - fn parse_special_chars() { - let k = SingleKey::parse("ctrl-[").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('[')); - assert!(k.ctrl); - - let k = SingleKey::parse("?").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('?')); - - let k = SingleKey::parse("/").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('/')); - } - - #[test] - fn parse_multi_key_sequence() { - let ki = KeyInput::parse("g g").unwrap(); - match ki { - KeyInput::Sequence(keys) => { - assert_eq!(keys.len(), 2); - assert_eq!(keys[0].code, KeyCodeValue::Char('g')); - assert_eq!(keys[1].code, KeyCodeValue::Char('g')); - } - _ => panic!("expected sequence"), - } - } - - #[test] - fn display_round_trip() { - let cases = ["ctrl-c", "alt-f", "enter", "G", "tab", "pageup"]; - for s in cases { - let k = KeyInput::parse(s).unwrap(); - let display = k.to_string(); - let k2 = KeyInput::parse(&display).unwrap(); - assert_eq!(k, k2, "round-trip failed for {s}"); - } - - let ki = KeyInput::parse("g g").unwrap(); - assert_eq!(ki.to_string(), "g g"); - } - - #[test] - fn from_event_basic() { - let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('c')); - assert!(k.ctrl); - assert!(!k.alt); - - let event = KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Enter); - } - - #[test] - fn from_event_uppercase() { - // Crossterm sends uppercase chars with SHIFT modifier - let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('G')); - // shift flag should be cleared since the case encodes it - assert!(!k.shift); - } - - #[test] - fn from_event_matches_parsed() { - // Verify that from_event and parse produce the same SingleKey - let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); - let from_event = SingleKey::from_event(&event).unwrap(); - let parsed = SingleKey::parse("ctrl-c").unwrap(); - assert_eq!(from_event, parsed); - - let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); - let from_event = SingleKey::from_event(&event).unwrap(); - let parsed = SingleKey::parse("G").unwrap(); - assert_eq!(from_event, parsed); - } - - #[test] - fn parse_super_modifier() { - let k = SingleKey::parse("super-a").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('a')); - assert!(k.super_key); - assert!(!k.ctrl && !k.alt && !k.shift); - - // "cmd" is an alias for "super" - let k2 = SingleKey::parse("cmd-a").unwrap(); - assert_eq!(k, k2); - - // "win" is an alias for "super" - let k3 = SingleKey::parse("win-a").unwrap(); - assert_eq!(k, k3); - } - - #[test] - fn parse_super_with_other_modifiers() { - let k = SingleKey::parse("super-ctrl-c").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('c')); - assert!(k.super_key && k.ctrl); - assert!(!k.alt && !k.shift); - } - - #[test] - fn display_super_modifier() { - let k = SingleKey::parse("super-a").unwrap(); - assert_eq!(k.to_string(), "super-a"); - - let k = SingleKey::parse("super-ctrl-x").unwrap(); - assert_eq!(k.to_string(), "super-ctrl-x"); - } - - #[test] - fn display_round_trip_super() { - let k = KeyInput::parse("super-a").unwrap(); - let display = k.to_string(); - let k2 = KeyInput::parse(&display).unwrap(); - assert_eq!(k, k2, "round-trip failed for super-a"); - } - - #[test] - fn from_event_super() { - let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('a')); - assert!(k.super_key); - assert!(!k.ctrl && !k.alt && !k.shift); - } - - #[test] - fn from_event_super_matches_parsed() { - let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); - let from_event = SingleKey::from_event(&event).unwrap(); - let parsed = SingleKey::parse("super-a").unwrap(); - assert_eq!(from_event, parsed); - } - - #[test] - fn super_uppercase_preserves_super() { - // super-G should keep the super flag (unlike bare "G" which clears shift) - let k = SingleKey::parse("super-G").unwrap(); - assert_eq!(k.code, KeyCodeValue::Char('G')); - assert!(k.super_key); - } - - #[test] - fn parse_errors() { - assert!(SingleKey::parse("ctrl-alt-shift-xxx").is_err()); - assert!(SingleKey::parse("foobar-a").is_err()); - } - - #[test] - fn parse_function_keys() { - let k = SingleKey::parse("f1").unwrap(); - assert_eq!(k.code, KeyCodeValue::F(1)); - assert!(!k.ctrl && !k.alt && !k.shift); - - let k = SingleKey::parse("F12").unwrap(); - assert_eq!(k.code, KeyCodeValue::F(12)); - - let k = SingleKey::parse("ctrl-f5").unwrap(); - assert_eq!(k.code, KeyCodeValue::F(5)); - assert!(k.ctrl); - - // F24 is valid (some keyboards have extended function keys) - let k = SingleKey::parse("f24").unwrap(); - assert_eq!(k.code, KeyCodeValue::F(24)); - - // F0 and F25+ are invalid - assert!(SingleKey::parse("f0").is_err()); - assert!(SingleKey::parse("f25").is_err()); - } - - #[test] - fn from_event_function_keys() { - let event = KeyEvent::new(KeyCode::F(1), KeyModifiers::NONE); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::F(1)); - - let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::CONTROL); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::F(12)); - assert!(k.ctrl); - } - - #[test] - fn display_function_keys() { - let k = SingleKey::parse("f1").unwrap(); - assert_eq!(k.to_string(), "f1"); - - let k = SingleKey::parse("ctrl-f12").unwrap(); - assert_eq!(k.to_string(), "ctrl-f12"); - } - - #[test] - fn function_key_round_trip() { - let cases = ["f1", "f12", "ctrl-f5", "alt-f10"]; - for s in cases { - let k = KeyInput::parse(s).unwrap(); - let display = k.to_string(); - let k2 = KeyInput::parse(&display).unwrap(); - assert_eq!(k, k2, "round-trip failed for {s}"); - } - } - - #[test] - fn from_event_function_key_matches_parsed() { - let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::NONE); - let from_event = SingleKey::from_event(&event).unwrap(); - let parsed = SingleKey::parse("f12").unwrap(); - assert_eq!(from_event, parsed); - } - - #[test] - fn from_event_backtab_becomes_shift_tab() { - // Many terminals send BackTab for Shift+Tab - let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Tab); - assert!(k.shift); - assert!(!k.ctrl && !k.alt); - } - - #[test] - fn from_event_backtab_matches_parsed_shift_tab() { - let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); - let from_event = SingleKey::from_event(&event).unwrap(); - let parsed = SingleKey::parse("shift-tab").unwrap(); - assert_eq!(from_event, parsed); - } - - #[test] - fn from_event_backtab_with_ctrl() { - // BackTab with ctrl modifier - let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::CONTROL); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Tab); - assert!(k.shift); - assert!(k.ctrl); - } - - #[test] - fn parse_insert_key() { - let k = SingleKey::parse("insert").unwrap(); - assert_eq!(k.code, KeyCodeValue::Insert); - assert!(!k.ctrl && !k.alt && !k.shift); - - let k = SingleKey::parse("ins").unwrap(); - assert_eq!(k.code, KeyCodeValue::Insert); - - let k = SingleKey::parse("ctrl-insert").unwrap(); - assert_eq!(k.code, KeyCodeValue::Insert); - assert!(k.ctrl); - } - - #[test] - fn from_event_insert_key() { - let event = KeyEvent::new(KeyCode::Insert, KeyModifiers::NONE); - let k = SingleKey::from_event(&event).unwrap(); - assert_eq!(k.code, KeyCodeValue::Insert); - } - - #[test] - fn insert_key_round_trip() { - let k = KeyInput::parse("insert").unwrap(); - let display = k.to_string(); - assert_eq!(display, "insert"); - let k2 = KeyInput::parse(&display).unwrap(); - assert_eq!(k, k2); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/keymap.rs b/crates/atuin/src/command/client/search/keybindings/keymap.rs deleted file mode 100644 index 0d362863..00000000 --- a/crates/atuin/src/command/client/search/keybindings/keymap.rs +++ /dev/null @@ -1,233 +0,0 @@ -use std::collections::HashMap; - -use super::actions::Action; -use super::conditions::{ConditionExpr, EvalContext}; -use super::key::{KeyInput, SingleKey}; - -/// A single rule within a keybinding: an optional condition and an action. -/// If the condition is `None`, the rule always matches. -#[derive(Debug, Clone)] -pub struct KeyRule { - pub condition: Option, - pub action: Action, -} - -/// A keybinding is an ordered list of rules. The first rule whose condition -/// matches (or has no condition) wins. -#[derive(Debug, Clone)] -pub struct KeyBinding { - pub rules: Vec, -} - -/// A keymap is a collection of keybindings indexed by key input. -#[derive(Debug, Clone)] -pub struct Keymap { - pub bindings: HashMap, -} - -impl KeyRule { - /// Create an unconditional rule. - pub fn always(action: Action) -> Self { - KeyRule { - condition: None, - action, - } - } - - /// Create a conditional rule. Accepts any type convertible to `ConditionExpr`, - /// including bare `ConditionAtom` values. - pub fn when(condition: impl Into, action: Action) -> Self { - KeyRule { - condition: Some(condition.into()), - action, - } - } -} - -impl KeyBinding { - /// Create a simple (unconditional) binding. - pub fn simple(action: Action) -> Self { - KeyBinding { - rules: vec![KeyRule::always(action)], - } - } - - /// Create a conditional binding from a list of rules. - pub fn conditional(rules: Vec) -> Self { - KeyBinding { rules } - } -} - -impl Keymap { - /// Create an empty keymap. - pub fn new() -> Self { - Keymap { - bindings: HashMap::new(), - } - } - - /// Bind a key input to a simple (unconditional) action. - pub fn bind(&mut self, key: KeyInput, action: Action) { - self.bindings.insert(key, KeyBinding::simple(action)); - } - - /// Bind a key input to a conditional set of rules. - pub fn bind_conditional(&mut self, key: KeyInput, rules: Vec) { - self.bindings.insert(key, KeyBinding::conditional(rules)); - } - - /// Resolve a key input to an action given the current evaluation context. - /// Returns `None` if the key has no binding or no rule's condition matches. - pub fn resolve(&self, key: &KeyInput, ctx: &EvalContext) -> Option { - let binding = self.bindings.get(key)?; - for rule in &binding.rules { - match &rule.condition { - None => return Some(rule.action.clone()), - Some(cond) if cond.evaluate(ctx) => return Some(rule.action.clone()), - Some(_) => {} - } - } - None - } - - /// Check if any binding starts with the given single key as the first key - /// of a multi-key sequence. Used to detect pending multi-key sequences. - pub fn has_sequence_starting_with(&self, prefix: &SingleKey) -> bool { - self.bindings.keys().any(|ki| match ki { - KeyInput::Sequence(keys) => keys.first() == Some(prefix), - KeyInput::Single(_) => false, - }) - } - - /// Merge another keymap into this one. Keys from `other` override keys in `self`. - #[expect(dead_code)] - pub fn merge(&mut self, other: &Keymap) { - for (key, binding) in &other.bindings { - self.bindings.insert(key.clone(), binding.clone()); - } - } -} - -impl Default for Keymap { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::super::conditions::ConditionAtom; - use super::*; - - fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { - EvalContext { - cursor_position: cursor, - input_width: width, - input_byte_len: width, - selected_index: selected, - results_len: len, - original_input_empty: false, - has_context: false, - } - } - - #[test] - fn simple_binding_resolves() { - let mut keymap = Keymap::new(); - let key = KeyInput::parse("ctrl-c").unwrap(); - keymap.bind(key.clone(), Action::ReturnOriginal); - - let ctx = make_ctx(0, 0, 0, 10); - assert_eq!(keymap.resolve(&key, &ctx), Some(Action::ReturnOriginal)); - } - - #[test] - fn conditional_first_match_wins() { - let mut keymap = Keymap::new(); - let key = KeyInput::parse("left").unwrap(); - keymap.bind_conditional( - key.clone(), - vec![ - KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), - KeyRule::always(Action::CursorLeft), - ], - ); - - // Cursor at start → Exit - let ctx = make_ctx(0, 5, 0, 10); - assert_eq!(keymap.resolve(&key, &ctx), Some(Action::Exit)); - - // Cursor not at start → CursorLeft - let ctx = make_ctx(3, 5, 0, 10); - assert_eq!(keymap.resolve(&key, &ctx), Some(Action::CursorLeft)); - } - - #[test] - fn no_match_returns_none() { - let keymap = Keymap::new(); - let key = KeyInput::parse("ctrl-c").unwrap(); - let ctx = make_ctx(0, 0, 0, 0); - assert_eq!(keymap.resolve(&key, &ctx), None); - } - - #[test] - fn conditional_no_condition_matches_returns_none() { - let mut keymap = Keymap::new(); - let key = KeyInput::parse("left").unwrap(); - // Only one rule with a condition that won't match - keymap.bind_conditional( - key.clone(), - vec![KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit)], - ); - - // Cursor not at start → no match - let ctx = make_ctx(3, 5, 0, 10); - assert_eq!(keymap.resolve(&key, &ctx), None); - } - - #[test] - fn has_sequence_starting_with() { - let mut keymap = Keymap::new(); - let seq = KeyInput::parse("g g").unwrap(); - keymap.bind(seq, Action::ScrollToTop); - - let g = SingleKey::parse("g").unwrap(); - assert!(keymap.has_sequence_starting_with(&g)); - - let h = SingleKey::parse("h").unwrap(); - assert!(!keymap.has_sequence_starting_with(&h)); - } - - #[test] - fn merge_overrides() { - let mut base = Keymap::new(); - let key = KeyInput::parse("ctrl-c").unwrap(); - base.bind(key.clone(), Action::ReturnOriginal); - - let mut overlay = Keymap::new(); - overlay.bind(key.clone(), Action::Exit); - - base.merge(&overlay); - - let ctx = make_ctx(0, 0, 0, 0); - assert_eq!(base.resolve(&key, &ctx), Some(Action::Exit)); - } - - #[test] - fn merge_preserves_unoverridden() { - let mut base = Keymap::new(); - let key1 = KeyInput::parse("ctrl-c").unwrap(); - let key2 = KeyInput::parse("ctrl-d").unwrap(); - base.bind(key1.clone(), Action::ReturnOriginal); - base.bind(key2.clone(), Action::DeleteCharAfter); - - let mut overlay = Keymap::new(); - overlay.bind(key1.clone(), Action::Exit); - - base.merge(&overlay); - - let ctx = make_ctx(0, 0, 0, 0); - assert_eq!(base.resolve(&key1, &ctx), Some(Action::Exit)); - assert_eq!(base.resolve(&key2, &ctx), Some(Action::DeleteCharAfter)); - } -} diff --git a/crates/atuin/src/command/client/search/keybindings/mod.rs b/crates/atuin/src/command/client/search/keybindings/mod.rs deleted file mode 100644 index 3b6eb2b2..00000000 --- a/crates/atuin/src/command/client/search/keybindings/mod.rs +++ /dev/null @@ -1,14 +0,0 @@ -pub mod actions; -pub mod conditions; -pub mod defaults; -pub mod key; -pub mod keymap; - -pub use actions::Action; -#[expect(unused_imports)] -pub use conditions::{ConditionAtom, ConditionExpr, EvalContext}; -pub use defaults::KeymapSet; -#[expect(unused_imports)] -pub use key::{KeyCodeValue, KeyInput, SingleKey}; -#[expect(unused_imports)] -pub use keymap::{KeyBinding, KeyRule, Keymap}; diff --git a/crates/atuin/src/command/client/setup.rs b/crates/atuin/src/command/client/setup.rs deleted file mode 100644 index 8de73d62..00000000 --- a/crates/atuin/src/command/client/setup.rs +++ /dev/null @@ -1,81 +0,0 @@ -use atuin_client::settings::Settings; - -use colored::Colorize; -use eyre::Result; -use std::io::{self, Write}; -use toml_edit::{DocumentMut, value}; - -pub async fn run(_settings: &Settings) -> Result<()> { - let enable_ai = prompt( - "Atuin AI", - "This will enable command generation and other AI features via the question mark key", - Some( - "By default, Atuin AI only has access to the name and version of your operating system and shell - your shell history is not sent to the AI.", - ), - )?; - - let enable_daemon = prompt( - "Atuin Daemon", - "This will enable improved search and history sync using a persistent background process", - None, - )?; - - let config_file = Settings::get_config_path()?; - let config_str = tokio::fs::read_to_string(&config_file).await?; - let mut doc = config_str.parse::()?; - - let mut changed = false; - if enable_ai { - changed = true; - if !doc.contains_key("ai") { - doc["ai"] = toml_edit::table(); - } - doc["ai"]["enabled"] = value(true); - } - - if enable_daemon { - changed = true; - if !doc.contains_key("daemon") { - doc["daemon"] = toml_edit::table(); - } - doc["daemon"]["enabled"] = value(true); - doc["daemon"]["autostart"] = value(true); - doc["search_mode"] = value("daemon-fuzzy"); - } - - if changed { - tokio::fs::write(config_file, doc.to_string()).await?; - - println!( - "{check} Settings updated successfully", - check = "✓".bold().bright_green() - ); - } else { - println!( - "{check} No settings changed", - check = "✓".bold().bright_green() - ); - } - - Ok(()) -} - -pub fn prompt(feature: &str, description: &str, note: Option<&str>) -> Result { - println!( - "> Enable {feature}?", - feature = feature.bold().bright_blue() - ); - if let Some(note) = note { - println!(" {description}"); - print!(" {note} {q} ", q = "[Y/n]".bold()); - } else { - print!(" {description} {q} ", q = "[Y/n]".bold()); - } - - io::stdout().flush().ok(); - - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - let answer = input.trim().to_lowercase(); - Ok(answer.is_empty() || answer == "y" || answer == "yes") -} diff --git a/crates/atuin/src/command/client/stats.rs b/crates/atuin/src/command/client/stats.rs deleted file mode 100644 index a7fc00ac..00000000 --- a/crates/atuin/src/command/client/stats.rs +++ /dev/null @@ -1,85 +0,0 @@ -use clap::Parser; -use eyre::Result; -use interim::parse_date_string; -use time::{Duration, OffsetDateTime, Time}; - -use atuin_client::{ - database::{Database, current_context}, - settings::Settings, - theme::Theme, -}; - -use atuin_history::stats::{compute, pretty_print}; - -fn parse_ngram_size(s: &str) -> Result { - let value = s - .parse::() - .map_err(|_| format!("'{s}' is not a valid window size"))?; - - if value == 0 { - return Err("ngram window size must be at least 1".to_string()); - } - - Ok(value) -} - -#[derive(Parser, Debug)] -#[command(infer_subcommands = true)] -pub struct Cmd { - /// Compute statistics for the specified period, leave blank for statistics since the beginning. See [this](https://docs.atuin.sh/reference/stats/) for more details. - period: Vec, - - /// How many top commands to list - #[arg(long, short, default_value = "10")] - count: usize, - - /// The number of consecutive commands to consider - #[arg(long, short, default_value = "1", value_parser = parse_ngram_size)] - ngram_size: usize, -} - -impl Cmd { - pub async fn run(&self, db: &impl Database, settings: &Settings, theme: &Theme) -> Result<()> { - let context = current_context().await?; - let words = if self.period.is_empty() { - String::from("all") - } else { - self.period.join(" ") - }; - - let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); - let last_night = now.replace_time(Time::MIDNIGHT); - - let history = if words.as_str() == "all" { - db.list(&[], &context, None, false, false).await? - } else if words.trim() == "today" { - let start = last_night; - let end = start + Duration::days(1); - db.range(start, end).await? - } else if words.trim() == "month" { - let end = last_night; - let start = end - Duration::days(31); - db.range(start, end).await? - } else if words.trim() == "week" { - let end = last_night; - let start = end - Duration::days(7); - db.range(start, end).await? - } else if words.trim() == "year" { - let end = last_night; - let start = end - Duration::days(365); - db.range(start, end).await? - } else { - let start = parse_date_string(&words, now, settings.dialect.into())?; - let end = start + Duration::days(1); - db.range(start, end).await? - }; - - let stats = compute(settings, &history, self.count, self.ngram_size); - - if let Some(stats) = stats { - pretty_print(stats, self.ngram_size, theme); - } - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store.rs b/crates/atuin/src/command/client/store.rs deleted file mode 100644 index 513c404a..00000000 --- a/crates/atuin/src/command/client/store.rs +++ /dev/null @@ -1,120 +0,0 @@ -use clap::Subcommand; -use eyre::Result; - -use atuin_client::{ - database::Database, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, -}; -use itertools::Itertools; -use time::{OffsetDateTime, UtcOffset}; - -#[cfg(feature = "sync")] -mod push; - -#[cfg(feature = "sync")] -mod pull; - -mod purge; -mod rebuild; -mod rekey; -mod verify; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Print the current status of the record store - Status, - - /// Rebuild a store (eg atuin store rebuild history) - Rebuild(rebuild::Rebuild), - - /// Re-encrypt the store with a new key (potential for data loss!) - Rekey(rekey::Rekey), - - /// Delete all records in the store that cannot be decrypted with the current key - Purge(purge::Purge), - - /// Verify that all records in the store can be decrypted with the current key - Verify(verify::Verify), - - /// Push all records to the remote sync server (one way sync) - #[cfg(feature = "sync")] - Push(push::Push), - - /// Pull records from the remote sync server (one way sync) - #[cfg(feature = "sync")] - Pull(pull::Pull), -} - -impl Cmd { - pub async fn run( - &self, - settings: &Settings, - database: &dyn Database, - store: SqliteStore, - ) -> Result<()> { - match self { - Self::Status => self.status(store).await, - Self::Rebuild(rebuild) => rebuild.run(settings, store, database).await, - Self::Rekey(rekey) => rekey.run(settings, store).await, - Self::Verify(verify) => verify.run(settings, store).await, - Self::Purge(purge) => purge.run(settings, store).await, - - #[cfg(feature = "sync")] - Self::Push(push) => push.run(settings, store).await, - - #[cfg(feature = "sync")] - Self::Pull(pull) => pull.run(settings, store, database).await, - } - } - - pub async fn status(&self, store: SqliteStore) -> Result<()> { - let host_id = Settings::host_id().await?; - let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); - - let status = store.status().await?; - - // TODO: should probs build some data structure and then pretty-print it or smth - for (host, st) in status.hosts.iter().sorted_by_key(|(h, _)| *h) { - let host_string = if host == &host_id { - format!("host: {} <- CURRENT HOST", host.0.as_hyphenated()) - } else { - format!("host: {}", host.0.as_hyphenated()) - }; - - println!("{host_string}"); - - for (tag, idx) in st.iter().sorted_by_key(|(tag, _)| *tag) { - println!("\tstore: {tag}"); - - let first = store.first(*host, tag).await?; - let last = store.last(*host, tag).await?; - - println!("\t\tidx: {idx}"); - - if let Some(first) = first { - println!("\t\tfirst: {}", first.id.0.as_hyphenated()); - - let time = - OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))? - .to_offset(offset); - println!("\t\t\tcreated: {time}"); - } - - if let Some(last) = last { - println!("\t\tlast: {}", last.id.0.as_hyphenated()); - - let time = - OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))? - .to_offset(offset); - println!("\t\t\tcreated: {time}"); - } - } - - println!(); - } - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/pull.rs b/crates/atuin/src/command/client/store/pull.rs deleted file mode 100644 index 25b925c7..00000000 --- a/crates/atuin/src/command/client/store/pull.rs +++ /dev/null @@ -1,94 +0,0 @@ -use clap::Args; -use eyre::Result; - -use atuin_client::{ - database::Database, - encryption::load_key, - record::store::Store, - record::sync::Operation, - record::{sqlite_store::SqliteStore, sync}, - settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Pull { - /// The tag to push (eg, 'history'). Defaults to all tags - #[arg(long, short)] - pub tag: Option, - - /// Force push records - /// This will first wipe the local store, and then download all records from the remote - #[arg(long, default_value = "false")] - pub force: bool, - - /// Page Size - /// How many records to download at once. Defaults to 100 - #[arg(long, default_value = "100")] - pub page: u64, -} - -impl Pull { - pub async fn run( - &self, - settings: &Settings, - store: SqliteStore, - db: &dyn Database, - ) -> Result<()> { - if self.force { - println!("Forcing local overwrite!"); - println!("Clearing local store"); - - store.delete_all().await?; - } - - // We can actually just use the existing diff/etc to push - // 1. Diff - // 2. Get operations - // 3. Filter operations by - // a) are they a download op? - // b) are they for the host/tag we are pushing here? - let client = sync::build_client(settings).await?; - let (diff, remote_index) = sync::diff(&client, &store).await?; - - // Skip on --force: local was already wiped above, mismatch is the user's call. - if !self.force { - let key: [u8; 32] = load_key(settings)?.into(); - sync::check_encryption_key(&client, &remote_index, &key) - .await - .map_err(crate::print_error::format_sync_error)?; - } - - let operations = sync::operations(diff, &store).await?; - - let operations = operations - .into_iter() - .filter(|op| match op { - // No noops or downloads thx - Operation::Noop { .. } | Operation::Upload { .. } => false, - - // pull, so yes plz to downloads! - Operation::Download { tag, .. } => { - if self.force { - return true; - } - - if let Some(t) = self.tag.clone() - && t != *tag - { - return false; - } - - true - } - }) - .collect(); - - let (_, downloaded) = sync::sync_remote(&client, operations, &store, self.page).await?; - - println!("Downloaded {} records", downloaded.len()); - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/purge.rs b/crates/atuin/src/command/client/store/purge.rs deleted file mode 100644 index ad2369ce..00000000 --- a/crates/atuin/src/command/client/store/purge.rs +++ /dev/null @@ -1,26 +0,0 @@ -use clap::Args; -use eyre::Result; - -use atuin_client::{ - encryption::load_key, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Purge {} - -impl Purge { - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - println!("Purging local records that cannot be decrypted"); - - let key = load_key(settings)?; - - match store.purge(&key.into()).await { - Ok(()) => println!("Local store purge completed OK"), - Err(e) => println!("Failed to purge local store: {e:?}"), - } - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/push.rs b/crates/atuin/src/command/client/store/push.rs deleted file mode 100644 index d8569e1e..00000000 --- a/crates/atuin/src/command/client/store/push.rs +++ /dev/null @@ -1,112 +0,0 @@ -use atuin_common::record::HostId; -use clap::Args; -use eyre::Result; -use uuid::Uuid; - -use atuin_client::{ - api_client::Client, - encryption::load_key, - record::sync::Operation, - record::{sqlite_store::SqliteStore, sync}, - settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Push { - /// The tag to push (eg, 'history'). Defaults to all tags - #[arg(long, short)] - pub tag: Option, - - /// The host to push, in the form of a UUID host ID. Defaults to the current host. - #[arg(long)] - pub host: Option, - - /// Force push records - /// This will override both host and tag, to be all hosts and all tags. First clear the remote store, then upload all of the - /// local store - #[arg(long, default_value = "false")] - pub force: bool, - - /// Page Size - /// How many records to upload at once. Defaults to 100 - #[arg(long, default_value = "100")] - pub page: u64, -} - -impl Push { - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - let host_id = Settings::host_id().await?; - - if self.force { - println!("Forcing remote store overwrite!"); - println!("Clearing remote store"); - - let client = Client::new( - &settings.sync_address, - settings.sync_auth_token().await?, - settings.network_connect_timeout, - settings.network_timeout * 10, // we may be deleting a lot of data... so up the - // timeout - ) - .expect("failed to create client"); - - client.delete_store().await?; - } - - // We can actually just use the existing diff/etc to push - // 1. Diff - // 2. Get operations - // 3. Filter operations by - // a) are they an upload op? - // b) are they for the host/tag we are pushing here? - let client = sync::build_client(settings).await?; - let (diff, remote_index) = sync::diff(&client, &store).await?; - - // Skip on --force: that path intentionally replaces remote with local. - if !self.force { - let key: [u8; 32] = load_key(settings)?.into(); - sync::check_encryption_key(&client, &remote_index, &key) - .await - .map_err(crate::print_error::format_sync_error)?; - } - - let operations = sync::operations(diff, &store).await?; - - let operations = operations - .into_iter() - .filter(|op| match op { - // No noops or downloads thx - Operation::Noop { .. } | Operation::Download { .. } => false, - - // push, so yes plz to uploads! - Operation::Upload { host, tag, .. } => { - if self.force { - return true; - } - - if let Some(h) = self.host { - if HostId(h) != *host { - return false; - } - } else if *host != host_id { - return false; - } - - if let Some(t) = self.tag.clone() - && t != *tag - { - return false; - } - - true - } - }) - .collect(); - - let (uploaded, _) = sync::sync_remote(&client, operations, &store, self.page).await?; - - println!("Uploaded {uploaded} records"); - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/rebuild.rs b/crates/atuin/src/command/client/store/rebuild.rs deleted file mode 100644 index b9f2837b..00000000 --- a/crates/atuin/src/command/client/store/rebuild.rs +++ /dev/null @@ -1,58 +0,0 @@ -use clap::Args; -use eyre::{Result, bail}; - -#[cfg(feature = "daemon")] -use crate::command::client::daemon as daemon_cmd; - -use atuin_client::{ - database::Database, encryption, history::store::HistoryStore, - record::sqlite_store::SqliteStore, settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Rebuild { - pub tag: String, -} - -impl Rebuild { - pub async fn run( - &self, - settings: &Settings, - store: SqliteStore, - database: &dyn Database, - ) -> Result<()> { - // keep it as a string and not an enum atm - // would be super cool to build this dynamically in the future - // eg register handles for rebuilding various tags without having to make this part of the - // binary big - match self.tag.as_str() { - "history" => { - self.rebuild_history(settings, store.clone(), database) - .await?; - } - - tag => bail!("unknown tag: {tag}"), - } - - Ok(()) - } - - async fn rebuild_history( - &self, - settings: &Settings, - store: SqliteStore, - database: &dyn Database, - ) -> Result<()> { - let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); - - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store, host_id, encryption_key); - - history_store.build(database).await?; - - #[cfg(feature = "daemon")] - daemon_cmd::emit_event(settings, atuin_daemon::DaemonEvent::HistoryRebuilt).await; - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/rekey.rs b/crates/atuin/src/command/client/store/rekey.rs deleted file mode 100644 index c92d2555..00000000 --- a/crates/atuin/src/command/client/store/rekey.rs +++ /dev/null @@ -1,41 +0,0 @@ -use clap::Args; -use eyre::Result; -use tokio::{fs::File, io::AsyncWriteExt}; - -use atuin_client::{ - encryption::{decode_key, generate_encoded_key, load_key}, - record::sqlite_store::SqliteStore, - record::store::Store, - settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Rekey { - /// The new key to use for encryption. Omit for a randomly-generated key - key: Option, -} - -impl Rekey { - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - let key = if let Some(key) = self.key.clone() { - println!("Re-encrypting store with specified key"); - - key - } else { - println!("Re-encrypting store with freshly-generated key"); - let (_, encoded) = generate_encoded_key()?; - encoded - }; - - let current_key: [u8; 32] = load_key(settings)?.into(); - let new_key: [u8; 32] = decode_key(key.clone())?.into(); - - store.re_encrypt(¤t_key, &new_key).await?; - - println!("Store rewritten. Saving new key"); - let mut file = File::create(settings.key_path.clone()).await?; - file.write_all(key.as_bytes()).await?; - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/store/verify.rs b/crates/atuin/src/command/client/store/verify.rs deleted file mode 100644 index 84bec96a..00000000 --- a/crates/atuin/src/command/client/store/verify.rs +++ /dev/null @@ -1,26 +0,0 @@ -use clap::Args; -use eyre::Result; - -use atuin_client::{ - encryption::load_key, - record::{sqlite_store::SqliteStore, store::Store}, - settings::Settings, -}; - -#[derive(Args, Debug)] -pub struct Verify {} - -impl Verify { - pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { - println!("Verifying local store can be decrypted with the current key"); - - let key = load_key(settings)?; - - match store.verify(&key.into()).await { - Ok(()) => println!("Local store encryption verified OK"), - Err(e) => println!("Failed to verify local store encryption: {e:?}"), - } - - Ok(()) - } -} diff --git a/crates/atuin/src/command/client/sync.rs b/crates/atuin/src/command/client/sync.rs deleted file mode 100644 index 5b8c2cb7..00000000 --- a/crates/atuin/src/command/client/sync.rs +++ /dev/null @@ -1,120 +0,0 @@ -use clap::Subcommand; -use eyre::{Result, WrapErr}; - -use atuin_client::{ - database::Database, - encryption, - history::store::HistoryStore, - record::{sqlite_store::SqliteStore, store::Store, sync}, - settings::Settings, -}; - -mod status; - -use crate::command::client::account; - -#[derive(Subcommand, Debug)] -#[command(infer_subcommands = true)] -pub enum Cmd { - /// Sync with the configured server - Sync { - /// Force re-download everything - #[arg(long, short)] - force: bool, - }, - - /// Login to the configured server - Login(account::login::Cmd), - - /// Log out - Logout, - - /// Register with the configured server - Register(account::register::Cmd), - - /// Print the encryption key for transfer to another machine - Key {}, - - /// Display the sync status - Status, -} - -impl Cmd { - pub async fn run( - self, - settings: Settings, - db: &impl Database, - store: SqliteStore, - ) -> Result<()> { - match self { - Self::Sync { force } => run(&settings, force, db, store).await, - Self::Login(l) => l.run(&settings, &store).await, - Self::Logout => account::logout::run().await, - Self::Register(r) => r.run(&settings).await, - Self::Status => status::run(&settings).await, - Self::Key {} => { - use atuin_client::encryption::{encode_key, load_key}; - let key = load_key(&settings).wrap_err("could not load encryption key")?; - - let encode = encode_key(&key).wrap_err("could not encode encryption key")?; - println!("{encode}"); - - Ok(()) - } - } - } -} - -async fn run( - settings: &Settings, - force: bool, - db: &impl Database, - store: SqliteStore, -) -> Result<()> { - let encryption_key: [u8; 32] = encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - - let host_id = Settings::host_id().await?; - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) - .await - .map_err(crate::print_error::format_sync_error)?; - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - - println!("{uploaded}/{} up/down to record store", downloaded.len()); - - let history_length = db.history_count(true).await?; - let store_history_length = store.len_tag("history").await?; - - #[expect(clippy::cast_sign_loss)] - if history_length as u64 > store_history_length { - println!("{history_length} in history index, but {store_history_length} in history store"); - println!("Running automatic history store init..."); - - // Internally we use the global filter mode, so this context is ignored. - // don't recurse or loop here. - history_store.init_store(db).await?; - - println!("Re-running sync due to new records locally"); - - // we'll want to run sync once more, as there will now be stuff to upload - let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) - .await - .map_err(crate::print_error::format_sync_error)?; - - crate::sync::build(settings, &store, db, Some(&downloaded)).await?; - - println!("{uploaded}/{} up/down to record store", downloaded.len()); - } - - println!( - "Sync complete! {} items in history database, force: {}", - db.history_count(true).await?, - force - ); - - Ok(()) -} diff --git a/crates/atuin/src/command/client/sync/status.rs b/crates/atuin/src/command/client/sync/status.rs deleted file mode 100644 index c992eb3e..00000000 --- a/crates/atuin/src/command/client/sync/status.rs +++ /dev/null @@ -1,37 +0,0 @@ -use crate::{SHA, VERSION}; -use atuin_client::{api_client, settings::Settings}; -use colored::Colorize; -use eyre::{Result, bail}; - -pub async fn run(settings: &Settings) -> Result<()> { - if !settings.logged_in().await? { - bail!("You are not logged in to a sync server - cannot show sync status"); - } - - let client = api_client::Client::new( - &settings.sync_address, - settings.sync_auth_token().await?, - settings.network_connect_timeout, - settings.network_timeout, - )?; - - let me = client.me().await?; - let last_sync = Settings::last_sync().await?; - - println!("Atuin v{VERSION} - Build rev {SHA}\n"); - - println!("{}", "[Local]".green()); - - if settings.auto_sync { - println!("Sync frequency: {}", settings.sync_frequency); - println!("Last sync: {}", last_sync.to_offset(settings.timezone.0)); - } - - if settings.auto_sync { - println!("{}", "[Remote]".green()); - println!("Address: {}", settings.sync_address); - println!("Username: {}", me.username); - } - - Ok(()) -} diff --git a/crates/atuin/src/command/client/wrapped.rs b/crates/atuin/src/command/client/wrapped.rs deleted file mode 100644 index 0e0c9f14..00000000 --- a/crates/atuin/src/command/client/wrapped.rs +++ /dev/null @@ -1,322 +0,0 @@ -use crossterm::style::{ResetColor, SetAttribute}; -use eyre::Result; -use std::collections::{HashMap, HashSet}; -use time::{Date, Duration, Month, OffsetDateTime, Time}; - -use atuin_client::{database::Database, settings::Settings, theme::Theme}; - -use atuin_history::stats::{Stats, compute}; - -#[derive(Debug)] -struct WrappedStats { - nav_commands: usize, - pkg_commands: usize, - error_rate: f64, - first_half_commands: Vec<(String, usize)>, - second_half_commands: Vec<(String, usize)>, - git_percentage: f64, - busiest_hour: Option<(String, usize)>, -} - -impl WrappedStats { - #[expect(clippy::too_many_lines, clippy::cast_precision_loss)] - fn new(settings: &Settings, stats: &Stats, history: &[atuin_client::history::History]) -> Self { - let nav_commands = stats - .top - .iter() - .filter(|(cmd, _)| { - let cmd = &cmd[0]; - cmd == "cd" || cmd == "ls" || cmd == "pwd" || cmd == "pushd" || cmd == "popd" - }) - .map(|(_, count)| count) - .sum(); - - let pkg_managers = [ - "cargo", - "npm", - "pnpm", - "yarn", - "pip", - "pip3", - "pipenv", - "poetry", - "pipx", - "uv", - "brew", - "apt", - "apt-get", - "apk", - "pacman", - "yay", - "paru", - "yum", - "dnf", - "dnf5", - "rpm", - "rpm-ostree", - "zypper", - "pkg", - "chocolatey", - "choco", - "scoop", - "winget", - "gem", - "bundle", - "shards", - "composer", - "gradle", - "maven", - "mvn", - "go get", - "nuget", - "dotnet", - "mix", - "hex", - "rebar3", - "nix", - "nix-env", - "cabal", - "opam", - ]; - - let pkg_commands = history - .iter() - .filter(|h| { - let cmd = h.command.clone(); - pkg_managers.iter().any(|pm| cmd.starts_with(pm)) - }) - .count(); - - // Error analysis - let mut command_errors: HashMap = HashMap::new(); // (total_uses, errors) - let midyear = history[0].timestamp + Duration::days(182); // Split year in half - - let mut first_half_commands: HashMap = HashMap::new(); - let mut second_half_commands: HashMap = HashMap::new(); - let mut hours: HashMap = HashMap::new(); - - for entry in history { - let cmd = entry - .command - .split_whitespace() - .next() - .unwrap_or("") - .to_string(); - let (total, errors) = command_errors.entry(cmd.clone()).or_insert((0, 0)); - *total += 1; - if entry.exit != 0 { - *errors += 1; - } - - // Track command evolution - if entry.timestamp < midyear { - *first_half_commands.entry(cmd.clone()).or_default() += 1; - } else { - *second_half_commands.entry(cmd).or_default() += 1; - } - - // Track hourly distribution - let local_time = entry - .timestamp - .to_offset(time::UtcOffset::current_local_offset().unwrap_or(settings.timezone.0)); - let hour = format!("{:02}:00", local_time.time().hour()); - *hours.entry(hour).or_default() += 1; - } - - let total_errors: usize = command_errors.values().map(|(_, errors)| errors).sum(); - let total_commands: usize = command_errors.values().map(|(total, _)| total).sum(); - let error_rate = total_errors as f64 / total_commands as f64; - - // Process command evolution data - let mut first_half: Vec<_> = first_half_commands.into_iter().collect(); - let mut second_half: Vec<_> = second_half_commands.into_iter().collect(); - first_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); - second_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); - first_half.truncate(5); - second_half.truncate(5); - - // Calculate git percentage - let git_commands: usize = stats - .top - .iter() - .filter(|(cmd, _)| cmd[0].starts_with("git")) - .map(|(_, count)| count) - .sum(); - let git_percentage = git_commands as f64 / stats.total_commands as f64; - - // Find busiest hour - let busiest_hour = hours.into_iter().max_by_key(|(_, count)| *count); - - Self { - nav_commands, - pkg_commands, - error_rate, - first_half_commands: first_half, - second_half_commands: second_half, - git_percentage, - busiest_hour, - } - } -} - -pub fn print_wrapped_header(year: i32) { - let reset = ResetColor; - let bold = SetAttribute(crossterm::style::Attribute::Bold); - - println!("{bold}╭────────────────────────────────────╮{reset}"); - println!("{bold}│ ATUIN WRAPPED {year} │{reset}"); - println!("{bold}│ Your Year in Shell History │{reset}"); - println!("{bold}╰────────────────────────────────────╯{reset}"); - println!(); -} - -#[expect(clippy::cast_precision_loss)] -fn print_fun_facts(wrapped_stats: &WrappedStats, stats: &Stats, year: i32) { - let reset = ResetColor; - let bold = SetAttribute(crossterm::style::Attribute::Bold); - - if wrapped_stats.git_percentage > 0.05 { - println!( - "{bold}🌟 You're a Git Power User!{reset} {bold}{:.1}%{reset} of your commands were Git operations\n", - wrapped_stats.git_percentage * 100.0 - ); - } - // Navigation patterns - let nav_percentage = wrapped_stats.nav_commands as f64 / stats.total_commands as f64 * 100.0; - if nav_percentage > 0.05 { - println!( - "{bold}🚀 You're a Navigator!{reset} {bold}{nav_percentage:.1}%{reset} of your time was spent navigating directories\n", - ); - } - - // Command vocabulary - println!( - "{bold}📚 Command Vocabulary{reset}: You know {bold}{}{reset} unique commands\n", - stats.unique_commands - ); - - // Package management - println!( - "{bold}📦 Package Management{reset}: You ran {bold}{}{reset} package-related commands\n", - wrapped_stats.pkg_commands - ); - - // Error patterns - let error_percentage = wrapped_stats.error_rate * 100.0; - println!( - "{bold}🚨 Error Analysis{reset}: Your commands failed {bold}{error_percentage:.1}%{reset} of the time\n", - ); - - // Command evolution - println!("🔍 Command Evolution:"); - - // print stats for each half and compare - println!(" {bold}Top Commands{reset} in the first half of {year}:"); - for (cmd, count) in wrapped_stats.first_half_commands.iter().take(3) { - println!(" {bold}{cmd}{reset} ({count} times)"); - } - - println!(" {bold}Top Commands{reset} in the second half of {year}:"); - for (cmd, count) in wrapped_stats.second_half_commands.iter().take(3) { - println!(" {bold}{cmd}{reset} ({count} times)"); - } - - // Find new favorite commands (in top 5 of second half but not in first half) - let first_half_set: HashSet<_> = wrapped_stats - .first_half_commands - .iter() - .map(|(cmd, _)| cmd) - .collect(); - let new_favorites: Vec<_> = wrapped_stats - .second_half_commands - .iter() - .filter(|(cmd, _)| !first_half_set.contains(cmd)) - .take(2) - .collect(); - - if !new_favorites.is_empty() { - println!(" {bold}New favorites{reset} in the second half:"); - for (cmd, count) in new_favorites { - println!(" {bold}{cmd}{reset} ({count} times)"); - } - } - - // Time patterns - if let Some((hour, count)) = &wrapped_stats.busiest_hour { - println!("\n🕘 Most Productive Hour: {bold}{hour}{reset} ({count} commands)"); - - // Night owl or early bird - let hour_num = hour - .split(':') - .next() - .unwrap_or("0") - .parse::() - .unwrap_or(0); - if hour_num >= 22 || hour_num <= 4 { - println!(" You're quite the night owl! 🦉"); - } else if (5..=7).contains(&hour_num) { - println!(" Early bird gets the worm! 🐦"); - } - } - - println!(); -} - -pub async fn run( - year: Option, - db: &impl Database, - settings: &Settings, - theme: &Theme, -) -> Result<()> { - let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); - let month = now.month(); - - // If we're in December, then wrapped is for the current year. If not, it's for the previous year - let year = year.unwrap_or_else(|| { - if month == Month::December { - now.year() - } else { - now.year() - 1 - } - }); - - let start = OffsetDateTime::new_in_offset( - Date::from_calendar_date(year, Month::January, 1).unwrap(), - Time::MIDNIGHT, - now.offset(), - ); - let end = OffsetDateTime::new_in_offset( - Date::from_calendar_date(year, Month::December, 31).unwrap(), - Time::MIDNIGHT + Duration::days(1) - Duration::nanoseconds(1), - now.offset(), - ); - - let history = db.range(start, end).await?; - if history.is_empty() { - println!( - "Your history for {year} is empty!\nMaybe 'atuin import' could help you import your previous history 🪄" - ); - return Ok(()); - } - - // Compute overall stats using existing functionality - let stats = compute(settings, &history, 10, 1).expect("Failed to compute stats"); - let wrapped_stats = WrappedStats::new(settings, &stats, &history); - - // Print wrapped format - print_wrapped_header(year); - - println!("🎉 In {year}, you typed {} commands!", stats.total_commands); - println!( - " That's ~{} commands every day\n", - stats.total_commands / 365 - ); - - println!("Your Top Commands:"); - atuin_history::stats::pretty_print(stats.clone(), 1, theme); - println!(); - - print_fun_facts(&wrapped_stats, &stats, year); - - Ok(()) -} diff --git a/crates/atuin/src/command/contributors.rs b/crates/atuin/src/command/contributors.rs deleted file mode 100644 index 452fd335..00000000 --- a/crates/atuin/src/command/contributors.rs +++ /dev/null @@ -1,5 +0,0 @@ -static CONTRIBUTORS: &str = include_str!("CONTRIBUTORS"); - -pub fn run() { - println!("\n{CONTRIBUTORS}"); -} diff --git a/crates/atuin/src/command/external.rs b/crates/atuin/src/command/external.rs deleted file mode 100644 index 5d875e9d..00000000 --- a/crates/atuin/src/command/external.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::fmt::Write as _; -use std::process::Command; -use std::{io, process}; - -#[cfg(feature = "client")] -use atuin_client::plugin::{OfficialPluginRegistry, PluginContext}; -use clap::CommandFactory; -use clap::builder::{StyledStr, Styles}; -use eyre::Result; - -use crate::Atuin; - -pub fn run(args: &[String]) -> Result<()> { - let subcommand = &args[0]; - let bin = format!("atuin-{subcommand}"); - let mut cmd = Command::new(&bin); - cmd.args(&args[1..]); - - #[cfg(feature = "client")] - let context = PluginContext::new(subcommand); - - let spawn_result = match cmd.spawn() { - Ok(child) => Ok(child), - Err(e) => match e.kind() { - io::ErrorKind::NotFound => { - let output = render_not_found(subcommand, &bin); - Err(output) - } - _ => Err(e.to_string().into()), - }, - }; - - match spawn_result { - Ok(mut child) => { - let status = child.wait()?; - if status.success() { - Ok(()) - } else { - #[cfg(feature = "client")] - drop(context); - - process::exit(status.code().unwrap_or(1)); - } - } - Err(e) => { - eprintln!("{}", e.ansi()); - - #[cfg(feature = "client")] - drop(context); - - process::exit(1); - } - } -} - -fn render_not_found(subcommand: &str, bin: &str) -> StyledStr { - let mut output = StyledStr::new(); - let styles = Styles::styled(); - - let error = styles.get_error(); - let invalid = styles.get_invalid(); - let literal = styles.get_literal(); - - #[cfg(feature = "client")] - { - let registry = OfficialPluginRegistry::new(); - - // Check if this is an official plugin - if let Some(install_message) = registry.get_install_message(subcommand) { - let _ = write!(output, "{error}error:{error:#} "); - let _ = write!( - output, - "'{invalid}{subcommand}{invalid:#}' is an official atuin plugin, but it's not installed" - ); - let _ = write!(output, "\n\n"); - let _ = write!(output, "{install_message}"); - return output; - } - } - - let mut atuin_cmd = Atuin::command(); - let usage = atuin_cmd.render_usage(); - - let _ = write!(output, "{error}error:{error:#} "); - let _ = write!( - output, - "unrecognized subcommand '{invalid}{subcommand}{invalid:#}' " - ); - let _ = write!( - output, - "and no executable named '{invalid}{bin}{invalid:#}' found in your PATH" - ); - let _ = write!(output, "\n\n"); - let _ = write!(output, "{usage}"); - let _ = write!(output, "\n\n"); - let _ = write!( - output, - "For more information, try '{literal}--help{literal:#}'." - ); - - output -} diff --git a/crates/atuin/src/command/gen_completions.rs b/crates/atuin/src/command/gen_completions.rs deleted file mode 100644 index 10d4f689..00000000 --- a/crates/atuin/src/command/gen_completions.rs +++ /dev/null @@ -1,84 +0,0 @@ -use clap::{CommandFactory, Parser, ValueEnum}; -use clap_complete::{Generator, Shell, generate, generate_to}; -use clap_complete_nushell::Nushell; -use eyre::Result; - -// clap put nushell completions into a separate package due to the maintainers -// being a little less committed to support them. -// This means we have to do a tiny bit of legwork to combine these completions -// into one command. -#[derive(Debug, Clone, ValueEnum)] -#[value(rename_all = "lower")] -pub enum GenShell { - Bash, - Elvish, - Fish, - Nushell, - PowerShell, - Zsh, -} - -impl Generator for GenShell { - fn file_name(&self, name: &str) -> String { - match self { - // clap_complete - Self::Bash => Shell::Bash.file_name(name), - Self::Elvish => Shell::Elvish.file_name(name), - Self::Fish => Shell::Fish.file_name(name), - Self::PowerShell => Shell::PowerShell.file_name(name), - Self::Zsh => Shell::Zsh.file_name(name), - - // clap_complete_nushell - Self::Nushell => Nushell.file_name(name), - } - } - - fn generate(&self, cmd: &clap::Command, buf: &mut dyn std::io::prelude::Write) { - match self { - // clap_complete - Self::Bash => Shell::Bash.generate(cmd, buf), - Self::Elvish => Shell::Elvish.generate(cmd, buf), - Self::Fish => Shell::Fish.generate(cmd, buf), - Self::PowerShell => Shell::PowerShell.generate(cmd, buf), - Self::Zsh => Shell::Zsh.generate(cmd, buf), - - // clap_complete_nushell - Self::Nushell => Nushell.generate(cmd, buf), - } - } -} - -#[derive(Debug, Parser)] -pub struct Cmd { - /// Set the shell for generating completions - #[arg(long, short)] - shell: GenShell, - - /// Set the output directory - #[arg(long, short)] - out_dir: Option, -} - -impl Cmd { - pub fn run(self) -> Result<()> { - let Cmd { shell, out_dir } = self; - - let mut cli = crate::Atuin::command(); - - match out_dir { - Some(out_dir) => { - generate_to(shell, &mut cli, env!("CARGO_PKG_NAME"), &out_dir)?; - } - None => { - generate( - shell, - &mut cli, - env!("CARGO_PKG_NAME"), - &mut std::io::stdout(), - ); - } - } - - Ok(()) - } -} diff --git a/crates/atuin/src/command/mod.rs b/crates/atuin/src/command/mod.rs deleted file mode 100644 index 8aac4062..00000000 --- a/crates/atuin/src/command/mod.rs +++ /dev/null @@ -1,162 +0,0 @@ -use clap::Subcommand; -use eyre::Result; - -#[cfg(not(windows))] -use rustix::{fs::Mode, process::umask}; - -#[cfg(feature = "client")] -mod client; - -mod contributors; - -mod gen_completions; - -mod external; - -#[derive(Subcommand)] -#[command(infer_subcommands = true)] -#[expect(clippy::large_enum_variant)] -pub enum AtuinCmd { - #[cfg(feature = "client")] - #[command(flatten)] - Client(client::Cmd), - - /// PTY proxy for atuin - #[cfg(feature = "pty-proxy")] - #[command(alias = "hex")] - PtyProxy(atuin_pty_proxy::PtyProxy), - - /// Generate a UUID - Uuid, - - Contributors, - - /// Generate shell completions - GenCompletions(gen_completions::Cmd), - - #[command(external_subcommand)] - External(Vec), -} - -impl AtuinCmd { - pub fn run(self) -> Result<()> { - #[cfg(not(windows))] - { - // set umask before we potentially open/create files - // or in other words, 077. Do not allow any access to any other user - let mode = Mode::RWXG | Mode::RWXO; - umask(mode); - } - - match self { - #[cfg(feature = "client")] - Self::Client(client) => client.run(), - - #[cfg(feature = "pty-proxy")] - Self::PtyProxy(proxy) => { - run_pty_proxy(proxy); - Ok(()) - } - - Self::Contributors => { - contributors::run(); - Ok(()) - } - Self::Uuid => { - println!("{}", atuin_common::utils::uuid_v7().as_simple()); - Ok(()) - } - Self::GenCompletions(gen_completions) => gen_completions.run(), - Self::External(args) => external::run(&args), - } - } -} - -#[cfg(all(feature = "pty-proxy", unix))] -fn run_pty_proxy(proxy: atuin_pty_proxy::PtyProxy) { - #[cfg(feature = "daemon")] - proxy.run(semantic_command_capture_sink()); - - #[cfg(not(feature = "daemon"))] - proxy.run(None); -} - -#[cfg(all(feature = "pty-proxy", not(unix)))] -fn run_pty_proxy(_proxy: atuin_pty_proxy::PtyProxy) { - eprintln!("atuin pty-proxy currently supports unix platforms"); - std::process::exit(1); -} - -#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] -fn semantic_command_capture_sink() -> Option { - use std::sync::mpsc; - use std::time::Duration; - - if is_truthy_env("ATUIN_TERMINAL") { - return None; - } - - let settings = atuin_client::settings::Settings::new().ok()?; - let (tx, rx) = mpsc::sync_channel::(128); - - std::thread::spawn(move || { - let Ok(runtime) = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - else { - return; - }; - - while let Ok(first) = rx.recv() { - let mut batch = vec![first]; - - while batch.len() < 64 { - match rx.recv_timeout(Duration::from_millis(25)) { - Ok(capture) => batch.push(capture), - Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { - break; - } - } - } - - runtime.block_on(send_semantic_command_captures(&settings, batch)); - } - }); - - Some(Box::new(move |capture| { - let _ = tx.try_send(capture); - })) -} - -#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] -#[inline] -fn is_truthy_env(name: &str) -> bool { - std::env::var(name) - .ok() - .as_ref() - .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") -} - -#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] -async fn send_semantic_command_captures( - settings: &atuin_client::settings::Settings, - batch: Vec, -) { - let captures = batch - .into_iter() - .map(|capture| atuin_daemon::semantic::CommandCapture { - prompt: capture.prompt, - command: capture.command, - output: capture.output, - exit_code: capture.exit_code, - history_id: capture.history_id, - session_id: capture.session_id, - output_truncated: capture.output_truncated, - output_observed_bytes: capture.output_observed_bytes, - }) - .collect(); - - if let Ok(mut client) = atuin_daemon::SemanticClient::from_settings(settings).await { - let _ = client.record_commands(captures).await; - } -} diff --git a/crates/atuin/src/main.rs b/crates/atuin/src/main.rs deleted file mode 100644 index 255db36a..00000000 --- a/crates/atuin/src/main.rs +++ /dev/null @@ -1,61 +0,0 @@ -#![warn(clippy::pedantic, clippy::nursery)] -#![allow(clippy::use_self, clippy::missing_const_for_fn)] // not 100% reliable - -use clap::Parser; -use clap::builder::Styles; -use clap::builder::styling::{AnsiColor, Effects}; -use eyre::Result; - -use command::AtuinCmd; - -mod command; - -#[cfg(feature = "sync")] -mod print_error; -#[cfg(feature = "sync")] -mod sync; - -const VERSION: &str = env!("CARGO_PKG_VERSION"); -const SHA: &str = env!("GIT_HASH"); - -const LONG_VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), " (", env!("GIT_HASH"), ")"); - -static HELP_TEMPLATE: &str = "\ -{before-help}{name} {version} -{author} -{about} - -{usage-heading} - {usage} - -{all-args}{after-help}"; - -const STYLES: Styles = Styles::styled() - .header(AnsiColor::Yellow.on_default().effects(Effects::BOLD)) - .usage(AnsiColor::Green.on_default().effects(Effects::BOLD)) - .literal(AnsiColor::Green.on_default().effects(Effects::BOLD)) - .placeholder(AnsiColor::Green.on_default()); - -/// Magical shell history -#[derive(Parser)] -#[command( - author = "Ellie Huxtable ", - version = VERSION, - long_version = LONG_VERSION, - help_template(HELP_TEMPLATE), - styles = STYLES, -)] -struct Atuin { - #[command(subcommand)] - atuin: AtuinCmd, -} - -impl Atuin { - fn run(self) -> Result<()> { - self.atuin.run() - } -} - -fn main() -> Result<()> { - Atuin::parse().run() -} diff --git a/crates/atuin/src/print_error.rs b/crates/atuin/src/print_error.rs deleted file mode 100644 index a6da283d..00000000 --- a/crates/atuin/src/print_error.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::io::IsTerminal; - -use atuin_client::record::sync::SyncError; -use colored::Colorize; -use crossterm::terminal; - -/// Print a prominent error to stderr. Colored and box-bordered when stderr is -/// a TTY, plain "Error: ..." header otherwise. The description is word-wrapped -/// to the terminal width (capped at 100 columns) so the message stays readable. -pub fn print_error(title: &str, description: &str) { - let is_tty = std::io::stderr().is_terminal(); - let width = if is_tty { - terminal::size().map_or(80, |(w, _)| w as usize) - } else { - 80 - } - .min(100); - - eprintln!(); - if is_tty { - let bar = "━".repeat(width).red().bold().to_string(); - eprintln!("{bar}"); - eprintln!(" {} {}", "✗".red().bold(), title.red().bold()); - eprintln!("{bar}"); - } else { - eprintln!("Error: {title}"); - eprintln!("{}", "-".repeat(width)); - } - eprintln!(); - - for line in wrap_text(description, width.saturating_sub(2)) { - eprintln!(" {line}"); - } - eprintln!(); -} - -/// Convert a `SyncError` into an `eyre::Report`, exiting on `WrongKey` after -/// painting the prominent banner. -pub fn format_sync_error(e: SyncError) -> eyre::Report { - if matches!(e, SyncError::WrongKey) { - print_error( - "Wrong encryption key", - "Your local encryption key cannot decrypt the data on the server. \ - This usually means another machine wrote records with a different key.\n\n\ - To fix this, find the correct key by running `atuin key` on a machine that \ - already syncs successfully, then run `atuin store rekey ` here.", - ); - std::process::exit(1); - } - e.into() -} - -fn wrap_text(text: &str, width: usize) -> Vec { - let mut out = Vec::new(); - for paragraph in text.split('\n') { - let mut line = String::new(); - let mut line_len = 0; - for word in paragraph.split_whitespace() { - let word_len = word.chars().count(); - if !line.is_empty() && line_len + 1 + word_len > width { - out.push(std::mem::take(&mut line)); - line_len = 0; - } - if !line.is_empty() { - line.push(' '); - line_len += 1; - } - line.push_str(word); - line_len += word_len; - } - // Push every paragraph's final line (even empty) so `\n\n` in the - // input becomes a blank line in the output. - out.push(line); - } - while out.first().is_some_and(String::is_empty) { - out.remove(0); - } - while out.last().is_some_and(String::is_empty) { - out.pop(); - } - out -} - -#[cfg(test)] -mod tests { - use super::wrap_text; - - #[test] - fn wraps_long_text() { - let lines = wrap_text("the quick brown fox jumps over the lazy dog", 20); - for line in &lines { - assert!(line.chars().count() <= 20, "line too long: {line:?}"); - } - assert_eq!( - lines.join(" "), - "the quick brown fox jumps over the lazy dog" - ); - } - - #[test] - fn preserves_explicit_newlines() { - let lines = wrap_text("first line\nsecond line", 80); - assert_eq!(lines, vec!["first line", "second line"]); - } - - #[test] - fn handles_word_longer_than_width() { - let lines = wrap_text("short superlongword more", 5); - assert_eq!(lines, vec!["short", "superlongword", "more"]); - } - - #[test] - fn preserves_blank_lines_between_paragraphs() { - let lines = wrap_text("first paragraph\n\nsecond paragraph", 80); - assert_eq!(lines, vec!["first paragraph", "", "second paragraph"]); - } - - #[test] - fn trims_leading_and_trailing_blank_lines() { - let lines = wrap_text("\n\nbody\n\n", 80); - assert_eq!(lines, vec!["body"]); - } -} diff --git a/crates/atuin/src/shell/.gitattributes b/crates/atuin/src/shell/.gitattributes deleted file mode 100644 index fae8897c..00000000 --- a/crates/atuin/src/shell/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -* eol=lf diff --git a/crates/atuin/src/shell/atuin.bash b/crates/atuin/src/shell/atuin.bash deleted file mode 100644 index 8b540bd7..00000000 --- a/crates/atuin/src/shell/atuin.bash +++ /dev/null @@ -1,725 +0,0 @@ -# Include guard -if [[ ${__atuin_initialized-} == true ]]; then - false -elif [[ $- != *i* ]]; then - # Enable only in interactive shells - false -elif ((BASH_VERSINFO[0] < 3 || BASH_VERSINFO[0] == 3 && BASH_VERSINFO[1] < 1)); then - # Require bash >= 3.1 - [[ -t 2 ]] && printf 'atuin: requires bash >= 3.1 for the integration.\n' >&2 - false -else # (include guard) beginning of main content -#------------------------------------------------------------------------------ -__atuin_initialized=true - -if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then - ATUIN_SESSION=$(atuin uuid) - export ATUIN_SESSION - export ATUIN_SHLVL=$SHLVL -fi -ATUIN_STTY=$(stty -g) -ATUIN_HISTORY_ID="" - -__atuin_osc133_command_executed() { - [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return - [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return - - printf '\033]133;C\a' -} - -__atuin_osc133_command_finished() { - [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return - [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return - - printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" -} - -__atuin_osc133_prompt_start=$'\001\033]133;A;cl=line\a\002' -__atuin_osc133_prompt_end=$'\001\033]133;B\a\002' - -__atuin_osc133_wrap_prompt() { - local __atuin_prompt="${PS1-}" - __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" - __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" - - if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then - PS1="${__atuin_osc133_prompt_start}${__atuin_prompt}${__atuin_osc133_prompt_end}" - else - PS1="$__atuin_prompt" - fi -} - -export ATUIN_PREEXEC_BACKEND=$SHLVL:none -__atuin_update_preexec_backend() { - if [[ ${BLE_ATTACHED-} ]]; then - ATUIN_PREEXEC_BACKEND=$SHLVL:blesh-${BLE_VERSION-} - elif [[ ${bash_preexec_imported-} ]]; then - ATUIN_PREEXEC_BACKEND=$SHLVL:bash-preexec - elif [[ ${__bp_imported-} ]]; then - ATUIN_PREEXEC_BACKEND="$SHLVL:bash-preexec (old)" - else - ATUIN_PREEXEC_BACKEND=$SHLVL:unknown - fi -} - -__atuin_preexec() { - # Workaround for old versions of bash-preexec - if [[ ! ${BLE_ATTACHED-} ]]; then - # In older versions of bash-preexec, the preexec hook may be called - # even for the commands run by keybindings. There is no general and - # robust way to detect the command for keybindings, but at least we - # want to exclude Atuin's keybindings. When the preexec hook is called - # for a keybinding, the preexec hook for the user command will not - # fire, so we instead set a fake ATUIN_HISTORY_ID here to notify - # __atuin_precmd of this failure. - if [[ $BASH_COMMAND != "$1" ]]; then - case $BASH_COMMAND in - '__atuin_history'* | '__atuin_widget_run'* | '__atuin_bash42_dispatch'*) - ATUIN_HISTORY_ID=__bash_preexec_failure__ - return 0 ;; - esac - fi - fi - - # Note: We update ATUIN_PREEXEC_BACKEND on every preexec because blesh's - # attaching state can dynamically change. - __atuin_update_preexec_backend - - local id - id=$(atuin history start -- "$1" 2>/dev/null) - export ATUIN_HISTORY_ID=$id - [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_executed - __atuin_preexec_time=${EPOCHREALTIME-} -} - -__atuin_precmd() { - local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} - - __atuin_osc133_wrap_prompt - - [[ ! $ATUIN_HISTORY_ID ]] && return - - # If the previous preexec hook failed, we manually call __atuin_preexec - local __atuin_skip_osc133="" - if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then - # This is the command extraction code taken from bash-preexec - local previous_command - previous_command=$( - export LC_ALL=C HISTTIMEFORMAT='' - builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' - ) - __atuin_skip_osc133=1 - __atuin_preexec "$previous_command" - fi - - local duration="" - # shellcheck disable=SC2154,SC2309 - if [[ ${BLE_ATTACHED-} && ${_ble_exec_time_ata-} ]]; then - # With ble.sh, we utilize the shell variable `_ble_exec_time_ata` - # recorded by ble.sh. It is more accurate than the measurements by - # Atuin, which includes the spawn cost of Atuin. ble.sh uses the - # special shell variable `EPOCHREALTIME` in bash >= 5.0 with the - # microsecond resolution, or the builtin `time` in bash < 5.0 with the - # millisecond resolution. - duration=${_ble_exec_time_ata}000 - elif ((BASH_VERSINFO[0] >= 5)); then - # We calculate the high-resolution duration based on EPOCHREALTIME - # (bash >= 5.0) recorded by precmd/preexec, though it might not be as - # accurate as `_ble_exec_time_ata` provided by ble.sh because it - # includes the extra time of the precmd/preexec handling. Since Bash - # does not offer floating-point arithmetic, we remove the non-digit - # characters and perform the integral arithmetic. The fraction part of - # EPOCHREALTIME is fixed to have 6 digits in Bash. We remove all the - # non-digit characters because the decimal point is not necessarily a - # period depending on the locale. - duration=$((${__atuin_precmd_time//[!0-9]} - ${__atuin_preexec_time//[!0-9]})) - if ((duration >= 0)); then - duration=${duration}000 - else - duration="" # clear the result on overflow - fi - fi - - [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_finished "$EXIT" - (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 - export ATUIN_HISTORY_ID="" -} - -__atuin_set_ret_value() { - return ${1:+"$1"} -} - -#------------------------------------------------------------------------------ -# section: __atuin_accept_line -# -# The function "__atuin_accept_line" is kept for backward compatibility of the -# direct use of __atuin_history in keybindings by users. - -# The shell function `__atuin_evaluate_prompt` evaluates prompt sequences in -# $PS1. We switch the implementation of the shell function -# `__atuin_evaluate_prompt` based on the Bash version because the expansion -# ${PS1@P} is only available in bash >= 4.4. -if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 4)); then - __atuin_evaluate_prompt() { - __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" - __atuin_prompt=${PS1@P} - - # Note: Strip the control characters ^A (\001) and ^B (\002), which - # Bash internally uses to enclose the escape sequences. They are - # produced by '\[' and '\]', respectively, in $PS1 and used to tell - # Bash that the strings inbetween do not contribute to the prompt - # width. After the prompt width calculation, Bash strips those control - # characters before outputting it to the terminal. We here strip these - # characters following Bash's behavior. - __atuin_prompt=${__atuin_prompt//[$'\001\002']} - - # Count the number of newlines contained in $__atuin_prompt - __atuin_prompt_offset=${__atuin_prompt//[!$'\n']} - __atuin_prompt_offset=${#__atuin_prompt_offset} - } -else - __atuin_evaluate_prompt() { - __atuin_prompt='$ ' - __atuin_prompt_offset=0 - } -fi - -# The shell function `__atuin_clear_prompt N` outputs terminal control -# sequences to clear the contents of the current and N previous lines. After -# clearing, the cursor is placed at the beginning of the N-th previous line. -__atuin_clear_prompt_cache=() -__atuin_clear_prompt() { - local offset=$1 - if [[ ! ${__atuin_clear_prompt_cache[offset]+set} ]]; then - if [[ ! ${__atuin_clear_prompt_cache[0]+set} ]]; then - __atuin_clear_prompt_cache[0]=$'\r'$(tput el 2>/dev/null || tput ce 2>/dev/null) - fi - if ((offset > 0)); then - __atuin_clear_prompt_cache[offset]=${__atuin_clear_prompt_cache[0]}$( - tput cuu "$offset" 2>/dev/null || tput UP "$offset" 2>/dev/null - tput dl "$offset" 2>/dev/null || tput DL "$offset" 2>/dev/null - tput il "$offset" 2>/dev/null || tput AL "$offset" 2>/dev/null - ) - fi - fi - printf '%s' "${__atuin_clear_prompt_cache[offset]}" -} - -__atuin_accept_line() { - local __atuin_command=$1 - - # Reprint the prompt, accounting for multiple lines - local __atuin_prompt __atuin_prompt_offset - __atuin_evaluate_prompt - __atuin_clear_prompt "$__atuin_prompt_offset" - printf '%s\n' "$__atuin_prompt$__atuin_command" - - # Add it to the bash history - history -s "$__atuin_command" - - # Assuming bash-preexec - # Invoke every function in the preexec array - local __atuin_preexec_function - local __atuin_preexec_function_ret_value - local __atuin_preexec_ret_value=0 - for __atuin_preexec_function in "${preexec_functions[@]:-}"; do - if type -t "$__atuin_preexec_function" 1>/dev/null; then - __atuin_set_ret_value "${__bp_last_ret_value:-}" - "$__atuin_preexec_function" "$__atuin_command" - __atuin_preexec_function_ret_value=$? - if [[ $__atuin_preexec_function_ret_value != 0 ]]; then - __atuin_preexec_ret_value=$__atuin_preexec_function_ret_value - fi - fi - done - - # If extdebug is turned on and any preexec function returns non-zero - # exit status, we do not run the user command. - if ! { shopt -q extdebug && ((__atuin_preexec_ret_value)); }; then - # Note: When a child Bash session is started by enter_accept, if the - # environment variable READLINE_POINT is present, bash-preexec in the - # child session does not fire preexec at all because it considers we - # are inside Atuin's keybinding of the current session. To avoid - # propagating the environment variable to the child session, we remove - # the export attribute of READLINE_LINE and READLINE_POINT. - export -n READLINE_LINE READLINE_POINT - - # Juggle the terminal settings so that the command can be interacted - # with - local __atuin_stty_backup - __atuin_stty_backup=$(stty -g) - stty "$ATUIN_STTY" - - # Execute the command. Note: We need to record $? and $_ after the - # user command within the same call of "eval" because $_ is otherwise - # overwritten by the last argument of "eval". - __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" - eval -- "$__atuin_command"$'\n__bp_last_ret_value=$? __bp_last_argument_prev_command=$_' - - stty "$__atuin_stty_backup" - fi - - # Execute preprompt commands - local __atuin_prompt_command - for __atuin_prompt_command in "${PROMPT_COMMAND[@]}"; do - __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" - eval -- "$__atuin_prompt_command" - done - # Bash will redraw only the line with the prompt after we finish, - # so to work for a multiline prompt we need to print it ourselves, - # then go to the beginning of the last line. - __atuin_evaluate_prompt - printf '%s' "$__atuin_prompt" - __atuin_clear_prompt 0 -} - -#------------------------------------------------------------------------------ - -# Check if tmux popup is available (tmux >= 3.2) -__atuin_tmux_popup_check() { - [[ -n "${TMUX-}" ]] || return 1 - [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 - - # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme - local tmux_version - tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... - [[ -z "$tmux_version" ]] && return 1 - - local m1 m2 - m1=${tmux_version%%.*} - m2=${tmux_version#*.} - m2=${m2%%.*} - [[ "$m1" =~ ^[0-9]+$ ]] || return 1 - [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 - (( m1 > 3 || (m1 == 3 && m2 >= 2) )) -} - -# Use global variable to fix scope issues with traps -__atuin_popup_tmpdir="" -__atuin_tmux_popup_cleanup() { - [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" - __atuin_popup_tmpdir="" -} - -__atuin_search_cmd() { - local -a search_args=("$@") - - if __atuin_tmux_popup_check; then - __atuin_popup_tmpdir=$(mktemp -d) || return 1 - local result_file="$__atuin_popup_tmpdir/result" - - trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM - - local escaped_query escaped_args - escaped_query=$(printf '%s' "$READLINE_LINE" | sed "s/'/'\\\\''/g") - escaped_args="" - for arg in "${search_args[@]}"; do - escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" - done - - # In the popup, atuin goes to terminal, stderr goes to file - local cdir popup_width popup_height - cdir=$(pwd) - popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways - popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" - tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ - sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" - - if [[ -f "$result_file" ]]; then - cat "$result_file" - fi - - __atuin_tmux_popup_cleanup - trap - EXIT HUP INT TERM - else - ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY=$READLINE_LINE atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- - fi -} - -__atuin_history() { - # Default action of the up key: When this function is called with the first - # argument `--shell-up-key-binding`, we perform Atuin's history search only - # when the up key is supposed to cause the history movement in the original - # binding. We do this only for ble.sh because the up key always invokes - # the history movement in the plain Bash. - if [[ ${BLE_ATTACHED-} && ${1-} == --shell-up-key-binding ]]; then - # When the current cursor position is not in the first line, the up key - # should move the cursor to the previous line. While the selection is - # performed, the up key should not start the history search. - # shellcheck disable=SC2154 # Note: these variables are set by ble.sh - if [[ ${_ble_edit_str::_ble_edit_ind} == *$'\n'* || $_ble_edit_mark_active ]]; then - ble/widget/@nomarked backward-line - local status=$? - READLINE_LINE=$_ble_edit_str - READLINE_POINT=$_ble_edit_ind - READLINE_MARK=$_ble_edit_mark - return "$status" - fi - fi - - # READLINE_LINE and READLINE_POINT are only supported by bash >= 4.0 or - # ble.sh. When it is not supported, we clear them to suppress strange - # behaviors. - [[ ${BLE_ATTACHED-} ]] || ((BASH_VERSINFO[0] >= 4)) || - READLINE_LINE="" READLINE_POINT=0 - - local __atuin_output - if ! __atuin_output=$(__atuin_search_cmd "$@"); then - [[ $__atuin_output ]] && printf '%s\n' "$__atuin_output" >&2 - return 1 - fi - - # We do nothing when the search is canceled. - [[ $__atuin_output ]] || return 0 - - if [[ $__atuin_output == __atuin_accept__:* ]]; then - __atuin_output=${__atuin_output#__atuin_accept__:} - - if [[ ${BLE_ATTACHED-} ]]; then - ble-edit/content/reset-and-check-dirty "$__atuin_output" - ble/widget/accept-line - READLINE_LINE="" - elif [[ ${__atuin_macro_chain_keymap-} ]]; then - READLINE_LINE=$__atuin_output - bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_accept_line" - else - __atuin_accept_line "$__atuin_output" - READLINE_LINE="" - fi - - READLINE_POINT=${#READLINE_LINE} - else - READLINE_LINE=$__atuin_output - READLINE_POINT=${#READLINE_LINE} - if [[ ! ${BLE_ATTACHED-} ]] && ((BASH_VERSINFO[0] < 4)) && [[ ${__atuin_macro_chain_keymap-} ]]; then - bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_insert_line" - fi - fi -} - -__atuin_initialize_blesh() { - # shellcheck disable=SC2154 - [[ ${BLE_VERSION-} ]] && ((_ble_version >= 400)) || return 0 - - ble-import contrib/integration/bash-preexec - - # Define and register an autosuggestion source for ble.sh's auto-complete. - # If you'd like to overwrite this, define the same name of shell function - # after the $(atuin init bash) line in your .bashrc. If you do not need - # the auto-complete source by Atuin, please add the following code to - # remove the entry after the $(atuin init bash) line in your .bashrc: - # - # ble/util/import/eval-after-load core-complete ' - # ble/array#remove _ble_complete_auto_source atuin-history' - # - function ble/complete/auto-complete/source:atuin-history { - local suggestion - suggestion=$(ATUIN_QUERY="$_ble_edit_str" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) - [[ $suggestion == "$_ble_edit_str"?* ]] || return 1 - ble/complete/auto-complete/enter h 0 "${suggestion:${#_ble_edit_str}}" '' "$suggestion" - } - ble/util/import/eval-after-load core-complete ' - ble/array#unshift _ble_complete_auto_source atuin-history' - - # @env BLE_SESSION_ID: `atuin doctor` references the environment variable - # BLE_SESSION_ID. We explicitly export the variable because it was not - # exported in older versions of ble.sh. - [[ ${BLE_SESSION_ID-} ]] && export BLE_SESSION_ID -} -__atuin_initialize_blesh -BLE_ONLOAD+=(__atuin_initialize_blesh) -precmd_functions+=(__atuin_precmd) -preexec_functions+=(__atuin_preexec) - -#------------------------------------------------------------------------------ -# section: atuin-bind - -__atuin_widget=() - -__atuin_widget_save() { - local data=$1 - for REPLY in "${!__atuin_widget[@]}"; do - if [[ ${__atuin_widget[REPLY]} == "$data" ]]; then - return 0 - fi - done - # shellcheck disable=SC2154 - REPLY=${#__atuin_widget[*]} - __atuin_widget[REPLY]=$data -} - -__atuin_widget_run() { - local data=${__atuin_widget[$1]} - local keymap=${data%%:*} widget=${data#*:} - local __atuin_macro_chain_keymap=$keymap - bind -m "$keymap" '"'"$__atuin_macro_chain"'": ""' - builtin eval -- "$widget" -} - -# To realize the enter_accept feature in a robust way, we need to call the -# readline bindable function `accept-line'. However, there is no way to call -# `accept-line' from the shell script. To call the bindable function -# `accept-line', we may utilize string macros of readline. When we bind KEYSEQ -# to a WIDGET that wants to conditionally call `accept-line' at the end, we -# perform two-step dispatching: -# -# 1. [KEYSEQ -> IKEYSEQ1 IKEYSEQ2]---We first translate KEYSEQ to two -# intermediate key sequences IKEYSEQ1 and IKEYSEQ2 using string macros. For -# example, when we bind `__atuin_history` to \C-r, this step can be set up by -# `bind '"\C-r": "IKEYSEQ1IKEYSEQ2"'`. -# -# 2. [IKEYSEQ1 -> WIDGET]---Then, IKEYSEQ1 is bound to the WIDGET, and the -# binding of IKEYSEQ2 is dynamically determined by WIDGET. For example, when -# we bind `__atuin_history` to \C-r, this step can be set up by `bind -x -# '"IKEYSEQ1": WIDGET'`. -# -# 3. [IKEYSEQ2 -> accept-line] or [IKEYSEQ2 -> ""]---To request the execution -# of `accept-line', WIDGET can change the binding of IKEYSEQ2 by running -# `bind '"IKEYSEQ2": accept-line''. Otherwise, WIDGET can change the binding -# of IKEYSEQ2 to no-op by running `bind '"IKEYSEQ2": ""'`. -# -# For the choice of the intermediate key sequences, we want to choose key -# sequences that are unlikely to conflict with others. In addition, we want to -# avoid a key sequence containing \e because keymap "vi-insert" stops -# processing key sequences containing \e in older versions of Bash. We have -# used \e[0;A (a variant of the [up] key with modifier ) in Atuin 3.10.0 -# for intermediate key sequences, but this contains \e and caused a problem. -# Instead, we use \C-x\C-_A\a, which starts with \C-x\C-_ (an unlikely -# two-byte combination) and A (represents the initial letter of Atuin), -# followed by the payload and the terminator \a (BEL, \C-g). - -__atuin_macro_chain='\C-x\C-_A0\a' -for __atuin_keymap in emacs vi-insert vi-command; do - bind -m "$__atuin_keymap" "\"$__atuin_macro_chain\": \"\"" -done -unset -v __atuin_keymap - -if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 3)); then - # In Bash >= 4.3 - - __atuin_macro_accept_line=accept-line - - __atuin_bind_impl() { - local keymap=$1 keyseq=$2 command=$3 - - # Note: In Bash <= 5.0, the table for `bind -x` from the keyseq to the - # command is shared by all the keymaps (emacs, vi-insert, and - # vi-command), so one cannot safely bind different command strings to - # the same keyseq in different keymaps. Therefore, the command string - # and the keyseq need to be globally in one-to-one correspondence in - # all the keymaps. - local REPLY - __atuin_widget_save "$keymap:$command" - local widget=$REPLY - local ikeyseq1='\C-x\C-_A'$((1 + widget))'\a' - local ikeyseq2=$__atuin_macro_chain - - if ((BASH_VERSINFO[0] == 5 && BASH_VERSINFO[1] == 1)); then - # Workaround for Bash 5.1: Bash 5.1 has a bug that overwriting an - # existing "bind -x" keybinding breaks other existing "bind -x" - # keybindings [1,2]. To work around the problem, we explicitly - # unbind an existing keybinding before overwriting it. - # - # [1] https://lists.gnu.org/archive/html/bug-bash/2021-04/msg00135.html - # [2] https://github.com/atuinsh/atuin/issues/962#issuecomment-3451132291 - bind -m "$keymap" -r "$keyseq" - fi - - bind -m "$keymap" "\"$keyseq\": \"$ikeyseq1$ikeyseq2\"" - bind -m "$keymap" -x "\"$ikeyseq1\": __atuin_widget_run $widget" - } - - __atuin_bind_blesh_onload() { - # In ble.sh, we need to enable unrecognized CSI sequences like \e[0;0A, - # which are discarded by ble.sh by default. Note: In Bash <= 4.2, we - # do not need to unset "decode_error_cseq_discard" because \e[0;A is - # used only for the macro chaining (which is unused by ble.sh) in Bash - # <= 4.2. - bleopt decode_error_cseq_discard= - } - if [[ ${BLE_VERSION-} ]]; then - __atuin_bind_blesh_onload - fi - BLE_ONLOAD+=(__atuin_bind_blesh_onload) -else - # In Bash <= 4.2, "bind -x" cannot bind a shell command to a keyseq having - # more than two bytes, so we need to work with only two-byte sequences. - # - # However, the number of available combinations of two-byte sequences is - # limited. To minimize the number of key sequences used by Atuin, instead - # of specifying a widget by its own intermediate sequence, we specify a - # widget by a fixed-length sequence of multiple two-byte sequences. More - # specifically, instead of IKEYSEQ1, we use IKS1 IKS2 IKS3 [IKS4 IKS5] - # IKSX, where IKS1..IKS5 just stores its information to a global variable, - # and IKSX collects all the information and determine and call the actual - # widget based on the stored information. Each of IKn (n=1..5) is one of - # the two reserved sequences, $__atuin_bash42_code0 and - # $__atuin_bash42_code1. IKSX is fixed to be $__atuin_bash42_code2. - # - # For the choices of the special key sequences, we consider \C-xQ, \C-xR, - # and \C-xS. In the emacs editing mode of Bash, \C-x is used as a prefix - # key, i.e., it is used for the beginning key of the keybindings with - # multiple keys, so \C-x is unlikely to be used for a single-key binding by - # the user. Also, \C-x is not used in the vi editing mode by default. The - # combinations \C-xQ..\C-xS are also unlikely be used because we need to - # switch the modifier keys from Control to Shift to input these sequences, - # and these are not easy to input. - __atuin_bash42_code0='\C-xQ' - __atuin_bash42_code1='\C-xR' - __atuin_bash42_code2='\C-xS' - - __atuin_bash42_encode() { - REPLY= - local n=$1 min_width=${2-} - while - if ((n % 2 == 0)); then - REPLY=$__atuin_bash42_code0$REPLY - else - REPLY=$__atuin_bash42_code1$REPLY - fi - (((n /= 2) || ${#REPLY} / ${#__atuin_bash42_code0} < min_width)) - do :; done - } - - __atuin_bash42_bind() { - local __atuin_keymap - for __atuin_keymap in emacs vi-insert vi-command; do - bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code0"'": __atuin_bash42_dispatch_selector+=0' - bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code1"'": __atuin_bash42_dispatch_selector+=1' - bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code2"'": __atuin_bash42_dispatch' - done - } - __atuin_bash42_bind - # In Bash <= 4.2, there is no way to read users' "bind -x" settings, so we - # need to explicitly perform "bind -x" when ble.sh is loaded. - BLE_ONLOAD+=(__atuin_bash42_bind) - - if ((BASH_VERSINFO[0] >= 4)); then - __atuin_macro_accept_line=accept-line - else - # Note: We rewrite the command line and invoke `accept-line'. In - # bash <= 3.2, there is no way to rewrite the command line from the - # shell script, so we rewrite it using a macro and - # `shell-expand-line'. - # - # Note: Concerning the key sequences to invoke bindable functions - # such as "\C-x\C-_A1\a", another option is to use - # "\exbegginning-of-line\r", etc. to make it consistent with bash - # >= 5.3. However, an older Bash configuration can still conflict - # on [M-x]. The conflict is more likely than \C-x\C-_A1\a. - for __atuin_keymap in emacs vi-insert vi-command; do - bind -m "$__atuin_keymap" '"\C-x\C-_A1\a": beginning-of-line' - bind -m "$__atuin_keymap" '"\C-x\C-_A2\a": kill-line' - # shellcheck disable=SC2016 - bind -m "$__atuin_keymap" '"\C-x\C-_A3\a": "$READLINE_LINE"' - bind -m "$__atuin_keymap" '"\C-x\C-_A4\a": shell-expand-line' - bind -m "$__atuin_keymap" '"\C-x\C-_A5\a": accept-line' - bind -m "$__atuin_keymap" '"\C-x\C-_A6\a": end-of-line' - done - unset -v __atuin_keymap - - bind -m vi-command '"\C-x\C-_A7\a": vi-insertion-mode' - bind -m vi-insert '"\C-x\C-_A7\a": vi-movement-mode' - - # "\C-x\C-_A10\a": Replace the command line with READLINE_LINE. When we are - # in the vi-command keymap, we go to vi-insert, input - # "$READLINE_LINE", and come back to vi-command. - bind -m emacs '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' - bind -m vi-insert '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' - bind -m vi-command '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A7\a\C-x\C-_A3\a\C-x\C-_A7\a\C-x\C-_A4\a"' - - __atuin_macro_accept_line='"\C-x\C-_A10\a\C-x\C-_A5\a"' - __atuin_macro_insert_line='"\C-x\C-_A10\a\C-x\C-_A6\a"' - fi - - __atuin_bash42_dispatch_selector= - - __atuin_bash42_dispatch() { - local s=$__atuin_bash42_dispatch_selector - __atuin_bash42_dispatch_selector= - __atuin_widget_run "$((2#0$s))" - } - - __atuin_bind_impl() { - local keymap=$1 keyseq=$2 command=$3 - - __atuin_widget_save "$keymap:$command" - __atuin_bash42_encode "$REPLY" - local macro=$REPLY$__atuin_bash42_code2$__atuin_macro_chain - - bind -m "$keymap" "\"$keyseq\": \"$macro\"" - } -fi - -atuin-bind() { - local keymap= - local OPTIND=1 OPTARG="" OPTERR=0 flag - while getopts ':m:' flag "$@"; do - case $flag in - m) keymap=$OPTARG ;; - *) - printf '%s\n' "atuin-bind: unrecognized option '-$flag'" >&2 - return 2 - ;; - esac - done - shift "$((OPTIND - 1))" - - if (($# != 2)); then - printf '%s\n' 'usage: atuin-bind [-m keymap] keyseq widget' >&2 - return 2 - fi - - local keyseq=$1 - [[ $keymap ]] || keymap=$(bind -v | awk '$2 == "keymap" { print $3 }') - case $keymap in - emacs-meta) keymap=emacs keyseq='\e'$keyseq ;; - emacs-ctlx) keymap=emacs keyseq='\C-x'$keyseq ;; - emacs*) keymap=emacs ;; - vi-insert) ;; - vi*) keymap=vi-command ;; - *) - printf '%s\n' "atuin-bind: unknown keymap '$keymap'" >&2 - return 2 ;; - esac - - local command=$2 widget=${2%%[[:blank:]]*} - case $widget in - atuin-search) command=${2/#"$widget"/__atuin_history} ;; - atuin-search-emacs) command=${2/#"$widget"/__atuin_history --keymap-mode=emacs} ;; - atuin-search-viins) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-insert} ;; - atuin-search-vicmd) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-normal} ;; - atuin-up-search) command=${2/#"$widget"/__atuin_history --shell-up-key-binding} ;; - atuin-up-search-emacs) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=emacs} ;; - atuin-up-search-viins) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-insert} ;; - atuin-up-search-vicmd) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-normal} ;; - esac - - __atuin_bind_impl "$keymap" "$keyseq" "$command" -} - -#------------------------------------------------------------------------------ - -# shellcheck disable=SC2154 -if [[ $__atuin_bind_ctrl_r == true ]]; then - # Note: We do not overwrite [C-r] in the vi-command keymap because we do - # not want to overwrite "redo", which is already bound to [C-r] in the - # vi_nmap keymap in ble.sh. - atuin-bind -m emacs '\C-r' atuin-search-emacs - atuin-bind -m vi-insert '\C-r' atuin-search-viins - atuin-bind -m vi-command '/' atuin-search-emacs -fi - -# shellcheck disable=SC2154 -if [[ $__atuin_bind_up_arrow == true ]]; then - atuin-bind -m emacs '\e[A' atuin-up-search-emacs - atuin-bind -m emacs '\eOA' atuin-up-search-emacs - atuin-bind -m vi-insert '\e[A' atuin-up-search-viins - atuin-bind -m vi-insert '\eOA' atuin-up-search-viins - atuin-bind -m vi-command '\e[A' atuin-up-search-vicmd - atuin-bind -m vi-command '\eOA' atuin-up-search-vicmd - atuin-bind -m vi-command 'k' atuin-up-search-vicmd -fi - -#------------------------------------------------------------------------------ -fi # (include guard) end of main content diff --git a/crates/atuin/src/shell/atuin.fish b/crates/atuin/src/shell/atuin.fish deleted file mode 100644 index 15b33451..00000000 --- a/crates/atuin/src/shell/atuin.fish +++ /dev/null @@ -1,178 +0,0 @@ -if not set -q ATUIN_SESSION; or test "$ATUIN_SHLVL" != "$SHLVL" - set -gx ATUIN_SESSION (atuin uuid) - set -gx ATUIN_SHLVL $SHLVL -end -set --erase ATUIN_HISTORY_ID - -function _atuin_osc133_command_executed - set -q ATUIN_PTY_PROXY_ACTIVE; or return - test -n "$ATUIN_HISTORY_ID"; or return - - printf '\033]133;C\a' -end - -function _atuin_osc133_command_finished --argument-names exit_code - set -q ATUIN_PTY_PROXY_ACTIVE; or return - test -n "$ATUIN_HISTORY_ID"; or return - - printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$exit_code" "$ATUIN_HISTORY_ID" "$ATUIN_SESSION" -end - -function _atuin_preexec --on-event fish_preexec - if not test -n "$fish_private_mode" - set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]" 2>/dev/null) - _atuin_osc133_command_executed - end -end - -function _atuin_postexec --on-event fish_postexec - set -l s $status - - if test -n "$ATUIN_HISTORY_ID" - _atuin_osc133_command_finished $s - ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & - disown - end - - set --erase ATUIN_HISTORY_ID -end - -# Check if tmux popup is available (tmux >= 3.2) -function _atuin_tmux_popup_check - if not test -n "$TMUX" - echo 0 - return - end - - if test "$ATUIN_TMUX_POPUP" = false - echo 0 - return - end - - set -l tmux_version (tmux -V 2>/dev/null | string match -r '\d+\.\d+') - if not test -n "$tmux_version" - echo 0 - return - end - - set -l parts (string split '.' $tmux_version) - set -l m1 $parts[1] - set -l m2 0 - if test (count $parts) -ge 2 - set m2 $parts[2] - end - - if not string match -rq '^[0-9]+$' -- "$m1" - echo 0 - return - end - - if not string match -rq '^[0-9]+$' -- "$m2" - set m2 0 - end - - if test "$m1" -gt 3 2>/dev/null; or begin - test "$m1" -eq 3 2>/dev/null; and test "$m2" -ge 2 2>/dev/null - end - echo 1 - else - echo 0 - end -end - -function _atuin_search - set -l keymap_mode - switch $fish_key_bindings - case fish_vi_key_bindings fish_hybrid_key_bindings - switch $fish_bind_mode - case default - set keymap_mode vim-normal - case insert - set keymap_mode vim-insert - end - case '*' - set keymap_mode emacs - end - - set -l use_tmux_popup (_atuin_tmux_popup_check) - - set -l ATUIN_H - set -l ATUIN_STATUS 0 - if test "$use_tmux_popup" -eq 1 - set -l tmpdir (mktemp -d) - if not test -d "$tmpdir" - # if mktemp got errors - set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) - set ATUIN_STATUS $pipestatus[1] - else - set -l result_file "$tmpdir/result" - - set -l query (commandline -b | string replace -a "'" "'\\''") - set -l escaped_args "" - for arg in $argv - set escaped_args "$escaped_args '"(string replace -a "'" "'\\''" -- $arg)"'" - end - - # In the popup, atuin goes to terminal, stderr goes to file - set -l cdir (pwd) - # Keep default value anyways - set -l popup_width (test -n "$ATUIN_TMUX_POPUP_WIDTH" && echo "$ATUIN_TMUX_POPUP_WIDTH" || echo "80%") - set -l popup_height (test -n "$ATUIN_TMUX_POPUP_HEIGHT" && echo "$ATUIN_TMUX_POPUP_HEIGHT" || echo "60%") - tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ - sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY='$query' atuin search --keymap-mode=$keymap_mode$escaped_args -i 2>'$result_file'" - set ATUIN_STATUS $status - - if test -f "$result_file" - set ATUIN_H (cat "$result_file" | string collect) - end - - command rm -rf "$tmpdir" - end - else - # In fish 3.4 and above we can use `"$(some command)"` to keep multiple lines separate; - # but to support fish 3.3 we need to use `(some command | string collect)`. - # https://fishshell.com/docs/current/relnotes.html#id24 (fish 3.4 "Notable improvements and fixes") - set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) - set ATUIN_STATUS $pipestatus[1] - end - - if test "$ATUIN_STATUS" -ne 0 - test -n "$ATUIN_H"; and printf '%s\n' "$ATUIN_H" >&2 - commandline -f repaint - return "$ATUIN_STATUS" - end - - set ATUIN_H (string trim -- $ATUIN_H | string collect) # trim whitespace - - if test -n "$ATUIN_H" - if string match --quiet '__atuin_accept__:*' "$ATUIN_H" - set -l ATUIN_HIST (string replace "__atuin_accept__:" "" -- "$ATUIN_H" | string collect) - commandline -r "$ATUIN_HIST" - commandline -f repaint - commandline -f execute - return - else - commandline -r "$ATUIN_H" - end - end - - commandline -f repaint -end - -function _atuin_bind_up - # Fallback to fish's builtin up-or-search if we're in search or paging mode - if commandline --search-mode; or commandline --paging-mode - up-or-search - return - end - - # Only invoke atuin if we're on the top line of the command - set -l lineno (commandline --line) - - switch $lineno - case 1 - _atuin_search --shell-up-key-binding - case '*' - up-or-search - end -end diff --git a/crates/atuin/src/shell/atuin.nu b/crates/atuin/src/shell/atuin.nu deleted file mode 100644 index d37457e4..00000000 --- a/crates/atuin/src/shell/atuin.nu +++ /dev/null @@ -1,121 +0,0 @@ -# Source this in your ~/.config/nushell/config.nu -# minimum supported version = 0.93.0 -module compat { - export def --wrapped "random uuid -v 7" [...rest] { atuin uuid } -} -use (if not ( - (version).major > 0 or - (version).minor >= 103 -) { "compat" }) * - -if 'ATUIN_SESSION' not-in $env or ('ATUIN_SHLVL' not-in $env) or ($env.ATUIN_SHLVL != ($env.SHLVL? | default "")) { - $env.ATUIN_SESSION = (random uuid -v 7 | str replace -a "-" "") - $env.ATUIN_SHLVL = ($env.SHLVL? | default "") -} -hide-env -i ATUIN_HISTORY_ID - -def _atuin_osc133_command_executed [] { - if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { - return - } - if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { - return - } - - print -n $"(char esc)]133;C(char bel)" -} - -def _atuin_osc133_command_finished [exit_code: int] { - if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { - return - } - if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { - return - } - - print -n $"(char esc)]133;D;($exit_code);history_id=($env.ATUIN_HISTORY_ID);session_id=($env.ATUIN_SESSION)(char bel)" -} - -# Magic token to make sure we don't record commands run by keybindings -let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" - -let _atuin_pre_execution = {|| - if ($nu | get history-enabled?) == false { - return - } - let cmd = (commandline) - if ($cmd | is-empty) { - return - } - if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { - $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd | complete | get stdout | str trim) - _atuin_osc133_command_executed - } -} - -let _atuin_pre_prompt = {|| - let last_exit = $env.LAST_EXIT_CODE - if 'ATUIN_HISTORY_ID' not-in $env { - return - } - _atuin_osc133_command_finished $last_exit - with-env { ATUIN_LOG: error } { - if (version).minor >= 104 or (version).major > 0 { - job spawn { - ^atuin history end $'--exit=($env.LAST_EXIT_CODE)' -- $env.ATUIN_HISTORY_ID | complete - } | ignore - } else { - do { atuin history end $'--exit=($last_exit)' -- $env.ATUIN_HISTORY_ID } | complete - } - - } - hide-env ATUIN_HISTORY_ID -} - -def _atuin_search_cmd [...flags: string] { - if (version).minor >= 106 or (version).major > 0 { - [ - $ATUIN_KEYBINDING_TOKEN, - ([ - `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline), ATUIN_SHELL: nu } {`, - ([ - 'let output = (run-external atuin search', - ($flags | append [--interactive] | each {|e| $'"($e)"'}), - 'e>| str trim)', - ] | flatten | str join ' '), - 'if ($output | str starts-with "__atuin_accept__:") {', - 'commandline edit --accept ($output | str replace "__atuin_accept__:" "")', - '} else {', - 'commandline edit $output', - '}', - `}`, - ] | flatten | str join "\n"), - ] - } else { - [ - $ATUIN_KEYBINDING_TOKEN, - ([ - `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline) } {`, - 'commandline edit', - '(run-external atuin search', - ($flags | append [--interactive] | each {|e| $'"($e)"'}), - ' e>| str trim)', - `}`, - ] | flatten | str join ' '), - ] - } | str join "\n" -} - -$env.config = ($env | default {} config).config -$env.config = ($env.config | default {} hooks) -$env.config = ( - $env.config | upsert hooks ( - $env.config.hooks - | upsert pre_execution ( - $env.config.hooks | get pre_execution? | default [] | append $_atuin_pre_execution) - | upsert pre_prompt ( - $env.config.hooks | get pre_prompt? | default [] | append $_atuin_pre_prompt) - ) -) - -$env.config = ($env.config | default [] keybindings) diff --git a/crates/atuin/src/shell/atuin.ps1 b/crates/atuin/src/shell/atuin.ps1 deleted file mode 100644 index 431ee2c3..00000000 --- a/crates/atuin/src/shell/atuin.ps1 +++ /dev/null @@ -1,240 +0,0 @@ -# Atuin PowerShell module -# -# This should support PowerShell 5.1 (which is shipped with Windows) and later versions, on Windows and Linux. -# -# Usage: atuin init powershell | Out-String | Invoke-Expression -# -# Settings: -# - $env:ATUIN_POWERSHELL_PROMPT_OFFSET - Number of lines to offset the prompt position after exiting search. -# This is useful when using a multi-line prompt: e.g. set this to -1 when using a 2-line prompt. -# It is initialized from the current prompt line count if not set when the first Atuin search is performed. - -if (Get-Module Atuin -ErrorAction Ignore) { - if ($PSVersionTable.PSVersion.Major -ge 7) { - Write-Warning "The Atuin module is already loaded, replacing it." - Remove-Module Atuin - } else { - Write-Warning "The Atuin module is already loaded, skipping." - return - } -} - -if (!(Get-Command atuin -ErrorAction Ignore)) { - Write-Error "The 'atuin' executable needs to be available in the PATH." - return -} - -if (!(Get-Module PSReadLine -ErrorAction Ignore)) { - Write-Error "Atuin requires the PSReadLine module to be installed." - return -} - -New-Module -Name Atuin -ScriptBlock { - if (-not $env:ATUIN_SESSION -or $env:ATUIN_PID -ne $PID) { - $env:ATUIN_SESSION = atuin uuid - $env:ATUIN_PID = $PID - } - - $script:atuinHistoryId = $null - $script:previousPSConsoleHostReadLine = $Function:PSConsoleHostReadLine - - # The ReadLine overloads changed with breaking changes over time, make sure the one we expect is available. - $script:hasExpectedReadLineOverload = ([Microsoft.PowerShell.PSConsoleReadLine]::ReadLine).OverloadDefinitions.Contains("static string ReadLine(runspace runspace, System.Management.Automation.EngineIntrinsics engineIntrinsics, System.Threading.CancellationToken cancellationToken, System.Nullable[bool] lastRunStatus)") - - function Get-CommandLine { - $commandLine = "" - [Microsoft.PowerShell.PSConsoleReadLine]::GetBufferState([ref]$commandLine, [ref]$null) - return $commandLine - } - - function Set-CommandLine { - param([string]$Text) - - $commandLine = Get-CommandLine - [Microsoft.PowerShell.PSConsoleReadLine]::Replace(0, $commandLine.Length, $Text) - } - - # This function name is called by PSReadLine to read the next command line to execute. - # We replace it with a custom implementation which adds Atuin support. - function PSConsoleHostReadLine { - ## 1. Collect the exit code of the previous command. - - # This needs to be done as the first thing because any script run will flush $?. - $lastRunStatus = $? - - # Exit statuses are maintained separately for native and PowerShell commands, this needs to be taken into account. - $lastNativeExitCode = $global:LASTEXITCODE - $exitCode = if ($lastRunStatus) { 0 } elseif ($lastNativeExitCode) { $lastNativeExitCode } else { 1 } - - ## 2. Report the status of the previous command to Atuin (atuin history end). - - if ($script:atuinHistoryId) { - try { - # The duration is not recorded in old PowerShell versions, let Atuin handle it. $null arguments are ignored. - $duration = (Get-History -Count 1).Duration.Ticks * 100 - $durationArg = if ($duration) { "--duration=$duration" } else { $null } - - # Fire and forget the atuin history end command to avoid blocking the shell during a potential sync. - $process = New-Object System.Diagnostics.Process - $process.StartInfo.FileName = "atuin" - $process.StartInfo.Arguments = "history end --exit=$exitCode $durationArg -- $script:atuinHistoryId" - $process.StartInfo.UseShellExecute = $false - $process.StartInfo.CreateNoWindow = $true - $process.StartInfo.RedirectStandardInput = $true - $process.StartInfo.RedirectStandardOutput = $true - $process.StartInfo.RedirectStandardError = $true - $process.Start() | Out-Null - $process.StandardInput.Close() - $process.BeginOutputReadLine() - $process.BeginErrorReadLine() - } - catch { - # Ignore errors to avoid breaking the shell. - # An error would occur if the user removes atuin from the PATH, for instance. - } - finally { - $script:atuinHistoryId = $null - } - } - - ## 3. Read the next command line to execute. - - # PSConsoleHostReadLine implementation from PSReadLine, adjusted to support old versions. - Microsoft.PowerShell.Core\Set-StrictMode -Off - - $line = if ($script:hasExpectedReadLineOverload) { - # When the overload we expect is available, we can pass $lastRunStatus to it. - [Microsoft.PowerShell.PSConsoleReadLine]::ReadLine($Host.Runspace, $ExecutionContext, [System.Threading.CancellationToken]::None, $lastRunStatus) - } else { - # Either PSReadLine is older than v2.2.0-beta3, or maybe newer than we expect, so use the function from PSReadLine as-is. - & $script:previousPSConsoleHostReadLine - } - - ## 4. Report the next command line to Atuin (atuin history start). - - # PowerShell doesn't handle double quotes in native command line arguments the same way depending on its version, - # and the value of $PSNativeCommandArgumentPassing - see the about_Parsing help page which explains the breaking changes. - # This makes it unreliable, so we go through an environment variable, which should always be consistent across versions. - try { - $env:ATUIN_COMMAND_LINE = $line - $script:atuinHistoryId = atuin history start --command-from-env - } - catch { - # Ignore errors to avoid breaking the shell, see above. - } - finally { - $env:ATUIN_COMMAND_LINE = $null - } - - $global:LASTEXITCODE = $lastNativeExitCode - return $line - } - - function Invoke-AtuinSearch { - param([string]$ExtraArgs = "") - - $previousOutputEncoding = [System.Console]::OutputEncoding - $resultFile = New-TemporaryFile - $suggestion = "" - $errorOutput = "" - - try { - [System.Console]::OutputEncoding = [System.Text.Encoding]::UTF8 - - # Start-Process does some crazy stuff, just use the Process class directly to have more control. - $process = New-Object System.Diagnostics.Process - $process.StartInfo.FileName = "atuin" - $process.StartInfo.Arguments = "search -i --result-file ""$($resultFile.FullName)"" $ExtraArgs" - $process.StartInfo.UseShellExecute = $false - $process.StartInfo.RedirectStandardError = $true - $process.StartInfo.StandardErrorEncoding = [System.Text.Encoding]::UTF8 - $process.StartInfo.EnvironmentVariables["ATUIN_SHELL"] = "powershell" - $process.StartInfo.EnvironmentVariables["ATUIN_QUERY"] = Get-CommandLine - # PowerShell's Set-Location (cd) doesn't update the process-level working directory, set it explicitly - $process.StartInfo.WorkingDirectory = (Get-Location -PSProvider FileSystem).ProviderPath - - try { - $process.Start() | Out-Null - - # A single stream is redirected, so we can read it synchronously, but we have to start reading it - # before waiting for the process to exit, otherwise the buffer could fill up and cause a deadlock. - $errorOutput = $process.StandardError.ReadToEnd().Trim() - $process.WaitForExit() - - $suggestion = (Get-Content -LiteralPath $resultFile.FullName -Raw -Encoding UTF8 | Out-String).Trim() - } - catch { - $errorOutput = $_ - } - - if ($errorOutput) { - Write-Host -ForegroundColor Red "Atuin error:" - Write-Host -ForegroundColor DarkRed $errorOutput - } - - # If no shell prompt offset is set, initialize it from the current prompt line count. - if ($null -eq $env:ATUIN_POWERSHELL_PROMPT_OFFSET) { - try { - $promptLines = (& $Function:prompt | Out-String | Measure-Object -Line).Lines - $env:ATUIN_POWERSHELL_PROMPT_OFFSET = -1 * ($promptLines - 1) - } - catch { - $env:ATUIN_POWERSHELL_PROMPT_OFFSET = 0 - } - } - - # PSReadLine maintains its own cursor position, which will no longer be valid if Atuin scrolls the display in inline mode. - # Fortunately, InvokePrompt can receive a new Y position and reset the internal state. - $y = $Host.UI.RawUI.CursorPosition.Y + [int]$env:ATUIN_POWERSHELL_PROMPT_OFFSET - $y = [System.Math]::Max([System.Math]::Min($y, [System.Console]::BufferHeight - 1), 0) - [Microsoft.PowerShell.PSConsoleReadLine]::InvokePrompt($null, $y) - - if ($suggestion -eq "") { - # The previous input was already rendered by InvokePrompt - return - } - - $acceptPrefix = "__atuin_accept__:" - - if ( $suggestion.StartsWith($acceptPrefix)) { - Set-CommandLine $suggestion.Substring($acceptPrefix.Length) - [Microsoft.PowerShell.PSConsoleReadLine]::AcceptLine() - } else { - Set-CommandLine $suggestion - } - } - finally { - [System.Console]::OutputEncoding = $previousOutputEncoding - $resultFile.Delete() - } - } - - function Enable-AtuinSearchKeys { - param([bool]$CtrlR = $true, [bool]$UpArrow = $true) - - if ($CtrlR) { - Set-PSReadLineKeyHandler -Chord "Ctrl+r" -BriefDescription "Runs Atuin search" -ScriptBlock { - Invoke-AtuinSearch - } - } - - if ($UpArrow) { - Set-PSReadLineKeyHandler -Chord "UpArrow" -BriefDescription "Runs Atuin search" -ScriptBlock { - $line = Get-CommandLine - - if (!$line.Contains("`n")) { - Invoke-AtuinSearch -ExtraArgs "--shell-up-key-binding" - } else { - [Microsoft.PowerShell.PSConsoleReadLine]::PreviousLine() - } - } - } - } - - $ExecutionContext.SessionState.Module.OnRemove += { - $env:ATUIN_SESSION = $null - $Function:PSConsoleHostReadLine = $script:previousPSConsoleHostReadLine - } - - Export-ModuleMember -Function @("Enable-AtuinSearchKeys", "PSConsoleHostReadLine") -} | Import-Module -Global diff --git a/crates/atuin/src/shell/atuin.xsh b/crates/atuin/src/shell/atuin.xsh deleted file mode 100644 index a0283402..00000000 --- a/crates/atuin/src/shell/atuin.xsh +++ /dev/null @@ -1,86 +0,0 @@ -import os -import subprocess - -from prompt_toolkit.application.current import get_app -from prompt_toolkit.filters import Condition -from prompt_toolkit.keys import Keys - - -if "ATUIN_SESSION" not in ${...} or ${...}.get("ATUIN_SHLVL", "") != ${...}.get("SHLVL", ""): - $ATUIN_SESSION=$(atuin uuid).rstrip('\n') - $ATUIN_SHLVL = ${...}.get("SHLVL", "") - -@events.on_precommand -def _atuin_precommand(cmd: str): - cmd = cmd.rstrip("\n") - try: - $ATUIN_HISTORY_ID = $(atuin history start -- @(cmd) 2>@(os.devnull)).rstrip("\n") - except: - $ATUIN_HISTORY_ID = "" - - -@events.on_postcommand -def _atuin_postcommand(cmd: str, rtn: int, out, ts): - if "ATUIN_HISTORY_ID" not in ${...}: - return - - duration = ts[1] - ts[0] - # Duration is float representing seconds, but atuin expects integer of nanoseconds - nanos = round(duration * 10 ** 9) - with ${...}.swap(ATUIN_LOG="error"): - # This causes the entire .xonshrc to be re-executed, which is incredibly slow - # This happens when using a subshell and using output redirection at the same time - # For more details, see https://github.com/xonsh/xonsh/issues/5224 - # (atuin history end --exit @(rtn) -- $ATUIN_HISTORY_ID &) > /dev/null 2>&1 - atuin history end --exit @(rtn) --duration @(nanos) -- $ATUIN_HISTORY_ID > @(os.devnull) 2>&1 - del $ATUIN_HISTORY_ID - - -def _search(event, extra_args: list[str]): - buffer = event.current_buffer - cmd = ["atuin", "search", "--interactive", *extra_args] - # We need to explicitly pass in xonsh env, in case user has set XDG_HOME or something else that matters - env = ${...}.detype() - env["ATUIN_SHELL"] = "xonsh" - env["ATUIN_QUERY"] = buffer.text - - p = subprocess.run(cmd, stderr=subprocess.PIPE, encoding="utf-8", env=env) - result = p.stderr.rstrip("\n") - # redraw prompt - necessary if atuin is configured to run inline, rather than fullscreen - event.cli.renderer.erase() - - if not result: - return - - buffer.reset() - if result.startswith("__atuin_accept__:"): - buffer.insert_text(result[17:]) - buffer.validate_and_handle() - else: - buffer.insert_text(result) - - -@events.on_ptk_create -def _custom_keybindings(bindings, **kw): - if _ATUIN_BIND_CTRL_R: - @bindings.add(Keys.ControlR) - def r_search(event): - _search(event, extra_args=[]) - - if _ATUIN_BIND_UP_ARROW: - @Condition - def should_search(): - buffer = get_app().current_buffer - # disable keybind when there is an active completion, so - # that up arrow can be used to navigate completion menu - if buffer.complete_state is not None: - return False - # similarly, disable when buffer text contains multiple lines - if '\n' in buffer.text: - return False - - return True - - @bindings.add(Keys.Up, filter=should_search) - def up_search(event): - _search(event, extra_args=["--shell-up-key-binding"]) diff --git a/crates/atuin/src/shell/atuin.zsh b/crates/atuin/src/shell/atuin.zsh deleted file mode 100644 index 7a7375aa..00000000 --- a/crates/atuin/src/shell/atuin.zsh +++ /dev/null @@ -1,221 +0,0 @@ -# shellcheck disable=SC2034,SC2153,SC2086,SC2155 - -# Above line is because shellcheck doesn't support zsh, per -# https://github.com/koalaman/shellcheck/wiki/SC1071, and the ignore: param in -# ludeeus/action-shellcheck only supports _directories_, not _files_. So -# instead, we manually add any error the shellcheck step finds in the file to -# the above line ... - -# Source this in your ~/.zshrc -autoload -U add-zsh-hook - -zmodload zsh/datetime 2>/dev/null - -# If zsh-autosuggestions is installed, configure it to use Atuin's search. If -# you'd like to override this, then add your config after the $(atuin init zsh) -# in your .zshrc -_zsh_autosuggest_strategy_atuin() { - # silence errors, since we don't want to spam the terminal prompt while typing. - suggestion=$(ATUIN_QUERY="$1" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) -} - -if [ -n "${ZSH_AUTOSUGGEST_STRATEGY:-}" ]; then - ZSH_AUTOSUGGEST_STRATEGY=("atuin" "${ZSH_AUTOSUGGEST_STRATEGY[@]}") -else - ZSH_AUTOSUGGEST_STRATEGY=("atuin") -fi - -if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then - export ATUIN_SESSION=$(atuin uuid) - export ATUIN_SHLVL=$SHLVL -fi -ATUIN_HISTORY_ID="" - -__atuin_osc133_command_executed() { - [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return - [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return - - printf '\033]133;C\a' -} - -__atuin_osc133_command_finished() { - [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return - [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return - - printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" -} - -__atuin_osc133_prompt_start=$'%{\033]133;A;cl=line\a%}' -__atuin_osc133_prompt_end=$'%{\033]133;B\a%}' - -__atuin_osc133_wrap_prompt() { - local __atuin_prompt="${PROMPT-}" - local __atuin_rprompt="${RPROMPT-}" - - __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" - __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" - __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_start/}" - __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_end/}" - - if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then - PROMPT="${__atuin_osc133_prompt_start}${__atuin_prompt}" - RPROMPT="${__atuin_rprompt}${__atuin_osc133_prompt_end}" - else - PROMPT="$__atuin_prompt" - RPROMPT="$__atuin_rprompt" - fi -} - -_atuin_preexec() { - local id - id=$(atuin history start -- "$1" 2>/dev/null) - export ATUIN_HISTORY_ID="$id" - __atuin_osc133_command_executed - __atuin_preexec_time=${EPOCHREALTIME-} -} - -_atuin_precmd() { - local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} - - __atuin_osc133_wrap_prompt - - [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return - - local duration="" - if [[ -n $__atuin_preexec_time && -n $__atuin_precmd_time ]]; then - printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) - fi - - __atuin_osc133_command_finished "$EXIT" - (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 - export ATUIN_HISTORY_ID="" -} - -# Check if tmux popup is available (tmux >= 3.2) -__atuin_tmux_popup_check() { - [[ -n "${TMUX-}" ]] || return 1 - [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 - - # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme - local tmux_version - tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... - [[ -z "$tmux_version" ]] && return 1 - - local m1 m2 - m1=${tmux_version%%.*} - m2=${tmux_version#*.} - m2=${m2%%.*} - [[ "$m1" =~ ^[0-9]+$ ]] || return 1 - [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 - (( m1 > 3 || (m1 == 3 && m2 >= 2) )) -} - -# Use global variable to fix scope issues with traps -__atuin_popup_tmpdir="" -__atuin_tmux_popup_cleanup() { - [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" - __atuin_popup_tmpdir="" -} - -__atuin_search_cmd() { - local -a search_args=("$@") - - if __atuin_tmux_popup_check; then - __atuin_popup_tmpdir=$(mktemp -d) || return 1 - local result_file="$__atuin_popup_tmpdir/result" - - trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM - - local escaped_query escaped_args - escaped_query=$(printf '%s' "$BUFFER" | sed "s/'/'\\\\''/g") - escaped_args="" - for arg in "${search_args[@]}"; do - escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" - done - - # In the popup, atuin goes to terminal, stderr goes to file - local cdir popup_width popup_height - cdir=$(pwd) - popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways - popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" - tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ - sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" - - if [[ -f "$result_file" ]]; then - cat "$result_file" - fi - - __atuin_tmux_popup_cleanup - trap - EXIT HUP INT TERM - else - ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY=$BUFFER atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- - fi -} - -_atuin_search() { - emulate -L zsh - zle -I - - # swap stderr and stdout, so that the tui stuff works - # TODO: not this - local output __atuin_status - # shellcheck disable=SC2048 - output=$(__atuin_search_cmd $*) - __atuin_status=$? - - zle reset-prompt - # re-enable bracketed paste - # shellcheck disable=SC2154 - echo -n ${zle_bracketed_paste[1]} >/dev/tty - - if (( __atuin_status != 0 )); then - [[ -n $output ]] && print -r -- "$output" >/dev/tty - return $__atuin_status - fi - - if [[ -n $output ]]; then - RBUFFER="" - LBUFFER=$output - - if [[ $LBUFFER == __atuin_accept__:* ]] - then - LBUFFER=${LBUFFER#__atuin_accept__:} - zle accept-line - fi - fi -} -_atuin_search_vicmd() { - _atuin_search --keymap-mode=vim-normal -} -_atuin_search_viins() { - _atuin_search --keymap-mode=vim-insert -} - -_atuin_up_search() { - # Only trigger if the buffer is a single line - if [[ ! $BUFFER == *$'\n'* ]]; then - _atuin_search --shell-up-key-binding "$@" - else - zle up-line - fi -} -_atuin_up_search_vicmd() { - _atuin_up_search --keymap-mode=vim-normal -} -_atuin_up_search_viins() { - _atuin_up_search --keymap-mode=vim-insert -} - -add-zsh-hook preexec _atuin_preexec -add-zsh-hook precmd _atuin_precmd - -zle -N atuin-search _atuin_search -zle -N atuin-search-vicmd _atuin_search_vicmd -zle -N atuin-search-viins _atuin_search_viins -zle -N atuin-up-search _atuin_up_search -zle -N atuin-up-search-vicmd _atuin_up_search_vicmd -zle -N atuin-up-search-viins _atuin_up_search_viins - -# These are compatibility widget names for "atuin <= 17.2.1" users. -zle -N _atuin_search_widget _atuin_search -zle -N _atuin_up_search_widget _atuin_up_search diff --git a/crates/atuin/src/sync.rs b/crates/atuin/src/sync.rs deleted file mode 100644 index 02e4db69..00000000 --- a/crates/atuin/src/sync.rs +++ /dev/null @@ -1,34 +0,0 @@ -use eyre::{Context, Result}; - -use atuin_client::{ - database::Database, history::store::HistoryStore, record::sqlite_store::SqliteStore, - settings::Settings, -}; -use atuin_common::record::RecordId; - -// This is the only crate that ties together all other crates. -// Therefore, it's the only crate where functions tying together all stores can live - -/// Rebuild all stores after a sync -/// Note: for history, this only does an _incremental_ sync. Hence the need to specify downloaded -/// records. -pub async fn build( - settings: &Settings, - store: &SqliteStore, - db: &dyn Database, - downloaded: Option<&[RecordId]>, -) -> Result<()> { - let encryption_key: [u8; 32] = atuin_client::encryption::load_key(settings) - .context("could not load encryption key")? - .into(); - - let host_id = Settings::host_id().await?; - - let downloaded = downloaded.unwrap_or(&[]); - - let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); - - history_store.incremental_build(db, downloaded).await?; - - Ok(()) -} diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs deleted file mode 100644 index 228c0d17..00000000 --- a/crates/atuin/tests/common/mod.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::{env, time::Duration}; - -use atuin_client::api_client; -use atuin_common::utils::uuid_v7; -use atuin_server::{Settings as ServerSettings, launch_with_tcp_listener}; -use atuin_server_database::DbSettings; -use atuin_server_postgres::Postgres; -use futures_util::TryFutureExt; -use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle}; -use tracing::{Dispatch, dispatcher}; -use tracing_subscriber::{EnvFilter, layer::SubscriberExt}; - -pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()>) { - let formatting_layer = tracing_tree::HierarchicalLayer::default() - .with_writer(tracing_subscriber::fmt::TestWriter::new()) - .with_indent_lines(true) - .with_ansi(true) - .with_targets(true) - .with_indent_amount(2); - - let dispatch: Dispatch = tracing_subscriber::registry() - .with(formatting_layer) - .with(EnvFilter::new("atuin_server=debug,atuin_client=debug,info")) - .into(); - - let db_uri = env::var("ATUIN_DB_URI") - .unwrap_or_else(|_| "postgres://atuin:pass@localhost:5432/atuin".to_owned()); - - let server_settings = ServerSettings { - host: "127.0.0.1".to_owned(), - port: 0, - path: path.to_owned(), - sync_v1_enabled: true, - open_registration: true, - max_history_length: 8192, - max_record_size: 1024 * 1024 * 1024, - page_size: 1100, - register_webhook_url: None, - register_webhook_username: String::new(), - db_settings: DbSettings { - db_uri: db_uri, - read_db_uri: None, - }, - metrics: atuin_server::settings::Metrics::default(), - fake_version: None, - }; - - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - let server = tokio::spawn(async move { - let _tracing_guard = dispatcher::set_default(&dispatch); - - if let Err(e) = launch_with_tcp_listener::( - server_settings, - listener, - shutdown_rx.unwrap_or_else(|_| ()), - ) - .await - { - tracing::error!(error=?e, "server error"); - panic!("error running server: {e:?}"); - } - }); - - // let the server come online - tokio::time::sleep(Duration::from_millis(200)).await; - - (format!("http://{addr}{path}"), shutdown_tx, server) -} - -pub async fn register_inner<'a>( - address: &'a str, - username: &str, - password: &str, -) -> api_client::Client<'a> { - let email = format!("{}@example.com", uuid_v7().as_simple()); - - // registration works - let registration_response = api_client::register(address, username, &email, password) - .await - .unwrap(); - - api_client::Client::new( - address, - api_client::AuthToken::Token(registration_response.session), - 5, - 30, - ) - .unwrap() -} - -#[expect(dead_code)] -pub async fn login(address: &str, username: String, password: String) -> api_client::Client<'_> { - // registration works - let login_response = api_client::login( - address, - atuin_common::api::LoginRequest { username, password }, - ) - .await - .unwrap(); - - api_client::Client::new( - address, - api_client::AuthToken::Token(login_response.session), - 5, - 30, - ) - .unwrap() -} - -#[expect(dead_code)] -pub async fn register(address: &str) -> api_client::Client<'_> { - let username = uuid_v7().as_simple().to_string(); - let password = uuid_v7().as_simple().to_string(); - register_inner(address, &username, &password).await -} diff --git a/crates/atuin/tests/sync.rs b/crates/atuin/tests/sync.rs deleted file mode 100644 index 7e25d1c2..00000000 --- a/crates/atuin/tests/sync.rs +++ /dev/null @@ -1,45 +0,0 @@ -use atuin_common::{api::AddHistoryRequest, utils::uuid_v7}; -use time::OffsetDateTime; - -mod common; - -#[tokio::test] -async fn sync() { - let path = format!("/{}", uuid_v7().as_simple()); - let (address, shutdown, server) = common::start_server(&path).await; - - let client = common::register(&address).await; - let hostname = uuid_v7().as_simple().to_string(); - let now = OffsetDateTime::now_utc(); - - let data1 = uuid_v7().as_simple().to_string(); - let data2 = uuid_v7().as_simple().to_string(); - - client - .post_history(&[ - AddHistoryRequest { - id: uuid_v7().as_simple().to_string(), - timestamp: now, - data: data1.clone(), - hostname: hostname.clone(), - }, - AddHistoryRequest { - id: uuid_v7().as_simple().to_string(), - timestamp: now, - data: data2.clone(), - hostname: hostname.clone(), - }, - ]) - .await - .unwrap(); - - let history = client - .get_history(OffsetDateTime::UNIX_EPOCH, OffsetDateTime::UNIX_EPOCH, None) - .await - .unwrap(); - - assert_eq!(history.history, vec![data1, data2]); - - shutdown.send(()).unwrap(); - server.await.unwrap(); -} diff --git a/crates/atuin/tests/users.rs b/crates/atuin/tests/users.rs deleted file mode 100644 index 95fb533b..00000000 --- a/crates/atuin/tests/users.rs +++ /dev/null @@ -1,121 +0,0 @@ -use atuin_common::utils::uuid_v7; - -mod common; - -#[tokio::test] -async fn registration() { - let path = format!("/{}", uuid_v7().as_simple()); - let (address, shutdown, server) = common::start_server(&path).await; - dbg!(&address); - - // -- REGISTRATION -- - - let username = uuid_v7().as_simple().to_string(); - let password = uuid_v7().as_simple().to_string(); - let client = common::register_inner(&address, &username, &password).await; - - // the session token works - let status = client.status().await.unwrap(); - assert_eq!(status.username, username); - - // -- LOGIN -- - - let client = common::login(&address, username.clone(), password).await; - - // the session token works - let status = client.status().await.unwrap(); - assert_eq!(status.username, username); - - shutdown.send(()).unwrap(); - server.await.unwrap(); -} - -#[tokio::test] -async fn change_password() { - let path = format!("/{}", uuid_v7().as_simple()); - let (address, shutdown, server) = common::start_server(&path).await; - - // -- REGISTRATION -- - - let username = uuid_v7().as_simple().to_string(); - let password = uuid_v7().as_simple().to_string(); - let client = common::register_inner(&address, &username, &password).await; - - // the session token works - let status = client.status().await.unwrap(); - assert_eq!(status.username, username); - - // -- PASSWORD CHANGE -- - - let current_password = password; - let new_password = uuid_v7().as_simple().to_string(); - let result = client - .change_password(current_password, new_password.clone()) - .await; - - // the password change request succeeded - assert!(result.is_ok()); - - // -- LOGIN -- - - let client = common::login(&address, username.clone(), new_password).await; - - // login with new password yields a working token - let status = client.status().await.unwrap(); - assert_eq!(status.username, username); - - shutdown.send(()).unwrap(); - server.await.unwrap(); -} - -#[tokio::test] -async fn multi_user_test() { - let path = format!("/{}", uuid_v7().as_simple()); - let (address, shutdown, server) = common::start_server(&path).await; - dbg!(&address); - - // -- REGISTRATION -- - - let user_one = uuid_v7().as_simple().to_string(); - let password_one = uuid_v7().as_simple().to_string(); - let client_one = common::register_inner(&address, &user_one, &password_one).await; - - // the session token works - let status = client_one.status().await.unwrap(); - assert_eq!(status.username, user_one); - - let user_two = uuid_v7().as_simple().to_string(); - let password_two = uuid_v7().as_simple().to_string(); - let client_two = common::register_inner(&address, &user_two, &password_two).await; - - // the session token works - let status = client_two.status().await.unwrap(); - assert_eq!(status.username, user_two); - - // check that we can change user one's password, and _this does not affect user two_ - - let current_password = password_one; - let new_password = uuid_v7().as_simple().to_string(); - let result = client_one - .change_password(current_password, new_password.clone()) - .await; - - // the password change request succeeded - assert!(result.is_ok()); - - // -- LOGIN -- - - let client_one = common::login(&address, user_one.clone(), new_password).await; - let client_two = common::login(&address, user_two.clone(), password_two).await; - - // login with new password yields a working token - let status = client_one.status().await.unwrap(); - assert_eq!(status.username, user_one); - assert_ne!(status.username, user_two); - - let status = client_two.status().await.unwrap(); - assert_eq!(status.username, user_two); - - shutdown.send(()).unwrap(); - server.await.unwrap(); -} diff --git a/crates/turtle/Cargo.toml b/crates/turtle/Cargo.toml new file mode 100644 index 00000000..87557905 --- /dev/null +++ b/crates/turtle/Cargo.toml @@ -0,0 +1,142 @@ +[package] +name = "atuin" +edition = "2024" +description = "atuin - magical shell history" +readme = "./README.md" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[features] +default = [ + "clipboard", + "daemon", + "hex", + "sync", + "client", +] + +clipboard = ["arboard"] +daemon = ["pty-proxy"] +pty-proxy = [] +client = [] +hex = ["dep:hex"] +sync = ["urlencoding", "reqwest", "sha2", "hex"] + +[dependencies] +argon2 = "0.5" +async-trait = "0.1.58" +atuin-nucleo-matcher = { workspace = true } +atuin-nucleo = { workspace = true } +axum = "0.8" +base64 = "0.22" +clap = { version = "4.5.7", features = ["derive"] } +clap_complete = "4.5.8" +clap_complete_nushell = "4.5.4" +colored = "2.0.4" +config = { version = "0.15.8", default-features = false, features = ["toml"] } +crossterm = {version = "0.29.0", features = ["use-dev-tty", "serde"] } +crypto_secretbox = "0.1.1" +dashmap = "6.1.0" +directories = "6.0.0" +eyre = "0.6" +fs-err = "3.1" +fs4 = "0.13.1" +futures = "0.3" +futures-util = "0.3" +fuzzy-matcher = "0.3.7" +generic-array = { version = "0.14", features = ["serde"] } +getrandom = "0.2" +glob-match = "0.2.1" +hex = { version = "0.4", optional = true } +humantime = "2.1.0" +hyper-util = "0.1" +imara-diff = "0.2" +indicatif = "0.18.0" +interim = { version = "0.2.0", features = ["time_0_3"] } +itertools = "0.14.0" +lasso = { version = "0.7", features = ["multi-threaded"] } +log = "0.4" +memchr = "2.7" +metrics = "0.24" +metrics-exporter-prometheus = { version = "0.18", default-features = false } +minijinja = "2.9.0" +minspan = "0.1.5" +norm = { version = "0.1.1", features = ["fzf-v2"] } +notify = "7" +open = "5" +palette = { version = "0.7.5", features = ["serializing"] } +pretty_assertions = "1.3.0" +prost = "0.14" +prost-types = "0.14" +rand = { version = "0.8.5", features = ["std"] } +ratatui = "0.30.0" +regex = "1.10.5" +reqwest = { version = "0.13", optional = true, features = ["json", "rustls-no-provider", "stream"], default-features = false } +rmp = { version = "0.8.14" } +rpassword = "7.0" +runtime-format = "0.1.3" +rustix = { version = "1.1.4", features = ["process", "fs"] } +rustls = { version = "0.23", default-features = false, features = [ "ring", "std", "tls12", ] } +rusty_paserk = { version = "0.5.0", default-features = false, features = [ "v4", "serde", ] } +rusty_paseto = { version = "0.8.0", default-features = false } +semver = "1.0.20" +serde = { version = "1.0.202", features = ["derive"] } +serde_json = "1.0.119" +serde_regex = "1.1.0" +serde_with = "3.8.1" +sha2 = { version = "0.10", optional = true } +shellexpand = "3" +shlex = "1.3.0" +sql-builder = "3" +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "time", "postgres", "uuid", "sqlite", "regexp"] } +strum = { version = "0.27", features = ["strum_macros"] } +strum_macros = "0.27" +sysinfo = "0.30.7" +tempfile = { version = "3.19" } +thiserror = "2" +time = { version = "0.3.47", features = [ "serde-human-readable", "macros", "local-offset", "macros", "formatting", "parsing"] } +tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1.14", features = ["net"] } +toml_edit = "0.25.4" +tonic = "0.14" +tonic-prost = "0.14" +tonic-types = "0.14" +tower = "0.5" +tower-http = { version = "0.6", features = ["trace"] } +tracing = "0.1" +tracing-appender = "0.2" +tracing-subscriber = { version = "0.3", features = ["ansi", "fmt", "registry", "env-filter", "json"] } +typed-builder = "0.18.2" +unicode-segmentation = "1.11.0" +unicode-width = "0.2" +url = "2.5.2" +urlencoding = { version = "2.1.0", optional = true } +uuid = { version = "1.9", features = ["v4", "v7", "serde"] } +vt100 = "0.16" +whoami = "2.1.0" +xxhash-rust = { version = "0.8", features = ["xxh3"] } + +[target.'cfg(target_os = "linux")'.dependencies] +arboard = { version = "3.4", optional = true, default-features = false, features = [ "wayland-data-control", ] } +listenfd = "1.0.1" + +[target.'cfg(unix)'.dependencies] +daemonize = "0.5.0" +portable-pty = "0.9" +signal-hook = "0.3" + +[dev-dependencies] +tracing-tree = "0.4" +divan = "0.1.14" +tokio = { version = "1", features = ["full"] } +testing_logger = "0.1.1" + +[build-dependencies] +protox = "0.9" +tonic-build = "0.14" +tonic-prost-build = "0.14" diff --git a/crates/turtle/build.rs b/crates/turtle/build.rs new file mode 100644 index 00000000..5f26e26c --- /dev/null +++ b/crates/turtle/build.rs @@ -0,0 +1,39 @@ +use std::process::Command; +use std::{env, fs, path::PathBuf}; + +use protox::prost::Message; + +fn main() -> Result<(), std::io::Error> { + { + let output = Command::new("git").args(["rev-parse", "HEAD"]).output(); + + let sha = match output { + Ok(sha) => String::from_utf8(sha.stdout).unwrap(), + Err(_) => String::from("NO_GIT"), + }; + + println!("cargo:rustc-env=GIT_HASH={sha}"); + } + + { + let proto_paths = [ + "proto/history.proto", + "proto/search.proto", + "proto/control.proto", + "proto/semantic.proto", + ]; + let proto_include_dirs = ["proto"]; + + let file_descriptors = protox::compile(proto_paths, proto_include_dirs).unwrap(); + + let file_descriptor_path = PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR not set")) + .join("file_descriptor_set.bin"); + fs::write(&file_descriptor_path, file_descriptors.encode_to_vec()).unwrap(); + + tonic_prost_build::configure() + .build_server(true) + .file_descriptor_set_path(&file_descriptor_path) + .skip_protoc_run() + .compile_protos(&proto_paths, &proto_include_dirs) + } +} diff --git a/crates/turtle/proto/control.proto b/crates/turtle/proto/control.proto new file mode 100644 index 00000000..06347902 --- /dev/null +++ b/crates/turtle/proto/control.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; +package control; + +// The Control service allows external processes (CLI commands, etc.) +// to inject events into the running daemon. +service Control { + // Send an event to the daemon's event bus + rpc SendEvent(SendEventRequest) returns (SendEventResponse); +} + +message SendEventRequest { + oneof event { + // History was pruned - search index needs full rebuild + HistoryPrunedEvent history_pruned = 1; + + // Specific history items were deleted + HistoryDeletedEvent history_deleted = 2; + + // Request immediate sync + ForceSyncEvent force_sync = 3; + + // Settings have changed, reload if needed + SettingsReloadedEvent settings_reloaded = 4; + + // Request graceful shutdown + ShutdownEvent shutdown = 5; + + // History was rebuilt - search index needs full rebuild + HistoryRebuiltEvent history_rebuilt = 6; + } +} + +message SendEventResponse { + // Empty on success; errors come through gRPC status +} + +// Individual event message types + +message HistoryPrunedEvent { + // No fields needed - just signals that pruning happened +} + +message HistoryRebuiltEvent { + // No fields needed - just signals that rebuilding happened +} + +message HistoryDeletedEvent { + // IDs of deleted history items (UUIDs as strings) + repeated string ids = 1; +} + +message ForceSyncEvent { + // No fields needed - just triggers sync +} + +message SettingsReloadedEvent { + // No fields needed - components should re-read settings +} + +message ShutdownEvent { + // No fields needed - triggers graceful shutdown +} diff --git a/crates/turtle/proto/history.proto b/crates/turtle/proto/history.proto new file mode 100644 index 00000000..59c12471 --- /dev/null +++ b/crates/turtle/proto/history.proto @@ -0,0 +1,81 @@ +syntax = "proto3"; +package history; + +message StartHistoryRequest { + // If people are still using my software in ~530 years, they can figure out a u128 migration + uint64 timestamp = 1; // nanosecond unix epoch + string command = 2; + string cwd = 3; + string session = 4; + string hostname = 5; + string author = 6; + string intent = 7; +} + +message EndHistoryRequest { + string id = 1; + int64 exit = 2; + uint64 duration = 3; +} + +message StartHistoryReply { + string id = 1; + string version = 2; + uint32 protocol = 3; +} + +message EndHistoryReply { + string id = 1; + uint64 idx = 2; + string version = 3; + uint32 protocol = 4; +} + +message StatusRequest {} + +message StatusReply { + bool healthy = 1; + string version = 2; + uint32 pid = 3; + uint32 protocol = 4; +} + +message ShutdownRequest {} + +message ShutdownReply { + bool accepted = 1; +} + +message TailHistoryRequest {} + +enum HistoryEventKind { + HISTORY_EVENT_KIND_UNSPECIFIED = 0; + HISTORY_EVENT_KIND_STARTED = 1; + HISTORY_EVENT_KIND_ENDED = 2; +} + +message HistoryEntry { + uint64 timestamp = 1; // nanosecond unix epoch + string id = 2; + string command = 3; + string cwd = 4; + string session = 5; + string hostname = 6; + string author = 7; + string intent = 8; + int64 exit = 9; + int64 duration = 10; +} + +message TailHistoryReply { + HistoryEventKind kind = 1; + HistoryEntry history = 2; +} + +service History { + rpc StartHistory(StartHistoryRequest) returns (StartHistoryReply); + rpc EndHistory(EndHistoryRequest) returns (EndHistoryReply); + rpc TailHistory(TailHistoryRequest) returns (stream TailHistoryReply); + rpc Status(StatusRequest) returns (StatusReply); + rpc Shutdown(ShutdownRequest) returns (ShutdownReply); +} diff --git a/crates/turtle/proto/search.proto b/crates/turtle/proto/search.proto new file mode 100644 index 00000000..6b84acbd --- /dev/null +++ b/crates/turtle/proto/search.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; +package search; + +enum FilterMode { + GLOBAL = 0; + HOST = 1; + SESSION = 2; + DIRECTORY = 3; + WORKSPACE = 4; + SESSION_PRELOAD = 5; +} + +message SearchContext { + string session_id = 1; + string cwd = 2; + string hostname = 3; + string host_id = 4; + optional string git_root = 5; +} + +message SearchRequest { + string query = 1; + uint64 query_id = 2; // Incrementing ID to match responses to queries + FilterMode filter_mode = 3; + SearchContext context = 4; +} + +message SearchResponse { + uint64 query_id = 1; // Echo back the query ID + repeated bytes ids = 2; +} + +service Search { + rpc Search(stream SearchRequest) returns (stream SearchResponse); +} diff --git a/crates/turtle/proto/semantic.proto b/crates/turtle/proto/semantic.proto new file mode 100644 index 00000000..07e550c8 --- /dev/null +++ b/crates/turtle/proto/semantic.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; +package semantic; + +service Semantic { + rpc RecordCommands(stream CommandCapture) returns (RecordCommandsReply); + rpc CommandOutput(CommandOutputRequest) returns (CommandOutputReply); +} + +message CommandCapture { + string prompt = 1; + string command = 2; + string output = 3; + optional int32 exit_code = 4; + optional string history_id = 5; + optional string session_id = 6; + bool output_truncated = 7; + uint64 output_observed_bytes = 8; +} + +message RecordCommandsReply { + uint64 accepted = 1; +} + +message CommandOutputRequest { + string history_id = 1; + repeated OutputRange ranges = 2; +} + +message OutputRange { + int64 start = 1; + int64 end = 2; +} + +message OutputLine { + uint64 line_number = 1; + string content = 2; +} + +message CommandOutputReply { + bool found = 1; + string output = 2; + uint64 total_bytes = 3; + uint64 total_lines = 4; + repeated OutputLine lines = 5; + bool output_truncated = 6; + uint64 output_observed_bytes = 7; +} diff --git a/crates/turtle/src/atuin_client/api_client.rs b/crates/turtle/src/atuin_client/api_client.rs new file mode 100644 index 00000000..7955c2da --- /dev/null +++ b/crates/turtle/src/atuin_client/api_client.rs @@ -0,0 +1,438 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use eyre::{Result, bail, eyre}; +use reqwest::{ + Response, StatusCode, Url, + header::{AUTHORIZATION, HeaderMap, USER_AGENT}, +}; +use tracing::debug; + +use crate::atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, + record::{EncryptedData, HostId, Record, RecordIdx}, + tls::ensure_crypto_provider, +}; +use crate::atuin_common::{ + api::{ + AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, + ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, + SyncHistoryResponse, + }, + record::RecordStatus, +}; + +use semver::Version; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; + +use crate::atuin_client::{history::History, sync::hash_str, utils::get_host_user}; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); + +/// Authentication token for sync API requests. +/// +/// The sync API supports two authentication methods: +/// - `Bearer`: Hub API tokens (for users authenticated via Atuin Hub) +/// - `Token`: Legacy CLI session tokens (for users registered via CLI or self-hosted) +/// +/// When both are available, Hub tokens are preferred as they provide unified +/// authentication across CLI and Hub features. +#[derive(Debug, Clone)] +pub enum AuthToken { + /// Legacy CLI session token, used with "Token {token}" header + Token(String), +} + +impl AuthToken { + /// Format the token as an Authorization header value + fn to_header_value(&self) -> String { + match self { + AuthToken::Token(token) => format!("Token {token}"), + } + } +} + +pub struct Client<'a> { + sync_addr: &'a str, + client: reqwest::Client, +} + +fn make_url(address: &str, path: &str) -> Result { + // `join()` expects a trailing `/` in order to join paths + // e.g. it treats `http://host:port/subdir` as a file called `subdir` + let address = if address.ends_with("/") { + address + } else { + &format!("{address}/") + }; + + // passing a path with a leading `/` will cause `join()` to replace the entire URL path + let path = path.strip_prefix("/").unwrap_or(path); + + let url = Url::parse(address) + .map(|url| url.join(path))? + .map_err(|_| eyre!("invalid address"))?; + + Ok(url.to_string()) +} + +pub async fn register( + address: &str, + username: &str, + email: &str, + password: &str, +) -> Result { + ensure_crypto_provider(); + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = make_url(address, &format!("/user/{username}"))?; + let resp = reqwest::get(url).await?; + + if resp.status().is_success() { + bail!("username already in use"); + } + + let url = make_url(address, "/register")?; + let client = reqwest::Client::new(); + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) + .json(&map) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not register user due to version mismatch"); + } + + let session = resp.json::().await?; + Ok(session) +} + +pub async fn login(address: &str, req: LoginRequest) -> Result { + ensure_crypto_provider(); + let url = make_url(address, "/login")?; + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .json(&req) + .send() + .await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("Could not login due to version mismatch"); + } + + let session = resp.json::().await?; + Ok(session) +} + +pub fn ensure_version(response: &Response) -> Result { + let version = response.headers().get(ATUIN_HEADER_VERSION); + + let version = if let Some(version) = version { + match version.to_str() { + Ok(v) => Version::parse(v), + Err(e) => bail!("failed to parse server version: {:?}", e), + } + } else { + bail!("Server not reporting its version: it is either too old or unhealthy"); + }?; + + // If the client is newer than the server + if version.major < ATUIN_VERSION.major { + println!( + "Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin" + ); + println!("Client: {ATUIN_CARGO_VERSION}"); + println!("Server: {version}"); + + return Ok(false); + } + + Ok(true) +} + +async fn handle_resp_error(resp: Response) -> Result { + let status = resp.status(); + let url = resp.url().to_string(); + + if status == StatusCode::SERVICE_UNAVAILABLE { + bail!( + "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" + ); + } + + if status == StatusCode::TOO_MANY_REQUESTS { + bail!("Rate limited; please wait before doing that again"); + } + + if !status.is_success() { + if let Ok(error) = resp.json::().await { + let reason = error.reason; + + if status.is_client_error() { + bail!("Invalid request to the service at {url}, {status} - {reason}.") + } + + bail!( + "There was an error with the atuin sync service at {url}, server error {status}: {reason}.\nIf the problem persists, contact the host" + ) + } + + bail!( + "There was an error with the atuin sync service at {url}, Status {status:?}.\nIf the problem persists, contact the host" + ) + } + + Ok(resp) +} + +impl<'a> Client<'a> { + pub fn new( + sync_addr: &'a str, + auth: AuthToken, + connect_timeout: u64, + timeout: u64, + ) -> Result { + ensure_crypto_provider(); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, auth.to_header_value().parse()?); + + // used for semver server check + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(Client { + sync_addr, + client: reqwest::Client::builder() + .user_agent(APP_USER_AGENT) + .default_headers(headers) + .connect_timeout(Duration::new(connect_timeout, 0)) + .timeout(Duration::new(timeout, 0)) + .build()?, + }) + } + + pub async fn count(&self) -> Result { + let url = make_url(self.sync_addr, "/sync/count")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + if resp.status() != StatusCode::OK { + bail!("failed to get count (are you logged in?)"); + } + + let count = resp.json::().await?; + + Ok(count.count) + } + + pub async fn status(&self) -> Result { + let url = make_url(self.sync_addr, "/sync/status")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + let status = resp.json::().await?; + + Ok(status) + } + + pub async fn me(&self) -> Result { + let url = make_url(self.sync_addr, "/api/v0/me")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let status = resp.json::().await?; + + Ok(status) + } + + pub async fn get_history( + &self, + sync_ts: OffsetDateTime, + history_ts: OffsetDateTime, + host: Option, + ) -> Result { + let host = host.unwrap_or_else(|| hash_str(&get_host_user())); + + let url = make_url( + self.sync_addr, + &format!( + "/sync/history?sync_ts={}&history_ts={}&host={}", + urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), + urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), + host, + ), + )?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let history = resp.json::().await?; + Ok(history) + } + + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.post(url).json(history).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_history(&self, h: History) -> Result<()> { + let url = make_url(self.sync_addr, "/history")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .delete(url) + .json(&DeleteHistoryRequest { + client_id: h.id.to_string(), + }) + .send() + .await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_store(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/store")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn post_records(&self, records: &[Record]) -> Result<()> { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + debug!("uploading {} records to {url}", records.len()); + + let resp = self.client.post(url).json(records).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: RecordIdx, + count: u64, + ) -> Result>> { + debug!("fetching record/s from host {}/{}/{}", host.0, tag, start); + + let url = make_url( + self.sync_addr, + &format!( + "/api/v0/record/next?host={}&tag={}&count={}&start={}", + host.0, tag, count, start + ), + )?; + + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let records = resp.json::>>().await?; + + Ok(records) + } + + pub async fn record_status(&self) -> Result { + let url = make_url(self.sync_addr, "/api/v0/record")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync records due to version mismatch"); + } + + let index = resp.json().await?; + + debug!("got remote index {index:?}"); + + Ok(index) + } + + pub async fn delete(&self) -> Result<()> { + let url = make_url(self.sync_addr, "/account")?; + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } + + pub async fn change_password( + &self, + current_password: String, + new_password: String, + ) -> Result<()> { + let url = make_url(self.sync_addr, "/account/password")?; + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .patch(url) + .json(&ChangePasswordRequest { + current_password, + new_password, + }) + .send() + .await?; + + if resp.status() == 401 { + bail!("current password is incorrect") + } else if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } +} diff --git a/crates/turtle/src/atuin_client/auth.rs b/crates/turtle/src/atuin_client/auth.rs new file mode 100644 index 00000000..b770c488 --- /dev/null +++ b/crates/turtle/src/atuin_client/auth.rs @@ -0,0 +1,223 @@ +use async_trait::async_trait; +use eyre::{Context, Result, bail}; +use reqwest::{Url, header::USER_AGENT}; + +use crate::{ + atuin_client::api_client, + atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ChangePasswordRequest, LoginRequest}, + tls::ensure_crypto_provider, + }, +}; + +use crate::atuin_client::settings::Settings; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION")); + +/// Result of an auth operation that may require 2FA. +pub enum AuthResponse { + /// Operation succeeded; for login/register, contains the session token. + /// `auth_type` indicates the kind of token: `Some("hub")` for Hub API + /// tokens (prefixed `atapi_`), `Some("cli")` for legacy CLI session + /// tokens. `None` when the server didn't include the field (old servers). + Success { + session: String, + auth_type: Option, + }, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry with it. + TwoFactorRequired, +} + +/// Result of a mutating account operation that may require 2FA. +pub enum MutateResponse { + /// Operation completed successfully. + Success, + /// Two-factor authentication is required; the caller should prompt for a + /// TOTP code and retry. + TwoFactorRequired, +} + +/// Abstraction over the legacy (Rust sync server) and Hub auth APIs. +/// +/// CLI commands use this trait so they don't need to know which backend is +/// active — they just prompt for input and call these methods. +#[async_trait] +pub trait AuthClient: Send + Sync { + /// Log in with username + password, optionally providing a TOTP code. + async fn login(&self, username: &str, password: &str) -> Result; + + /// Register a new account. + async fn register(&self, username: &str, email: &str, password: &str) -> Result; + + /// Change the account password, optionally providing a TOTP code. + async fn change_password( + &self, + current_password: &str, + new_password: &str, + totp_code: Option<&str>, + ) -> Result; + + /// Delete the account, requiring the current password and optionally a TOTP code. + async fn delete_account( + &self, + password: &str, + totp_code: Option<&str>, + ) -> Result; +} + +/// Resolve the appropriate [`AuthClient`] for the current settings. +pub async fn auth_client(settings: &Settings) -> Box { + Box::new(LegacyAuthClient::new( + &settings.sync_address, + settings.session_token().await.ok(), + settings.network_connect_timeout, + settings.network_timeout, + )) as Box +} + +// --------------------------------------------------------------------------- +// Legacy backend — talks to the Rust sync server +// --------------------------------------------------------------------------- + +pub struct LegacyAuthClient { + address: String, + session_token: Option, + connect_timeout: u64, + timeout: u64, +} + +impl LegacyAuthClient { + pub fn new( + address: &str, + session_token: Option, + connect_timeout: u64, + timeout: u64, + ) -> Self { + Self { + address: address.to_string(), + session_token, + connect_timeout, + timeout, + } + } + + fn authenticated_client(&self) -> Result { + let token = self + .session_token + .as_deref() + .ok_or_else(|| eyre::eyre!("Not logged in"))?; + + ensure_crypto_provider(); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + format!("Token {token}").parse()?, + ); + headers.insert(USER_AGENT, APP_USER_AGENT.parse()?); + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(reqwest::Client::builder() + .default_headers(headers) + .connect_timeout(std::time::Duration::new(self.connect_timeout, 0)) + .timeout(std::time::Duration::new(self.timeout, 0)) + .build()?) + } +} + +#[async_trait] +impl AuthClient for LegacyAuthClient { + async fn login(&self, username: &str, password: &str) -> Result { + // The legacy server has no 2FA support; totp_code is ignored. + let resp = api_client::login( + &self.address, + LoginRequest { + username: username.to_string(), + password: password.to_string(), + }, + ) + .await?; + + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn register(&self, username: &str, email: &str, password: &str) -> Result { + let resp = api_client::register(&self.address, username, email, password).await?; + Ok(AuthResponse::Success { + session: resp.session, + auth_type: resp.auth.or(Some("cli".into())), + }) + } + + async fn change_password( + &self, + current_password: &str, + new_password: &str, + _totp_code: Option<&str>, + ) -> Result { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account/password")?; + + let resp = client + .patch(&url) + .json(&ChangePasswordRequest { + current_password: current_password.to_string(), + new_password: new_password.to_string(), + }) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("current password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } + + async fn delete_account( + &self, + password: &str, + _totp_code: Option<&str>, + ) -> Result { + let client = self.authenticated_client()?; + let url = make_url(&self.address, "/account")?; + + let resp = client + .delete(&url) + .json(&serde_json::json!({ "password": password })) + .send() + .await?; + + match resp.status().as_u16() { + 200 => Ok(MutateResponse::Success), + 401 => bail!("password is incorrect"), + 403 => bail!("invalid login details"), + _ => bail!("unknown error"), + } + } +} + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +fn make_url(address: &str, path: &str) -> Result { + let address = if address.ends_with('/') { + address.to_string() + } else { + format!("{address}/") + }; + + let path = path.strip_prefix('/').unwrap_or(path); + + let url = Url::parse(&address) + .context("failed to parse server address")? + .join(path) + .context("failed to join URL path")?; + + Ok(url.to_string()) +} diff --git a/crates/turtle/src/atuin_client/database.rs b/crates/turtle/src/atuin_client/database.rs new file mode 100644 index 00000000..75b1200c --- /dev/null +++ b/crates/turtle/src/atuin_client/database.rs @@ -0,0 +1,1526 @@ +use std::{ + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use crate::atuin_client::history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, KNOWN_AGENTS}; +use crate::atuin_common::utils; +use async_trait::async_trait; +use fs_err as fs; +use itertools::Itertools; +use rand::{Rng, distributions::Alphanumeric}; +use sql_builder::{SqlBuilder, SqlName, bind::Bind, esc, quote}; +use sqlx::{ + Result, Row, + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, +}; +use time::OffsetDateTime; +use tracing::debug; +use uuid::Uuid; + +use crate::atuin_client::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +#[derive(Clone)] +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option, + pub exclude_exit: Option, + pub cwd: Option, + pub exclude_cwd: Option, + pub before: Option, + pub after: Option, + pub limit: Option, + pub offset: Option, + pub reverse: bool, + pub include_duplicates: bool, + /// Author filter. Supports special values `$all-user` and `$all-agent`. + pub authors: Vec, +} + +pub async fn current_context() -> eyre::Result { + let session = env::var("ATUIN_SESSION").map_err(|_| { + eyre::eyre!("Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell.") + })?; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().await?; + let git_root = utils::in_git_repo(cwd.as_str()); + + Ok(Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + }) +} + +impl Context { + pub fn from_history(entry: &History) -> Self { + Context { + session: entry.session.to_string(), + cwd: entry.cwd.to_string(), + hostname: entry.hostname.to_string(), + host_id: String::new(), + git_root: utils::in_git_repo(entry.cwd.as_str()), + } + } +} + +/// Each entry is OR'd: `$all-user` → NOT IN agents, `$all-agent` → IN agents, literal → exact match. +fn apply_author_filter(sql: &mut SqlBuilder, authors: &[String]) { + let mut conditions: Vec = Vec::new(); + let agent_list: String = KNOWN_AGENTS.iter().map(quote).join(", "); + let author_expr = "CASE \ + WHEN author IS NULL OR trim(author) = '' THEN \ + CASE \ + WHEN instr(hostname, ':') > 0 THEN substr(hostname, instr(hostname, ':') + 1) \ + ELSE hostname \ + END \ + ELSE author \ + END"; + + for author in authors { + match author.as_str() { + AUTHOR_FILTER_ALL_USER => { + conditions.push(format!("{author_expr} NOT IN ({agent_list})")); + } + AUTHOR_FILTER_ALL_AGENT => { + conditions.push(format!("{author_expr} IN ({agent_list})")); + } + literal => { + conditions.push(format!("{author_expr} = {}", quote(literal))); + } + } + } + + if !conditions.is_empty() { + sql.and_where(format!("({})", conditions.join(" OR "))); + } +} + +fn get_session_start_time(session_id: &str) -> Option { + if let Ok(uuid) = Uuid::parse_str(session_id) + && let Some(timestamp) = uuid.get_timestamp() + { + let (seconds, nanos) = timestamp.to_unix(); + return Some(seconds as i64 * 1_000_000_000 + nanos as i64); + } + None +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option, + unique: bool, + include_deleted: bool, + ) -> Result>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result; + + async fn last(&self) -> Result>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[expect(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result>; + + async fn query_history(&self, query: &str) -> Result>; + + async fn all_with_count(&self) -> Result>; + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged; + + async fn stats(&self, h: &History) -> Result; + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result>; + + fn clone_boxed(&self) -> Box; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +#[derive(Debug, Clone)] +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef, timeout: f64) -> Result { + let path = path.as_ref(); + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + Ok(Self { pool }) + } + + pub async fn sqlite_version(&self) -> Result { + sqlx::query_scalar("SELECT sqlite_version()") + .fetch_one(&self.pool) + .await + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, author, intent, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option = row.get("deleted_at"); + let hostname: String = row.get("hostname"); + let author: Option = row.try_get("author").ok().flatten(); + let author = author + .filter(|author| !author.trim().is_empty()) + .unwrap_or_else(|| History::author_from_hostname(hostname.as_str())); + let intent: Option = row.try_get("intent").ok().flatten(); + let intent = intent.filter(|intent| !intent.trim().is_empty()); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(hostname) + .author(author) + .intent(intent) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, author = ?9, intent = ?10, deleted_at = ?11 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.author.as_str()) + .bind(h.intent.as_deref()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option, + unique: bool, + include_deleted: bool, + ) -> Result> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + query.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + query.or_where_lt("timestamp", session_start); + } + &mut query + } + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result> { + let mut sql = SqlBuilder::select_from("history"); + + if !filter_options.include_duplicates { + sql.group_by("command").having("max(timestamp)"); + } + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { + sql.offset(offset); + } + + if filter_options.reverse { + sql.order_asc("timestamp"); + } else { + sql.order_desc("timestamp"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + let session_start = get_session_start_time(&context.session); + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => { + sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) + } + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::SessionPreload => { + sql.and_where_eq("session", quote(&context.session)); + if let Some(session_start) = session_start { + sql.or_where_lt("timestamp", session_start); + } + &mut sql + } + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), + }; + + let orig_query = query; + + let mut regexes = Vec::new(); + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), + _ => { + let mut is_or = false; + for token in QueryTokenizer::new(query) { + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if token.has_uppercase() { + (true, "*") + } else { + (false, "%") + }; + let param = match token { + QueryToken::Regex(r) => { + regexes.push(String::from(r)); + continue; + } + QueryToken::Or => { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") + } + } + QueryToken::MatchStart(term, _) => { + format!("{term}{glob}") + } + QueryToken::MatchEnd(term, _) => { + format!("{glob}{term}") + } + QueryToken::MatchFull(term, _) => { + format!("{glob}{term}{glob}") + } + QueryToken::Match(term, _) => { + if search_mode == SearchMode::FullText { + format!("{glob}{term}{glob}") + } else { + term.split("").join(glob) + } + } + }; + + sql.fuzzy_condition("command", param, token.is_inverse(), is_glob, is_or); + is_or = false; + } + + &mut sql + } + }; + + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + + filter_options + .exit + .map(|exit| sql.and_where_eq("exit", exit)); + + filter_options + .exclude_exit + .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); + + filter_options + .cwd + .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); + + filter_options + .exclude_cwd + .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); + + filter_options.before.map(|before| { + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|before| { + sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) + }) + }); + + filter_options.after.map(|after| { + interim::parse_date_string( + after.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) + }); + + if !filter_options.authors.is_empty() { + apply_author_filter(&mut sql, &filter_options.authors); + } + + sql.and_where_is_null("deleted_at"); + + let query = sql.sql().expect("bug in search query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) + } + + async fn query_history(&self, query: &str) -> Result> { + let res = sqlx::query(query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn all_with_count(&self) -> Result> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query + .fields(&[ + "id", + "max(timestamp) as timestamp", + "max(duration) as duration", + "exit", + "command", + "deleted_at", + "null as author", + "null as intent", + "group_concat(cwd, ':') as cwd", + "group_concat(session) as session", + "group_concat(hostname, ',') as hostname", + "count(*) as count", + ]) + .group_by("command") + .group_by("exit") + .and_where("deleted_at is null") + .order_desc("timestamp"); + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(|row: SqliteRow| { + let count: i32 = row.get("count"); + (Self::query_history(row), count) + }) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn all_paged(&self, page_size: usize, include_deleted: bool, unique: bool) -> Paged { + Paged::new(Box::new(self.clone()), page_size, include_deleted, unique) + } + + // deleted_at doesn't mean the actual time that the user deleted it, + // but the time that the system marks it as deleted + async fn delete(&self, mut h: History) -> Result<()> { + let now = OffsetDateTime::now_utc(); + h.command = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); // overwrite with random string + h.deleted_at = Some(now); // delete it + + self.update(&h).await?; // save it + + Ok(()) + } + + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for id in ids { + Self::delete_row_raw(&mut tx, id.clone()).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn stats(&self, h: &History) -> Result { + // We select the previous in the session by time + let mut prev = SqlBuilder::select_from("history"); + prev.field("*") + .and_where("timestamp < ?1") + .and_where("session = ?2") + .order_by("timestamp", true) + .limit(1); + + let mut next = SqlBuilder::select_from("history"); + next.field("*") + .and_where("timestamp > ?1") + .and_where("session = ?2") + .order_by("timestamp", false) + .limit(1); + + let mut total = SqlBuilder::select_from("history"); + total.field("count(1)").and_where("command = ?1"); + + let mut average = SqlBuilder::select_from("history"); + average.field("avg(duration)").and_where("command = ?1"); + + let mut exits = SqlBuilder::select_from("history"); + exits + .fields(&["exit", "count(1) as count"]) + .and_where("command = ?1") + .group_by("exit"); + + // rewrite the following with sqlbuilder + let mut day_of_week = SqlBuilder::select_from("history"); + day_of_week + .fields(&[ + "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", + "count(1) as count", + ]) + .and_where("command = ?1") + .group_by("day_of_week"); + + // Intentionally format the string with 01 hardcoded. We want the average runtime for the + // _entire month_, but will later parse it as a datetime for sorting + // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a + // string sort, which won't be correct. + let mut duration_over_time = SqlBuilder::select_from("history"); + duration_over_time + .fields(&[ + "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", + "avg(duration) as duration", + ]) + .and_where("command = ?1") + .group_by("month_year") + .having("duration > 0"); + + let prev = prev.sql().expect("issue in stats previous query"); + let next = next.sql().expect("issue in stats next query"); + let total = total.sql().expect("issue in stats average query"); + let average = average.sql().expect("issue in stats previous query"); + let exits = exits.sql().expect("issue in stats exits query"); + let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); + let duration_over_time = duration_over_time + .sql() + .expect("issue in stats duration over time query"); + + let prev = sqlx::query(&prev) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let next = sqlx::query(&next) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let total: (i64,) = sqlx::query_as(&total) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let average: (f64,) = sqlx::query_as(&average) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time = duration_over_time + .iter() + .map(|f| (f.0.clone(), f.1.round() as i64)) + .collect(); + + Ok(HistoryStats { + next, + previous: prev, + total: total.0 as u64, + average_duration: average.0 as u64, + exits, + day_of_week, + duration_over_time, + }) + } + + async fn get_dups(&self, before: i64, dupkeep: u32) -> Result> { + let res = sqlx::query( + "SELECT * FROM ( + SELECT *, ROW_NUMBER() + OVER (PARTITION BY command, cwd, hostname ORDER BY timestamp DESC) + AS rn + FROM history + ) sub + WHERE rn > ?1 and timestamp < ?2; + ", + ) + .bind(dupkeep) + .bind(before) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + fn clone_boxed(&self) -> Box { + Box::new(self.clone()) + } +} + +pub struct Paged { + database: Box, + page_size: usize, + last_id: Option, + include_deleted: bool, + unique: bool, +} + +impl Paged { + pub fn new( + database: Box, + page_size: usize, + include_deleted: bool, + unique: bool, + ) -> Self { + Self { + database, + page_size, + last_id: None, + include_deleted, + unique, + } + } + + pub async fn next(&mut self) -> Result>> { + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query.field("*").order_desc("id"); + + if !self.include_deleted { + query.and_where_is_null("deleted_at"); + } + + if self.unique { + // We want to deduplicate on command, but the user can search via cwd, hostname, and session. + // Without those fields, filter modes won't work right. With those fields, we get duplicates. + // This must be handled upstream. + query + .group_by("command, cwd, hostname, session") + .having("max(timestamp)"); + } + + query.limit(self.page_size); + + if let Some(last_id) = &self.last_id { + query.and_where_lt("id", quote(last_id)); + } + + let query = query.sql().expect("bug in list query. please report"); + let res = self.database.query_history(&query).await?; + + if res.is_empty() { + Ok(None) + } else { + self.last_id = Some(res.last().unwrap().id.0.clone()); + Ok(Some(res)) + } + } +} + +trait SqlBuilderExt { + fn fuzzy_condition( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +} + +#[cfg(test)] +mod test { + use crate::atuin_client::settings::test_local_timeout; + + use super::*; + use std::time::{Duration, Instant}; + + async fn assert_search_eq( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected: usize, + ) -> Result> { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let results = db + .search( + mode, + filter_mode, + &context, + query, + OptFilters { + ..Default::default() + }, + ) + .await?; + + assert_eq!( + results.len(), + expected, + "query \"{}\", commands: {:?}", + query, + results.iter().map(|a| &a.command).collect::>() + ); + Ok(results) + } + + async fn assert_search_commands( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected_commands: Vec<&str>, + ) { + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) + .await + .unwrap(); + let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); + assert_eq!(commands, expected_commands); + } + + async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + let mut captured: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(cmd) + .cwd("/home/ellie") + .build() + .into(); + + captured.exit = 0; + captured.duration = 1; + captured.session = "beep boop".to_string(); + captured.hostname = "booop".to_string(); + + db.save(&captured).await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_prefix() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fulltext() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + new_history_item(&mut db, "ls /home/frank").await.unwrap(); + new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); + new_history_item(&mut db, "/home/ellie/.bin/rustup") + .await + .unwrap(); + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) + .await + .unwrap(); + + // single term operators + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) + .await + .unwrap(); + + // multiple terms + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); + + // case matching + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_reordered_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + // test ordering of results: we should choose the first, even though it happened longer ago. + + new_history_item(&mut db, "curl").await.unwrap(); + new_history_item(&mut db, "corburl").await.unwrap(); + + // if fuzzy reordering is on, it should come back in a more sensible order + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_basic() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add 5 history items + for i in 0..5 { + new_history_item(&mut db, &format!("command{}", i)) + .await + .unwrap(); + } + + // Create a paged iterator with page_size of 2 + let mut paged = db.all_paged(2, false, false); + + // First page should have 2 items + let page1 = paged.next().await.unwrap(); + assert!(page1.is_some()); + assert_eq!(page1.unwrap().len(), 2); + + // Second page should have 2 items + let page2 = paged.next().await.unwrap(); + assert!(page2.is_some()); + assert_eq!(page2.unwrap().len(), 2); + + // Third page should have 1 item + let page3 = paged.next().await.unwrap(); + assert!(page3.is_some()); + assert_eq!(page3.unwrap().len(), 1); + + // Fourth page should be None (exhausted) + let page4 = paged.next().await.unwrap(); + assert!(page4.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_empty() { + let db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Create a paged iterator on empty database + let mut paged = db.all_paged(10, false, false); + + // Should return None immediately + let page = paged.next().await.unwrap(); + assert!(page.is_none()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_unique() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add duplicate commands + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "duplicate").await.unwrap(); + new_history_item(&mut db, "unique1").await.unwrap(); + new_history_item(&mut db, "unique2").await.unwrap(); + + // Without unique flag - should get all 4 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 4); + + // With unique flag - should get 3 (duplicates collapsed) + let mut paged_unique = db.all_paged(10, false, true); + let page_unique = paged_unique.next().await.unwrap().unwrap(); + assert_eq!(page_unique.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_paged_include_deleted() { + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + + // Add items + new_history_item(&mut db, "keep1").await.unwrap(); + new_history_item(&mut db, "keep2").await.unwrap(); + new_history_item(&mut db, "delete_me").await.unwrap(); + + // Delete one item + let all = db + .list( + &[], + &Context { + hostname: "".to_string(), + session: "".to_string(), + cwd: "".to_string(), + host_id: "".to_string(), + git_root: None, + }, + None, + false, + false, + ) + .await + .unwrap(); + + let to_delete = all + .iter() + .find(|h| h.command == "delete_me") + .unwrap() + .clone(); + db.delete(to_delete).await.unwrap(); + + // Without include_deleted - should get 2 + let mut paged = db.all_paged(10, false, false); + let page = paged.next().await.unwrap().unwrap(); + assert_eq!(page.len(), 2); + + // With include_deleted - should get 3 + let mut paged_deleted = db.all_paged(10, true, false); + let page_deleted = paged_deleted.next().await.unwrap().unwrap(); + assert_eq!(page_deleted.len(), 3); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let mut db = Sqlite::new("sqlite::memory:", test_local_timeout()) + .await + .unwrap(); + for _i in 1..10000 { + new_history_item(&mut db, "i am a duplicated command") + .await + .unwrap(); + } + let start = Instant::now(); + let _results = db + .search( + SearchMode::Fuzzy, + FilterMode::Global, + &context, + "", + OptFilters { + ..Default::default() + }, + ) + .await + .unwrap(); + let duration = start.elapsed(); + + assert!(duration < Duration::from_secs(15)); + } +} + +pub struct QueryTokenizer<'a> { + query: &'a str, + last_pos: usize, +} + +pub enum QueryToken<'a> { + Match(&'a str, bool), + MatchStart(&'a str, bool), + MatchEnd(&'a str, bool), + MatchFull(&'a str, bool), + Or, + Regex(&'a str), +} + +impl<'a> QueryToken<'a> { + pub fn has_uppercase(&self) -> bool { + match self { + Self::Match(term, _) + | Self::MatchStart(term, _) + | Self::MatchEnd(term, _) + | Self::MatchFull(term, _) => term.contains(char::is_uppercase), + _ => false, + } + } + + pub fn is_inverse(&self) -> bool { + match self { + Self::Match(_, inv) + | Self::MatchStart(_, inv) + | Self::MatchEnd(_, inv) + | Self::MatchFull(_, inv) => *inv, + _ => false, + } + } +} + +impl<'a> QueryTokenizer<'a> { + pub fn new(query: &'a str) -> Self { + Self { query, last_pos: 0 } + } +} + +impl<'a> Iterator for QueryTokenizer<'a> { + type Item = QueryToken<'a>; + fn next(&mut self) -> Option { + let remaining = &self.query[self.last_pos..]; + if remaining.is_empty() { + return None; + } + + if let Some(remaining) = remaining.strip_prefix("r/") { + let (regex, next_pos) = if let Some(end) = remaining.find("/ ") { + (&remaining[..end], self.last_pos + 2 + end + 2) + } else if let Some(remaining) = remaining.strip_suffix('/') { + (remaining, self.query.len()) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + Some(QueryToken::Regex(regex)) + } else { + let (mut part, next_pos) = if let Some(sp) = remaining.find(' ') { + (&remaining[..sp], self.last_pos + sp + 1) + } else { + (remaining, self.query.len()) + }; + self.last_pos = next_pos; + + if part == "|" { + return Some(QueryToken::Or); + } + + let mut is_inverse = false; + if let Some(s) = part.strip_prefix('!') { + part = s; + is_inverse = true; + } + let token = if let Some(s) = part.strip_prefix('^') { + QueryToken::MatchStart(s, is_inverse) + } else if let Some(s) = part.strip_suffix('$') { + QueryToken::MatchEnd(s, is_inverse) + } else if let Some(s) = part.strip_prefix('\'') { + QueryToken::MatchFull(s, is_inverse) + } else { + QueryToken::Match(part, is_inverse) + }; + Some(token) + } + } +} diff --git a/crates/turtle/src/atuin_client/distro.rs b/crates/turtle/src/atuin_client/distro.rs new file mode 100644 index 00000000..dead8355 --- /dev/null +++ b/crates/turtle/src/atuin_client/distro.rs @@ -0,0 +1,89 @@ +use std::process::Command; + +/// Detect the Linux distribution from the system, +/// using system-specific release files and falling +/// back to lsb_release. +pub fn detect_linux_distribution() -> String { + detect_from_os_release() + .or_else(detect_from_debian_version) + .or_else(detect_from_centos_release) + .or_else(detect_from_redhat_release) + .or_else(detect_from_fedora_release) + .or_else(detect_from_arch_release) + .or_else(detect_from_alpine_release) + .or_else(detect_from_suse_release) + .or_else(detect_from_lsb_release) + .unwrap_or_else(|| "Unknown".to_string()) +} + +fn detect_from_os_release() -> Option { + let content = std::fs::read_to_string("/etc/os-release").ok()?; + + content + .lines() + .find(|l| l.starts_with("PRETTY_NAME=")) + .and_then(|l| l.split_once('=').map(|s| s.1)) + .map(|s| s.trim_matches('"').to_string()) +} + +fn detect_from_debian_version() -> Option { + std::fs::read_to_string("/etc/debian_version") + .ok() + .map(|v| format!("Debian {}", v.trim())) +} + +fn detect_from_centos_release() -> Option { + std::fs::read_to_string("/etc/centos-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_redhat_release() -> Option { + std::fs::read_to_string("/etc/redhat-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_fedora_release() -> Option { + std::fs::read_to_string("/etc/fedora-release") + .ok() + .map(|v| v.trim().to_string()) +} + +fn detect_from_arch_release() -> Option { + std::fs::read_to_string("/etc/arch-release") + .ok() + .filter(|v| !v.trim().is_empty()) + .map(|_| "Arch Linux".to_string()) +} + +fn detect_from_alpine_release() -> Option { + std::fs::read_to_string("/etc/alpine-release") + .ok() + .map(|v| format!("Alpine {}", v.trim())) +} + +fn detect_from_suse_release() -> Option { + std::fs::read_to_string("/etc/SuSE-release") + .ok() + .and_then(|content| content.lines().next().map(|l| l.trim().to_string())) +} + +fn detect_from_lsb_release() -> Option { + let output = Command::new("lsb_release").arg("-a").output().ok()?; + + if !output.status.success() { + return None; + } + + let output = String::from_utf8(output.stdout).ok()?; + linux_distro_from_lsb_release(&output) +} + +fn linux_distro_from_lsb_release(output: &str) -> Option { + output + .lines() + .find(|line| line.starts_with("Description:")) + .and_then(|line| line.split_once(':').map(|s| s.1)) + .map(|s| s.trim().to_string()) +} diff --git a/crates/turtle/src/atuin_client/encryption.rs b/crates/turtle/src/atuin_client/encryption.rs new file mode 100644 index 00000000..20a0cd90 --- /dev/null +++ b/crates/turtle/src/atuin_client/encryption.rs @@ -0,0 +1,440 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::{io::prelude::*, path::PathBuf}; + +use base64::prelude::{BASE64_STANDARD, Engine}; +pub use crypto_secretbox::Key; +use crypto_secretbox::{ + AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, + aead::{Nonce, OsRng}, +}; +use eyre::{Context, Result, bail, ensure, eyre}; +use fs_err as fs; +use rmp::{Marker, decode::Bytes}; +use serde::{Deserialize, Serialize}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339, macros::format_description}; + +use crate::atuin_client::{history::History, settings::Settings}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec, + pub nonce: Nonce, +} + +pub fn generate_encoded_key() -> Result<(Key, String)> { + let key = XSalsa20Poly1305::generate_key(&mut OsRng); + let encoded = encode_key(&key)?; + + Ok((key, encoded)) +} + +pub fn new_key(settings: &Settings) -> Result { + let path = settings.key_path.as_str(); + let path = PathBuf::from(path); + + if path.exists() { + bail!("key already exists! cannot overwrite"); + } + + let (key, encoded) = generate_encoded_key()?; + + let mut file = fs::File::create(path)?; + file.write_all(encoded.as_bytes())?; + + Ok(key) +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result { + let path = settings.key_path.as_str(); + + let key = if PathBuf::from(path).exists() { + let key = fs_err::read_to_string(path)?; + decode_key(key)? + } else { + new_key(settings)? + }; + + Ok(key) +} + +pub fn encode_key(key: &Key) -> Result { + let mut buf = vec![]; + rmp::encode::write_array_len(&mut buf, key.len() as u32) + .wrap_err("could not encode key to message pack")?; + for b in key { + rmp::encode::write_uint(&mut buf, *b as u64) + .wrap_err("could not encode key to message pack")?; + } + let buf = BASE64_STANDARD.encode(buf); + + Ok(buf) +} + +pub fn decode_key(key: String) -> Result { + use rmp::decode; + + let buf = BASE64_STANDARD + .decode(key.trim_end()) + .wrap_err("encryption key is not a valid base64 encoding")?; + + // old code wrote the key as a fixed length array of 32 bytes + // new code writes the key with a length prefix + match <[u8; 32]>::try_from(&*buf) { + Ok(key) => Ok(key.into()), + Err(_) => { + let mut bytes = rmp::decode::Bytes::new(&buf); + + match Marker::from_u8(buf[0]) { + Marker::Bin8 => { + let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + let key = <[u8; 32]>::try_from(bytes.remaining_slice()) + .context("could not decode encryption key")?; + Ok(key.into()) + } + Marker::Array16 => { + let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + + let mut key = Key::default(); + for i in &mut key { + *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + } + Ok(key) + } + _ => bail!("could not decode encryption key"), + } + } + } +} + +pub fn encrypt(history: &History, key: &Key) -> Result { + // serialize with msgpack + let mut buf = encode(history)?; + + let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); + XSalsa20Poly1305::new(key) + .encrypt_in_place(&nonce, &[], &mut buf) + .map_err(|_| eyre!("could not encrypt"))?; + + Ok(EncryptedHistory { + ciphertext: buf, + nonce, + }) +} + +pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result { + XSalsa20Poly1305::new(key) + .decrypt_in_place( + &encrypted_history.nonce, + &[], + &mut encrypted_history.ciphertext, + ) + .map_err(|_| eyre!("could not decrypt history"))?; + let plaintext = encrypted_history.ciphertext; + + let history = decode(&plaintext)?; + + Ok(history) +} + +fn format_rfc3339(ts: OffsetDateTime) -> Result { + // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. + // time does not have this functionality. + static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); + static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); + static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); + static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); + + let fmt = match ts.nanosecond() { + 0 => PARTIAL_RFC3339_0, + ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, + ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, + _ => PARTIAL_RFC3339_9, + }; + + Ok(ts.format(fmt)?) +} + +fn encode(h: &History) -> Result> { + use rmp::encode; + + let mut output = vec![]; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &h.id.0)?; + encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; + encode::write_sint(&mut output, h.duration)?; + encode::write_sint(&mut output, h.exit)?; + encode::write_str(&mut output, &h.command)?; + encode::write_str(&mut output, &h.cwd)?; + encode::write_str(&mut output, &h.session)?; + encode::write_str(&mut output, &h.hostname)?; + match h.deleted_at { + Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, + None => encode::write_nil(&mut output)?, + } + + Ok(output) +} + +fn decode(bytes: &[u8]) -> Result { + use rmp::decode::{self, DecodeStringError}; + + let mut bytes = Bytes::new(bytes); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + if nfields < 8 { + bail!("malformed decrypted history") + } + if nfields > 9 { + bail!("cannot decrypt history from a newer version of atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut deleted_at = None; + let mut bytes = bytes; + if nfields > 8 { + bytes = match decode::read_str_from_slice(bytes) { + Ok((d, b)) => { + deleted_at = Some(d); + b + } + // we accept null here + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + // consume the null marker + let mut c = Bytes::new(bytes); + decode::read_nil(&mut c).map_err(error_report)?; + c.remaining_slice() + } + Err(err) => return Err(error_report(err)), + }; + } + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: History::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::parse(t, &Rfc3339)) + .transpose()?, + }) +} + +fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") +} + +#[cfg(test)] +mod test { + use crypto_secretbox::{KeyInit, XSalsa20Poly1305, aead::OsRng}; + use pretty_assertions::assert_eq; + use time::{OffsetDateTime, macros::datetime}; + + use crate::history::History; + + use super::{decode, decrypt, encode, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); + let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); + + let history = History::from_db() + .id("1".into()) + .timestamp(OffsetDateTime::now_utc()) + .command("ls".into()) + .cwd("/home/ellie".into()) + .exit(0) + .duration(1) + .session("beep boop".into()) + .hostname("booop".into()) + .author("booop".into()) + .intent(None) + .deleted_at(None) + .build() + .into(); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(e1, &key1) { + Err(e) => panic!("failed to decrypt, got {e}"), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); + } + + #[test] + fn test_decode() { + let bytes = [ + 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, 192, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + + let b = encode(&h).unwrap(); + assert_eq!(&bytes, &*b); + } + + #[test] + fn test_decode_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), + }; + + let b = encode(&history).unwrap(); + let h = decode(&b).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn test_decode_old() { + let bytes = [ + 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn key_encodings() { + use super::{Key, decode_key, encode_key}; + + // a history of our key encodings. + // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) + // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) + // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== + // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) + // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) + + let key = Key::from([ + 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, + 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, + ]); + + assert_eq!( + encode_key(&key).unwrap(), + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" + ); + + // key encodings we have to support + let valid_encodings = [ + "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", + ]; + + for k in valid_encodings { + assert_eq!(decode_key(k.to_owned()).expect(k), key); + } + } +} diff --git a/crates/turtle/src/atuin_client/history.rs b/crates/turtle/src/atuin_client/history.rs new file mode 100644 index 00000000..cef65115 --- /dev/null +++ b/crates/turtle/src/atuin_client/history.rs @@ -0,0 +1,756 @@ +use core::fmt::Formatter; +use rmp::decode::DecodeStringError; +use rmp::decode::ValueReadError; +use rmp::{Marker, decode::Bytes}; +use std::env; +use std::fmt::Display; + +use crate::atuin_common::record::DecryptedData; +use crate::atuin_common::utils::uuid_v7; + +use eyre::{Result, bail, eyre}; + +use crate::atuin_client::secrets::SECRET_PATTERNS_RE; +use crate::atuin_client::settings::Settings; +use crate::atuin_client::utils::get_host_user; +use time::OffsetDateTime; + +mod builder; +pub mod store; + +/// Known AI agent author values. Used to expand `$all-agent` and `$all-user` filters. +pub const KNOWN_AGENTS: &[&str] = &["claude-code", "codex", "copilot", "pi"]; +pub const AUTHOR_FILTER_ALL_USER: &str = "$all-user"; +pub const AUTHOR_FILTER_ALL_AGENT: &str = "$all-agent"; + +pub fn is_known_agent(author: &str) -> bool { + KNOWN_AGENTS.contains(&author) +} + +pub fn author_matches_filters(author: &str, filters: &[String]) -> bool { + filters.is_empty() + || filters.iter().any(|filter| match filter.as_str() { + AUTHOR_FILTER_ALL_USER => !is_known_agent(author), + AUTHOR_FILTER_ALL_AGENT => is_known_agent(author), + literal => author == literal, + }) +} + +pub(crate) const HISTORY_VERSION_V0: &str = "v0"; +pub(crate) const HISTORY_VERSION_V1: &str = "v1"; +const HISTORY_RECORD_VERSION_V0: u16 = 0; +const HISTORY_RECORD_VERSION_V1: u16 = 1; +pub(crate) const HISTORY_VERSION: &str = HISTORY_VERSION_V1; +pub const HISTORY_TAG: &str = "history"; +const HISTORY_AUTHOR_ENV: &str = "ATUIN_HISTORY_AUTHOR"; +const HISTORY_INTENT_ENV: &str = "ATUIN_HISTORY_INTENT"; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct HistoryId(pub String); + +impl Display for HistoryId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for HistoryId { + fn from(s: String) -> Self { + Self(s) + } +} + +/// Client-side history entry. +/// +/// Client stores data unencrypted, and only encrypts it before sending to the server. +/// +/// To create a new history entry, use one of the builders: +/// - [`History::import()`] to import an entry from the shell history file +/// - [`History::capture()`] to capture an entry via hook +/// - [`History::from_db()`] to create an instance from the database entry +// +// ## Implementation Notes +// +// New fields must be added to `History::{serialize,deserialize}` in a backwards +// compatible way (sensible defaults and careful `nfields` handling). +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct History { + /// A client-generated ID, used to identify the entry when syncing. + /// + /// Stored as `client_id` in the database. + pub id: HistoryId, + /// When the command was run. + pub timestamp: OffsetDateTime, + /// How long the command took to run. + pub duration: i64, + /// The exit code of the command. + pub exit: i64, + /// The command that was run. + pub command: String, + /// The current working directory when the command was run. + pub cwd: String, + /// The session ID, associated with a terminal session. + pub session: String, + /// The hostname of the machine the command was run on. + pub hostname: String, + /// Who wrote this command (human user or automation/agent identity). + pub author: String, + /// Optional rationale for why the command was executed. + pub intent: Option, + /// Timestamp, which is set when the entry is deleted, allowing a soft delete. + pub deleted_at: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct HistoryStats { + /// The command that was ran after this one in the session + pub next: Option, + /// + /// The command that was ran before this one in the session + pub previous: Option, + + /// How many times has this command been ran? + pub total: u64, + + pub average_duration: u64, + + pub exits: Vec<(i64, i64)>, + + pub day_of_week: Vec<(String, i64)>, + + pub duration_over_time: Vec<(String, i64)>, +} + +impl History { + pub(crate) fn author_from_hostname(hostname: &str) -> String { + hostname + .split_once(':') + .map_or_else(|| hostname.to_owned(), |(_, user)| user.to_owned()) + } + + fn normalize_optional_field(field: Option) -> Option { + field.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } + }) + } + + #[expect(clippy::too_many_arguments)] + fn new( + timestamp: OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: Option, + hostname: Option, + author: Option, + intent: Option, + deleted_at: Option, + ) -> Self { + let session = session + .or_else(|| env::var("ATUIN_SESSION").ok()) + .unwrap_or_else(|| uuid_v7().as_simple().to_string()); + let hostname = hostname.unwrap_or_else(get_host_user); + let author = Self::normalize_optional_field(author) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_AUTHOR_ENV).ok())) + .unwrap_or_else(|| Self::author_from_hostname(hostname.as_str())); + let intent = Self::normalize_optional_field(intent) + .or_else(|| Self::normalize_optional_field(env::var(HISTORY_INTENT_ENV).ok())); + + Self { + id: uuid_v7().as_simple().to_string().into(), + timestamp, + command, + cwd, + exit, + duration, + session, + hostname, + author, + intent, + deleted_at, + } + } + + pub fn serialize(&self) -> Result { + // This is pretty much the same as what we used for the old history, with one difference - + // it uses integers for timestamps rather than a string format. + + use rmp::encode; + + let mut output = vec![]; + + // write the version + encode::write_u16(&mut output, HISTORY_RECORD_VERSION_V1)?; + let include_intent = self.intent.is_some(); + encode::write_array_len(&mut output, 10 + u32::from(include_intent))?; + + encode::write_str(&mut output, &self.id.0)?; + encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; + encode::write_sint(&mut output, self.duration)?; + encode::write_sint(&mut output, self.exit)?; + encode::write_str(&mut output, &self.command)?; + encode::write_str(&mut output, &self.cwd)?; + encode::write_str(&mut output, &self.session)?; + encode::write_str(&mut output, &self.hostname)?; + + match self.deleted_at { + Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, + None => encode::write_nil(&mut output)?, + } + + encode::write_str(&mut output, self.author.as_str())?; + if let Some(intent) = &self.intent { + encode::write_str(&mut output, intent.as_str())?; + } + + Ok(DecryptedData(output)) + } + + fn read_optional_string(bytes: &[u8]) -> Result<(Option, &[u8])> { + use rmp::decode; + + fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match decode::read_str_from_slice(bytes) { + Ok((value, bytes)) => Ok((Some(value.to_owned()), bytes)), + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + let mut cursor = Bytes::new(bytes); + decode::read_nil(&mut cursor).map_err(error_report)?; + + Ok((None, cursor.remaining_slice())) + } + Err(err) => Err(error_report(err)), + } + } + + fn deserialize_v0(bytes: &[u8]) -> Result { + use rmp::decode; + + fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V0 { + bail!("expected decoding v0 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if nfields != 9 { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: Self::author_from_hostname(hostname), + intent: None, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + fn deserialize_v1(bytes: &[u8]) -> Result { + use rmp::decode; + + fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != HISTORY_RECORD_VERSION_V1 { + bail!("expected decoding v1 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if !(10..=11).contains(&nfields) { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + let (author, bytes) = Self::read_optional_string(bytes)?; + let (intent, bytes) = if nfields > 10 { + Self::read_optional_string(bytes)? + } else { + (None, bytes) + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + author: author.unwrap_or_else(|| Self::author_from_hostname(hostname)), + intent, + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result { + match version { + HISTORY_VERSION_V0 => Self::deserialize_v0(bytes), + HISTORY_VERSION_V1 => Self::deserialize_v1(bytes), + + _ => bail!("unknown version {version:?}"), + } + } + + /// Builder for a history entry that is imported from shell history. + /// + /// The only two required fields are `timestamp` and `command`. + /// + /// ## Examples + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + /// + /// If shell history contains more information, it can be added to the builder: + /// ``` + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .exit(0) + /// .duration(100) + /// .build() + /// .into(); + /// ``` + /// + /// Unknown command or command without timestamp cannot be imported, which + /// is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because timestamp is missing + /// let history: History = History::import() + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn import() -> builder::HistoryImportedBuilder { + builder::HistoryImported::builder() + } + + /// Builder for a history entry that is captured via hook. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `cwd` is missing + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn capture() -> builder::HistoryCapturedBuilder { + builder::HistoryCaptured::builder() + } + + /// Builder for a history entry that is captured via hook, and sent to the daemon. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// It does, however, include information that can usually be inferred. + /// + /// This is because the daemon we are sending a request to lacks the context of the command + /// + /// ## Examples + /// ```rust + /// use crate::atuin_client::history::History; + /// + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .hostname("computer:ellie") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `hostname` is missing + /// let history: History = History::daemon() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .session("018deb6e8287781f9973ef40e0fde76b") + /// .build() + /// .into(); + /// ``` + pub fn daemon() -> builder::HistoryDaemonCaptureBuilder { + builder::HistoryDaemonCapture::builder() + } + + /// Builder for a history entry that is imported from the database. + /// + /// All fields are required, as they are all present in the database. + /// + /// ```compile_fail + /// use crate::atuin_client::history::History; + /// + /// // this will not compile because `id` field is missing + /// let history: History = History::from_db() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la".to_string()) + /// .cwd("/home/user".to_string()) + /// .exit(0) + /// .duration(100) + /// .session("somesession".to_string()) + /// .hostname("localhost".to_string()) + /// .author("user".to_string()) + /// .intent(None) + /// .deleted_at(None) + /// .build() + /// .into(); + /// ``` + pub fn from_db() -> builder::HistoryFromDbBuilder { + builder::HistoryFromDb::builder() + } + + pub fn success(&self) -> bool { + self.exit == 0 || self.duration == -1 + } + + pub fn should_save(&self, settings: &Settings) -> bool { + !(self.command.starts_with(' ') + || self.command.is_empty() + || settings.history_filter.is_match(&self.command) + || settings.cwd_filter.is_match(&self.cwd) + || (settings.secrets_filter && SECRET_PATTERNS_RE.is_match(&self.command))) + } +} + +#[cfg(test)] +mod tests { + use regex::RegexSet; + use time::macros::datetime; + + use crate::{ + history::{AUTHOR_FILTER_ALL_AGENT, AUTHOR_FILTER_ALL_USER, HISTORY_VERSION}, + settings::Settings, + }; + + use super::{History, author_matches_filters, is_known_agent}; + + // Test that we don't save history where necessary + #[test] + fn privacy_test() { + let settings = Settings { + cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), + history_filter: RegexSet::new(["^psql"]).unwrap(), + ..Settings::utc() + }; + + let normal_command: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo foo") + .cwd("/") + .build() + .into(); + + let with_space: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command(" echo bar") + .cwd("/") + .build() + .into(); + + let empty: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("") + .cwd("/") + .build() + .into(); + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + let secret_dir: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo ohno") + .cwd("/supasecret") + .build() + .into(); + + let with_psql: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("psql") + .cwd("/supasecret") + .build() + .into(); + + assert!(normal_command.should_save(&settings)); + assert!(!with_space.should_save(&settings)); + assert!(!empty.should_save(&settings)); + assert!(!stripe_key.should_save(&settings)); + assert!(!secret_dir.should_save(&settings)); + assert!(!with_psql.should_save(&settings)); + } + + #[test] + fn known_agents_include_pi() { + assert!(is_known_agent("pi")); + assert!(author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_AGENT.to_string()] + )); + assert!(!author_matches_filters( + "pi", + &[AUTHOR_FILTER_ALL_USER.to_string()] + )); + } + + #[test] + fn disable_secrets() { + let settings = Settings { + secrets_filter: false, + ..Settings::utc() + }; + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + assert!(stripe_key.should_save(&settings)); + } + + #[test] + fn test_serialize_deserialize() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + assert_eq!( + &serialized.0[0..3], + [205, 0, 1], + "should encode as history v1" + ); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), + }; + + let serialized = history.serialize().expect("failed to serialize history"); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_with_author_and_intent() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "claude".to_owned(), + intent: Some("check repository status".to_owned()), + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_version() { + // v0 + let bytes_v0 = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let deserialized = History::deserialize(&bytes_v0, "v0"); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); + assert!(deserialized.is_err()); + + let current = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + author: "conrad.ludgate".to_owned(), + intent: None, + deleted_at: None, + }; + + let bytes_v1 = current.serialize().expect("failed to serialize history"); + let deserialized = History::deserialize(&bytes_v1.0, HISTORY_VERSION); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v1.0, "v0"); + assert!(deserialized.is_err()); + } +} diff --git a/crates/turtle/src/atuin_client/history/builder.rs b/crates/turtle/src/atuin_client/history/builder.rs new file mode 100644 index 00000000..72a505fd --- /dev/null +++ b/crates/turtle/src/atuin_client/history/builder.rs @@ -0,0 +1,154 @@ +use typed_builder::TypedBuilder; + +use super::History; + +/// Builder for a history entry that is imported from shell history. +/// +/// The only two required fields are `timestamp` and `command`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryImported { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(default = "unknown".into(), setter(into))] + cwd: String, + #[builder(default = -1)] + exit: i64, + #[builder(default = -1)] + duration: i64, + #[builder(default, setter(strip_option, into))] + session: Option, + #[builder(default, setter(strip_option, into))] + hostname: Option, + #[builder(default, setter(strip_option, into))] + author: Option, + #[builder(default, setter(strip_option, into))] + intent: Option, +} + +impl From for History { + fn from(imported: HistoryImported) -> Self { + History::new( + imported.timestamp, + imported.command, + imported.cwd, + imported.exit, + imported.duration, + imported.session, + imported.hostname, + imported.author, + imported.intent, + None, + ) + } +} + +/// Builder for a history entry that is captured via hook. +/// +/// This builder is used only at the `start` step of the hook, +/// so it doesn't have any fields which are known only after +/// the command is finished, such as `exit` or `duration`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryCaptured { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(default, setter(strip_option, into))] + author: Option, + #[builder(default, setter(strip_option, into))] + intent: Option, +} + +impl From for History { + fn from(captured: HistoryCaptured) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + None, + None, + captured.author, + captured.intent, + None, + ) + } +} + +/// Builder for a history entry that is loaded from the database. +/// +/// All fields are required, as they are all present in the database. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryFromDb { + id: String, + timestamp: time::OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: String, + hostname: String, + author: String, + intent: Option, + deleted_at: Option, +} + +impl From for History { + fn from(from_db: HistoryFromDb) -> Self { + History { + id: from_db.id.into(), + timestamp: from_db.timestamp, + exit: from_db.exit, + command: from_db.command, + cwd: from_db.cwd, + duration: from_db.duration, + session: from_db.session, + hostname: from_db.hostname, + author: from_db.author, + intent: from_db.intent, + deleted_at: from_db.deleted_at, + } + } +} + +/// Builder for a history entry that is captured via hook and sent to the daemon +/// +/// This builder is similar to Capture, but we just require more information up front. +/// For the old setup, we could just rely on History::new to read some of the missing +/// data. This is no longer the case. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryDaemonCapture { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, + #[builder(setter(into))] + session: String, + #[builder(setter(into))] + hostname: String, + #[builder(default, setter(strip_option, into))] + author: Option, + #[builder(default, setter(strip_option, into))] + intent: Option, +} + +impl From for History { + fn from(captured: HistoryDaemonCapture) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + Some(captured.session), + Some(captured.hostname), + captured.author, + captured.intent, + None, + ) + } +} diff --git a/crates/turtle/src/atuin_client/history/store.rs b/crates/turtle/src/atuin_client/history/store.rs new file mode 100644 index 00000000..66d9db47 --- /dev/null +++ b/crates/turtle/src/atuin_client/history/store.rs @@ -0,0 +1,435 @@ +use std::{collections::HashSet, fmt::Write, time::Duration}; + +use eyre::{Result, bail, eyre}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; +use rmp::decode::Bytes; +use tracing::debug; + +use crate::atuin_client::{ + database::{Database, current_context}, + record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, +}; +use crate::atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; + +use super::{HISTORY_TAG, HISTORY_VERSION, HISTORY_VERSION_V0, History, HistoryId}; + +#[derive(Debug, Clone)] +pub struct HistoryStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum HistoryRecord { + Create(History), // Create a history record + Delete(HistoryId), // Delete a history record, identified by ID +} + +impl HistoryRecord { + /// Serialize a history record, returning DecryptedData + /// The record will be of a certain type + /// We map those like so: + /// + /// HistoryRecord::Create -> 0 + /// HistoryRecord::Delete-> 1 + /// + /// This numeric identifier is then written as the first byte to the buffer. For history, we + /// append the serialized history right afterwards, to avoid having to handle serialization + /// twice. + /// + /// Deletion simply refers to the history by ID + pub fn serialize(&self) -> Result { + // probably don't actually need to use rmp here, but if we ever need to extend it, it's a + // nice wrapper around raw byte stuff + use rmp::encode; + + let mut output = vec![]; + + match self { + HistoryRecord::Create(history) => { + // 0 -> a history create + encode::write_u8(&mut output, 0)?; + + let bytes = history.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + HistoryRecord::Delete(id) => { + // 1 -> a history delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id.0.as_str())?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result { + use rmp::decode; + + fn error_report(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(&bytes.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // 0 -> HistoryRecord::Create + 0 => { + // not super useful to us atm, but perhaps in the future + // written by write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + + let record = History::deserialize(bytes.remaining_slice(), version)?; + + Ok(HistoryRecord::Create(record)) + } + + // 1 -> HistoryRecord::Delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!( + "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" + ); + } + + Ok(HistoryRecord::Delete(id.to_string().into())) + } + + n => { + bail!("unknown HistoryRecord type {n}") + } + } + } +} + +impl HistoryStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + HistoryStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.store + .push(&record.encrypt::(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + async fn push_batch(&self, records: impl Iterator) -> Result<()> { + let mut ret = Vec::new(); + + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + // Could probably _also_ do this as an iterator, but let's see how this is for now. + // optimizing for minimal sqlite transactions, this code can be optimised later + for (n, record) in records.enumerate() { + let bytes = record.serialize()?; + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx + n as u64) + .data(bytes) + .build(); + + let record = record.encrypt::(&self.encryption_key); + + ret.push(record); + } + + self.store.push_batch(ret.iter()).await?; + + Ok(()) + } + + pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { + let record = HistoryRecord::Delete(id); + + self.push_record(record).await + } + + /// Delete a batch of history entries via the record store. + /// Returns the record IDs so the caller can run incremental_build when ready. + pub async fn delete_entries( + &self, + entries: impl IntoIterator, + ) -> Result> { + let mut record_ids = Vec::new(); + for entry in entries { + let (id, _) = self.delete(entry.id).await?; + record_ids.push(id); + } + Ok(record_ids) + } + + pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { + // TODO(ellie): move the history store to its own file + // it's tiny rn so fine as is + let record = HistoryRecord::Create(history); + + self.push_record(record).await + } + + pub async fn history(&self) -> Result> { + // Atm this loads all history into memory + // Not ideal as that is potentially quite a lot, although history will be small. + let records = self.store.all_tagged(HISTORY_TAG).await?; + let mut ret = Vec::with_capacity(records.len()); + + for record in records.into_iter() { + let hist = match record.version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + let version = record.version.clone(); + let decrypted = record.decrypt::(&self.encryption_key)?; + + HistoryRecord::deserialize(&decrypted.data, version.as_str()) + } + version => bail!("unknown history version {version:?}"), + }?; + + ret.push(hist); + } + + Ok(ret) + } + + pub async fn build(&self, database: &dyn Database) -> Result<()> { + // I'd like to change how we rebuild and not couple this with the database, but need to + // consider the structure more deeply. This will be easy to change. + + // TODO(ellie): page or iterate this + let history = self.history().await?; + + // In theory we could flatten this here + // The current issue is that the database may have history in it already, from the old sync + // This didn't actually delete old history + // If we're sure we have a DB only maintained by the new store, we can flatten + // create/delete before we even get to sqlite + let mut creates = Vec::new(); + let mut deletes = Vec::new(); + + for i in history { + match i { + HistoryRecord::Create(h) => { + creates.push(h); + } + HistoryRecord::Delete(id) => { + deletes.push(id); + } + } + } + + database.save_bulk(&creates).await?; + database.delete_rows(&deletes).await?; + + Ok(()) + } + + pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { + for id in ids { + let record = self.store.get(*id).await; + + let record = match record { + Ok(record) => record, + _ => { + continue; + } + }; + + if record.tag != HISTORY_TAG { + continue; + } + + let version = record.version.clone(); + let decrypted = record.decrypt::(&self.encryption_key)?; + let record = match version.as_str() { + HISTORY_VERSION_V0 | HISTORY_VERSION => { + HistoryRecord::deserialize(&decrypted.data, version.as_str())? + } + version => bail!("unknown history version {version:?}"), + }; + + match record { + HistoryRecord::Create(h) => { + // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time + database.save(&h).await?; + } + HistoryRecord::Delete(id) => { + database.delete_rows(&[id]).await?; + } + } + } + + Ok(()) + } + + /// Get a list of history IDs that exist in the store + /// Note: This currently involves loading all history into memory. This is not going to be a + /// large amount in absolute terms, but do not all it in a hot loop. + pub async fn history_ids(&self) -> Result> { + let history = self.history().await?; + + let ret = HashSet::from_iter(history.iter().map(|h| match h { + HistoryRecord::Create(h) => h.id.clone(), + HistoryRecord::Delete(id) => id.clone(), + })); + + Ok(ret) + } + + pub async fn init_store(&self, db: &impl Database) -> Result<()> { + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::with_template("{spinner:.blue} {msg}") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + pb.enable_steady_tick(Duration::from_millis(500)); + + pb.set_message("Fetching history from old database"); + + let context = current_context().await?; + let history = db.list(&[], &context, None, false, true).await?; + + pb.set_message("Fetching history already in store"); + let store_ids = self.history_ids().await?; + + pb.set_message("Converting old history to new store"); + let mut records = Vec::new(); + + for i in history { + debug!("loaded {}", i.id); + + if store_ids.contains(&i.id) { + debug!("skipping {} - already exists", i.id); + continue; + } + + if i.deleted_at.is_some() { + records.push(HistoryRecord::Delete(i.id)); + } else { + records.push(HistoryRecord::Create(i)); + } + } + + pb.set_message("Writing to db"); + + if !records.is_empty() { + self.push_batch(records.into_iter()).await?; + } + + pb.finish_with_message("Import complete"); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::DecryptedData; + use time::macros::datetime; + + use crate::atuin_client::history::{HISTORY_VERSION, store::HistoryRecord}; + + use super::History; + + #[test] + fn test_serialize_deserialize_create() { + let bytes = [ + 204, 0, 196, 147, 205, 0, 1, 154, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, + 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, + 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, + 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, + 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, + 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, + 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, + 58, 101, 108, 108, 105, 101, 192, 165, 101, 108, 108, 105, 101, + ]; + + let history = History { + id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), + timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), + duration: 100, + exit: 0, + command: "ls".to_owned(), + cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), + session: "018cd4fead897597852527a31c998059".to_owned(), + hostname: "boop:ellie".to_owned(), + author: "ellie".to_owned(), + intent: None, + deleted_at: None, + }; + + let record = HistoryRecord::Create(history); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + // check the snapshot too + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } + + #[test] + fn test_serialize_deserialize_delete() { + let bytes = [ + 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, + 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, + ]; + let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } +} diff --git a/crates/turtle/src/atuin_client/import/bash.rs b/crates/turtle/src/atuin_client/import/bash.rs new file mode 100644 index 00000000..d92fdfa0 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/bash.rs @@ -0,0 +1,221 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use itertools::Itertools; +use time::{Duration, OffsetDateTime}; +use tracing::warn; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Bash { + bytes: Vec, +} + +fn default_histpath() -> Result { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".bash_history")) +} + +#[async_trait] +impl Importer for Bash { + const NAME: &'static str = "bash"; + + async fn new() -> Result { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + let count = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| matches!(line, LineType::Command(_))) + .count(); + Ok(count) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let lines = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored + .collect_vec(); + + let (commands_before_first_timestamp, first_timestamp) = lines + .iter() + .enumerate() + .find_map(|(i, line)| match line { + LineType::Timestamp(t) => Some((i, *t)), + _ => None, + }) + // if no known timestamps, use now as base + .unwrap_or((lines.len(), OffsetDateTime::now_utc())); + + // if no timestamp is recorded, then use this increment to set an arbitrary timestamp + // to preserve ordering + // this increment is deliberately very small to prevent particularly fast fingers + // causing ordering issues; it also helps in handling the "here document" syntax, + // where several lines are recorded in succession without individual timestamps + let timestamp_increment = Duration::milliseconds(1); + + // make sure there is a minimum amount of time before the first known timestamp + // to fit all commands, given the default increment + let mut next_timestamp = + first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; + + for line in lines.into_iter() { + match line { + LineType::NotUtf8 => unreachable!(), // already filtered + LineType::Empty => {} // do nothing + LineType::Timestamp(t) => { + if t < next_timestamp { + warn!( + "Time reversal detected in Bash history! Commands may be ordered incorrectly." + ); + } + next_timestamp = t; + } + LineType::Command(c) => { + let imported = History::import().timestamp(next_timestamp).command(c); + + h.push(imported.build().into()).await?; + next_timestamp += timestamp_increment; + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum LineType<'a> { + NotUtf8, + /// Can happen when using the "here document" syntax. + Empty, + /// A timestamp line start with a '#', followed immediately by an integer + /// that represents seconds since UNIX epoch. + Timestamp(OffsetDateTime), + /// Anything else. + Command(&'a str), +} +impl<'a> From<&'a [u8]> for LineType<'a> { + fn from(bytes: &'a [u8]) -> Self { + let Ok(line) = str::from_utf8(bytes) else { + return LineType::NotUtf8; + }; + if line.is_empty() { + return LineType::Empty; + } + + match try_parse_line_as_timestamp(line) { + Some(time) => LineType::Timestamp(time), + None => LineType::Command(line), + } + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option { + let seconds = line.strip_prefix('#')?.parse().ok()?; + OffsetDateTime::from_unix_timestamp(seconds).ok() +} + +#[cfg(test)] +mod test { + use std::cmp::Ordering; + + use itertools::{Itertools, assert_equal}; + + use crate::atuin_client::import::{Importer, tests::TestLoader}; + + use super::Bash; + + #[tokio::test] + async fn parse_no_timestamps() { + let bytes = r"cargo install atuin +cargo update +cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + #[tokio::test] + async fn parse_with_timestamps() { + let bytes = b"#1672918999 +git reset +#1672919006 +git clean -dxf +#1672919020 +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert_equal( + loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), + [1_672_918_999, 1_672_919_006, 1_672_919_020], + ) + } + + #[tokio::test] + async fn parse_with_partial_timestamps() { + let bytes = b"git reset +#1672919006 +git clean -dxf +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + fn is_strictly_sorted(iter: impl IntoIterator) -> bool + where + T: Clone + PartialOrd, + { + iter.into_iter() + .tuple_windows() + .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) + } +} diff --git a/crates/turtle/src/atuin_client/import/fish.rs b/crates/turtle/src/atuin_client/import/fish.rs new file mode 100644 index 00000000..1375bdd6 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/fish.rs @@ -0,0 +1,179 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Fish { + bytes: Vec, +} + +/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history +fn default_histpath() -> Result { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let data = std::env::var("XDG_DATA_HOME").map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ); + + // fish supports multiple history sessions + // If `fish_history` var is missing, or set to `default`, use `fish` as the session + let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); + let session = if session == "default" { + String::from("fish") + } else { + session + }; + + let mut histpath = data.join("fish"); + histpath.push(format!("{session}_history")); + + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Fish { + const NAME: &'static str = "fish"; + + async fn new() -> Result { + let bytes = read_to_end(default_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut time: Option = None; + let mut cmd: Option = None; + + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + if let Some(c) = s.strip_prefix("- cmd: ") { + // first, we must deal with the prev cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + // using raw strings to avoid needing escaping. + // replaces double backslashes with single backslashes + let c = c.replace(r"\\", r"\"); + // replaces escaped newlines + let c = c.replace(r"\n", "\n"); + // TODO: any other escape characters? + + cmd = Some(c); + } else if let Some(t) = s.strip_prefix(" when: ") { + // if t is not an int, just ignore this line + if let Ok(t) = t.parse::() { + time = Some(OffsetDateTime::from_unix_timestamp(t)?); + } + } else { + // ... ignore paths lines + } + } + + // we might have a trailing cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Fish; + + #[tokio::test] + async fn parse_complex() { + // complicated input with varying contents and escaped strings. + let bytes = r#"- cmd: history --help + when: 1639162832 +- cmd: cat ~/.bash_history + when: 1639162851 + paths: + - ~/.bash_history +- cmd: ls ~/.local/share/fish/fish_history + when: 1639162890 + paths: + - ~/.local/share/fish/fish_history +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162893 + paths: + - ~/.local/share/fish/fish_history +ERROR +- CORRUPTED: ENTRY + CONTINUE: + - AS + - NORMAL +- cmd: echo "foo" \\\n'bar' baz + when: 1639162933 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162939 + paths: + - ~/.local/share/fish/fish_history +- cmd: echo "\\"" \\\\ "\\\\" + when: 1639163063 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639163066 + paths: + - ~/.local/share/fish/fish_history +"# + .as_bytes() + .to_owned(); + + let fish = Fish { bytes }; + + let mut loader = TestLoader::default(); + fish.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for fish history entry + macro_rules! fishtory { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + fishtory!(1639162832, "history --help"); + fishtory!(1639162851, "cat ~/.bash_history"); + fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); + fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); + fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639163063, r#"echo "\"" \\ "\\""#); + fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); + } +} diff --git a/crates/turtle/src/atuin_client/import/mod.rs b/crates/turtle/src/atuin_client/import/mod.rs new file mode 100644 index 00000000..7726ead7 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/mod.rs @@ -0,0 +1,140 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + +use async_trait::async_trait; +use eyre::{Result, bail}; +use memchr::Memchr; + +use crate::atuin_client::history::History; + +pub mod bash; +pub mod fish; +pub mod nu; +pub mod nu_histdb; +pub mod powershell; +pub mod replxx; +pub mod resh; +pub mod xonsh; +pub mod xonsh_sqlite; +pub mod zsh; +pub mod zsh_histdb; + +#[async_trait] +pub trait Importer: Sized { + const NAME: &'static str; + async fn new() -> Result; + async fn entries(&mut self) -> Result; + async fn load(self, loader: &mut impl Loader) -> Result<()>; +} + +#[async_trait] +pub trait Loader: Sync + Send { + async fn push(&mut self, hist: History) -> eyre::Result<()>; +} + +fn unix_byte_lines(input: &[u8]) -> impl Iterator { + UnixByteLines { + iter: memchr::memchr_iter(b'\n', input), + bytes: input, + i: 0, + } +} + +struct UnixByteLines<'a> { + iter: Memchr<'a>, + bytes: &'a [u8], + i: usize, +} + +impl<'a> Iterator for UnixByteLines<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + let j = self.iter.next()?; + let out = &self.bytes[self.i..j]; + self.i = j + 1; + Some(out) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.iter.count() + } +} + +fn count_lines(input: &[u8]) -> usize { + unix_byte_lines(input).count() +} + +fn get_histpath(def: D) -> Result +where + D: FnOnce() -> Result, +{ + if let Ok(p) = std::env::var("HISTFILE") { + Ok(PathBuf::from(p)) + } else { + def() + } +} + +fn get_histfile_path(def: D) -> Result +where + D: FnOnce() -> Result, +{ + get_histpath(def).and_then(is_file) +} + +fn get_histdir_path(def: D) -> Result +where + D: FnOnce() -> Result, +{ + get_histpath(def).and_then(is_dir) +} + +fn read_to_end(path: PathBuf) -> Result> { + let mut bytes = Vec::new(); + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(bytes) +} +fn is_file(p: PathBuf) -> Result { + if p.is_file() { + Ok(p) + } else { + bail!( + "Could not find history file {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} +fn is_dir(p: PathBuf) -> Result { + if p.is_dir() { + Ok(p) + } else { + bail!( + "Could not find history directory {:?}. Try setting and exporting $HISTFILE", + p + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Default)] + pub struct TestLoader { + pub buf: Vec, + } + + #[async_trait] + impl Loader for TestLoader { + async fn push(&mut self, hist: History) -> Result<()> { + self.buf.push(hist); + Ok(()) + } + } +} diff --git a/crates/turtle/src/atuin_client/import/nu.rs b/crates/turtle/src/atuin_client/import/nu.rs new file mode 100644 index 00000000..c93789b8 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu.rs @@ -0,0 +1,67 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Nu { + bytes: Vec, +} + +fn get_histpath() -> Result { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histpath = config_dir.join("history.txt"); + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Nu { + const NAME: &'static str = "nu"; + + async fn new() -> Result { + let bytes = read_to_end(get_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + let cmd: String = s.replace("<\\n>", "\n"); + + let offset = time::Duration::nanoseconds(counter); + counter += 1; + + let entry = History::import().timestamp(now - offset).command(cmd); + + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/nu_histdb.rs b/crates/turtle/src/atuin_client/import/nu_histdb.rs new file mode 100644 index 00000000..7de18369 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/nu_histdb.rs @@ -0,0 +1,113 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::{Duration, OffsetDateTime}; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub command_line: Vec, + pub start_timestamp: i64, + pub session_id: i64, + pub hostname: Vec, + pub cwd: Vec, + pub duration_ms: i64, + pub exit_status: i64, + pub more_info: Vec, +} + +impl From for History { + fn from(histdb_item: HistDbEntry) -> Self { + let ts_secs = histdb_item.start_timestamp / 1000; + let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; + let imported = History::import() + .timestamp( + OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() + + Duration::nanoseconds(ts_ns), + ) + .command(String::from_utf8(histdb_item.command_line).unwrap()) + .cwd(String::from_utf8(histdb_item.cwd).unwrap()) + .exit(histdb_item.exit_status) + .duration(histdb_item.duration_ms) + .session(format!("{:x}", histdb_item.session_id)) + .hostname(String::from_utf8(histdb_item.hostname).unwrap()); + + imported.build().into() + } +} + +#[derive(Debug)] +pub struct NuHistDb { + histdb: Vec, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool) -> Result> { + let query = r#" + SELECT + id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, + more_info + FROM history + ORDER BY start_timestamp + "#; + let histdb_vec: Vec = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl NuHistDb { + pub fn histpath() -> Result { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histdb_path = config_dir.join("history.sqlite3"); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!("Could not find history file.")) + } + } +} + +#[async_trait] +impl Importer for NuHistDb { + // Not sure how this is used + const NAME: &'static str = "nu_histdb"; + + /// Creates a new NuHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result { + let dbpath = NuHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + }) + } + + async fn entries(&mut self) -> Result { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for i in self.histdb { + h.push(i.into()).await?; + } + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/powershell.rs b/crates/turtle/src/atuin_client/import/powershell.rs new file mode 100644 index 00000000..8adcc850 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/powershell.rs @@ -0,0 +1,202 @@ +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use std::path::PathBuf; +use time::{Duration, OffsetDateTime}; + +use super::{Importer, Loader, count_lines, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct PowerShell { + bytes: Vec, + line_count: Option, +} + +fn get_history_path() -> Result { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + + // The command line history in PowerShell is maintained by the PSReadLine module: + // https://learn.microsoft.com/en-us/powershell/module/psreadline/about/about_psreadline#command-history + // + // > PSReadLine maintains a history file containing all the commands and data you've entered from the command line. + // > The history files are a file named `$($Host.Name)_history.txt`. + // > On Windows systems the history file is stored at `$Env:APPDATA\Microsoft\Windows\PowerShell\PSReadLine`. + // > On non-Windows systems, the history files are stored at `$Env:XDG_DATA_HOME/powershell/PSReadLine` + // > or `$Env:HOME/.local/share/powershell/PSReadLine`. + + let dir = if cfg!(windows) { + base.data_dir() + .join("Microsoft") + .join("Windows") + .join("PowerShell") + .join("PSReadLine") + } else { + std::env::var("XDG_DATA_HOME") + .map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ) + .join("powershell") + .join("PSReadLine") + }; + + // The history is stored in a file named `$($Host.Name)_history.txt`. + // For the default console host shipped by Microsoft,`$Host.Name` is `ConsoleHost`: + // https://learn.microsoft.com/en-us/dotnet/api/system.management.automation.host.pshost.name#remarks + + let file = dir.join("ConsoleHost_history.txt"); + + if file.is_file() { + Ok(file) + } else { + Err(eyre!("Could not find history file: {}", file.display())) + } +} + +#[async_trait] +impl Importer for PowerShell { + const NAME: &'static str = "PowerShell"; + + async fn new() -> Result { + let bytes = read_to_end(get_history_path()?)?; + Ok(Self { + bytes, + line_count: None, + }) + } + + async fn entries(&mut self) -> Result { + // Commands can be split over multiple lines, + // but this is only used for a progress bar, and multi-line commands + // should be quite rare, so this is not an issue in practice. + if self.line_count.is_none() { + self.line_count = Some(count_lines(&self.bytes)); + } + Ok(self.line_count.unwrap()) + } + + async fn load(mut self, h: &mut impl Loader) -> Result<()> { + let line_count = self.entries().await?; + let start = OffsetDateTime::now_utc() - Duration::milliseconds(line_count as i64); + + let mut counter = 0; + let mut iter = unix_byte_lines(&self.bytes); + + while let Some(s) = iter.next() { + let Ok(s) = read_line(s) else { + continue; // We can skip past things like invalid utf8 + }; + + let mut cmd = s.to_string(); + + // Multi-line commands end with a backtick, append the following lines. + while cmd.ends_with('`') { + cmd.pop(); + + let Some(next) = iter.next() else { + break; + }; + let Ok(next) = read_line(next) else { + break; + }; + + cmd.push('\n'); + cmd.push_str(next); + } + + if cmd.is_empty() { + continue; + } + + let offset = Duration::milliseconds(counter); + counter += 1; + + let entry = History::import().timestamp(start + offset).command(cmd); + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +fn read_line(s: &[u8]) -> Result<&str> { + let s = str::from_utf8(s)?; + + // History is stored in CRLF on Windows, normalize the input to LF on all platforms. + let s = s.strip_suffix('\r').unwrap_or(s); + + Ok(s) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::import::tests::TestLoader; + use itertools::assert_equal; + + const INPUT: &str = r#"cargo install atuin +cargo update +echo "first line` +second line` +` +last line" +echo foo + +echo bar +echo baz +"#; + + const EXPECTED: &[&str] = &[ + "cargo install atuin", + "cargo update", + "echo \"first line\nsecond line\n\nlast line\"", + "echo foo", + "echo bar", + "echo baz", + ]; + + #[tokio::test] + async fn test_import() { + let loader = import(INPUT).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_crlf() { + let input = INPUT.replace("\n", "\r\n"); + let loader = import(input.as_str()).await; + + let actual = loader.buf.iter().map(|h| h.command.clone()); + let expected = EXPECTED.iter().map(|s| s.to_string()); + + assert_equal(actual, expected); + } + + #[tokio::test] + async fn test_timestamps() { + let loader = import(INPUT).await; + + let mut prev = loader.buf.first().unwrap().timestamp; + for current in loader.buf.iter().skip(1).map(|h| h.timestamp) { + assert!(current > prev); + prev = current; + } + } + + async fn import(input: &str) -> TestLoader { + let powershell = PowerShell { + bytes: input.as_bytes().to_vec(), + line_count: None, + }; + + let mut loader = TestLoader::default(); + powershell.load(&mut loader).await.unwrap(); + loader + } +} diff --git a/crates/turtle/src/atuin_client/import/replxx.rs b/crates/turtle/src/atuin_client/import/replxx.rs new file mode 100644 index 00000000..42f84df5 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/replxx.rs @@ -0,0 +1,137 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::{OffsetDateTime, PrimitiveDateTime, macros::format_description}; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Replxx { + bytes: Vec, +} + +fn default_histpath() -> Result { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + // There is no default histfile for replxx. + // Here we try a couple of common names. + let mut candidates = ["replxx_history.txt", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Replxx { + const NAME: &'static str = "replxx"; + + async fn new() -> Result { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes) / 2) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut timestamp = OffsetDateTime::UNIX_EPOCH; + + for b in unix_byte_lines(&self.bytes) { + let s = std::str::from_utf8(b)?; + match try_parse_line_as_timestamp(s) { + Some(t) => timestamp = t, + None => { + // replxx uses ETB character (0x17) as line breaker + let cmd = s.replace('\u{0017}', "\n"); + let imported = History::import().timestamp(timestamp).command(cmd); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option { + // replxx history date time format: ### yyyy-mm-dd hh:mm:ss.xxx + let date_time_str = line.strip_prefix("### ")?; + let format = + format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"); + + let primitive_date_time = PrimitiveDateTime::parse(date_time_str, format).ok()?; + // There is no safe way to get local time offset. + // For simplicity let's just assume UTC. + Some(primitive_date_time.assume_utc()) +} + +#[cfg(test)] +mod test { + + use crate::import::{Importer, tests::TestLoader}; + + use super::Replxx; + + #[tokio::test] + async fn parse_complex() { + let bytes = r#"### 2024-02-10 22:16:28.302 +select * from remote('127.0.0.1:20222', view(select 1)) +### 2024-02-10 22:16:36.919 +select * from numbers(10) +### 2024-02-10 22:16:41.710 +select * from system.numbers +### 2024-02-10 22:19:28.655 +select 1 +### 2024-02-22 11:15:33.046 +CREATE TABLE test( stamp DateTime('UTC'))ENGINE = MergeTreePARTITION BY toDate(stamp)order by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000); +"# + .as_bytes() + .to_owned(); + + let replxx = Replxx { bytes }; + + let mut loader = TestLoader::default(); + replxx.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for replxx history entry + macro_rules! history { + ($timestamp:expr_2021, $command:expr_2021) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + history!( + 1707603388, + "select * from remote('127.0.0.1:20222', view(select 1))" + ); + history!(1707603396, "select * from numbers(10)"); + history!(1707603401, "select * from system.numbers"); + history!(1707603568, "select 1"); + history!( + 1708600533, + "CREATE TABLE test\n( stamp DateTime('UTC'))\nENGINE = MergeTree\nPARTITION BY toDate(stamp)\norder by tuple() as select toDateTime('2020-01-01')+number*60 from numbers(80000);" + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/resh.rs b/crates/turtle/src/atuin_client/import/resh.rs new file mode 100644 index 00000000..c5980c44 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/resh.rs @@ -0,0 +1,140 @@ +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; + +use crate::atuin_common::utils::uuid_v7; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ReshEntry { + pub cmd_line: String, + pub exit_code: i64, + pub shell: String, + pub uname: String, + pub session_id: String, + pub home: String, + pub lang: String, + pub lc_all: String, + pub login: String, + pub pwd: String, + pub pwd_after: String, + pub shell_env: String, + pub term: String, + pub real_pwd: String, + pub real_pwd_after: String, + pub pid: i64, + pub session_pid: i64, + pub host: String, + pub hosttype: String, + pub ostype: String, + pub machtype: String, + pub shlvl: i64, + pub timezone_before: String, + pub timezone_after: String, + pub realtime_before: f64, + pub realtime_after: f64, + pub realtime_before_local: f64, + pub realtime_after_local: f64, + pub realtime_duration: f64, + pub realtime_since_session_start: f64, + pub realtime_since_boot: f64, + pub git_dir: String, + pub git_real_dir: String, + pub git_origin_remote: String, + pub git_dir_after: String, + pub git_real_dir_after: String, + pub git_origin_remote_after: String, + pub machine_id: String, + pub os_release_id: String, + pub os_release_version_id: String, + pub os_release_id_like: String, + pub os_release_name: String, + pub os_release_pretty_name: String, + pub resh_uuid: String, + pub resh_version: String, + pub resh_revision: String, + pub parts_merged: bool, + pub recalled: bool, + pub recall_last_cmd_line: String, + pub cols: String, + pub lines: String, +} + +#[derive(Debug)] +pub struct Resh { + bytes: Vec, +} + +fn default_histpath() -> Result { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) +} + +#[async_trait] +impl Importer for Resh { + const NAME: &'static str = "resh"; + + async fn new() -> Result { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let entry = match serde_json::from_str::(s) { + Ok(e) => e, + Err(_) => continue, // skip invalid json :shrug: + }; + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; + OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) + }; + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; + let base = OffsetDateTime::from_unix_timestamp(secs)? + + time::Duration::nanoseconds(nanosecs); + let difference = base - timestamp; + difference.whole_nanoseconds() as i64 + }; + + let imported = History::import() + .command(entry.cmd_line) + .timestamp(timestamp) + .duration(duration) + .exit(entry.exit_code) + .cwd(entry.pwd) + .hostname(entry.host) + // CHECK: should we add uuid here? It's not set in the other importers + .session(uuid_v7().as_simple().to_string()); + + h.push(imported.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh.rs b/crates/turtle/src/atuin_client/import/xonsh.rs new file mode 100644 index 00000000..a7217826 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh.rs @@ -0,0 +1,234 @@ +use std::env; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use serde::Deserialize; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histdir_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't +// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. +#[derive(Debug, Deserialize)] +struct HistoryFile { + data: HistoryData, +} + +#[derive(Debug, Deserialize)] +struct HistoryData { + sessionid: String, + cmds: Vec, +} + +#[derive(Debug, Deserialize)] +struct HistoryCmd { + cwd: String, + inp: String, + rtn: Option, + ts: (f64, f64), +} + +#[derive(Debug)] +pub struct Xonsh { + // history is stored as a bunch of json files, one per session + sessions: Vec, + hostname: String, +} + +fn xonsh_hist_dir(xonsh_data_dir: Option) -> Result { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("history_json"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_dir = base.data_dir().join("xonsh/history_json"); + if hist_dir.exists() || cfg!(test) { + Ok(hist_dir) + } else { + Err(eyre!("Could not find xonsh history files")) + } +} + +fn load_sessions(hist_dir: &Path) -> Result> { + let mut sessions = vec![]; + for entry in fs::read_dir(hist_dir)? { + let p = entry?.path(); + let ext = p.extension().and_then(|e| e.to_str()); + if p.is_file() + && ext == Some("json") + && let Some(data) = load_session(&p)? + { + sessions.push(data); + } + } + Ok(sessions) +} + +fn load_session(path: &Path) -> Result> { + let file = File::open(path)?; + // empty files are not valid json, so we can't deserialize them + if file.metadata()?.len() == 0 { + return Ok(None); + } + + let mut hist_file: HistoryFile = serde_json::from_reader(file)?; + + // if there are commands in this session, replace the existing UUIDv4 + // with a UUIDv7 generated from the timestamp of the first command + if let Some(cmd) = hist_file.data.cmds.first() { + let seconds = cmd.ts.0.trunc() as u64; + let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; + let ts = Timestamp::from_unix(NoContext, seconds, nanos); + hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); + } + Ok(Some(hist_file.data)) +} + +#[async_trait] +impl Importer for Xonsh { + const NAME: &'static str = "xonsh"; + + async fn new() -> Result { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let hist_dir = get_histdir_path(|| xonsh_hist_dir(xonsh_data_dir))?; + let sessions = load_sessions(&hist_dir)?; + let hostname = get_host_user(); + Ok(Xonsh { sessions, hostname }) + } + + async fn entries(&mut self) -> Result { + let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); + Ok(total) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + for session in self.sessions { + for cmd in session.cmds { + let (start, end) = cmd.ts; + let ts_nanos = (start * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; + + let duration = (end - start) * 1_000_000_000_f64; + + match cmd.rtn { + Some(exit) => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + None => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_hist_dir_xonsh() { + let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + hist_dir, + PathBuf::from("/home/user/xonsh_data/history_json") + ); + } + + #[tokio::test] + async fn test_import() { + let dir = PathBuf::from("tests/data/xonsh"); + let sessions = load_sessions(&dir).unwrap(); + let hostname = "box:user".to_string(); + let xonsh = Xonsh { sessions, hostname }; + + let mut loader = TestLoader::default(); + xonsh.load(&mut loader).await.unwrap(); + // order in buf will depend on filenames, so sort by timestamp for consistency + loader.buf.sort_by_key(|h| h.timestamp); + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4651069) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(21288633) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(1) + .duration(10269403) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(0) + .duration(4259347) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs new file mode 100644 index 00000000..ceedf7e9 --- /dev/null +++ b/crates/turtle/src/atuin_client/import/xonsh_sqlite.rs @@ -0,0 +1,217 @@ +use std::env; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{Result, eyre}; +use futures::TryStreamExt; +use sqlx::{FromRow, Row, sqlite::SqlitePool}; +use time::OffsetDateTime; +use uuid::Uuid; +use uuid::timestamp::{Timestamp, context::NoContext}; + +use super::{Importer, Loader, get_histfile_path}; +use crate::atuin_client::history::History; +use crate::atuin_client::utils::get_host_user; + +#[derive(Debug, FromRow)] +struct HistDbEntry { + inp: String, + rtn: Option, + tsb: f64, + tse: f64, + cwd: String, + session_start: f64, +} + +impl HistDbEntry { + fn into_hist_with_hostname(self, hostname: String) -> History { + let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); + + let session_ts_seconds = self.session_start.trunc() as u64; + let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; + let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); + let session_id = Uuid::new_v7(session_ts).to_string(); + let duration = (self.tse - self.tsb) * 1_000_000_000_f64; + + if let Some(exit) = self.rtn { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } else { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } + } +} + +fn xonsh_db_path(xonsh_data_dir: Option) -> Result { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("xonsh-history.sqlite"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); + if hist_file.exists() || cfg!(test) { + Ok(hist_file) + } else { + Err(eyre!( + "Could not find xonsh history db at: {}", + hist_file.to_string_lossy() + )) + } +} + +#[derive(Debug)] +pub struct XonshSqlite { + pool: SqlitePool, + hostname: String, +} + +#[async_trait] +impl Importer for XonshSqlite { + const NAME: &'static str = "xonsh_sqlite"; + + async fn new() -> Result { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let db_path = get_histfile_path(|| xonsh_db_path(xonsh_data_dir))?; + let connection_str = db_path.to_str().ok_or_else(|| { + eyre!( + "Invalid path for SQLite database: {}", + db_path.to_string_lossy() + ) + })?; + + let pool = SqlitePool::connect(connection_str).await?; + let hostname = get_host_user(); + Ok(XonshSqlite { pool, hostname }) + } + + async fn entries(&mut self) -> Result { + let query = "SELECT COUNT(*) FROM xonsh_history"; + let row = sqlx::query(query).fetch_one(&self.pool).await?; + let count: u32 = row.get(0); + Ok(count as usize) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let query = r#" + SELECT inp, rtn, tsb, tse, cwd, + MIN(tsb) OVER (PARTITION BY sessionid) AS session_start + FROM xonsh_history + ORDER BY rowid + "#; + + let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); + + let mut count = 0; + while let Some(entry) = entries.try_next().await? { + let hist = entry.into_hist_with_hostname(self.hostname.clone()); + loader.push(hist).await?; + count += 1; + } + + println!("Loaded: {count}"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_db_path_xonsh() { + let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + db_path, + PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") + ); + } + + #[tokio::test] + async fn test_import() { + let connection_str = "tests/data/xonsh-history.sqlite"; + let xonsh_sqlite = XonshSqlite { + pool: SqlitePool::connect(connection_str).await.unwrap(), + hostname: "box:user".to_string(), + }; + + let mut loader = TestLoader::default(); + xonsh_sqlite.load(&mut loader).await.unwrap(); + + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(2628564) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(9371519) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(1) + .duration(17337560) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4599094) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh.rs b/crates/turtle/src/atuin_client/import/zsh.rs new file mode 100644 index 00000000..e1fd813a --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh.rs @@ -0,0 +1,230 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::borrow::Cow; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{Result, eyre}; +use time::OffsetDateTime; + +use super::{Importer, Loader, get_histfile_path, unix_byte_lines}; +use crate::atuin_client::history::History; +use crate::atuin_client::import::read_to_end; + +#[derive(Debug)] +pub struct Zsh { + bytes: Vec, +} + +fn default_histpath() -> Result { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // zsh-newuser-install propose as default .histfile https://github.com/zsh-users/zsh/blob/master/Functions/Newuser/zsh-newuser-install#L794 + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + let mut candidates = [".zhistory", ".zsh_history", ".histfile"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )); + } + } + } +} + +#[async_trait] +impl Importer for Zsh { + const NAME: &'static str = "zsh"; + + async fn new() -> Result { + let bytes = read_to_end(get_histfile_path(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut line = String::new(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match unmetafy(b) { + Some(s) => s, + _ => continue, // we can skip past things like invalid utf8 + }; + + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push('\n'); + } else { + line.push_str(&s); + let command = std::mem::take(&mut line); + + if let Some(command) = command.strip_prefix(": ") { + counter += 1; + h.push(parse_extended(command, counter)).await?; + } else { + let offset = time::Duration::seconds(counter); + counter += 1; + + let imported = History::import() + // preserve ordering + .timestamp(now - offset) + .command(command.trim_end().to_string()); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn parse_extended(line: &str, counter: i64) -> History { + let (time, duration) = line.split_once(':').unwrap(); + let (duration, command) = duration.split_once(';').unwrap(); + + let time = time + .parse::() + .ok() + .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) + .unwrap_or_else(OffsetDateTime::now_utc) + + time::Duration::milliseconds(counter); + + // use nanos, because why the hell not? we won't display them. + let duration = duration.parse::().map_or(-1, |t| t * 1_000_000_000); + + let imported = History::import() + .timestamp(time) + .command(command.trim_end().to_string()) + .duration(duration); + + imported.build().into() +} + +fn unmetafy(line: &[u8]) -> Option> { + if line.contains(&0x83) { + let mut s = Vec::with_capacity(line.len()); + let mut is_meta = false; + for ch in line { + if *ch == 0x83 { + is_meta = true; + } else if is_meta { + is_meta = false; + s.push(*ch ^ 32); + } else { + s.push(*ch) + } + } + String::from_utf8(s).ok().map(Cow::Owned) + } else { + std::str::from_utf8(line).ok().map(Cow::Borrowed) + } +} + +#[cfg(test)] +mod test { + use itertools::assert_equal; + + use crate::import::tests::TestLoader; + + use super::*; + + #[test] + fn test_parse_extended_simple() { + let parsed = parse_extended("1613322469:0;cargo install atuin", 0); + + assert_eq!(parsed.command, "cargo install atuin"); + assert_eq!(parsed.duration, 0); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); + + assert_eq!(parsed.command, "cargo install atuin;cargo update"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); + + assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + } + + #[tokio::test] + async fn test_parse_file() { + let bytes = r": 1613322469:0;cargo install atuin +: 1613322469:10;cargo install atuin; \\ +cargo update +: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 4); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + } + + #[tokio::test] + async fn test_parse_metafied() { + let bytes = + b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 2); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["echo 你好", "ls ~/音乐"], + ); + } +} diff --git a/crates/turtle/src/atuin_client/import/zsh_histdb.rs b/crates/turtle/src/atuin_client/import/zsh_histdb.rs new file mode 100644 index 00000000..f61bb74f --- /dev/null +++ b/crates/turtle/src/atuin_client/import/zsh_histdb.rs @@ -0,0 +1,249 @@ +// import old shell history from zsh-histdb! +// automatically hoover up all that we can find + +// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based +// on the schema from 2022-05-01 +// +// I have run into some histories that will not import b/c of non UTF-8 characters. +// + +// +// An Example sqlite query for hsitdb data: +// +//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir +// +// +// select +// history.id, +// history.start_time, +// places.host, +// places.dir, +// commands.argv +// from history +// left join commands on history.command_id = commands.id +// left join places on history.place_id = places.id ; +// +// CREATE TABLE history (id integer primary key autoincrement, +// session int, +// command_id int references commands (id), +// place_id int references places (id), +// exit_status int, +// start_time int, +// duration int); +// + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use crate::atuin_common::utils::uuid_v7; +use directories::UserDirs; +use eyre::{Result, eyre}; +use sqlx::{Pool, sqlite::SqlitePool}; +use time::PrimitiveDateTime; + +use super::Importer; +use crate::atuin_client::history::History; +use crate::atuin_client::import::Loader; +use crate::atuin_client::utils::{get_hostname, get_username}; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntryCount { + pub count: usize, +} + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub start_time: PrimitiveDateTime, + pub host: Vec, + pub dir: Vec, + pub argv: Vec, + pub duration: i64, + pub exit_status: i64, + pub session: i64, +} + +#[derive(Debug)] +pub struct ZshHistDb { + histdb: Vec, + username: String, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool) -> Result> { + let query = r#" + SELECT + history.id, history.start_time, history.duration, places.host, places.dir, + commands.argv, history.exit_status, history.session + FROM history + LEFT JOIN commands ON history.command_id = commands.id + LEFT JOIN places ON history.place_id = places.id + ORDER BY history.start_time + "#; + let histdb_vec: Vec = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl ZshHistDb { + pub fn histpath_candidate() -> PathBuf { + // By default histdb database is `${HOME}/.histdb/zsh-history.db` + // This can be modified by ${HISTDB_FILE} + // + // if [[ -z ${HISTDB_FILE} ]]; then + // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" + let user_dirs = UserDirs::new().unwrap(); // should catch error here? + let home_dir = user_dirs.home_dir(); + std::env::var("HISTDB_FILE") + .as_ref() + .map(|x| Path::new(x).to_path_buf()) + .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) + } + pub fn histpath() -> Result { + let histdb_path = ZshHistDb::histpath_candidate(); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!( + "Could not find history file. Try setting $HISTDB_FILE" + )) + } + } +} + +#[async_trait] +impl Importer for ZshHistDb { + // Not sure how this is used + const NAME: &'static str = "zsh_histdb"; + + /// Creates a new ZshHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result { + let dbpath = ZshHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + username: get_username(), + }) + } + + async fn entries(&mut self) -> Result { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut session_map = HashMap::new(); + for entry in self.histdb { + let command = match std::str::from_utf8(&entry.argv) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let cwd = match std::str::from_utf8(&entry.dir) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let hostname = format!( + "{}:{}", + String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), + self.username + ); + let session = session_map.entry(entry.session).or_insert_with(uuid_v7); + + let imported = History::import() + .timestamp(entry.start_time.assume_utc()) + .command(command) + .cwd(cwd) + .duration(entry.duration * 1_000_000_000) + .exit(entry.exit_status) + .session(session.as_simple().to_string()) + .hostname(hostname) + .build(); + h.push(imported.into()).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + use std::env; + #[tokio::test(flavor = "multi_thread")] + #[expect(unsafe_code)] + async fn test_env_vars() { + let test_env_db = "nonstd-zsh-history.db"; + let key = "HISTDB_FILE"; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var(key, test_env_db) }; + + // test the env got set + assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); + + // test histdb returns the proper db from previous step + let histdb_path = ZshHistDb::histpath_candidate(); + assert_eq!(histdb_path.to_str().unwrap(), test_env_db); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_import() { + let pool: SqlitePool = SqlitePoolOptions::new() + .min_connections(2) + .connect(":memory:") + .await + .unwrap(); + + // sql dump directly from a test database. + let db_sql = r#" + PRAGMA foreign_keys=OFF; + BEGIN TRANSACTION; + CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); + INSERT INTO commands VALUES(1,'pwd'); + INSERT INTO commands VALUES(2,'curl google.com'); + INSERT INTO commands VALUES(3,'bash'); + CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); + INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); + CREATE TABLE history (id integer primary key autoincrement, + session int, + command_id int references commands (id), + place_id int references places (id), + exit_status int, + start_time int, + duration int); + INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); + INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); + INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); + DELETE FROM sqlite_sequence; + INSERT INTO sqlite_sequence VALUES('commands',3); + INSERT INTO sqlite_sequence VALUES('places',3); + INSERT INTO sqlite_sequence VALUES('history',3); + CREATE INDEX hist_time on history(start_time); + CREATE INDEX place_dir on places(dir); + CREATE INDEX place_host on places(host); + CREATE INDEX history_command_place on history(command_id, place_id); + COMMIT; "#; + + sqlx::query(db_sql).execute(&pool).await.unwrap(); + + // test histdb iterator + let histdb_vec = hist_from_db_conn(pool).await.unwrap(); + let histdb = ZshHistDb { + histdb: histdb_vec, + username: get_username(), + }; + + println!("h: {:#?}", histdb.histdb); + println!("counter: {:?}", histdb.histdb.len()); + for i in histdb.histdb { + println!("{i:?}"); + } + } +} diff --git a/crates/turtle/src/atuin_client/login.rs b/crates/turtle/src/atuin_client/login.rs new file mode 100644 index 00000000..ca4e16fe --- /dev/null +++ b/crates/turtle/src/atuin_client/login.rs @@ -0,0 +1,68 @@ +use std::path::PathBuf; + +use crate::atuin_common::api::LoginRequest; +use eyre::{Context, Result, bail}; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; + +use crate::atuin_client::{ + api_client, + encryption::{decode_key, load_key}, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +pub async fn login( + settings: &Settings, + store: &SqliteStore, + username: String, + password: String, + key: String, +) -> Result { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("the specified key was invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("could not decode provided key - is not valid base64")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + let session = api_client::login( + settings.sync_address.as_str(), + LoginRequest { username, password }, + ) + .await?; + + Settings::meta_store() + .await? + .save_session(&session.session) + .await?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/logout.rs b/crates/turtle/src/atuin_client/logout.rs new file mode 100644 index 00000000..343934b9 --- /dev/null +++ b/crates/turtle/src/atuin_client/logout.rs @@ -0,0 +1,16 @@ +use eyre::Result; + +use crate::atuin_client::settings::Settings; + +pub async fn logout() -> Result<()> { + let meta = Settings::meta_store().await?; + + if meta.logged_in().await? { + meta.delete_session().await?; + println!("You have logged out!"); + } else { + println!("You are not logged in"); + } + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/meta.rs b/crates/turtle/src/atuin_client/meta.rs new file mode 100644 index 00000000..1eea7061 --- /dev/null +++ b/crates/turtle/src/atuin_client/meta.rs @@ -0,0 +1,366 @@ +use std::path::Path; +use std::str::FromStr; +use std::time::Duration; + +use crate::atuin_common::record::HostId; +use eyre::{Result, eyre}; +use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions}; +use time::{OffsetDateTime, format_description::well_known::Rfc3339}; +use tokio::sync::OnceCell; +use tracing::{debug, warn}; +use uuid::Uuid; + +// Filenames for the legacy plain-text files that we migrate from. +const LEGACY_HOST_ID_FILENAME: &str = "host_id"; +const LEGACY_LAST_SYNC_FILENAME: &str = "last_sync_time"; +const LEGACY_LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; +const LEGACY_LATEST_VERSION_FILENAME: &str = "latest_version"; +const LEGACY_SESSION_FILENAME: &str = "session"; + +const KEY_HOST_ID: &str = "host_id"; +const KEY_LAST_SYNC: &str = "last_sync_time"; +const KEY_LAST_VERSION_CHECK: &str = "last_version_check_time"; +const KEY_LATEST_VERSION: &str = "latest_version"; +const KEY_SESSION: &str = "session"; +const KEY_FILES_MIGRATED: &str = "files_migrated"; + +pub struct MetaStore { + pool: SqlitePool, + cached_host_id: OnceCell, +} + +impl MetaStore { + pub async fn new(path: impl AsRef, timeout: f64) -> Result { + let path = path.as_ref(); + let path_str = path + .as_os_str() + .to_str() + .ok_or_else(|| eyre!("meta database path is not valid UTF-8: {path:?}"))?; + debug!("opening meta sqlite database at {path:?}"); + + let is_memory = path_str.contains(":memory:"); + + if !is_memory + && !path.exists() + && let Some(dir) = path.parent() + { + fs_err::create_dir_all(dir)?; + } + + // Use DELETE journal mode instead of WAL. This is a small, infrequently- + // written KV store — WAL's concurrency benefits aren't needed, and DELETE + // mode avoids creating auxiliary -wal/-shm files that complicate + // permission handling. + let opts = SqliteConnectOptions::from_str(path_str)? + .journal_mode(SqliteJournalMode::Delete) + .optimize_on_close(true, None) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + sqlx::migrate!("./meta-migrations").run(&pool).await?; + + // Session tokens are stored in this database, so restrict permissions. + #[cfg(unix)] + if !is_memory { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?; + } + + let store = Self { + pool, + cached_host_id: OnceCell::const_new(), + }; + + if !is_memory { + store.migrate_files().await?; + } + + Ok(store) + } + + // Generic key-value operations + + pub async fn get(&self, key: &str) -> Result> { + let row: Option<(String,)> = sqlx::query_as("SELECT value FROM meta WHERE key = ?1") + .bind(key) + .fetch_optional(&self.pool) + .await?; + + Ok(row.map(|r| r.0)) + } + + pub async fn set(&self, key: &str, value: &str) -> Result<()> { + sqlx::query( + "INSERT INTO meta (key, value, updated_at) VALUES (?1, ?2, strftime('%s', 'now')) + ON CONFLICT(key) DO UPDATE SET value = ?2, updated_at = strftime('%s', 'now')", + ) + .bind(key) + .bind(value) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn delete(&self, key: &str) -> Result<()> { + sqlx::query("DELETE FROM meta WHERE key = ?1") + .bind(key) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // Typed accessors + + pub async fn host_id(&self) -> Result { + self.cached_host_id + .get_or_try_init(|| async { + if let Some(id) = self.get(KEY_HOST_ID).await? { + let parsed = Uuid::from_str(id.as_str()) + .map_err(|e| eyre!("failed to parse host ID: {e}"))?; + return Ok(HostId(parsed)); + } + + let uuid = crate::atuin_common::utils::uuid_v7(); + self.set(KEY_HOST_ID, uuid.as_simple().to_string().as_ref()) + .await?; + + Ok(HostId(uuid)) + }) + .await + .copied() + } + + pub async fn last_sync(&self) -> Result { + match self.get(KEY_LAST_SYNC).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_sync_time(&self) -> Result<()> { + self.set( + KEY_LAST_SYNC, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn last_version_check(&self) -> Result { + match self.get(KEY_LAST_VERSION_CHECK).await? { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub async fn save_version_check_time(&self) -> Result<()> { + self.set( + KEY_LAST_VERSION_CHECK, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + ) + .await + } + + pub async fn latest_version(&self) -> Result> { + self.get(KEY_LATEST_VERSION).await + } + + pub async fn save_latest_version(&self, version: &str) -> Result<()> { + self.set(KEY_LATEST_VERSION, version).await + } + + pub async fn session_token(&self) -> Result> { + self.get(KEY_SESSION).await + } + + pub async fn save_session(&self, token: &str) -> Result<()> { + self.set(KEY_SESSION, token).await + } + + pub async fn delete_session(&self) -> Result<()> { + self.delete(KEY_SESSION).await + } + + pub async fn logged_in(&self) -> Result { + Ok(self.session_token().await?.is_some()) + } + + // File migration: on first open, migrate old plain-text files into the database. + // Old files are left in place for safe downgrades. + + async fn migrate_files(&self) -> Result<()> { + if self.get(KEY_FILES_MIGRATED).await?.is_some() { + return Ok(()); + } + + let data_dir = crate::atuin_client::settings::Settings::effective_data_dir(); + + // host_id — validate as UUID + let host_id_path = data_dir.join(LEGACY_HOST_ID_FILENAME); + if host_id_path.exists() + && let Ok(value) = fs_err::read_to_string(&host_id_path) + { + let value = value.trim(); + if !value.is_empty() { + if Uuid::from_str(value).is_ok() { + self.set(KEY_HOST_ID, value).await?; + } else { + warn!("skipping migration of host_id: invalid UUID {value:?}"); + } + } + } + + // last_sync_time — validate as RFC3339 + let sync_path = data_dir.join(LEGACY_LAST_SYNC_FILENAME); + if sync_path.exists() + && let Ok(value) = fs_err::read_to_string(&sync_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_SYNC, value).await?; + } else { + warn!("skipping migration of last_sync_time: invalid RFC3339 {value:?}"); + } + } + } + + // last_version_check_time — validate as RFC3339 + let version_check_path = data_dir.join(LEGACY_LAST_VERSION_CHECK_FILENAME); + if version_check_path.exists() + && let Ok(value) = fs_err::read_to_string(&version_check_path) + { + let value = value.trim(); + if !value.is_empty() { + if OffsetDateTime::parse(value, &Rfc3339).is_ok() { + self.set(KEY_LAST_VERSION_CHECK, value).await?; + } else { + warn!( + "skipping migration of last_version_check_time: invalid RFC3339 {value:?}" + ); + } + } + } + + // latest_version — no strict validation, just non-empty + let latest_version_path = data_dir.join(LEGACY_LATEST_VERSION_FILENAME); + if latest_version_path.exists() + && let Ok(value) = fs_err::read_to_string(&latest_version_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_LATEST_VERSION, value).await?; + } + } + + // session token — no strict validation, just non-empty + let session_path = data_dir.join(LEGACY_SESSION_FILENAME); + if session_path.exists() + && let Ok(value) = fs_err::read_to_string(&session_path) + { + let value = value.trim(); + if !value.is_empty() { + self.set(KEY_SESSION, value).await?; + } + } + + self.set(KEY_FILES_MIGRATED, "true").await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + async fn new_test_store() -> MetaStore { + MetaStore::new("sqlite::memory:", 2.0).await.unwrap() + } + + #[tokio::test] + async fn test_get_set_delete() { + let store = new_test_store().await; + + assert_eq!(store.get("foo").await.unwrap(), None); + + store.set("foo", "bar").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("bar".to_string())); + + store.set("foo", "baz").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), Some("baz".to_string())); + + store.delete("foo").await.unwrap(); + assert_eq!(store.get("foo").await.unwrap(), None); + } + + #[tokio::test] + async fn test_host_id_generation_and_stability() { + let store = new_test_store().await; + + let id1 = store.host_id().await.unwrap(); + let id2 = store.host_id().await.unwrap(); + + assert_eq!(id1, id2, "host_id should be stable across calls"); + } + + #[tokio::test] + async fn test_sync_time() { + let store = new_test_store().await; + + let t = store.last_sync().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_sync_time().await.unwrap(); + let t = store.last_sync().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_version_check_time() { + let store = new_test_store().await; + + let t = store.last_version_check().await.unwrap(); + assert_eq!(t, OffsetDateTime::UNIX_EPOCH); + + store.save_version_check_time().await.unwrap(); + let t = store.last_version_check().await.unwrap(); + assert!(t > OffsetDateTime::UNIX_EPOCH); + } + + #[tokio::test] + async fn test_session_crud() { + let store = new_test_store().await; + + assert!(!store.logged_in().await.unwrap()); + assert_eq!(store.session_token().await.unwrap(), None); + + store.save_session("tok123").await.unwrap(); + assert!(store.logged_in().await.unwrap()); + assert_eq!( + store.session_token().await.unwrap(), + Some("tok123".to_string()) + ); + + store.delete_session().await.unwrap(); + assert!(!store.logged_in().await.unwrap()); + } + + #[tokio::test] + async fn test_latest_version() { + let store = new_test_store().await; + + assert_eq!(store.latest_version().await.unwrap(), None); + + store.save_latest_version("1.2.3").await.unwrap(); + assert_eq!( + store.latest_version().await.unwrap(), + Some("1.2.3".to_string()) + ); + } +} diff --git a/crates/turtle/src/atuin_client/mod.rs b/crates/turtle/src/atuin_client/mod.rs new file mode 100644 index 00000000..7f07f2e2 --- /dev/null +++ b/crates/turtle/src/atuin_client/mod.rs @@ -0,0 +1,26 @@ +#[cfg(feature = "sync")] +pub mod api_client; +#[cfg(feature = "sync")] +pub mod auth; +#[cfg(feature = "sync")] +pub mod login; +#[cfg(feature = "sync")] +pub mod register; +#[cfg(feature = "sync")] +pub mod sync; + +pub mod database; +pub mod distro; +pub mod encryption; +pub mod history; +pub mod import; +pub mod logout; +pub mod meta; +pub mod ordering; +pub mod plugin; +pub mod record; +pub mod secrets; +pub mod settings; +pub mod theme; + +mod utils; diff --git a/crates/turtle/src/atuin_client/ordering.rs b/crates/turtle/src/atuin_client/ordering.rs new file mode 100644 index 00000000..4e5ec84c --- /dev/null +++ b/crates/turtle/src/atuin_client/ordering.rs @@ -0,0 +1,32 @@ +use minspan::minspan; + +use super::{history::History, settings::SearchMode}; + +pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec) -> Vec { + match mode { + SearchMode::Fuzzy => reorder(query, |x| &x.command, res), + _ => res, + } +} + +fn reorder(query: &str, f: F, res: Vec) -> Vec +where + F: Fn(&A) -> &String, + A: Clone, +{ + let mut r = res.clone(); + let qvec = &query.chars().collect(); + r.sort_by_cached_key(|h| { + // TODO for fzf search we should sum up scores for each matched term + let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { + Some(x) => x, + // this is a little unfortunate: when we are asked to match a query that is found nowhere, + // we don't want to return a None, as the comparison behaviour would put the worst matches + // at the front. therefore, we'll return a set of indices that are one larger than the longest + // possible legitimate match. This is meaningless except as a comparison. + None => (0, res.len()), + }; + 1 + to - from + }); + r +} diff --git a/crates/turtle/src/atuin_client/plugin.rs b/crates/turtle/src/atuin_client/plugin.rs new file mode 100644 index 00000000..6f351bf1 --- /dev/null +++ b/crates/turtle/src/atuin_client/plugin.rs @@ -0,0 +1,150 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct OfficialPlugin { + pub name: String, + pub description: String, + pub install_message: String, +} + +impl OfficialPlugin { + pub fn new(name: &str, description: &str, install_message: &str) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + install_message: install_message.to_string(), + } + } +} + +pub struct OfficialPluginRegistry { + plugins: HashMap, +} + +impl OfficialPluginRegistry { + pub fn new() -> Self { + let mut registry = Self { + plugins: HashMap::new(), + }; + + // Register official plugins + registry.register_official_plugins(); + + registry + } + + fn register_official_plugins(&mut self) { + // atuin-update plugin + self.plugins.insert( + "update".to_string(), + OfficialPlugin::new( + "update", + "Update atuin to the latest version", + "The 'atuin update' command is provided by the atuin-update plugin.\n\ + It is only installed if you used the install script\n \ + If you used a package manager (brew, apt, etc), please continue to use it for updates", + ), + ); + } + + pub fn get_plugin(&self, name: &str) -> Option<&OfficialPlugin> { + self.plugins.get(name) + } + + pub fn is_official_plugin(&self, name: &str) -> bool { + self.plugins.contains_key(name) + } + + pub fn get_install_message(&self, name: &str) -> Option<&str> { + self.plugins + .get(name) + .map(|plugin| plugin.install_message.as_str()) + } +} + +impl Default for OfficialPluginRegistry { + fn default() -> Self { + Self::new() + } +} + +pub struct PluginContext { + #[cfg(windows)] + _update_on_windows: Option, +} + +impl PluginContext { + pub fn new(_subcommand: &str) -> Self { + PluginContext { + #[cfg(windows)] + _update_on_windows: (_subcommand == "update").then(UpdateOnWindowsContext::new), + } + } +} + +impl Drop for PluginContext { + fn drop(&mut self) {} +} + +#[cfg(windows)] +struct UpdateOnWindowsContext { + initial_exe: Option, +} + +#[cfg(windows)] +impl UpdateOnWindowsContext { + const OLD_FILE_NAME: &'static str = "atuin.old"; + + pub fn new() -> Self { + // Windows doesn't let you overwrite a running exe, but it lets you rename it, + // so make some room for atuin-update to install the new version. + let initial_exe = std::env::current_exe().ok().and_then(|exe| { + std::fs::rename(&exe, exe.with_file_name(Self::OLD_FILE_NAME)).ok()?; + Some(exe) + }); + + Self { initial_exe } + } +} + +#[cfg(windows)] +impl Drop for UpdateOnWindowsContext { + fn drop(&mut self) { + if let Some(exe) = &self.initial_exe + && !exe.exists() + { + // The update failed, roll back the current exe to its initial name. + std::fs::rename(exe.with_file_name(Self::OLD_FILE_NAME), exe).unwrap_or_else(|e| { + eprintln!("Failed to roll back the update, you may need to reinstall Atuin: {e}"); + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_registry_creation() { + let registry = OfficialPluginRegistry::new(); + assert!(registry.is_official_plugin("update")); + assert!(!registry.is_official_plugin("nonexistent")); + } + + #[test] + fn test_get_plugin() { + let registry = OfficialPluginRegistry::new(); + let plugin = registry.get_plugin("update"); + assert!(plugin.is_some()); + assert_eq!(plugin.unwrap().name, "update"); + } + + #[test] + fn test_get_install_message() { + let registry = OfficialPluginRegistry::new(); + let message = registry.get_install_message("update"); + assert!(message.is_some()); + assert!(message.unwrap().contains("atuin-update")); + } +} diff --git a/crates/turtle/src/atuin_client/record/encryption.rs b/crates/turtle/src/atuin_client/record/encryption.rs new file mode 100644 index 00000000..22dcdec3 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/encryption.rs @@ -0,0 +1,373 @@ +use crate::atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, +}; +use base64::{Engine, engine::general_purpose}; +use eyre::{Context, Result, ensure}; +use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; +use rusty_paseto::core::{ + ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, +}; +use serde::{Deserialize, Serialize}; + +/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. +#[expect(non_camel_case_types)] +pub struct PASETO_V4; + +/* +Why do we use a random content-encryption key? +Originally I was planning on using a derived key for encryption based on additional data. +This would be a lot more secure than using the master key directly. + +However, there's an established norm of using a random key. This scheme might be otherwise known as +- client-side encryption +- envelope encryption +- key wrapping + +A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey +will have some keys that they keep to themselves. These keys never leave their physical hardware. +If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. +This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM +and then store that with your data. + +See + - + - + - + - + - + +Why would we care? In the past we have received some requests for company solutions. If in future we can configure a +KMS service with little effort, then that would solve a lot of issues for their security team. + +Even for personal use, if a user is not comfortable with sharing keys between hosts, +GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs +1000 atuin records a day, that would only cost them $1 and 10 cent a month. + +Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. +This makes it very fast to rotate a key in bulk. + +For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting +will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption +that happens in the background can make the network calls to the HSM +*/ + +impl Encryption for PASETO_V4 { + fn re_encrypt( + mut data: EncryptedData, + _ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result { + let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; + data.content_encryption_key = Self::encrypt_cek(cek, new_key); + Ok(data) + } + + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { + // generate a random key for this entry + // aka content-encryption-key (CEK) + let random_key = Key::::new_os_random(); + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // build the payload and encrypt the token + let payload = serde_json::to_string(&AtuinPayload { + data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), + }) + .expect("json encoding can't fail"); + let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); + let nonce = PasetoNonce::::from(&nonce); + + let token = Paseto::::builder() + .set_payload(Payload::from(payload.as_str())) + .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) + .try_encrypt(&random_key.into(), &nonce) + .expect("error encrypting atuin data"); + + EncryptedData { + data: token, + content_encryption_key: Self::encrypt_cek(random_key, key), + } + } + + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result { + let token = data.data; + let cek = Self::decrypt_cek(data.content_encryption_key, key)?; + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // decrypt the payload with the footer and implicit assertions + let payload = Paseto::::try_decrypt( + &token, + &cek.into(), + None, + ImplicitAssertion::from(&*assertions), + ) + .context("could not decrypt entry")?; + + let payload: AtuinPayload = serde_json::from_str(&payload)?; + let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; + Ok(DecryptedData(data)) + } +} + +impl PASETO_V4 { + fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result> { + let wrapping_key = Key::::from_bytes(*key); + + // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); + + let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) + .context("wrapped cek did not contain the correct contents")?; + + // check that the wrapping key matches the required key to decrypt. + // In future, we could support multiple keys and use this key to + // look up the key rather than only allow one key. + // For now though we will only support the one key and key rotation will + // have to be a hard reset + let current_kid = wrapping_key.to_id(); + + ensure!( + current_kid == kid, + "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" + ); + + // decrypt the random key + Ok(wpk.unwrap_key(&wrapping_key)?) + } + + fn encrypt_cek(cek: Key, key: &[u8; 32]) -> String { + // aka key-encryption-key (KEK) + let wrapping_key = Key::::from_bytes(*key); + + // wrap the random key so we can decrypt it later + let wrapped_cek = AtuinFooter { + wpk: cek.wrap_pie(&wrapping_key), + kid: wrapping_key.to_id(), + }; + serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") + } +} + +#[derive(Serialize, Deserialize)] +struct AtuinPayload { + data: String, +} + +#[derive(Serialize, Deserialize)] +/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. +/// +struct AtuinFooter { + /// Wrapped key + wpk: PieWrappedKey, + /// ID of the key which was used to wrap + kid: KeyId, +} + +/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. +// This cannot be changed, otherwise it breaks the authenticated encryption. +#[derive(Debug, Copy, Clone, Serialize)] +struct Assertions<'a> { + id: &'a RecordId, + idx: &'a RecordIdx, + version: &'a str, + tag: &'a str, + host: &'a HostId, +} + +impl<'a> From> for Assertions<'a> { + fn from(ad: AdditionalData<'a>) -> Self { + Self { + id: ad.id, + version: ad.version, + tag: ad.tag, + host: ad.host, + idx: ad.idx, + } + } +} + +impl Assertions<'_> { + fn encode(&self) -> String { + serde_json::to_string(self).expect("could not serialize implicit assertions") + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; + + use super::*; + + #[test] + fn round_trip() { + let key = Key::::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); + assert_eq!(decrypted, data); + } + + #[test] + fn same_entry_different_output() { + let key = Key::::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + assert_ne!( + encrypted.data, encrypted2.data, + "re-encrypting the same contents should have different output due to key randomization" + ); + } + + #[test] + fn cannot_decrypt_different_key() { + let key = Key::::new_os_random(); + let fake_key = Key::::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); + } + + #[test] + fn cannot_decrypt_different_id() { + let key = Key::::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + ..ad + }; + let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); + } + + #[test] + fn re_encrypt_round_trip() { + let key1 = Key::::new_os_random(); + let key2 = Key::::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); + let encrypted2 = + PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) + .unwrap(); + + // we only re-encrypt the content keys + assert_eq!(encrypted1.data, encrypted2.data); + assert_ne!( + encrypted1.content_encryption_key, + encrypted2.content_encryption_key + ); + + let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); + + assert_eq!(decrypted, data); + } + + #[test] + fn full_record_round_trip() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::(&key); + + assert!(!encrypted.data.data.is_empty()); + assert!(!encrypted.data.content_encryption_key.is_empty()); + + let decrypted = encrypted.decrypt::(&key).unwrap(); + + assert_eq!(decrypted.data.0, [1, 2, 3, 4]); + } + + #[test] + fn full_record_round_trip_fail() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::(&key); + + let mut enc1 = encrypted.clone(); + enc1.host = Host::new(HostId(uuid_v7())); + let _ = enc1 + .decrypt::(&key) + .expect_err("tampering with the host should result in auth failure"); + + let mut enc2 = encrypted; + enc2.id = RecordId(uuid_v7()); + let _ = enc2 + .decrypt::(&key) + .expect_err("tampering with the id should result in auth failure"); + } +} diff --git a/crates/turtle/src/atuin_client/record/mod.rs b/crates/turtle/src/atuin_client/record/mod.rs new file mode 100644 index 00000000..c40fd395 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/mod.rs @@ -0,0 +1,6 @@ +pub mod encryption; +pub mod sqlite_store; +pub mod store; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/crates/turtle/src/atuin_client/record/sqlite_store.rs b/crates/turtle/src/atuin_client/record/sqlite_store.rs new file mode 100644 index 00000000..5fab999d --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sqlite_store.rs @@ -0,0 +1,643 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::str::FromStr; +use std::{path::Path, time::Duration}; + +use async_trait::async_trait; +use eyre::{Result, eyre}; +use fs_err as fs; + +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, +}; +use tracing::debug; + +use crate::atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; +use crate::atuin_common::utils; +use uuid::Uuid; + +use super::encryption::PASETO_V4; +use super::store::Store; + +#[derive(Debug, Clone)] +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef, timeout: f64) -> Result { + let path = path.as_ref(); + + debug!("opening sqlite database at {path:?}"); + + if utils::broken_symlink(path) { + eprintln!( + "Atuin: Sqlite db path ({path:?}) is a broken symlink. Unable to read or create replacement." + ); + std::process::exit(1); + } + + if !path.exists() + && let Some(dir) = path.parent() + { + fs::create_dir_all(dir)?; + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + r: &Record, + ) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record { + let idx: i64 = row.get("idx"); + let timestamp: i64 = row.get("timestamp"); + + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + + Record { + id: RecordId(id), + idx: idx as u64, + host: Host::new(HostId(host)), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: EncryptedData { + data: row.get("data"), + content_encryption_key: row.get("cek"), + }, + } + } + + async fn load_all(&self) -> Result>> { + let res = sqlx::query("select * from store ") + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch( + &self, + records: impl Iterator> + Send + Sync, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: RecordId) -> Result> { + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn delete(&self, id: RecordId) -> Result<()> { + sqlx::query("delete from store where id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_all(&self) -> Result<()> { + sqlx::query("delete from store").execute(&self.pool).await?; + + Ok(()) + } + + async fn last(&self, host: HostId, tag: &str) -> Result>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result>> { + self.idx(host, tag, 0).await + } + + async fn len_all(&self) -> Result { + let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len_tag(&self, tag: &str) -> Result { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len(&self, host: HostId, tag: &str) -> Result { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result>> { + let res = sqlx::query( + "select * from store where idx >= ?1 and host = ?2 and tag = ?3 order by idx asc limit ?4", + ) + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn status(&self) -> Result { + let mut status = RecordStatus::new(); + + let res: Result, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; + + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; + + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); + + status.set_raw(host, i.1, i.2 as u64); + } + + Ok(status) + } + + async fn all_tagged(&self, tag: &str) -> Result>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + /// Reencrypt every single item in this store with a new key + /// Be careful - this may mess with sync. + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { + // Load all the records + // In memory like some of the other code here + // This will never be called in a hot loop, and only under the following circumstances + // 1. The user has logged into a new account, with a new key. They are unlikely to have a + // lot of data + // 2. The user has encountered some sort of issue, and runs a maintenance command that + // invokes this + let all = self.load_all().await?; + + let re_encrypted = all + .into_iter() + .map(|record| record.re_encrypt::(old_key, new_key)) + .collect::>>()?; + + // next up, we delete all the old data and reinsert the new stuff + // do it in one transaction, so if anything fails we rollback OK + + let mut tx = self.pool.begin().await?; + + let res = sqlx::query("delete from store").execute(&mut *tx).await?; + + let rows = res.rows_affected(); + debug!("deleted {rows} rows"); + + // don't call push_batch, as it will start its own transaction + // call the underlying save_raw + + for record in re_encrypted { + Self::save_raw(&mut tx, &record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn verify(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + all.into_iter() + .map(|record| record.decrypt::(key)) + .collect::>>()?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn purge(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + for record in all.iter() { + match record.clone().decrypt::(key) { + Ok(_) => continue, + Err(_) => { + println!( + "Failed to decrypt {}, deleting", + record.id.0.as_hyphenated() + ); + + self.delete(record.id).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::{ + record::{DecryptedData, EncryptedData, Host, HostId, Record}, + utils::uuid_v7, + }; + + use crate::{ + encryption::generate_encoded_key, + record::{encryption::PASETO_V4, store::Store}, + settings::test_local_timeout, + }; + + use super::SqliteStore; + + fn test_record() -> Record { + Record::builder() + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: "1234".into(), + content_encryption_key: "1234".into(), + }) + .idx(0) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:", test_local_timeout()).await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db.get(record.id).await.expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.append(vec![1, 2, 3, 4]).encrypt::(&[0; 32]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + + let mut records: Vec> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.append(vec![1, 2, 3]).encrypt::(&[0; 32]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn re_encrypt() { + let store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .unwrap(); + let (key, _) = generate_encoded_key().unwrap(); + let data = vec![0u8, 1u8, 2u8, 3u8]; + let host_id = HostId(uuid_v7()); + + for i in 0..10 { + let record = Record::builder() + .host(Host::new(host_id)) + .version(String::from("test")) + .tag(String::from("test")) + .idx(i) + .data(DecryptedData(data.clone())) + .build(); + + let record = record.encrypt::(&key.into()); + store + .push(&record) + .await + .expect("failed to push encrypted record"); + } + + // first, check that we can decrypt the data with the current key + let all = store.all_tagged("test").await.unwrap(); + + assert_eq!(all.len(), 10, "failed to fetch all records"); + + for record in all { + let decrypted = record.decrypt::(&key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + // reencrypt the store, then check if + // 1) it cannot be decrypted with the old key + // 2) it can be decrypted with the new key + + let (new_key, _) = generate_encoded_key().unwrap(); + store + .re_encrypt(&key.into(), &new_key.into()) + .await + .expect("failed to re-encrypt store"); + + let all = store.all_tagged("test").await.unwrap(); + + for record in all.iter() { + let decrypted = record.clone().decrypt::(&key.into()); + assert!( + decrypted.is_err(), + "did not get error decrypting with old key after re-encrypt" + ) + } + + for record in all { + let decrypted = record.decrypt::(&new_key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + assert_eq!(store.len(host_id, "test").await.unwrap(), 10); + } +} diff --git a/crates/turtle/src/atuin_client/record/store.rs b/crates/turtle/src/atuin_client/record/store.rs new file mode 100644 index 00000000..f99085d0 --- /dev/null +++ b/crates/turtle/src/atuin_client/record/store.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use eyre::Result; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitrary data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch( + &self, + records: impl Iterator> + Send + Sync, + ) -> Result<()>; + + async fn get(&self, id: RecordId) -> Result>; + + async fn delete(&self, id: RecordId) -> Result<()>; + async fn delete_all(&self) -> Result<()>; + + async fn len_all(&self) -> Result; + async fn len(&self, host: HostId, tag: &str) -> Result; + async fn len_tag(&self, tag: &str) -> Result; + + async fn last(&self, host: HostId, tag: &str) -> Result>>; + async fn first(&self, host: HostId, tag: &str) -> Result>>; + + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; + async fn verify(&self, key: &[u8; 32]) -> Result<()>; + async fn purge(&self, key: &[u8; 32]) -> Result<()>; + + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result>>; + + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result>>; + + async fn status(&self) -> Result; + + /// Get all records for a given tag + async fn all_tagged(&self, tag: &str) -> Result>>; +} diff --git a/crates/turtle/src/atuin_client/record/sync.rs b/crates/turtle/src/atuin_client/record/sync.rs new file mode 100644 index 00000000..f831570b --- /dev/null +++ b/crates/turtle/src/atuin_client/record/sync.rs @@ -0,0 +1,664 @@ +// do a sync :O +use std::{cmp::Ordering, fmt::Write}; + +use eyre::Result; +use thiserror::Error; +use tracing::error; + +use super::{encryption::PASETO_V4, store::Store}; +use crate::atuin_client::{api_client::Client, settings::Settings}; + +use crate::atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occurred: {msg:?}")] + LocalStoreError { msg: String }, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, + + #[error( + "the encryption key on this machine does not match the data on the server. \ + this usually means a new machine was set up without copying the existing key. \ + to fix: run `atuin key` on a machine that already syncs correctly, then run \ + `atuin store rekey ` on this machine with the value from the other machine" + )] + WrongKey, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the states matches the below + Upload { + local: RecordIdx, + remote: Option, + host: HostId, + tag: String, + }, + Download { + local: Option, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { + host: HostId, + tag: String, + }, +} + +pub async fn build_client(settings: &Settings) -> Result, SyncError> { + Client::new( + &settings.sync_address, + settings + .sync_auth_token() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?, + settings.network_connect_timeout, + settings.network_timeout, + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() }) +} + +pub async fn diff( + client: &Client<'_>, + store: &impl Store, +) -> Result<(Vec, RecordStatus), SyncError> { + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations( + diffs: Vec, + _store: &impl Store, +) -> Result, SyncError> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, + + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, + host: diff.host, + tag: diff.tag, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, + host: diff.host, + tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }); + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: RecordIdx, + remote: Option, + page_size: u64, +) -> Result { + let remote = remote.unwrap_or(0); + let expected = local - remote; + let mut progress = 0; + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + println!( + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag + ); + + loop { + let page = store + .next(host, tag.as_str(), remote + progress, page_size) + .await + .map_err(|e| { + error!("failed to read upload page: {e:?}"); + + SyncError::LocalStoreError { msg: e.to_string() } + })?; + + if page.is_empty() { + break; + } + + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); + + SyncError::RemoteRequestError { msg: e.to_string() } + })?; + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Uploaded records"); + + Ok(progress as i64) +} + +async fn sync_download( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: Option, + remote: RecordIdx, + page_size: u64, +) -> Result, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let mut progress = 0; + let mut ret = Vec::new(); + + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + loop { + let page = client + .next_records(host, tag.clone(), local + progress, page_size) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + if page.is_empty() { + break; + } + + store + .push_batch(page.iter()) + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + ret.extend(page.iter().map(|f| f.id)); + + progress += page.len() as u64; + pb.set_position(progress); + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Downloaded records"); + + Ok(ret) +} + +pub async fn sync_remote( + client: &Client<'_>, + operations: Vec, + local_store: &impl Store, + page_size: u64, +) -> Result<(i64, Vec), SyncError> { + let mut uploaded = 0; + let mut downloaded = Vec::new(); + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { + host, + tag, + local, + remote, + } => { + uploaded += + sync_upload(local_store, client, host, tag, local, remote, page_size).await? + } + + Operation::Download { + host, + tag, + local, + remote, + } => { + let mut d = + sync_download(local_store, client, host, tag, local, remote, page_size).await?; + downloaded.append(&mut d) + } + + Operation::Noop { .. } => continue, + } + } + + Ok((uploaded, downloaded)) +} + +pub async fn check_encryption_key( + client: &Client<'_>, + remote_index: &RecordStatus, + encryption_key: &[u8; 32], +) -> Result<(), SyncError> { + let sample = remote_index + .hosts + .iter() + .flat_map(|(host, tags)| tags.keys().map(move |tag| (*host, tag.clone()))) + .next(); + + let Some((host, tag)) = sample else { + return Ok(()); + }; + + let records = client + .next_records(host, tag, 0, 1) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let Some(record) = records.into_iter().next() else { + return Ok(()); + }; + + record + .decrypt::(encryption_key) + .map_err(|_| SyncError::WrongKey)?; + + Ok(()) +} + +pub async fn sync( + settings: &Settings, + store: &impl Store, + encryption_key: &[u8; 32], +) -> Result<(i64, Vec), SyncError> { + let client = build_client(settings).await?; + let (diff, remote_index) = diff(&client, store).await?; + + // Bail before mutating either side if the local key can't read the remote. + check_encryption_key(&client, &remote_index, encryption_key).await?; + + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(&client, operations, store, 100).await?; + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::atuin_client::{ + record::{ + encryption::PASETO_V4, + sqlite_store::SqliteStore, + store::Store, + sync::{self, Operation}, + }, + settings::test_local_timeout, + }; + + fn test_record() -> Record { + Record::builder() + .host(crate::atuin_common::record::Host::new(HostId( + crate::atuin_common::utils::uuid_v7(), + ))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .idx(0) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec>, + remote_records: Vec>, + ) -> (SqliteStore, Vec) { + let local_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:", test_local_timeout()) + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host.id, + tag: record.tag, + local: record.idx, + remote: None, + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + let remote_ahead = test_record(); + + let local_ahead = shared_record + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + assert_eq!(local_ahead.idx, 1); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + // Or in otherwords, local is ahead by one + Operation::Upload { + host: local_ahead.host.id, + tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::(&[0; 32]); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), + ]; + + let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + second_shared_remote_ahead.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 7); + + let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have + Operation::Download { + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, + }, + // Remote knows of a store with a bunch of records that local does not have + Operation::Download { + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, + }, + // Local knows of a record in a store that remote does not have + Operation::Upload { + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, + }, + // Local knows of 4 records in a store that remote does not have + Operation::Upload { + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + assert_eq!(result_ops, operations); + } +} diff --git a/crates/turtle/src/atuin_client/register.rs b/crates/turtle/src/atuin_client/register.rs new file mode 100644 index 00000000..4b14c233 --- /dev/null +++ b/crates/turtle/src/atuin_client/register.rs @@ -0,0 +1,20 @@ +use eyre::Result; + +use crate::atuin_client::{api_client, settings::Settings}; + +pub async fn register_classic( + settings: &Settings, + username: String, + email: String, + password: String, +) -> Result { + let session = + api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + Ok(session.session) +} diff --git a/crates/turtle/src/atuin_client/secrets.rs b/crates/turtle/src/atuin_client/secrets.rs new file mode 100644 index 00000000..e8a6ab62 --- /dev/null +++ b/crates/turtle/src/atuin_client/secrets.rs @@ -0,0 +1,194 @@ +// This file will probably trigger a lot of scanners. Sorry. + +use regex::RegexSet; +use std::sync::LazyLock; + +pub enum TestValue<'a> { + Single(&'a str), + Multiple(&'a [&'a str]), +} + +/// A list of `(name, regex, test)`, where `test` should match against `regex`. +pub static SECRET_PATTERNS: &[(&str, &str, TestValue)] = &[ + ( + "AWS Access Key ID", + "A[KS]IA[0-9A-Z]{16}", + TestValue::Single("AKIAIOSFODNN7EXAMPLE"), + ), + ( + "AWS Secret Access Key env var", + "AWS_SECRET_ACCESS_KEY", + TestValue::Single("AWS_SECRET_ACCESS_KEY=KEYDATA"), + ), + ( + "AWS Session Token env var", + "AWS_SESSION_TOKEN", + TestValue::Single("AWS_SESSION_TOKEN=KEYDATA"), + ), + ( + "Microsoft Azure secret access key env var", + "AZURE_.*_KEY", + TestValue::Single("export AZURE_STORAGE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Google cloud platform key env var", + "GOOGLE_SERVICE_ACCOUNT_KEY", + TestValue::Single("export GOOGLE_SERVICE_ACCOUNT_KEY=KEYDATA"), + ), + ( + "Atuin login", + r"atuin\s+login", + TestValue::Single( + "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", + ), + ), + ( + "GitHub PAT (old)", + "ghp_[a-zA-Z0-9]{36}", + TestValue::Single("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"), // legit, I expired it + ), + ( + "GitHub PAT (new)", + "gh1_[A-Za-z0-9]{21}_[A-Za-z0-9]{59}|github_pat_[0-9][A-Za-z0-9]{21}_[A-Za-z0-9]{59}", + TestValue::Multiple(&[ + "gh1_1234567890abcdefghijk_1234567890abcdefghijklmnopqrstuvwxyz1234567890abcdefghijklm", + "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", // also legit, also expired + ]), + ), + ( + "GitHub OAuth Access Token", + "gho_[A-Za-z0-9]{36}", + TestValue::Single("gho_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub OAuth Access Token (user)", + "ghu_[A-Za-z0-9]{36}", + TestValue::Single("ghu_1234567890abcdefghijklmnopqrstuvwx000"), // not a real token + ), + ( + "GitHub App Installation Access Token", + "ghs_[A-Za-z0-9._-]{36,}", + TestValue::Multiple(&[ + "ghs_1234567890abcdefghijklmnopqrstuvwx000", // not a real token + "ghs_abc-def.ghi_jklMNOP0123456789qrstuv-wxyzABCD", // new token format, fake data + ]), + ), + ( + "GitHub Refresh Token", + "ghr_[A-Za-z0-9]{76}", + TestValue::Single( + "ghr_1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx1234567890abcdefghijklmnopqrstuvwx", + ), // not a real token + ), + ( + "GitHub App Installation Access Token v1", + "v1\\.[0-9A-Fa-f]{40}", + TestValue::Single("v1.1234567890abcdef1234567890abcdef12345678"), // not a real token + ), + ( + "GitLab PAT", + "glpat-[a-zA-Z0-9_]{20}", + TestValue::Single("glpat-RkE_BG5p_bbjML21WSfy"), + ), + ( + "Slack OAuth v2 bot", + "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack OAuth v2 user token", + "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + TestValue::Single("xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy"), + ), + ( + "Slack webhook", + "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", + TestValue::Single( + "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", + ), + ), + ( + "Stripe test key", + "sk_test_[0-9a-zA-Z]{24}", + TestValue::Single("sk_test_1234567890abcdefghijklmnop"), + ), + ( + "Stripe live key", + "sk_live_[0-9a-zA-Z]{24}", + TestValue::Single("sk_live_1234567890abcdefghijklmnop"), + ), + ( + "Netlify authentication token", + "nf[pcoub]_[0-9a-zA-Z]{36}", + TestValue::Single("nfp_nBh7BdJxUwyaBBwFzpyD29MMFT6pZ9wq5634"), + ), + ( + "npm token", + "npm_[A-Za-z0-9]{36}", + TestValue::Single("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"), + ), + ( + "Pulumi personal access token", + "pul-[0-9a-f]{40}", + TestValue::Single("pul-683c2770662c51d960d72ec27613be7653c5cb26"), + ), +]; + +/// The `regex` expressions from [`SECRET_PATTERNS`] compiled into a `RegexSet`. +pub static SECRET_PATTERNS_RE: LazyLock = LazyLock::new(|| { + let exprs = SECRET_PATTERNS.iter().map(|f| f.1); + RegexSet::new(exprs).expect("Failed to build secrets regex") +}); + +#[cfg(test)] +mod tests { + use regex::Regex; + + use crate::secrets::{SECRET_PATTERNS, TestValue}; + + #[test] + fn test_secrets() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + assert!(re.is_match(test), "{name} test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + assert!( + re.is_match(test_str), + "{name} test with value \"{test_str}\" failed!" + ); + } + } + } + } + } + + #[test] + fn test_secrets_embedded() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + match test { + TestValue::Single(test) => { + let embedded = format!("some random text {test} some more random text"); + assert!(re.is_match(&embedded), "{name} embedded test failed!"); + } + TestValue::Multiple(tests) => { + for test_str in tests.iter() { + let embedded = format!("some random text {test_str} some more random text"); + assert!( + re.is_match(&embedded), + "{name} embedded test with value \"{test_str}\" failed!" + ); + } + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/settings.rs b/crates/turtle/src/atuin_client/settings.rs new file mode 100644 index 00000000..b0ffc7c1 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings.rs @@ -0,0 +1,1851 @@ +use std::{collections::HashMap, fmt, io::prelude::*, path::PathBuf, str::FromStr, sync::OnceLock}; +use tokio::sync::OnceCell; + +use crate::atuin_common::record::HostId; +use crate::atuin_common::utils; +use clap::ValueEnum; +use config::{ + Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, builder::DefaultState, +}; +use eyre::{Context, Error, Result, bail, eyre}; +use fs_err::{File, create_dir_all}; +use humantime::parse_duration; +use regex::RegexSet; +use serde::{Deserialize, Serialize}; +use serde_with::DeserializeFromStr; +use time::{OffsetDateTime, UtcOffset, format_description::FormatItem, macros::format_description}; + +pub const HISTORY_PAGE_SIZE: i64 = 100; + +static DATA_DIR: OnceLock = OnceLock::new(); +static META_CONFIG: OnceLock<(String, f64)> = OnceLock::new(); +static META_STORE: OnceCell = OnceCell::const_new(); + +pub(crate) mod meta; +pub mod watcher; + +/// Default sync address for Atuin's hosted service +pub const DEFAULT_SYNC_ADDRESS: &str = "https://api.atuin.sh"; + +#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq, Serialize)] +pub enum SearchMode { + #[serde(rename = "prefix")] + Prefix, + + #[serde(rename = "fulltext")] + #[clap(aliases = &["fulltext"])] + FullText, + + #[serde(rename = "fuzzy")] + Fuzzy, + + #[serde(rename = "skim")] + Skim, + + #[serde(rename = "daemon-fuzzy")] + #[clap(aliases = &["daemon-fuzzy"])] + DaemonFuzzy, +} + +impl SearchMode { + pub fn as_str(&self) -> &'static str { + match self { + SearchMode::Prefix => "PREFIX", + SearchMode::FullText => "FULLTXT", + SearchMode::Fuzzy => "FUZZY", + SearchMode::Skim => "SKIM", + SearchMode::DaemonFuzzy => "DAEMON", + } + } + pub fn next(&self, settings: &Settings) -> Self { + match self { + SearchMode::Prefix => SearchMode::FullText, + // if the user is using skim, we go to skim + SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, + // if the user is using daemon-fuzzy, we go to daemon-fuzzy + SearchMode::FullText if settings.search_mode == SearchMode::DaemonFuzzy => { + SearchMode::DaemonFuzzy + } + // otherwise fuzzy. + SearchMode::FullText => SearchMode::Fuzzy, + SearchMode::Fuzzy | SearchMode::Skim | SearchMode::DaemonFuzzy => SearchMode::Prefix, + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum FilterMode { + #[serde(rename = "global")] + Global = 0, + + #[serde(rename = "host")] + Host = 1, + + #[serde(rename = "session")] + Session = 2, + + #[serde(rename = "directory")] + Directory = 3, + + #[serde(rename = "workspace")] + Workspace = 4, + + #[serde(rename = "session-preload")] + SessionPreload = 5, +} + +impl FilterMode { + pub fn as_str(&self) -> &'static str { + match self { + FilterMode::Global => "GLOBAL", + FilterMode::Host => "HOST", + FilterMode::Session => "SESSION", + FilterMode::Directory => "DIRECTORY", + FilterMode::Workspace => "WORKSPACE", + FilterMode::SessionPreload => "SESSION+", + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum ExitMode { + #[serde(rename = "return-original")] + ReturnOriginal, + + #[serde(rename = "return-query")] + ReturnQuery, +} + +// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged +// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Dialect { + #[serde(rename = "us")] + Us, + + #[serde(rename = "uk")] + Uk, +} + +impl From for interim::Dialect { + fn from(d: Dialect) -> interim::Dialect { + match d { + Dialect::Uk => interim::Dialect::Uk, + Dialect::Us => interim::Dialect::Us, + } + } +} + +/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. +/// +/// Note that the parsing of this struct needs to be done before starting any +/// multithreaded runtime, otherwise it will fail on most Unix systems. +/// +/// See: +#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr, Serialize)] +pub struct Timezone(pub UtcOffset); +impl fmt::Display for Timezone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +/// format: <+|->[:[:]] +static OFFSET_FMT: &[FormatItem<'_>] = format_description!( + "[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]" +); +impl FromStr for Timezone { + type Err = Error; + + fn from_str(s: &str) -> Result { + // local timezone + if matches!(s.to_lowercase().as_str(), "l" | "local") { + // There have been some timezone issues, related to errors fetching it on some + // platforms + // Rather than fail to start, fallback to UTC. The user should still be able to specify + // their timezone manually in the config file. + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + return Ok(Self(offset)); + } + + if matches!(s.to_lowercase().as_str(), "0" | "utc") { + let offset = UtcOffset::UTC; + return Ok(Self(offset)); + } + + // offset from UTC + if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { + return Ok(Self(offset)); + } + + // IDEA: Currently named timezones are not supported, because the well-known crate + // for this is `chrono_tz`, which is not really interoperable with the datetime crate + // that we currently use - `time`. If ever we migrate to using `chrono`, this would + // be a good feature to add. + + bail!(r#""{s}" is not a valid timezone spec"#) + } +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum Style { + #[serde(rename = "auto")] + Auto, + + #[serde(rename = "full")] + Full, + + #[serde(rename = "compact")] + Compact, +} + +#[derive(Clone, Debug, Deserialize, Copy, Serialize)] +pub enum WordJumpMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "subl")] + Subl, +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum KeymapMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "vim-normal")] + VimNormal, + + #[serde(rename = "vim-insert")] + VimInsert, + + #[serde(rename = "auto")] + Auto, +} + +impl KeymapMode { + pub fn as_str(&self) -> &'static str { + match self { + KeymapMode::Emacs => "EMACS", + KeymapMode::VimNormal => "VIMNORMAL", + KeymapMode::VimInsert => "VIMINSERT", + KeymapMode::Auto => "AUTO", + } + } +} + +// We want to translate the config to crossterm::cursor::SetCursorStyle, but +// the original type does not implement trait serde::Deserialize unfortunately. +// It seems impossible to implement Deserialize for external types when it is +// used in HashMap (https://stackoverflow.com/questions/67142663). We instead +// define an adapter type. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum CursorStyle { + #[serde(rename = "default")] + DefaultUserShape, + + #[serde(rename = "blink-block")] + BlinkingBlock, + + #[serde(rename = "steady-block")] + SteadyBlock, + + #[serde(rename = "blink-underline")] + BlinkingUnderScore, + + #[serde(rename = "steady-underline")] + SteadyUnderScore, + + #[serde(rename = "blink-bar")] + BlinkingBar, + + #[serde(rename = "steady-bar")] + SteadyBar, +} + +impl CursorStyle { + pub fn as_str(&self) -> &'static str { + match self { + CursorStyle::DefaultUserShape => "DEFAULT", + CursorStyle::BlinkingBlock => "BLINKBLOCK", + CursorStyle::SteadyBlock => "STEADYBLOCK", + CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", + CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", + CursorStyle::BlinkingBar => "BLINKBAR", + CursorStyle::SteadyBar => "STEADYBAR", + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Stats { + #[serde(default = "Stats::common_prefix_default")] + pub common_prefix: Vec, // sudo, etc. commands we want to strip off + #[serde(default = "Stats::common_subcommands_default")] + pub common_subcommands: Vec, // kubectl, commands we should consider subcommands for + #[serde(default = "Stats::ignored_commands_default")] + pub ignored_commands: Vec, // cd, ls, etc. commands we want to completely hide from stats +} + +impl Stats { + fn common_prefix_default() -> Vec { + vec!["sudo", "doas"].into_iter().map(String::from).collect() + } + + fn common_subcommands_default() -> Vec { + vec![ + "apt", + "cargo", + "composer", + "dnf", + "docker", + "dotnet", + "git", + "go", + "ip", + "jj", + "kubectl", + "nix", + "nmcli", + "npm", + "pecl", + "pnpm", + "podman", + "port", + "systemctl", + "tmux", + "yarn", + ] + .into_iter() + .map(String::from) + .collect() + } + + fn ignored_commands_default() -> Vec { + vec![] + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + common_prefix: Self::common_prefix_default(), + common_subcommands: Self::common_subcommands_default(), + ignored_commands: Self::ignored_commands_default(), + } + } +} + +/// Sync protocol type for authentication. +/// +/// This setting is primarily for development/testing. When not explicitly set, +/// the protocol is inferred from the sync_address: +/// - Default sync address (api.atuin.sh) → Hub protocol +/// - Custom sync address → Legacy protocol +/// +/// Set explicitly to "hub" to use Hub authentication with a custom sync_address +/// (useful for local development against a Hub instance). +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum SyncProtocol { + /// Use legacy CLI authentication (Token from CLI register/login) + #[default] + Legacy, +} + +/// Resolved authentication state for sync operations. +/// +/// Determined at runtime by examining which tokens are available and what +/// server the client is configured to talk to. Operations use this to pick +/// the right auth header and endpoint style. +#[cfg(feature = "sync")] +#[derive(Debug, Clone)] +pub enum SyncAuth { + /// Self-hosted Rust server. Uses `Authorization: Token ` and + /// legacy endpoints. + Legacy { token: String }, + + /// Not authenticated at all. Contains an actionable user-facing message. + NotLoggedIn { reason: String }, +} + +#[cfg(feature = "sync")] +impl SyncAuth { + /// Convert into the auth token type used by the API client. + /// + /// Returns an error with an actionable message for `NotLoggedIn`. + pub fn into_auth_token(self) -> Result { + use crate::atuin_client::api_client::AuthToken; + match self { + SyncAuth::Legacy { token } => Ok(AuthToken::Token(token)), + SyncAuth::NotLoggedIn { reason } => Err(eyre!(reason)), + } + } +} + +#[derive(Clone, Debug, Deserialize, Default, Serialize)] +pub struct Keys { + pub scroll_exits: bool, + pub exit_past_line_start: bool, + pub accept_past_line_end: bool, + pub accept_past_line_start: bool, + pub accept_with_backspace: bool, + pub prefix: String, +} + +impl Keys { + /// The standard default values for all `[keys]` options. + /// These match the config defaults set in `builder_with_data_dir()`. + pub fn standard_defaults() -> Self { + Keys { + scroll_exits: true, + exit_past_line_start: true, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + } + } + + /// Returns true if any value differs from the standard defaults. + pub fn has_non_default_values(&self) -> bool { + let d = Self::standard_defaults(); + self.scroll_exits != d.scroll_exits + || self.exit_past_line_start != d.exit_past_line_start + || self.accept_past_line_end != d.accept_past_line_end + || self.accept_past_line_start != d.accept_past_line_start + || self.accept_with_backspace != d.accept_with_backspace + || self.prefix != d.prefix + } +} + +/// A single rule within a conditional keybinding config. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct KeyRuleConfig { + /// Optional condition expression (e.g. "cursor-at-start", "input-empty && no-results"). + /// If absent, the rule always matches. + #[serde(default)] + pub when: Option, + /// The action to perform (e.g. "exit", "cursor-left", "accept"). + pub action: String, +} + +/// A keybinding config value: either a simple action string or an ordered list of conditional rules. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum KeyBindingConfig { + /// Simple unconditional binding: `"ctrl-c" = "return-original"` + Simple(String), + /// Conditional binding: `"left" = [{ when = "cursor-at-start", action = "exit" }, { action = "cursor-left" }]` + Rules(Vec), +} + +/// User-facing keymap configuration. Each mode maps key strings to bindings. +/// Keys present here override the defaults for that key; unmentioned keys keep defaults. +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct KeymapConfig { + #[serde(default)] + pub emacs: HashMap, + #[serde(default, rename = "vim-normal")] + pub vim_normal: HashMap, + #[serde(default, rename = "vim-insert")] + pub vim_insert: HashMap, + #[serde(default)] + pub inspector: HashMap, + #[serde(default)] + pub prefix: HashMap, +} + +impl KeymapConfig { + /// Returns true if no keybinding overrides are configured in any mode. + pub fn is_empty(&self) -> bool { + self.emacs.is_empty() + && self.vim_normal.is_empty() + && self.vim_insert.is_empty() + && self.inspector.is_empty() + && self.prefix.is_empty() + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Preview { + pub strategy: PreviewStrategy, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Theme { + /// Name of desired theme ("default" for base) + pub name: String, + + /// Whether any available additional theme debug should be shown + pub debug: Option, + + /// How many levels of parenthood will be traversed if needed + pub max_depth: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Daemon { + /// Use the daemon to sync + /// If enabled, history hooks are routed through the daemon. + #[serde(alias = "enable")] + pub enabled: bool, + + /// Automatically start and manage a local daemon when needed. + pub autostart: bool, + + /// The daemon will handle sync on an interval. How often to sync, in seconds. + pub sync_frequency: u64, + + /// The path to the unix socket used by the daemon + pub socket_path: String, + + /// Path to the daemon pidfile used for process coordination. + pub pidfile_path: String, + + /// Use a socket passed via systemd's socket activation protocol, instead of the path + pub systemd_socket: bool, + + /// The port that should be used for TCP on non unix systems + pub tcp_port: u64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Search { + /// The list of enabled filter modes, in order of priority. + pub filters: Vec, + + /// The recency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub recency_score_multiplier: f64, + + /// The frequency score multiplier for the search index (default: 1.0). + /// Values < 1.0 reduce weight, > 1.0 increase weight, 0.0 disables. + pub frequency_score_multiplier: f64, + + /// The overall frecency score multiplier for the search index (default: 1.0). + /// Applied after combining recency and frequency scores. + pub frecency_score_multiplier: f64, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Tmux { + /// Enable using atuin with tmux popup (tmux >= 3.2) + pub enabled: bool, + + /// Width of the tmux popup (percentage) + pub width: String, + + /// Height of the tmux popup (percentage) + pub height: String, +} + +/// Log level for file logging. Maps to tracing's LevelFilter. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + Trace, + Debug, + #[default] + Info, + Warn, + Error, +} + +impl LogLevel { + /// Convert to a tracing directive string for use with EnvFilter. + pub fn as_directive(&self) -> &'static str { + match self { + LogLevel::Trace => "trace", + LogLevel::Debug => "debug", + LogLevel::Info => "info", + LogLevel::Warn => "warn", + LogLevel::Error => "error", + } + } +} + +/// Configuration for a specific log type (search or daemon). +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct LogConfig { + /// Log file name (relative to dir) or absolute path. + pub file: String, + + /// Override global enabled setting for this log type. + pub enabled: Option, + + /// Override global level setting for this log type. + pub level: Option, + + /// Override global retention days setting for this log type. + pub retention: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Logs { + /// Enable file logging globally. Defaults to true. + #[serde(default = "Logs::default_enabled")] + pub enabled: bool, + + /// Directory for log files. Defaults to ~/.atuin/logs + pub dir: String, + + /// Default log level for file logging. Defaults to "info". + /// Note: ATUIN_LOG environment variable overrides this. + #[serde(default)] + pub level: LogLevel, + + /// Default retention days for log files. Defaults to 4. + #[serde(default = "Logs::default_retention")] + pub retention: u64, + + /// Search log settings + #[serde(default)] + pub search: LogConfig, + + /// Daemon log settings + #[serde(default)] + pub daemon: LogConfig, + + /// AI log settings + #[serde(default)] + pub ai: LogConfig, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct Ai { + /// Whether or not the AI features are enabled. + pub enabled: Option, + + /// The address of the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub endpoint: Option, + + /// The API token for the Atuin AI endpoint. Used for AI features like command generation. + /// Only necessary for custom AI endpoints. + pub api_token: Option, + + /// Path to the AI sessions database. + pub db_path: String, + + /// The maximum time in minutes that an AI session can be automatically resumed. + pub session_continue_minutes: i64, + + /// Deprecated: use opening.send_cwd instead. Kept for backwards compatibility. + #[serde(default)] + pub send_cwd: Option, + + /// Configuration for what context is sent in the opening AI request. + #[serde(default)] + pub opening: AiOpening, + + /// Tool capability flags. + #[serde(default)] + pub capabilities: AiCapabilities, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiCapabilities { + /// Whether the AI can request to search Atuin history. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_search: Option, + /// Whether the AI can request to view the stored output, if any, for Atuin history entries. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_history_output: Option, + /// Whether the AI can request to read and write files. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_file_tools: Option, + /// Whether the AI can request to execute bash commands. `None` = unset (defaults to enabled, and the ai will ask for permission). + pub enable_command_execution: Option, +} + +#[derive(Default, Clone, Debug, Deserialize, Serialize)] +pub struct AiOpening { + /// Whether or not to send the current working directory to the AI endpoint. + pub send_cwd: Option, + + /// Whether or not to send the last command as context in the opening AI request. + pub send_last_command: Option, +} + +impl Default for Preview { + fn default() -> Self { + Self { + strategy: PreviewStrategy::Auto, + } + } +} + +impl Default for Theme { + fn default() -> Self { + Self { + name: "".to_string(), + debug: None::, + max_depth: Some(10), + } + } +} + +impl Default for Daemon { + fn default() -> Self { + Self { + enabled: false, + autostart: false, + sync_frequency: 300, + socket_path: "".to_string(), + pidfile_path: "".to_string(), + systemd_socket: false, + tcp_port: 8889, + } + } +} + +impl Default for Logs { + fn default() -> Self { + Self { + enabled: true, + dir: "".to_string(), + level: LogLevel::default(), + retention: Self::default_retention(), + search: LogConfig { + file: "search.log".to_string(), + ..Default::default() + }, + daemon: LogConfig { + file: "daemon.log".to_string(), + ..Default::default() + }, + ai: LogConfig { + file: "ai.log".to_string(), + ..Default::default() + }, + } + } +} + +impl Logs { + fn default_enabled() -> bool { + true + } + + fn default_retention() -> u64 { + 4 + } + + /// Returns whether search logging is enabled. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_enabled(&self) -> bool { + self.search.enabled.unwrap_or(self.enabled) + } + + /// Returns whether daemon logging is enabled. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_enabled(&self) -> bool { + self.daemon.enabled.unwrap_or(self.enabled) + } + + /// Returns whether AI logging is enabled. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_enabled(&self) -> bool { + self.ai.enabled.unwrap_or(self.enabled) + } + + /// Returns the log level for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_level(&self) -> LogLevel { + self.search.level.unwrap_or(self.level) + } + + /// Returns the log level for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_level(&self) -> LogLevel { + self.daemon.level.unwrap_or(self.level) + } + + /// Returns the log level for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_level(&self) -> LogLevel { + self.ai.level.unwrap_or(self.level) + } + + /// Returns the retention days for search logging. + /// Uses search-specific setting if set, otherwise falls back to global. + pub fn search_retention(&self) -> u64 { + self.search.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for daemon logging. + /// Uses daemon-specific setting if set, otherwise falls back to global. + pub fn daemon_retention(&self) -> u64 { + self.daemon.retention.unwrap_or(self.retention) + } + + /// Returns the retention days for AI logging. + /// Uses AI-specific setting if set, otherwise falls back to global. + pub fn ai_retention(&self) -> u64 { + self.ai.retention.unwrap_or(self.retention) + } + + /// Returns the full path for the search log file. + pub fn search_path(&self) -> PathBuf { + let path = PathBuf::from(&self.search.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the daemon log file. + pub fn daemon_path(&self) -> PathBuf { + let path = PathBuf::from(&self.daemon.file); + PathBuf::from(&self.dir).join(path) + } + + /// Returns the full path for the AI log file. + pub fn ai_path(&self) -> PathBuf { + let path = PathBuf::from(&self.ai.file); + PathBuf::from(&self.dir).join(path) + } +} + +impl Default for Search { + fn default() -> Self { + Self { + filters: vec![ + FilterMode::Global, + FilterMode::Host, + FilterMode::Session, + FilterMode::SessionPreload, + FilterMode::Workspace, + FilterMode::Directory, + ], + + recency_score_multiplier: 1.0, + frequency_score_multiplier: 1.0, + frecency_score_multiplier: 1.0, + } + } +} + +impl Default for Tmux { + fn default() -> Self { + Self { + enabled: false, + width: "80%".to_string(), + height: "60%".to_string(), + } + } +} + +// The preview height strategy also takes max_preview_height into account. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum, Serialize)] +pub enum PreviewStrategy { + // Preview height is calculated for the length of the selected command. + #[serde(rename = "auto")] + Auto, + + // Preview height is calculated for the length of the longest command stored in the history. + #[serde(rename = "static")] + Static, + + // max_preview_height is used as fixed height. + #[serde(rename = "fixed")] + Fixed, +} + +/// Column types available for the interactive search UI. +#[derive(Clone, Copy, Debug, Deserialize, PartialEq, Eq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum UiColumnType { + /// Command execution duration (e.g., "123ms") + Duration, + /// Relative time since execution (e.g., "59s ago") + Time, + /// Absolute timestamp (e.g., "2025-01-22 14:35") + Datetime, + /// Working directory + Directory, + /// Hostname + Host, + /// Username + User, + /// Exit code + Exit, + /// The command itself (should be last, expands to fill) + Command, +} + +impl UiColumnType { + /// Returns the default width for this column type (in characters). + /// The Command column returns 0 as it expands to fill remaining space. + pub fn default_width(&self) -> u16 { + match self { + UiColumnType::Duration => 5, // "814ms" + UiColumnType::Time => 9, // "459ms ago" + UiColumnType::Datetime => 16, // "2025-01-22 14:35" + UiColumnType::Directory => 20, + UiColumnType::Host => 15, + UiColumnType::User => 10, + UiColumnType::Exit => { + if cfg!(windows) { + 11 // 32-bit integer on Windows: "-1978335212" + } else { + 3 // Usually a byte on Unix + } + } + UiColumnType::Command => 0, // Expands to fill + } + } +} + +/// A column configuration with type and optional custom width. +/// Can be specified as just a string (uses default width) or as an object with type and width. +#[derive(Clone, Debug, Serialize)] +pub struct UiColumn { + pub column_type: UiColumnType, + pub width: u16, + /// If true, this column expands to fill remaining space. Only one column should expand. + pub expand: bool, +} + +impl UiColumn { + pub fn new(column_type: UiColumnType) -> Self { + Self { + width: column_type.default_width(), + expand: column_type == UiColumnType::Command, + column_type, + } + } + + pub fn with_width(column_type: UiColumnType, width: u16) -> Self { + Self { + column_type, + width, + expand: column_type == UiColumnType::Command, + } + } +} + +// Custom deserialize to handle both string and object formats: +// "duration" or { type = "duration", width = 8, expand = true } +impl<'de> serde::Deserialize<'de> for UiColumn { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, MapAccess, Visitor}; + + struct UiColumnVisitor; + + impl<'de> Visitor<'de> for UiColumnVisitor { + type Value = UiColumn; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "a column type string or an object with 'type' and optional 'width'/'expand'", + ) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let column_type: UiColumnType = + serde::Deserialize::deserialize(serde::de::value::StrDeserializer::new(value))?; + Ok(UiColumn::new(column_type)) + } + + fn visit_map(self, mut map: M) -> Result + where + M: MapAccess<'de>, + { + let mut column_type: Option = None; + let mut width: Option = None; + let mut expand: Option = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "type" => { + column_type = Some(map.next_value()?); + } + "width" => { + width = Some(map.next_value()?); + } + "expand" => { + expand = Some(map.next_value()?); + } + _ => { + let _: serde::de::IgnoredAny = map.next_value()?; + } + } + } + + let column_type = column_type.ok_or_else(|| de::Error::missing_field("type"))?; + let width = width.unwrap_or_else(|| column_type.default_width()); + let expand = expand.unwrap_or(column_type == UiColumnType::Command); + Ok(UiColumn { + column_type, + width, + expand, + }) + } + } + + deserializer.deserialize_any(UiColumnVisitor) + } +} + +/// UI-specific settings for the interactive search. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Ui { + /// Columns to display in interactive search, from left to right. + /// The indicator column (" > ") is always shown first implicitly. + /// The "command" column should be last as it expands to fill remaining space. + /// Can be simple strings or objects with type and width. + #[serde(default = "Ui::default_columns")] + pub columns: Vec, +} + +impl Ui { + fn default_columns() -> Vec { + vec![ + UiColumn::new(UiColumnType::Duration), + UiColumn::new(UiColumnType::Time), + UiColumn::new(UiColumnType::Command), + ] + } + + /// Validate the UI configuration. + /// Returns an error if more than one column has expand = true. + pub fn validate(&self) -> Result<()> { + let expand_count = self.columns.iter().filter(|c| c.expand).count(); + if expand_count > 1 { + bail!( + "Only one column can have expand = true, but {} columns are set to expand", + expand_count + ); + } + Ok(()) + } +} + +impl Default for Ui { + fn default() -> Self { + Self { + columns: Self::default_columns(), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub data_dir: Option, + pub dialect: Dialect, + pub timezone: Timezone, + pub style: Style, + pub auto_sync: bool, + + /// The sync address for atuin. + pub sync_address: String, + + #[serde(default)] + pub sync_protocol: SyncProtocol, + + pub sync_frequency: String, + pub db_path: String, + pub record_store_path: String, + pub key_path: String, + pub search_mode: SearchMode, + pub filter_mode: Option, + pub filter_mode_shell_up_key_binding: Option, + pub search_mode_shell_up_key_binding: Option, + pub shell_up_key_binding: bool, + pub inline_height: u16, + pub inline_height_shell_up_key_binding: Option, + pub invert: bool, + pub show_preview: bool, + pub max_preview_height: u16, + pub show_help: bool, + pub show_tabs: bool, + pub show_numeric_shortcuts: bool, + pub auto_hide_height: u16, + pub exit_mode: ExitMode, + pub keymap_mode: KeymapMode, + pub keymap_mode_shell: KeymapMode, + pub keymap_cursor: HashMap, + pub word_jump_mode: WordJumpMode, + pub word_chars: String, + pub scroll_context_lines: usize, + pub history_format: String, + pub strip_trailing_whitespace: bool, + pub prefers_reduced_motion: bool, + pub store_failed: bool, + pub no_mouse: bool, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub history_filter: RegexSet, + + #[serde(with = "serde_regex", default = "RegexSet::empty", skip_serializing)] + pub cwd_filter: RegexSet, + + pub secrets_filter: bool, + pub workspaces: bool, + pub ctrl_n_shortcuts: bool, + + pub network_connect_timeout: u64, + pub network_timeout: u64, + pub local_timeout: f64, + pub enter_accept: bool, + pub smart_sort: bool, + pub command_chaining: bool, + + #[serde(default)] + pub stats: Stats, + + #[serde(default)] + pub keys: Keys, + + #[serde(default)] + pub keymap: KeymapConfig, + + #[serde(default)] + pub preview: Preview, + + #[serde(default)] + pub daemon: Daemon, + + #[serde(default)] + pub search: Search, + + #[serde(default)] + pub theme: Theme, + + #[serde(default)] + pub ui: Ui, + + #[serde(default)] + pub tmux: Tmux, + + #[serde(default)] + pub logs: Logs, + + #[serde(default)] + pub meta: meta::Settings, +} + +impl Settings { + pub fn utc() -> Self { + Self::builder() + .expect("Could not build default") + .set_override("timezone", "0") + .expect("failed to override timezone with UTC") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } + + pub(crate) fn effective_data_dir() -> PathBuf { + DATA_DIR + .get() + .cloned() + .unwrap_or_else(crate::atuin_common::utils::data_dir) + } + + // -- Meta store: lazily initialized on first access -- + + pub async fn meta_store() -> Result<&'static crate::atuin_client::meta::MetaStore> { + META_STORE + .get_or_try_init(|| async { + let (db_path, timeout) = META_CONFIG.get().ok_or_else(|| { + eyre!("meta store config not set — Settings::new() has not been called") + })?; + crate::atuin_client::meta::MetaStore::new(db_path, *timeout).await + }) + .await + } + + pub async fn host_id() -> Result { + Self::meta_store().await?.host_id().await + } + + pub async fn last_sync() -> Result { + Self::meta_store().await?.last_sync().await + } + + pub async fn save_sync_time() -> Result<()> { + Self::meta_store().await?.save_sync_time().await + } + + pub async fn last_version_check() -> Result { + Self::meta_store().await?.last_version_check().await + } + + pub async fn save_version_check_time() -> Result<()> { + Self::meta_store().await?.save_version_check_time().await + } + + pub async fn should_sync(&self) -> Result { + if !self.auto_sync || !Self::meta_store().await?.logged_in().await? { + return Ok(false); + } + + if self.sync_frequency == "0" { + return Ok(true); + } + + match parse_duration(self.sync_frequency.as_str()) { + Ok(d) => { + let d = time::Duration::try_from(d)?; + Ok(OffsetDateTime::now_utc() - Settings::last_sync().await? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } + + pub async fn logged_in(&self) -> Result { + Self::meta_store().await?.logged_in().await + } + + pub async fn session_token(&self) -> Result { + match Self::meta_store().await?.session_token().await? { + Some(token) => Ok(token), + None => Err(eyre!("Tried to load session; not logged in")), + } + } + + /// Examines the configured sync target and available tokens to determine + /// the correct auth strategy. Also performs cleanup of mis-stored tokens + /// (e.g. a CLI token incorrectly saved in the Hub session slot). + #[cfg(feature = "sync")] + pub async fn resolve_sync_auth(&self) -> SyncAuth { + let meta = match Self::meta_store().await { + Ok(m) => m, + Err(e) => { + return SyncAuth::NotLoggedIn { + reason: format!("Failed to open meta store: {e}"), + }; + } + }; + + // Self-hosted / legacy server + match meta.session_token().await { + Ok(Some(token)) => SyncAuth::Legacy { token }, + _ => SyncAuth::NotLoggedIn { + reason: "Not logged in. Run 'atuin login' to authenticate \ + with your sync server." + .into(), + }, + } + } + + /// Returns the appropriate auth token for sync operations. + /// + /// Delegates to [`resolve_sync_auth`] and converts the result to an + /// `AuthToken`. Callers that need to distinguish between auth states + /// (e.g. to show different UI) should call `resolve_sync_auth` directly. + #[cfg(feature = "sync")] + pub async fn sync_auth_token(&self) -> Result { + self.resolve_sync_auth().await.into_auth_token() + } + + pub fn default_filter_mode(&self, git_root: bool) -> FilterMode { + self.filter_mode + .filter(|x| self.search.filters.contains(x)) + .or_else(|| { + self.search + .filters + .iter() + .find(|x| match (x, git_root, self.workspaces) { + (FilterMode::Workspace, true, true) => true, + (FilterMode::Workspace, _, _) => false, + (_, _, _) => true, + }) + .copied() + }) + .unwrap_or(FilterMode::Global) + } + + pub fn builder() -> Result> { + Self::builder_with_data_dir(&crate::atuin_common::utils::data_dir()) + } + + fn builder_with_data_dir(data_dir: &std::path::Path) -> Result> { + let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + let kv_path = data_dir.join("kv.db"); + let scripts_path = data_dir.join("scripts.db"); + let ai_sessions_path = data_dir.join("ai_sessions.db"); + let socket_path = crate::atuin_common::utils::runtime_dir().join("atuin.sock"); + let pidfile_path = data_dir.join("atuin-daemon.pid"); + let logs_dir = crate::atuin_common::utils::logs_dir(); + + let key_path = data_dir.join("key"); + let meta_path = data_dir.join("meta.db"); + + Ok(Config::builder() + .set_default("history_format", "{time}\t{command}\t{duration}")? + .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? + .set_default("key_path", key_path.to_str())? + .set_default("dialect", "us")? + .set_default("timezone", "local")? + .set_default("auto_sync", true)? + .set_default("sync_address", "https://api.atuin.sh")? + .set_default("sync_frequency", "5m")? + .set_default("search_mode", "fuzzy")? + .set_default("filter_mode", None::)? + .set_default("style", "compact")? + .set_default("inline_height", 40)? + .set_default("show_preview", true)? + .set_default("preview.strategy", "auto")? + .set_default("max_preview_height", 4)? + .set_default("show_help", true)? + .set_default("show_tabs", true)? + .set_default("show_numeric_shortcuts", true)? + .set_default("auto_hide_height", 8)? + .set_default("invert", false)? + .set_default("exit_mode", "return-original")? + .set_default("word_jump_mode", "emacs")? + .set_default( + "word_chars", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + )? + .set_default("scroll_context_lines", 1)? + .set_default("shell_up_key_binding", false)? + .set_default("workspaces", false)? + .set_default("ctrl_n_shortcuts", false)? + .set_default("secrets_filter", true)? + .set_default("strip_trailing_whitespace", true)? + .set_default("network_connect_timeout", 5)? + .set_default("network_timeout", 30)? + .set_default("local_timeout", 2.0)? + // enter_accept defaults to false here, but true in the default config file. The dissonance is + // intentional! + // Existing users will get the default "False", so we don't mess with any potential + // muscle memory. + // New users will get the new default, that is more similar to what they are used to. + .set_default("enter_accept", false)? + .set_default("keys.scroll_exits", true)? + .set_default("keys.accept_past_line_end", true)? + .set_default("keys.exit_past_line_start", true)? + .set_default("keys.accept_past_line_start", false)? + .set_default("keys.accept_with_backspace", false)? + .set_default("keys.prefix", "a")? + .set_default("keymap_mode", "emacs")? + .set_default("keymap_mode_shell", "auto")? + .set_default("keymap_cursor", HashMap::::new())? + .set_default("smart_sort", false)? + .set_default("command_chaining", false)? + .set_default("store_failed", true)? + .set_default("daemon.sync_frequency", 300)? + .set_default("daemon.enabled", false)? + .set_default("daemon.autostart", false)? + .set_default("daemon.socket_path", socket_path.to_str())? + .set_default("daemon.pidfile_path", pidfile_path.to_str())? + .set_default("daemon.systemd_socket", false)? + .set_default("daemon.tcp_port", 8889)? + .set_default("logs.enabled", true)? + .set_default("logs.dir", logs_dir.to_str())? + .set_default("logs.level", "info")? + .set_default("logs.search.file", "search.log")? + .set_default("logs.daemon.file", "daemon.log")? + .set_default("logs.ai.file", "ai.log")? + .set_default("kv.db_path", kv_path.to_str())? + .set_default("scripts.db_path", scripts_path.to_str())? + .set_default("search.recency_score_multiplier", 1.0)? + .set_default("search.frequency_score_multiplier", 1.0)? + .set_default("search.frecency_score_multiplier", 1.0)? + .set_default("meta.db_path", meta_path.to_str())? + .set_default("ai.db_path", ai_sessions_path.to_str())? + .set_default("ai.session_continue_minutes", 60)? + .set_default("ai.send_cwd", false)? + .set_default("ai.opening.send_cwd", false)? + .set_default("ai.opening.send_last_command", false)? + .set_default( + "search.filters", + vec![ + "global", + "host", + "session", + "workspace", + "directory", + "session-preload", + ], + )? + .set_default("theme.name", "default")? + .set_default("theme.debug", None::)? + .set_default("tmux.enabled", false)? + .set_default("tmux.width", "80%")? + .set_default("tmux.height", "60%")? + .set_default( + "prefers_reduced_motion", + std::env::var("NO_MOTION") + .ok() + .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) + .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), + )? + .set_default("no_mouse", false)? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + )) + } + + pub fn get_config_path() -> Result { + let config_dir = crate::atuin_common::utils::config_dir(); + + create_dir_all(&config_dir) + .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; + + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + config_file.push(config_dir); + config_file + }; + + config_file.push("config.toml"); + + Ok(config_file) + } + + /// Build a merged `Config` from defaults, config file, and environment. + /// + /// This resolves `data_dir`, initializes the data directory on disk, + /// and layers defaults → config file → env overrides. Both `new()` and + /// `get_config_value()` use this so the resolution logic lives in one place. + fn build_config() -> Result { + let config_file = Self::get_config_path()?; + + // extract data_dir first so we can use it as the base for other path defaults + let effective_data_dir = if config_file.exists() { + #[derive(Deserialize, Default)] + struct DataDirOnly { + data_dir: Option, + } + + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + + let partial_config = Config::builder() + .add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ) + .build() + .ok(); + + let custom_data_dir = partial_config + .and_then(|c| c.try_deserialize::().ok()) + .and_then(|d| d.data_dir); + + match custom_data_dir { + Some(dir) => { + let expanded = shellexpand::full(&dir) + .map_err(|e| eyre!("failed to expand data_dir path: {}", e))?; + PathBuf::from(expanded.as_ref()) + } + None => crate::atuin_common::utils::data_dir(), + } + } else { + crate::atuin_common::utils::data_dir() + }; + + DATA_DIR.set(effective_data_dir.clone()).ok(); + + create_dir_all(&effective_data_dir) + .wrap_err_with(|| format!("could not create dir {effective_data_dir:?}"))?; + + let mut config_builder = Self::builder_with_data_dir(&effective_data_dir)?; + + config_builder = if config_file.exists() { + let config_file_str = config_file + .to_str() + .ok_or_else(|| eyre!("config file path is not valid UTF-8"))?; + config_builder.add_source(ConfigFile::new(config_file_str, FileFormat::Toml)) + } else { + let mut file = File::create(config_file).wrap_err("could not create config file")?; + + let config = config_builder.build_cloned()?; + // TODO(@bpeetz): Not so sure about this <2026-06-10> + file.write_all(config.cache.to_string().as_bytes()) + .wrap_err("could not write default config file")?; + + config_builder + }; + + // all paths should be expanded + let built = config_builder.build_cloned()?; + config_builder = [ + "db_path", + "record_store_path", + "key_path", + "daemon.socket_path", + "daemon.pidfile_path", + "logs.dir", + "logs.search.file", + "logs.daemon.file", + ] + .iter() + .map(|key| (key, built.get_string(key).unwrap_or_default())) + .filter_map(|(key, value)| match Self::expand_path(value) { + Ok(expanded) => Some((key, expanded)), + Err(e) => { + log::warn!("failed to expand path for {key}: {e}"); + None + } + }) + .fold(config_builder, |builder, (key, value)| { + builder + .set_override(key, value) + .unwrap_or_else(|_| panic!("failed to set absolute path override for {key}")) + }); + + config_builder.build().map_err(Into::into) + } + + /// Look up a single config value by dotted key (e.g. `"daemon.sync_frequency"`). + /// + /// Returns the effective value after merging defaults, config file, and + /// environment — without the side-effects of full `Settings` construction + /// (meta store init, path expansion, etc.). + pub fn get_config_value(key: &str) -> Result { + let config = Self::build_config()?; + let value: config::Value = config + .get(key) + .map_err(|e| eyre!("failed to get config value '{}': {}", key, e))?; + Ok(Self::format_resolved_value(&value, key)) + } + + fn format_resolved_value(value: &config::Value, prefix: &str) -> String { + use config::ValueKind; + + match &value.kind { + ValueKind::Nil => String::new(), + ValueKind::Boolean(b) => b.to_string(), + ValueKind::I64(i) => i.to_string(), + ValueKind::I128(i) => i.to_string(), + ValueKind::U64(u) => u.to_string(), + ValueKind::U128(u) => u.to_string(), + ValueKind::Float(f) => f.to_string(), + ValueKind::String(s) => s.clone(), + ValueKind::Array(arr) => { + let items: Vec = arr + .iter() + .map(|v| Self::format_resolved_value(v, "")) + .collect(); + format!("[{}]", items.join(", ")) + } + ValueKind::Table(map) => { + let mut lines = Vec::new(); + let mut keys: Vec<_> = map.keys().collect(); + keys.sort(); + + for k in keys { + let v = &map[k]; + let full_key = if prefix.is_empty() { + k.clone() + } else { + format!("{}.{}", prefix, k) + }; + + match &v.kind { + ValueKind::Table(_) => { + lines.push(Self::format_resolved_value(v, &full_key)); + } + _ => { + lines.push(format!( + "{} = {}", + full_key, + Self::format_resolved_value(v, "") + )); + } + } + } + + lines.join("\n") + } + } + } + + pub fn new() -> Result { + let config = Self::build_config()?; + let settings: Settings = config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e))?; + + // Validate UI settings + settings.ui.validate()?; + + // Register meta store config for lazy initialization on first access + META_CONFIG + .set((settings.meta.db_path.clone(), settings.local_timeout)) + .ok(); + + Ok(settings) + } + + fn expand_path(path: String) -> Result { + shellexpand::full(&path) + .map(|p| p.to_string()) + .map_err(|e| eyre!("failed to expand path: {}", e)) + } + + pub fn paths_ok(&self) -> bool { + let paths = [ + &self.db_path, + &self.record_store_path, + &self.key_path, + &self.meta.db_path, + ]; + paths.iter().all(|p| !utils::broken_symlink(p)) + } +} + +impl Default for Settings { + fn default() -> Self { + // if this panics something is very wrong, as the default config + // does not build or deserialize into the settings struct + Self::builder() + .expect("Could not build default") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } +} + +/// Initialize the meta store configuration for testing. +/// +/// This should only be used in tests. It allows tests to bypass the normal +/// Settings::new() flow while still being able to use Settings::host_id() +/// and other meta store dependent functions. +/// +/// # Safety +/// This function is not thread-safe with concurrent calls to Settings::new() +/// or other meta store initialization. Only call from tests. +#[doc(hidden)] +pub fn init_meta_config_for_testing(meta_db_path: impl Into, local_timeout: f64) { + META_CONFIG.set((meta_db_path.into(), local_timeout)).ok(); +} + +#[cfg(test)] +pub(crate) fn test_local_timeout() -> f64 { + std::env::var("ATUIN_TEST_LOCAL_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + // this hardcoded value should be replaced by a simple way to get the + // default local_timeout of Settings if possible + .unwrap_or(2.0) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use eyre::Result; + + use super::Timezone; + + #[test] + fn can_parse_offset_timezone_spec() -> Result<()> { + assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); + + // single digit hours are allowed + assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); + + // fully qualified form + assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); + assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); + + // these offsets don't really exist but are supported anyway + assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); + assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); + assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); + assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); + + // require a leading sign for clarity + assert!(Timezone::from_str("5").is_err()); + assert!(Timezone::from_str("10:30").is_err()); + + Ok(()) + } + + #[test] + fn can_choose_workspace_filters_when_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!( + settings.default_filter_mode(true), + super::FilterMode::Workspace, + ); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_not_in_git_context() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = true; + + assert_eq!(settings.default_filter_mode(false), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn wont_choose_workspace_filters_when_workspaces_disabled() -> Result<()> { + let mut settings = super::Settings::default(); + settings.search.filters = vec![ + super::FilterMode::Workspace, + super::FilterMode::Host, + super::FilterMode::Directory, + super::FilterMode::Session, + super::FilterMode::Global, + ]; + settings.workspaces = false; + + assert_eq!(settings.default_filter_mode(true), super::FilterMode::Host,); + + Ok(()) + } + + #[test] + fn builder_with_data_dir_uses_custom_paths() -> Result<()> { + use std::path::PathBuf; + + let custom_dir = PathBuf::from("/custom/data/dir"); + let builder = super::Settings::builder_with_data_dir(&custom_dir)?; + let config = builder.build()?; + + let db_path: String = config.get("db_path")?; + let key_path: String = config.get("key_path")?; + let record_store_path: String = config.get("record_store_path")?; + let kv_db_path: String = config.get("kv.db_path")?; + let scripts_db_path: String = config.get("scripts.db_path")?; + let meta_db_path: String = config.get("meta.db_path")?; + let daemon_socket_path: String = config.get("daemon.socket_path")?; + let daemon_pidfile_path: String = config.get("daemon.pidfile_path")?; + let daemon_autostart: bool = config.get("daemon.autostart")?; + + assert_eq!(db_path, custom_dir.join("history.db").to_str().unwrap()); + assert_eq!(key_path, custom_dir.join("key").to_str().unwrap()); + assert_eq!( + record_store_path, + custom_dir.join("records.db").to_str().unwrap() + ); + assert_eq!(kv_db_path, custom_dir.join("kv.db").to_str().unwrap()); + assert_eq!( + scripts_db_path, + custom_dir.join("scripts.db").to_str().unwrap() + ); + assert_eq!(meta_db_path, custom_dir.join("meta.db").to_str().unwrap()); + assert_eq!( + daemon_socket_path, + crate::atuin_common::utils::runtime_dir() + .join("atuin.sock") + .to_str() + .unwrap() + ); + assert_eq!( + daemon_pidfile_path, + custom_dir.join("atuin-daemon.pid").to_str().unwrap() + ); + assert!(!daemon_autostart); + + Ok(()) + } + + #[test] + fn effective_data_dir_returns_default_when_not_set() { + let effective = super::Settings::effective_data_dir(); + let default = crate::atuin_common::utils::data_dir(); + + assert!(effective.to_str().is_some()); + assert!(effective.ends_with("atuin") || effective == default); + } + + #[test] + fn keymap_config_deserializes_simple_binding() { + let json = r#"{"emacs": {"ctrl-c": "exit"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.emacs.len(), 1); + match &config.emacs["ctrl-c"] { + super::KeyBindingConfig::Simple(s) => assert_eq!(s, "exit"), + _ => panic!("expected Simple variant"), + } + } + + #[test] + fn keymap_config_deserializes_conditional_binding() { + let json = r#"{ + "emacs": { + "left": [ + {"when": "cursor-at-start", "action": "exit"}, + {"action": "cursor-left"} + ] + } + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + match &config.emacs["left"] { + super::KeyBindingConfig::Rules(rules) => { + assert_eq!(rules.len(), 2); + assert_eq!(rules[0].when.as_deref(), Some("cursor-at-start")); + assert_eq!(rules[0].action, "exit"); + assert!(rules[1].when.is_none()); + assert_eq!(rules[1].action, "cursor-left"); + } + _ => panic!("expected Rules variant"), + } + } + + #[test] + fn keymap_config_deserializes_vim_normal() { + let json = r#"{"vim-normal": {"j": "select-next", "k": "select-previous"}}"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.vim_normal.len(), 2); + assert!(config.emacs.is_empty()); + } + + #[test] + fn keymap_config_is_empty_when_default() { + let config = super::KeymapConfig::default(); + assert!(config.is_empty()); + } + + #[test] + fn keymap_config_mixed_modes() { + let json = r#"{ + "emacs": {"ctrl-c": "exit"}, + "vim-normal": {"q": "exit"}, + "inspector": {"d": "delete"} + }"#; + let config: super::KeymapConfig = serde_json::from_str(json).unwrap(); + assert!(!config.is_empty()); + assert_eq!(config.emacs.len(), 1); + assert_eq!(config.vim_normal.len(), 1); + assert_eq!(config.inspector.len(), 1); + assert!(config.vim_insert.is_empty()); + assert!(config.prefix.is_empty()); + } +} diff --git a/crates/turtle/src/atuin_client/settings/meta.rs b/crates/turtle/src/atuin_client/settings/meta.rs new file mode 100644 index 00000000..450d0432 --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/meta.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Settings { + pub db_path: String, +} + +impl Default for Settings { + fn default() -> Self { + let dir = crate::atuin_common::utils::data_dir(); + let path = dir.join("meta.db"); + + Self { + db_path: path.to_string_lossy().to_string(), + } + } +} diff --git a/crates/turtle/src/atuin_client/settings/watcher.rs b/crates/turtle/src/atuin_client/settings/watcher.rs new file mode 100644 index 00000000..7548573e --- /dev/null +++ b/crates/turtle/src/atuin_client/settings/watcher.rs @@ -0,0 +1,256 @@ +//! Config file watching for automatic settings reload. +//! +//! This module provides a `SettingsWatcher` that monitors the config file +//! for changes and broadcasts updated settings via a `tokio::sync::watch` channel. +//! +//! # Example +//! +//! ```no_run +//! use crate::atuin_client::settings::watcher::global_settings_watcher; +//! +//! async fn example() -> eyre::Result<()> { +//! let watcher = global_settings_watcher()?; +//! let mut rx = watcher.subscribe(); +//! +//! // React to settings changes +//! while rx.changed().await.is_ok() { +//! let settings = rx.borrow(); +//! println!("Settings updated!"); +//! } +//! Ok(()) +//! } +//! ``` + +use std::{ + path::PathBuf, + sync::{Arc, OnceLock}, + time::Duration, +}; + +use eyre::{Result, WrapErr}; +use log::{debug, error, info, warn}; +use notify::{ + Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher, + event::{EventKind, ModifyKind}, +}; +use tokio::sync::watch; + +use super::Settings; + +/// Global singleton for the settings watcher. +static SETTINGS_WATCHER: OnceLock> = OnceLock::new(); + +/// Get the global settings watcher singleton. +/// +/// Initializes the watcher on first call. Subsequent calls return the same instance. +/// The watcher monitors the config file for changes and broadcasts updates. +pub fn global_settings_watcher() -> Result<&'static SettingsWatcher> { + let result = SETTINGS_WATCHER.get_or_init(|| SettingsWatcher::new().map_err(|e| e.to_string())); + + match result { + Ok(watcher) => Ok(watcher), + Err(e) => Err(eyre::eyre!("{}", e)), + } +} + +/// Watches the config file for changes and broadcasts updated settings. +/// +/// Uses `notify` for cross-platform file watching and `tokio::sync::watch` +/// for efficient broadcast to multiple subscribers. +pub struct SettingsWatcher { + /// Receiver for settings updates. Clone this to subscribe. + rx: watch::Receiver>, + /// Keeps the file watcher alive for the lifetime of this struct. + _watcher: RecommendedWatcher, +} + +impl SettingsWatcher { + /// Create a new settings watcher. + /// + /// Loads initial settings and starts watching the config file for changes. + /// Changes are debounced (500ms) to avoid multiple reloads during saves. + pub fn new() -> Result { + let initial_settings = Arc::new(Settings::new()?); + let (tx, rx) = watch::channel(initial_settings); + + let config_path = Self::config_path(); + info!("starting config file watcher: {:?}", config_path); + + let watcher = Self::create_watcher(tx, config_path)?; + + Ok(Self { + rx, + _watcher: watcher, + }) + } + + /// Subscribe to settings updates. + /// + /// Returns a receiver that will be notified when settings change. + /// Use `changed().await` to wait for the next update, then `borrow()` + /// to access the current settings. + pub fn subscribe(&self) -> watch::Receiver> { + self.rx.clone() + } + + /// Get the current settings without subscribing to updates. + pub fn current(&self) -> Arc { + self.rx.borrow().clone() + } + + /// Get the config file path. + fn config_path() -> PathBuf { + let config_dir = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + crate::atuin_common::utils::config_dir() + }; + config_dir.join("config.toml") + } + + /// Create the file watcher with debouncing. + fn create_watcher( + tx: watch::Sender>, + config_path: PathBuf, + ) -> Result { + // Channel for debouncing file events + let (debounce_tx, debounce_rx) = std::sync::mpsc::channel::<()>(); + + // Spawn debounce thread + let config_path_clone = config_path.clone(); + std::thread::spawn(move || { + Self::debounce_loop(debounce_rx, tx, config_path_clone); + }); + + // Clone config_path for use in the watcher callback + let config_path_for_watcher = config_path.clone(); + + // Canonicalize config path for reliable comparison on macOS + // (handles symlinks like /var -> /private/var) + let canonical_config_path = config_path_for_watcher + .canonicalize() + .unwrap_or_else(|_| config_path_for_watcher.clone()); + + // Create file watcher + let mut watcher = RecommendedWatcher::new( + move |res: Result| { + match res { + Ok(event) => { + // Defensive: if paths is empty, we can't filter, so assume + // it might be our config file and trigger a reload to be safe + if event.paths.is_empty() { + warn!( + "config watcher: event has no paths, triggering reload to be safe" + ); + let _ = debounce_tx.send(()); + return; + } + + // Only react to events for our specific config file + // (filter out editor temp files, backups, etc.) + let is_config_file = event.paths.iter().any(|path| { + // Canonicalize for reliable comparison (handles macOS symlinks) + let canonical_event_path = + path.canonicalize().unwrap_or_else(|_| path.clone()); + + // Check if this event is for our config file + // (either exact match or the file was renamed to our config) + canonical_event_path == canonical_config_path + || path.file_name() == config_path_for_watcher.file_name() + }); + + if !is_config_file { + return; + } + + // Only react to modify events (content changes) or creates + if matches!( + event.kind, + EventKind::Modify(ModifyKind::Data(_) | ModifyKind::Any) + | EventKind::Create(_) + ) { + debug!("config file event detected: {:?}", event); + // Send to debounce channel (ignore send errors - receiver might be gone) + let _ = debounce_tx.send(()); + } + } + Err(e) => { + error!("file watcher error: {}", e); + } + } + }, + NotifyConfig::default(), + ) + .wrap_err("failed to create file watcher")?; + + // Watch the config file's parent directory (some editors create new files) + let watch_path = config_path.parent().unwrap_or(&config_path); + + // Defensive: ensure watch path exists before trying to watch + if !watch_path.exists() { + warn!( + "config directory does not exist, creating it: {:?}", + watch_path + ); + std::fs::create_dir_all(watch_path) + .wrap_err_with(|| format!("failed to create config directory: {:?}", watch_path))?; + } + + watcher + .watch(watch_path, RecursiveMode::NonRecursive) + .wrap_err_with(|| format!("failed to watch config directory: {:?}", watch_path))?; + + info!("config file watcher initialized for: {:?}", watch_path); + Ok(watcher) + } + + /// Debounce loop that batches file events and reloads settings. + fn debounce_loop( + rx: std::sync::mpsc::Receiver<()>, + tx: watch::Sender>, + config_path: PathBuf, + ) { + const DEBOUNCE_DURATION: Duration = Duration::from_millis(500); + + loop { + // Wait for first event + if rx.recv().is_err() { + // Channel closed, watcher was dropped + debug!("config watcher debounce loop exiting"); + return; + } + + // Drain any additional events within debounce window + while rx.recv_timeout(DEBOUNCE_DURATION).is_ok() { + // Keep draining + } + + // Defensive: check if config file exists before reloading + // (handles case where file was deleted - we'll get notified when it's recreated) + if !config_path.exists() { + debug!( + "config file does not exist, skipping reload: {:?}", + config_path + ); + continue; + } + + // Now reload settings + info!("config file changed, reloading settings: {:?}", config_path); + match Settings::new() { + Ok(settings) => { + if tx.send(Arc::new(settings)).is_err() { + // All receivers dropped + debug!("all settings subscribers dropped, exiting"); + return; + } + info!("settings reloaded successfully"); + } + Err(e) => { + warn!("failed to reload settings: {}", e); + // Keep the old settings, don't broadcast the error + } + } + } + } +} diff --git a/crates/turtle/src/atuin_client/sync.rs b/crates/turtle/src/atuin_client/sync.rs new file mode 100644 index 00000000..361b6636 --- /dev/null +++ b/crates/turtle/src/atuin_client/sync.rs @@ -0,0 +1,214 @@ +use std::collections::HashSet; +use std::iter::FromIterator; + +use eyre::Result; +use tracing::{debug, info}; + +use crate::atuin_common::api::AddHistoryRequest; +use crypto_secretbox::Key; +use time::OffsetDateTime; + +use crate::atuin_client::{ + api_client, + database::Database, + encryption::{decrypt, encrypt, load_key}, + settings::Settings, +}; + +pub fn hash_str(string: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(string.as_bytes()); + hex::encode(hasher.finalize()) +} + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +async fn sync_download( + key: &Key, + force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<(i64, i64)> { + debug!("starting sync download"); + + let remote_status = client.status().await?; + let remote_count = remote_status.count; + + // useful to ensure we don't even save something that hasn't yet been synced + deleted + let remote_deleted = + HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); + + let initial_local = db.history_count(true).await?; + let mut local_count = initial_local; + + let mut last_sync = if force { + OffsetDateTime::UNIX_EPOCH + } else { + Settings::last_sync().await? + }; + + let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client + .get_history(last_sync, last_timestamp, host.clone()) + .await?; + + let history: Vec<_> = page + .history + .iter() + // TODO: handle deletion earlier in this chain + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) + .map(|mut h| { + if remote_deleted.contains(h.id.0.as_str()) { + h.deleted_at = Some(time::OffsetDateTime::now_utc()); + h.command = String::from(""); + } + + h + }) + .collect(); + + db.save_bulk(&history).await?; + + local_count = db.history_count(true).await?; + let remote_page_size = std::cmp::max(remote_status.page_size, 0) as usize; + + if history.len() < remote_page_size { + break; + } + + let page_last = history + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = OffsetDateTime::UNIX_EPOCH; + last_sync -= time::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + for i in remote_status.deleted { + // we will update the stored history to have this data + // pretty much everything can be nullified + match db.load(i.as_str()).await? { + Some(h) => { + db.delete(h).await?; + } + _ => { + info!( + "could not delete history with id {}, not found locally", + i.as_str() + ); + } + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +async fn sync_upload( + key: &Key, + _force: bool, + client: &api_client::Client<'_>, + db: &impl Database, +) -> Result<()> { + debug!("starting sync upload"); + + let remote_status = client.status().await?; + let remote_deleted: HashSet = HashSet::from_iter(remote_status.deleted.clone()); + + let initial_remote_count = client.count().await?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count(true).await?; + + debug!("remote has {remote_count}, we have {local_count}"); + + // first just try the most recent set + let mut cursor = OffsetDateTime::now_utc(); + + while local_count > remote_count { + let last = db.before(cursor, remote_status.page_size).await?; + let mut buffer = Vec::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id.to_string(), + timestamp: i.timestamp, + data, + hostname: hash_str(&i.hostname), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer).await?; + cursor = buffer.last().unwrap().timestamp; + remote_count = client.count().await?; + + debug!("upload cursor: {cursor:?}"); + } + + let deleted = db.deleted().await?; + + for i in deleted { + if remote_deleted.contains(&i.id.to_string()) { + continue; + } + + info!("deleting {} on remote", i.id); + client.delete_history(i).await?; + } + + Ok(()) +} + +pub async fn sync(settings: &Settings, force: bool, db: &impl Database) -> Result<()> { + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + Settings::save_sync_time().await?; + + let key = load_key(settings)?; // encryption key + + sync_upload(&key, force, &client, db).await?; + + let download = sync_download(&key, force, &client, db).await?; + + debug!("sync downloaded {}", download.0); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_client/theme.rs b/crates/turtle/src/atuin_client/theme.rs new file mode 100644 index 00000000..1d9c0b9e --- /dev/null +++ b/crates/turtle/src/atuin_client/theme.rs @@ -0,0 +1,831 @@ +use config::{Config, File as ConfigFile, FileFormat}; +use log; +use palette::named; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::collections::HashMap; +use std::error; +use std::io::{Error, ErrorKind}; +use std::path::PathBuf; +use std::sync::LazyLock; +use strum_macros; + +static DEFAULT_MAX_DEPTH: u8 = 10; + +// Collection of settable "meanings" that can have colors set. +// NOTE: You can add a new meaning here without breaking backwards compatibility but please: +// - update the atuin/docs repository, which has a list of available meanings +// - add a fallback in the MEANING_FALLBACKS below, so that themes which do not have it +// get a sensible fallback (see Title as an example) +#[derive( + Serialize, Deserialize, Copy, Clone, Hash, Debug, Eq, PartialEq, strum_macros::Display, +)] +#[strum(serialize_all = "camel_case")] +pub enum Meaning { + AlertInfo, + AlertWarn, + AlertError, + Annotation, + Base, + Guidance, + Important, + Title, + Muted, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeConfig { + // Definition of the theme + pub theme: ThemeDefinitionConfigBlock, + + // Colors + pub colors: HashMap, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ThemeDefinitionConfigBlock { + /// Name of theme ("default" for base) + pub name: String, + + /// Whether any theme should be treated as a parent _if available_ + pub parent: Option, +} + +use crossterm::style::{Attribute, Attributes, Color, ContentStyle}; + +// For now, a theme is loaded as a mapping of meanings to colors, but it may be desirable to +// expand that in the future to general styles, so we populate a Meaning->ContentStyle hashmap. +pub struct Theme { + pub name: String, + pub parent: Option, + pub styles: HashMap, +} + +// Themes have a number of convenience functions for the most commonly used meanings. +// The general purpose `as_style` routine gives back a style, but for ease-of-use and to keep +// theme-related boilerplate minimal, the convenience functions give a color. +impl Theme { + // This is the base "default" color, for general text + pub fn get_base(&self) -> ContentStyle { + self.styles[&Meaning::Base] + } + + pub fn get_info(&self) -> ContentStyle { + self.get_alert(log::Level::Info) + } + + pub fn get_warning(&self) -> ContentStyle { + self.get_alert(log::Level::Warn) + } + + pub fn get_error(&self) -> ContentStyle { + self.get_alert(log::Level::Error) + } + + // The alert meanings may be chosen by the Level enum, rather than the methods above + // or the full Meaning enum, to simplify programmatic selection of a log-level. + pub fn get_alert(&self, severity: log::Level) -> ContentStyle { + self.styles[ALERT_TYPES.get(&severity).unwrap()] + } + + pub fn new( + name: String, + parent: Option, + styles: HashMap, + ) -> Theme { + Theme { + name, + parent, + styles, + } + } + + pub fn closest_meaning<'a>(&self, meaning: &'a Meaning) -> &'a Meaning { + if self.styles.contains_key(meaning) { + meaning + } else if MEANING_FALLBACKS.contains_key(meaning) { + self.closest_meaning(&MEANING_FALLBACKS[meaning]) + } else { + &Meaning::Base + } + } + + // General access - if you have a meaning, this will give you a (crossterm) style + pub fn as_style(&self, meaning: Meaning) -> ContentStyle { + self.styles[self.closest_meaning(&meaning)] + } + + // Turns a map of meanings to colornames into a theme + // If theme-debug is on, then we will print any colornames that we cannot load, + // but we do not have this on in general, as it could print unfiltered text to the terminal + // from a theme TOML file. However, it will always return a theme, falling back to + // defaults on error, so that a TOML file does not break loading + pub fn from_foreground_colors( + name: String, + parent: Option<&Theme>, + foreground_colors: HashMap, + debug: bool, + ) -> Theme { + let styles: HashMap = foreground_colors + .iter() + .map(|(name, color)| { + ( + *name, + StyleFactory::from_fg_string(color).unwrap_or_else(|err| { + if debug { + log::warn!("Tried to load string as a color unsuccessfully: ({name}={color}) {err}"); + } + ContentStyle::default() + }), + ) + }) + .collect(); + Theme::from_map(name, parent, &styles) + } + + // Boil down a meaning-color hashmap into a theme, by taking the defaults + // for any unknown colors + fn from_map( + name: String, + parent: Option<&Theme>, + overrides: &HashMap, + ) -> Theme { + let styles = match parent { + Some(theme) => Box::new(theme.styles.clone()), + None => Box::new(DEFAULT_THEME.styles.clone()), + } + .iter() + .map(|(name, color)| match overrides.get(name) { + Some(value) => (*name, *value), + None => (*name, *color), + }) + .collect(); + Theme::new(name, parent.map(|p| p.name.clone()), styles) + } +} + +// Use palette to get a color from a string name, if possible +fn from_string(name: &str) -> Result { + if name.is_empty() { + return Err("Empty string".into()); + } + let first_char = name.chars().next().unwrap(); + match first_char { + '#' => { + let hexcode = &name[1..]; + let vec: Vec = hexcode + .chars() + .collect::>() + .chunks(2) + .map(|pair| u8::from_str_radix(pair.iter().collect::().as_str(), 16)) + .filter_map(|n| n.ok()) + .collect(); + if vec.len() != 3 { + return Err("Could not parse 3 hex values from string".into()); + } + Ok(Color::Rgb { + r: vec[0], + g: vec[1], + b: vec[2], + }) + } + '@' => { + // For full flexibility, we need to use serde_json, given + // crossterm's approach. + serde_json::from_str::(format!("\"{}\"", &name[1..]).as_str()) + .map_err(|_| format!("Could not convert color name {name} to Crossterm color")) + } + _ => { + let srgb = named::from_str(name).ok_or("No such color in palette")?; + Ok(Color::Rgb { + r: srgb.red, + g: srgb.green, + b: srgb.blue, + }) + } + } +} + +pub struct StyleFactory {} + +impl StyleFactory { + fn from_fg_string(name: &str) -> Result { + match from_string(name) { + Ok(color) => Ok(Self::from_fg_color(color)), + Err(err) => Err(err), + } + } + + // For succinctness, if we are confident that the name will be known, + // this routine is available to keep the code readable + fn known_fg_string(name: &str) -> ContentStyle { + Self::from_fg_string(name).unwrap() + } + + fn from_fg_color(color: Color) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + ..ContentStyle::default() + } + } + + fn from_fg_color_and_attributes(color: Color, attributes: Attributes) -> ContentStyle { + ContentStyle { + foreground_color: Some(color), + attributes, + ..ContentStyle::default() + } + } +} + +// Built-in themes. Rather than having extra files added before any theming +// is available, this gives a couple of basic options, demonstrating the use +// of themes: autumn and marine +static ALERT_TYPES: LazyLock> = LazyLock::new(|| { + HashMap::from([ + (log::Level::Info, Meaning::AlertInfo), + (log::Level::Warn, Meaning::AlertWarn), + (log::Level::Error, Meaning::AlertError), + ]) +}); + +static MEANING_FALLBACKS: LazyLock> = LazyLock::new(|| { + HashMap::from([ + (Meaning::Guidance, Meaning::AlertInfo), + (Meaning::Annotation, Meaning::AlertInfo), + (Meaning::Title, Meaning::Important), + ]) +}); + +static DEFAULT_THEME: LazyLock = LazyLock::new(|| { + Theme::new( + "default".to_string(), + None, + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::from_fg_color(Color::DarkRed), + ), + ( + Meaning::AlertWarn, + StyleFactory::from_fg_color(Color::DarkYellow), + ), + ( + Meaning::AlertInfo, + StyleFactory::from_fg_color(Color::DarkGreen), + ), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + ( + Meaning::Guidance, + StyleFactory::from_fg_color(Color::DarkBlue), + ), + ( + Meaning::Important, + StyleFactory::from_fg_color_and_attributes( + Color::White, + Attributes::from(Attribute::Bold), + ), + ), + (Meaning::Muted, StyleFactory::from_fg_color(Color::Grey)), + (Meaning::Base, ContentStyle::default()), + ]), + ) +}); + +static BUILTIN_THEMES: LazyLock> = LazyLock::new(|| { + HashMap::from([ + ("default", HashMap::new()), + ( + "(none)", + HashMap::from([ + (Meaning::AlertError, ContentStyle::default()), + (Meaning::AlertWarn, ContentStyle::default()), + (Meaning::AlertInfo, ContentStyle::default()), + (Meaning::Annotation, ContentStyle::default()), + (Meaning::Guidance, ContentStyle::default()), + (Meaning::Important, ContentStyle::default()), + (Meaning::Muted, ContentStyle::default()), + (Meaning::Base, ContentStyle::default()), + ]), + ), + ( + "autumn", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("saddlebrown"), + ), + ( + Meaning::AlertWarn, + StyleFactory::known_fg_string("darkorange"), + ), + (Meaning::AlertInfo, StyleFactory::known_fg_string("gold")), + ( + Meaning::Annotation, + StyleFactory::from_fg_color(Color::DarkGrey), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("brown")), + ]), + ), + ( + "marine", + HashMap::from([ + ( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + ), + (Meaning::AlertWarn, StyleFactory::known_fg_string("cyan")), + ( + Meaning::AlertInfo, + StyleFactory::known_fg_string("turquoise"), + ), + ( + Meaning::Annotation, + StyleFactory::known_fg_string("steelblue"), + ), + ( + Meaning::Base, + StyleFactory::known_fg_string("lightsteelblue"), + ), + (Meaning::Guidance, StyleFactory::known_fg_string("teal")), + ]), + ), + ]) + .iter() + .map(|(name, theme)| (*name, Theme::from_map(name.to_string(), None, theme))) + .collect() +}); + +// To avoid themes being repeatedly loaded, we store them in a theme manager +pub struct ThemeManager { + loaded_themes: HashMap, + debug: bool, + override_theme_dir: Option, +} + +// Theme-loading logic +impl ThemeManager { + pub fn new(debug: Option, theme_dir: Option) -> Self { + Self { + loaded_themes: HashMap::new(), + debug: debug.unwrap_or(false), + override_theme_dir: match theme_dir { + Some(theme_dir) => Some(theme_dir), + None => std::env::var("ATUIN_THEME_DIR").ok(), + }, + } + } + + // Try to load a theme from a `{name}.toml` file in the theme directory. If an override is set + // for the theme dir (via ATUIN_THEME_DIR env) we should load the theme from there + pub fn load_theme_from_file( + &mut self, + name: &str, + max_depth: u8, + ) -> Result<&Theme, Box> { + let mut theme_file = if let Some(p) = &self.override_theme_dir { + if p.is_empty() { + return Err(Box::new(Error::new( + ErrorKind::NotFound, + "Empty theme directory override and could not find theme elsewhere", + ))); + } + PathBuf::from(p) + } else { + let config_dir = crate::atuin_common::utils::config_dir(); + let mut theme_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut theme_file = PathBuf::new(); + theme_file.push(config_dir); + theme_file + }; + theme_file.push("themes"); + theme_file + }; + + let theme_toml = format!("{name}.toml"); + theme_file.push(theme_toml); + + let mut config_builder = Config::builder(); + + config_builder = config_builder.add_source(ConfigFile::new( + theme_file.to_str().unwrap(), + FileFormat::Toml, + )); + + let config = config_builder.build()?; + self.load_theme_from_config(name, config, max_depth) + } + + pub fn load_theme_from_config( + &mut self, + name: &str, + config: Config, + max_depth: u8, + ) -> Result<&Theme, Box> { + let debug = self.debug; + let theme_config: ThemeConfig = match config.try_deserialize() { + Ok(tc) => tc, + Err(e) => { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + format!( + "Failed to deserialize theme: {}", + if debug { + e.to_string() + } else { + "set theme debug on for more info".to_string() + } + ), + ))); + } + }; + let colors: HashMap = theme_config.colors; + let parent: Option<&Theme> = match theme_config.theme.parent { + Some(parent_name) => { + if max_depth == 0 { + return Err(Box::new(Error::new( + ErrorKind::InvalidInput, + "Parent requested but we hit the recursion limit", + ))); + } + Some(self.load_theme(parent_name.as_str(), Some(max_depth - 1))) + } + None => Some(self.load_theme("default", Some(max_depth - 1))), + }; + + if debug && name != theme_config.theme.name { + log::warn!( + "Your theme config name is not the name of your loaded theme {} != {}", + name, + theme_config.theme.name + ); + } + + let theme = Theme::from_foreground_colors(theme_config.theme.name, parent, colors, debug); + let name = name.to_string(); + self.loaded_themes.insert(name.clone(), theme); + let theme = self.loaded_themes.get(&name).unwrap(); + Ok(theme) + } + + // Check if the requested theme is loaded and, if not, then attempt to get it + // from the builtins or, if not there, from file + pub fn load_theme(&mut self, name: &str, max_depth: Option) -> &Theme { + if self.loaded_themes.contains_key(name) { + return self.loaded_themes.get(name).unwrap(); + } + let built_ins = &BUILTIN_THEMES; + match built_ins.get(name) { + Some(theme) => theme, + None => match self.load_theme_from_file(name, max_depth.unwrap_or(DEFAULT_MAX_DEPTH)) { + Ok(theme) => theme, + Err(err) => { + log::warn!("Could not load theme {name}: {err}"); + built_ins.get("(none)").unwrap() + } + }, + } + } +} + +#[cfg(test)] +mod theme_tests { + use super::*; + + #[test] + fn test_can_load_builtin_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let theme = manager.load_theme("autumn", None); + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("brown").ok() + ); + } + + #[test] + fn test_can_create_theme() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + let mytheme = Theme::new( + "mytheme".to_string(), + None, + HashMap::from([( + Meaning::AlertError, + StyleFactory::known_fg_string("yellowgreen"), + )]), + ); + manager.loaded_themes.insert("mytheme".to_string(), mytheme); + let theme = manager.load_theme("mytheme", None); + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + from_string("yellowgreen").ok() + ); + } + + #[test] + fn test_can_fallback_when_meaning_missing() { + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // We use title as an example of a meaning that is not defined + // even in the base theme. + assert!(!DEFAULT_THEME.styles.contains_key(&Meaning::Title)); + + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let theme = manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + + // Correctly picks overridden color. + assert_eq!( + theme.as_style(Meaning::Guidance).foreground_color, + from_string("white").ok() + ); + + // Does not fall back to any color. + assert_eq!(theme.as_style(Meaning::AlertInfo).foreground_color, None); + + // Even for the base. + assert_eq!(theme.as_style(Meaning::Base).foreground_color, None); + + // Falls back to red as meaning missing from theme, so picks base default. + assert_eq!( + theme.as_style(Meaning::AlertError).foreground_color, + Some(Color::DarkRed) + ); + + // Falls back to Important as Title not available. + assert_eq!( + theme.as_style(Meaning::Title).foreground_color, + theme.as_style(Meaning::Important).foreground_color, + ); + + let title_config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"title_theme\" + + [colors] + Title = \"white\" + AlertInfo = \"zomp\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let title_theme = manager + .load_theme_from_config("title_theme", title_config, 1) + .unwrap(); + + assert_eq!( + title_theme.as_style(Meaning::Title).foreground_color, + Some(Color::White) + ); + } + + #[test] + fn test_no_fallbacks_are_circular() { + let mytheme = Theme::new("mytheme".to_string(), None, HashMap::from([])); + MEANING_FALLBACKS + .iter() + .for_each(|pair| assert_eq!(mytheme.closest_meaning(pair.0), &Meaning::Base)) + } + + #[test] + fn test_can_get_colors_via_convenience_functions() { + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("default", None); + assert_eq!(theme.get_error().foreground_color.unwrap(), Color::DarkRed); + assert_eq!( + theme.get_warning().foreground_color.unwrap(), + Color::DarkYellow + ); + assert_eq!(theme.get_info().foreground_color.unwrap(), Color::DarkGreen); + assert_eq!(theme.get_base().foreground_color, None); + assert_eq!( + theme.get_alert(log::Level::Error).foreground_color.unwrap(), + Color::DarkRed + ) + } + + #[test] + fn test_can_use_parent_theme_for_fallbacks() { + testing_logger::setup(); + + let mut manager = ThemeManager::new(Some(false), Some("".to_string())); + + // First, we introduce a base theme + let solarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"solarized\" + + [colors] + Guidance = \"white\" + AlertInfo = \"pink\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let solarized_theme = manager + .load_theme_from_config("solarized", solarized, 1) + .unwrap(); + + assert_eq!( + solarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("pink").ok() + ); + + // Then we introduce a derived theme + let unsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"unsolarized\" + parent = \"solarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let unsolarized_theme = manager + .load_theme_from_config("unsolarized", unsolarized, 1) + .unwrap(); + + // It will take its own values + assert_eq!( + unsolarized_theme + .as_style(Meaning::AlertInfo) + .foreground_color, + from_string("red").ok() + ); + + // ...or fall back to the parent + assert_eq!( + unsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + from_string("white").ok() + ); + + testing_logger::validate(|captured_logs| assert_eq!(captured_logs.len(), 0)); + + // If the parent is not found, we end up with the no theme colors or styling + // as this is considered a (soft) error state. + let nunsolarized = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"nunsolarized\" + parent = \"nonsolarized\" + + [colors] + AlertInfo = \"red\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + let nunsolarized_theme = manager + .load_theme_from_config("nunsolarized", nunsolarized, 1) + .unwrap(); + + assert_eq!( + nunsolarized_theme + .as_style(Meaning::Guidance) + .foreground_color, + None + ); + + testing_logger::validate(|captured_logs| { + assert_eq!(captured_logs.len(), 1); + assert_eq!( + captured_logs[0].body, + "Could not load theme nonsolarized: Empty theme directory override and could not find theme elsewhere" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn) + }); + } + + #[test] + fn test_can_debug_theme() { + testing_logger::setup(); + [true, false].iter().for_each(|debug| { + let mut manager = ThemeManager::new(Some(*debug), Some("".to_string())); + let config = Config::builder() + .add_source(ConfigFile::from_str( + " + [theme] + name = \"mytheme\" + + [colors] + Guidance = \"white\" + AlertInfo = \"xinetic\" + ", + FileFormat::Toml, + )) + .build() + .unwrap(); + manager + .load_theme_from_config("config_theme", config, 1) + .unwrap(); + testing_logger::validate(|captured_logs| { + if *debug { + assert_eq!(captured_logs.len(), 2); + assert_eq!( + captured_logs[0].body, + "Your theme config name is not the name of your loaded theme config_theme != mytheme" + ); + assert_eq!(captured_logs[0].level, log::Level::Warn); + assert_eq!( + captured_logs[1].body, + "Tried to load string as a color unsuccessfully: (AlertInfo=xinetic) No such color in palette" + ); + assert_eq!(captured_logs[1].level, log::Level::Warn) + } else { + assert_eq!(captured_logs.len(), 0) + } + }) + }) + } + + #[test] + fn test_can_parse_color_strings_correctly() { + assert_eq!( + from_string("brown").unwrap(), + Color::Rgb { + r: 165, + g: 42, + b: 42 + } + ); + + assert_eq!(from_string(""), Err("Empty string".into())); + + ["manatee", "caput mortuum", "123456"] + .iter() + .for_each(|inp| { + assert_eq!(from_string(inp), Err("No such color in palette".into())); + }); + + assert_eq!( + from_string("#ff1122").unwrap(), + Color::Rgb { + r: 255, + g: 17, + b: 34 + } + ); + ["#1122", "#ffaa112", "#brown"].iter().for_each(|inp| { + assert_eq!( + from_string(inp), + Err("Could not parse 3 hex values from string".into()) + ); + }); + + assert_eq!(from_string("@dark_grey").unwrap(), Color::DarkGrey); + assert_eq!( + from_string("@rgb_(255,255,255)").unwrap(), + Color::Rgb { + r: 255, + g: 255, + b: 255 + } + ); + assert_eq!(from_string("@ansi_(255)").unwrap(), Color::AnsiValue(255)); + ["@", "@DarkGray", "@Dark 4ay", "@ansi(256)"] + .iter() + .for_each(|inp| { + assert_eq!( + from_string(inp), + Err(format!( + "Could not convert color name {inp} to Crossterm color" + )) + ); + }); + } +} diff --git a/crates/turtle/src/atuin_client/utils.rs b/crates/turtle/src/atuin_client/utils.rs new file mode 100644 index 00000000..35d7db26 --- /dev/null +++ b/crates/turtle/src/atuin_client/utils.rs @@ -0,0 +1,14 @@ +pub(crate) fn get_hostname() -> String { + std::env::var("ATUIN_HOST_NAME") + .unwrap_or_else(|_| whoami::hostname().unwrap_or_else(|_| "unknown-host".to_string())) +} + +pub(crate) fn get_username() -> String { + std::env::var("ATUIN_HOST_USER") + .unwrap_or_else(|_| whoami::username().unwrap_or_else(|_| "unknown-user".to_string())) +} + +/// Returns a pair of the hostname and username, separated by a colon. +pub(crate) fn get_host_user() -> String { + format!("{}:{}", get_hostname(), get_username()) +} diff --git a/crates/turtle/src/atuin_common/api.rs b/crates/turtle/src/atuin_common/api.rs new file mode 100644 index 00000000..1a9f348c --- /dev/null +++ b/crates/turtle/src/atuin_common/api.rs @@ -0,0 +1,144 @@ +use semver::Version; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::sync::LazyLock; +use time::OffsetDateTime; + +// the usage of X- has been deprecated for quite along time, it turns out +pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version"; +pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub static ATUIN_VERSION: LazyLock = + LazyLock::new(|| Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver")); + +#[derive(Debug, Serialize, Deserialize)] +pub struct UserResponse { + pub username: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteUserResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordRequest { + pub current_password: String, + pub new_password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub session: String, + /// Auth type: "hub" for Hub API tokens, "cli" for legacy CLI session tokens. + /// Old servers that don't return this field will deserialize as None. + #[serde(default)] + pub auth: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddHistoryRequest { + pub id: String, + #[serde(with = "time::serde::rfc3339")] + pub timestamp: OffsetDateTime, + pub data: String, + pub hostname: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountResponse { + pub count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryRequest { + #[serde(with = "time::serde::rfc3339")] + pub sync_ts: OffsetDateTime, + #[serde(with = "time::serde::rfc3339")] + pub history_ts: OffsetDateTime, + pub host: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryResponse { + pub history: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse<'a> { + pub reason: Cow<'a, str>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexResponse { + pub homage: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusResponse { + pub count: i64, + pub username: String, + pub deleted: Vec, + + // These could/should also go on the index of the server + // However, we do not request the server index as a part of normal sync + // I'd rather slightly increase the size of this response, than add an extra HTTP request + pub page_size: i64, // max page size supported by the server + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteHistoryRequest { + pub client_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageResponse { + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MeResponse { + pub username: String, +} + +// Hub CLI authentication types + +/// Response from POST /auth/cli/code - generates a code for CLI auth +#[derive(Debug, Serialize, Deserialize)] +pub struct CliCodeResponse { + pub code: String, +} + +/// Response from GET /auth/cli/verify?code= - polls for authorization +#[derive(Debug, Serialize, Deserialize)] +pub struct CliVerifyResponse { + /// Session token, present only when authorization is complete + pub token: Option, + pub success: Option, + pub error: Option, +} diff --git a/crates/turtle/src/atuin_common/calendar.rs b/crates/turtle/src/atuin_common/calendar.rs new file mode 100644 index 00000000..d3b1d921 --- /dev/null +++ b/crates/turtle/src/atuin_common/calendar.rs @@ -0,0 +1,16 @@ +// Calendar data +use serde::{Serialize, Deserialize}; + +pub enum TimePeriod { + YEAR, + MONTH, + DAY, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/turtle/src/atuin_common/mod.rs b/crates/turtle/src/atuin_common/mod.rs new file mode 100644 index 00000000..d886520d --- /dev/null +++ b/crates/turtle/src/atuin_common/mod.rs @@ -0,0 +1,58 @@ +/// Defines a new UUID type wrapper +macro_rules! new_uuid { + ($name:ident) => { + #[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, + )] + #[serde(transparent)] + pub struct $name(pub Uuid); + + impl sqlx::Type for $name + where + Uuid: sqlx::Type, + { + fn type_info() -> ::TypeInfo { + Uuid::type_info() + } + } + + impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name + where + Uuid: sqlx::Decode<'r, DB>, + { + fn decode( + value: DB::ValueRef<'r>, + ) -> std::result::Result { + Uuid::decode(value).map(Self) + } + } + + impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name + where + Uuid: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut DB::ArgumentBuffer<'q>, + ) -> Result> + { + self.0.encode_by_ref(buf) + } + } + }; +} + +pub mod api; +pub mod record; +pub mod shell; +pub mod tls; +pub mod utils; diff --git a/crates/turtle/src/atuin_common/record.rs b/crates/turtle/src/atuin_common/record.rs new file mode 100644 index 00000000..05c29338 --- /dev/null +++ b/crates/turtle/src/atuin_common/record.rs @@ -0,0 +1,426 @@ +use std::collections::HashMap; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; +use uuid::Uuid; + +#[derive(Clone, Debug, PartialEq)] +pub struct DecryptedData(pub Vec); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EncryptedData { + pub data: String, + pub content_encryption_key: String, +} + +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Diff { + pub host: HostId, + pub tag: String, + pub local: Option, + pub remote: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub struct Host { + pub id: HostId, + pub name: String, +} + +impl Host { + pub fn new(id: HostId) -> Self { + Host { + id, + name: String::new(), + } + } +} + +new_uuid!(RecordId); +new_uuid!(HostId); + +pub type RecordIdx = u64; + +/// A single record stored inside of our local database +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +pub struct Record { + /// a unique ID + #[builder(default = RecordId(crate::atuin_common::utils::uuid_v7()))] + pub id: RecordId, + + /// The integer record ID. This is only unique per (host, tag). + pub idx: RecordIdx, + + /// The unique ID of the host. + // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store + // as strings. I would rather avoid normalization, so store as UUID binary instead of + // encoding to a string and wasting much more storage. + pub host: Host, + + /// The creation time in nanoseconds since unix epoch + #[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)] + pub timestamp: u64, + + /// The version the data in the entry conforms to + // However we want to track versions for this tag, eg v2 + pub version: String, + + /// The type of data we are storing here. Eg, "history" + pub tag: String, + + /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. + pub data: Data, +} + +/// Extra data from the record that should be encoded in the data +#[derive(Debug, Copy, Clone)] +pub struct AdditionalData<'a> { + pub id: &'a RecordId, + pub idx: &'a u64, + pub version: &'a str, + pub tag: &'a str, + pub host: &'a HostId, +} + +impl Record { + pub fn append(&self, data: Vec) -> Record { + Record::builder() + .host(self.host.clone()) + .version(self.version.clone()) + .idx(self.idx + 1) + .tag(self.tag.clone()) + .data(DecryptedData(data)) + .build() + } +} + +/// An index representing the current state of the record stores +/// This can be both remote, or local, and compared in either direction +#[derive(Debug, Serialize, Deserialize)] +pub struct RecordStatus { + // A map of host -> tag -> max(idx) + pub hosts: HashMap>, +} + +impl Default for RecordStatus { + fn default() -> Self { + Self::new() + } +} + +impl Extend<(HostId, String, RecordIdx)> for RecordStatus { + fn extend>(&mut self, iter: T) { + for (host, tag, tail_idx) in iter { + self.set_raw(host, tag, tail_idx); + } + } +} + +impl RecordStatus { + pub fn new() -> RecordStatus { + RecordStatus { + hosts: HashMap::new(), + } + } + + /// Insert a new tail record into the store + pub fn set(&mut self, tail: Record) { + self.set_raw(tail.host.id, tail.tag, tail.idx) + } + + pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) { + self.hosts.entry(host).or_default().insert(tag, tail_id); + } + + pub fn get(&self, host: HostId, tag: String) -> Option { + self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() + } + + /// Diff this index with another, likely remote index. + /// The two diffs can then be reconciled, and the optimal change set calculated + /// Returns a tuple, with (host, tag, Option(OTHER)) + /// OTHER is set to the value of the idx on the other machine. If it is greater than our index, + /// then we need to do some downloading. If it is smaller, then we need to do some uploading + /// Note that we cannot upload if we are not the owner of the record store - hosts can only + /// write to their own store. + pub fn diff(&self, other: &Self) -> Vec { + let mut ret = Vec::new(); + + // First, we check if other has everything that self has + for (host, tag_map) in self.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match other.get(*host, tag.clone()) { + // The other store is all up to date! No diff. + Some(t) if t.eq(idx) => continue, + + // The other store does exist, and it is either ahead or behind us. A diff regardless + Some(t) => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: Some(t), + }), + + // The other store does not exist :O + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: None, + }), + }; + } + } + + // At this point, there is a single case we have not yet considered. + // If the other store knows of a tag that we are not yet aware of, then the diff will be missed + + // account for that! + for (host, tag_map) in other.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match self.get(*host, tag.clone()) { + // If we have this host/tag combo, the comparison and diff will have already happened above + Some(_) => continue, + + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + remote: Some(*idx), + local: None, + }), + }; + } + } + + // Stability is a nice property to have + ret.sort(); + ret + } +} + +pub trait Encryption { + fn re_encrypt( + data: EncryptedData, + ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result { + let data = Self::decrypt(data, ad, old_key)?; + Ok(Self::encrypt(data, ad, new_key)) + } + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData; + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result; +} + +impl Record { + pub fn encrypt(self, key: &[u8; 32]) -> Record { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Record { + data: E::encrypt(self.data, ad, key), + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + } + } +} + +impl Record { + pub fn decrypt(self, key: &[u8; 32]) -> Result> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::decrypt(self.data, ad, key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } + + pub fn re_encrypt( + self, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::re_encrypt(self.data, ad, old_key, new_key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::atuin_common::record::{Host, HostId}; + + use super::{DecryptedData, Diff, Record, RecordStatus}; + use pretty_assertions::assert_eq; + + fn test_record() -> Record { + Record::builder() + .host(Host::new(HostId(crate::atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(crate::atuin_common::utils::uuid_v7().simple().to_string()) + .data(DecryptedData(vec![0, 1, 2, 3])) + .idx(0) + .build() + } + + #[test] + fn record_index() { + let mut index = RecordStatus::new(); + let record = test_record(); + + index.set(record.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + record.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_overwrite() { + let mut index = RecordStatus::new(); + let record = test_record(); + let child = record.append(vec![1, 2, 3]); + + index.set(record.clone()); + index.set(child.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + child.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_no_diff() { + // Here, they both have the same version and should have no diff + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + + index1.set(record1.clone()); + index2.set(record1); + + let diff = index1.diff(&index2); + + assert_eq!(0, diff.len(), "expected empty diff"); + } + + #[test] + fn record_index_single_diff() { + // Here, they both have the same stores, but one is ahead by a single record + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + let record2 = record1.append(vec![1, 2, 3]); + + index1.set(record1); + index2.set(record2.clone()); + + let diff = index1.diff(&index2); + + assert_eq!(1, diff.len(), "expected single diff"); + assert_eq!( + diff[0], + Diff { + host: record2.host.id, + tag: record2.tag, + remote: Some(1), + local: Some(0) + } + ); + } + + #[test] + fn record_index_multi_diff() { + // A much more complex case, with a bunch more checks + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let store1record1 = test_record(); + let store1record2 = store1record1.append(vec![1, 2, 3]); + + let store2record1 = test_record(); + let store2record2 = store2record1.append(vec![1, 2, 3]); + + let store3record1 = test_record(); + + let store4record1 = test_record(); + + // index1 only knows about the first two entries of the first two stores + index1.set(store1record1); + index1.set(store2record1); + + // index2 is fully up to date with the first two stores, and knows of a third + index2.set(store1record2); + index2.set(store2record2); + index2.set(store3record1); + + // index1 knows of a 4th store + index1.set(store4record1); + + let diff1 = index1.diff(&index2); + let diff2 = index2.diff(&index1); + + // both diffs the same length + assert_eq!(4, diff1.len()); + assert_eq!(4, diff2.len()); + + dbg!(&diff1, &diff2); + + // both diffs should be ALMOST the same. They will agree on which hosts and tags + // require updating, but the "other" value will not be the same. + let smol_diff_1: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + let smol_diff_2: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + + assert_eq!(smol_diff_1, smol_diff_2); + + // diffing with yourself = no diff + assert_eq!(index1.diff(&index1).len(), 0); + assert_eq!(index2.diff(&index2).len(), 0); + } +} diff --git a/crates/turtle/src/atuin_common/shell.rs b/crates/turtle/src/atuin_common/shell.rs new file mode 100644 index 00000000..7f9a7b8f --- /dev/null +++ b/crates/turtle/src/atuin_common/shell.rs @@ -0,0 +1,183 @@ +use std::{ffi::OsStr, path::Path, process::Command}; + +use serde::Serialize; +use sysinfo::{Process, System, get_current_pid}; +use thiserror::Error; + +#[derive(PartialEq)] +pub enum Shell { + Sh, + Bash, + Fish, + Zsh, + Xonsh, + Nu, + Powershell, + + Unknown, +} + +impl std::fmt::Display for Shell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let shell = match self { + Shell::Bash => "bash", + Shell::Fish => "fish", + Shell::Zsh => "zsh", + Shell::Nu => "nu", + Shell::Xonsh => "xonsh", + Shell::Sh => "sh", + Shell::Powershell => "powershell", + + Shell::Unknown => "unknown", + }; + + write!(f, "{shell}") + } +} + +#[derive(Debug, Error, Serialize)] +pub enum ShellError { + #[error("shell not supported")] + NotSupported, + + #[error("failed to execute shell command: {0}")] + ExecError(String), +} + +impl Shell { + pub fn current() -> Shell { + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + Shell::from_string(shell.to_string()) + } + + pub fn from_env() -> Shell { + std::env::var("ATUIN_SHELL").map_or(Shell::Unknown, |shell| { + Shell::from_string(shell.trim().to_lowercase()) + }) + } + + pub fn config_file(&self) -> Option { + let mut path = if let Some(base) = directories::BaseDirs::new() { + base.home_dir().to_owned() + } else { + return None; + }; + + // TODO: handle all shells + match self { + Shell::Bash => path.push(".bashrc"), + Shell::Zsh => path.push(".zshrc"), + Shell::Fish => path.push(".config/fish/config.fish"), + + _ => return None, + }; + + Some(path) + } + + /// Best-effort attempt to determine the default shell + /// This implementation will be different across different platforms + /// Caller should ensure to handle Shell::Unknown correctly + pub fn default_shell() -> Result { + let sys = System::name().unwrap_or("".to_string()).to_lowercase(); + + // TODO: Support Linux + // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues + let path = if sys.contains("darwin") { + // This works in my testing so far + Shell::Sh.run_interactive([ + "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", + ])? + } else if cfg!(windows) { + return Ok(Shell::Powershell); + } else { + Shell::Sh.run_interactive(["getent passwd $LOGNAME | cut -d: -f7"])? + }; + + let path = Path::new(path.trim()); + let shell = path.file_name(); + + if shell.is_none() { + return Err(ShellError::NotSupported); + } + + Ok(Shell::from_string( + shell.unwrap().to_string_lossy().to_string(), + )) + } + + pub fn from_string(name: String) -> Shell { + match name.as_str() { + "bash" => Shell::Bash, + "fish" => Shell::Fish, + "zsh" => Shell::Zsh, + "xonsh" => Shell::Xonsh, + "nu" => Shell::Nu, + "sh" => Shell::Sh, + "powershell" => Shell::Powershell, + + _ => Shell::Unknown, + } + } + + /// Returns true if the shell is posix-like + /// Note that while fish is not posix compliant, it behaves well enough for our current + /// featureset that this does not matter. + pub fn is_posixish(&self) -> bool { + matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) + } + + pub fn run_interactive(&self, args: I) -> Result + where + I: IntoIterator, + S: AsRef, + { + let shell = self.to_string(); + let output = if self == &Self::Powershell { + Command::new(shell) + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + } else { + Command::new(shell) + .arg("-ic") + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))? + }; + + Ok(String::from_utf8(output.stdout).unwrap()) + } +} + +pub fn shell_name(parent: Option<&Process>) -> String { + let sys = System::new_all(); + + let parent = if let Some(parent) = parent { + parent + } else { + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + sys.process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist") + }; + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + shell.to_string() +} diff --git a/crates/turtle/src/atuin_common/tls.rs b/crates/turtle/src/atuin_common/tls.rs new file mode 100644 index 00000000..e8c840e0 --- /dev/null +++ b/crates/turtle/src/atuin_common/tls.rs @@ -0,0 +1,15 @@ +use std::sync::Once; + +static INIT: Once = Once::new(); + +/// Ensure the rustls crypto provider (ring) is installed. +/// +/// Must be called before creating any reqwest clients. Safe to call +/// multiple times — only the first call installs the provider. +pub fn ensure_crypto_provider() { + INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install rustls crypto provider"); + }); +} diff --git a/crates/turtle/src/atuin_common/utils.rs b/crates/turtle/src/atuin_common/utils.rs new file mode 100644 index 00000000..d7382fb2 --- /dev/null +++ b/crates/turtle/src/atuin_common/utils.rs @@ -0,0 +1,383 @@ +use std::borrow::Cow; +use std::env; +use std::path::{Path, PathBuf}; + +use eyre::{Result, eyre}; + +use base64::prelude::{BASE64_URL_SAFE_NO_PAD, Engine}; +use getrandom::getrandom; +use uuid::Uuid; + +/// Generate N random bytes, using a cryptographically secure source +pub fn crypto_random_bytes() -> [u8; N] { + // rand say they are in principle safe for crypto purposes, but that it is perhaps a better + // idea to use getrandom for things such as passwords. + let mut ret = [0u8; N]; + + getrandom(&mut ret).expect("Failed to generate random bytes!"); + + ret +} + +/// Generate N random bytes using a cryptographically secure source, return encoded as a string +pub fn crypto_random_string() -> String { + let bytes = crypto_random_bytes::(); + + // We only use this to create a random string, and won't be reversing it to find the original + // data - no padding is OK there. It may be in URLs. + BASE64_URL_SAFE_NO_PAD.encode(bytes) +} + +pub fn uuid_v7() -> Uuid { + Uuid::now_v7() +} + +pub fn uuid_v4() -> String { + Uuid::new_v4().as_simple().to_string() +} + +pub fn has_git_dir(path: &str) -> bool { + let mut gitdir = PathBuf::from(path); + gitdir.push(".git"); + + gitdir.exists() +} + +// in a git worktree, .git is a file containing "gitdir: " pointing +// to the main repo's .git/worktrees/ directory. follow the pointer +// back to the main repo root so all worktrees share a workspace. +fn resolve_git_worktree(path: &Path) -> Option { + let git_path = path.join(".git"); + + if !git_path.is_file() { + return None; + } + + let contents = std::fs::read_to_string(&git_path).ok()?; + let gitdir_str = contents.strip_prefix("gitdir: ")?.trim(); + + let gitdir = PathBuf::from(gitdir_str); + let gitdir = if gitdir.is_absolute() { + gitdir + } else { + path.join(gitdir_str) + }; + + // walk up from e.g. /repo/.git/worktrees/feature to find /repo + let mut candidate = gitdir.as_path(); + while let Some(parent) = candidate.parent() { + if parent.join(".git").is_dir() { + return Some(parent.to_path_buf()); + } + candidate = parent; + } + + None +} + +// detect if any parent dir has a git repo in it +// I really don't want to bring in libgit for something simple like this +// If we start to do anything more advanced, then perhaps +pub fn in_git_repo(path: &str) -> Option { + let mut gitdir = PathBuf::from(path); + + while gitdir.parent().is_some() && !has_git_dir(gitdir.to_str().unwrap()) { + gitdir.pop(); + } + + // No parent? then we hit root, finding no git + if gitdir.parent().is_some() { + // if .git is a file (worktree), resolve to the main repo root + if let Some(main_repo) = resolve_git_worktree(&gitdir) { + return Some(main_repo); + } + return Some(gitdir); + } + + None +} + +// TODO: more reliable, more tested +// I don't want to use ProjectDirs, it puts config in awkward places on +// mac. Data too. Seems to be more intended for GUI apps. + +pub fn home_dir() -> PathBuf { + directories::BaseDirs::new() + .map(|d| d.home_dir().to_path_buf()) + .expect("could not determine home directory") +} + +pub fn config_dir() -> PathBuf { + let config_dir = + std::env::var("XDG_CONFIG_HOME").map_or_else(|_| home_dir().join(".config"), PathBuf::from); + config_dir.join("atuin") +} + +pub fn data_dir() -> PathBuf { + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin") +} + +pub fn runtime_dir() -> PathBuf { + std::env::var("XDG_RUNTIME_DIR").map_or_else(|_| data_dir(), PathBuf::from) +} + +pub fn logs_dir() -> PathBuf { + home_dir().join(".atuin").join("logs") +} + +pub fn dotfiles_cache_dir() -> PathBuf { + // In most cases, this will be ~/.local/share/atuin/dotfiles/cache + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin").join("dotfiles").join("cache") +} + +pub fn get_current_dir() -> String { + // Prefer PWD environment variable over cwd if available to better support symbolic links + match env::var("PWD") { + Ok(v) => v, + Err(_) => match env::current_dir() { + Ok(dir) => dir.display().to_string(), + Err(_) => String::from(""), + }, + } +} + +pub fn broken_symlink>(path: P) -> bool { + let path = path.into(); + path.is_symlink() && !path.exists() +} + +/// Extension trait for anything that can behave like a string to make it easy to escape control +/// characters. +/// +/// Intended to help prevent control characters being printed and interpreted by the terminal when +/// printing history as well as to ensure the commands that appear in the interactive search +/// reflect the actual command run rather than just the printable characters. +pub trait Escapable: AsRef { + fn escape_control(&self) -> Cow<'_, str> { + if !self.as_ref().contains(|c: char| c.is_ascii_control()) { + self.as_ref().into() + } else { + let mut remaining = self.as_ref(); + // Not a perfect way to reserve space but should reduce the allocations + let mut buf = String::with_capacity(remaining.len()); + while let Some(i) = remaining.find(|c: char| c.is_ascii_control()) { + // safe to index with `..i`, `i` and `i+1..` as part[i] is a single byte ascii char + buf.push_str(&remaining[..i]); + buf.push('^'); + buf.push(match remaining.as_bytes()[i] { + 0x7F => '?', + code => char::from_u32(u32::from(code) + 64).unwrap(), + }); + remaining = &remaining[i + 1..]; + } + buf.push_str(remaining); + buf.into() + } + } +} + +pub fn unquote(s: &str) -> Result { + if s.chars().count() < 2 { + return Err(eyre!("not enough chars")); + } + + let quote = s.chars().next().unwrap(); + + // not quoted, do nothing + if quote != '"' && quote != '\'' && quote != '`' { + return Ok(s.to_string()); + } + + if s.chars().last().unwrap() != quote { + return Err(eyre!("unexpected eof, quotes do not match")); + } + + // removes quote characters + // the sanity checks performed above ensure that the quotes will be ASCII and this will not + // panic + let s = &s[1..s.len() - 1]; + + Ok(s.to_string()) +} + +impl> Escapable for T {} + +#[expect(unsafe_code)] +#[cfg(test)] +mod tests { + use pretty_assertions::assert_ne; + + use super::*; + + use std::collections::HashSet; + + #[cfg(not(windows))] + #[test] + fn test_dirs() { + // these tests need to be run sequentially to prevent race condition + test_config_dir_xdg(); + test_config_dir(); + test_data_dir_xdg(); + test_data_dir(); + } + + #[cfg(not(windows))] + fn test_config_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_CONFIG_HOME", "/home/user/custom_config") }; + assert_eq!( + config_dir(), + PathBuf::from("/home/user/custom_config/atuin") + ); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + } + + #[cfg(not(windows))] + fn test_config_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_CONFIG_HOME") }; + + assert_eq!(config_dir(), PathBuf::from("/home/user/.config/atuin")); + + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir_xdg() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("XDG_DATA_HOME", "/home/user/custom_data") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/custom_data/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + } + + #[cfg(not(windows))] + fn test_data_dir() { + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::set_var("HOME", "/home/user") }; + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("XDG_DATA_HOME") }; + assert_eq!(data_dir(), PathBuf::from("/home/user/.local/share/atuin")); + // TODO: Audit that the environment access only happens in single-threaded code. + unsafe { env::remove_var("HOME") }; + } + + #[test] + fn uuid_is_unique() { + let how_many: usize = 1000000; + + // for peace of mind + let mut uuids: HashSet = HashSet::with_capacity(how_many); + + // there will be many in the same millisecond + for _ in 0..how_many { + let uuid = uuid_v7(); + uuids.insert(uuid); + } + + assert_eq!(uuids.len(), how_many); + } + + #[test] + fn escape_control_characters() { + use super::Escapable; + // CSI colour sequence + assert_eq!("\x1b[31mfoo".escape_control(), "^[[31mfoo"); + + // Tabs count as control chars + assert_eq!("foo\tbar".escape_control(), "foo^Ibar"); + + // space is in control char range but should be excluded + assert_eq!("two words".escape_control(), "two words"); + + // unicode multi-byte characters + let s = "🐢\x1b[32m🦀"; + assert_eq!(s.escape_control(), s.replace("\x1b", "^[")); + } + + #[test] + fn escape_no_control_characters() { + use super::Escapable as _; + assert!(matches!( + "no control characters".escape_control(), + Cow::Borrowed(_) + )); + assert!(matches!( + "with \x1b[31mcontrol\x1b[0m characters".escape_control(), + Cow::Owned(_) + )); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_regular() { + // regular git repo should resolve to the directory containing .git + let tmp = std::env::temp_dir().join("atuin-test-regular-git"); + let _ = std::fs::remove_dir_all(&tmp); + let subdir = tmp.join("src").join("deep"); + std::fs::create_dir_all(&subdir).unwrap(); + std::fs::create_dir_all(tmp.join(".git")).unwrap(); + + let result = in_git_repo(subdir.to_str().unwrap()); + assert_eq!(result, Some(tmp.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[cfg(not(windows))] + #[test] + fn in_git_repo_worktree_resolves_to_main_repo() { + // worktree .git is a file pointing back to the main repo — + // in_git_repo should follow it so all worktrees share a workspace + let tmp = std::env::temp_dir().join("atuin-test-worktree-git"); + let _ = std::fs::remove_dir_all(&tmp); + + // main repo at tmp/main with a real .git directory + let main_repo = tmp.join("main"); + let worktree_git_dir = main_repo.join(".git").join("worktrees").join("feature"); + std::fs::create_dir_all(&worktree_git_dir).unwrap(); + + // worktree at tmp/worktree with a .git file + let worktree = tmp.join("worktree"); + let worktree_subdir = worktree.join("src"); + std::fs::create_dir_all(&worktree_subdir).unwrap(); + std::fs::write( + worktree.join(".git"), + format!("gitdir: {}", worktree_git_dir.to_str().unwrap()), + ) + .unwrap(); + + // should resolve to the main repo root, not the worktree root + let result = in_git_repo(worktree_subdir.to_str().unwrap()); + assert_eq!(result, Some(main_repo.clone())); + + std::fs::remove_dir_all(&tmp).unwrap(); + } + + #[test] + fn dumb_random_test() { + // Obviously not a test of randomness, but make sure we haven't made some + // catastrophic error + + assert_ne!(crypto_random_string::<1>(), crypto_random_string::<1>()); + assert_ne!(crypto_random_string::<2>(), crypto_random_string::<2>()); + assert_ne!(crypto_random_string::<4>(), crypto_random_string::<4>()); + assert_ne!(crypto_random_string::<8>(), crypto_random_string::<8>()); + assert_ne!(crypto_random_string::<16>(), crypto_random_string::<16>()); + assert_ne!(crypto_random_string::<32>(), crypto_random_string::<32>()); + } +} diff --git a/crates/turtle/src/atuin_daemon/client.rs b/crates/turtle/src/atuin_daemon/client.rs new file mode 100644 index 00000000..45ef19e9 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/client.rs @@ -0,0 +1,418 @@ +use crate::atuin_client::database::Context; +use crate::atuin_client::settings::{FilterMode, Settings}; +use eyre::{Context as EyreContext, Result}; +use tonic::Code; +use tonic::transport::{Channel, Endpoint, Uri}; +use tower::service_fn; + +use hyper_util::rt::TokioIo; + +#[cfg(unix)] +use tokio::net::UnixStream; + +use crate::atuin_client::history::History; +use tracing::{Level, instrument, span}; + +use crate::atuin_daemon::control::HistoryRebuiltEvent; +use crate::atuin_daemon::control::{ + ForceSyncEvent, HistoryDeletedEvent, HistoryPrunedEvent, SendEventRequest, + SettingsReloadedEvent, ShutdownEvent, control_client::ControlClient as ControlServiceClient, +}; +use crate::atuin_daemon::events::DaemonEvent; +use crate::atuin_daemon::history::{ + EndHistoryReply, EndHistoryRequest, ShutdownRequest, StartHistoryReply, StartHistoryRequest, + StatusReply, StatusRequest, TailHistoryReply, TailHistoryRequest, + history_client::HistoryClient as HistoryServiceClient, +}; +use crate::atuin_daemon::search::{ + FilterMode as RpcFilterMode, SearchContext as RpcSearchContext, SearchRequest, SearchResponse, + search_client::SearchClient as SearchServiceClient, +}; +use crate::atuin_daemon::semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputRange, RecordCommandsReply, + semantic_client::SemanticClient as SemanticServiceClient, +}; + +pub struct HistoryClient { + client: HistoryServiceClient, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum DaemonClientErrorKind { + Connect, + Unavailable, + Unimplemented, + Other, +} + +#[must_use] +pub fn classify_error(error: &eyre::Report) -> DaemonClientErrorKind { + for cause in error.chain() { + if cause.downcast_ref::().is_some() { + return DaemonClientErrorKind::Connect; + } + + if let Some(status) = cause.downcast_ref::() { + return match status.code() { + Code::Unavailable => DaemonClientErrorKind::Unavailable, + Code::Unimplemented => DaemonClientErrorKind::Unimplemented, + _ => DaemonClientErrorKind::Other, + }; + } + } + + DaemonClientErrorKind::Other +} + +// Wrap the grpc client +impl HistoryClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result { + use eyre::Context; + + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = HistoryServiceClient::new(channel); + + Ok(HistoryClient { client }) + } + + pub async fn start_history(&mut self, h: History) -> Result { + let req = StartHistoryRequest { + command: h.command, + cwd: h.cwd, + hostname: h.hostname, + session: h.session, + timestamp: h.timestamp.unix_timestamp_nanos() as u64, + author: h.author, + intent: h.intent.unwrap_or_default(), + }; + + Ok(self.client.start_history(req).await?.into_inner()) + } + + pub async fn end_history( + &mut self, + id: String, + duration: u64, + exit: i64, + ) -> Result { + let req = EndHistoryRequest { id, duration, exit }; + + Ok(self.client.end_history(req).await?.into_inner()) + } + + pub async fn status(&mut self) -> Result { + Ok(self.client.status(StatusRequest {}).await?.into_inner()) + } + + pub async fn tail_history(&mut self) -> Result> { + Ok(self + .client + .tail_history(TailHistoryRequest {}) + .await? + .into_inner()) + } + + pub async fn shutdown(&mut self) -> Result { + let resp = self.client.shutdown(ShutdownRequest {}).await?.into_inner(); + Ok(resp.accepted) + } +} + +pub struct SearchClient { + client: SearchServiceClient, +} + +impl SearchClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SearchServiceClient::new(channel); + + Ok(SearchClient { client }) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_client_search", fields(query = %query, query_id = query_id))] + pub async fn search( + &mut self, + query: String, + query_id: u64, + filter_mode: FilterMode, + context: Option, + ) -> Result> { + let request = SearchRequest { + query, + query_id, + filter_mode: RpcFilterMode::from(filter_mode).into(), + context: context.map(RpcSearchContext::from), + }; + let request_stream = tokio_stream::once(request); + let response = span!(Level::TRACE, "daemon_client_search.request") + .in_scope(async || self.client.search(request_stream).await) + .await?; + + Ok(response.into_inner()) + } +} + +impl From for RpcFilterMode { + fn from(filter_mode: FilterMode) -> Self { + match filter_mode { + FilterMode::Global => RpcFilterMode::Global, + FilterMode::Host => RpcFilterMode::Host, + FilterMode::Session => RpcFilterMode::Session, + FilterMode::Directory => RpcFilterMode::Directory, + FilterMode::Workspace => RpcFilterMode::Workspace, + FilterMode::SessionPreload => RpcFilterMode::SessionPreload, + } + } +} + +impl From for RpcSearchContext { + fn from(context: Context) -> Self { + RpcSearchContext { + session_id: context.session, + cwd: context.cwd, + hostname: context.hostname, + host_id: context.host_id, + git_root: context + .git_root + .map(|path| path.to_string_lossy().to_string()), + } + } +} + +pub struct SemanticClient { + client: SemanticServiceClient, +} + +impl SemanticClient { + #[cfg(unix)] + pub async fn new(path: String) -> Result { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = SemanticServiceClient::new(channel); + + Ok(SemanticClient { client }) + } + + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result { + Self::new(settings.daemon.socket_path.clone()).await + } + + pub async fn record_commands( + &mut self, + captures: Vec, + ) -> Result { + let stream = tokio_stream::iter(captures); + Ok(self.client.record_commands(stream).await?.into_inner()) + } + + pub async fn command_output( + &mut self, + history_id: String, + ranges: Vec<(i64, i64)>, + ) -> Result { + let request = CommandOutputRequest { + history_id, + ranges: ranges + .into_iter() + .map(|(start, end)| OutputRange { start, end }) + .collect(), + }; + + Ok(self.client.command_output(request).await?.into_inner()) + } +} + +// ============================================================================ +// Control Client +// ============================================================================ + +/// Client for the Control gRPC service. +/// +/// Used to inject events into a running daemon from external processes. +pub struct ControlClient { + client: ControlServiceClient, +} + +impl ControlClient { + /// Connect to the daemon's control service. + #[cfg(unix)] + pub async fn new(path: String) -> Result { + let log_path = path.clone(); + let channel = Endpoint::try_from("http://atuin_local_daemon:0")? + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + + async move { + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path.clone()).await?)) + } + })) + .await + .wrap_err_with(|| { + format!( + "failed to connect to local atuin daemon at {}. Is it running?", + &log_path + ) + })?; + + let client = ControlServiceClient::new(channel); + + Ok(ControlClient { client }) + } + + /// Connect using settings. + #[cfg(unix)] + pub async fn from_settings(settings: &Settings) -> Result { + Self::new(settings.daemon.socket_path.clone()).await + } + + /// Send an event to the daemon. + pub async fn send_event(&mut self, event: DaemonEvent) -> Result<()> { + let proto_event = daemon_event_to_proto(event); + let request = SendEventRequest { + event: Some(proto_event), + }; + self.client.send_event(request).await?; + Ok(()) + } +} + +/// Convert a daemon event to its proto representation. +fn daemon_event_to_proto( + event: DaemonEvent, +) -> crate::atuin_daemon::control::send_event_request::Event { + use crate::atuin_daemon::control::send_event_request::Event; + + match event { + DaemonEvent::HistoryPruned => Event::HistoryPruned(HistoryPrunedEvent {}), + DaemonEvent::HistoryRebuilt => Event::HistoryRebuilt(HistoryRebuiltEvent {}), + DaemonEvent::HistoryDeleted { ids } => Event::HistoryDeleted(HistoryDeletedEvent { + ids: ids.into_iter().map(|id| id.0).collect(), + }), + DaemonEvent::ForceSync => Event::ForceSync(ForceSyncEvent {}), + DaemonEvent::SettingsReloaded => Event::SettingsReloaded(SettingsReloadedEvent {}), + DaemonEvent::ShutdownRequested => Event::Shutdown(ShutdownEvent {}), + // These events are internal and not sent via the control service + DaemonEvent::HistoryStarted(_) + | DaemonEvent::HistoryEnded(_) + | DaemonEvent::RecordsAdded(_) + | DaemonEvent::SyncCompleted { .. } + | DaemonEvent::SyncFailed { .. } => { + // Use shutdown as a fallback, though this shouldn't happen + tracing::warn!("attempted to send internal event via control service"); + Event::Shutdown(ShutdownEvent {}) + } + } +} + +// ============================================================================ +// Convenience Functions +// ============================================================================ + +/// Emit an event to the daemon. +/// +/// This is a fire-and-forget helper for sending events to the daemon from +/// external processes like CLI commands. If the daemon isn't running, this +/// will silently succeed (returns Ok). +/// +/// # Example +/// +/// ```ignore +/// // After pruning history +/// emit_event(DaemonEvent::HistoryPruned).await?; +/// +/// // After deleting specific history items +/// emit_event(DaemonEvent::HistoryDeleted { ids: vec![...] }).await?; +/// +/// // Request immediate sync +/// emit_event(DaemonEvent::ForceSync).await?; +/// ``` +pub async fn emit_event(event: DaemonEvent) -> Result<()> { + emit_event_with_settings(event, None).await +} + +/// Emit an event to the daemon with explicit settings. +/// +/// If settings are not provided, they will be loaded from the default location. +/// If the daemon isn't running, this will silently succeed. +pub async fn emit_event_with_settings( + event: DaemonEvent, + settings: Option<&Settings>, +) -> Result<()> { + // Load settings if not provided + let owned_settings; + let settings = match settings { + Some(s) => s, + None => { + owned_settings = Settings::new()?; + &owned_settings + } + }; + + // Try to connect - if daemon isn't running, that's fine + let mut client = match ControlClient::from_settings(settings).await { + Ok(c) => c, + Err(e) => { + tracing::debug!(?e, "daemon not running, skipping event emission"); + return Ok(()); + } + }; + + // Send the event + if let Err(e) = client.send_event(event).await { + tracing::debug!(?e, "failed to send event to daemon"); + // Don't fail - this is fire-and-forget + } + + Ok(()) +} diff --git a/crates/turtle/src/atuin_daemon/components/history.rs b/crates/turtle/src/atuin_daemon/components/history.rs new file mode 100644 index 00000000..95d34b69 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/history.rs @@ -0,0 +1,327 @@ +//! History component. +//! +//! Handles command history lifecycle (start/end) and provides the History gRPC service. + +use std::{pin::Pin, sync::Arc}; + +use crate::atuin_client::{ + database::Database, + history::{History, HistoryId, store::HistoryStore}, + settings::Settings, +}; +use dashmap::DashMap; +use eyre::Result; +use time::OffsetDateTime; +use tokio_stream::Stream; +use tonic::{Request, Response, Status}; +use tracing::{Level, instrument}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + history::{ + EndHistoryReply, EndHistoryRequest, HistoryEntry, HistoryEventKind, ShutdownReply, + ShutdownRequest, StartHistoryReply, StartHistoryRequest, StatusReply, StatusRequest, + TailHistoryReply, TailHistoryRequest, + history_server::{History as HistorySvc, HistoryServer}, + }, +}; + +const DAEMON_PROTOCOL_VERSION: u32 = 1; + +/// History component - manages command history lifecycle. +/// +/// This component: +/// - Tracks currently running commands (stored in memory) +/// - Saves completed commands to the database and record store +/// - Emits history events for other components (e.g., search indexing) +/// - Provides the History gRPC service +pub struct HistoryComponent { + inner: Arc, +} + +struct HistoryComponentInner { + /// Commands currently running (not yet completed). + running: DashMap, + + /// Handle to the daemon (set during start). + handle: tokio::sync::RwLock>, + + /// History store for pushing records (set during start). + history_store: tokio::sync::RwLock>, +} + +impl HistoryComponent { + /// Create a new history component. + pub fn new() -> Self { + Self { + inner: Arc::new(HistoryComponentInner { + running: DashMap::new(), + handle: tokio::sync::RwLock::new(None), + history_store: tokio::sync::RwLock::new(None), + }), + } + } + + /// Get the gRPC service for this component. + /// + /// This returns a tonic service that can be added to a gRPC server. + pub fn grpc_service(&self) -> HistoryServer { + HistoryServer::new(HistoryGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for HistoryComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for HistoryComponent { + fn name(&self) -> &'static str { + "history" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + // Create the history store + let host_id = Settings::host_id().await?; + let history_store = + HistoryStore::new(handle.store().clone(), host_id, *handle.encryption_key()); + + *self.inner.history_store.write().await = Some(history_store); + *self.inner.handle.write().await = Some(handle); + + tracing::info!("history component started"); + Ok(()) + } + + async fn handle_event(&mut self, _event: &DaemonEvent) -> Result<()> { + // History component produces events but doesn't need to react to them + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + tracing::info!("history component stopped"); + Ok(()) + } +} + +/// The gRPC service implementation. +/// +/// This is a thin wrapper that delegates to the component's shared state. +pub struct HistoryGrpcService { + inner: Arc, +} + +fn history_to_tail_reply(kind: HistoryEventKind, history: History) -> TailHistoryReply { + TailHistoryReply { + kind: kind as i32, + history: Some(HistoryEntry { + timestamp: history.timestamp.unix_timestamp_nanos() as u64, + id: history.id.0, + command: history.command, + cwd: history.cwd, + session: history.session, + hostname: history.hostname, + author: history.author, + intent: history.intent.unwrap_or_default(), + exit: history.exit, + duration: history.duration, + }), + } +} + +#[tonic::async_trait] +impl HistorySvc for HistoryGrpcService { + type TailHistoryStream = Pin> + Send>>; + + #[instrument(skip_all, level = Level::INFO)] + async fn start_history( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + let timestamp = + OffsetDateTime::from_unix_timestamp_nanos(req.timestamp as i128).map_err(|_| { + Status::invalid_argument( + "failed to parse timestamp as unix time (expected nanos since epoch)", + ) + })?; + + let h: History = History::daemon() + .timestamp(timestamp) + .command(req.command) + .cwd(req.cwd) + .session(req.session) + .hostname(req.hostname) + .author(req.author) + .intent(req.intent) + .build() + .into(); + + // Emit the event + if let Some(handle) = self.inner.handle.read().await.as_ref() { + handle.emit(DaemonEvent::HistoryStarted(h.clone())); + } + + let id = h.id.clone(); + tracing::info!(id = id.to_string(), "start history"); + self.inner.running.insert(id.clone(), h); + + let reply = StartHistoryReply { + id: id.to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn end_history( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + let id = HistoryId(req.id); + + if let Some((_, mut history)) = self.inner.running.remove(&id) { + history.exit = req.exit; + history.duration = match req.duration { + 0 => i64::try_from( + (OffsetDateTime::now_utc() - history.timestamp).whole_nanoseconds(), + ) + .expect("failed to convert calculated duration to i64"), + value => i64::try_from(value).expect("failed to get i64 duration"), + }; + + // Get the handle and store to save the history + let handle_guard = self.inner.handle.read().await; + let handle = handle_guard + .as_ref() + .ok_or_else(|| Status::internal("component not initialized"))?; + + let store_guard = self.inner.history_store.read().await; + let history_store = store_guard + .as_ref() + .ok_or_else(|| Status::internal("component not initialized"))?; + + // Save to database + handle + .history_db() + .save(&history) + .await + .map_err(|e| Status::internal(format!("failed to write to db: {e:?}")))?; + + tracing::info!( + id = id.0.to_string(), + duration = history.duration, + "end history" + ); + + // Push to record store + let (record_id, idx) = history_store + .push(history.clone()) + .await + .map_err(|e| Status::internal(format!("failed to push record to store: {e:?}")))?; + + // Emit the event + handle.emit(DaemonEvent::HistoryEnded(history)); + + let reply = EndHistoryReply { + id: record_id.0.to_string(), + idx, + version: env!("CARGO_PKG_VERSION").to_string(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + return Ok(Response::new(reply)); + } + + Err(Status::not_found(format!( + "could not find history with id: {id}" + ))) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn tail_history( + &self, + _request: Request, + ) -> Result, Status> { + let handle_guard = self.inner.handle.read().await; + let handle = handle_guard + .as_ref() + .cloned() + .ok_or_else(|| Status::internal("component not initialized"))?; + + let mut rx = handle.subscribe(); + let (tx, out_rx) = tokio::sync::mpsc::channel::>(128); + + tokio::spawn(async move { + loop { + let event = match rx.recv().await { + Ok(event) => event, + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + let _ = tx + .send(Err(Status::resource_exhausted(format!( + "tail stream lagged behind and dropped {skipped} events" + )))) + .await; + break; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + }; + + let reply = match event { + DaemonEvent::HistoryStarted(history) => { + Some(history_to_tail_reply(HistoryEventKind::Started, history)) + } + DaemonEvent::HistoryEnded(history) => { + Some(history_to_tail_reply(HistoryEventKind::Ended, history)) + } + _ => None, + }; + + if let Some(reply) = reply + && tx.send(Ok(reply)).await.is_err() + { + break; + } + } + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(out_rx); + Ok(Response::new(Box::pin(stream))) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn status( + &self, + _request: Request, + ) -> Result, Status> { + let reply = StatusReply { + healthy: true, + version: env!("CARGO_PKG_VERSION").to_string(), + pid: std::process::id(), + protocol: DAEMON_PROTOCOL_VERSION, + }; + + Ok(Response::new(reply)) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn shutdown( + &self, + _request: Request, + ) -> Result, Status> { + // Use the daemon handle to request shutdown + if let Some(handle) = self.inner.handle.read().await.as_ref() { + handle.shutdown(); + } + Ok(Response::new(ShutdownReply { accepted: true })) + } +} diff --git a/crates/turtle/src/atuin_daemon/components/mod.rs b/crates/turtle/src/atuin_daemon/components/mod.rs new file mode 100644 index 00000000..447e31df --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/mod.rs @@ -0,0 +1,25 @@ +//! Daemon components. +//! +//! Components are the building blocks of the daemon. Each component handles +//! a specific domain and can: +//! +//! - Expose gRPC services +//! - React to events +//! - Spawn background tasks +//! +//! Available components: +//! +//! - [`history::HistoryComponent`]: Command history lifecycle management +//! - [`search::SearchComponent`]: Fuzzy search over history +//! - [`semantic::SemanticComponent`]: In-memory semantic command captures +//! - [`sync::SyncComponent`]: Cloud sync + +pub mod history; +pub mod search; +pub mod semantic; +pub mod sync; + +pub use history::HistoryComponent; +pub use search::SearchComponent; +pub use semantic::SemanticComponent; +pub use sync::SyncComponent; diff --git a/crates/turtle/src/atuin_daemon/components/search.rs b/crates/turtle/src/atuin_daemon/components/search.rs new file mode 100644 index 00000000..85191cff --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/search.rs @@ -0,0 +1,413 @@ +//! Search component. +//! +//! Provides fuzzy search over command history using the Nucleo search library +//! with frecency-based ranking and dynamic filtering. + +use std::{pin::Pin, sync::Arc}; + +use crate::atuin_client::database::Database; +use eyre::Result; +use tokio::sync::RwLock; +use tokio_stream::Stream; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, debug, info, instrument, span, trace}; +use uuid::Uuid; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + search::{ + FilterMode, IndexFilterMode, QueryContext, SearchIndex, SearchRequest, SearchResponse, + search_server::{Search as SearchSvc, SearchServer}, + }, +}; + +const PAGE_SIZE: usize = 5000; +const RESULTS_LIMIT: u32 = 200; +/// How often to rebuild the frecency map (in seconds). +const FRECENCY_REFRESH_INTERVAL_SECS: u64 = 60; + +/// Search component - provides fuzzy search over command history. +/// +/// This component: +/// - Maintains a deduplicated search index with frecency ranking +/// - Loads history from the database on startup +/// - Updates the index when history events occur +/// - Provides the Search gRPC service +pub struct SearchComponent { + index: Arc>, + handle: tokio::sync::RwLock>, + loader_handle: Option>, + frecency_handle: Option>, +} + +impl SearchComponent { + /// Create a new search component. + pub fn new() -> Self { + Self { + index: Arc::new(RwLock::new(SearchIndex::new())), + handle: tokio::sync::RwLock::new(None), + loader_handle: None, + frecency_handle: None, + } + } + + /// Get the gRPC service for this component. + pub fn grpc_service(&self) -> SearchServer { + SearchServer::new(SearchGrpcService { + index: self.index.clone(), + }) + } + + /// Rebuild the entire search index from the database. + async fn rebuild_index(&self) -> Result<()> { + let handle_guard = self.handle.read().await; + let handle = handle_guard + .as_ref() + .ok_or_else(|| eyre::eyre!("component not initialized"))?; + + info!("Rebuilding search index from database"); + + // Create a new index + let new_index = SearchIndex::new(); + + // Load all history into the new index + let db = handle.history_db().clone(); + let mut pager = db.all_paged(PAGE_SIZE, false, true); + loop { + match pager.next().await { + Ok(Some(histories)) => { + info!( + "Loading {} history entries into search index", + histories.len() + ); + new_index.add_histories(&histories); + } + Ok(None) => break, + Err(e) => { + tracing::error!("Failed to load history during rebuild: {}", e); + break; + } + } + } + + info!( + "Search index rebuild complete; {} unique commands", + new_index.command_count() + ); + + // Replace the old index with the new one + *self.index.write().await = new_index; + Ok(()) + } +} + +impl Default for SearchComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SearchComponent { + fn name(&self) -> &'static str { + "search" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + *self.handle.write().await = Some(handle.clone()); + + // Spawn background task to load history into index + let index = self.index.clone(); + let db = handle.history_db().clone(); + let handle_for_loader = handle.clone(); + + self.loader_handle = Some(tokio::spawn(async move { + info!( + "Loading history into search index; page size = {}", + PAGE_SIZE + ); + let mut pager = db.all_paged(PAGE_SIZE, false, true); + loop { + match pager.next().await { + Ok(Some(histories)) => { + info!( + "Loading {} history entries into search index", + histories.len() + ); + index.read().await.add_histories(&histories); + } + Ok(None) => { + info!( + "Initial history load complete; {} unique commands indexed", + index.read().await.command_count() + ); + // Build initial frecency map with current settings + let settings = handle_for_loader.settings().await; + index.read().await.rebuild_frecency(&settings.search).await; + info!("Initial frecency map built"); + break; + } + Err(e) => { + tracing::error!("Failed to load history: {}", e); + break; + } + } + } + })); + + // Spawn background task to periodically refresh frecency + let index_for_frecency = self.index.clone(); + let handle_for_frecency = handle.clone(); + self.frecency_handle = Some(tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs( + FRECENCY_REFRESH_INTERVAL_SECS, + )); + loop { + interval.tick().await; + trace!("Refreshing frecency map"); + let settings = handle_for_frecency.settings().await; + index_for_frecency + .read() + .await + .rebuild_frecency(&settings.search) + .await; + } + })); + + tracing::info!("search component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + match event { + DaemonEvent::RecordsAdded(records) => { + debug!( + count = records.len(), + "Processing added records for search index" + ); + + let handle_guard = self.handle.read().await; + if let Some(handle) = handle_guard.as_ref() { + let histories: Vec<_> = handle + .history_db() + .query_history( + format!( + "select * from history where id in ({})", + records + .iter() + .map(|record| record.0.to_string()) + .collect::>() + .join(",") + ) + .as_str(), + ) + .await + .unwrap_or_default(); + + span!(Level::TRACE, "inject_records", count = histories.len()) + .in_scope(async || { + self.index.read().await.add_histories(&histories); + }) + .await; + } + } + DaemonEvent::HistoryStarted(history) => { + debug!(id = %history.id, command = %history.command, "History started (no index action)"); + } + DaemonEvent::HistoryEnded(history) => { + span!(Level::TRACE, "inject_history_ended") + .in_scope(async || { + self.index.read().await.add_history(history); + }) + .await; + } + DaemonEvent::HistoryPruned | DaemonEvent::HistoryRebuilt => { + info!("History store pruned or rebuilt, rebuilding search index"); + if let Err(e) = self.rebuild_index().await { + tracing::error!("Failed to rebuild search index: {}", e); + } + } + DaemonEvent::HistoryDeleted { ids } => { + info!( + count = ids.len(), + "History deleted, rebuilding search index" + ); + // For now, just rebuild the entire index. A more efficient implementation + // would remove specific items from the index. + if let Err(e) = self.rebuild_index().await { + tracing::error!("Failed to rebuild search index: {}", e); + } + } + DaemonEvent::SettingsReloaded => { + info!("Settings reloaded, rebuilding frecency map with new multipliers"); + let handle_guard = self.handle.read().await; + if let Some(handle) = handle_guard.as_ref() { + let settings = handle.settings().await; + self.index + .read() + .await + .rebuild_frecency(&settings.search) + .await; + } + } + // Events we don't care about + DaemonEvent::SyncCompleted { .. } + | DaemonEvent::SyncFailed { .. } + | DaemonEvent::ForceSync + | DaemonEvent::ShutdownRequested => {} + } + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + if let Some(handle) = self.loader_handle.take() { + handle.abort(); + } + if let Some(handle) = self.frecency_handle.take() { + handle.abort(); + } + tracing::info!("search component stopped"); + Ok(()) + } +} + +/// The gRPC service implementation. +pub struct SearchGrpcService { + index: Arc>, +} + +#[tonic::async_trait] +impl SearchSvc for SearchGrpcService { + type SearchStream = Pin> + Send>>; + + #[instrument(skip_all, level = Level::TRACE, name = "search_rpc")] + async fn search( + &self, + request: Request>, + ) -> Result, Status> { + let mut in_stream = request.into_inner(); + let index = self.index.clone(); + + // Create output channel + let (tx, rx) = tokio::sync::mpsc::channel::>(128); + + // Spawn task to handle incoming requests and send responses + tokio::spawn(async move { + while let Some(req) = in_stream.message().await.transpose() { + match req { + Ok(search_req) => { + let query = search_req.query; + let query_id = search_req.query_id; + let filter_mode: FilterMode = search_req + .filter_mode + .try_into() + .unwrap_or(FilterMode::Global); + let proto_context = search_req.context; + + debug!( + "search request: query = {}, query_id = {}, filter_mode = {}, context = {:?}", + query, + query_id, + filter_mode.as_str_name(), + proto_context + ); + + // Convert proto FilterMode + context to IndexFilterMode + let index_filter = convert_filter_mode(filter_mode, &proto_context); + + // Build QueryContext from proto context + let query_context = proto_context + .map(|ctx| QueryContext { + cwd: Some(with_trailing_slash(&ctx.cwd)), + git_root: ctx.git_root.map(|s| with_trailing_slash(&s)), + hostname: Some(ctx.hostname), + session_id: Some(ctx.session_id), + }) + .unwrap_or_default(); + + // Perform the search + let history_ids = + span!(Level::TRACE, "daemon_search_query", %query, query_id) + .in_scope(|| async { + let index = index.read().await; + index + .search(&query, index_filter, &query_context, RESULTS_LIMIT) + .await + }) + .await; + + // Convert history IDs to bytes + let ids: Vec> = history_ids + .iter() + .filter_map(|id| { + Uuid::parse_str(id) + .ok() + .map(|uuid| uuid.as_bytes().to_vec()) + }) + .collect(); + + if tx.send(Ok(SearchResponse { query_id, ids })).await.is_err() { + break; // Client disconnected + } + } + Err(e) => { + let _ = tx.send(Err(e)).await; + break; + } + } + } + }); + + // Convert receiver to stream + let out_stream = tokio_stream::wrappers::ReceiverStream::new(rx); + Ok(Response::new(Box::pin(out_stream))) + } +} + +/// Convert proto FilterMode and context to IndexFilterMode. +fn convert_filter_mode( + mode: FilterMode, + context: &Option, +) -> IndexFilterMode { + match (mode, context) { + (FilterMode::Global, _) => IndexFilterMode::Global, + (FilterMode::Directory, Some(ctx)) => { + IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) + } + (FilterMode::Workspace, Some(ctx)) => { + if let Some(ref git_root) = ctx.git_root { + IndexFilterMode::Workspace(with_trailing_slash(git_root)) + } else { + // Fall back to directory if no git root + IndexFilterMode::Directory(with_trailing_slash(&ctx.cwd)) + } + } + (FilterMode::Host, Some(ctx)) => IndexFilterMode::Host(ctx.hostname.clone()), + (FilterMode::Session, Some(ctx)) => IndexFilterMode::Session(ctx.session_id.clone()), + (FilterMode::SessionPreload, Some(ctx)) => { + // SessionPreload is similar to Session - filter by session + IndexFilterMode::Session(ctx.session_id.clone()) + } + // If no context provided, fall back to global + _ => IndexFilterMode::Global, + } +} + +#[cfg(windows)] +pub fn with_trailing_slash(s: &str) -> String { + if s.ends_with('\\') { + s.to_string() + } else { + format!("{}\\", s) + } +} + +#[cfg(not(windows))] +pub fn with_trailing_slash(s: &str) -> String { + if s.ends_with('/') { + s.to_string() + } else { + format!("{}/", s) + } +} diff --git a/crates/turtle/src/atuin_daemon/components/semantic.rs b/crates/turtle/src/atuin_daemon/components/semantic.rs new file mode 100644 index 00000000..a42fd5cb --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/semantic.rs @@ -0,0 +1,903 @@ +//! Semantic command capture component. +//! +//! This is a prototype in-memory store for completed command captures emitted +//! by atuin-pty-proxy. It keeps recent captures per Atuin session and indexes +//! them by history ID for AI tool lookup. + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use crate::atuin_client::history::{History, HistoryId}; +use eyre::Result; +use tokio::sync::Mutex; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{Level, instrument}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, + semantic::{ + CommandCapture, CommandOutputReply, CommandOutputRequest, OutputLine, RecordCommandsReply, + semantic_server::{Semantic as SemanticSvc, SemanticServer}, + }, +}; + +const MAX_SESSIONS: usize = 20; +const MAX_COMMANDS_PER_SESSION: usize = 128; +const MAX_BYTES_PER_SESSION: usize = 32 * 1024 * 1024; +const MAX_PENDING_HISTORIES: usize = 128; + +/// Stores completed command captures and associates them with history events. +pub struct SemanticComponent { + inner: Arc, +} + +struct SemanticComponentInner { + state: Mutex, +} + +#[derive(Default)] +struct SemanticState { + sessions: HashMap, + session_lru: VecDeque, + history_index: HashMap, + pending_histories: VecDeque, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct SessionId(String); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct CaptureId(u64); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CaptureRef { + session_id: SessionId, + capture_id: CaptureId, +} + +#[derive(Default)] +struct SessionCaptures { + next_id: u64, + records: VecDeque, + output_bytes: usize, +} + +struct StoredCapture { + id: CaptureId, + history_id: HistoryId, + output_bytes: usize, + record: SemanticCommandRecord, +} + +struct EvictedCapture { + history_id: HistoryId, + capture_id: CaptureId, +} + +#[derive(Debug, Clone)] +struct SemanticCommandRecord { + capture: CommandCapture, + history: Option, +} + +impl SemanticComponent { + pub fn new() -> Self { + Self { + inner: Arc::new(SemanticComponentInner { + state: Mutex::new(SemanticState::default()), + }), + } + } + + pub fn grpc_service(&self) -> SemanticServer { + SemanticServer::new(SemanticGrpcService { + inner: self.inner.clone(), + }) + } +} + +impl Default for SemanticComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SemanticComponent { + fn name(&self) -> &'static str { + "semantic" + } + + async fn start(&mut self, _handle: DaemonHandle) -> Result<()> { + tracing::info!("semantic component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::HistoryEnded(history) = event { + self.inner.record_history(history.clone()).await; + } + + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + let state = self.inner.state.lock().await; + tracing::info!( + sessions = state.sessions.len(), + records = state.record_count(), + indexed_histories = state.history_index.len(), + pending_histories = state.pending_histories.len(), + "semantic component stopped" + ); + Ok(()) + } +} + +impl SemanticComponentInner { + async fn record_capture(&self, capture: CommandCapture) -> bool { + let mut state = self.state.lock().await; + state.record_capture(capture) + } + + async fn record_history(&self, history: History) { + let mut state = self.state.lock().await; + state.record_history(history); + } + + async fn command_output(&self, request: &CommandOutputRequest) -> CommandOutputReply { + let mut state = self.state.lock().await; + state.command_output(request) + } +} + +impl SemanticState { + fn record_capture(&mut self, mut capture: CommandCapture) -> bool { + let Some(history_id) = history_id_from_str(capture.history_id.as_deref()) else { + tracing::debug!( + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without history id" + ); + return false; + }; + + let history = take_pending_history(&mut self.pending_histories, &history_id); + let Some(session_id) = capture + .session_id + .as_deref() + .and_then(|session_id| SessionId::try_from(session_id).ok()) + .or_else(|| { + history + .as_ref() + .and_then(|history| SessionId::try_from(history.session.as_str()).ok()) + }) + else { + tracing::debug!( + history_id = %history_id, + command_bytes = capture.command.len(), + prompt_bytes = capture.prompt.len(), + output_bytes = capture.output.len(), + output_truncated = capture.output_truncated, + "dropping semantic command capture without session id" + ); + return false; + }; + + capture.history_id = Some(history_id.to_string()); + capture.session_id = Some(session_id.to_string()); + if capture.output_observed_bytes == 0 { + capture.output_observed_bytes = capture.output.len() as u64; + } + + let record = SemanticCommandRecord { capture, history }; + log_record(&record, "recorded semantic command capture"); + self.push_record(session_id, history_id, record); + true + } + + fn record_history(&mut self, history: History) { + let history_id = history.id.clone(); + + if let Some(capture_ref) = self.history_index.get(&history_id).cloned() { + if let Some(stored) = self.stored_capture_mut(&capture_ref) { + stored.record.history = Some(history); + log_record( + &stored.record, + "associated semantic command capture with history", + ); + return; + } + + self.history_index.remove(&history_id); + } + + tracing::debug!( + id = %history.id, + command_bytes = history.command.len(), + "history ended before semantic capture arrived" + ); + push_pending_history(&mut self.pending_histories, history); + } + + fn command_output(&mut self, request: &CommandOutputRequest) -> CommandOutputReply { + let Some(history_id) = history_id_from_str(Some(&request.history_id)) else { + return command_output_not_found(); + }; + let Some(capture_ref) = self.history_index.get(&history_id).cloned() else { + return command_output_not_found(); + }; + + let Some(reply) = self.command_output_for_ref(&capture_ref, &request.ranges) else { + self.history_index.remove(&history_id); + return command_output_not_found(); + }; + + self.touch_session(&capture_ref.session_id); + reply + } + + fn command_output_for_ref( + &self, + capture_ref: &CaptureRef, + ranges: &[crate::atuin_daemon::semantic::OutputRange], + ) -> Option { + let stored = self + .sessions + .get(&capture_ref.session_id)? + .stored_capture(capture_ref.capture_id)?; + let output = &stored.record.capture.output; + let output_observed_bytes = stored + .record + .capture + .output_observed_bytes + .max(output.len() as u64); + + Some(CommandOutputReply { + found: true, + output: String::new(), + total_bytes: output.len() as u64, + total_lines: output.lines().count() as u64, + lines: select_output_ranges(output, ranges), + output_truncated: stored.record.capture.output_truncated, + output_observed_bytes, + }) + } + + fn push_record( + &mut self, + session_id: SessionId, + history_id: HistoryId, + record: SemanticCommandRecord, + ) { + self.touch_session(&session_id); + + let (capture_id, evicted) = { + let session = self.sessions.entry(session_id.clone()).or_default(); + session.push(history_id.clone(), record) + }; + + let capture_ref = CaptureRef { + session_id: session_id.clone(), + capture_id, + }; + self.history_index.insert(history_id, capture_ref); + + for evicted in evicted { + self.remove_history_index_if_matches( + &session_id, + &evicted.history_id, + evicted.capture_id, + ); + } + + self.expire_lru_sessions(); + } + + fn touch_session(&mut self, session_id: &SessionId) { + if let Some(index) = self.session_lru.iter().position(|id| id == session_id) { + self.session_lru.remove(index); + } + self.session_lru.push_back(session_id.clone()); + } + + fn expire_lru_sessions(&mut self) { + while self.session_lru.len() > MAX_SESSIONS { + let Some(session_id) = self.session_lru.pop_front() else { + break; + }; + let Some(session) = self.sessions.remove(&session_id) else { + continue; + }; + + for stored in session.records { + self.remove_history_index_if_matches(&session_id, &stored.history_id, stored.id); + } + } + } + + fn remove_history_index_if_matches( + &mut self, + session_id: &SessionId, + history_id: &HistoryId, + capture_id: CaptureId, + ) { + if self + .history_index + .get(history_id) + .is_some_and(|capture_ref| { + &capture_ref.session_id == session_id && capture_ref.capture_id == capture_id + }) + { + self.history_index.remove(history_id); + } + } + + fn stored_capture_mut(&mut self, capture_ref: &CaptureRef) -> Option<&mut StoredCapture> { + self.sessions + .get_mut(&capture_ref.session_id)? + .stored_capture_mut(capture_ref.capture_id) + } + + fn record_count(&self) -> usize { + self.sessions + .values() + .map(|session| session.records.len()) + .sum() + } +} + +impl SessionCaptures { + fn push( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + ) -> (CaptureId, Vec) { + self.push_with_limits( + history_id, + record, + MAX_COMMANDS_PER_SESSION, + MAX_BYTES_PER_SESSION, + ) + } + + fn push_with_limits( + &mut self, + history_id: HistoryId, + record: SemanticCommandRecord, + max_commands: usize, + max_output_bytes: usize, + ) -> (CaptureId, Vec) { + let capture_id = CaptureId(self.next_id); + self.next_id = self.next_id.saturating_add(1); + let output_bytes = record.capture.output.len(); + self.output_bytes = self.output_bytes.saturating_add(output_bytes); + self.records.push_back(StoredCapture { + id: capture_id, + history_id, + output_bytes, + record, + }); + + ( + capture_id, + self.evict_to_limits(max_commands, max_output_bytes), + ) + } + + fn evict_to_limits( + &mut self, + max_commands: usize, + max_output_bytes: usize, + ) -> Vec { + let mut evicted = Vec::new(); + while self.records.len() > max_commands || self.output_bytes > max_output_bytes { + let Some(record) = self.records.pop_front() else { + break; + }; + self.output_bytes = self.output_bytes.saturating_sub(record.output_bytes); + evicted.push(EvictedCapture { + history_id: record.history_id, + capture_id: record.id, + }); + } + evicted + } + + fn stored_capture(&self, capture_id: CaptureId) -> Option<&StoredCapture> { + self.records.iter().find(|record| record.id == capture_id) + } + + fn stored_capture_mut(&mut self, capture_id: CaptureId) -> Option<&mut StoredCapture> { + self.records + .iter_mut() + .find(|record| record.id == capture_id) + } +} + +impl TryFrom<&str> for SessionId { + type Error = (); + + fn try_from(value: &str) -> std::result::Result { + let value = value.trim(); + if value.is_empty() { + return Err(()); + } + + Ok(Self(value.to_string())) + } +} + +impl TryFrom for SessionId { + type Error = (); + + fn try_from(value: String) -> std::result::Result { + Self::try_from(value.as_str()) + } +} + +impl AsRef for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl Display for SessionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +pub struct SemanticGrpcService { + inner: Arc, +} + +#[tonic::async_trait] +impl SemanticSvc for SemanticGrpcService { + #[instrument(skip_all, level = Level::INFO)] + async fn record_commands( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); + let mut accepted = 0_u64; + + while let Some(capture) = stream.message().await? { + if self.inner.record_capture(capture).await { + accepted += 1; + } + } + + Ok(Response::new(RecordCommandsReply { accepted })) + } + + #[instrument(skip_all, level = Level::INFO)] + async fn command_output( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + if request.history_id.trim().is_empty() { + return Err(Status::invalid_argument("history_id is required")); + } + + Ok(Response::new(self.inner.command_output(&request).await)) + } +} + +fn history_id_from_str(value: Option<&str>) -> Option { + let value = value?.trim(); + (!value.is_empty()).then(|| HistoryId(value.to_string())) +} + +fn take_pending_history( + histories: &mut VecDeque, + history_id: &HistoryId, +) -> Option { + let index = histories + .iter() + .position(|history| &history.id == history_id)?; + histories.remove(index) +} + +fn push_pending_history(histories: &mut VecDeque, history: History) { + if let Some(index) = histories + .iter() + .position(|pending| pending.id == history.id) + { + histories.remove(index); + } + + histories.push_back(history); + trim_front(histories, MAX_PENDING_HISTORIES); +} + +fn trim_front(records: &mut VecDeque, max_len: usize) { + while records.len() > max_len { + records.pop_front(); + } +} + +fn command_output_not_found() -> CommandOutputReply { + CommandOutputReply { + found: false, + output: String::new(), + total_bytes: 0, + total_lines: 0, + lines: Vec::new(), + output_truncated: false, + output_observed_bytes: 0, + } +} + +fn select_output_ranges( + output: &str, + ranges: &[crate::atuin_daemon::semantic::OutputRange], +) -> Vec { + let lines: Vec<&str> = output.lines().collect(); + if lines.is_empty() { + return Vec::new(); + } + + let ranges = if ranges.is_empty() { + vec![crate::atuin_daemon::semantic::OutputRange { start: 0, end: 999 }] + } else { + ranges.to_vec() + }; + + let mut ranges = ranges + .into_iter() + .filter_map(|range| normalize_line_range(range.start, range.end, lines.len())) + .collect::>(); + ranges.sort_unstable_by_key(|(start, _)| *start); + + let mut merged: Vec<(usize, usize)> = Vec::new(); + for (start, end) in ranges { + match merged.last_mut() { + Some((_, merged_end)) if start <= merged_end.saturating_add(1) => { + *merged_end = (*merged_end).max(end); + } + _ => merged.push((start, end)), + } + } + + merged + .into_iter() + .flat_map(|(start, end)| { + lines[start..=end] + .iter() + .enumerate() + .map(move |(offset, line)| OutputLine { + line_number: (start + offset + 1) as u64, + content: (*line).to_string(), + }) + }) + .collect() +} + +fn normalize_line_range(start: i64, end: i64, line_count: usize) -> Option<(usize, usize)> { + let line_count = i64::try_from(line_count).ok()?; + let start = if start < 0 { line_count + start } else { start }; + let end = if end < 0 { line_count + end } else { end }; + + if end < 0 || start >= line_count { + return None; + } + + let start = start.max(0); + let end = end.min(line_count - 1); + + (start <= end).then_some((start as usize, end as usize)) +} + +fn log_record(record: &SemanticCommandRecord, message: &'static str) { + let history_id = record.capture.history_id.as_deref().unwrap_or(""); + let associated_history_id = record + .history + .as_ref() + .map(|history| history.id.to_string()); + let exit = record.history.as_ref().map(|history| history.exit); + let duration = record.history.as_ref().map(|history| history.duration); + let author = record + .history + .as_ref() + .map(|history| history.author.as_str()); + let session_id = record.capture.session_id.as_deref(); + + tracing::debug!( + history_id = %history_id, + associated_history_id = ?associated_history_id, + session_id = ?session_id, + command_bytes = record.capture.command.len(), + prompt_bytes = record.capture.prompt.len(), + output_bytes = record.capture.output.len(), + output_truncated = record.capture.output_truncated, + output_observed_bytes = record.capture.output_observed_bytes, + capture_exit_code = ?record.capture.exit_code, + history_exit = ?exit, + duration = ?duration, + author = ?author, + "{message}" + ); +} + +#[cfg(test)] +mod tests { + use super::*; + use time::OffsetDateTime; + + fn history(id: &str, session: &str, command: &str) -> History { + History { + id: HistoryId(id.to_string()), + timestamp: OffsetDateTime::UNIX_EPOCH, + duration: 0, + exit: 0, + command: command.to_string(), + cwd: String::new(), + session: session.to_string(), + hostname: String::new(), + author: String::new(), + intent: None, + deleted_at: None, + } + } + + fn capture(history_id: Option<&str>, session_id: Option<&str>, output: &str) -> CommandCapture { + CommandCapture { + prompt: String::new(), + command: String::new(), + output: output.to_string(), + exit_code: None, + history_id: history_id.map(str::to_string), + session_id: session_id.map(str::to_string), + output_truncated: false, + output_observed_bytes: output.len() as u64, + } + } + + fn command_output(state: &mut SemanticState, history_id: &str) -> CommandOutputReply { + state.command_output(&CommandOutputRequest { + history_id: history_id.to_string(), + ranges: Vec::new(), + }) + } + + fn output_line(line_number: u64, content: &str) -> OutputLine { + OutputLine { + line_number, + content: content.to_string(), + } + } + + #[test] + fn drops_capture_without_history_id() { + let mut state = SemanticState::default(); + + assert!(!state.record_capture(capture(None, Some("session-1"), "output"))); + assert!(!command_output(&mut state, "id-1").found); + assert_eq!(state.record_count(), 0); + } + + #[test] + fn stores_capture_by_session_and_history_id() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.found); + assert_eq!(reply.total_bytes, 6); + assert_eq!(reply.output_observed_bytes, 6); + assert_eq!(reply.lines, vec![output_line(1, "output")]); + } + + #[test] + fn uses_pending_history_session_when_capture_session_is_missing() { + let mut state = SemanticState::default(); + + state.record_history(history("id-1", "session-from-history", "cargo test")); + assert!(state.record_capture(capture(Some("id-1"), None, "output"))); + + assert!( + state + .sessions + .contains_key(&SessionId("session-from-history".to_string())) + ); + assert!(command_output(&mut state, "id-1").found); + } + + #[test] + fn associates_history_by_id_after_capture_arrives() { + let mut state = SemanticState::default(); + + assert!(state.record_capture(capture(Some("id-1"), Some("session-1"), "output"))); + state.record_history(history("id-1", "session-1", "different command")); + + let capture_ref = state + .history_index + .get(&HistoryId("id-1".to_string())) + .unwrap(); + let stored = state + .sessions + .get(&capture_ref.session_id) + .unwrap() + .stored_capture(capture_ref.capture_id) + .unwrap(); + assert!(stored.record.history.is_some()); + } + + #[test] + fn evicts_oldest_command_when_session_ring_is_full() { + let mut state = SemanticState::default(); + + for index in 0..=MAX_COMMANDS_PER_SESSION { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some("session-1"), + "output", + ))); + } + + assert!(!command_output(&mut state, "id-0").found); + assert!(command_output(&mut state, &format!("id-{MAX_COMMANDS_PER_SESSION}")).found); + assert_eq!(state.record_count(), MAX_COMMANDS_PER_SESSION); + } + + #[test] + fn evicts_oldest_session_after_lru_limit() { + let mut state = SemanticState::default(); + + for index in 0..MAX_SESSIONS { + assert!(state.record_capture(capture( + Some(&format!("id-{index}")), + Some(&format!("session-{index}")), + "output", + ))); + } + assert!(command_output(&mut state, "id-0").found); + + assert!(state.record_capture(capture(Some("new-id"), Some("new-session"), "output",))); + + assert!(command_output(&mut state, "id-0").found); + assert!(!command_output(&mut state, "id-1").found); + assert!(command_output(&mut state, "new-id").found); + assert_eq!(state.sessions.len(), MAX_SESSIONS); + } + + #[test] + fn evicts_by_session_byte_limit() { + let mut session = SessionCaptures::default(); + let first_output = "x".repeat(10); + let second_output = "y"; + let (_, evicted_first) = session.push_with_limits( + HistoryId("first".to_string()), + SemanticCommandRecord { + capture: capture(Some("first"), Some("session-1"), &first_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + assert!(evicted_first.is_empty()); + + let (_, evicted_second) = session.push_with_limits( + HistoryId("second".to_string()), + SemanticCommandRecord { + capture: capture(Some("second"), Some("session-1"), second_output), + history: None, + }, + MAX_COMMANDS_PER_SESSION, + 10, + ); + + assert_eq!(evicted_second.len(), 1); + assert_eq!(evicted_second[0].history_id, HistoryId("first".to_string())); + assert_eq!(session.records.len(), 1); + assert_eq!(session.output_bytes, 1); + } + + #[test] + fn command_output_reports_truncation_metadata() { + let mut state = SemanticState::default(); + let mut capture = capture(Some("id-1"), Some("session-1"), "partial"); + capture.output_truncated = true; + capture.output_observed_bytes = 1024; + + assert!(state.record_capture(capture)); + + let reply = command_output(&mut state, "id-1"); + assert!(reply.output_truncated); + assert_eq!(reply.total_bytes, 7); + assert_eq!(reply.output_observed_bytes, 1024); + } + + #[test] + fn output_ranges_are_line_based_inclusive_and_support_negative_offsets() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 1, end: 2 }, + crate::atuin_daemon::semantic::OutputRange { start: -2, end: -1 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(2, "one"), + output_line(3, "two"), + output_line(4, "three"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn output_ranges_merge_overlaps_and_adjacent_ranges() { + let output = (0..100) + .map(|n| format!("line {n}")) + .collect::>() + .join("\n"); + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 0, end: 100 }, + crate::atuin_daemon::semantic::OutputRange { + start: -100, + end: -1, + }, + ]; + + let selected = select_output_ranges(&output, &ranges); + + assert_eq!(selected.len(), 100); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(100, "line 99"))); + } + + #[test] + fn output_ranges_can_leave_gaps_for_client_formatting() { + let output = "zero\none\ntwo\nthree\nfour"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 0, end: 1 }, + crate::atuin_daemon::semantic::OutputRange { start: 4, end: 4 }, + ]; + + assert_eq!( + select_output_ranges(output, &ranges), + vec![ + output_line(1, "zero"), + output_line(2, "one"), + output_line(5, "four"), + ] + ); + } + + #[test] + fn empty_output_ranges_default_to_first_thousand_lines() { + let output = (0..1001) + .map(|n| format!("line {n}")) + .collect::>() + .join("\n"); + + let selected = select_output_ranges(&output, &[]); + + assert_eq!(selected.len(), 1000); + assert_eq!(selected.first(), Some(&output_line(1, "line 0"))); + assert_eq!(selected.last(), Some(&output_line(1000, "line 999"))); + } + + #[test] + fn output_ranges_skip_ranges_fully_outside_output() { + let output = "zero\none\ntwo"; + let ranges = vec![ + crate::atuin_daemon::semantic::OutputRange { start: 10, end: 20 }, + crate::atuin_daemon::semantic::OutputRange { + start: -20, + end: -10, + }, + ]; + + assert_eq!(select_output_ranges(output, &ranges), Vec::new()); + } +} diff --git a/crates/turtle/src/atuin_daemon/components/sync.rs b/crates/turtle/src/atuin_daemon/components/sync.rs new file mode 100644 index 00000000..c76fb71b --- /dev/null +++ b/crates/turtle/src/atuin_daemon/components/sync.rs @@ -0,0 +1,279 @@ +//! Sync component. +//! +//! Handles periodic synchronization with the Atuin cloud server. + +use std::time::Duration; + +use eyre::Result; +use rand::Rng; +use tokio::sync::mpsc; +use tokio::time::{self, MissedTickBehavior}; + +use crate::atuin_client::{history::store::HistoryStore, record::sync, settings::Settings}; + +use crate::atuin_daemon::{ + daemon::{Component, DaemonHandle}, + events::DaemonEvent, +}; + +/// Commands that can be sent to the sync task. +enum SyncCommand { + /// Trigger an immediate sync. + ForceSync, + /// Stop the sync loop. + Stop, +} + +/// Sync state - tracks whether we're in normal operation or retrying after failure. +#[derive(Clone, Copy, PartialEq, Eq)] +enum SyncState { + /// Normal operation. Periodic syncs only run if auto_sync is enabled. + Idle, + /// Retrying after a sync failure. Retries continue regardless of auto_sync + /// until the sync succeeds. + Retrying, +} + +/// Sync component - handles periodic cloud synchronization. +/// +/// This component: +/// - Runs a background sync loop on a configurable interval +/// - Implements exponential backoff on sync failures +/// - Responds to ForceSync events for immediate sync +/// - Emits SyncCompleted/SyncFailed events +pub struct SyncComponent { + task_handle: Option>, + command_tx: Option>, +} + +impl SyncComponent { + /// Create a new sync component. + pub fn new() -> Self { + Self { + task_handle: None, + command_tx: None, + } + } +} + +impl Default for SyncComponent { + fn default() -> Self { + Self::new() + } +} + +#[tonic::async_trait] +impl Component for SyncComponent { + fn name(&self) -> &'static str { + "sync" + } + + async fn start(&mut self, handle: DaemonHandle) -> Result<()> { + let (cmd_tx, cmd_rx) = mpsc::channel(16); + self.command_tx = Some(cmd_tx); + + // Spawn the sync loop with its own copy of the handle + self.task_handle = Some(tokio::spawn(sync_loop(handle, cmd_rx))); + + tracing::info!("sync component started"); + Ok(()) + } + + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { + if let DaemonEvent::ForceSync = event { + tracing::info!("force sync requested"); + if let Some(tx) = &self.command_tx { + let _ = tx.send(SyncCommand::ForceSync).await; + } + } + Ok(()) + } + + async fn stop(&mut self) -> Result<()> { + if let Some(tx) = &self.command_tx { + let _ = tx.send(SyncCommand::Stop).await; + } + if let Some(handle) = self.task_handle.take() { + // Give the task a moment to shut down gracefully + let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; + } + tracing::info!("sync component stopped"); + Ok(()) + } +} + +/// The main sync loop. +/// +/// This runs in a spawned task and handles periodic sync as well as +/// force sync requests. +async fn sync_loop(handle: DaemonHandle, mut cmd_rx: mpsc::Receiver) { + tracing::info!("sync loop starting"); + + // Clone settings since we need them across await points + let settings = handle.settings().await.clone(); + let host_id = match Settings::host_id().await { + Ok(id) => id, + Err(e) => { + tracing::error!("failed to get host id, sync disabled: {e}"); + return; + } + }; + + // Create the stores we need + let encryption_key = *handle.encryption_key(); + let history_store = HistoryStore::new(handle.store().clone(), host_id, encryption_key); + + // Don't backoff by more than 30 mins (with a random jitter of up to 1 min) + let max_interval: f64 = 60.0 * 30.0 + rand::thread_rng().gen_range(0.0..60.0); + + let mut ticker = time::interval(time::Duration::from_secs(settings.daemon.sync_frequency)); + + // IMPORTANT: without this, if we miss ticks because a sync takes ages or is otherwise delayed, + // we may end up running a lot of syncs in a hot loop. + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + let mut sync_state = SyncState::Idle; + + loop { + tokio::select! { + _ = ticker.tick() => { + let settings = handle.settings().await; + + // Skip periodic ticks if auto_sync is disabled AND we're not retrying + // a previous failure. Retries must continue regardless of auto_sync. + if !settings.auto_sync && sync_state == SyncState::Idle { + tracing::debug!("auto_sync disabled, skipping periodic sync tick"); + continue; + } + + sync_state = do_sync_tick( + &handle, + &history_store, + &mut ticker, + max_interval, + &settings, + ).await; + } + cmd = cmd_rx.recv() => { + match cmd { + Some(SyncCommand::ForceSync) => { + tracing::info!("executing force sync"); + let settings = handle.settings().await; + sync_state = do_sync_tick( + &handle, + &history_store, + &mut ticker, + max_interval, + &settings, + ).await; + } + Some(SyncCommand::Stop) | None => { + tracing::info!("sync loop stopping"); + break; + } + } + } + } + } +} + +/// Execute a single sync tick. +/// +/// Returns the new sync state: `Idle` on success, `Retrying` on failure. +async fn do_sync_tick( + handle: &DaemonHandle, + history_store: &HistoryStore, + ticker: &mut time::Interval, + max_interval: f64, + settings: &Settings, +) -> SyncState { + tracing::info!("sync tick"); + + // Check if logged in + let logged_in = match settings.logged_in().await { + Ok(v) => v, + Err(e) => { + tracing::warn!("failed to check login status, skipping sync tick: {e}"); + return SyncState::Idle; + } + }; + + if !logged_in { + tracing::debug!("not logged in, skipping sync tick"); + return SyncState::Idle; + } + + // Perform the sync + let res = sync::sync(settings, handle.store(), handle.encryption_key()).await; + + match res { + Err(e) => { + tracing::error!("sync tick failed with {e}"); + + // Emit failure event + handle.emit(DaemonEvent::SyncFailed { + error: e.to_string(), + }); + + // Exponential backoff + let mut rng = rand::thread_rng(); + let mut new_interval = ticker.period().as_secs_f64() * rng.gen_range(2.0..2.2); + + if new_interval > max_interval { + new_interval = max_interval; + } + + *ticker = time::interval_at( + tokio::time::Instant::now() + Duration::from_secs(new_interval as u64), + time::Duration::from_secs(new_interval as u64), + ); + ticker.reset_after(time::Duration::from_secs(new_interval as u64)); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + + tracing::error!("backing off, next sync tick in {new_interval}"); + + SyncState::Retrying + } + Ok((uploaded_count, downloaded_records)) => { + tracing::info!( + uploaded = uploaded_count, + downloaded = downloaded_records.len(), + "sync complete" + ); + + // Build history from downloaded records + if let Err(e) = history_store + .incremental_build(handle.history_db(), &downloaded_records) + .await + { + tracing::error!("failed to build history from downloaded records: {e}"); + } + + // Emit the records added event (for search indexing) + handle.emit(DaemonEvent::RecordsAdded(downloaded_records.clone())); + + // Emit sync completed event + handle.emit(DaemonEvent::SyncCompleted { + uploaded: uploaded_count as usize, + downloaded: downloaded_records.len(), + }); + + // Reset backoff on success + if ticker.period().as_secs() != settings.daemon.sync_frequency { + *ticker = time::interval_at( + tokio::time::Instant::now() + + Duration::from_secs(settings.daemon.sync_frequency), + time::Duration::from_secs(settings.daemon.sync_frequency), + ); + ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + } + + // Store sync time + if let Err(e) = Settings::save_sync_time().await { + tracing::error!("failed to save sync time: {e}"); + } + + SyncState::Idle + } + } +} diff --git a/crates/turtle/src/atuin_daemon/control/mod.rs b/crates/turtle/src/atuin_daemon/control/mod.rs new file mode 100644 index 00000000..afb29c57 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/control/mod.rs @@ -0,0 +1,12 @@ +//! Control module for external event injection. +//! +//! This module provides the gRPC service that allows external processes +//! (like CLI commands) to inject events into the daemon's event bus. + +mod service; + +// Include the generated proto code +tonic::include_proto!("control"); + +// Re-export the service +pub use service::ControlService; diff --git a/crates/turtle/src/atuin_daemon/control/service.rs b/crates/turtle/src/atuin_daemon/control/service.rs new file mode 100644 index 00000000..cb2ff74e --- /dev/null +++ b/crates/turtle/src/atuin_daemon/control/service.rs @@ -0,0 +1,71 @@ +//! Control service implementation. +//! +//! This gRPC service allows external processes (like CLI commands) to inject +//! events into the daemon's event bus. + +use crate::atuin_client::history::HistoryId; +use tonic::{Request, Response, Status}; +use tracing::{Level, info, instrument}; + +use super::{ + SendEventRequest, SendEventResponse, + control_server::{Control, ControlServer}, + send_event_request::Event, +}; +use crate::atuin_daemon::{daemon::DaemonHandle, events::DaemonEvent}; + +/// The Control gRPC service. +/// +/// This service is used by external processes to inject events into the daemon. +/// It's not a component - it's part of the daemon's core infrastructure. +pub struct ControlService { + handle: DaemonHandle, +} + +impl ControlService { + /// Create a new control service with the given daemon handle. + pub fn new(handle: DaemonHandle) -> Self { + Self { handle } + } + + /// Get a tonic server for this service. + pub fn into_server(self) -> ControlServer { + ControlServer::new(self) + } +} + +#[tonic::async_trait] +impl Control for ControlService { + #[instrument(skip_all, level = Level::INFO, name = "control_send_event")] + async fn send_event( + &self, + request: Request, + ) -> Result, Status> { + let req = request.into_inner(); + + let event = req + .event + .ok_or_else(|| Status::invalid_argument("event is required"))?; + + let daemon_event = proto_event_to_daemon_event(event)?; + + info!(?daemon_event, "received control event"); + self.handle.emit(daemon_event); + + Ok(Response::new(SendEventResponse {})) + } +} + +/// Convert a proto event to a daemon event. +fn proto_event_to_daemon_event(event: Event) -> Result { + match event { + Event::HistoryPruned(_) => Ok(DaemonEvent::HistoryPruned), + Event::HistoryRebuilt(_) => Ok(DaemonEvent::HistoryRebuilt), + Event::HistoryDeleted(e) => Ok(DaemonEvent::HistoryDeleted { + ids: e.ids.into_iter().map(HistoryId).collect(), + }), + Event::ForceSync(_) => Ok(DaemonEvent::ForceSync), + Event::SettingsReloaded(_) => Ok(DaemonEvent::SettingsReloaded), + Event::Shutdown(_) => Ok(DaemonEvent::ShutdownRequested), + } +} diff --git a/crates/turtle/src/atuin_daemon/daemon.rs b/crates/turtle/src/atuin_daemon/daemon.rs new file mode 100644 index 00000000..77c0d8a5 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/daemon.rs @@ -0,0 +1,458 @@ +//! Core daemon infrastructure. +//! +//! This module provides the foundational types for building the atuin daemon: +//! +//! - [`DaemonState`]: Shared state owned by the daemon +//! - [`DaemonHandle`]: A lightweight, cloneable handle for accessing daemon state +//! - [`Component`]: A trait for implementing daemon components +//! - [`Daemon`]: The main daemon orchestrator +//! - [`DaemonBuilder`]: Builder for constructing and configuring the daemon + +use std::sync::Arc; + +use crate::atuin_client::{ + database::Sqlite as HistoryDatabase, encryption, record::sqlite_store::SqliteStore, + settings::Settings, +}; +use eyre::{Context, Result}; +use tokio::sync::{RwLock, broadcast}; + +use crate::atuin_daemon::events::DaemonEvent; + +// ============================================================================ +// DaemonState +// ============================================================================ + +/// Shared state owned by the daemon. +/// +/// This contains all the resources that components and services need access to. +/// The state is wrapped in an `Arc` and accessed via [`DaemonHandle`]. +pub struct DaemonState { + // Event bus + event_tx: broadcast::Sender, + + // Configuration (mutable - can be reloaded) + settings: RwLock, + + // Encryption key (immutable - derived at startup) + encryption_key: [u8; 32], + + // Database handles + history_db: HistoryDatabase, + store: SqliteStore, +} + +// ============================================================================ +// DaemonHandle +// ============================================================================ + +/// A lightweight handle to the daemon's shared state. +/// +/// This is the primary way for components, gRPC services, and spawned tasks to +/// interact with the daemon. It provides access to: +/// +/// - Event emission and subscription +/// - Configuration (settings, encryption key) +/// - Database handles +/// +/// The handle is cheaply cloneable (wraps an `Arc`) and can be freely passed +/// around to any code that needs daemon access. +/// +/// # Example +/// +/// ```ignore +/// // Emit an event +/// handle.emit(DaemonEvent::HistoryPruned); +/// +/// // Access settings +/// let settings = handle.settings().await; +/// let sync_freq = settings.daemon.sync_frequency; +/// +/// // Access database +/// let history = handle.history_db().load(id).await?; +/// ``` +#[derive(Clone)] +pub struct DaemonHandle { + state: Arc, +} + +impl DaemonHandle { + // ---- Events ---- + + /// Emit an event to the daemon's event bus. + /// + /// This is fire-and-forget - if no receivers are listening (which shouldn't + /// happen in normal operation), the event is dropped silently. + pub fn emit(&self, event: DaemonEvent) { + if let Err(e) = self.state.event_tx.send(event) { + tracing::warn!("failed to emit event (no receivers?): {e}"); + } + } + + /// Subscribe to the event bus. + /// + /// Returns a receiver that will receive all events emitted after this call. + /// Useful for components that need to listen for events outside of the + /// normal `handle_event` callback flow. + pub fn subscribe(&self) -> broadcast::Receiver { + self.state.event_tx.subscribe() + } + + /// Request graceful shutdown of the daemon. + pub fn shutdown(&self) { + self.emit(DaemonEvent::ShutdownRequested); + } + + // ---- Configuration ---- + + /// Get the current settings. + /// + /// This acquires a read lock on the settings. For most use cases, clone + /// the settings if you need to hold onto them. + pub async fn settings(&self) -> tokio::sync::RwLockReadGuard<'_, Settings> { + self.state.settings.read().await + } + + /// Reload settings from disk and emit a SettingsReloaded event. + /// + /// Components listening for `SettingsReloaded` can then re-read settings + /// via `handle.settings()` to pick up the changes. + pub async fn reload_settings(&self) -> Result<()> { + let new_settings = Settings::new()?; + self.apply_settings(new_settings).await; + Ok(()) + } + + /// Apply already-loaded settings and emit a SettingsReloaded event. + /// + /// Use this when settings have already been loaded (e.g., from a file watcher) + /// to avoid parsing the config file twice. + pub async fn apply_settings(&self, settings: Settings) { + *self.state.settings.write().await = settings; + self.emit(DaemonEvent::SettingsReloaded); + tracing::info!("settings applied"); + } + + /// Get the encryption key. + pub fn encryption_key(&self) -> &[u8; 32] { + &self.state.encryption_key + } + + // ---- Database ---- + + /// Get a reference to the history database. + pub fn history_db(&self) -> &HistoryDatabase { + &self.state.history_db + } + + /// Get a reference to the record store. + pub fn store(&self) -> &SqliteStore { + &self.state.store + } +} + +impl std::fmt::Debug for DaemonHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DaemonHandle").finish_non_exhaustive() + } +} + +// ============================================================================ +// Component Trait +// ============================================================================ + +/// A daemon component that handles a specific domain. +/// +/// Components are the building blocks of the daemon. Each component: +/// +/// - Has a unique name for logging and debugging +/// - Can optionally expose gRPC services +/// - Receives a [`DaemonHandle`] on startup for accessing daemon resources +/// - Handles events from the event bus +/// - Performs cleanup on shutdown +/// +/// # Lifecycle +/// +/// 1. **Construction**: Component is created (usually via `new()`) +/// 2. **Start**: `start()` is called with a [`DaemonHandle`] +/// 3. **Running**: `handle_event()` is called for each event on the bus +/// 4. **Shutdown**: `stop()` is called for cleanup +/// +/// # Example +/// +/// ```ignore +/// pub struct MyComponent { +/// handle: Option, +/// } +/// +/// #[async_trait] +/// impl Component for MyComponent { +/// fn name(&self) -> &'static str { "my-component" } +/// +/// async fn start(&mut self, handle: DaemonHandle) -> Result<()> { +/// self.handle = Some(handle); +/// Ok(()) +/// } +/// +/// async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()> { +/// match event { +/// DaemonEvent::SomeEvent => { +/// // Handle the event +/// if let Some(handle) = &self.handle { +/// handle.emit(DaemonEvent::ResponseEvent); +/// } +/// } +/// _ => {} +/// } +/// Ok(()) +/// } +/// +/// async fn stop(&mut self) -> Result<()> { +/// Ok(()) +/// } +/// } +/// ``` +#[tonic::async_trait] +pub trait Component: Send + Sync { + /// Human-readable name for logging and debugging. + fn name(&self) -> &'static str; + + /// Called once at startup. + /// + /// Store the handle if you need to emit events or access daemon resources + /// later. The handle is cheaply cloneable, so feel free to clone it for + /// spawned tasks. + async fn start(&mut self, handle: DaemonHandle) -> Result<()>; + + /// Handle an incoming event. + /// + /// Called for every event on the bus. To emit new events in response, + /// use the handle stored during `start()`. Events emitted here will be + /// processed in subsequent event loop iterations. + async fn handle_event(&mut self, event: &DaemonEvent) -> Result<()>; + + /// Called on graceful shutdown. + /// + /// Use this to clean up resources, abort spawned tasks, etc. + async fn stop(&mut self) -> Result<()>; +} + +// ============================================================================ +// Daemon +// ============================================================================ + +/// The main daemon orchestrator. +/// +/// The daemon manages components, runs the event loop, and coordinates startup +/// and shutdown. It is constructed via [`DaemonBuilder`]. +/// +/// # Event Loop +/// +/// The daemon runs a simple event loop: +/// +/// 1. Wait for an event on the bus +/// 2. Dispatch the event to all components (in registration order) +/// 3. Components may emit new events in response +/// 4. Repeat until `ShutdownRequested` is received +/// +/// Events emitted during handling are queued and processed in subsequent +/// iterations, ensuring the loop eventually drains. +pub struct Daemon { + components: Vec>, + handle: DaemonHandle, +} + +impl Daemon { + /// Create a new daemon builder. + pub fn builder(settings: Settings) -> DaemonBuilder { + DaemonBuilder::new(settings) + } + + /// Get a clone of the daemon handle. + /// + /// The handle can be used to emit events, access settings, etc. + pub fn handle(&self) -> DaemonHandle { + self.handle.clone() + } + + /// Start all components. + /// + /// This must be called before `run_event_loop()`. It initializes all + /// registered components with the daemon handle. + pub async fn start_components(&mut self) -> Result<()> { + for component in &mut self.components { + tracing::info!(component = component.name(), "starting component"); + component + .start(self.handle.clone()) + .await + .with_context(|| format!("failed to start component: {}", component.name()))?; + } + Ok(()) + } + + /// Run the daemon event loop. + /// + /// This processes events until a ShutdownRequested event is received. + /// Components must be started first via `start_components()`. + pub async fn run_event_loop(&mut self) -> Result<()> { + let mut event_rx = self.handle.subscribe(); + loop { + match event_rx.recv().await { + Ok(DaemonEvent::ShutdownRequested) => { + tracing::info!("shutdown requested, stopping daemon"); + break; + } + Ok(event) => { + tracing::debug!(?event, "processing event"); + self.dispatch_event(&event).await; + } + Err(broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + skipped = n, + "event receiver lagged, some events were dropped" + ); + } + Err(broadcast::error::RecvError::Closed) => { + tracing::info!("event bus closed, stopping daemon"); + break; + } + } + } + Ok(()) + } + + /// Stop all components. + /// + /// This performs graceful shutdown of all components. + pub async fn stop_components(&mut self) { + for component in &mut self.components { + tracing::info!(component = component.name(), "stopping component"); + if let Err(e) = component.stop().await { + tracing::error!( + component = component.name(), + error = ?e, + "error stopping component" + ); + } + } + tracing::info!("all components stopped"); + } + + /// Run the daemon. + /// + /// This is a convenience method that starts components, runs the event loop, + /// and handles shutdown. It does not return until the daemon is shut down. + pub async fn run(mut self) -> Result<()> { + self.start_components().await?; + self.run_event_loop().await?; + self.stop_components().await; + tracing::info!("daemon stopped"); + Ok(()) + } + + async fn dispatch_event(&mut self, event: &DaemonEvent) { + for component in &mut self.components { + if let Err(e) = component.handle_event(event).await { + tracing::error!( + component = component.name(), + error = ?e, + "error handling event" + ); + } + } + } +} + +// ============================================================================ +// DaemonBuilder +// ============================================================================ + +/// Builder for constructing a [`Daemon`]. +/// +/// # Example +/// +/// ```ignore +/// let daemon = Daemon::builder(settings) +/// .store(store) +/// .history_db(history_db) +/// .component(HistoryComponent::new()) +/// .component(SearchComponent::new()) +/// .component(SyncComponent::new()) +/// .build() +/// .await?; +/// +/// daemon.run().await?; +/// ``` +pub struct DaemonBuilder { + settings: Settings, + store: Option, + history_db: Option, + components: Vec>, +} + +impl DaemonBuilder { + /// Create a new daemon builder with the given settings. + pub fn new(settings: Settings) -> Self { + Self { + settings, + store: None, + history_db: None, + components: Vec::new(), + } + } + + /// Set the record store. + pub fn store(mut self, store: SqliteStore) -> Self { + self.store = Some(store); + self + } + + /// Set the history database. + pub fn history_db(mut self, db: HistoryDatabase) -> Self { + self.history_db = Some(db); + self + } + + /// Register a component. + /// + /// Components are started in registration order and stopped in reverse order. + pub fn component(mut self, component: impl Component + 'static) -> Self { + self.components.push(Box::new(component)); + self + } + + /// Build the daemon. + /// + /// This loads the encryption key and creates the daemon state. + pub async fn build(self) -> Result { + let store = self.store.ok_or_else(|| eyre::eyre!("store is required"))?; + let history_db = self + .history_db + .ok_or_else(|| eyre::eyre!("history_db is required"))?; + + // Load encryption key + let encryption_key: [u8; 32] = encryption::load_key(&self.settings) + .context("could not load encryption key")? + .into(); + + // Create the event bus + let (event_tx, _) = broadcast::channel(64); + + // Create the shared state + let state = Arc::new(DaemonState { + event_tx, + settings: RwLock::new(self.settings), + encryption_key, + history_db, + store, + }); + + // Create the handle (just a reference to the state) + let handle = DaemonHandle { state }; + + Ok(Daemon { + components: self.components, + handle, + }) + } +} diff --git a/crates/turtle/src/atuin_daemon/events.rs b/crates/turtle/src/atuin_daemon/events.rs new file mode 100644 index 00000000..9a398925 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/events.rs @@ -0,0 +1,74 @@ +//! Daemon events. +//! +//! Events are the primary communication mechanism within the daemon. +//! Components emit events to notify others of state changes, and handle +//! events to react to changes elsewhere in the system. +//! +//! External processes (like CLI commands) can also inject events via the +//! Control gRPC service. + +use crate::atuin_client::history::{History, HistoryId}; +use crate::atuin_common::record::RecordId; + +/// Events that flow through the daemon's event bus. +/// +/// Events are broadcast to all components. Each component decides which +/// events it cares about in its `handle_event` implementation. +#[derive(Debug, Clone)] +pub enum DaemonEvent { + // ---- History lifecycle ---- + /// A command has started running. + HistoryStarted(History), + + /// A command has finished running. + HistoryEnded(History), + + // ---- Sync ---- + /// Records were synced from the server. + /// + /// The search component uses this to update its index with new history. + RecordsAdded(Vec), + + /// Sync completed successfully. + SyncCompleted { + /// Number of records uploaded. + uploaded: usize, + /// Number of records downloaded. + downloaded: usize, + }, + + /// Sync failed. + SyncFailed { + /// Error message describing what went wrong. + error: String, + }, + + /// Request an immediate sync (external trigger). + ForceSync, + + // ---- External commands ---- + /// History was pruned - search index needs a full rebuild. + /// + /// Emitted when the user runs `atuin history prune` or similar. + HistoryPruned, + + /// History was rebuilt - search index needs a full rebuild. + /// + /// Emitted when the user runs `atuin store rebuild history` or similar. + HistoryRebuilt, + + /// Specific history items were deleted. + /// + /// The search component should remove these from its index. + HistoryDeleted { + /// IDs of the deleted history entries. + ids: Vec, + }, + + /// Settings have changed, components should reload if needed. + SettingsReloaded, + + // ---- Lifecycle ---- + /// Request graceful shutdown of the daemon. + ShutdownRequested, +} diff --git a/crates/turtle/src/atuin_daemon/history/mod.rs b/crates/turtle/src/atuin_daemon/history/mod.rs new file mode 100644 index 00000000..b71853df --- /dev/null +++ b/crates/turtle/src/atuin_daemon/history/mod.rs @@ -0,0 +1,6 @@ +//! History module for the daemon gRPC history service. +//! +//! This module contains the proto-generated types for the history gRPC service. + +// Include the generated proto code +tonic::include_proto!("history"); diff --git a/crates/turtle/src/atuin_daemon/mod.rs b/crates/turtle/src/atuin_daemon/mod.rs new file mode 100644 index 00000000..b05eb95c --- /dev/null +++ b/crates/turtle/src/atuin_daemon/mod.rs @@ -0,0 +1,128 @@ +use crate::atuin_client::database::Sqlite as HistoryDatabase; +use crate::atuin_client::record::sqlite_store::SqliteStore; +use crate::atuin_client::settings::{Settings, watcher::global_settings_watcher}; +use eyre::Result; + +pub mod client; +pub mod components; +pub mod control; +pub mod daemon; +pub mod events; +pub mod history; +pub mod search; +pub mod semantic; +pub mod server; + +// Re-export core daemon types for convenience +pub use daemon::{Component, Daemon, DaemonBuilder, DaemonHandle}; +pub use events::DaemonEvent; + +// Re-export components +pub use components::{HistoryComponent, SearchComponent, SemanticComponent, SyncComponent}; + +// Re-export client helpers +pub use client::{ControlClient, SemanticClient, emit_event, emit_event_with_settings}; + +/// Boot the daemon using the new component-based architecture. +/// +/// This creates a daemon with the standard components (history, search, sync), +/// starts the gRPC server with their services, and runs the event loop. +pub async fn boot( + settings: Settings, + store: SqliteStore, + history_db: HistoryDatabase, +) -> Result<()> { + // Create the components + let history_component = HistoryComponent::new(); + let search_component = SearchComponent::new(); + let semantic_component = SemanticComponent::new(); + let sync_component = SyncComponent::new(); + + // Get the gRPC services before moving components into the daemon + // (The services share state with the components via Arc) + let history_service = history_component.grpc_service(); + let search_service = search_component.grpc_service(); + let semantic_service = semantic_component.grpc_service(); + + // Build the daemon + let mut daemon = Daemon::builder(settings.clone()) + .store(store) + .history_db(history_db) + .component(history_component) + .component(search_component) + .component(semantic_component) + .component(sync_component) + .build() + .await?; + + // Get a handle for the control service and gRPC server shutdown + let handle = daemon.handle(); + + // Create the control service + let control_service = control::ControlService::new(handle.clone()); + + // Start all components first (so gRPC services can work) + daemon.start_components().await?; + + // Spawn config file watcher to reload settings on changes + if let Ok(watcher) = global_settings_watcher() { + let mut settings_rx = watcher.subscribe(); + let watcher_handle = handle.clone(); + tokio::spawn(async move { + tracing::info!("config file watcher started"); + while settings_rx.changed().await.is_ok() { + // Use the already-loaded settings from the watcher + // (avoids parsing the config file twice) + let new_settings = (*settings_rx.borrow()).clone(); + watcher_handle.apply_settings((*new_settings).clone()).await; + } + tracing::debug!("config file watcher stopped"); + }); + } else { + tracing::warn!( + "failed to start config file watcher; settings changes will require daemon restart" + ); + } + + // Spawn signal handler to emit ShutdownRequested on Ctrl+C/SIGTERM + let signal_handle = handle.clone(); + tokio::spawn(async move { + shutdown_signal().await; + tracing::info!("received shutdown signal"); + signal_handle.shutdown(); + }); + + // Start the gRPC server in the background + server::run_grpc_server( + settings, + history_service, + search_service, + semantic_service, + control_service.into_server(), + handle, + ) + .await?; + + // Run the daemon event loop + daemon.run_event_loop().await?; + + // Stop all components on shutdown + daemon.stop_components().await; + + tracing::info!("daemon shut down complete"); + Ok(()) +} + +/// Wait for a shutdown signal (Ctrl+C or SIGTERM). +#[cfg(unix)] +async fn shutdown_signal() { + let mut term = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to register sigterm handler"); + let mut int = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) + .expect("failed to register sigint handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = int.recv() => {}, + } +} diff --git a/crates/turtle/src/atuin_daemon/search/index.rs b/crates/turtle/src/atuin_daemon/search/index.rs new file mode 100644 index 00000000..df627e1b --- /dev/null +++ b/crates/turtle/src/atuin_daemon/search/index.rs @@ -0,0 +1,684 @@ +//! Search index with frecency-based ranking. +//! +//! This module provides a deduplicated search index where each unique command +//! is stored once, with metadata about all its invocations. This enables: +//! +//! - Efficient fuzzy matching (fewer items to match) +//! - Frecency-based ranking (frequency + recency) +//! - Dynamic filtering by directory, host, session, etc. + +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use crate::atuin_client::settings::Search; +use crate::{ + atuin_client::history::{History, is_known_agent}, + atuin_daemon::components::search::with_trailing_slash, +}; +use atuin_nucleo::{Injector, Nucleo, pattern}; +use dashmap::DashMap; +use lasso::{Spur, ThreadedRodeo}; +use time::OffsetDateTime; +use tokio::sync::RwLock; +use tracing::{Level, instrument}; +use uuid::Uuid; + +/// Parse a UUID string into a 16-byte array. +/// Returns None if the string is not a valid UUID. +fn parse_uuid_bytes(s: &str) -> Option<[u8; 16]> { + Uuid::parse_str(s).ok().map(|u| *u.as_bytes()) +} + +/// Format a 16-byte array as a UUID string. +fn format_uuid_bytes(bytes: &[u8; 16]) -> String { + Uuid::from_bytes(*bytes).to_string() +} + +/// Pre-computed frecency data for O(1) lookup. +#[derive(Debug, Clone, Default)] +pub struct FrecencyData { + /// Total number of times this command was used. + pub count: u32, + /// Most recent usage timestamp (unix seconds). + pub last_used: i64, +} + +impl FrecencyData { + /// Record a new usage of this command. + pub fn record_use(&mut self, timestamp: i64) { + self.count += 1; + if timestamp > self.last_used { + self.last_used = timestamp; + } + } + + /// Compute frecency score based on count and recency. + /// + /// Uses a decay function where more recent commands score higher. + /// The formula balances frequency (how often) with recency (how recent). + /// + /// Multipliers allow tuning the relative weights: + /// - `recency_mul`: Multiplier for recency score (default: 1.0) + /// - `frequency_mul`: Multiplier for frequency score (default: 1.0) + /// + /// A multiplier of 0.0 disables that component, 1.0 is unchanged, 2.0 doubles weight. + /// Values like 0.5 reduce weight by half, 1.5 increases by 50%, etc. + #[instrument(level = tracing::Level::TRACE, name = "index_frecency_compute")] + pub fn compute(&self, now: i64, recency_mul: f64, frequency_mul: f64) -> u32 { + if self.count == 0 { + return 0; + } + + // Time-based decay: score decreases as time passes + let age_seconds = (now - self.last_used).max(0) as u64; + let age_hours = age_seconds / 3600; + + // Decay factor: recent commands get higher scores + // - Last hour: multiplier ~1.0 + // - Last day: multiplier ~0.5 + // - Last week: multiplier ~0.1 + // - Older: multiplier approaches 0 + let recency_score: f64 = match age_hours { + 0 => 100.0, + 1..=6 => 90.0, + 7..=24 => 70.0, + 25..=72 => 50.0, + 73..=168 => 30.0, + 169..=720 => 15.0, + _ => 5.0, + }; + + // Frequency boost: more uses = higher score (with diminishing returns) + let frequency_score = ((self.count as f64).ln() * 20.0).min(100.0); + + // Apply multipliers and combine scores, then round to u32 + ((recency_score * recency_mul) + (frequency_score * frequency_mul)).round() as u32 + } +} + +/// Data for a unique command. +pub struct CommandData { + /// History ID of the most recent invocation (16-byte UUID). + most_recent_id: [u8; 16], + /// Timestamp of the most recent invocation. + most_recent_timestamp: i64, + /// Pre-computed global frecency. + pub global_frecency: FrecencyData, + + // Pre-computed indexes for O(1) filter lookups + // Using HashSet instead of DashSet since CommandData lives inside DashMap (already synchronized) + /// All directories where this command has been run (interned keys). + directories: HashSet, + /// All hostnames where this command has been run (interned keys). + hosts: HashSet, + /// All sessions where this command has been run (as 16-byte UUIDs). + sessions: HashSet<[u8; 16]>, +} + +impl CommandData { + /// Create a new CommandData from a history entry. + /// Returns None if the history entry has invalid UUIDs. + pub fn new(history: &History, interner: &ThreadedRodeo) -> Option { + let history_id = parse_uuid_bytes(&history.id.0)?; + let session = parse_uuid_bytes(&history.session)?; + let timestamp = history.timestamp.unix_timestamp(); + + let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); + let host_key = interner.get_or_intern(&history.hostname); + + let mut directories = HashSet::new(); + directories.insert(dir_key); + + let mut hosts = HashSet::new(); + hosts.insert(host_key); + + let mut sessions = HashSet::new(); + sessions.insert(session); + + let mut global_frecency = FrecencyData::default(); + global_frecency.record_use(timestamp); + + Some(Self { + most_recent_id: history_id, + most_recent_timestamp: timestamp, + global_frecency, + directories, + hosts, + sessions, + }) + } + + /// Add an invocation from a history entry. + /// Returns false if the history entry has invalid UUIDs. + pub fn add_invocation(&mut self, history: &History, interner: &ThreadedRodeo) -> bool { + let Some(history_id) = parse_uuid_bytes(&history.id.0) else { + return false; + }; + let Some(session) = parse_uuid_bytes(&history.session) else { + return false; + }; + + let timestamp = history.timestamp.unix_timestamp(); + + // Update global frecency + self.global_frecency.record_use(timestamp); + + // Update pre-computed indexes for O(1) filter lookups + let dir_key = interner.get_or_intern(with_trailing_slash(&history.cwd)); + self.directories.insert(dir_key); + self.hosts.insert(interner.get_or_intern(&history.hostname)); + self.sessions.insert(session); + + // Update most recent if this invocation is newer + if timestamp > self.most_recent_timestamp { + self.most_recent_id = history_id; + self.most_recent_timestamp = timestamp; + } + + true + } + + /// Get the most recent history ID for this command. + pub fn most_recent_id(&self) -> String { + format_uuid_bytes(&self.most_recent_id) + } + + /// Check if any invocation matches a directory filter (exact match). + /// O(1) lookup using pre-computed index. + pub fn has_invocation_in_dir(&self, dir: &str, interner: &ThreadedRodeo) -> bool { + interner + .get(dir) + .is_some_and(|spur| self.directories.contains(&spur)) + } + + /// Check if any invocation matches a directory prefix (workspace/git root). + /// O(n) where n = number of unique directories for this command. + pub fn has_invocation_in_workspace(&self, prefix: &str, interner: &ThreadedRodeo) -> bool { + self.directories + .iter() + .any(|&spur| interner.resolve(&spur).starts_with(prefix)) + } + + /// Check if any invocation matches a hostname. + /// O(1) lookup using pre-computed index. + pub fn has_invocation_on_host(&self, hostname: &str, interner: &ThreadedRodeo) -> bool { + interner + .get(hostname) + .is_some_and(|spur| self.hosts.contains(&spur)) + } + + /// Check if any invocation matches a session. + /// O(1) lookup using pre-computed index. + pub fn has_invocation_in_session(&self, session: &str) -> bool { + parse_uuid_bytes(session).is_some_and(|bytes| self.sessions.contains(&bytes)) + } +} + +/// Filter mode for search queries. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum IndexFilterMode { + /// No filtering - search all commands. + Global, + /// Filter to commands run in a specific directory. + Directory(String), + /// Filter to commands run in a workspace (directory prefix). + Workspace(String), + /// Filter to commands run on a specific host. + Host(String), + /// Filter to commands run in a specific session. + Session(String), +} + +/// Context for search queries. +#[derive(Debug, Clone, Default)] +pub struct QueryContext { + pub cwd: Option, + pub git_root: Option, + pub hostname: Option, + pub session_id: Option, +} + +/// Shareable frecency map: command -> frecency score. +/// Wrapped in Arc for zero-copy sharing with scorer callbacks. +type FrecencyMap = Arc, u32>>; + +/// A deduplicated search index with frecency-based ranking. +/// +/// Commands are stored by their text, with metadata about all invocations. +/// Nucleo handles fuzzy matching, while frecency is computed via scorer callback. +/// +/// Global frecency is precomputed by a background task and used for scoring. +/// If frecency data is not available, search still works but without frecency ranking; +/// although this should never happen due to precomputing the frecency map. +pub struct SearchIndex { + /// Map from command text to command data. + /// Using DashMap for concurrent read/write access, wrapped in Arc for sharing with scorer. + /// Keys are Arc to enable zero-copy sharing with frecency_map. + commands: Arc, CommandData>>, + /// Nucleo fuzzy matcher - items are command strings. + nucleo: RwLock>, + /// Injector for adding new commands to Nucleo. + injector: Injector, + /// Precomputed global frecency map. Updated by background task. + frecency_map: RwLock>, + /// String interner for deduplicating cwd, hostname, and directory paths. + interner: Arc, +} + +impl SearchIndex { + /// Create a new empty search index. + pub fn new() -> Self { + let nucleo_config = atuin_nucleo::Config::DEFAULT; + // Single column for command text + let nucleo = Nucleo::::new(nucleo_config, Arc::new(|| {}), None, 1); + let injector = nucleo.injector(); + + Self { + commands: Arc::new(DashMap::new()), + nucleo: RwLock::new(nucleo), + injector, + frecency_map: RwLock::new(None), + interner: Arc::new(ThreadedRodeo::new()), + } + } + + /// Add a history entry to the index. + /// + /// If the command already exists, updates its invocation data. + /// If it's a new command, adds it to both the map and Nucleo. + pub fn add_history(&self, history: &History) { + if is_known_agent(&history.author) { + return; + } + + let command = history.command.as_str(); + + // DashMap with Arc keys can be looked up with &str via Borrow trait + if let Some(mut entry) = self.commands.get_mut(command) { + // Existing command - just update invocations + entry.add_invocation(history, &self.interner); + } else { + // New command - create Arc once and share it + let Some(data) = CommandData::new(history, &self.interner) else { + return; // Invalid UUIDs, skip this entry + }; + let command_arc: Arc = command.into(); + self.commands.insert(Arc::clone(&command_arc), data); + // Nucleo still needs String (unavoidable copy for fuzzy matching) + self.injector.push(command_arc.to_string(), |cmd, cols| { + cols[0] = cmd.clone().into(); + }); + } + // Note: frecency_map is rebuilt by background task, not invalidated here + } + + /// Add multiple history entries to the index. + pub fn add_histories(&self, histories: &[History]) { + for history in histories { + self.add_history(history); + } + } + + /// Get the number of unique commands in the index. + pub fn command_count(&self) -> usize { + self.commands.len() + } + + /// Get the number of items in Nucleo (should match command_count). + pub async fn nucleo_item_count(&self) -> u32 { + self.nucleo.read().await.snapshot().item_count() + } + + /// Search for commands matching a query. + /// + /// Returns a list of history IDs (most recent invocation per command). + /// Uses precomputed global frecency for scoring if available. + #[instrument(skip_all, level = tracing::Level::TRACE, name = "index_search", fields(query = %query))] + pub async fn search( + &self, + query: &str, + filter_mode: IndexFilterMode, + _context: &QueryContext, + limit: u32, + ) -> Vec { + let mut nucleo = self.nucleo.write().await; + + // Get precomputed frecency map (may be None if not yet computed) + let frecency_map = self.frecency_map.read().await.clone(); + + // Build filter based on mode + let filter = self.build_filter(&filter_mode); + nucleo.set_filter(filter); + + // Build scorer from precomputed frecency (or None if not available) + let scorer = Self::build_scorer(frecency_map); + nucleo.set_scorer(scorer); + + // Update pattern + nucleo.pattern.reparse( + 0, + query, + pattern::CaseMatching::Smart, + pattern::Normalization::Smart, + false, + ); + + tracing::span!(Level::TRACE, "index_search_tick").in_scope(|| { + // Tick until complete + while nucleo.tick(10).running {} + }); + + // Collect results + let snapshot = nucleo.snapshot(); + let matched_count = snapshot.matched_item_count().min(limit); + + tracing::span!(Level::TRACE, "index_search_results").in_scope(|| { + snapshot + .matched_items(..matched_count) + .filter_map(|item| { + let cmd = item.data; + // DashMap, _>::get accepts &str via Borrow trait + self.commands + .get(cmd.as_str()) + .map(|data| data.most_recent_id()) + }) + .collect() + }) + } + + /// Rebuild the global frecency map. + /// + /// This should be called by a background task periodically. + /// The map is used for scoring search results. + /// + /// Uses multipliers from search settings: + /// - `recency_score_multiplier`: Weight for recency component + /// - `frequency_score_multiplier`: Weight for frequency component + /// - `frecency_score_multiplier`: Overall multiplier for final score + #[instrument(skip_all, level = tracing::Level::DEBUG, name = "rebuild_frecency")] + pub async fn rebuild_frecency(&self, search_settings: &Search) { + let now = OffsetDateTime::now_utc().unix_timestamp(); + let mut frecency_map: HashMap, u32> = HashMap::new(); + + // Clamp multipliers to non-negative values to prevent broken frecency ranking + // (negative values would produce unexpected results when cast to u32) + let recency_mul = search_settings.recency_score_multiplier.max(0.0); + let frequency_mul = search_settings.frequency_score_multiplier.max(0.0); + let frecency_mul = search_settings.frecency_score_multiplier.max(0.0); + + for entry in self.commands.iter() { + let frecency = entry + .global_frecency + .compute(now, recency_mul, frequency_mul); + // Apply overall frecency multiplier and round to u32 + let frecency = (frecency as f64 * frecency_mul).round() as u32; + // Arc::clone is cheap - just increments reference count + frecency_map.insert(Arc::clone(entry.key()), frecency); + } + + *self.frecency_map.write().await = Some(Arc::new(frecency_map)); + } + + /// Build filter predicate for the given mode. + fn build_filter(&self, mode: &IndexFilterMode) -> Option> { + // For Global mode, no filter needed + if matches!(mode, IndexFilterMode::Global) { + return None; + } + + // Pre-compute which commands pass the filter + // Use HashSet for the short-lived filter (simpler than Arc lookup) + let passing_commands: Arc> = { + let mut set = HashSet::new(); + for entry in self.commands.iter() { + let passes = match mode { + IndexFilterMode::Global => unreachable!(), + IndexFilterMode::Directory(dir) => { + entry.has_invocation_in_dir(dir, &self.interner) + } + IndexFilterMode::Workspace(prefix) => { + entry.has_invocation_in_workspace(prefix, &self.interner) + } + IndexFilterMode::Host(hostname) => { + entry.has_invocation_on_host(hostname, &self.interner) + } + IndexFilterMode::Session(session) => entry.has_invocation_in_session(session), + }; + if passes { + // Convert Arc to String for filter lookup + set.insert(entry.key().to_string()); + } + } + Arc::new(set) + }; + + Some(Arc::new(move |cmd: &String| passing_commands.contains(cmd))) + } + + /// Build scorer from precomputed frecency map. + /// + /// Returns None if frecency map is not available (search still works, just without frecency ranking). + fn build_scorer(frecency_map: Option) -> Option> { + let map = frecency_map?; + Some(Arc::new(move |cmd: &String, fuzzy_score: u32| { + // HashMap, _>::get accepts &str via Borrow trait + let frecency = map.get(cmd.as_str()).copied().unwrap_or(0); + fuzzy_score + frecency + })) + } +} + +impl Default for SearchIndex { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use time::macros::datetime; + + fn make_history(command: &str, cwd: &str, timestamp: OffsetDateTime) -> History { + History::import() + .timestamp(timestamp) + .command(command) + .cwd(cwd) + .build() + .into() + } + + #[test] + fn frecency_data_compute() { + let now = 1_000_000i64; + + // Recent command (with default multipliers of 1.0) + let recent = FrecencyData { + count: 5, + last_used: now - 60, // 1 minute ago + }; + assert!(recent.compute(now, 1.0, 1.0) > 100); // High score + + // Old command + let old = FrecencyData { + count: 5, + last_used: now - 86400 * 30, // 30 days ago + }; + assert!(old.compute(now, 1.0, 1.0) < recent.compute(now, 1.0, 1.0)); + + // Frequently used old command + let frequent_old = FrecencyData { + count: 100, + last_used: now - 86400 * 7, // 1 week ago + }; + // Should still have decent score due to frequency + assert!(frequent_old.compute(now, 1.0, 1.0) > 50); + } + + #[test] + fn frecency_data_compute_with_multipliers() { + let now = 1_000_000_i64; + + let data = FrecencyData { + count: 5, + last_used: now - 60, // 1 minute ago (recency_score = 100) + }; + + // Default multipliers (1.0, 1.0) + let default_score = data.compute(now, 1.0, 1.0); + + // Double recency weight + let double_recency = data.compute(now, 2.0, 1.0); + assert!(double_recency > default_score); + + // Double frequency weight + let double_frequency = data.compute(now, 1.0, 2.0); + assert!(double_frequency > default_score); + + // Zero out recency (only frequency counts) + let no_recency = data.compute(now, 0.0, 1.0); + assert!(no_recency < default_score); + + // Zero out frequency (only recency counts) + let no_frequency = data.compute(now, 1.0, 0.0); + assert!(no_frequency < default_score); + + // Zero both (should be zero) + let no_score = data.compute(now, 0.0, 0.0); + assert_eq!(no_score, 0); + + // Fractional multipliers + let half_recency = data.compute(now, 0.5, 1.0); + assert!(half_recency < default_score); + assert!(half_recency > no_recency); + + // 1.5x multiplier + let boost_recency = data.compute(now, 1.5, 1.0); + assert!(boost_recency > default_score); + assert!(boost_recency < double_recency); + } + + #[test] + fn command_data_add_invocation() { + let interner = ThreadedRodeo::new(); + + let (dir1, dir2) = if cfg!(windows) { + ("C:\\Users\\User\\project", "C:\\Users\\User\\other") + } else { + ("/home/user/project", "/home/user/other") + }; + + let history1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let history2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); + + let mut data = CommandData::new(&history1, &interner).unwrap(); + assert_eq!(data.global_frecency.count, 1); + let id1 = data.most_recent_id(); + + data.add_invocation(&history2, &interner); + assert_eq!(data.global_frecency.count, 2); + + // Most recent ID should update to history2 (newer timestamp) + let id2 = data.most_recent_id(); + assert_ne!(id1, id2); + } + + #[test] + fn command_data_filters() { + let interner = ThreadedRodeo::new(); + + let (dir1, dir2) = if cfg!(windows) { + ("C:\\Users\\User\\project", "C:\\Users\\User\\other") + } else { + ("/home/user/project", "/home/user/other") + }; + + let h1 = make_history("git status", dir1, datetime!(2024-01-01 10:00 UTC)); + let h2 = make_history("git status", dir2, datetime!(2024-01-01 12:00 UTC)); + + let mut data = CommandData::new(&h1, &interner).unwrap(); + data.add_invocation(&h2, &interner); + + let (check1, check2, check3) = if cfg!(windows) { + ( + with_trailing_slash("C:\\Users\\User\\project"), + with_trailing_slash("C:\\Users\\User\\other"), + with_trailing_slash("C:\\Users\\User\\missing"), + ) + } else { + ( + with_trailing_slash("/home/user/project"), + with_trailing_slash("/home/user/other"), + with_trailing_slash("/home/user/missing"), + ) + }; + + assert!(data.has_invocation_in_dir(&check1, &interner)); + assert!(data.has_invocation_in_dir(&check2, &interner)); + assert!(!data.has_invocation_in_dir(&check3, &interner)); + + let (check1, check2, check3) = if cfg!(windows) { + ( + with_trailing_slash("C:\\Users\\User"), + with_trailing_slash("C:\\Users"), + with_trailing_slash("C:\\Users\\User\\var"), + ) + } else { + ( + with_trailing_slash("/home/user"), + with_trailing_slash("/home"), + with_trailing_slash("/var"), + ) + }; + + assert!(data.has_invocation_in_workspace(&check1, &interner)); + assert!(data.has_invocation_in_workspace(&check2, &interner)); + assert!(!data.has_invocation_in_workspace(&check3, &interner)); + } + + #[tokio::test] + async fn search_index_add_and_search() { + let index = SearchIndex::new(); + + let h1 = make_history( + "git status", + "/home/user/project", + datetime!(2024-01-01 10:00 UTC), + ); + let h2 = make_history( + "git commit -m 'test'", + "/home/user/project", + datetime!(2024-01-01 10:05 UTC), + ); + let h3 = make_history( + "ls -la", + "/home/user/other", + datetime!(2024-01-01 10:10 UTC), + ); + + index.add_history(&h1); + index.add_history(&h2); + index.add_history(&h3); + + assert_eq!(index.command_count(), 3); + + // Search for "git" - should match 2 commands + let results = index + .search("git", IndexFilterMode::Global, &QueryContext::default(), 10) + .await; + assert_eq!(results.len(), 2); + + // Search with directory filter + let results = index + .search( + "", + IndexFilterMode::Directory(with_trailing_slash("/home/user/project")), + &QueryContext::default(), + 10, + ) + .await; + assert_eq!(results.len(), 2); // git status and git commit + } +} diff --git a/crates/turtle/src/atuin_daemon/search/mod.rs b/crates/turtle/src/atuin_daemon/search/mod.rs new file mode 100644 index 00000000..4d261956 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/search/mod.rs @@ -0,0 +1,11 @@ +//! Search module for the daemon gRPC search service. +//! +//! This module provides fuzzy search over command history using Nucleo. + +mod index; + +// Include the generated proto code +tonic::include_proto!("search"); + +// Re-export the service and index +pub use index::{IndexFilterMode, QueryContext, SearchIndex}; diff --git a/crates/turtle/src/atuin_daemon/semantic/mod.rs b/crates/turtle/src/atuin_daemon/semantic/mod.rs new file mode 100644 index 00000000..c3511676 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/semantic/mod.rs @@ -0,0 +1,3 @@ +//! Semantic command capture gRPC service types. + +tonic::include_proto!("semantic"); diff --git a/crates/turtle/src/atuin_daemon/server.rs b/crates/turtle/src/atuin_daemon/server.rs new file mode 100644 index 00000000..23b04342 --- /dev/null +++ b/crates/turtle/src/atuin_daemon/server.rs @@ -0,0 +1,115 @@ +use eyre::Result; + +use crate::atuin_daemon::components::history::HistoryGrpcService; +use crate::atuin_daemon::components::search::SearchGrpcService; +use crate::atuin_daemon::components::semantic::SemanticGrpcService; +use crate::atuin_daemon::control::{ControlService, control_server::ControlServer}; +use crate::atuin_daemon::daemon::DaemonHandle; +use crate::atuin_daemon::history::history_server::HistoryServer; +use crate::atuin_daemon::search::search_server::SearchServer; +use crate::atuin_daemon::semantic::semantic_server::SemanticServer; + +use crate::atuin_client::settings::Settings; + +/// Run the gRPC server with the given services. +/// +/// This starts the gRPC server in the background and returns immediately. +/// The server will shut down when a ShutdownRequested event is received. +#[cfg(unix)] +pub async fn run_grpc_server( + settings: Settings, + history_service: HistoryServer, + search_service: SearchServer, + semantic_service: SemanticServer, + control_service: ControlServer, + handle: DaemonHandle, +) -> Result<()> { + use tokio::net::UnixListener; + use tokio_stream::wrappers::UnixListenerStream; + + let socket_path = settings.daemon.socket_path.clone(); + + let (uds, cleanup) = if cfg!(target_os = "linux") && settings.daemon.systemd_socket { + #[cfg(target_os = "linux")] + { + use eyre::{OptionExt, WrapErr}; + use std::os::unix::net::SocketAddr; + use std::path::PathBuf; + tracing::info!("getting systemd socket"); + let listener = listenfd::ListenFd::from_env() + .take_unix_listener(0)? + .ok_or_eyre("missing systemd socket")?; + listener.set_nonblocking(true)?; + let actual_path: Result = listener + .local_addr() + .context("getting systemd socket's path") + .and_then(|addr: SocketAddr| { + addr.as_pathname() + .ok_or_eyre("systemd socket missing path") + .map(|path: &std::path::Path| path.to_owned()) + }); + match actual_path { + Ok(actual_path) => { + tracing::info!("listening on systemd socket: {actual_path:?}"); + if actual_path != std::path::Path::new(&socket_path) { + tracing::warn!( + "systemd socket is not at configured client path: {socket_path:?}" + ); + } + } + Err(err) => { + tracing::warn!( + "could not detect systemd socket path, ensure that it's at the configured path: {socket_path:?}, error: {err:?}" + ); + } + } + (UnixListener::from_std(listener)?, false) + } + } else { + tracing::info!("listening on unix socket {socket_path:?}"); + (UnixListener::bind(socket_path.clone())?, true) + }; + + let uds_stream = UnixListenerStream::new(uds); + + // Create shutdown signal from daemon handle + let shutdown_signal = async move { + let mut rx = handle.subscribe(); + loop { + use crate::atuin_daemon::DaemonEvent; + + match rx.recv().await { + Ok(DaemonEvent::ShutdownRequested) => break, + Ok(_) => continue, + Err(_) => break, // Channel closed + } + } + if cleanup { + eprintln!("Removing socket..."); + if let Err(e) = std::fs::remove_file(&socket_path) + && e.kind() != std::io::ErrorKind::NotFound + { + eprintln!("failed to remove socket: {e}"); + } + } + eprintln!("Shutting down gRPC server..."); + }; + + // Spawn the server in the background + tokio::spawn(async move { + use tonic::transport::Server; + + if let Err(e) = Server::builder() + .add_service(history_service) + .add_service(search_service) + .add_service(semantic_service) + .add_service(control_service) + .serve_with_incoming_shutdown(uds_stream, shutdown_signal) + .await + { + tracing::error!("gRPC server error: {e}"); + } + }); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_history/mod.rs b/crates/turtle/src/atuin_history/mod.rs new file mode 100644 index 00000000..e7b33916 --- /dev/null +++ b/crates/turtle/src/atuin_history/mod.rs @@ -0,0 +1,2 @@ +pub mod sort; +pub mod stats; diff --git a/crates/turtle/src/atuin_history/sort.rs b/crates/turtle/src/atuin_history/sort.rs new file mode 100644 index 00000000..b162c810 --- /dev/null +++ b/crates/turtle/src/atuin_history/sort.rs @@ -0,0 +1,46 @@ +use crate::atuin_client::history::History; + +type ScoredHistory = (f64, History); + +// Fuzzy search already comes sorted by minspan +// This sorting should be applicable to all search modes, and solve the more "obvious" issues +// first. +// Later on, we can pass in context and do some boosts there too. +pub fn sort(query: &str, input: Vec) -> Vec { + // This can totally be extended. We need to be _careful_ that it's not slow. + // We also need to balance sorting db-side with sorting here. SQLite can do a lot, + // but some things are just much easier/more doable in Rust. + + let mut scored = input + .into_iter() + .map(|h| { + // If history is _prefixed_ with the query, score it more highly + let score = if h.command.starts_with(query) { + 2.0 + } else if h.command.contains(query) { + 1.75 + } else { + 1.0 + }; + + // calculate how long ago the history was, in seconds + let now = time::OffsetDateTime::now_utc().unix_timestamp(); + let time = h.timestamp.unix_timestamp(); + let diff = std::cmp::max(1, now - time); // no /0 please + + // prefer newer history, but not hugely so as to offset the other scoring + // the numbers will get super small over time, but I don't want time to overpower other + // scoring + #[expect(clippy::cast_precision_loss)] + let time_score = 1.0 + (1.0 / diff as f64); + let score = score * time_score; + + (score, h) + }) + .collect::>(); + + scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap().reverse()); + + // Remove the scores and return the history + scored.into_iter().map(|(_, h)| h).collect::>() +} diff --git a/crates/turtle/src/atuin_history/stats.rs b/crates/turtle/src/atuin_history/stats.rs new file mode 100644 index 00000000..e47d6c8e --- /dev/null +++ b/crates/turtle/src/atuin_history/stats.rs @@ -0,0 +1,548 @@ +use std::collections::{HashMap, HashSet}; + +use crossterm::style::{Color, ResetColor, SetAttribute, SetForegroundColor}; +use serde::{Deserialize, Serialize}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::atuin_client::{history::History, settings::Settings, theme::Meaning, theme::Theme}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Stats { + pub total_commands: usize, + pub unique_commands: usize, + pub top: Vec<(Vec, usize)>, +} + +fn first_non_whitespace(s: &str) -> Option { + s.char_indices() + // find the first non whitespace char + .find(|(_, c)| !c.is_ascii_whitespace()) + // return the index of that char + .map(|(i, _)| i) +} + +fn first_whitespace(s: &str) -> usize { + s.char_indices() + // find the first whitespace char + .find(|(_, c)| c.is_ascii_whitespace()) + // return the index of that char, (or the max length of the string) + .map_or(s.len(), |(i, _)| i) +} + +fn interesting_command<'a>(settings: &Settings, mut command: &'a str) -> &'a str { + // Sort by length so that we match the longest prefix first + let mut common_prefix = settings.stats.common_prefix.clone(); + common_prefix.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Trim off the common prefix, if it exists + for p in &common_prefix { + if command.starts_with(p) { + let i = p.len(); + let prefix = &command[..i]; + command = command[i..].trim_start(); + if command.is_empty() { + // no commands following, just use the prefix + return prefix; + } + break; + } + } + + // Sort the common_subcommands by length so that we match the longest subcommand first + let mut common_subcommands = settings.stats.common_subcommands.clone(); + common_subcommands.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Check for a common subcommand + for p in &common_subcommands { + if command.starts_with(p) { + // if the subcommand is the same length as the command, then we just use the subcommand + if p.len() == command.len() { + return command; + } + // otherwise we need to use the subcommand + the next word + let non_whitespace = first_non_whitespace(&command[p.len()..]).unwrap_or(0); + let j = + p.len() + non_whitespace + first_whitespace(&command[p.len() + non_whitespace..]); + return &command[..j]; + } + } + // Return the first word if there is no subcommand + &command[..first_whitespace(command)] +} + +fn split_at_pipe(command: &str) -> Vec<&str> { + let mut result = vec![]; + let mut quoted = false; + let mut start = 0; + let mut graphemes = UnicodeSegmentation::grapheme_indices(command, true); + + while let Some((i, c)) = graphemes.next() { + let current = i; + match c { + "\"" if command[start..current] != *"\"" => { + quoted = !quoted; + } + "'" if command[start..current] != *"'" => { + quoted = !quoted; + } + "\\" if graphemes.next().is_some() => {} + "|" if !quoted => { + if current > start && command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..current]); + start = current; + } + _ => {} + } + } + if command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..]); + result +} + +fn strip_leading_env_vars(command: &str) -> &str { + // fast path: no equals sign, no environment variable + if !command.contains('=') { + return command; + } + + let mut in_token = false; + let mut token_start_pos = 0; + let mut in_single_quotes = false; + let mut in_double_quotes = false; + let mut escape_next = false; + let mut has_equals_outside_quotes = false; + + for (i, g) in UnicodeSegmentation::grapheme_indices(command, true) { + if escape_next { + escape_next = false; + continue; + } + + if !in_token { + token_start_pos = i; + } + + match g { + "\\" => { + escape_next = true; + in_token = true; + } + "'" if !in_double_quotes => { + in_single_quotes = !in_single_quotes; + in_token = true; + } + "\"" if !in_single_quotes => { + in_double_quotes = !in_double_quotes; + in_token = true; + } + "=" if !in_single_quotes && !in_double_quotes => { + has_equals_outside_quotes = true; + in_token = true; + } + " " | "\t" if !in_single_quotes && !in_double_quotes => { + if in_token { + if !has_equals_outside_quotes { + // if we're not in an env var, we can break early + break; + } + in_token = false; + has_equals_outside_quotes = false; + } + } + _ => { + in_token = true; + } + } + } + + command[token_start_pos..].trim() +} + +pub fn pretty_print(stats: Stats, ngram_size: usize, theme: &Theme) { + let max = stats.top.iter().map(|x| x.1).max().unwrap(); + let num_pad = max.ilog10() as usize + 1; + + // Find the length of the longest command name for each column + let column_widths = stats + .top + .iter() + .map(|(commands, _)| commands.iter().map(|c| c.len()).collect::>()) + .fold(vec![0; ngram_size], |acc, item| { + acc.iter() + .zip(item.iter()) + .map(|(a, i)| *std::cmp::max(a, i)) + .collect() + }); + + for (command, count) in stats.top { + let gray = SetForegroundColor(match theme.as_style(Meaning::Muted).foreground_color { + Some(color) => color, + None => Color::Grey, + }); + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + let in_ten = 10 * count / max; + + print!("["); + print!( + "{}", + SetForegroundColor(match theme.get_error().foreground_color { + Some(color) => color, + None => Color::Red, + }) + ); + + for i in 0..in_ten { + if i == 2 { + print!( + "{}", + SetForegroundColor(match theme.get_warning().foreground_color { + Some(color) => color, + None => Color::Yellow, + }) + ); + } + + if i == 5 { + print!( + "{}", + SetForegroundColor(match theme.get_info().foreground_color { + Some(color) => color, + None => Color::Green, + }) + ); + } + + print!("▮"); + } + + for _ in in_ten..10 { + print!(" "); + } + + let formatted_command = command + .iter() + .zip(column_widths.iter()) + .map(|(cmd, width)| format!("{cmd:width$}")) + .collect::>() + .join(" | "); + + println!( + "{ResetColor}] {gray}{count:num_pad$}{ResetColor} {bold}{formatted_command}{ResetColor}" + ); + } + println!("Total commands: {}", stats.total_commands); + println!("Unique commands: {}", stats.unique_commands); +} + +pub fn compute( + settings: &Settings, + history: &[History], + count: usize, + ngram_size: usize, +) -> Option { + let mut commands = HashSet::<&str>::with_capacity(history.len()); + let mut total_unignored = 0; + let mut prefixes = HashMap::, usize>::with_capacity(history.len()); + + for i in history { + // just in case it somehow has a leading tab or space or something (legacy atuin didn't ignore space prefixes) + let command = strip_leading_env_vars(i.command.trim()); + let prefix = interesting_command(settings, command); + + if settings.stats.ignored_commands.iter().any(|c| c == prefix) { + continue; + } + + total_unignored += 1; + commands.insert(command); + + split_at_pipe(command) + .iter() + .map(|l| { + let command = l.trim(); + commands.insert(command); + command + }) + .collect::>() + .windows(ngram_size) + .for_each(|w| { + *prefixes + .entry(w.iter().map(|c| interesting_command(settings, c)).collect()) + .or_default() += 1; + }); + } + + let unique = commands.len(); + let mut top = prefixes.into_iter().collect::>(); + + top.sort_unstable_by_key(|x| std::cmp::Reverse(x.1)); + top.truncate(count); + + if top.is_empty() { + return None; + } + + Some(Stats { + unique_commands: unique, + total_commands: total_unignored, + top: top + .into_iter() + .map(|t| (t.0.into_iter().map(|s| s.to_string()).collect(), t.1)) + .collect(), + }) +} + +#[cfg(test)] +mod tests { + use crate::atuin_client::history::History; + use crate::atuin_client::settings::Settings; + use time::OffsetDateTime; + + use super::compute; + use super::{interesting_command, split_at_pipe, strip_leading_env_vars}; + + #[test] + fn ignored_env_vars() { + let settings = Settings::utc(); + + let history: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("FOO='BAR=🚀' echo foo") + .cwd("/") + .build() + .into(); + + let stats = compute(&settings, &[history], 10, 1).expect("failed to compute stats"); + assert_eq!(stats.top.first().unwrap().0, vec!["echo"]); + } + + #[test] + fn ignored_commands() { + let mut settings = Settings::utc(); + settings.stats.ignored_commands.push("cd".to_string()); + + let history = [ + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cd foo") + .build() + .into(), + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cargo build stuff") + .build() + .into(), + ]; + + let stats = compute(&settings, &history, 10, 1).expect("failed to compute stats"); + assert_eq!(stats.total_commands, 1); + assert_eq!(stats.unique_commands, 1); + } + + #[test] + fn interesting_commands() { + let settings = Settings::utc(); + + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build" + ); + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + } + + // Test with spaces in the common_prefix + #[test] + fn interesting_commands_spaces() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + + assert_eq!(interesting_command(&settings, "sudo test"), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test "), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test foo bar"), "foo"); + assert_eq!( + interesting_command(&settings, "sudo test foo bar"), + "foo" + ); + + // Works with a common_subcommand as well + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build" + ); + + // We still match on just the sudo prefix + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + assert_eq!(interesting_command(&settings, "sudo foo"), "foo"); + } + + // Test with spaces in the common_subcommand + #[test] + fn interesting_commands_spaces_subcommand() { + let mut settings = Settings::utc(); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!(interesting_command(&settings, "cargo build"), "cargo build"); + assert_eq!( + interesting_command(&settings, "cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build foo" + ); + + // Works with a common_prefix as well + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build foo" + ); + + // We still match on just cargo as a subcommand + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!(interesting_command(&settings, "cargo foo"), "cargo foo"); + } + + // Test with spaces in the common_prefix and common_subcommand + #[test] + fn interesting_commands_spaces_both() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build foo" + ); + } + + #[test] + fn split_simple() { + assert_eq!(split_at_pipe("fd | rg"), ["fd ", " rg"]); + } + + #[test] + fn split_multi() { + assert_eq!( + split_at_pipe("kubectl | jq | rg"), + ["kubectl ", " jq ", " rg"] + ); + } + + #[test] + fn split_simple_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz {} | quux' | xyzzy"), + ["foo ", " bar 'baz {} | quux' ", " xyzzy"] + ); + } + + #[test] + fn split_multi_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz \"{}\" | quux' | xyzzy"), + ["foo ", " bar 'baz \"{}\" | quux' ", " xyzzy"] + ); + } + + #[test] + fn escaped_pipes() { + assert_eq!( + split_at_pipe("foo | bar baz \\| quux"), + ["foo ", " bar baz \\| quux"] + ); + } + + #[test] + fn emoji() { + assert_eq!( + split_at_pipe("git commit -m \"🚀\""), + ["git commit -m \"🚀\""] + ); + } + + #[test] + fn starts_with_pipe() { + assert_eq!( + split_at_pipe("| sed 's/[0-9a-f]//g'"), + ["", " sed 's/[0-9a-f]//g'"] + ); + } + + #[test] + fn starts_with_spaces_and_pipe() { + assert_eq!( + split_at_pipe(" | sed 's/[0-9a-f]//g'"), + [" ", " sed 's/[0-9a-f]//g'"] + ); + } + + #[test] + fn strip_leading_env_vars_simple() { + assert_eq!( + strip_leading_env_vars("FOO=bar BAZ=quux echo foo"), + "echo foo" + ); + } + + #[test] + fn strip_leading_env_vars_quoted_single() { + assert_eq!(strip_leading_env_vars("FOO='BAR=baz' echo foo"), "echo foo"); + } + + #[test] + fn strip_leading_env_vars_quoted_double() { + assert_eq!( + strip_leading_env_vars("FOO=\"BAR=baz\" echo foo"), + "echo foo" + ); + } + + #[test] + fn strip_leading_env_vars_quoted_single_and_double() { + assert_eq!( + strip_leading_env_vars("FOO='BAR=\"baz\"' echo foo \"BAR=quux\""), + "echo foo \"BAR=quux\"" + ); + } + + #[test] + fn strip_leading_env_vars_emojis() { + assert_eq!( + strip_leading_env_vars("FOO='BAR=🚀' echo foo \"BAR=quux\" foo"), + "echo foo \"BAR=quux\" foo" + ); + } + + #[test] + fn strip_leading_env_vars_name_same_as_command() { + assert_eq!(strip_leading_env_vars("FOO='bar' bar baz"), "bar baz"); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/capture.rs b/crates/turtle/src/atuin_pty_proxy/capture.rs new file mode 100644 index 00000000..97ac9b8f --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/capture.rs @@ -0,0 +1,467 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; + +use crate::atuin_pty_proxy::osc133::{Event, Params, Parser, Zone}; + +const HISTORY_ID_PARAM: &str = "history_id"; +const SESSION_ID_PARAM: &str = "session_id"; +const MAX_OUTPUT_CAPTURE_BYTES: usize = 1024 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommandCapture { + pub prompt: String, + pub command: String, + pub output: String, + pub exit_code: Option, + pub history_id: Option, + pub session_id: Option, + pub output_truncated: bool, + pub output_observed_bytes: u64, +} + +pub type CommandCaptureSink = Box; + +#[derive(Default)] +struct CaptureBuffers { + prompt: Vec, + command: Vec, + output: Vec, + output_observed_bytes: u64, + output_truncated: bool, + exit_code: Option, + history_id: Option, + session_id: Option, +} + +pub(crate) struct CommandCaptureTracker { + parser: Parser, + zone: Zone, + buffers: CaptureBuffers, + cols: Arc, +} + +impl CommandCaptureTracker { + pub(crate) fn new(cols: Arc) -> Self { + Self { + parser: Parser::new(), + zone: Zone::Unknown, + buffers: CaptureBuffers::default(), + cols, + } + } + + pub(crate) fn push(&mut self, data: &[u8], mut on_capture: impl FnMut(CommandCapture)) { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + let mut start = 0; + for located in events { + let marker_start = located.start_offset.min(data.len()).max(start); + let offset = located.offset.min(data.len()); + self.append(&data[start..marker_start]); + self.handle_event(located.event, &located.params, &mut on_capture); + self.zone = located.zone; + start = offset; + } + + let append_end = self + .parser + .incomplete_osc_sequence_start() + .map_or(data.len(), |sequence_start| { + sequence_start.min(data.len()).max(start) + }); + if start < append_end { + self.append(&data[start..append_end]); + } + } + + fn append(&mut self, data: &[u8]) { + match self.zone { + Zone::Prompt => self.buffers.prompt.extend_from_slice(data), + Zone::Input => self.buffers.command.extend_from_slice(data), + Zone::Output => self.append_output(data), + Zone::Unknown => {} + } + } + + fn append_output(&mut self, data: &[u8]) { + self.buffers.output_observed_bytes = self + .buffers + .output_observed_bytes + .saturating_add(data.len() as u64); + + if self.buffers.output_truncated { + return; + } + + let remaining = MAX_OUTPUT_CAPTURE_BYTES.saturating_sub(self.buffers.output.len()); + let retained = data.len().min(remaining); + self.buffers.output_truncated = retained < data.len(); + + if retained > 0 { + self.buffers.output.extend_from_slice(&data[..retained]); + } + } + + fn handle_event( + &mut self, + event: Event, + params: &Params, + on_capture: &mut impl FnMut(CommandCapture), + ) { + match event { + Event::PromptStart => { + if self.zone != Zone::Prompt { + self.buffers = CaptureBuffers::default(); + } + } + Event::CommandStart | Event::CommandExecuted => {} + Event::CommandFinished { exit_code } => { + let Some(history_id) = params.get(HISTORY_ID_PARAM).map(str::to_owned) else { + return; + }; + + if exit_code.is_some() || self.buffers.exit_code.is_none() { + self.buffers.exit_code = exit_code; + } + self.buffers.history_id = Some(history_id); + self.buffers.session_id = params.get(SESSION_ID_PARAM).map(str::to_owned); + + if let Some(capture) = self.finish_capture() { + on_capture(capture); + } + } + } + } + + fn finish_capture(&mut self) -> Option { + let buffers = std::mem::take(&mut self.buffers); + let cols = self.cols.load(Ordering::Relaxed).max(1); + let prompt = render_plain_text(&buffers.prompt, cols); + let command = render_plain_text(&buffers.command, cols) + .trim_matches(|c| c == '\r' || c == '\n') + .to_string(); + let output = render_plain_text(&buffers.output, cols); + let output_truncated = buffers.output_truncated; + let output_observed_bytes = buffers.output_observed_bytes; + let exit_code = buffers.exit_code; + let history_id = buffers.history_id; + let session_id = buffers.session_id; + + if command.is_empty() && output.is_empty() { + return None; + } + + Some(CommandCapture { + prompt, + command, + output, + exit_code, + history_id, + session_id, + output_truncated, + output_observed_bytes, + }) + } +} + +const CLEAN_TEXT_MAX_ROWS: usize = 10_000; + +fn render_plain_text(bytes: &[u8], cols: u16) -> String { + if bytes.is_empty() { + return String::new(); + } + + let cols = cols.max(1); + let mut parser = vt100::Parser::new(estimated_rows(bytes, cols), cols, 0); + parser.process(bytes); + normalize_screen_contents(&parser.screen().contents()) +} + +fn normalize_screen_contents(contents: &str) -> String { + let mut lines = contents.lines().map(str::trim_end).collect::>(); + while lines.last().is_some_and(|line| line.is_empty()) { + lines.pop(); + } + lines.join("\n") +} + +fn estimated_rows(bytes: &[u8], cols: u16) -> u16 { + let newline_rows = bytes.iter().filter(|byte| **byte == b'\n').count() + 1; + let wrapped_rows = bytes.len() / cols as usize; + newline_rows + .saturating_add(wrapped_rows) + .saturating_add(1) + .clamp(1, CLEAN_TEXT_MAX_ROWS) as u16 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tracker(cols: u16) -> CommandCaptureTracker { + CommandCaptureTracker::new(Arc::new(AtomicU16::new(cols))) + } + + fn assert_no_terminal_controls(text: &str) { + assert!( + !text + .chars() + .any(|ch| ch.is_control() && ch != '\n' && ch != '\t'), + "text still contains terminal controls: {text:?}" + ); + } + + #[test] + fn command_text_collapses_terminal_echo_edits() { + assert_eq!(render_plain_text(b"e\x08echo hi", 80), "echo hi"); + assert_eq!( + render_plain_text( + b"e\x08echo\x08 \x08\x08 \x08\x08\x08e \x08\x08 \x08e\x08echo hi", + 80 + ), + "echo hi" + ); + assert_eq!(render_plain_text(b"echo hi", 80), "echo hi"); + } + + #[test] + fn text_cleaning_strips_ansi_and_terminal_controls() { + let text = render_plain_text( + b"\x1b[32mhi\x1b[0m\r\n% \r \r", + 80, + ); + + assert_eq!(text, "hi"); + assert_no_terminal_controls(&text); + } + + #[test] + fn text_cleaning_preserves_valid_utf8_after_backspace() { + let text = render_plain_text("🦀x\x08 \x08 crab".as_bytes(), 80); + + assert_eq!(text, "🦀 crab"); + assert_no_terminal_controls(&text); + } + + #[test] + fn command_text_replays_backspaces() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + let input = + b"\x1b]133;A\x07$ \x1b]133;B\x07e\x08echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ "; + tracker.push(input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + assert_no_terminal_controls(&captures[0].command); + assert_no_terminal_controls(&captures[0].output); + } + + #[test] + fn captures_complete_command() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "$".to_string(), + command: "echo hi".to_string(), + output: "hi".to_string(), + exit_code: Some(0), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 4, + }] + ); + } + + #[test] + fn strips_ansi_and_split_markers() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;A\x07\x1b[32m%\x1b[0m ", |_| {}); + tracker.push(b"\x1b]133;B\x07ls\x1b]133;C", |_| {}); + tracker.push( + b"\x07\x1b[31mfile\x1b[0m\r\n\x1b]133;D;1;history_id=hist;session_id=sess\x07\x1b]133;A\x07% ", + |capture| { + captures.push(capture); + }, + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: "%".to_string(), + command: "ls".to_string(), + output: "file".to_string(), + exit_code: Some(1), + history_id: Some("hist".to_string()), + session_id: Some("sess".to_string()), + output_truncated: false, + output_observed_bytes: 15, + }] + ); + } + + #[test] + fn duplicate_prompt_start_does_not_reset_prompt_capture() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;A\x07$ \x1b]133;A\x07continued \x1b]133;B\x07echo hi\r\n\x1b]133;C\x07hi\r\n\x1b]133;D;0;history_id=hist;session_id=sess\x07\x1b]133;A\x07$ ", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].prompt, "$ continued"); + assert_eq!(captures[0].command, "echo hi"); + assert_eq!(captures[0].output, "hi"); + } + + #[test] + fn bare_finish_without_metadata_is_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + + tracker.push(b"\x1b]133;A\x07$ ", |capture| captures.push(capture)); + + assert!(captures.is_empty()); + } + + #[test] + fn bare_finish_before_metadata_in_same_push_ignored() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;1\x07\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn metadata_arriving_after_bare_finish_across_pushes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push(b"\x1b]133;C\x07line one\r\n\x1b]133;D;0\x07", |capture| { + captures.push(capture); + }); + tracker.push(b"\x1b]133;D;0;history_id=018f", |capture| { + captures.push(capture) + }); + + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].exit_code, Some(0)); + assert_eq!(captures[0].history_id.as_deref(), Some("018f")); + assert_eq!(captures[0].session_id.as_deref(), Some("abcd")); + } + + #[test] + fn split_finish_marker_is_not_counted_as_output() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f", + |capture| { + captures.push(capture); + }, + ); + assert!(captures.is_empty()); + + tracker.push(b";session_id=abcd\x07", |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert_eq!(captures[0].output, "line one"); + assert_eq!(captures[0].output_observed_bytes, 10); + } + + #[test] + fn captures_output_with_history_metadata_from_d_marker() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07line one\r\n\x1b]133;D;0;history_id=018f;session_id=abcd\x07", + |capture| captures.push(capture), + ); + + assert_eq!( + captures, + vec![CommandCapture { + prompt: String::new(), + command: String::new(), + output: "line one".to_string(), + exit_code: Some(0), + history_id: Some("018f".to_string()), + session_id: Some("abcd".to_string()), + output_truncated: false, + output_observed_bytes: 10, + }] + ); + } + + #[test] + fn output_capture_is_capped_and_reports_observed_bytes() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + let mut input = b"\x1b]133;C\x07".to_vec(); + input.extend(std::iter::repeat_n(b'x', MAX_OUTPUT_CAPTURE_BYTES + 10)); + input.extend_from_slice(b"\x1b]133;D;0;history_id=big;session_id=session-1\x07"); + + tracker.push(&input, |capture| captures.push(capture)); + + assert_eq!(captures.len(), 1); + assert!(captures[0].output_truncated); + assert_eq!( + captures[0].output_observed_bytes, + (MAX_OUTPUT_CAPTURE_BYTES + 10) as u64 + ); + } + + #[test] + fn resets_buffers_between_c_d_only_captures() { + let mut tracker = tracker(80); + let mut captures = Vec::new(); + + tracker.push( + b"\x1b]133;C\x07first\r\n\x1b]133;D;0;history_id=one\x07\x1b]133;C\x07second\r\n\x1b]133;D;1;history_id=two\x07", + |capture| captures.push(capture), + ); + + assert_eq!(captures.len(), 2); + assert_eq!(captures[0].output, "first"); + assert_eq!(captures[0].history_id.as_deref(), Some("one")); + assert_eq!(captures[1].output, "second"); + assert_eq!(captures[1].history_id.as_deref(), Some("two")); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/debug.rs b/crates/turtle/src/atuin_pty_proxy/debug.rs new file mode 100644 index 00000000..bf311281 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/debug.rs @@ -0,0 +1,53 @@ +use crate::atuin_pty_proxy::osc133::{Event, Parser}; + +pub(crate) const RESET: &[u8] = b"\x1b[0m"; + +pub(crate) struct Osc133DebugHighlighter { + parser: Parser, +} + +impl Osc133DebugHighlighter { + pub(crate) fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub(crate) fn render(&mut self, data: &[u8]) -> Vec { + let mut events = Vec::new(); + self.parser + .push_located(data, |located| events.push(located)); + + if events.is_empty() { + return data.to_vec(); + } + + let mut rendered = Vec::with_capacity(data.len() + (events.len() * 64)); + let mut start = 0; + + for located in events { + let offset = located.offset.min(data.len()); + if offset > start { + rendered.extend_from_slice(&data[start..offset]); + } + + rendered.extend_from_slice(event_label(&located.event)); + rendered.extend_from_slice(RESET); + start = offset; + } + + rendered.extend_from_slice(&data[start..]); + rendered + } +} + +fn event_label(event: &Event) -> &'static [u8] { + match event { + Event::PromptStart => b"\x1b[1;37;45m[OSC133:A prompt]\x1b[0m", + Event::CommandStart => b"\x1b[1;30;43m[OSC133:B input]\x1b[0m", + Event::CommandExecuted => b"\x1b[1;30;46m[OSC133:C output]\x1b[0m", + Event::CommandFinished { exit_code: Some(0) } => b"\x1b[1;37;42m[OSC133:D exit=0]\x1b[0m", + Event::CommandFinished { exit_code: Some(_) } => b"\x1b[1;37;41m[OSC133:D exit!=0]\x1b[0m", + Event::CommandFinished { exit_code: None } => b"\x1b[1;37;44m[OSC133:D exit=?]\x1b[0m", + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/mod.rs b/crates/turtle/src/atuin_pty_proxy/mod.rs new file mode 100644 index 00000000..612943fa --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/mod.rs @@ -0,0 +1,17 @@ +#[cfg(unix)] +mod capture; +#[cfg(unix)] +mod debug; +#[cfg(unix)] +mod osc133; +#[cfg(unix)] +mod pty_proxy; +#[cfg(unix)] +mod runtime; +#[cfg(unix)] +mod screen; + +#[cfg(unix)] +pub use capture::{CommandCapture, CommandCaptureSink}; +#[cfg(unix)] +pub use pty_proxy::PtyProxy; diff --git a/crates/turtle/src/atuin_pty_proxy/osc133.rs b/crates/turtle/src/atuin_pty_proxy/osc133.rs new file mode 100644 index 00000000..5b70f0aa --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/osc133.rs @@ -0,0 +1,900 @@ +//! Streaming parser for OSC 133 (FinalTerm semantic prompt) escape sequences. +//! +//! OSC 133 marks four regions of a shell interaction: +//! +//! | Marker | Meaning | +//! |--------|--------------------------------------| +//! | A | Prompt is about to be printed | +//! | B | Prompt ended — command input begins | +//! | C | Command submitted — output begins | +//! | D[;n] | Command finished with exit code *n* | +//! +//! The wire format is `ESC ] 133 ; [; ] ST` where ST is BEL +//! (0x07), ESC \ (0x1B 0x5C), or C1 ST (0x9C). +//! +//! # Design goals +//! +//! * **Transparent** — the parser observes the byte stream without modifying it; +//! the caller remains responsible for forwarding bytes to their destination. +//! * **Bounded** — OSC parameter buffering is capped so malformed output cannot +//! grow memory without limit. +//! * **Non-blocking** — [`Parser::push`] processes whatever bytes are available +//! and returns immediately. +//! * **Extensible** — marker parameters are preserved so Atuin-specific metadata +//! can ride alongside standard OSC 133 markers. + +/// Events emitted when an OSC 133 marker is detected. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + /// `ESC ] 133 ; A ST` — the shell is about to display its prompt. + PromptStart, + /// `ESC ] 133 ; B ST` — the prompt has ended; the user may type a command. + CommandStart, + /// `ESC ] 133 ; C ST` — the command has been submitted for execution. + CommandExecuted, + /// `ESC ] 133 ; D [; ] ST` — command output is complete. + CommandFinished { + /// The exit code reported after the `;`, if present and valid. + exit_code: Option, + }, +} + +/// Parameters attached to an OSC 133 marker. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Params { + items: Vec, +} + +impl Params { + /// Iterate over all marker parameters in order. + #[cfg(test)] + #[inline] + pub fn iter(&self) -> impl Iterator { + self.items.iter() + } + + /// Return the value for the first `key=value` parameter with this key. + #[inline] + pub fn get(&self, key: &str) -> Option<&str> { + self.items.iter().find_map(|item| match item { + Param::KeyValue { + key: item_key, + value, + } if item_key == key => Some(value.as_str()), + Param::Value(_) | Param::KeyValue { .. } => None, + }) + } +} + +/// A single OSC 133 marker parameter. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Param { + /// A positional parameter without an equals sign. + Value(String), + /// A `key=value` parameter. + KeyValue { key: String, value: String }, +} + +/// An OSC 133 event with its position in the most recent input chunk. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LocatedEvent { + /// The OSC 133 event that was parsed. + pub event: Event, + /// Offset where this marker starts in the current chunk. + /// + /// If a marker started in an earlier [`Parser::push_located`] call, this is + /// `0` in the chunk that completed the marker. + pub start_offset: usize, + /// Offset immediately after this marker's terminator in the current chunk. + /// + /// If a marker spans multiple [`Parser::push_located`] calls, this is still + /// the offset in the chunk that completed the marker. + pub offset: usize, + /// The semantic zone after applying this event. + pub zone: Zone, + /// Metadata parameters attached to this marker. + pub params: Params, +} + +/// The current semantic zone as determined by the most recent OSC 133 marker. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +#[expect(dead_code)] +pub enum Zone { + /// No marker seen yet, or after a `D` marker (between commands). + #[default] + Unknown, + /// Between `A` and `B` — the shell is rendering its prompt. + Prompt, + /// Between `B` and `C` — the user is editing a command line. + Input, + /// Between `C` and `D` — command output is being produced. + Output, +} + +// --------------------------------------------------------------------------- +// Internal constants +// --------------------------------------------------------------------------- + +const ESC: u8 = 0x1B; +const BEL: u8 = 0x07; +const C1_ST: u8 = 0x9C; +const BACKSLASH: u8 = b'\\'; +const RIGHT_BRACKET: u8 = b']'; + +/// Maximum bytes we'll buffer for the OSC parameter string. This is large enough +/// for Atuin metadata such as history/session IDs while still bounding malformed +/// OSC sequences. +const PARAM_BUF_CAP: usize = 512; + +// --------------------------------------------------------------------------- +// State machine +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum State { + /// Normal pass-through. + Ground, + /// Saw ESC (0x1B). + Esc, + /// Inside an OSC sequence (`ESC ]`), accumulating parameter bytes. + OscParam, + /// Inside an OSC sequence, saw ESC — next byte decides if this is `ESC \` + /// (string terminator) or something else. + OscEsc, +} + +/// A streaming, zero-allocation parser for OSC 133 escape sequences. +/// +/// Feed arbitrary byte slices into [`Parser::push`]. The parser detects +/// OSC 133 markers and reports [`Event`]s through a caller-supplied callback +/// without modifying the data. It can sit transparently between a PTY reader +/// and stdout. +pub struct Parser { + state: State, + zone: Zone, + sequence_start: Option, + param_buf: [u8; PARAM_BUF_CAP], + param_len: usize, +} + +impl Default for Parser { + fn default() -> Self { + Self::new() + } +} + +impl Parser { + /// Create a new parser in the initial (ground / unknown-zone) state. + #[inline] + pub fn new() -> Self { + Self { + state: State::Ground, + zone: Zone::Unknown, + sequence_start: None, + param_buf: [0u8; PARAM_BUF_CAP], + param_len: 0, + } + } + + /// The current semantic zone based on markers seen so far. + #[inline] + #[expect(dead_code)] + pub fn zone(&self) -> Zone { + self.zone + } + + /// Start offset of an incomplete OSC sequence in the most recent chunk. + #[inline] + pub(crate) fn incomplete_osc_sequence_start(&self) -> Option { + matches!(self.state, State::OscParam | State::OscEsc) + .then(|| self.sequence_start.unwrap_or(0)) + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found. + /// + /// All bytes in `data` should still be forwarded to the terminal by the + /// caller — this method only *observes* the stream. + #[cfg(test)] + #[inline] + pub fn push(&mut self, data: &[u8], mut on_event: impl FnMut(Event)) { + self.push_located(data, |located| on_event(located.event)); + } + + /// Process a chunk of bytes, calling `on_event` for every OSC 133 marker + /// found with its byte offset in this chunk. + /// + /// The offset points to the first byte after the marker terminator, making + /// it suitable for callers that need to split the original chunk at marker + /// boundaries. + #[inline] + pub fn push_located(&mut self, data: &[u8], mut on_event: impl FnMut(LocatedEvent)) { + self.sequence_start = (self.state != State::Ground).then_some(0); + + for (offset, &byte) in data.iter().enumerate() { + match self.state { + State::Ground => { + if byte == ESC { + self.state = State::Esc; + self.sequence_start = Some(offset); + } + } + State::Esc => { + if byte == RIGHT_BRACKET { + self.state = State::OscParam; + self.param_len = 0; + } else { + self.state = State::Ground; + self.sequence_start = None; + } + } + State::OscParam => { + if byte == BEL || byte == C1_ST { + self.dispatch(offset + 1, &mut on_event); + self.state = State::Ground; + self.sequence_start = None; + } else if byte == ESC { + self.state = State::OscEsc; + } else if self.param_len < PARAM_BUF_CAP { + self.param_buf[self.param_len] = byte; + self.param_len += 1; + } + // If param_len == PARAM_BUF_CAP we silently stop + // accumulating — dispatch will ignore non-133 sequences. + } + State::OscEsc => { + if byte == BACKSLASH { + self.dispatch(offset + 1, &mut on_event); + } + // Whether we got a valid ST or not, return to ground. + // (A new ESC ] would restart accumulation via the Ground + // -> Esc -> OscParam path on the *next* byte.) + self.state = State::Ground; + self.sequence_start = None; + } + } + } + } + + /// Inspect the accumulated parameter buffer. If it holds an OSC 133 + /// payload, emit the corresponding [`Event`] and update the zone. + #[inline] + fn dispatch(&mut self, offset: usize, on_event: &mut impl FnMut(LocatedEvent)) { + let payload = &self.param_buf[..self.param_len]; + + if payload.len() < 5 || &payload[..4] != b"133;" { + return; + } + + if payload.len() > 5 && payload[5] != b';' { + return; + } + + let metadata = payload.get(6..).unwrap_or_default(); + let cmd = payload[4]; + let (event, params) = match cmd { + b'A' => { + self.zone = Zone::Prompt; + (Event::PromptStart, parse_params(metadata)) + } + b'B' => { + self.zone = Zone::Input; + (Event::CommandStart, parse_params(metadata)) + } + b'C' => { + self.zone = Zone::Output; + (Event::CommandExecuted, parse_params(metadata)) + } + b'D' => { + let (exit_code, params) = parse_command_finished_params(metadata); + self.zone = Zone::Unknown; + (Event::CommandFinished { exit_code }, params) + } + _ => return, + }; + + on_event(LocatedEvent { + event, + start_offset: self.sequence_start.unwrap_or(0), + offset, + zone: self.zone, + params, + }); + } +} + +fn parse_command_finished_params(metadata: &[u8]) -> (Option, Params) { + if metadata.is_empty() { + return (None, Params::default()); + } + + let Some(separator) = metadata.iter().position(|byte| *byte == b';') else { + return parse_exit_code(metadata).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), Params::default()), + ); + }; + + let (first, rest) = metadata.split_at(separator); + let rest = &rest[1..]; + + parse_exit_code(first).map_or_else( + || (None, parse_params(metadata)), + |exit_code| (Some(exit_code), parse_params(rest)), + ) +} + +fn parse_exit_code(code: &[u8]) -> Option { + if code.is_empty() { + return None; + } + + std::str::from_utf8(code) + .ok() + .and_then(|code| code.parse::().ok()) +} + +fn parse_params(metadata: &[u8]) -> Params { + let items = metadata + .split(|byte| *byte == b';') + .filter(|part| !part.is_empty()) + .map(parse_param) + .collect(); + + Params { items } +} + +fn parse_param(param: &[u8]) -> Param { + let param = String::from_utf8_lossy(param); + + if let Some((key, value)) = param.split_once('=') { + return Param::KeyValue { + key: key.to_string(), + value: value.to_string(), + }; + } + + Param::Value(param.into_owned()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + /// Collect all events from a single `push` call. + fn parse_events(data: &[u8]) -> Vec { + let mut parser = Parser::new(); + let mut events = Vec::new(); + parser.push(data, |e| events.push(e)); + events + } + + // -- Basic event detection ------------------------------------------------ + + #[test] + fn detect_prompt_start_bel() { + let data = b"\x1b]133;A\x07"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + #[test] + fn detect_prompt_start_st() { + let data = b"\x1b]133;A\x1b\\"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + #[test] + fn detect_command_start_bel() { + let data = b"\x1b]133;B\x07"; + assert_eq!(parse_events(data), vec![Event::CommandStart]); + } + + #[test] + fn detect_command_start_st() { + let data = b"\x1b]133;B\x1b\\"; + assert_eq!(parse_events(data), vec![Event::CommandStart]); + } + + #[test] + fn detect_command_executed_bel() { + let data = b"\x1b]133;C\x07"; + assert_eq!(parse_events(data), vec![Event::CommandExecuted]); + } + + #[test] + fn detect_command_executed_st() { + let data = b"\x1b]133;C\x1b\\"; + assert_eq!(parse_events(data), vec![Event::CommandExecuted]); + } + + #[test] + fn detect_command_finished_no_exit_code() { + let data = b"\x1b]133;D\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + #[test] + fn detect_command_finished_exit_zero() { + let data = b"\x1b]133;D;0\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: Some(0) }] + ); + } + + #[test] + fn detect_command_finished_exit_nonzero() { + let data = b"\x1b]133;D;127\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(127) + }] + ); + } + + #[test] + fn detect_command_finished_negative_exit_code() { + let data = b"\x1b]133;D;-1\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(-1) + }] + ); + } + + #[test] + fn detect_command_finished_exit_code_st() { + let data = b"\x1b]133;D;42\x1b\\"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(42) + }] + ); + } + + #[test] + fn invalid_exit_code_yields_none() { + let data = b"\x1b]133;D;abc\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + // -- Zone tracking -------------------------------------------------------- + + #[test] + fn zone_starts_unknown() { + let parser = Parser::new(); + assert_eq!(parser.zone(), Zone::Unknown); + } + + #[test] + fn full_zone_cycle() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;A\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Prompt); + + parser.push(b"\x1b]133;B\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Input); + + parser.push(b"\x1b]133;C\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Output); + + parser.push(b"\x1b]133;D;0\x07", |e| events.push(e)); + assert_eq!(parser.zone(), Zone::Unknown); + + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(0) }, + ] + ); + } + + // -- Multiple events in one push ------------------------------------------ + + #[test] + fn multiple_events_single_push() { + let data = b"\x1b]133;A\x07$ \x1b]133;B\x07ls\n\x1b]133;C\x07file.txt\n\x1b]133;D;0\x07"; + let events = parse_events(data); + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(0) }, + ] + ); + } + + // -- Split across push boundaries ----------------------------------------- + + #[test] + fn split_esc_and_bracket() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"]133;A\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn split_mid_param() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]13", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"3;D;42\x07", |e| events.push(e)); + assert_eq!( + events, + vec![Event::CommandFinished { + exit_code: Some(42) + }] + ); + } + + #[test] + fn split_before_terminator() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;B", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::CommandStart]); + } + + #[test] + fn split_esc_backslash_terminator() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b]133;C\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + parser.push(b"\\", |e| events.push(e)); + assert_eq!(events, vec![Event::CommandExecuted]); + } + + // -- Interleaved normal text ---------------------------------------------- + + #[test] + fn normal_text_before_and_after() { + let data = b"hello world\x1b]133;A\x07prompt text\x1b]133;B\x07command"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- Non-133 OSC sequences (should be ignored) ---------------------------- + + #[test] + fn non_133_osc_ignored() { + let data = b"\x1b]0;window title\x07\x1b]133;A\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn osc_7_ignored() { + let data = b"\x1b]7;file:///home/user\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Unknown command letter ----------------------------------------------- + + #[test] + fn unknown_command_ignored() { + let data = b"\x1b]133;Z\x07"; + assert!(parse_events(data).is_empty()); + } + + #[test] + fn marker_with_unexpected_trailing_bytes_ignored() { + let data = b"\x1b]133;ABC\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Malformed sequences -------------------------------------------------- + + #[test] + fn esc_followed_by_non_bracket() { + let data = b"\x1b[31m\x1b]133;A\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn lone_esc_at_end_of_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push(b"\x1b", |e| events.push(e)); + assert!(events.is_empty()); + + // Feed non-bracket to abort the escape, then a real sequence. + parser.push(b"x\x1b]133;A\x07", |e| events.push(e)); + assert_eq!(events, vec![Event::PromptStart]); + } + + #[test] + fn truncated_133_prefix() { + // "13" followed by terminator — not "133;" so no event. + let data = b"\x1b]13\x07"; + assert!(parse_events(data).is_empty()); + } + + #[test] + fn empty_osc() { + let data = b"\x1b]\x07"; + assert!(parse_events(data).is_empty()); + } + + // -- Buffer overflow (very long non-133 OSC) ------------------------------ + + #[test] + fn very_long_osc_does_not_panic() { + let mut data = Vec::new(); + data.extend_from_slice(b"\x1b]"); + data.extend(std::iter::repeat_n(b'x', 1000)); + data.push(BEL); + // Should not panic and should produce no event. + assert!(parse_events(&data).is_empty()); + } + + // -- Empty input ---------------------------------------------------------- + + #[test] + fn empty_input() { + assert!(parse_events(b"").is_empty()); + } + + #[test] + fn only_normal_text() { + let data = b"just some regular terminal output\r\n"; + assert!(parse_events(data).is_empty()); + } + + // -- Repeated prompts (empty command) ------------------------------------ + + #[test] + fn repeated_prompt_cycle() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + // User hits enter on an empty prompt twice. + let data = b"\x1b]133;A\x07$ \x1b]133;B\x07\x1b]133;D\x07\x1b]133;A\x07$ \x1b]133;B\x07"; + parser.push(data, |e| events.push(e)); + + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandFinished { exit_code: None }, + Event::PromptStart, + Event::CommandStart, + ] + ); + assert_eq!(parser.zone(), Zone::Input); + } + + // -- Byte-at-a-time feeding ----------------------------------------------- + + #[test] + fn byte_at_a_time() { + let data = b"\x1b]133;D;99\x07"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + for &byte in data { + parser.push(&[byte], |e| events.push(e)); + } + + assert_eq!( + events, + vec![Event::CommandFinished { + exit_code: Some(99) + }] + ); + } + + // -- Mixed terminators ---------------------------------------------------- + + #[test] + fn mixed_bel_and_st_terminators() { + let data = b"\x1b]133;A\x07\x1b]133;B\x1b\\\x1b]133;C\x07\x1b]133;D;1\x1b\\"; + let events = parse_events(data); + assert_eq!( + events, + vec![ + Event::PromptStart, + Event::CommandStart, + Event::CommandExecuted, + Event::CommandFinished { exit_code: Some(1) }, + ] + ); + } + + #[test] + fn detects_c1_st_terminator() { + let data = b"\x1b]133;A\x9c"; + assert_eq!(parse_events(data), vec![Event::PromptStart]); + } + + // -- Located event offsets ------------------------------------------------ + + #[test] + fn located_event_reports_offset_after_marker() { + let data = b"before\x1b]133;A\x07prompt"; + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(data, |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::PromptStart, + start_offset: b"before".len(), + offset: b"before\x1b]133;A\x07".len(), + zone: Zone::Prompt, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_offset_is_relative_to_completing_chunk() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;", |e| events.push(e)); + parser.push_located(b"D;42\x07after", |e| events.push(e)); + + assert_eq!( + events, + vec![LocatedEvent { + event: Event::CommandFinished { + exit_code: Some(42) + }, + start_offset: 0, + offset: b"D;42\x07".len(), + zone: Zone::Unknown, + params: Params::default(), + }] + ); + } + + #[test] + fn located_event_preserves_metadata_params() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located( + b"\x1b]133;D;127;history_id=018f;session_id=abcd;flag\x07", + |event| events.push(event), + ); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!( + event.event, + Event::CommandFinished { + exit_code: Some(127) + } + ); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + assert!( + event + .params + .iter() + .any(|param| param == &Param::Value("flag".to_string())) + ); + } + + #[test] + fn command_finished_metadata_without_exit_code_is_preserved() { + let mut parser = Parser::new(); + let mut events = Vec::new(); + + parser.push_located(b"\x1b]133;D;history_id=018f;session_id=abcd\x07", |event| { + events.push(event); + }); + + assert_eq!(events.len(), 1); + let event = &events[0]; + assert_eq!(event.event, Event::CommandFinished { exit_code: None }); + assert_eq!(event.params.get("history_id"), Some("018f")); + assert_eq!(event.params.get("session_id"), Some("abcd")); + } + + // -- Default trait -------------------------------------------------------- + + #[test] + fn parser_default() { + let parser = Parser::default(); + assert_eq!(parser.zone(), Zone::Unknown); + } + + #[test] + fn zone_default() { + assert_eq!(Zone::default(), Zone::Unknown); + } + + // -- D with empty exit code field ----------------------------------------- + + #[test] + fn d_with_semicolon_but_empty_code() { + // "133;D;" — semicolon present but no digits. + let data = b"\x1b]133;D;\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } + + // -- Consecutive OSC sequences without gap -------------------------------- + + #[test] + fn back_to_back_osc_no_gap() { + let data = b"\x1b]133;A\x07\x1b]133;B\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- CSI sequences interleaved (should not confuse parser) ---------------- + + #[test] + fn csi_sequences_ignored() { + // CSI (ESC [) color codes mixed with OSC 133. + let data = b"\x1b[32m\x1b]133;A\x07\x1b[0m$ \x1b]133;B\x07"; + let events = parse_events(data); + assert_eq!(events, vec![Event::PromptStart, Event::CommandStart]); + } + + // -- Large exit codes ----------------------------------------------------- + + #[test] + fn large_exit_code() { + let data = b"\x1b]133;D;2147483647\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { + exit_code: Some(i32::MAX) + }] + ); + } + + #[test] + fn overflow_exit_code_yields_none() { + let data = b"\x1b]133;D;9999999999999\x07"; + assert_eq!( + parse_events(data), + vec![Event::CommandFinished { exit_code: None }] + ); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs b/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs new file mode 100644 index 00000000..8dde6f53 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/pty_proxy.rs @@ -0,0 +1,231 @@ +use clap::{Args, Subcommand, ValueEnum}; + +use crate::atuin_pty_proxy::{CommandCaptureSink, runtime}; + +#[derive(Args, Debug)] +pub struct PtyProxy { + /// Highlight OSC 133 prompt, input, output, and exit-code regions + #[arg(long)] + debug_osc133: bool, + + #[command(subcommand)] + cmd: Option, +} + +#[derive(Subcommand, Debug)] +pub enum Cmd { + /// Print shell code to initialize atuin pty-proxy on shell startup + Init(Init), +} + +#[derive(Args, Debug)] +pub struct Init { + /// Shell to generate init for. If omitted, attempt auto-detection + #[arg(value_enum)] + shell: Option, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, ValueEnum)] +#[value(rename_all = "lower")] +#[expect(clippy::enum_variant_names, clippy::doc_markdown)] +enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, +} + +pub(crate) struct RuntimeOptions { + pub(crate) debug_osc133: bool, + pub(crate) command_capture_sink: Option, +} + +impl RuntimeOptions { + fn new(debug_osc133: bool, command_capture_sink: Option) -> Self { + Self { + debug_osc133: debug_osc133 || env_flag("ATUIN_PTY_PROXY_DEBUG"), + command_capture_sink, + } + } +} + +impl PtyProxy { + pub fn run(self, command_capture_sink: Option) { + match self.cmd { + Some(Cmd::Init(init)) => { + if let Err(err) = init.run() { + eprintln!("atuin pty-proxy: {err}"); + std::process::exit(1); + } + } + None => runtime::main(RuntimeOptions::new(self.debug_osc133, command_capture_sink)), + } + } +} + +impl Init { + fn run(self) -> Result<(), String> { + let shell = detect_shell(self.shell)?; + let script = render_init(shell); + print!("{script}"); + Ok(()) + } +} + +fn detect_shell(cli_shell: Option) -> Result { + if let Some(shell) = cli_shell { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("ATUIN_SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + if let Ok(shell) = std::env::var("SHELL") + && let Some(shell) = shell_from_name(&shell) + { + return Ok(shell); + } + + Err( + "could not detect a supported shell. Please specify one explicitly: bash, zsh, fish, or nu" + .to_string(), + ) +} + +fn shell_from_name(name: &str) -> Option { + let shell = name + .trim() + .rsplit('/') + .next() + .unwrap_or(name) + .trim_start_matches('-') + .to_ascii_lowercase(); + + match shell.as_str() { + "bash" => Some(Shell::Bash), + "zsh" => Some(Shell::Zsh), + "fish" => Some(Shell::Fish), + "nu" => Some(Shell::Nu), + _ => None, + } +} + +fn env_flag(name: &str) -> bool { + std::env::var(name).is_ok_and(|value| { + matches!( + value.trim().to_ascii_lowercase().as_str(), + "1" | "true" | "yes" | "on" + ) + }) +} + +fn render_init(shell: Shell) -> &'static str { + match shell { + Shell::Bash | Shell::Zsh => { + r#"if [[ "$-" == *i* ]] && [[ -t 0 ]] && [[ -t 1 ]]; then + _atuin_pty_proxy_tmux_current="${TMUX:-}" + _atuin_pty_proxy_tmux_previous="${ATUIN_PTY_PROXY_TMUX:-}" + + if [[ -z "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || [[ "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" ]]; then + export ATUIN_PTY_PROXY_ACTIVE=1 + export ATUIN_PTY_PROXY_TMUX="$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + fi + + unset _atuin_pty_proxy_tmux_current _atuin_pty_proxy_tmux_previous +fi +"# + } + Shell::Fish => { + r#"if status is-interactive; and test -t 0; and test -t 1 + set -l _atuin_pty_proxy_tmux_current "" + if set -q TMUX + set _atuin_pty_proxy_tmux_current "$TMUX" + end + + set -l _atuin_pty_proxy_tmux_previous "" + if set -q ATUIN_PTY_PROXY_TMUX + set _atuin_pty_proxy_tmux_previous "$ATUIN_PTY_PROXY_TMUX" + end + + if not set -q ATUIN_PTY_PROXY_ACTIVE + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + else if test "$_atuin_pty_proxy_tmux_current" != "$_atuin_pty_proxy_tmux_previous" + set -gx ATUIN_PTY_PROXY_ACTIVE 1 + set -gx ATUIN_PTY_PROXY_TMUX "$_atuin_pty_proxy_tmux_current" + exec atuin pty-proxy + end +end +"# + } + // Nushell cannot dynamically source the output of `atuin init nu`, + // so we only output the pty-proxy preamble here. Users must also set up + // `atuin init nu` separately. + Shell::Nu => { + r#"if (is-terminal --stdin) and (is-terminal --stdout) { + let tmux_current = ($env.TMUX? | default "") + let tmux_previous = ($env.ATUIN_PTY_PROXY_TMUX? | default "") + + if (($env.ATUIN_PTY_PROXY_ACTIVE? | default "") | is-empty) or ($tmux_current != $tmux_previous) { + $env.ATUIN_PTY_PROXY_ACTIVE = "1" + $env.ATUIN_PTY_PROXY_TMUX = $tmux_current + exec atuin pty-proxy + } +} +"# + } + } +} + +#[cfg(test)] +mod tests { + use super::{Shell, render_init, shell_from_name}; + + #[test] + fn shell_from_name_handles_paths() { + assert_eq!(shell_from_name("/bin/zsh"), Some(Shell::Zsh)); + assert_eq!(shell_from_name("/usr/local/bin/bash"), Some(Shell::Bash)); + assert_eq!(shell_from_name("fish"), Some(Shell::Fish)); + assert_eq!(shell_from_name("nu"), Some(Shell::Nu)); + } + + #[test] + fn posix_init_uses_exec_and_tmux_guard() { + let script = render_init(Shell::Bash); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(!script.contains("eval \"$(atuin init bash)\"")); + } + + #[test] + fn posix_init_has_no_double_braces() { + let script = render_init(Shell::Bash); + assert!(!script.contains("${{"), "double braces in bash init script"); + } + + #[test] + fn fish_init_uses_source() { + let script = render_init(Shell::Fish); + assert!(script.contains("exec atuin pty-proxy")); + assert!(!script.contains("atuin init fish | source")); + } + + #[test] + fn nu_init_uses_exec_and_tty_guard() { + let script = render_init(Shell::Nu); + assert!(script.contains("exec atuin pty-proxy")); + assert!(script.contains("ATUIN_PTY_PROXY_TMUX")); + assert!(script.contains("is-terminal --stdin")); + assert!(script.contains("is-terminal --stdout")); + assert!(script.contains("ATUIN_PTY_PROXY_ACTIVE")); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/runtime.rs b/crates/turtle/src/atuin_pty_proxy/runtime.rs new file mode 100644 index 00000000..37c77eef --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/runtime.rs @@ -0,0 +1,184 @@ +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU16, Ordering}; +use std::sync::mpsc; + +use crossterm::terminal; +use portable_pty::{CommandBuilder, PtySize, native_pty_system}; + +use crate::atuin_pty_proxy::capture::CommandCaptureTracker; +use crate::atuin_pty_proxy::debug::{Osc133DebugHighlighter, RESET}; +use crate::atuin_pty_proxy::pty_proxy::RuntimeOptions; +use crate::atuin_pty_proxy::screen::{self, Msg}; + +pub(crate) fn main(options: RuntimeOptions) { + if let Err(e) = run(options) { + let _ = terminal::disable_raw_mode(); + eprintln!("atuin pty-proxy: {e:#}"); + std::process::exit(1); + } +} + +fn run(options: RuntimeOptions) -> eyre::Result<()> { + let (cols, rows) = terminal::size()?; + + let pty_system = native_pty_system(); + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let sock_path = screen::socket_path(); + let _ = std::fs::remove_file(&sock_path); + + let mut cmd = CommandBuilder::new_default_prog(); + cmd.cwd(std::env::current_dir()?); + cmd.env("ATUIN_PTY_PROXY_SOCKET", sock_path.as_os_str()); + cmd.env("ATUIN_PTY_PROXY_ACTIVE", "1"); + + let mut child = pair + .slave + .spawn_command(cmd) + .map_err(|e| eyre::eyre!("{e:#}"))?; + + drop(pair.slave); + + let mut pty_reader = pair + .master + .try_clone_reader() + .map_err(|e| eyre::eyre!("{e:#}"))?; + let mut pty_writer = pair + .master + .take_writer() + .map_err(|e| eyre::eyre!("{e:#}"))?; + + let (msg_tx, msg_rx) = mpsc::sync_channel::(64); + let current_cols = Arc::new(AtomicU16::new(cols.max(1))); + + screen::spawn_parser_thread(rows, cols, msg_rx); + screen::spawn_socket_server(sock_path.clone(), msg_tx.clone()); + spawn_resize_handler(pair.master, msg_tx.clone(), current_cols.clone())?; + + terminal::enable_raw_mode()?; + + let stdout_thread = std::thread::spawn(move || { + let mut stdout = std::io::stdout(); + let mut highlighter = options.debug_osc133.then(Osc133DebugHighlighter::new); + let mut capture_tracker = options + .command_capture_sink + .as_ref() + .map(|_| CommandCaptureTracker::new(current_cols)); + let mut buf = [0u8; 8192]; + + loop { + match pty_reader.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if let (Some(tracker), Some(sink)) = ( + capture_tracker.as_mut(), + options.command_capture_sink.as_ref(), + ) { + tracker.push(&buf[..n], sink); + } + + if let Some(highlighter) = highlighter.as_mut() { + let rendered = highlighter.render(&buf[..n]); + let _ = msg_tx.try_send(Msg::Data(rendered.clone())); + + if stdout.write_all(&rendered).is_err() { + break; + } + } else { + let _ = msg_tx.try_send(Msg::Data(buf[..n].to_vec())); + + if stdout.write_all(&buf[..n]).is_err() { + break; + } + } + let _ = stdout.flush(); + } + } + } + + if highlighter.is_some() { + let _ = stdout.write_all(RESET); + let _ = stdout.flush(); + } + }); + + std::thread::spawn(move || { + let mut stdin = std::io::stdin(); + let mut buf = [0u8; 8192]; + loop { + match stdin.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + if pty_writer.write_all(&buf[..n]).is_err() { + break; + } + } + } + } + }); + + let status = child.wait()?; + let _ = stdout_thread.join(); + + let _ = terminal::disable_raw_mode(); + let _ = std::fs::remove_file(&sock_path); + + std::process::exit(process_exit_code(status.exit_code())); +} + +fn spawn_resize_handler( + master: Box, + resize_tx: mpsc::SyncSender, + current_cols: Arc, +) -> eyre::Result<()> { + use signal_hook::consts::SIGWINCH; + use signal_hook::iterator::Signals; + + let mut signals = Signals::new([SIGWINCH])?; + + std::thread::spawn(move || { + for _ in signals.forever() { + if let Ok((cols, rows)) = terminal::size() { + current_cols.store(cols.max(1), Ordering::Relaxed); + let _ = master.resize(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }); + let _ = resize_tx.try_send(Msg::Resize { rows, cols }); + } + } + }); + + Ok(()) +} + +fn process_exit_code(code: u32) -> i32 { + i32::try_from(code).unwrap_or(1) +} + +#[cfg(test)] +mod tests { + use super::process_exit_code; + + #[test] + fn process_exit_code_preserves_valid_values() { + assert_eq!(process_exit_code(0), 0); + assert_eq!(process_exit_code(127), 127); + assert_eq!(process_exit_code(i32::MAX as u32), i32::MAX); + } + + #[test] + fn process_exit_code_defaults_when_out_of_range() { + assert_eq!(process_exit_code(i32::MAX as u32 + 1), 1); + } +} diff --git a/crates/turtle/src/atuin_pty_proxy/screen.rs b/crates/turtle/src/atuin_pty_proxy/screen.rs new file mode 100644 index 00000000..5b892e21 --- /dev/null +++ b/crates/turtle/src/atuin_pty_proxy/screen.rs @@ -0,0 +1,104 @@ +use std::io::Write; +use std::os::unix::net::UnixListener; +use std::path::PathBuf; +use std::sync::mpsc::{self, Receiver, SyncSender}; + +pub(crate) enum Msg { + Data(Vec), + Resize { rows: u16, cols: u16 }, + ScreenRequest(mpsc::Sender>), +} + +pub(crate) fn socket_path() -> PathBuf { + let dir = std::env::temp_dir(); + dir.join(format!("atuin-pty-proxy-{}.sock", std::process::id())) +} + +pub(crate) fn spawn_parser_thread(rows: u16, cols: u16, msg_rx: Receiver) { + std::thread::spawn(move || { + let mut parser = vt100::Parser::new(rows, cols, 0); + + loop { + let first = match msg_rx.recv() { + Ok(msg) => msg, + Err(_) => break, + }; + + handle_parser_msg(&mut parser, first); + + while let Ok(msg) = msg_rx.try_recv() { + handle_parser_msg(&mut parser, msg); + } + } + }); +} + +pub(crate) fn spawn_socket_server(sock_path: PathBuf, screen_tx: SyncSender) { + std::thread::spawn(move || { + let listener = match UnixListener::bind(&sock_path) { + Ok(l) => l, + Err(e) => { + eprintln!("atuin pty-proxy: failed to bind socket: {e}"); + return; + } + }; + + for stream in listener.incoming() { + let mut stream = match stream { + Ok(s) => s, + Err(_) => break, + }; + + let (reply_tx, reply_rx) = mpsc::channel(); + if screen_tx.send(Msg::ScreenRequest(reply_tx)).is_err() { + break; + } + if let Ok(data) = reply_rx.recv() { + let _ = stream.write_all(&data); + let _ = stream.flush(); + } + } + }); +} + +/// Wire format written to the Unix socket: +/// +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +/// +/// Each row's bytes come from `screen.rows_formatted(0, cols)` and contain +/// pre-built ANSI escape sequences. The client can write them directly to +/// stdout without needing its own vt100 parser. +fn encode_screen(parser: &vt100::Parser) -> Vec { + let screen = parser.screen(); + let (rows, cols) = screen.size(); + let (cursor_row, cursor_col) = screen.cursor_position(); + + let mut buf: Vec = Vec::with_capacity(256 + (rows as usize * cols as usize)); + buf.extend_from_slice(&rows.to_be_bytes()); + buf.extend_from_slice(&cols.to_be_bytes()); + buf.extend_from_slice(&cursor_row.to_be_bytes()); + buf.extend_from_slice(&cursor_col.to_be_bytes()); + + for row_bytes in screen.rows_formatted(0, cols) { + let len = row_bytes.len() as u32; + buf.extend_from_slice(&len.to_be_bytes()); + buf.extend_from_slice(&row_bytes); + } + + buf +} + +fn handle_parser_msg(parser: &mut vt100::Parser, msg: Msg) { + match msg { + Msg::Data(data) => parser.process(&data), + Msg::Resize { rows, cols } => parser.screen_mut().set_size(rows, cols), + Msg::ScreenRequest(reply_tx) => { + let _ = reply_tx.send(encode_screen(parser)); + } + } +} diff --git a/crates/turtle/src/atuin_server/handlers/health.rs b/crates/turtle/src/atuin_server/handlers/health.rs new file mode 100644 index 00000000..aebd1e8f --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/health.rs @@ -0,0 +1,15 @@ +use axum::{Json, http, response::IntoResponse}; + +use serde::Serialize; + +#[derive(Serialize)] +pub struct HealthResponse { + pub status: &'static str, +} + +pub async fn health_check() -> impl IntoResponse { + ( + http::StatusCode::OK, + Json(HealthResponse { status: "healthy" }), + ) +} diff --git a/crates/turtle/src/atuin_server/handlers/history.rs b/crates/turtle/src/atuin_server/handlers/history.rs new file mode 100644 index 00000000..7f09161b --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/history.rs @@ -0,0 +1,237 @@ +use std::{collections::HashMap, convert::TryFrom}; + +use axum::{ + Json, + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, +}; +use metrics::counter; +use time::{Month, UtcOffset}; +use tracing::{debug, error, instrument}; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::{ + router::{AppState, UserAuth}, + utils::client_version_min, +}; +use crate::atuin_server_database::{ + Database, + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, +}; + +use crate::atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn count( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => Ok(Json(CountResponse { count })), + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + }, + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn list( + req: Query, + UserAuth(user): UserAuth, + headers: HeaderMap, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let agent = headers + .get("user-agent") + .map_or("", |v| v.to_str().unwrap_or("")); + + let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); + + let page_size = if variable_page_size { + state.settings.page_size + } else { + 100 + }; + + if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { + error!("client asked for history from < epoch 0"); + counter!("atuin_history_epoch_before_zero").increment(1); + + return Err( + ErrorResponse::reply("asked for history from before epoch 0") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + let history = db + .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + let history: Vec = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + counter!("atuin_history_returned").increment(history.len() as u64); + + Ok(Json(SyncHistoryResponse { history })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + UserAuth(user): UserAuth, + state: State>, + Json(req): Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + // user_id is the ID of the history, as set by the user (the server has its own ID) + let deleted = db.delete_history(&user, req.client_id).await; + + if let Err(e) = deleted { + error!("failed to delete history: {}", e); + return Err(ErrorResponse::reply("failed to delete history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + Ok(Json(MessageResponse { + message: String::from("deleted OK"), + })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn add( + UserAuth(user): UserAuth, + state: State>, + Json(req): Json>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + debug!("request to add {} history items", req.len()); + counter!("atuin_history_uploaded").increment(req.len() as u64); + + let mut history: Vec = req + .into_iter() + .map(|h| NewHistory { + client_id: h.id, + user_id: user.id, + hostname: h.hostname, + timestamp: h.timestamp, + data: h.data, + }) + .collect(); + + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + counter!("atuin_history_too_long").increment(1); + + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { + error!("failed to add history: {}", e); + + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[derive(serde::Deserialize, Debug)] +pub struct CalendarQuery { + #[serde(default = "serde_calendar::zero")] + year: i32, + #[serde(default = "serde_calendar::one")] + month: u8, + + #[serde(default = "serde_calendar::utc")] + tz: UtcOffset, +} + +mod serde_calendar { + use time::UtcOffset; + + pub fn zero() -> i32 { + 0 + } + + pub fn one() -> u8 { + 1 + } + + pub fn utc() -> UtcOffset { + UtcOffset::UTC + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn calendar( + Path(focus): Path, + Query(params): Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result>, ErrorResponseStatus<'static>> { + let focus = focus.as_str(); + + let year = params.year; + let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { + error: ErrorResponse { + reason: e.to_string().into(), + }, + status: StatusCode::BAD_REQUEST, + })?; + + let period = match focus { + "year" => TimePeriod::Year, + "month" => TimePeriod::Month { year }, + "day" => TimePeriod::Day { year, month }, + _ => { + return Err(ErrorResponse::reply("invalid focus: use year/month/day") + .with_status(StatusCode::BAD_REQUEST)); + } + }; + + let db = &state.0.database; + let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { + ErrorResponse::reply("failed to query calendar") + .with_status(StatusCode::INTERNAL_SERVER_ERROR) + })?; + + Ok(Json(focus)) +} diff --git a/crates/turtle/src/atuin_server/handlers/mod.rs b/crates/turtle/src/atuin_server/handlers/mod.rs new file mode 100644 index 00000000..7722d03e --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/mod.rs @@ -0,0 +1,60 @@ +use crate::atuin_common::api::{ErrorResponse, IndexResponse}; +use crate::atuin_server_database::Database; +use axum::{Json, extract::State, http, response::IntoResponse}; + +use crate::atuin_server::router::AppState; + +pub mod health; +pub mod history; +pub mod record; +pub mod status; +pub mod user; +pub mod v0; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub async fn index(state: State>) -> Json { + let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; + + let version = state + .settings + .fake_version + .clone() + .unwrap_or(VERSION.to_string()); + + Json(IndexResponse { + homage: homage.to_string(), + version, + }) +} + +impl IntoResponse for ErrorResponseStatus<'_> { + fn into_response(self) -> axum::response::Response { + (self.status, Json(self.error)).into_response() + } +} + +pub struct ErrorResponseStatus<'a> { + pub error: ErrorResponse<'a>, + pub status: http::StatusCode, +} + +pub trait RespExt<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a>; + fn reply(reason: &'a str) -> Self; +} + +impl<'a> RespExt<'a> for ErrorResponse<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { + ErrorResponseStatus { + error: self, + status, + } + } + + fn reply(reason: &'a str) -> ErrorResponse<'a> { + Self { + reason: reason.into(), + } + } +} diff --git a/crates/turtle/src/atuin_server/handlers/record.rs b/crates/turtle/src/atuin_server/handlers/record.rs new file mode 100644 index 00000000..63325606 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/record.rs @@ -0,0 +1,42 @@ +use axum::{Json, http::StatusCode, response::IntoResponse}; +use serde_json::json; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::UserAuth; + +use crate::atuin_common::record::{EncryptedData, Record}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post(UserAuth(user): UserAuth) -> Result<(), ErrorResponseStatus<'static>> { + // anyone who has actually used the old record store (a very small number) will see this error + // upon trying to sync. + // 1. The status endpoint will say that the server has nothing + // 2. The client will try to upload local records + // 3. Sync will fail with this error + + // If the client has no local records, they will see the empty index and do nothing. For the + // vast majority of users, this is the case. + return Err( + ErrorResponse::reply("record store deprecated; please upgrade") + .with_status(StatusCode::BAD_REQUEST), + ); +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index(UserAuth(user): UserAuth) -> axum::response::Response { + let ret = json!({ + "hosts": {} + }); + + ret.to_string().into_response() +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + UserAuth(user): UserAuth, +) -> Result>>, ErrorResponseStatus<'static>> { + let records = Vec::new(); + + Ok(Json(records)) +} diff --git a/crates/turtle/src/atuin_server/handlers/status.rs b/crates/turtle/src/atuin_server/handlers/status.rs new file mode 100644 index 00000000..0cf2ca1e --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/status.rs @@ -0,0 +1,45 @@ +use axum::{Json, extract::State, http::StatusCode}; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::{AppState, UserAuth}; +use crate::atuin_server_database::Database; + +use crate::atuin_common::api::*; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn status( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); + + let count = match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => count, + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => count, + Err(_) => { + return Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }, + }; + + tracing::debug!(user = user.username, "requested sync status"); + + Ok(Json(StatusResponse { + count, + deleted, + username: user.username, + version: VERSION.to_string(), + page_size: state.settings.page_size, + })) +} diff --git a/crates/turtle/src/atuin_server/handlers/user.rs b/crates/turtle/src/atuin_server/handlers/user.rs new file mode 100644 index 00000000..01b72202 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/user.rs @@ -0,0 +1,269 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::time::Duration; + +use argon2::{ + Algorithm, Argon2, Params, PasswordHash, PasswordHasher, PasswordVerifier, Version, + password_hash::SaltString, +}; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, +}; +use metrics::counter; + +use rand::rngs::OsRng; +use tracing::{debug, error, info, instrument}; + +use crate::atuin_common::tls::ensure_crypto_provider; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::atuin_server::router::{AppState, UserAuth}; +use crate::atuin_server_database::{ + Database, DbError, + models::{NewSession, NewUser}, +}; + +use reqwest::header::CONTENT_TYPE; + +use crate::atuin_common::{api::*, utils::crypto_random_string}; + +pub fn verify_str(hash: &str, password: &str) -> bool { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let Ok(hash) = PasswordHash::new(hash) else { + return false; + }; + arg2.verify_password(password.as_bytes(), &hash).is_ok() +} + +// Try to send a Discord webhook once - if it fails, we don't retry. "At most once", and best effort. +// Don't return the status because if this fails, we don't really care. +async fn send_register_hook(url: &str, username: String, registered: String) { + ensure_crypto_provider(); + let hook = HashMap::from([ + ("username", username), + ("content", format!("{registered} has just signed up!")), + ]); + + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .timeout(Duration::new(5, 0)) + .header(CONTENT_TYPE, "application/json") + .json(&hook) + .send() + .await; + + match resp { + Ok(_) => info!("register webhook sent ok!"), + Err(e) => error!("failed to send register webhook: {}", e), + } +} + +#[instrument(skip_all, fields(user.username = username.as_str()))] +pub async fn get( + Path(username): Path, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(username.as_ref()).await { + Ok(user) => user, + Err(DbError::NotFound) => { + debug!("user not found: {}", username); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error: {}", err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(UserResponse { + username: user.username, + })) +} + +#[instrument(skip_all)] +pub async fn register( + state: State>, + Json(register): Json, +) -> Result, ErrorResponseStatus<'static>> { + if !state.settings.open_registration { + return Err( + ErrorResponse::reply("this server is not open for registrations") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + for c in register.username.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => {} + _ => { + return Err(ErrorResponse::reply( + "Only alphanumeric and hyphens (-) are allowed in usernames", + ) + .with_status(StatusCode::BAD_REQUEST)); + } + } + } + + let hashed = hash_secret(®ister.password); + + let new_user = NewUser { + email: register.email.clone(), + username: register.username.clone(), + password: hashed, + }; + + let db = &state.0.database; + let user_id = match db.add_user(&new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) + ); + } + }; + + // 24 bytes encoded as base64 + let token = crypto_random_string::<24>(); + + let new_session = NewSession { + user_id, + token: (&token).into(), + }; + + if let Some(url) = &state.settings.register_webhook_url { + // Could probs be run on another thread, but it's ok atm + send_register_hook( + url, + state.settings.register_webhook_username.clone(), + register.username, + ) + .await; + } + + counter!("atuin_users_registered").increment(1); + + match db.add_session(&new_session).await { + Ok(_) => Ok(Json(RegisterResponse { + session: token, + auth: Some("cli".into()), + })), + Err(e) => { + error!("failed to add session: {}", e); + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) + } + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + debug!("request to delete user {}", user.id); + + let db = &state.0.database; + if let Err(e) = db.delete_user(&user).await { + error!("failed to delete user: {}", e); + + return Err(ErrorResponse::reply("failed to delete user") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + counter!("atuin_users_deleted").increment(1); + + Ok(Json(DeleteUserResponse {})) +} + +#[instrument(skip_all, fields(user.id = user.id, change_password))] +pub async fn change_password( + UserAuth(mut user): UserAuth, + state: State>, + Json(change_password): Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let verified = verify_str( + user.password.as_str(), + change_password.current_password.borrow(), + ); + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + let hashed = hash_secret(&change_password.new_password); + user.password = hashed; + + if let Err(e) = db.update_user_password(&user).await { + error!("failed to change user password: {}", e); + + return Err(ErrorResponse::reply("failed to change user password") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + Ok(Json(ChangePasswordResponse {})) +} + +#[instrument(skip_all, fields(user.username = login.username.as_str()))] +pub async fn login( + state: State>, + login: Json, +) -> Result, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(login.username.borrow()).await { + Ok(u) => u, + Err(DbError::NotFound) => { + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(e)) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(DbError::NotFound) => { + debug!("user session not found for user id={}", user.id); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error for user {}: {}", login.username, err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.borrow()); + + if !verified { + debug!(user = user.username, "login failed"); + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + debug!(user = user.username, "login success"); + + Ok(Json(LoginResponse { + session: session.token, + auth: Some("cli".into()), + })) +} + +fn hash_secret(password: &str) -> String { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let salt = SaltString::generate(&mut OsRng); + let hash = arg2.hash_password(password.as_bytes(), &salt).unwrap(); + hash.to_string() +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/me.rs b/crates/turtle/src/atuin_server/handlers/v0/me.rs new file mode 100644 index 00000000..a1e2db46 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/me.rs @@ -0,0 +1,16 @@ +use axum::Json; +use tracing::instrument; + +use crate::atuin_server::handlers::ErrorResponseStatus; +use crate::atuin_server::router::UserAuth; + +use crate::atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn get( + UserAuth(user): UserAuth, +) -> Result, ErrorResponseStatus<'static>> { + Ok(Json(MeResponse { + username: user.username, + })) +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/mod.rs b/crates/turtle/src/atuin_server/handlers/v0/mod.rs new file mode 100644 index 00000000..d6f880f2 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod me; +pub(crate) mod record; +pub(crate) mod store; diff --git a/crates/turtle/src/atuin_server/handlers/v0/record.rs b/crates/turtle/src/atuin_server/handlers/v0/record.rs new file mode 100644 index 00000000..9b147a52 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/record.rs @@ -0,0 +1,114 @@ +use axum::{Json, extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::atuin_server::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use crate::atuin_server_database::Database; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post( + UserAuth(user): UserAuth, + state: State>, + Json(records): Json>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + tracing::debug!( + count = records.len(), + user = user.username, + "request to add records" + ); + + counter!("atuin_record_uploaded").increment(records.len() as u64); + + let keep = records + .iter() + .all(|r| r.data.data.len() <= settings.max_record_size || settings.max_record_size == 0); + + if !keep { + counter!("atuin_record_too_large").increment(1); + + return Err( + ErrorResponse::reply("could not add records; record too large") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + if let Err(e) = database.add_records(&user, &records).await { + error!("failed to add record: {}", e); + + return Err(ErrorResponse::reply("failed to add record") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index( + UserAuth(user): UserAuth, + state: State>, +) -> Result, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + let record_index = match database.status(&user).await { + Ok(index) => index, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + tracing::debug!(user = user.username, "record index request"); + + Ok(Json(record_index)) +} + +#[derive(Deserialize)] +pub struct NextParams { + host: HostId, + tag: String, + start: Option, + count: u64, +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + params: Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result>>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + let params = params.0; + + let records = match database + .next_records(&user, params.host, params.tag, params.start, params.count) + .await + { + Ok(records) => records, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + counter!("atuin_record_downloaded").increment(records.len() as u64); + + Ok(Json(records)) +} diff --git a/crates/turtle/src/atuin_server/handlers/v0/store.rs b/crates/turtle/src/atuin_server/handlers/v0/store.rs new file mode 100644 index 00000000..cd184546 --- /dev/null +++ b/crates/turtle/src/atuin_server/handlers/v0/store.rs @@ -0,0 +1,37 @@ +use axum::{extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::atuin_server::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use crate::atuin_server_database::Database; + +#[derive(Deserialize)] +pub struct DeleteParams {} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete( + _params: Query, + UserAuth(user): UserAuth, + state: State>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + if let Err(e) = database.delete_store(&user).await { + counter!("atuin_store_delete_failed").increment(1); + error!("failed to delete store {e:?}"); + + return Err(ErrorResponse::reply("failed to delete store") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + counter!("atuin_store_deleted").increment(1); + + Ok(()) +} diff --git a/crates/turtle/src/atuin_server/metrics.rs b/crates/turtle/src/atuin_server/metrics.rs new file mode 100644 index 00000000..ebd0dd2d --- /dev/null +++ b/crates/turtle/src/atuin_server/metrics.rs @@ -0,0 +1,55 @@ +use std::time::Instant; + +use axum::{ + extract::{MatchedPath, Request}, + middleware::Next, + response::IntoResponse, +}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; + +pub fn setup_metrics_recorder() -> PrometheusHandle { + const EXPONENTIAL_SECONDS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, + ]; + + PrometheusBuilder::new() + .set_buckets_for_metric( + Matcher::Full("http_requests_duration_seconds".to_string()), + EXPONENTIAL_SECONDS, + ) + .unwrap() + .install_recorder() + .unwrap() +} + +/// Middleware to record some common HTTP metrics +/// Generic over B to allow for arbitrary body types (eg Vec, Streams, a deserialized thing, etc) +/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57 +pub async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { + let start = Instant::now(); + + let path = match req.extensions().get::() { + Some(matched_path) => matched_path.as_str().to_owned(), + _ => req.uri().path().to_owned(), + }; + + let method = req.method().clone(); + + // Run the rest of the request handling first, so we can measure it and get response + // codes. + let response = next.run(req).await; + + let latency = start.elapsed().as_secs_f64(); + let status = response.status().as_u16().to_string(); + + let labels = [ + ("method", method.to_string()), + ("path", path), + ("status", status), + ]; + + metrics::counter!("http_requests_total", &labels).increment(1); + metrics::histogram!("http_requests_duration_seconds", &labels).record(latency); + + response +} diff --git a/crates/turtle/src/atuin_server/mod.rs b/crates/turtle/src/atuin_server/mod.rs new file mode 100644 index 00000000..bd0f2168 --- /dev/null +++ b/crates/turtle/src/atuin_server/mod.rs @@ -0,0 +1,86 @@ +use std::future::Future; +use std::net::SocketAddr; + +use crate::atuin_server_database::Database; +use axum::{Router, serve}; +use eyre::{Context, Result}; + +mod handlers; +mod metrics; +mod router; +mod utils; + +pub use settings::Settings; + +pub mod settings; + +use tokio::net::TcpListener; +use tokio::signal; + +#[cfg(target_family = "unix")] +async fn shutdown_signal() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler"); + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register signal handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + eprintln!("Shutting down gracefully..."); +} + +pub async fn launch(settings: Settings, addr: SocketAddr) -> Result<()> { + launch_with_tcp_listener::( + settings, + TcpListener::bind(addr) + .await + .context("could not connect to socket")?, + shutdown_signal(), + ) + .await +} + +pub async fn launch_with_tcp_listener( + settings: Settings, + listener: TcpListener, + shutdown: impl Future + Send + 'static, +) -> Result<()> { + let r = make_router::(settings).await?; + + serve(listener, r.into_make_service()) + .with_graceful_shutdown(shutdown) + .await?; + + Ok(()) +} + +// The separate listener means it's much easier to ensure metrics are not accidentally exposed to +// the public. +pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { + let listener = TcpListener::bind((host, port)) + .await + .context("failed to bind metrics tcp")?; + + let recorder_handle = metrics::setup_metrics_recorder(); + + let router = Router::new().route( + "/metrics", + axum::routing::get(move || std::future::ready(recorder_handle.render())), + ); + + serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +async fn make_router(settings: Settings) -> Result { + let db = Db::new(&settings.db_settings) + .await + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; + let r = router::router(db, settings); + Ok(r) +} diff --git a/crates/turtle/src/atuin_server/router.rs b/crates/turtle/src/atuin_server/router.rs new file mode 100644 index 00000000..11a16148 --- /dev/null +++ b/crates/turtle/src/atuin_server/router.rs @@ -0,0 +1,155 @@ +use crate::atuin_common::api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ErrorResponse}; +use axum::{ + Router, + extract::{FromRequestParts, Request}, + http::{self, request::Parts}, + middleware::Next, + response::{IntoResponse, Response}, + routing::{delete, get, patch, post}, +}; +use eyre::Result; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; + +use super::handlers; +use crate::atuin_server::{ + handlers::{ErrorResponseStatus, RespExt}, + metrics, + settings::Settings, +}; +use crate::atuin_server_database::{Database, DbError, models::User}; + +pub struct UserAuth(pub User); + +impl FromRequestParts> for UserAuth +where + DB: Database, +{ + type Rejection = ErrorResponseStatus<'static>; + + async fn from_request_parts( + req: &mut Parts, + state: &AppState, + ) -> Result { + let auth_header = req + .headers + .get(http::header::AUTHORIZATION) + .ok_or_else(|| { + ErrorResponse::reply("missing authorization header") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let auth_header = auth_header.to_str().map_err(|_| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + + if typ != "Token" { + return Err( + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST), + ); + } + + let user = state + .database + .get_session_user(token) + .await + .map_err(|e| match e { + DbError::NotFound => ErrorResponse::reply("session not found") + .with_status(http::StatusCode::FORBIDDEN), + DbError::Other(e) => { + tracing::error!(error = ?e, "could not query user session"); + ErrorResponse::reply("could not query user session") + .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) + } + })?; + + Ok(UserAuth(user)) + } +} + +async fn teapot() -> impl IntoResponse { + // This used to return 418: 🫖 + // Much as it was fun, it wasn't as useful or informative as it should be + (http::StatusCode::NOT_FOUND, "404 not found") +} + +async fn clacks_overhead(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + + let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; + let gnu_terry_header = "X-Clacks-Overhead"; + + response + .headers_mut() + .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); + response +} + +/// Ensure that we only try and sync with clients on the same major version +async fn semver(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + response + .headers_mut() + .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); + + response +} + +#[derive(Clone)] +pub struct AppState { + pub database: DB, + pub settings: Settings, +} + +pub fn router(database: DB, settings: Settings) -> Router { + let mut routes = Router::new() + .route("/", get(handlers::index)) + .route("/healthz", get(handlers::health::health_check)); + + // Sync v1 routes - can be disabled in favor of record-based sync + if settings.sync_v1_enabled { + routes = routes + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/sync/calendar/{focus}", get(handlers::history::calendar)) + .route("/sync/status", get(handlers::status::status)) + .route("/history", post(handlers::history::add)) + .route("/history", delete(handlers::history::delete)); + } + + let routes = routes + .route("/user/{username}", get(handlers::user::get)) + .route("/account", delete(handlers::user::delete)) + .route("/account/password", patch(handlers::user::change_password)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .route("/record", post(handlers::record::post)) + .route("/record", get(handlers::record::index)) + .route("/record/next", get(handlers::record::next)) + .route("/api/v0/me", get(handlers::v0::me::get)) + .route("/api/v0/record", post(handlers::v0::record::post)) + .route("/api/v0/record", get(handlers::v0::record::index)) + .route("/api/v0/record/next", get(handlers::v0::record::next)) + .route("/api/v0/store", delete(handlers::v0::store::delete)); + + let path = settings.path.as_str(); + if path.is_empty() { + routes + } else { + Router::new().nest(path, routes) + } + .fallback(teapot) + .with_state(AppState { database, settings }) + .layer( + ServiceBuilder::new() + .layer(axum::middleware::from_fn(clacks_overhead)) + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(metrics::track_metrics)) + .layer(axum::middleware::from_fn(semver)), + ) +} diff --git a/crates/turtle/src/atuin_server/settings.rs b/crates/turtle/src/atuin_server/settings.rs new file mode 100644 index 00000000..f6650af0 --- /dev/null +++ b/crates/turtle/src/atuin_server/settings.rs @@ -0,0 +1,110 @@ +use std::{io::prelude::*, path::PathBuf}; + +use crate::atuin_server_database::DbSettings; +use config::{Config, Environment, File as ConfigFile, FileFormat}; +use eyre::{Result, eyre}; +use fs_err::{File, create_dir_all}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Metrics { + #[serde(alias = "enabled")] + pub enable: bool, + pub host: String, + pub port: u16, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + enable: false, + host: String::from("127.0.0.1"), + port: 9001, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings { + pub host: String, + pub port: u16, + pub path: String, + pub open_registration: bool, + pub max_history_length: usize, + pub max_record_size: usize, + pub page_size: i64, + pub register_webhook_url: Option, + pub register_webhook_username: String, + pub metrics: Metrics, + + /// Enable legacy sync v1 routes (history-based sync) + /// Set to false to use only the newer record-based sync + pub sync_v1_enabled: bool, + + /// Advertise a version that is not what we are _actually_ running + /// Many clients compare their version with api.atuin.sh, and if they differ, notify the user + /// that an update is available. + /// Now that we take beta releases, we should be able to advertise a different version to avoid + /// notifying users when the server runs something that is not a stable release. + pub fake_version: Option, + + #[serde(flatten)] + pub db_settings: DbSettings, +} + +impl Settings { + pub fn new() -> Result { + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + let config_dir = crate::atuin_common::utils::config_dir(); + config_file.push(config_dir); + config_file + }; + + config_file.push("server.toml"); + + // create the config file if it does not exist + let mut config_builder = Config::builder() + .set_default("host", "127.0.0.1")? + .set_default("port", 8888)? + .set_default("open_registration", false)? + .set_default("max_history_length", 8192)? + .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky + .set_default("path", "")? + .set_default("register_webhook_username", "")? + .set_default("page_size", 1100)? + .set_default("metrics.enable", false)? + .set_default("metrics.host", "127.0.0.1")? + .set_default("metrics.port", 9001)? + .set_default("sync_v1_enabled", true)? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ); + + let config = if config_file.exists() { + config_builder + .add_source(ConfigFile::new( + config_file.to_str().unwrap(), + FileFormat::Toml, + )) + .build()? + } else { + create_dir_all(config_file.parent().unwrap())?; + let mut file = File::create(config_file)?; + + let config = config_builder.build()?; + // TODO(@bpeetz): I'm quiet unsure, if this will work <2026-06-10> + file.write_all(config.cache.to_string().as_bytes())?; + + config + }; + + config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e)) + } +} diff --git a/crates/turtle/src/atuin_server/utils.rs b/crates/turtle/src/atuin_server/utils.rs new file mode 100644 index 00000000..12e9ac1b --- /dev/null +++ b/crates/turtle/src/atuin_server/utils.rs @@ -0,0 +1,15 @@ +use eyre::Result; +use semver::{Version, VersionReq}; + +pub fn client_version_min(user_agent: &str, req: &str) -> Result { + if user_agent.is_empty() { + return Ok(false); + } + + let version = user_agent.replace("atuin/", ""); + + let req = VersionReq::parse(req)?; + let version = Version::parse(version.as_str())?; + + Ok(req.matches(&version)) +} diff --git a/crates/turtle/src/atuin_server_database/calendar.rs b/crates/turtle/src/atuin_server_database/calendar.rs new file mode 100644 index 00000000..2229667b --- /dev/null +++ b/crates/turtle/src/atuin_server_database/calendar.rs @@ -0,0 +1,18 @@ +// Calendar data + +use serde::{Deserialize, Serialize}; +use time::Month; + +pub enum TimePeriod { + Year, + Month { year: i32 }, + Day { year: i32, month: Month }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/turtle/src/atuin_server_database/mod.rs b/crates/turtle/src/atuin_server_database/mod.rs new file mode 100644 index 00000000..91077b84 --- /dev/null +++ b/crates/turtle/src/atuin_server_database/mod.rs @@ -0,0 +1,266 @@ +pub mod calendar; +pub mod models; + +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + ops::Range, +}; + +use self::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use serde::{Deserialize, Serialize}; +use time::{Date, Duration, Month, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; +use tracing::instrument; + +#[derive(Debug)] +pub enum DbError { + NotFound, + Other(eyre::Report), +} + +impl Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From for DbError { + fn from(error: time::error::ComponentRange) -> Self { + DbError::Other(error.into()) + } +} + +impl From for DbError { + fn from(error: time::error::Error) -> Self { + DbError::Other(error.into()) + } +} + +impl From for DbError { + fn from(error: sqlx::Error) -> Self { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } + } +} + +impl std::error::Error for DbError {} + +pub type DbResult = Result; + +#[derive(Debug, PartialEq)] +pub enum DbType { + Postgres, + Sqlite, + Unknown, +} + +#[derive(Clone, Deserialize, Serialize)] +pub struct DbSettings { + pub db_uri: String, + /// Optional URI for read replicas. If set, read-only queries will use this connection. + pub read_db_uri: Option, +} + +impl DbSettings { + pub fn db_type(&self) -> DbType { + if self.db_uri.starts_with("postgres://") || self.db_uri.starts_with("postgresql://") { + DbType::Postgres + } else if self.db_uri.starts_with("sqlite://") { + DbType::Sqlite + } else { + DbType::Unknown + } + } +} + +fn redact_db_uri(uri: &str) -> String { + url::Url::parse(uri) + .map(|mut url| { + let _ = url.set_password(Some("****")); + url.to_string() + }) + .unwrap_or_else(|_| uri.to_string()) +} + +// Do our best to redact passwords so they're not logged in the event of an error. +impl Debug for DbSettings { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.db_type() == DbType::Postgres { + let redacted_uri = redact_db_uri(&self.db_uri); + let redacted_read_uri = self.read_db_uri.as_ref().map(|uri| redact_db_uri(uri)); + f.debug_struct("DbSettings") + .field("db_uri", &redacted_uri) + .field("read_db_uri", &redacted_read_uri) + .finish() + } else { + f.debug_struct("DbSettings") + .field("db_uri", &self.db_uri) + .field("read_db_uri", &self.read_db_uri) + .finish() + } + } +} + +#[async_trait] +pub trait Database: Sized + Clone + Send + Sync + 'static { + async fn new(settings: &DbSettings) -> DbResult; + + async fn get_session(&self, token: &str) -> DbResult; + async fn get_session_user(&self, token: &str) -> DbResult; + async fn add_session(&self, session: &NewSession) -> DbResult<()>; + + async fn get_user(&self, username: &str) -> DbResult; + async fn get_user_session(&self, u: &User) -> DbResult; + async fn add_user(&self, user: &NewUser) -> DbResult; + + async fn update_user_password(&self, u: &User) -> DbResult<()>; + + async fn count_history(&self, user: &User) -> DbResult; + async fn count_history_cached(&self, user: &User) -> DbResult; + + async fn delete_user(&self, u: &User) -> DbResult<()>; + async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; + async fn deleted_history(&self, user: &User) -> DbResult>; + async fn delete_store(&self, user: &User) -> DbResult<()>; + + async fn add_records(&self, user: &User, record: &[Record]) -> DbResult<()>; + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>>; + + // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) + async fn status(&self, user: &User) -> DbResult; + + async fn count_history_range(&self, user: &User, range: Range) + -> DbResult; + + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult>; + + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; + + async fn oldest_history(&self, user: &User) -> DbResult; + + #[instrument(skip_all)] + async fn calendar( + &self, + user: &User, + period: TimePeriod, + tz: UtcOffset, + ) -> DbResult> { + let mut ret = HashMap::new(); + let iter: Box)>> + Send> = match period { + TimePeriod::Year => { + // First we need to work out how far back to calculate. Get the + // oldest history item + let oldest = self + .oldest_history(user) + .await? + .timestamp + .to_offset(tz) + .year(); + let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); + + // All the years we need to get data for + // The upper bound is exclusive, so include current +1 + let years = oldest..current_year + 1; + + Box::new(years.map(|year| { + let start = Date::from_calendar_date(year, time::Month::January, 1)?; + let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; + + Ok((year as u64, start..end)) + })) + } + + TimePeriod::Month { year } => { + let months = + std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); + + Box::new(months.map(move |month| { + let start = Date::from_calendar_date(year, month, 1)?; + let days = start.month().length(year); + let end = start + Duration::days(days as i64); + + Ok((month as u64, start..end)) + })) + } + + TimePeriod::Day { year, month } => { + let days = 1..month.length(year); + Box::new(days.map(move |day| { + let start = Date::from_calendar_date(year, month, day)?; + let end = start + .next_day() + .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; + + Ok((day as u64, start..end)) + })) + } + }; + + for x in iter { + let (index, range) = x?; + + let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); + let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); + + let count = self.count_history_range(user, start..end).await?; + + ret.insert( + index, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } +} + +pub fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use crate::into_utc; + + #[test] + fn utc() { + let dt = datetime!(2023-09-26 15:11:02 +05:30); + assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 -07:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 +00:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + } +} diff --git a/crates/turtle/src/atuin_server_database/models.rs b/crates/turtle/src/atuin_server_database/models.rs new file mode 100644 index 00000000..b71a9bc9 --- /dev/null +++ b/crates/turtle/src/atuin_server_database/models.rs @@ -0,0 +1,52 @@ +use time::OffsetDateTime; + +pub struct History { + pub id: i64, + pub client_id: String, // a client generated ID + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, + + pub created_at: OffsetDateTime, +} + +pub struct NewHistory { + pub client_id: String, + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, +} + +pub struct User { + pub id: i64, + pub username: String, + pub email: String, + pub password: String, +} + +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, +} + +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, +} + +pub struct NewSession { + pub user_id: i64, + pub token: String, +} diff --git a/crates/turtle/src/atuin_server_postgres/mod.rs b/crates/turtle/src/atuin_server_postgres/mod.rs new file mode 100644 index 00000000..f506cf25 --- /dev/null +++ b/crates/turtle/src/atuin_server_postgres/mod.rs @@ -0,0 +1,583 @@ +use std::collections::HashMap; +use std::ops::Range; + +use rand::Rng; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use crate::atuin_server_database::models::{ + History, NewHistory, NewSession, NewUser, Session, User, +}; +use crate::atuin_server_database::{Database, DbError, DbResult, DbSettings, into_utc}; +use async_trait::async_trait; +use futures_util::TryStreamExt; +use sqlx::Row; +use sqlx::postgres::PgPoolOptions; + +use time::OffsetDateTime; +use tracing::instrument; +use uuid::Uuid; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +const MIN_PG_VERSION: u32 = 14; + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool, + /// Optional read replica pool for read-only queries + read_pool: Option>, +} + +impl Postgres { + /// Returns the appropriate pool for read operations. + /// Uses read_pool if available, otherwise falls back to the primary pool. + fn read_pool(&self) -> &sqlx::Pool { + self.read_pool.as_ref().unwrap_or(&self.pool) + } +} + +#[async_trait] +impl Database for Postgres { + async fn new(settings: &DbSettings) -> DbResult { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(settings.db_uri.as_str()) + .await?; + + // Call server_version_num to get the DB server's major version number + // The call returns None for servers older than 8.x. + let pg_major_version: u32 = + pool.acquire() + .await? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version", + )))? + / 10000; + + if pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {pg_major_version}, minimum required is {MIN_PG_VERSION}" + )))); + } + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + // Create read replica pool if configured + let read_pool = if let Some(read_db_uri) = &settings.read_db_uri { + tracing::info!("Connecting to read replica database"); + let read_pool = PgPoolOptions::new() + .max_connections(100) + .connect(read_db_uri.as_str()) + .await?; + + // Verify the read replica is also a supported PostgreSQL version + let read_pg_major_version: u32 = read_pool + .acquire() + .await? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version from read replica", + )))? + / 10000; + + if read_pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {read_pg_major_version} on read replica, minimum required is {MIN_PG_VERSION}" + )))); + } + + Some(read_pool) + } else { + None + }; + + Ok(Self { pool, read_pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, user: &User) -> DbResult { + let res: (i32,) = sqlx::query_as( + "select total from total_history_count_user + where user_id = $1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0 as i64) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&mut *tx) + .await?; + + sqlx::query( + "delete from store_idx_cache + where user_id = $1", + ) + .bind(user.id) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(self.read_pool()) + .await?; + + let res = res + .iter() + .map(|row| row.get::("client_id")) + .collect(); + + Ok(res) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: Range, + ) -> DbResult { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(self.read_pool()) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(self.read_pool()) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from store where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from total_history_count_user where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(self.read_pool()) + .await + .map_err(Into::into) + .map(|DbHistory(h)| h) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + // We won't have uploaded this data if it wasn't the max. Therefore, we can deduce the max + // idx without having to make further database queries. Doing the query on this small + // amount of data should be much, much faster. + // + // Worst case, say we get this wrong. We end up caching data that isn't actually the max + // idx, so clients upload again. The cache logic can be verified with a sql query anyway :) + + let mut heads = HashMap::<(HostId, &str), u64>::new(); + + for i in records { + let id = crate::atuin_common::utils::uuid_v7(); + + let result = sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await?; + + // Only update heads if we actually inserted the record + if result.rows_affected() > 0 { + heads + .entry((i.host.id, &i.tag)) + .and_modify(|e| { + if i.idx > *e { + *e = i.idx + } + }) + .or_insert(i.idx); + } + } + + // we've built the map of heads for this push, so commit it to the database + for ((host, tag), idx) in heads { + sqlx::query( + "insert into store_idx_cache + (user_id, host, tag, idx) + values ($1, $2, $3, $4) + on conflict(user_id, host, tag) do update set idx = greatest(store_idx_cache.idx, $4) + ", + ) + .bind(user.id) + .bind(host) + .bind(tag) + .bind(idx as i64) + .execute(&mut *tx) + .await + ?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(self.read_pool()) + .await + .map_err(Into::into); + + let ret = match records { + Ok(records) => { + let records: Vec> = records + .into_iter() + .map(|f| { + let record: Record = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + // If IDX_CACHE_ROLLOUT is set, then we + // 1. Read the value of the var, use it as a % chance of using the cache + // 2. If we use the cache, just read from the cache table + // 3. If we don't use the cache, read from the store table + // IDX_CACHE_ROLLOUT should be between 0 and 100. + + let idx_cache_rollout = std::env::var("IDX_CACHE_ROLLOUT").unwrap_or("0".to_string()); + let idx_cache_rollout = idx_cache_rollout.parse::().unwrap_or(0.0); + let use_idx_cache = rand::thread_rng().gen_bool(idx_cache_rollout / 100.0); + + let mut res: Vec<(Uuid, String, i64)> = if use_idx_cache { + tracing::debug!("using idx cache for user {}", user.id); + sqlx::query_as("select host, tag, idx from store_idx_cache where user_id = $1") + .bind(user.id) + .fetch_all(self.read_pool()) + .await? + } else { + tracing::debug!("using aggregate query for user {}", user.id); + sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(self.read_pool()) + .await? + }; + + res.sort(); + + let mut status = RecordStatus::new(); + + for i in res.iter() { + status.set_raw(HostId(i.0), i.1.clone(), i.2 as u64); + } + + Ok(status) + } +} diff --git a/crates/turtle/src/atuin_server_postgres/wrappers.rs b/crates/turtle/src/atuin_server_postgres/wrappers.rs new file mode 100644 index 00000000..214b255d --- /dev/null +++ b/crates/turtle/src/atuin_server_postgres/wrappers.rs @@ -0,0 +1,77 @@ +use ::sqlx::{FromRow, Result}; +use crate::atuin_common::record::{EncryptedData, Host, Record}; +use crate::atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, postgres::PgRow}; +use time::PrimitiveDateTime; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record); + +impl<'a> FromRow<'a, PgRow> for DbUser { + fn from_row(row: &'a PgRow) -> Result { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { + fn from_row(row: &'a PgRow) -> ::sqlx::Result { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { + fn from_row(row: &'a PgRow) -> ::sqlx::Result { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row + .try_get::("timestamp")? + .assume_utc(), + data: row.try_get("data")?, + created_at: row + .try_get::("created_at")? + .assume_utc(), + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { + fn from_row(row: &'a PgRow) -> ::sqlx::Result { + let timestamp: i64 = row.try_get("timestamp")?; + let idx: i64 = row.try_get("idx")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From for Record { + fn from(other: DbRecord) -> Record { + Record { ..other.0 } + } +} diff --git a/crates/turtle/src/atuin_server_sqlite/mod.rs b/crates/turtle/src/atuin_server_sqlite/mod.rs new file mode 100644 index 00000000..3470a2f1 --- /dev/null +++ b/crates/turtle/src/atuin_server_sqlite/mod.rs @@ -0,0 +1,430 @@ +use std::str::FromStr; + +use crate::atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use crate::atuin_server_database::{ + Database, DbError, DbResult, DbSettings, into_utc, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use futures_util::TryStreamExt; +use sqlx::{ + Row, + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions}, + types::Uuid, +}; +use tracing::instrument; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +#[derive(Clone)] +pub struct Sqlite { + pool: sqlx::Pool, +} + +#[async_trait] +impl Database for Sqlite { + async fn new(settings: &DbSettings) -> DbResult { + let opts = SqliteConnectOptions::from_str(&settings.db_uri)? + .journal_mode(SqliteJournalMode::Wal) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new().connect_with(opts).await?; + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, _user: &User) -> DbResult { + Err(DbError::NotFound) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(time::OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await?; + + let res = res.iter().map(|row| row.get("client_id")).collect(); + + Ok(res) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in records { + let id = crate::atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option, + count: u64, + ) -> DbResult>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(Into::into); + + let ret = match records { + Ok(records) => { + let records: Vec> = records + .into_iter() + .map(|f| { + let record: Record = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: std::ops::Range, + ) -> DbResult { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: time::OffsetDateTime, + since: time::OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await?; + } + + tx.commit().await?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(Into::into) + .map(|DbHistory(h)| h) + } +} diff --git a/crates/turtle/src/atuin_server_sqlite/wrappers.rs b/crates/turtle/src/atuin_server_sqlite/wrappers.rs new file mode 100644 index 00000000..5aa7a982 --- /dev/null +++ b/crates/turtle/src/atuin_server_sqlite/wrappers.rs @@ -0,0 +1,72 @@ +use ::sqlx::{FromRow, Result}; +use crate::atuin_common::record::{EncryptedData, Host, Record}; +use crate::atuin_server_database::models::{History, Session, User}; +use sqlx::{Row, sqlite::SqliteRow}; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record); + +impl<'a> FromRow<'a, SqliteRow> for DbUser { + fn from_row(row: &'a SqliteRow) -> Result { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row.try_get("timestamp")?, + data: row.try_get("data")?, + created_at: row.try_get("created_at")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord { + fn from_row(row: &'a SqliteRow) -> ::sqlx::Result { + let idx: i64 = row.try_get("idx")?; + let timestamp: i64 = row.try_get("timestamp")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From for Record { + fn from(other: DbRecord) -> Record { + Record { ..other.0 } + } +} diff --git a/crates/turtle/src/command/CONTRIBUTORS b/crates/turtle/src/command/CONTRIBUTORS new file mode 120000 index 00000000..1ca4115a --- /dev/null +++ b/crates/turtle/src/command/CONTRIBUTORS @@ -0,0 +1 @@ +../../../../CONTRIBUTORS \ No newline at end of file diff --git a/crates/turtle/src/command/client.rs b/crates/turtle/src/command/client.rs new file mode 100644 index 00000000..20d85303 --- /dev/null +++ b/crates/turtle/src/command/client.rs @@ -0,0 +1,371 @@ +use std::fs::{self, OpenOptions}; +use std::path::{Path, PathBuf}; + +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings, theme, +}; +use tracing_appender::rolling::{RollingFileAppender, Rotation}; +use tracing_subscriber::{ + Layer, filter::EnvFilter, filter::LevelFilter, fmt, fmt::format::FmtSpan, prelude::*, +}; + +fn cleanup_old_logs(log_dir: &Path, prefix: &str, retention_days: u64) { + let cutoff = std::time::SystemTime::now() + - std::time::Duration::from_secs(retention_days * 24 * 60 * 60); + + let Ok(entries) = fs::read_dir(log_dir) else { + return; + }; + + for entry in entries.flatten() { + let path = entry.path(); + let Some(name) = path.file_name().and_then(|n| n.to_str()) else { + continue; + }; + + // Match files like "search.log.2024-02-23" or "daemon.log.2024-02-23" + if !name.starts_with(prefix) || name == prefix { + continue; + } + + if let Ok(metadata) = entry.metadata() + && let Ok(modified) = metadata.modified() + && modified < cutoff + { + let _ = fs::remove_file(&path); + } + } +} + +#[cfg(feature = "sync")] +mod sync; + +#[cfg(feature = "sync")] +mod account; + +#[cfg(feature = "daemon")] +mod daemon; + +mod config; +mod default_config; +mod doctor; +mod history; +mod import; +mod info; +mod init; +mod search; +mod server; +mod setup; +mod stats; +mod store; +mod wrapped; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Setup Atuin features + #[command()] + Setup, + + /// Manipulate shell history + #[command(subcommand)] + History(history::Cmd), + + /// Import shell history from file + #[command(subcommand)] + Import(import::Cmd), + + /// Calculate statistics for your history + Stats(stats::Cmd), + + /// Interactive history search + Search(search::Cmd), + + #[cfg(feature = "sync")] + #[command(flatten)] + Sync(sync::Cmd), + + /// Manage the atuin server + #[command(subcommand)] + Server(server::Cmd), + + /// Manage your sync account + #[cfg(feature = "sync")] + Account(account::Cmd), + + /// Manage the atuin data store + #[command(subcommand)] + Store(store::Cmd), + + /// Print Atuin's shell init script + #[command()] + Init(init::Cmd), + + /// Information about dotfiles locations and ENV vars + #[command()] + Info, + + /// Run the doctor to check for common issues + #[command()] + Doctor, + + #[command()] + Wrapped { year: Option }, + + /// *Experimental* Manage the background daemon + #[cfg(feature = "daemon")] + #[command()] + Daemon(daemon::Cmd), + + /// Print the default atuin configuration (config.toml) + #[command()] + DefaultConfig, + + #[command(subcommand)] + Config(config::Cmd), +} + +impl Cmd { + pub fn run(self) -> Result<()> { + // Daemonize before creating the async runtime – fork() inside a live + // tokio runtime corrupts its internal state. + #[cfg(all(unix, feature = "daemon"))] + if let Self::Daemon(ref cmd) = self + && cmd.should_daemonize() + { + daemon::daemonize_current_process()?; + } + + let mut runtime = tokio::runtime::Builder::new_current_thread(); + + let runtime = runtime.enable_all().build().unwrap(); + + // For non-history commands, we want to initialize logging and the theme manager before + // doing anything else. History commands are performance-sensitive and run before and after + // every shell command, so we want to skip any unnecessary initialization for them. + let settings = Settings::new().wrap_err("could not load client settings")?; + let theme_manager = theme::ThemeManager::new(settings.theme.debug, None); + let res = runtime.block_on(self.run_inner(settings, theme_manager)); + + runtime.shutdown_timeout(std::time::Duration::from_millis(50)); + + res + } + + #[expect(clippy::too_many_lines, clippy::future_not_send)] + async fn run_inner( + self, + mut settings: Settings, + mut theme_manager: theme::ThemeManager, + ) -> Result<()> { + // ATUIN_LOG env var overrides config file level settings + let env_log_set = std::env::var("ATUIN_LOG").is_ok(); + + // Base filter from env var (or empty if not set) + let base_filter = + EnvFilter::from_env("ATUIN_LOG").add_directive("sqlx_sqlite::regexp=off".parse()?); + + let is_interactive_search = matches!(&self, Self::Search(cmd) if cmd.is_interactive()); + // Use file-based logging for interactive search (TUI mode) + let use_search_logging = is_interactive_search && settings.logs.search_enabled(); + + // Use file-based logging for daemon + #[cfg(feature = "daemon")] + let use_daemon_logging = matches!(&self, Self::Daemon(_)) && settings.logs.daemon_enabled(); + + #[cfg(not(feature = "daemon"))] + let use_daemon_logging = false; + + // Check if daemon should also log to console + #[cfg(feature = "daemon")] + let daemon_show_logs = matches!(&self, Self::Daemon(cmd) if cmd.show_logs()); + + #[cfg(not(feature = "daemon"))] + let daemon_show_logs = false; + + // Set up span timing JSON logs if ATUIN_SPAN is set + let span_path = std::env::var("ATUIN_SPAN").ok().map(|p| { + if p.is_empty() { + "atuin-spans.json".to_string() + } else { + p + } + }); + + // Helper to create span timing layer + macro_rules! make_span_layer { + ($path:expr) => {{ + let span_file = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open($path)?; + Some( + fmt::layer() + .json() + .with_writer(span_file) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .with_filter(LevelFilter::TRACE), + ) + }}; + } + + // Build the subscriber with all configured layers + if use_search_logging { + let search_filename = settings.logs.search.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &search_filename, settings.logs.search_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &search_filename); + + // Use config level unless ATUIN_LOG is set + let filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.search_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let base = tracing_subscriber::registry().with( + fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(filter), + ); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else if use_daemon_logging { + let daemon_filename = settings.logs.daemon.file.clone(); + let log_dir = PathBuf::from(&settings.logs.dir); + fs::create_dir_all(&log_dir)?; + + // Clean up old log files + cleanup_old_logs(&log_dir, &daemon_filename, settings.logs.daemon_retention()); + + let file_appender = + RollingFileAppender::new(Rotation::DAILY, &log_dir, &daemon_filename); + + // Use config level unless ATUIN_LOG is set + let file_filter = if env_log_set { + base_filter + } else { + EnvFilter::default() + .add_directive(settings.logs.daemon_level().as_directive().parse()?) + .add_directive("sqlx_sqlite::regexp=off".parse()?) + }; + + let file_layer = fmt::layer() + .with_writer(file_appender) + .with_ansi(false) + .with_filter(file_filter); + + // Optionally add console layer for --show-logs + if daemon_show_logs { + let console_filter = EnvFilter::from_env("ATUIN_LOG") + .add_directive("sqlx_sqlite::regexp=off".parse()?); + + let console_layer = fmt::layer().with_filter(console_filter); + + let base = tracing_subscriber::registry() + .with(file_layer) + .with(console_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } else { + let base = tracing_subscriber::registry().with(file_layer); + + match &span_path { + Some(sp) => { + base.with(make_span_layer!(sp)).init(); + } + None => { + base.init(); + } + } + } + } + + tracing::trace!(command = ?self, "client command"); + + // Skip initializing any databases for history + // This is a pretty hot path, as it runs before and after every single command the user + // runs + match self { + Self::History(history) => return history.run(&settings).await, + Self::Init(init) => { + init.run(&settings); + return Ok(()); + } + Self::Doctor => return doctor::run(&settings).await, + Self::Config(config) => return config.run(&settings).await, + _ => {} + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let theme_name = settings.theme.name.clone(); + let theme = theme_manager.load_theme(theme_name.as_str(), settings.theme.max_depth); + + match self { + Self::Setup => setup::run(&settings).await, + Self::Import(import) => import.run(&db).await, + Self::Stats(stats) => stats.run(&db, &settings, theme).await, + Self::Search(search) => search.run(db, &mut settings, sqlite_store, theme).await, + + #[cfg(feature = "sync")] + Self::Sync(sync) => sync.run(settings, &db, sqlite_store).await, + + #[cfg(feature = "sync")] + Self::Account(account) => account.run(settings, sqlite_store).await, + + Self::Store(store) => store.run(&settings, &db, sqlite_store).await, + + Self::Server(server) => server.run().await, + + Self::Info => { + info::run(&settings); + Ok(()) + } + + Self::DefaultConfig => { + default_config::run(); + Ok(()) + } + + Self::Wrapped { year } => wrapped::run(year, &db, &settings, theme).await, + + #[cfg(feature = "daemon")] + Self::Daemon(cmd) => cmd.run(settings, sqlite_store, db).await, + + Self::History(_) | Self::Init(_) | Self::Doctor | Self::Config(_) => { + unreachable!() + } + } + } +} diff --git a/crates/turtle/src/command/client/account.rs b/crates/turtle/src/command/client/account.rs new file mode 100644 index 00000000..898f1ac4 --- /dev/null +++ b/crates/turtle/src/command/client/account.rs @@ -0,0 +1,47 @@ +use clap::{Args, Subcommand}; +use eyre::Result; + +use crate::atuin_client::record::sqlite_store::SqliteStore; +use crate::atuin_client::settings::Settings; + +pub mod change_password; +pub mod delete; +pub mod login; +pub mod logout; +pub mod register; + +#[derive(Args, Debug)] +pub struct Cmd { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Login to the configured server + Login(login::Cmd), + + /// Register a new account + Register(register::Cmd), + + /// Log out + Logout, + + /// Delete your account, and all synced data + Delete(delete::Cmd), + + /// Change your password + ChangePassword(change_password::Cmd), +} + +impl Cmd { + pub async fn run(self, settings: Settings, store: SqliteStore) -> Result<()> { + match self.command { + Commands::Login(l) => l.run(&settings, &store).await, + Commands::Register(r) => r.run(&settings).await, + Commands::Logout => logout::run().await, + Commands::Delete(d) => d.run(&settings).await, + Commands::ChangePassword(c) => c.run(&settings).await, + } + } +} diff --git a/crates/turtle/src/command/client/account/change_password.rs b/crates/turtle/src/command/client/account/change_password.rs new file mode 100644 index 00000000..6112b0df --- /dev/null +++ b/crates/turtle/src/command/client/account/change_password.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub current_password: Option, + + #[clap(long, short)] + pub new_password: Option, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let current_password = self.current_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the current password: ") + .expect("Failed to read from input") + }); + + if current_password.is_empty() { + bail!("please provide the current password"); + } + + let new_password = self.new_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the new password: ").expect("Failed to read from input") + }); + + if new_password.is_empty() { + bail!("please provide a new password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .change_password(¤t_password, &new_password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(super::login::or_user_input(None, "two-factor code")); + } + } + } + + println!("Account password successfully changed!"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/delete.rs b/crates/turtle/src/command/client/account/delete.rs new file mode 100644 index 00000000..bcb40bc3 --- /dev/null +++ b/crates/turtle/src/command/client/account/delete.rs @@ -0,0 +1,57 @@ +use crate::atuin_client::{ + auth::{self, MutateResponse}, + settings::Settings, +}; +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::{or_user_input, read_user_password}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub password: Option, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in"); + } + + let client = auth::auth_client(settings).await; + + let password = self.password.clone().unwrap_or_else(read_user_password); + + if password.is_empty() { + bail!("please provide your password"); + } + + let mut totp_code = self.totp_code.clone(); + + loop { + let response = client + .delete_account(&password, totp_code.as_deref()) + .await?; + + match response { + MutateResponse::Success => break, + MutateResponse::TwoFactorRequired => { + totp_code = Some(or_user_input(None, "two-factor code")); + } + } + } + + // Clean up sessions from meta store + let meta = Settings::meta_store().await?; + meta.delete_session().await?; + + println!("Your account is deleted"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/account/login.rs b/crates/turtle/src/command/client/account/login.rs new file mode 100644 index 00000000..0c5b66f5 --- /dev/null +++ b/crates/turtle/src/command/client/account/login.rs @@ -0,0 +1,206 @@ +use std::{io, path::PathBuf}; + +use clap::Parser; +use eyre::{Context, Result, bail}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + auth::{self, AuthResponse}, + encryption::{decode_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + record::sync::{self, SyncError}, + settings::{Settings, SyncAuth}, +}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option, + + #[clap(long, short)] + pub password: Option, + + /// The encryption key for your account + #[clap(long, short)] + pub key: Option, + + /// The two-factor authentication code for your account, if any + #[clap(long, short)] + pub totp_code: Option, + + #[clap(long, hide = true)] + pub from_registration: bool, +} + +fn get_input() -> Result { + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + Ok(input.trim_end_matches(&['\r', '\n'][..]).to_string()) +} + +impl Cmd { + pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are logged in to your sync server."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + SyncAuth::NotLoggedIn { .. } => {} + } + + self.run_legacy_login(settings, store).await?; + + verify_key_against_remote(settings).await + } + + /// Legacy login: always prompt for username/password interactively + /// (or accept them via flags). + async fn run_legacy_login(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let username = or_user_input(self.username.clone(), "username"); + let password = self.password.clone().unwrap_or_else(read_user_password); + + self.prompt_and_store_key(settings, store).await?; + + let client = auth::auth_client(settings).await; + let response = client.login(&username, &password).await?; + + match response { + AuthResponse::Success { session, .. } => { + Settings::meta_store().await?.save_session(&session).await?; + } + AuthResponse::TwoFactorRequired => { + // Legacy server doesn't support 2FA, so this shouldn't happen. + bail!("unexpected two-factor requirement from legacy server"); + } + } + + println!("Logged in!"); + Ok(()) + } + + async fn prompt_and_store_key(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + println!("IMPORTANT"); + println!( + "If you are already logged in on another machine, you must ensure that the key you use here is the same as the key you used there." + ); + println!("You can find your key by running 'atuin key' on the other machine."); + println!("Do not share this key with anyone."); + println!("\nRead more here: https://docs.atuin.sh/guide/sync/#login \n"); + + let key = or_user_input( + self.key.clone(), + "encryption key [blank to use existing key file]", + ); + + if key.is_empty() { + if key_path.exists() { + let bytes = fs_err::read_to_string(&key_path).context(format!( + "Existing key file at '{}' could not be read", + key_path.to_string_lossy() + ))?; + if decode_key(bytes).is_err() { + bail!(format!( + "The key in existing key file at '{}' is invalid", + key_path.to_string_lossy() + )); + } + } else { + panic!( + "No key provided and no existing key file found. Please use 'atuin key' on your other machine, or recover your key from a backup" + ) + } + } else if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("The specified key is invalid"); + } + + let mut file = File::create(&key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("Could not decode provided key; is not valid base64-encoded key")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(&key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + Ok(()) + } +} + +async fn verify_key_against_remote(settings: &Settings) -> Result<()> { + let key: [u8; 32] = load_key(settings) + .context("could not load encryption key for verification")? + .into(); + + let client = sync::build_client(settings).await?; + let remote_index = match client.record_status().await { + Ok(idx) => idx, + Err(e) => { + tracing::warn!("could not fetch remote status to verify key: {e}"); + return Ok(()); + } + }; + + match sync::check_encryption_key(&client, &remote_index, &key).await { + Ok(()) => Ok(()), + Err(SyncError::WrongKey) => { + // Roll back the saved session so the user is not left in a + // half-authenticated state with a key that can't read the data. + if let Ok(meta) = Settings::meta_store().await { + let _ = meta.delete_session().await; + } + crate::print_error::print_error( + "Wrong encryption key", + "The encryption key on this machine does not match the data on the server. \ + You have been logged out.\n\n\ + To fix this, find your existing key by running `atuin key` on a machine that \ + already syncs successfully, then run `atuin login` again here with that key.", + ); + std::process::exit(1); + } + Err(e) => { + // Non-key error (e.g. transient network issue). Don't fail the + // login — the user is authenticated and can sync later when the + // network recovers. + tracing::warn!("could not verify encryption key against remote: {e}"); + Ok(()) + } + } +} + +pub(super) fn or_user_input(value: Option, name: &'static str) -> String { + value.unwrap_or_else(|| read_user_input(name)) +} + +pub(super) fn read_user_password() -> String { + let password = prompt_password("Please enter password: "); + password.expect("Failed to read from input") +} + +fn read_user_input(name: &'static str) -> String { + eprint!("Please enter {name}: "); + get_input().expect("Failed to read from input") +} diff --git a/crates/turtle/src/command/client/account/logout.rs b/crates/turtle/src/command/client/account/logout.rs new file mode 100644 index 00000000..6150a52b --- /dev/null +++ b/crates/turtle/src/command/client/account/logout.rs @@ -0,0 +1,5 @@ +use eyre::Result; + +pub async fn run() -> Result<()> { + crate::atuin_client::logout::logout().await +} diff --git a/crates/turtle/src/command/client/account/register.rs b/crates/turtle/src/command/client/account/register.rs new file mode 100644 index 00000000..548c2739 --- /dev/null +++ b/crates/turtle/src/command/client/account/register.rs @@ -0,0 +1,67 @@ +use clap::Parser; +use eyre::{Result, bail}; + +use super::login::or_user_input; +use crate::atuin_client::settings::{Settings, SyncAuth}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option, + + #[clap(long, short)] + pub password: Option, + + #[clap(long, short)] + pub email: Option, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings) -> Result<()> { + match settings.resolve_sync_auth().await { + SyncAuth::Legacy { .. } => { + println!("You are already logged in."); + println!("Run 'atuin logout' to log out."); + return Ok(()); + } + + SyncAuth::NotLoggedIn { .. } => {} + } + + // Legacy registration flow + println!("Registering for an Atuin Sync account"); + + let username = or_user_input(self.username.clone(), "username"); + let email = or_user_input(self.email.clone(), "email"); + let password = self + .password + .clone() + .unwrap_or_else(super::login::read_user_password); + + if password.is_empty() { + bail!("please provide a password"); + } + + let session = crate::atuin_client::api_client::register( + settings.sync_address.as_str(), + &username, + &email, + &password, + ) + .await?; + + let meta = Settings::meta_store().await?; + meta.save_session(&session.session).await?; + + let _key = crate::atuin_client::encryption::load_key(settings)?; + + println!( + "Registration successful! Please make a note of your key (run 'atuin key') and keep it safe." + ); + println!( + "You will need it to log in on other devices, and we cannot help recover it if you lose it." + ); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/config.rs b/crates/turtle/src/command/client/config.rs new file mode 100644 index 00000000..1597a8d6 --- /dev/null +++ b/crates/turtle/src/command/client/config.rs @@ -0,0 +1,352 @@ +use crate::atuin_client::settings::Settings; +use clap::{Args, Subcommand, ValueEnum}; +use eyre::Result; +use toml_edit::{Document, DocumentMut, Item, Table, TableLike, Value}; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Get a configuration value from your config.toml file + /// or after defaults and overrides are applied + #[command()] + Get(GetCmd), + + /// Set a configuration value in your config.toml file + #[command()] + Set(SetCmd), + + /// Print all configuration values from your config.toml file + /// in TOML format + /// + /// If a key is provided, only print the value of that key and all its children + #[command()] + Print(PrintCmd), +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Get(get) => get.run(settings).await, + Self::Set(set) => set.run(settings).await, + Self::Print(print) => print.run(settings).await, + } + } +} + +/// Get a configuration value from your config.toml file, +/// or optionally the effective value after defaults and overrides are applied. +#[derive(Args, Debug)] +pub struct GetCmd { + /// The configuration key to get + pub key: String, + + /// Print the value after defaults and overrides are applied + #[arg(long, short)] + pub resolved: bool, + + /// Print both the config file value and the resolved value + #[arg(long, short)] + pub verbose: bool, +} + +impl GetCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + if self.verbose { + println!("Config file:"); + self.print_current_value(key, " ").await?; + println!("\nResolved:"); + Self::print_effective_value(key, " "); + return Ok(()); + } + + if self.resolved { + Self::print_effective_value(key, ""); + } else { + self.print_current_value(key, "").await?; + } + + Ok(()) + } + + async fn print_current_value(&self, key: &str, prefix: &str) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::>()?; + + let current = get_deep_key(&doc, key); + + match current { + Some(item) if item.is_table() || item.is_inline_table() => { + let table = item + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + println!("{prefix}[{key}]"); + dump_table(table, prefix, &mut vec![key.to_string()])?; + } + Some(item) => { + let val = item.to_string(); + let val = val.trim().trim_matches('"'); + println!("{prefix}{val}"); + } + None => { + println!("{prefix}(not set in config file)"); + } + } + + Ok(()) + } + + fn print_effective_value(key: &str, prefix: &str) { + match Settings::get_config_value(key) { + Ok(value) => { + for line in value.lines() { + println!("{prefix}{line}"); + } + } + Err(_) => { + println!("{prefix}(unknown key)"); + } + } + } +} + +#[derive(Args, Debug)] +pub struct SetCmd { + /// The configuration key to set + pub key: String, + + /// The value to set + pub value: String, + + /// Store value as an explicit type + #[arg(long = "type", short, value_enum, default_value_t = ValueType::Auto, value_name = "TYPE")] + pub the_type: ValueType, +} + +#[derive(ValueEnum, Debug, Clone, PartialEq, Eq)] +pub enum ValueType { + /// Automatically determine the type of the value + Auto, + /// Store value as a string + String, + /// Store value as a boolean + Boolean, + /// Store value as an integer + Integer, + /// Store the value as a float + Float, +} + +impl SetCmd { + pub async fn run(self, _settings: &Settings) -> Result<()> { + let key = self.key.trim(); + if key.is_empty() || key.contains(char::is_whitespace) { + eyre::bail!("Config key must be non-empty and must not contain whitespace"); + } + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc: DocumentMut = config_str.parse()?; + + // When using auto type detection, try to match the existing value's type + // so we don't accidentally change e.g. "300" (string) to 300 (integer) + let existing_type = detect_existing_type(&doc, key); + let value = self.parse_value(existing_type.as_ref())?; + set_deep_key(&mut doc, key, value)?; + + tokio::fs::write(&config_file, doc.to_string()).await?; + + Ok(()) + } + + fn parse_value(&self, existing_type: Option<&ValueType>) -> Result { + let raw = &self.value; + + // Explicit --type takes priority, then existing value type, then auto-detect + let effective_type = if self.the_type != ValueType::Auto { + &self.the_type + } else if let Some(existing) = existing_type { + existing + } else { + &ValueType::Auto + }; + + match effective_type { + ValueType::String => Ok(Value::from(raw.as_str())), + ValueType::Boolean => { + let b: bool = raw + .parse() + .map_err(|_| eyre::eyre!("invalid boolean value: {raw}"))?; + Ok(Value::from(b)) + } + ValueType::Integer => { + let i: i64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid integer value: {raw}"))?; + Ok(Value::from(i)) + } + ValueType::Float => { + let f: f64 = raw + .parse() + .map_err(|_| eyre::eyre!("invalid float value: {raw}"))?; + Ok(Value::from(f)) + } + ValueType::Auto => { + if raw == "true" || raw == "false" { + return Ok(Value::from(raw == "true")); + } + if let Ok(i) = raw.parse::() { + return Ok(Value::from(i)); + } + if let Ok(f) = raw.parse::() { + return Ok(Value::from(f)); + } + Ok(Value::from(raw.as_str())) + } + } + } +} + +#[derive(Args, Debug)] +pub struct PrintCmd { + /// Print the value of a specific key and all its children + pub key: Option, +} + +impl PrintCmd { + pub async fn run(&self, _settings: &Settings) -> Result<()> { + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let doc = config_str.parse::>()?; + + if let Some(key) = &self.key { + let current = get_deep_key(&doc, key); + + if let Some(current) = current { + if current.is_table() || current.is_inline_table() { + println!("[{key}]"); + dump_table( + current + .as_table_like() + .expect("is_table()/is_inline_table() but no table"), + "", + &mut vec![key.clone()], + )?; + } else { + println!("{}", current.to_string().trim().trim_matches('"')); + } + } else { + println!("key not found"); + } + } else { + dump_table(doc.as_table(), "", &mut Vec::new())?; + } + + Ok(()) + } +} + +fn dump_table(table: &dyn TableLike, prefix: &str, stack: &mut Vec) -> Result<()> { + for (key, value) in table.iter() { + if value.is_table() || value.is_inline_table() { + stack.push(key.to_string()); + + let table = value + .as_table_like() + .expect("is_table()/is_inline_table() but no table"); + + println!("\n{}[{}]", prefix, stack.join(".")); + + dump_table(table, prefix, stack)?; + + stack.pop(); + } else { + println!("{prefix}{key} = {value}"); + } + } + + Ok(()) +} + +fn get_deep_key<'doc>(doc: &'doc Document, key: &str) -> Option<&'doc Item> { + let parts = key.split('.'); + let mut current: Option<&Item> = Some(doc.as_item()); + + for part in parts { + current = current + .and_then(|item| item.as_table_like()) + .and_then(|table| table.get(part)); + } + + current +} + +/// Detect the TOML type of an existing key in the document, so `set` with auto +/// type detection preserves the original type rather than guessing from the value string. +fn detect_existing_type(doc: &DocumentMut, key: &str) -> Option { + let parts: Vec<&str> = key.split('.').collect(); + let mut current: &dyn TableLike = doc.as_table(); + + for &part in &parts[..parts.len().saturating_sub(1)] { + current = current.get(part)?.as_table_like()?; + } + + let last = parts.last()?; + let v = current.get(last)?.as_value()?; + + if v.is_str() { + Some(ValueType::String) + } else if v.is_bool() { + Some(ValueType::Boolean) + } else if v.is_integer() { + Some(ValueType::Integer) + } else if v.is_float() { + Some(ValueType::Float) + } else { + None + } +} + +fn set_deep_key(doc: &mut DocumentMut, key: &str, value: Value) -> Result<()> { + let parts: Vec<&str> = key.split('.').collect(); + + if parts.is_empty() { + eyre::bail!("empty config key"); + } + + let mut current: &mut dyn TableLike = doc.as_table_mut(); + + // Navigate/create intermediate tables + for &part in &parts[..parts.len() - 1] { + if !current.contains_key(part) { + current.insert(part, Item::Table(Table::new())); + } + current = current + .get_mut(part) + .expect("just inserted or already exists") + .as_table_like_mut() + .ok_or_else(|| eyre::eyre!("'{}' exists but is not a table", part))?; + } + + let last = *parts.last().unwrap(); + + // Don't silently overwrite a table with a scalar value + if let Some(existing) = current.get(last) + && (existing.is_table() || existing.is_inline_table()) + { + eyre::bail!( + "'{}' is a table; use a dotted key like '{}.key' to set a value within it", + key, + key + ); + } + + current.insert(last, Item::Value(value)); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/daemon.rs b/crates/turtle/src/command/client/daemon.rs new file mode 100644 index 00000000..2ee9b759 --- /dev/null +++ b/crates/turtle/src/command/client/daemon.rs @@ -0,0 +1,769 @@ +use std::fs::{self, File, OpenOptions}; +use std::io::{ErrorKind, Write}; +#[cfg(unix)] +use std::os::unix::net::UnixStream as StdUnixStream; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{Duration, Instant}; + +use crate::atuin_client::{ + database::Sqlite, history::History, record::sqlite_store::SqliteStore, settings::Settings, +}; +use crate::atuin_daemon::DaemonEvent; +use crate::atuin_daemon::client::{ + ControlClient, DaemonClientErrorKind, HistoryClient, classify_error, +}; +use clap::Subcommand; +#[cfg(unix)] +use daemonize::Daemonize; +use eyre::{Result, WrapErr, bail, eyre}; +use fs4::fs_std::FileExt; +use tokio::time::sleep; + +#[derive(clap::Args, Debug)] +pub struct Cmd { + /// Internal flag for daemonization + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + #[command(subcommand)] + subcmd: Option, +} + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum SubCmd { + /// Start the daemon server + Start { + #[arg(long, hide = true)] + daemonize: bool, + + /// Also write daemon logs to the console (useful for debugging) + #[arg(long)] + show_logs: bool, + + /// Force start: kill existing daemon process and reset the socket + #[arg(long)] + force: bool, + }, + + /// Show the daemon's current status + Status, + + /// Stop the daemon gracefully + Stop, + + /// Restart the daemon (stop, then start in background) + Restart, +} + +impl Cmd { + /// Returns `true` when the process should daemonize before creating the + /// async runtime or opening any database connections. + #[cfg(unix)] + pub fn should_daemonize(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { daemonize, .. }) => *daemonize, + None => self.daemonize, + _ => false, + } + } + + /// Returns `true` when logs should also be written to the console. + pub fn show_logs(&self) -> bool { + match &self.subcmd { + Some(SubCmd::Start { show_logs, .. }) => *show_logs, + None => self.show_logs, + _ => false, + } + } + + pub async fn run( + self, + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + ) -> Result<()> { + match self.subcmd { + None => { + eprintln!("Warning: `atuin daemon` is deprecated, use `atuin daemon start`"); + run(settings, store, history_db, false).await + } + Some(SubCmd::Start { force, .. }) => run(settings, store, history_db, force).await, + Some(SubCmd::Status) => status_cmd(&settings).await, + Some(SubCmd::Stop) => stop_cmd(&settings).await, + Some(SubCmd::Restart) => restart_cmd(&settings).await, + } + } +} + +const DAEMON_VERSION: &str = env!("CARGO_PKG_VERSION"); +const DAEMON_PROTOCOL_VERSION: u32 = 1; +const STARTUP_POLL: Duration = Duration::from_millis(40); +const LOCK_POLL: Duration = Duration::from_millis(20); +const LEGACY_DAEMON_RESTART_MESSAGE: &str = "legacy daemon detected; restart daemon manually"; + +struct PidfileGuard { + file: File, +} + +impl PidfileGuard { + fn acquire(path: &Path) -> Result { + let mut file = open_lock_file(path)?; + + if !file.try_lock_exclusive()? { + bail!( + "daemon already running (pidfile lock busy at {})", + path.display() + ); + } + + file.set_len(0) + .wrap_err_with(|| format!("could not truncate daemon pidfile {}", path.display()))?; + writeln!(file, "{}", std::process::id()) + .and_then(|()| writeln!(file, "{DAEMON_VERSION}")) + .wrap_err_with(|| format!("could not write daemon pidfile {}", path.display()))?; + + Ok(Self { file }) + } +} + +impl Drop for PidfileGuard { + fn drop(&mut self) { + let _ = self.file.unlock(); + } +} + +enum Probe { + Ready(HistoryClient), + NeedsRestart(String), + Unreachable(eyre::Report), +} + +fn daemon_matches_expected(version: &str, protocol: u32) -> bool { + version == DAEMON_VERSION && protocol == DAEMON_PROTOCOL_VERSION +} + +fn daemon_mismatch_message(version: &str, protocol: u32) -> String { + if protocol == DAEMON_PROTOCOL_VERSION { + format!("daemon is out of date: expected {DAEMON_VERSION}, got {version}") + } else { + format!("daemon protocol mismatch: expected {DAEMON_PROTOCOL_VERSION}, got {protocol}") + } +} + +fn is_legacy_daemon_error(err: &eyre::Report) -> bool { + matches!(classify_error(err), DaemonClientErrorKind::Unimplemented) +} + +fn should_retry_after_error(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) +} + +fn daemon_startup_lock_path(pidfile_path: &Path) -> PathBuf { + let mut os = pidfile_path.as_os_str().to_os_string(); + os.push(".startup.lock"); + PathBuf::from(os) +} + +fn open_lock_file(path: &Path) -> Result { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .wrap_err_with(|| format!("could not create lock directory {}", parent.display()))?; + } + + OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .open(path) + .wrap_err_with(|| format!("could not open lock file {}", path.display())) +} + +async fn wait_for_lock(path: &Path, timeout: Duration) -> Result { + let file = open_lock_file(path)?; + let start = Instant::now(); + + loop { + match file.try_lock_exclusive() { + Ok(true) => return Ok(file), + Ok(false) => { + if start.elapsed() >= timeout { + bail!("timed out waiting for lock at {}", path.display()); + } + + sleep(LOCK_POLL).await; + } + Err(err) => { + return Err(eyre!("could not lock {}: {err}", path.display())); + } + } + } +} + +async fn wait_for_pidfile_available(path: &Path, timeout: Duration) -> Result<()> { + let file = wait_for_lock(path, timeout).await?; + file.unlock() + .wrap_err_with(|| format!("failed to unlock {}", path.display()))?; + Ok(()) +} + +async fn connect_client(settings: &Settings) -> Result { + HistoryClient::new( + #[cfg(unix)] + settings.daemon.socket_path.clone(), + ) + .await +} + +async fn probe(settings: &Settings) -> Probe { + let mut client = match connect_client(settings).await { + Ok(client) => client, + Err(err) => return Probe::Unreachable(err), + }; + + match client.status().await { + Ok(status) => { + if daemon_matches_expected(&status.version, status.protocol) { + Probe::Ready(client) + } else { + Probe::NeedsRestart(daemon_mismatch_message(&status.version, status.protocol)) + } + } + Err(err) => Probe::Unreachable(err), + } +} + +async fn request_shutdown(settings: &Settings) { + if let Ok(mut client) = connect_client(settings).await { + let _ = client.shutdown().await; + } +} + +fn spawn_daemon_process() -> Result<()> { + let exe = std::env::current_exe().wrap_err("could not locate atuin executable")?; + + let mut cmd = Command::new(exe); + cmd.arg("daemon") + .arg("start") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + + #[cfg(unix)] + cmd.arg("--daemonize"); + + cmd.spawn().wrap_err("failed to spawn daemon process")?; + + Ok(()) +} + +fn startup_timeout(settings: &Settings) -> Duration { + Duration::from_secs_f64(settings.local_timeout.max(0.5) + 2.0) +} + +#[cfg(unix)] +fn remove_stale_socket_if_present(settings: &Settings) -> Result<()> { + if settings.daemon.systemd_socket { + return Ok(()); + } + + let socket_path = Path::new(&settings.daemon.socket_path); + if !socket_path.exists() { + return Ok(()); + } + + match StdUnixStream::connect(socket_path) { + Ok(stream) => { + drop(stream); + Ok(()) + } + Err(err) if err.kind() == ErrorKind::ConnectionRefused => { + fs::remove_file(socket_path).wrap_err_with(|| { + format!( + "failed to remove stale daemon socket {}", + socket_path.display() + ) + })?; + Ok(()) + } + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + Err(_) => Ok(()), + } +} + +async fn wait_until_ready(settings: &Settings, timeout: Duration) -> Result { + let start = Instant::now(); + let mut last_error = eyre!("daemon did not become ready"); + + loop { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) => { + last_error = eyre!(reason); + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + last_error = err; + } + } + + if start.elapsed() >= timeout { + return Err(last_error.wrap_err(format!( + "timed out waiting for daemon startup after {}ms", + timeout.as_millis() + ))); + } + + sleep(STARTUP_POLL).await; + } +} + +#[expect(clippy::unnecessary_wraps)] +fn ensure_autostart_supported(settings: &Settings) -> Result<()> { + #[cfg(unix)] + if settings.daemon.systemd_socket { + bail!( + "daemon autostart is incompatible with `daemon.systemd_socket = true`; use systemd to manage the daemon" + ); + } + + Ok(()) +} + +/// Ensure the daemon is running, starting it if necessary. +/// +/// If the daemon is already running and up-to-date, this is a no-op. +/// If it is not running or needs a restart, this will spawn a new daemon +/// process and wait for it to become ready. +/// +/// Returns an error if the daemon could not be started. +pub async fn ensure_daemon_running(settings: &Settings) -> Result<()> { + ensure_autostart_supported(settings)?; + + let timeout = startup_timeout(settings); + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let startup_lock_path = daemon_startup_lock_path(&pidfile_path); + let startup_lock = wait_for_lock(&startup_lock_path, timeout).await?; + + match probe(settings).await { + Probe::Ready(_) => { + drop(startup_lock); + return Ok(()); + } + Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + } + Probe::Unreachable(err) => { + if is_legacy_daemon_error(&err) { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + } + } + + // This prevents rapid-fire hook invocations from racing daemon restart. + wait_for_pidfile_available(&pidfile_path, timeout).await?; + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + let _ = wait_until_ready(settings, timeout).await?; + + drop(startup_lock); + Ok(()) +} + +async fn restart_daemon(settings: &Settings) -> Result { + ensure_daemon_running(settings).await?; + connect_client(settings).await +} + +fn ensure_reply_compatible(settings: &Settings, version: &str, protocol: u32) -> Result<()> { + if daemon_matches_expected(version, protocol) { + return Ok(()); + } + + let message = daemon_mismatch_message(version, protocol); + if settings.daemon.autostart { + bail!("{message}"); + } + + bail!("{message}. Enable `daemon.autostart = true` or restart the daemon manually"); +} + +pub async fn start_history(settings: &Settings, history: History) -> Result { + match async { + connect_client(settings) + .await? + .start_history(history.clone()) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(resp.id); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .start_history(history) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(resp.id) +} + +pub async fn end_history(settings: &Settings, id: String, duration: u64, exit: i64) -> Result<()> { + match async { + connect_client(settings) + .await? + .end_history(id.clone(), duration, exit) + .await + } + .await + { + Ok(resp) => { + if daemon_matches_expected(&resp.version, resp.protocol) { + return Ok(()); + } + + if !settings.daemon.autostart { + return Err(eyre!( + "{}. Enable `daemon.autostart = true` or restart the daemon manually", + daemon_mismatch_message(&resp.version, resp.protocol) + )); + } + + // End succeeded on the running daemon, so avoid replaying it. + // We only restart to make subsequent hook calls target the expected version. + let _ = restart_daemon(settings).await; + return Ok(()); + } + Err(err) if !settings.daemon.autostart => return Err(err), + Err(err) if !should_retry_after_error(&err) => return Err(err), + Err(_) => {} + } + + let resp = restart_daemon(settings) + .await? + .end_history(id, duration, exit) + .await?; + ensure_reply_compatible(settings, &resp.version, resp.protocol)?; + Ok(()) +} + +/// Emit a daemon event, auto-starting the daemon if it is not running. +/// +/// If the daemon is not reachable and `daemon.autostart` is enabled, this +/// will start the daemon and retry the event. If the daemon cannot be +/// started or the retry fails, a warning is printed to stderr. +pub async fn emit_event(settings: &Settings, event: DaemonEvent) { + // Try to connect and send + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + tracing::debug!(?e, "failed to send event to daemon"); + } + return; + } + Err(e) if !settings.daemon.autostart || !should_retry_after_error(&e) => { + tracing::debug!(?e, "daemon not available, skipping event emission"); + return; + } + Err(_) => {} + } + + // Auto-start the daemon and retry + if let Err(e) = ensure_daemon_running(settings).await { + eprintln!("Could not start daemon: {e}"); + return; + } + + match ControlClient::from_settings(settings).await { + Ok(mut client) => { + if let Err(e) = client.send_event(event).await { + eprintln!("Daemon started but failed to send event: {e}"); + } + } + Err(e) => { + eprintln!("Daemon started but failed to connect: {e}"); + } + } +} + +pub async fn tail_client(settings: &Settings) -> Result { + match probe(settings).await { + Probe::Ready(client) => return Ok(client), + Probe::NeedsRestart(reason) if !settings.daemon.autostart => { + bail!("{reason}. Enable `daemon.autostart = true` or restart the daemon manually"); + } + Probe::Unreachable(err) if is_legacy_daemon_error(&err) => { + return Err(err.wrap_err(LEGACY_DAEMON_RESTART_MESSAGE)); + } + Probe::Unreachable(err) if !settings.daemon.autostart => return Err(err), + Probe::Unreachable(err) if !should_retry_after_error(&err) => return Err(err), + Probe::NeedsRestart(_) | Probe::Unreachable(_) => {} + } + + restart_daemon(settings).await +} + +async fn status_cmd(settings: &Settings) -> Result<()> { + match probe(settings).await { + Probe::Ready(mut client) => { + let status = client.status().await?; + println!("Daemon running"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + println!(" Protocol: {}", status.protocol); + println!(" Healthy: {}", status.healthy); + #[cfg(unix)] + println!(" Socket: {}", settings.daemon.socket_path); + } + Probe::NeedsRestart(reason) => { + println!("Daemon running (needs restart)"); + println!(" Reason: {reason}"); + } + Probe::Unreachable(_) => { + println!("Daemon is not running"); + } + } + + Ok(()) +} + +async fn stop_cmd(settings: &Settings) -> Result<()> { + let Ok(mut client) = connect_client(settings).await else { + println!("Daemon is not running"); + return Ok(()); + }; + + match client.shutdown().await { + Ok(true) => { + println!("Shutdown requested"); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + match wait_for_pidfile_available(&pidfile_path, timeout).await { + Ok(()) => println!("Daemon stopped"), + Err(_) => println!("Daemon may still be shutting down"), + } + + Ok(()) + } + Ok(false) => bail!("Daemon rejected shutdown request"), + Err(err) => Err(err.wrap_err("Failed to send shutdown request")), + } +} + +async fn restart_cmd(settings: &Settings) -> Result<()> { + // Stop if running + match probe(settings).await { + Probe::Ready(_) | Probe::NeedsRestart(_) => { + request_shutdown(settings).await; + println!("Stopping daemon..."); + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let timeout = Duration::from_secs(5); + wait_for_pidfile_available(&pidfile_path, timeout) + .await + .wrap_err("Timed out waiting for old daemon to stop")?; + } + Probe::Unreachable(_) => { + println!("No daemon running"); + } + } + + #[cfg(unix)] + remove_stale_socket_if_present(settings)?; + + spawn_daemon_process()?; + println!("Starting daemon..."); + + let timeout = startup_timeout(settings); + let status = wait_until_ready(settings, timeout).await?.status().await?; + + println!("Daemon restarted"); + println!(" PID: {}", status.pid); + println!(" Version: {}", status.version); + + Ok(()) +} + +/// Daemonize the current process. Must be called before creating the tokio +/// runtime or opening database connections, since `fork()` inside an async +/// runtime corrupts its internal state. +#[cfg(unix)] +pub fn daemonize_current_process() -> Result<()> { + let cwd = + std::env::current_dir().wrap_err("could not determine current directory for daemon")?; + + Daemonize::new() + .working_directory(cwd) + .start() + .wrap_err("failed to daemonize process")?; + + Ok(()) +} + +async fn run( + settings: Settings, + store: SqliteStore, + history_db: Sqlite, + force: bool, +) -> Result<()> { + if force { + force_cleanup(&settings); + } + + let pidfile_path = PathBuf::from(&settings.daemon.pidfile_path); + let _pidfile_guard = PidfileGuard::acquire(&pidfile_path)?; + + crate::atuin_daemon::boot(settings, store, history_db).await?; + + Ok(()) +} + +/// Force cleanup: kill existing daemon process and remove socket. +fn force_cleanup(settings: &Settings) { + let pidfile_path = Path::new(&settings.daemon.pidfile_path); + + // Read and kill the existing process if pidfile exists + if pidfile_path.exists() { + if let Ok(contents) = fs::read_to_string(pidfile_path) + && let Some(pid_str) = contents.lines().next() + && let Ok(pid) = pid_str.parse::() + { + kill_process(pid); + // Give it a moment to release resources + std::thread::sleep(Duration::from_millis(100)); + } + + // Remove the pidfile + if let Err(e) = fs::remove_file(pidfile_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove pidfile: {e}"); + } + } + + // Remove the socket file + #[cfg(unix)] + { + let socket_path = Path::new(&settings.daemon.socket_path); + if socket_path.exists() + && let Err(e) = fs::remove_file(socket_path) + && e.kind() != ErrorKind::NotFound + { + tracing::warn!("failed to remove socket: {e}"); + } + } +} + +/// Kill a process by PID. +#[cfg(unix)] +fn kill_process(pid: u32) { + // Use kill command to send SIGTERM for graceful shutdown + let _ = Command::new("kill") + .args(["-TERM", &pid.to_string()]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_version_matches() { + assert!(daemon_matches_expected( + DAEMON_VERSION, + DAEMON_PROTOCOL_VERSION + )); + } + + #[test] + fn test_version_mismatch() { + assert!(!daemon_matches_expected("0.0.0", DAEMON_PROTOCOL_VERSION)); + assert!(!daemon_matches_expected(DAEMON_VERSION, 999)); + assert!(!daemon_matches_expected("0.0.0", 999)); + } + + #[test] + fn test_mismatch_message_version() { + let msg = daemon_mismatch_message("0.0.0", DAEMON_PROTOCOL_VERSION); + assert!(msg.contains("out of date"), "got: {msg}"); + assert!(msg.contains("0.0.0")); + assert!(msg.contains(DAEMON_VERSION)); + } + + #[test] + fn test_mismatch_message_protocol() { + let msg = daemon_mismatch_message(DAEMON_VERSION, 999); + assert!(msg.contains("protocol mismatch"), "got: {msg}"); + } + + #[test] + fn test_startup_lock_path() { + let pidfile = Path::new("/tmp/atuin-daemon.pid"); + let lock = daemon_startup_lock_path(pidfile); + assert_eq!(lock, PathBuf::from("/tmp/atuin-daemon.pid.startup.lock")); + } + + #[test] + fn test_pidfile_guard_acquire_and_drop() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + { + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + // Guard holds an exclusive lock — on Windows other handles cannot + // read the file, so we verify contents after the guard is dropped. + } + + let contents = std::fs::read_to_string(&pidfile).unwrap(); + let lines: Vec<&str> = contents.lines().collect(); + assert_eq!(lines.len(), 2); + assert_eq!(lines[0], std::process::id().to_string()); + assert_eq!(lines[1], DAEMON_VERSION); + + // After guard is dropped, lock should be released — acquiring again must succeed. + let _guard2 = PidfileGuard::acquire(&pidfile).unwrap(); + } + + #[test] + fn test_pidfile_guard_prevents_double_acquire() { + let tmp = tempfile::tempdir().unwrap(); + let pidfile = tmp.path().join("daemon.pid"); + + let _guard = PidfileGuard::acquire(&pidfile).unwrap(); + let result = PidfileGuard::acquire(&pidfile); + assert!(result.is_err()); + } +} diff --git a/crates/turtle/src/command/client/default_config.rs b/crates/turtle/src/command/client/default_config.rs new file mode 100644 index 00000000..e8cc15f9 --- /dev/null +++ b/crates/turtle/src/command/client/default_config.rs @@ -0,0 +1,4 @@ +pub fn run() { + // TODO(@bpeetz): Re-add the default settings option back (Settings::example_config()) <2026-06-11> + println!("TODO"); +} diff --git a/crates/turtle/src/command/client/doctor.rs b/crates/turtle/src/command/client/doctor.rs new file mode 100644 index 00000000..09fa6e77 --- /dev/null +++ b/crates/turtle/src/command/client/doctor.rs @@ -0,0 +1,412 @@ +use std::process::Command; +use std::{env, str::FromStr}; + +use crate::atuin_client::database::Sqlite; +use crate::atuin_client::settings::Settings; +use crate::atuin_common::shell::{Shell, shell_name}; +use crate::atuin_common::utils; +use colored::Colorize; +use eyre::Result; +use serde::Serialize; + +use sysinfo::{Disks, System, get_current_pid}; + +#[derive(Debug, Serialize)] +struct ShellInfo { + pub name: String, + + // best-effort, not supported on all OSes + pub default: String, + + // Detect some shell plugins that the user has installed. + // I'm just going to start with preexec/blesh + pub plugins: Vec, + + // The preexec framework used in the current session, if Atuin is loaded. + pub preexec: Option, +} + +impl ShellInfo { + // HACK ALERT! + // Many of the shell vars we need to detect are not exported :( + // So, we're going to run a interactive session and directly check the + // variable. There's a chance this won't work, so it should not be fatal. + // + // Every shell we support handles `shell -ic 'command'` + fn shellvar_exists(shell: &str, var: &str) -> bool { + let cmd = Command::new(shell) + .args([ + "-ic", + format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), + ]) + .output() + .map_or(String::new(), |v| { + let out = v.stdout; + String::from_utf8(out).unwrap_or_default() + }); + + cmd.contains("ATUIN_DOCTOR_ENV_FOUND") + } + + fn detect_preexec_framework(shell: &str) -> Option { + if env::var("ATUIN_SESSION").ok().is_none() { + None + } else if shell.starts_with("bash") || shell == "sh" { + env::var("ATUIN_PREEXEC_BACKEND") + .ok() + .filter(|value| !value.is_empty()) + .and_then(|atuin_preexec_backend| { + atuin_preexec_backend.rfind(':').and_then(|pos_colon| { + u32::from_str(&atuin_preexec_backend[..pos_colon]) + .ok() + .is_some_and(|preexec_shlvl| { + env::var("SHLVL") + .ok() + .and_then(|shlvl| u32::from_str(&shlvl).ok()) + .is_some_and(|shlvl| shlvl == preexec_shlvl) + }) + .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) + }) + }) + } else { + Some("built-in".to_string()) + } + } + + fn validate_plugin_blesh( + _shell: &str, + shell_process: &sysinfo::Process, + ble_session_id: &str, + ) -> Option { + ble_session_id + .split('/') + .nth(1) + .and_then(|field| u32::from_str(field).ok()) + .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) + .map(|_| "blesh".to_string()) + } + + pub fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec { + // consider a different detection approach if there are plugins + // that don't set shell vars + + enum PluginShellType { + Any, + Bash, + + // Note: these are currently unused + #[expect(dead_code)] + Zsh, + #[expect(dead_code)] + Fish, + #[expect(dead_code)] + Nushell, + #[expect(dead_code)] + Xonsh, + } + + enum PluginProbeType { + EnvironmentVariable(&'static str), + InteractiveShellVariable(&'static str), + } + + type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option; + + let plugin_list: [( + &str, + PluginShellType, + PluginProbeType, + Option, + ); 3] = [ + ( + "atuin", + PluginShellType::Any, + PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), + None, + ), + ( + "blesh", + PluginShellType::Bash, + PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), + Some(Self::validate_plugin_blesh), + ), + ( + "bash-preexec", + PluginShellType::Bash, + PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), + None, + ), + ]; + + plugin_list + .into_iter() + .filter(|(_, shell_type, _, _)| match shell_type { + PluginShellType::Any => true, + PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", + PluginShellType::Zsh => shell.starts_with("zsh"), + PluginShellType::Fish => shell.starts_with("fish"), + PluginShellType::Nushell => shell.starts_with("nu"), + PluginShellType::Xonsh => shell.starts_with("xonsh"), + }) + .filter_map(|(plugin, _, probe_type, validator)| -> Option { + match probe_type { + PluginProbeType::EnvironmentVariable(env) => { + env::var(env).ok().filter(|value| !value.is_empty()) + } + PluginProbeType::InteractiveShellVariable(shellvar) => { + ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) + } + } + .and_then(|value| { + validator.map_or_else( + || Some(plugin.to_string()), + |validator| validator(shell, shell_process, &value), + ) + }) + }) + .collect() + } + + pub fn new() -> Self { + // TODO: rework to use crate::atuin_common::Shell + + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let name = shell_name(Some(parent)); + + let plugins = ShellInfo::plugins(name.as_str(), parent); + + let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); + + let preexec = Self::detect_preexec_framework(name.as_str()); + + Self { + name, + default, + plugins, + preexec, + } + } +} + +#[derive(Debug, Serialize)] +struct DiskInfo { + pub name: String, + pub filesystem: String, +} + +#[derive(Debug, Serialize)] +struct SystemInfo { + pub os: String, + + pub arch: String, + + pub version: String, + pub disks: Vec, +} + +impl SystemInfo { + pub fn new() -> Self { + let disks = Disks::new_with_refreshed_list(); + let disks = disks + .list() + .iter() + .map(|d| DiskInfo { + name: d.name().to_os_string().into_string().unwrap(), + filesystem: d.file_system().to_os_string().into_string().unwrap(), + }) + .collect(); + + Self { + os: System::name().unwrap_or_else(|| "unknown".to_string()), + arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), + version: System::os_version().unwrap_or_else(|| "unknown".to_string()), + disks, + } + } +} + +#[derive(Debug, Serialize)] +struct SyncInfo { + pub auth_state: String, + pub auto_sync: bool, + + pub last_sync: String, +} + +impl SyncInfo { + pub async fn new(settings: &Settings) -> Self { + // Build auth state description from raw token state without calling + // resolve_sync_auth(), which has side effects (token migration cleanup) + // that a diagnostic command should not trigger. + let meta = Settings::meta_store().await.ok(); + let has_cli_token = match &meta { + Some(m) => m.session_token().await.ok().flatten().is_some(), + None => false, + }; + + let auth_state = if has_cli_token { + "Self-hosted (authenticated)".into() + } else { + "Not authenticated".into() + }; + + Self { + auth_state, + auto_sync: settings.auto_sync, + last_sync: Settings::last_sync() + .await + .map_or_else(|_| "no last sync".to_string(), |v| v.to_string()), + } + } +} + +#[derive(Debug)] +struct SettingPaths { + db: String, + record_store: String, + key: String, +} + +impl SettingPaths { + pub fn new(settings: &Settings) -> Self { + Self { + db: settings.db_path.clone(), + record_store: settings.record_store_path.clone(), + key: settings.key_path.clone(), + } + } + + pub fn verify(&self) { + let paths = vec![ + ("ATUIN_DB_PATH", &self.db), + ("ATUIN_RECORD_STORE", &self.record_store), + ("ATUIN_KEY", &self.key), + ]; + + for (path_env_var, path) in paths { + if utils::broken_symlink(path) { + eprintln!( + "{path} (${path_env_var}) is a broken symlink. This may cause issues with Atuin." + ); + } + } + } +} + +#[derive(Debug, Serialize)] +struct AtuinInfo { + pub version: String, + pub commit: String, + + /// Whether the main Atuin sync server is in use + /// I'm just calling it Atuin Cloud for lack of a better name atm + pub sync: Option, + + pub sqlite_version: String, + + #[serde(skip)] // probably unnecessary to expose this + pub setting_paths: SettingPaths, +} + +impl AtuinInfo { + pub async fn new(settings: &Settings) -> Self { + let logged_in = settings.logged_in().await.unwrap_or(false); + + let sync = if logged_in { + Some(SyncInfo::new(settings).await) + } else { + None + }; + + let sqlite_version = match Sqlite::new("sqlite::memory:", 0.1).await { + Ok(db) => db + .sqlite_version() + .await + .unwrap_or_else(|_| "unknown".to_string()), + Err(_) => "error".to_string(), + }; + + Self { + version: crate::VERSION.to_string(), + commit: crate::SHA.to_string(), + sync, + sqlite_version, + setting_paths: SettingPaths::new(settings), + } + } +} + +#[derive(Debug, Serialize)] +struct DoctorDump { + pub atuin: AtuinInfo, + pub shell: ShellInfo, + pub system: SystemInfo, +} + +impl DoctorDump { + pub async fn new(settings: &Settings) -> Self { + Self { + atuin: AtuinInfo::new(settings).await, + shell: ShellInfo::new(), + system: SystemInfo::new(), + } + } +} + +fn checks(info: &DoctorDump) { + println!(); // spacing + // + let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); + let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh (>= 0.4) be installed. An older ble.sh may not be detected. so ignore this if you have ble.sh >= 0.4 set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); + let blesh_integration_error = "[Shell] Atuin and ble.sh seem to be loaded in the session, but the integration does not seem to be working. Please check the setup in .bashrc.".bold().red(); + + // ZFS: https://github.com/atuinsh/atuin/issues/952 + if info.system.disks.iter().any(|d| d.filesystem == "zfs") { + println!("{zfs_error}"); + } + + info.atuin.setting_paths.verify(); + + // Shell + if info.shell.name == "bash" { + if !info + .shell + .plugins + .iter() + .any(|p| p == "blesh" || p == "bash-preexec") + { + println!("{bash_plugin_error}"); + } + + if info.shell.plugins.iter().any(|plugin| plugin == "atuin") + && info.shell.plugins.iter().any(|plugin| plugin == "blesh") + && info.shell.preexec.as_ref().is_some_and(|val| val == "none") + { + println!("{blesh_integration_error}"); + } + } +} + +pub async fn run(settings: &Settings) -> Result<()> { + println!("{}", "Atuin Doctor".bold()); + println!("Checking for diagnostics"); + let dump = DoctorDump::new(settings).await; + + checks(&dump); + + let dump = serde_json::to_string_pretty(&dump)?; + + println!("\nPlease include the output below with any bug reports or issues\n"); + println!("{dump}"); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/history.rs b/crates/turtle/src/command/client/history.rs new file mode 100644 index 00000000..0c61392c --- /dev/null +++ b/crates/turtle/src/command/client/history.rs @@ -0,0 +1,1340 @@ +use std::{ + fmt::{self, Display}, + io::{self, IsTerminal, Write}, + path::PathBuf, + time::Duration, +}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Subcommand; +use eyre::{Context, Result, bail}; +use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt}; + +#[cfg(feature = "daemon")] +use super::daemon as daemon_cmd; +#[cfg(feature = "daemon")] +use colored::Colorize; +#[cfg(feature = "daemon")] +use serde::Serialize; + +#[cfg(feature = "daemon")] +use crate::atuin_daemon::history::{HistoryEventKind, TailHistoryReply}; + +use crate::atuin_client::{ + database::{Database, Sqlite, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{ + FilterMode::{Directory, Global, Session}, + Settings, Timezone, + }, +}; + +#[cfg(feature = "sync")] +use crate::atuin_client::record; + +use log::{debug, warn}; +use time::{OffsetDateTime, macros::format_description}; + +#[cfg(feature = "daemon")] +use super::daemon; +use super::search::format_duration_into; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Begins a new command in the history + Start { + /// Collects the command from the `ATUIN_COMMAND_LINE` environment variable, + /// which does not need escaping and is more compatible between OS and shells + #[arg(long = "command-from-env", hide = true)] + cmd_env: bool, + + /// Author of this command, eg `ellie`, `claude`, or `copilot` + #[arg(long)] + author: Option, + + /// Optional intent/rationale for running this command + #[arg(long)] + intent: Option, + + command: Vec, + }, + + /// Finishes a new command in the history (adds time, exit code) + End { + id: String, + #[arg(long, short)] + exit: i64, + #[arg(long, short)] + duration: Option, + }, + + /// Stream history events from the daemon as they are received + Tail, + + /// List all items in history + List { + #[arg(long, short)] + cwd: bool, + + #[arg(long, short)] + session: bool, + + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline support + #[arg(long)] + print0: bool, + + #[arg(long, short, default_value = "true")] + // accept no value + #[arg(num_args(0..=1), default_missing_value("true"))] + // accept a value + #[arg(action = clap::ArgAction::Set)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {exit}, {time}, {session}, and {uuid} + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option, + }, + + /// Get the last command ran + Last { + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {author}, {intent}, {time}, {session}, {uuid} and {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option, + }, + + InitStore, + + /// Delete history entries matching the configured exclusion filters + Prune { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + }, + + /// Delete duplicate history entries (that have the same command, cwd and hostname) + Dedup { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + + /// Only delete results added before this date + #[arg(long, short)] + before: String, + + /// How many recent duplicates to keep + #[arg(long)] + dupkeep: u32, + }, +} + +#[derive(Clone, Copy, Debug)] +pub enum ListMode { + Human, + CmdOnly, + Regular, +} + +impl ListMode { + pub const fn from_flags(human: bool, cmd_only: bool) -> Self { + if human { + ListMode::Human + } else if cmd_only { + ListMode::CmdOnly + } else { + ListMode::Regular + } + } +} + +#[expect(clippy::cast_sign_loss)] +pub fn print_list( + h: &[History], + list_mode: ListMode, + format: Option<&str>, + print0: bool, + reverse: bool, + tz: Timezone, +) { + let w = std::io::stdout(); + let mut w = w.lock(); + + let fmt_str = match list_mode { + ListMode::Human => format + .unwrap_or("{time} · {duration}\t{command}") + .replace("\\t", "\t"), + ListMode::Regular => format + .unwrap_or("{time}\t{command}\t{duration}") + .replace("\\t", "\t"), + // not used + ListMode::CmdOnly => String::new(), + }; + + let parsed_fmt = match list_mode { + ListMode::Human | ListMode::Regular => parse_fmt(&fmt_str), + ListMode::CmdOnly => std::iter::once(ParseSegment::Key("command")).collect(), + }; + + let iterator = if reverse { + Box::new(h.iter().rev()) as Box> + } else { + Box::new(h.iter()) as Box> + }; + + let entry_terminator = if print0 { "\0" } else { "\n" }; + let flush_each_line = print0; + + for history in iterator { + let fh = FmtHistory { + history, + cmd_format: CmdFormat::for_output(&w), + tz: &tz, + }; + let args = parsed_fmt.with_args(&fh); + + // Check for formatting errors before attempting to write + if let Err(err) = args.status() { + eprintln!("ERROR: history output failed with: {err}"); + std::process::exit(1); + } + + let write_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + write!(w, "{args}{entry_terminator}") + })); + + match write_result { + Ok(Ok(())) => { + // Write succeeded + } + Ok(Err(err)) => { + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: Failed to write history output: {err}"); + std::process::exit(1); + } + } + Err(_) => { + eprintln!("ERROR: Format string caused a formatting error."); + eprintln!( + "This may be due to an unsupported format string containing special characters." + ); + eprintln!( + "Please check your format string syntax and ensure literal braces are properly escaped." + ); + std::process::exit(1); + } + } + if flush_each_line { + check_for_write_errors(w.flush()); + } + } + + if !flush_each_line { + check_for_write_errors(w.flush()); + } +} + +fn check_for_write_errors(write: Result<(), io::Error>) { + if let Err(err) = write { + // Ignore broken pipe (issue #626) + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: History output failed with the following error: {err}"); + std::process::exit(1); + } + } +} + +/// Type wrapper around `History` with formatting settings. +#[derive(Clone, Copy, Debug)] +struct FmtHistory<'a> { + history: &'a History, + cmd_format: CmdFormat, + tz: &'a Timezone, +} + +#[derive(Clone, Copy, Debug)] +enum CmdFormat { + Literal, + Escaped, +} +impl CmdFormat { + fn for_output(out: &O) -> Self { + if out.is_terminal() { + Self::Escaped + } else { + Self::Literal + } + } +} + +static TIME_FMT: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day] [hour repr:24]:[minute]:[second]"); + +/// defines how to format the history +impl FormatKey for FmtHistory<'_> { + #[expect(clippy::cast_sign_loss)] + fn fmt(&self, key: &str, f: &mut fmt::Formatter<'_>) -> Result<(), FormatKeyError> { + match key { + "command" => match self.cmd_format { + CmdFormat::Literal => f.write_str(self.history.command.trim()), + CmdFormat::Escaped => f.write_str(&self.history.command.trim().escape_control()), + }?, + "directory" => f.write_str(self.history.cwd.trim())?, + "exit" => f.write_str(&self.history.exit.to_string())?, + "duration" => { + let dur = Duration::from_nanos(std::cmp::max(self.history.duration, 0) as u64); + format_duration_into(dur, f)?; + } + "time" => { + self.history + .timestamp + .to_offset(self.tz.0) + .format(TIME_FMT) + .map_err(|_| fmt::Error)? + .fmt(f)?; + } + "relativetime" => { + let since = OffsetDateTime::now_utc() - self.history.timestamp; + let d = Duration::try_from(since).unwrap_or_default(); + format_duration_into(d, f)?; + } + "host" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or(&self.history.hostname, |(host, _)| host), + )?, + "author" => f.write_str(&self.history.author)?, + "intent" => f.write_str(self.history.intent.as_deref().unwrap_or_default())?, + "user" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user), + )?, + "session" => f.write_str(&self.history.session)?, + "uuid" => f.write_str(&self.history.id.0)?, + _ => return Err(FormatKeyError::UnknownKey), + } + Ok(()) + } +} + +fn parse_fmt(format: &str) -> ParsedFmt<'_> { + match ParsedFmt::new(format) { + Ok(fmt) => fmt, + Err(err) => { + eprintln!("ERROR: History formatting failed with the following error: {err}"); + + if format.contains('"') && (format.contains(":{") || format.contains(",{")) { + eprintln!("It looks like you're trying to create JSON output."); + eprintln!("For JSON, you need to escape literal braces by doubling them:"); + eprintln!("Example: '{{\"command\":\"{{command}}\",\"time\":\"{{time}}\"}}'"); + } else { + eprintln!( + "If your formatting string contains literal curly braces, you need to escape them by doubling:" + ); + eprintln!("Use {{{{ for literal {{ and }}}} for literal }}"); + } + std::process::exit(1) + } + } +} + +fn apply_start_metadata(history: &mut History, author: Option<&str>, intent: Option<&str>) { + if let Some(author) = author.map(str::trim).filter(|author| !author.is_empty()) { + author.clone_into(&mut history.author); + } + + if let Some(intent) = intent.map(str::trim).filter(|intent| !intent.is_empty()) { + history.intent = Some(intent.to_owned()); + } else if intent.is_some() { + history.intent = None; + } +} + +fn normalize_command_for_storage<'a>(command: &'a str, settings: &Settings) -> &'a str { + if !settings.strip_trailing_whitespace { + return command; + } + + let trimmed = command.trim_end_matches([' ', '\t']); + if trimmed.len() == command.len() { + return command; + } + + let trailing_backslashes = trimmed + .as_bytes() + .iter() + .rev() + .take_while(|&&byte| byte == b'\\') + .count(); + + if trailing_backslashes % 2 == 1 { + command + } else { + trimmed + } +} + +async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + let id = h.id.0.clone(); + + // Silently ignore database errors to avoid breaking the shell + // This is important when disk is full or database is locked + if let Err(e) = db.save(&h).await { + debug!("failed to save history: {e}"); + } + + Ok(Some(id)) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_start( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result> { + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + let command = normalize_command_for_storage(command, settings); + + let mut h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + apply_start_metadata(&mut h, author, intent); + + if !h.should_save(settings) { + return Ok(None); + } + + // Attempt to start history via daemon, but silently ignore errors + // to avoid breaking the shell when the daemon is unavailable or disk is full + let resp = match daemon::start_history(settings, h.clone()).await { + Ok(id) => id, + Err(e) => { + debug!("failed to start history via daemon: {e}"); + h.id.0.clone() + } + }; + + Ok(Some(resp)) +} + +#[expect(unused_variables)] +async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option, +) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && exit > 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync().await? { + let (_, downloaded) = + record::sync::sync(settings, &store, &history_store.encryption_key).await?; + Settings::save_sync_time().await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) +} + +#[cfg(feature = "daemon")] +async fn handle_daemon_end( + settings: &Settings, + id: &str, + exit: i64, + duration: Option, +) -> Result<()> { + daemon::end_history(settings, id.to_string(), duration.unwrap_or(0), exit).await?; + + Ok(()) +} + +pub(super) async fn start_history_entry( + settings: &Settings, + command: &str, + author: Option<&str>, + intent: Option<&str>, +) -> Result> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_start(settings, command, author, intent).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let db = Sqlite::new(db_path, settings.local_timeout).await?; + handle_start(&db, settings, command, author, intent).await +} + +pub(super) async fn end_history_entry( + settings: &Settings, + id: &str, + exit: i64, + duration: Option, +) -> Result<()> { + #[cfg(feature = "daemon")] + if settings.daemon.enabled { + return handle_daemon_end(settings, id, exit, duration).await; + } + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + handle_end(&db, store, history_store, settings, id, exit, duration).await +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum TailKind { + Started, + Ended, +} + +#[cfg(feature = "daemon")] +#[derive(Clone, Debug, Eq, PartialEq)] +struct TailEvent { + kind: TailKind, + history: History, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonEvent<'a> { + event: &'static str, + history: TailJsonHistory<'a>, +} + +#[cfg(feature = "daemon")] +#[derive(Serialize)] +struct TailJsonHistory<'a> { + id: &'a str, + timestamp: String, + timestamp_unix_ns: u64, + command: &'a str, + cwd: &'a str, + session: &'a str, + hostname: &'a str, + host: &'a str, + user: &'a str, + author: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + intent: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + exit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + duration_ns: Option, + #[serde(skip_serializing_if = "Option::is_none")] + duration: Option, + #[serde(skip_serializing_if = "Option::is_none")] + success: Option, + #[serde(skip_serializing_if = "Option::is_none")] + finished_at: Option, +} + +#[cfg(feature = "daemon")] +impl TailEvent { + fn from_proto(reply: TailHistoryReply) -> Result { + let history = reply + .history + .ok_or_else(|| eyre::eyre!("daemon sent a history tail event without history"))?; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(i128::from(history.timestamp)) + .context("invalid daemon history timestamp")?; + let kind = match HistoryEventKind::try_from(reply.kind) + .unwrap_or(HistoryEventKind::Unspecified) + { + HistoryEventKind::Started => TailKind::Started, + HistoryEventKind::Ended => TailKind::Ended, + HistoryEventKind::Unspecified => bail!("daemon sent an unspecified history tail event"), + }; + + Ok(Self { + kind, + history: History { + id: history.id.into(), + timestamp, + duration: history.duration, + exit: history.exit, + command: history.command, + cwd: history.cwd, + session: history.session, + hostname: history.hostname, + author: history.author, + intent: normalize_optional_field(&history.intent), + deleted_at: None, + }, + }) + } + + fn render(&self, tty: bool, tz: Timezone) -> Result { + if tty { + Ok(self.render_pretty(tz)) + } else { + let mut json = self.render_json(tz)?; + json.push('\n'); + Ok(json) + } + } + + fn render_json(&self, tz: Timezone) -> Result { + let payload = TailJsonEvent { + event: self.kind.as_str(), + history: TailJsonHistory { + id: &self.history.id.0, + timestamp: format_history_time(self.history.timestamp, tz)?, + timestamp_unix_ns: u64::try_from(self.history.timestamp.unix_timestamp_nanos()) + .context("history timestamp predates unix epoch")?, + command: &self.history.command, + cwd: &self.history.cwd, + session: &self.history.session, + hostname: &self.history.hostname, + host: self.host(), + user: self.user(), + author: &self.history.author, + intent: self.history.intent.as_deref(), + exit: self.exit_value(), + duration_ns: self.duration_value(), + duration: self.duration_value().map(format_duration_ns), + success: self.success_value(), + finished_at: self + .finished_at() + .map(|time| format_history_time(time, tz)) + .transpose()?, + }, + }; + + Ok(serde_json::to_string(&payload)?) + } + + fn render_pretty(&self, tz: Timezone) -> String { + let mut out = String::new(); + let border = match self.kind { + TailKind::Started => "-".repeat(72).bright_blue().to_string(), + TailKind::Ended if self.history.exit == 0 => "-".repeat(72).bright_green().to_string(), + TailKind::Ended => "-".repeat(72).bright_red().to_string(), + }; + + out.push_str(&border); + out.push('\n'); + + let command = self.history.command.trim(); + let escaped_command = command.escape_control(); + let mut command_lines = escaped_command.lines(); + let header = format!( + "{} {}", + self.kind.badge(self.history.exit), + command_lines.next().unwrap_or_default().bold() + ); + out.push_str(&header); + out.push('\n'); + + for line in command_lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } + + push_pretty_field( + &mut out, + "start", + &format_history_time(self.history.timestamp, tz) + .unwrap_or_else(|_| "invalid".to_owned()), + ); + push_pretty_field(&mut out, "history", &self.history.id.0); + push_pretty_field(&mut out, "session", &self.history.session); + push_pretty_field(&mut out, "exit", &self.exit_display()); + push_pretty_field(&mut out, "duration", &self.duration_display()); + + out.push('\n'); + + push_pretty_field(&mut out, "cwd", &self.history.cwd); + push_pretty_field(&mut out, "hostname", &self.history.hostname); + push_pretty_field(&mut out, "host", self.host()); + push_pretty_field(&mut out, "user", self.user()); + push_pretty_field(&mut out, "author", &self.history.author); + + if let Some(intent) = self.history.intent.as_deref() { + push_pretty_field(&mut out, "intent", intent); + } + + if let Some(finished) = self.finished_at() { + let finished = + format_history_time(finished, tz).unwrap_or_else(|_| "invalid".to_owned()); + push_pretty_field(&mut out, "finished", &finished); + } + + out.push_str(&border); + out.push_str("\n\n"); + out + } + + fn host(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or(self.history.hostname.as_str(), |(host, _)| host) + } + + fn user(&self) -> &str { + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user) + } + + fn exit_value(&self) -> Option { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit) + } + + fn duration_value(&self) -> Option { + matches!(self.kind, TailKind::Ended).then_some(self.history.duration) + } + + fn success_value(&self) -> Option { + matches!(self.kind, TailKind::Ended).then_some(self.history.exit == 0) + } + + fn finished_at(&self) -> Option { + self.duration_value() + .filter(|duration| *duration >= 0) + .map(time::Duration::nanoseconds) + .and_then(|duration| self.history.timestamp.checked_add(duration)) + } + + fn exit_display(&self) -> String { + match self.exit_value() { + Some(0) => "0 (success)".bright_green().to_string(), + Some(code) => format!("{code} (failure)").bright_red().to_string(), + None => "pending".bright_yellow().to_string(), + } + } + + fn duration_display(&self) -> String { + match self.duration_value() { + Some(duration) if duration >= 0 => format_duration_ns(duration), + Some(_) => "unknown".bright_yellow().to_string(), + None => "running".bright_yellow().to_string(), + } + } +} + +#[cfg(feature = "daemon")] +impl TailKind { + const fn as_str(self) -> &'static str { + match self { + Self::Started => "started", + Self::Ended => "ended", + } + } + + fn badge(self, exit: i64) -> colored::ColoredString { + match self { + Self::Started => "STARTED".bold().bright_blue(), + Self::Ended if exit == 0 => "ENDED".bold().bright_green(), + Self::Ended => "ENDED".bold().bright_red(), + } + } +} + +#[cfg(feature = "daemon")] +fn format_history_time(timestamp: OffsetDateTime, tz: Timezone) -> Result { + Ok(timestamp.to_offset(tz.0).format(TIME_FMT)?) +} + +#[cfg(feature = "daemon")] +fn format_duration_ns(duration_ns: i64) -> String { + struct F(Duration); + impl Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + + F(Duration::from_nanos(duration_ns.max(0).cast_unsigned())).to_string() +} + +#[cfg(feature = "daemon")] +fn push_pretty_field(out: &mut String, label: &str, value: &str) { + out.push_str(" "); + let label = format!("{label}:"); + out.push_str(&label.bright_cyan().bold().to_string()); + if label.len() < 10 { + out.push_str(&" ".repeat(10 - label.len())); + } + + let mut lines = value.lines(); + if let Some(first) = lines.next() { + out.push_str(first); + } + out.push('\n'); + + for line in lines { + out.push_str(" "); + out.push_str(line); + out.push('\n'); + } +} + +#[cfg(feature = "daemon")] +fn normalize_optional_field(value: &str) -> Option { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_owned()) + } +} + +impl Cmd { + #[cfg(feature = "daemon")] + async fn handle_tail(settings: &Settings) -> Result<()> { + let tty = std::io::stdout().is_terminal(); + let mut client = daemon::tail_client(settings).await?; + let mut stream = client.tail_history().await?; + let stdout = std::io::stdout(); + + while let Some(reply) = stream.message().await? { + let event = TailEvent::from_proto(reply)?; + let rendered = event.render(tty, settings.timezone)?; + let mut out = stdout.lock(); + + match out.write_all(rendered.as_bytes()) { + Ok(()) => out.flush()?, + Err(err) if err.kind() == io::ErrorKind::BrokenPipe => break, + Err(err) => return Err(err.into()), + } + } + + Ok(()) + } + + #[expect(clippy::too_many_lines, clippy::cast_possible_truncation)] + #[expect(clippy::too_many_arguments)] + #[expect(clippy::fn_params_excessive_bools)] + async fn handle_list( + db: &impl Database, + settings: &Settings, + context: crate::atuin_client::database::Context, + session: bool, + cwd: bool, + mode: ListMode, + format: Option, + include_deleted: bool, + print0: bool, + reverse: bool, + tz: Timezone, + ) -> Result<()> { + let filters = match (session, cwd) { + (true, true) => [Session, Directory], + (true, false) => [Session, Global], + (false, true) => [Global, Directory], + (false, false) => [ + settings.default_filter_mode(context.git_root.is_some()), + Global, + ], + }; + + let history = db + .list(&filters, &context, None, false, include_deleted) + .await?; + + print_list( + &history, + mode, + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + print0, + reverse, + tz, + ); + + Ok(()) + } + + async fn handle_prune( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + context: crate::atuin_client::database::Context, + dry_run: bool, + ) -> Result<()> { + // Grab all executed commands and filter them using History::should_save. + // We could iterate or paginate here if memory usage becomes an issue. + let matches: Vec = db + .list(&[Global], &context, None, false, false) + .await? + .into_iter() + .filter(|h| !h.should_save(settings)) + .collect(); + + match matches.len() { + 0 => { + println!("No entries to prune."); + return Ok(()); + } + 1 => println!("Found 1 entry to prune."), + n => println!("Found {n} entries to prune."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id.clone()).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryPruned).await; + } + Ok(()) + } + + async fn handle_dedup( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + before: i64, + dupkeep: u32, + dry_run: bool, + ) -> Result<()> { + if dupkeep == 0 { + eprintln!( + "\"--dupkeep 0\" would keep 0 copies of duplicate commands and thus delete all of them! Use \"atuin search --delete ...\" if you really want that." + ); + std::process::exit(1); + } + + let matches: Vec = db.get_dups(before, dupkeep).await?; + + match matches.len() { + 0 => { + println!("No duplicates to delete."); + return Ok(()); + } + 1 => println!("Found 1 duplicate to delete."), + n => println!("Found {n} duplicates to delete."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + #[cfg(feature = "daemon")] + let ids = matches.iter().map(|h| h.id.clone()).collect::>(); + + for entry in matches { + eprintln!("deleting {}", entry.id); + let (id, _) = history_store.delete(entry.id).await?; + history_store.incremental_build(db, &[id]).await?; + } + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event( + settings, + crate::atuin_daemon::DaemonEvent::HistoryDeleted { ids }, + ) + .await; + } + Ok(()) + } + + #[expect(clippy::too_many_lines)] + pub async fn run(self, settings: &Settings) -> Result<()> { + match self { + Self::Start { + cmd_env, + author, + intent, + command, + } => { + let command = if cmd_env { + std::env::var("ATUIN_COMMAND_LINE").unwrap_or_default() + } else { + command.join(" ") + }; + + if let Some(id) = + start_history_entry(settings, &command, author.as_deref(), intent.as_deref()) + .await? + { + println!("{id}"); + } + + Ok(()) + } + Self::End { id, exit, duration } => { + end_history_entry(settings, &id, exit, duration).await + } + Self::Tail => { + #[cfg(feature = "daemon")] + { + return Self::handle_tail(settings).await; + } + + #[cfg(not(feature = "daemon"))] + bail!("`atuin history tail` requires Atuin to be built with the `daemon` feature"); + } + cmd => { + let context = current_context().await?; + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + match cmd { + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + &db, settings, context, session, cwd, mode, format, false, print0, + reverse, tz, + ) + .await + } + + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_slice(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); + + Ok(()) + } + + Self::InitStore => history_store.init_store(&db).await, + + Self::Prune { dry_run } => { + Self::handle_prune(&db, settings, store, context, dry_run).await + } + + Self::Dedup { + dry_run, + before, + dupkeep, + } => { + let before = i64::try_from( + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + )? + .unix_timestamp_nanos(), + )?; + Self::handle_dedup(&db, settings, store, before, dupkeep, dry_run).await + } + + Self::Start { .. } | Self::End { .. } | Self::Tail => unreachable!(), + } + } + } + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "daemon")] + use time::macros::datetime; + + use super::*; + + #[test] + fn normalize_command_strips_trailing_spaces_and_tabs() { + let settings = Settings::utc(); + + assert!(settings.strip_trailing_whitespace); + assert_eq!(normalize_command_for_storage("ls \t", &settings), "ls"); + } + + #[test] + fn normalize_command_preserves_escaped_trailing_space() { + let settings = Settings::utc(); + + assert_eq!( + normalize_command_for_storage("printf foo\\ ", &settings), + "printf foo\\ " + ); + assert_eq!( + normalize_command_for_storage("printf foo\\\\ ", &settings), + "printf foo\\\\" + ); + } + + #[tokio::test] + async fn handle_start_saves_trimmed_command() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings::utc(); + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls"); + } + + #[tokio::test] + async fn handle_start_can_keep_trailing_whitespace() { + let db = Sqlite::new("sqlite::memory:", 2.0).await.unwrap(); + let settings = Settings { + strip_trailing_whitespace: false, + ..Settings::utc() + }; + + handle_start(&db, &settings, "ls \t", None, None) + .await + .unwrap(); + + let history = db + .before(OffsetDateTime::now_utc() + time::Duration::SECOND, 1) + .await + .unwrap() + .pop() + .unwrap(); + assert_eq!(history.command, "ls \t"); + } + + #[test] + fn test_format_string_no_panic() { + // Don't panic but provide helpful output (issue #2776) + let malformed_json = r#"{"command":"{command}","key":"value"}"#; + + let result = std::panic::catch_unwind(|| parse_fmt(malformed_json)); + + assert!(result.is_ok()); + } + + #[test] + fn test_valid_formats_still_work() { + assert!(std::panic::catch_unwind(|| parse_fmt("{command}")).is_ok()); + assert!(std::panic::catch_unwind(|| parse_fmt("{time} - {command}")).is_ok()); + } + + #[cfg(feature = "daemon")] + fn sample_tail_event(kind: TailKind) -> TailEvent { + TailEvent { + kind, + history: History { + id: "history-id".to_owned().into(), + timestamp: datetime!(2026-04-09 17:18:19 UTC), + duration: 12_345_678, + exit: 0, + command: "git status".to_owned(), + cwd: "/tmp/repo".to_owned(), + session: "session-id".to_owned(), + hostname: "host:ellie".to_owned(), + author: "claude".to_owned(), + intent: Some("inspect repository state".to_owned()), + deleted_at: None, + }, + } + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_json_output_contains_history_fields() { + let json = sample_tail_event(TailKind::Ended) + .render(false, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let value: serde_json::Value = serde_json::from_str(&json).unwrap(); + + assert_eq!(value["event"], "ended"); + assert_eq!(value["history"]["id"], "history-id"); + assert_eq!(value["history"]["duration_ns"], 12_345_678); + assert_eq!(value["history"]["success"], true); + assert!(value.get("record").is_none()); + } + + #[cfg(feature = "daemon")] + #[test] + fn test_tail_pretty_output_shows_pending_fields_for_started_events() { + let rendered = sample_tail_event(TailKind::Started) + .render(true, Timezone(time::UtcOffset::UTC)) + .unwrap(); + let plain = regex::Regex::new(r"\x1b\[[0-9;]*m") + .unwrap() + .replace_all(&rendered, ""); + + assert!(plain.contains("STARTED git status")); + assert!(plain.contains("exit:")); + assert!(plain.contains("pending")); + assert!(plain.contains("duration:")); + assert!(plain.contains("running")); + } +} diff --git a/crates/turtle/src/command/client/import.rs b/crates/turtle/src/command/client/import.rs new file mode 100644 index 00000000..363e6405 --- /dev/null +++ b/crates/turtle/src/command/client/import.rs @@ -0,0 +1,186 @@ +use std::env; + +use async_trait::async_trait; +use clap::Parser; +use eyre::Result; +use indicatif::ProgressBar; + +use crate::atuin_client::{ + database::Database, + history::History, + import::{ + Importer, Loader, bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, + powershell::PowerShell, replxx::Replxx, resh::Resh, xonsh::Xonsh, + xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, + }, +}; + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Import history for the current shell + Auto, + + /// Import history from the zsh history file + Zsh, + /// Import history from the zsh history file + ZshHistDb, + /// Import history from the bash history file + Bash, + /// Import history from the replxx history file + Replxx, + /// Import history from the resh history file + Resh, + /// Import history from the fish history file + Fish, + /// Import history from the nu history file + Nu, + /// Import history from the nu history file + NuHistDb, + /// Import history from xonsh json files + Xonsh, + /// Import history from xonsh sqlite db + XonshSqlite, + /// Import history from the powershell history file + Powershell, +} + +const BATCH_SIZE: usize = 100; + +impl Cmd { + #[expect(clippy::cognitive_complexity)] + pub async fn run(&self, db: &DB) -> Result<()> { + println!(" Atuin "); + println!("======================"); + println!(" \u{1f30d} "); + println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); + println!(" \u{1f422} "); + println!("======================"); + println!("Importing history..."); + + match self { + Self::Auto => { + if cfg!(windows) { + return if env::var("PSModulePath").is_ok() { + println!("Detected PowerShell"); + import::(db).await + } else { + println!("Could not detect the current shell."); + println!("Please run atuin import ."); + println!("To view a list of shells, run atuin import."); + Ok(()) + }; + } + + // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is + let xonsh_histfile = + env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); + let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); + + if xonsh_histfile.to_lowercase().ends_with(".json") { + println!("Detected Xonsh"); + import::(db).await + } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { + println!("Detected Xonsh (SQLite backend)"); + import::(db).await + } else if shell.ends_with("/zsh") { + if ZshHistDb::histpath().is_ok() { + println!( + "Detected Zsh-HistDb, using :{}", + ZshHistDb::histpath().unwrap().to_str().unwrap() + ); + import::(db).await + } else { + println!("Detected ZSH"); + import::(db).await + } + } else if shell.ends_with("/fish") { + println!("Detected Fish"); + import::(db).await + } else if shell.ends_with("/bash") { + println!("Detected Bash"); + import::(db).await + } else if shell.ends_with("/nu") { + if NuHistDb::histpath().is_ok() { + println!( + "Detected Nu-HistDb, using :{}", + NuHistDb::histpath().unwrap().to_str().unwrap() + ); + import::(db).await + } else { + println!("Detected Nushell"); + import::(db).await + } + } else if shell.ends_with("/pwsh") { + println!("Detected PowerShell"); + import::(db).await + } else { + println!("cannot import {shell} history"); + Ok(()) + } + } + + Self::Zsh => import::(db).await, + Self::ZshHistDb => import::(db).await, + Self::Bash => import::(db).await, + Self::Replxx => import::(db).await, + Self::Resh => import::(db).await, + Self::Fish => import::(db).await, + Self::Nu => import::(db).await, + Self::NuHistDb => import::(db).await, + Self::Xonsh => import::(db).await, + Self::XonshSqlite => import::(db).await, + Self::Powershell => import::(db).await, + } + } +} + +pub struct HistoryImporter<'db, DB: Database> { + pb: ProgressBar, + buf: Vec, + db: &'db DB, +} + +impl<'db, DB: Database> HistoryImporter<'db, DB> { + fn new(db: &'db DB, len: usize) -> Self { + Self { + pb: ProgressBar::new(len as u64), + buf: Vec::with_capacity(BATCH_SIZE), + db, + } + } + + async fn flush(self) -> Result<()> { + if !self.buf.is_empty() { + self.db.save_bulk(&self.buf).await?; + } + self.pb.finish(); + Ok(()) + } +} + +#[async_trait] +impl Loader for HistoryImporter<'_, DB> { + async fn push(&mut self, hist: History) -> Result<()> { + self.pb.inc(1); + self.buf.push(hist); + if self.buf.len() == self.buf.capacity() { + self.db.save_bulk(&self.buf).await?; + self.buf.clear(); + } + Ok(()) + } +} + +async fn import(db: &DB) -> Result<()> { + println!("Importing history from {}", I::NAME); + + let mut importer = I::new().await?; + let len = importer.entries().await.unwrap(); + let mut loader = HistoryImporter::new(db, len); + importer.load(&mut loader).await?; + loader.flush().await?; + + println!("Import complete!"); + Ok(()) +} diff --git a/crates/turtle/src/command/client/info.rs b/crates/turtle/src/command/client/info.rs new file mode 100644 index 00000000..ee24c419 --- /dev/null +++ b/crates/turtle/src/command/client/info.rs @@ -0,0 +1,31 @@ +use crate::atuin_client::settings::Settings; + +use crate::{SHA, VERSION}; + +pub fn run(settings: &Settings) { + let config = crate::atuin_common::utils::config_dir(); + let mut config_file = config.clone(); + config_file.push("config.toml"); + let mut sever_config = config; + sever_config.push("server.toml"); + + let config_paths = format!( + "Config files:\nclient config: {:?}\nserver config: {:?}\nclient db path: {:?}\nkey path: {:?}\nmeta db path: {:?}", + config_file.to_string_lossy(), + sever_config.to_string_lossy(), + settings.db_path, + settings.key_path, + settings.meta.db_path + ); + + let env_vars = format!( + "Env Vars:\nATUIN_CONFIG_DIR = {:?}", + std::env::var("ATUIN_CONFIG_DIR").unwrap_or_else(|_| "None".into()) + ); + + let general_info = format!("Version info:\nversion: {VERSION}\ncommit: {SHA}"); + + let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}"); + + println!("{print_out}"); +} diff --git a/crates/turtle/src/command/client/init.rs b/crates/turtle/src/command/client/init.rs new file mode 100644 index 00000000..bf9747bb --- /dev/null +++ b/crates/turtle/src/command/client/init.rs @@ -0,0 +1,127 @@ +use crate::atuin_client::settings::{Settings, Tmux}; +use clap::{Parser, ValueEnum}; + +mod bash; +mod fish; +mod powershell; +mod xonsh; +mod zsh; + +#[derive(Parser, Debug)] +pub struct Cmd { + shell: Shell, + + /// Disable the binding of CTRL-R to atuin + #[clap(long)] + disable_ctrl_r: bool, + + /// Disable the binding of the Up Arrow key to atuin + #[clap(long)] + disable_up_arrow: bool, + + /// Disable the binding of ? to Atuin AI + #[clap(long)] + disable_ai: bool, +} + +#[derive(Clone, Copy, ValueEnum, Debug)] +#[value(rename_all = "lower")] +#[expect(clippy::enum_variant_names, clippy::doc_markdown)] +pub enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, + /// Xonsh setup + Xonsh, + /// PowerShell setup + PowerShell, +} + +impl Cmd { + fn init_nu(&self, _tmux: &Tmux) { + let full = include_str!("../../shell/atuin.nu"); + + // TODO: tmux popup for Nu + println!("{full}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: control + keycode: char_r + mode: [emacs, vi_normal, vi_insert] + event: { send: executehostcommand cmd: (_atuin_search_cmd) } + } + ) +)"; + const BIND_UP_ARROW: &str = r" +$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: none + keycode: up + mode: [emacs, vi_normal, vi_insert] + event: { + until: [ + {send: menuup} + {send: executehostcommand cmd: (_atuin_search_cmd '--shell-up-key-binding') } + ] + } + } + ) +) +"; + if !self.disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !self.disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } + } + + fn static_init(&self, settings: &Settings) { + let tmux = &settings.tmux; + + match self.shell { + Shell::Zsh => { + zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Bash => { + bash::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Fish => { + fish::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::Nu => { + self.init_nu(tmux); + } + Shell::Xonsh => { + xonsh::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + Shell::PowerShell => { + powershell::init_static(self.disable_up_arrow, self.disable_ctrl_r, tmux); + } + } + } + + pub fn run(self, settings: &Settings) { + if !settings.paths_ok() { + eprintln!( + "Atuin settings paths are broken. Disabling atuin shell hooks. Run `atuin doctor` to diagnose." + ); + } + + self.static_init(settings); + } +} diff --git a/crates/turtle/src/command/client/init/bash.rs b/crates/turtle/src/command/client/init/bash.rs new file mode 100644 index 00000000..fd17e37e --- /dev/null +++ b/crates/turtle/src/command/client/init/bash.rs @@ -0,0 +1,25 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.bash"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + print_tmux_config(tmux); + println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); + println!("__atuin_bind_up_arrow={bind_up_arrow}"); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/fish.rs b/crates/turtle/src/command/client/init/fish.rs new file mode 100644 index 00000000..8a046bfa --- /dev/null +++ b/crates/turtle/src/command/client/init/fish.rs @@ -0,0 +1,86 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("set -gx ATUIN_TMUX_POPUP_WIDTH '{}'", tmux.width); + println!("set -gx ATUIN_TMUX_POPUP_HEIGHT '{}'", tmux.height); + } else { + println!("set -gx ATUIN_TMUX_POPUP false"); + } +} + +fn print_bindings( + indent: &str, + disable_up_arrow: bool, + disable_ctrl_r: bool, + bind_ctrl_r: &str, + bind_up_arrow: &str, + bind_ctrl_r_ins: &str, + bind_up_arrow_ins: &str, +) { + if !disable_ctrl_r { + println!("{indent}{bind_ctrl_r}"); + } + if !disable_up_arrow { + println!("{indent}{bind_up_arrow}"); + } + + println!("{indent}if bind -M insert >/dev/null 2>&1"); + if !disable_ctrl_r { + println!("{indent}{indent}{bind_ctrl_r_ins}"); + } + if !disable_up_arrow { + println!("{indent}{indent}{bind_up_arrow_ins}"); + } + println!("{indent}end"); +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let indent = " ".repeat(4); + + let base = include_str!("../../../shell/atuin.fish"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + println!("if string match -q '4.*' $version"); + + // In fish 4.0 and above the option bind -k doesn't exist anymore, + // instead we can use key names and modifiers directly. + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + "bind ctrl-r _atuin_search", + "bind up _atuin_bind_up", + "bind -M insert ctrl-r _atuin_search", + "bind -M insert up _atuin_bind_up", + ); + + println!("else"); + + // We keep these for compatibility with fish 3.x + print_bindings( + &indent, + disable_up_arrow, + disable_ctrl_r, + r"bind \cr _atuin_search", + &[ + r"bind -k up _atuin_bind_up", + r"bind \eOA _atuin_bind_up", + r"bind \e\[A _atuin_bind_up", + ] + .join("; "), + r"bind -M insert \cr _atuin_search", + &[ + r"bind -M insert -k up _atuin_bind_up", + r"bind -M insert \eOA _atuin_bind_up", + r"bind -M insert \e\[A _atuin_bind_up", + ] + .join("; "), + ); + + println!("end"); + } +} diff --git a/crates/turtle/src/command/client/init/powershell.rs b/crates/turtle/src/command/client/init/powershell.rs new file mode 100644 index 00000000..10c0c461 --- /dev/null +++ b/crates/turtle/src/command/client/init/powershell.rs @@ -0,0 +1,23 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.ps1"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for Powershell + println!("{base}"); + println!( + "Enable-AtuinSearchKeys -CtrlR {} -UpArrow {}", + ps_bool(bind_ctrl_r), + ps_bool(bind_up_arrow) + ); +} + +fn ps_bool(value: bool) -> &'static str { + if value { "$true" } else { "$false" } +} diff --git a/crates/turtle/src/command/client/init/xonsh.rs b/crates/turtle/src/command/client/init/xonsh.rs new file mode 100644 index 00000000..a17d85d8 --- /dev/null +++ b/crates/turtle/src/command/client/init/xonsh.rs @@ -0,0 +1,22 @@ +use crate::atuin_client::settings::Tmux; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, _tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.xsh"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + // TODO: tmux popup for xonsh + println!( + "_ATUIN_BIND_CTRL_R={}", + if bind_ctrl_r { "True" } else { "False" } + ); + println!( + "_ATUIN_BIND_UP_ARROW={}", + if bind_up_arrow { "True" } else { "False" } + ); + println!("{base}"); +} diff --git a/crates/turtle/src/command/client/init/zsh.rs b/crates/turtle/src/command/client/init/zsh.rs new file mode 100644 index 00000000..38c3086b --- /dev/null +++ b/crates/turtle/src/command/client/init/zsh.rs @@ -0,0 +1,38 @@ +use crate::atuin_client::settings::Tmux; + +fn print_tmux_config(tmux: &Tmux) { + if tmux.enabled { + println!("export ATUIN_TMUX_POPUP_WIDTH='{}'", tmux.width); + println!("export ATUIN_TMUX_POPUP_HEIGHT='{}'", tmux.height); + } else { + println!("export ATUIN_TMUX_POPUP=false"); + } +} + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool, tmux: &Tmux) { + let base = include_str!("../../../shell/atuin.zsh"); + + print_tmux_config(tmux); + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"bindkey -M emacs '^r' atuin-search +bindkey -M viins '^r' atuin-search-viins +bindkey -M vicmd '/' atuin-search"; + + const BIND_UP_ARROW: &str = r"bindkey -M emacs '^[[A' atuin-up-search +bindkey -M vicmd '^[[A' atuin-up-search-vicmd +bindkey -M viins '^[[A' atuin-up-search-viins +bindkey -M emacs '^[OA' atuin-up-search +bindkey -M vicmd '^[OA' atuin-up-search-vicmd +bindkey -M viins '^[OA' atuin-up-search-viins +bindkey -M vicmd 'k' atuin-up-search-vicmd"; + + if !disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } +} diff --git a/crates/turtle/src/command/client/search.rs b/crates/turtle/src/command/client/search.rs new file mode 100644 index 00000000..4a2114d5 --- /dev/null +++ b/crates/turtle/src/command/client/search.rs @@ -0,0 +1,375 @@ +use std::fs::File; +use std::io::{IsTerminal as _, Write, stderr, stdout}; + +use crate::atuin_common::utils::{self, Escapable as _}; +use clap::Parser; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + database::{OptFilters, current_context}, + encryption, + history::{History, store::HistoryStore}, + record::sqlite_store::SqliteStore, + settings::{FilterMode, KeymapMode, SearchMode, Settings, Timezone}, + theme::Theme, +}; + +use super::history::ListMode; + +mod cursor; +mod duration; +mod engines; +mod history_list; +mod inspector; +mod interactive; +pub mod keybindings; + +pub use duration::format_duration_into; + +#[expect(clippy::struct_excessive_bools, clippy::struct_field_names)] +#[derive(Parser, Debug)] +pub struct Cmd { + /// Filter search result by directory + #[arg(long, short)] + cwd: Option, + + /// Exclude directory from results + #[arg(long = "exclude-cwd")] + exclude_cwd: Option, + + /// Filter search result by exit code + #[arg(long, short)] + exit: Option, + + /// Exclude results with this exit code + #[arg(long = "exclude-exit")] + exclude_exit: Option, + + /// Only include results added before this date + #[arg(long, short)] + before: Option, + + /// Only include results after this date + #[arg(long)] + after: Option, + + /// How many entries to return at most + #[arg(long)] + limit: Option, + + /// Offset from the start of the results + #[arg(long)] + offset: Option, + + /// Open interactive search UI + #[arg(long, short)] + interactive: bool, + + /// Allow overriding filter mode over config + #[arg(long = "filter-mode")] + filter_mode: Option, + + /// Allow overriding search mode over config + #[arg(long = "search-mode")] + search_mode: Option, + + /// Marker argument used to inform atuin that it was invoked from a shell up-key binding (hidden from help to avoid confusion) + #[arg(long = "shell-up-key-binding", hide = true)] + shell_up_key_binding: bool, + + /// Notify the keymap at the shell's side + #[arg(long = "keymap-mode", default_value = "auto")] + keymap_mode: KeymapMode, + + /// Use human-readable formatting for time + #[arg(long)] + human: bool, + + #[arg(allow_hyphen_values = true)] + query: Option>, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline handling + #[arg(long)] + print0: bool, + + /// Delete anything matching this query. Will not print out the match + #[arg(long)] + delete: bool, + + /// Delete EVERYTHING! + #[arg(long)] + delete_it_all: bool, + + /// Reverse the order of results, oldest first + #[arg(long, short)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + #[arg(allow_hyphen_values = true)] + // Clippy warns about `Option>`, but we suppress it because we need + // this distinction for proper argument handling. + #[expect(clippy::option_option)] + timezone: Option>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {time}, {exit} and + /// {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option, + + /// Set the maximum number of lines Atuin's interface should take up. + #[arg(long = "inline-height")] + inline_height: Option, + + /// Filter by author. Supports $all-user (non-agents), $all-agent, or literal names. + /// Can be specified multiple times. + #[arg(long)] + author: Option>, + + /// Include duplicate commands in the output (non-interactive only) + #[arg(long)] + include_duplicates: bool, + + /// File name to write the result to (hidden from help as this is meant to be used from a script) + #[arg(long = "result-file", hide = true)] + result_file: Option, +} + +impl Cmd { + /// Returns true if this search command will run in interactive (TUI) mode + pub fn is_interactive(&self) -> bool { + self.interactive + } + + // clippy: please write this instead + // clippy: now it has too many lines + // me: I'll do it later OKAY + #[expect(clippy::too_many_lines)] + pub async fn run( + self, + db: impl Database, + settings: &mut Settings, + store: SqliteStore, + theme: &Theme, + ) -> Result<()> { + let query = self.query.unwrap_or_else(|| { + std::env::var("ATUIN_QUERY").map_or_else( + |_| vec![], + |query| { + query + .split(' ') + .map(std::string::ToString::to_string) + .collect() + }, + ) + }); + + if (self.delete_it_all || self.delete) && self.limit.is_some() { + // Because of how deletion is implemented, it will always delete all matches + // and disregard the limit option. It is also not clear what deletion with a + // limit would even mean. Deleting the LIMIT most recent entries that match + // the search query would make sense, but that wouldn't match what's displayed + // when running the equivalent search, but deleting those entries that are + // displayed with the search would leave any duplicates of those lines which may + // or may not have been intended to be deleted. + eprintln!("\"--limit\" is not compatible with deletion."); + return Ok(()); + } + + if self.delete && query.is_empty() { + eprintln!( + "Please specify a query to match the items you wish to delete. If you wish to delete all history, pass --delete-it-all" + ); + return Ok(()); + } + + if self.delete_it_all && !query.is_empty() { + eprintln!( + "--delete-it-all will delete ALL of your history! It does not require a query." + ); + return Ok(()); + } + + if let Some(search_mode) = self.search_mode { + settings.search_mode = search_mode; + } + if let Some(filter_mode) = self.filter_mode { + settings.filter_mode = Some(filter_mode); + } + if let Some(inline_height) = self.inline_height { + settings.inline_height = inline_height; + } + + settings.shell_up_key_binding = self.shell_up_key_binding; + + // `keymap_mode` specified in config.toml overrides the `--keymap-mode` + // option specified in the keybindings. + settings.keymap_mode = match settings.keymap_mode { + KeymapMode::Auto => self.keymap_mode, + value => value, + }; + settings.keymap_mode_shell = self.keymap_mode; + + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + if self.interactive { + let item = interactive::history(&query, settings, db, &history_store, theme).await?; + + if let Some(result_file) = self.result_file { + let mut file = File::create(result_file)?; + write!(file, "{item}")?; + } else if !stdout().is_terminal() { + // stdout is not a terminal - likely command substitution like VAR=$(atuin search -i) + // Write to stdout so it gets captured. This requires some care on Windows, as the current + // console code page or `[Console]::OutputEncoding` on PowerShell may be different from UTF-8. + println!("{item}"); + } else if stderr().is_terminal() { + eprintln!("{}", item.escape_control()); + } else { + eprintln!("{item}"); + } + } else { + let opt_filter = OptFilters { + exit: self.exit, + exclude_exit: self.exclude_exit, + cwd: self.cwd, + exclude_cwd: self.exclude_cwd, + before: self.before, + after: self.after, + limit: self.limit, + offset: self.offset, + reverse: self.reverse, + include_duplicates: self.include_duplicates, + authors: self.author.clone().unwrap_or_default(), + }; + + let mut entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + + if entries.is_empty() { + std::process::exit(1) + } + + // if we aren't deleting, print it all + if self.delete || self.delete_it_all { + // delete it + // it only took me _years_ to add this + // sorry + while !entries.is_empty() { + for entry in &entries { + eprintln!("deleting {}", entry.id); + } + + let ids = history_store.delete_entries(entries).await?; + history_store.incremental_build(&db, &ids).await?; + + entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + } + } else { + let format = match self.format { + None => Some(settings.history_format.as_str()), + _ => self.format.as_deref(), + }; + let tz = match self.timezone { + Some(Some(tz)) => tz, // User provided a value + Some(None) | None => settings.timezone, // No value was provided + }; + + super::history::print_list( + &entries, + ListMode::from_flags(self.human, self.cmd_only), + format, + self.print0, + true, + tz, + ); + } + } + Ok(()) + } +} + +// This is supposed to more-or-less mirror the command line version, so ofc +// it is going to have a lot of args +#[expect(clippy::too_many_arguments, clippy::cast_possible_truncation)] +async fn run_non_interactive( + settings: &Settings, + filter_options: OptFilters, + query: &[String], + db: &impl Database, +) -> Result> { + let dir = if filter_options.cwd.as_deref() == Some(".") { + Some(utils::get_current_dir()) + } else { + filter_options.cwd + }; + + let context = current_context().await?; + + let opt_filter = OptFilters { + cwd: dir.clone(), + ..filter_options + }; + + let filter_mode = settings.default_filter_mode(context.git_root.is_some()); + + let results = db + .search( + settings.search_mode, + filter_mode, + &context, + query.join(" ").as_str(), + opt_filter, + ) + .await?; + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::Cmd; + use clap::Parser; + + #[test] + fn search_for_triple_dash() { + // Issue #3028: searching for `---` should not be treated as a CLI flag + let cmd = Cmd::try_parse_from(["search", "---"]); + assert!(cmd.is_ok(), "Failed to parse '---' as a query: {cmd:?}"); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["---".to_string()])); + } + + #[test] + fn search_for_double_dash_value() { + // Searching for strings starting with -- should also work + let cmd = Cmd::try_parse_from(["search", "--", "--foo"]); + assert!(cmd.is_ok()); + let cmd = cmd.unwrap(); + assert_eq!(cmd.query, Some(vec!["--foo".to_string()])); + } + + #[test] + fn search_author_cli_flag() { + let cmd = + Cmd::try_parse_from(["search", "--author", "codex", "--author", "ellie"]).unwrap(); + assert_eq!( + cmd.author, + Some(vec!["codex".to_string(), "ellie".to_string()]) + ); + } +} diff --git a/crates/turtle/src/command/client/search/cursor.rs b/crates/turtle/src/command/client/search/cursor.rs new file mode 100644 index 00000000..84f94082 --- /dev/null +++ b/crates/turtle/src/command/client/search/cursor.rs @@ -0,0 +1,405 @@ +use crate::atuin_client::settings::WordJumpMode; + +pub struct Cursor { + source: String, + index: usize, +} + +impl From for Cursor { + fn from(source: String) -> Self { + Self { source, index: 0 } + } +} + +pub struct WordJumper<'a> { + word_chars: &'a str, + word_jump_mode: WordJumpMode, +} + +impl WordJumper<'_> { + fn is_word_boundary(&self, c: char, next_c: char) -> bool { + (c.is_whitespace() && !next_c.is_whitespace()) + || (!c.is_whitespace() && next_c.is_whitespace()) + || (self.word_chars.contains(c) && !self.word_chars.contains(next_c)) + || (!self.word_chars.contains(c) && self.word_chars.contains(next_c)) + } + + fn emacs_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index + 1..source.len().saturating_sub(1)) + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()); + (index + 1..source.len().saturating_sub(1)) + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()) + } + + fn emacs_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(0); + (1..index) + .rev() + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .map_or(0, |i| i + 1) + } + + fn subl_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index..source.len().saturating_sub(1)).find(|&i| { + self.is_word_boundary( + source.chars().nth(i).unwrap(), + source.chars().nth(i + 1).unwrap(), + ) + }); + if index.is_none() { + return source.len(); + } + (index.unwrap() + 1..source.len()) + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()) + .unwrap_or(source.len()) + } + + fn subl_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()); + if index.is_none() { + return 0; + } + (1..index.unwrap()) + .rev() + .find(|&i| { + self.is_word_boundary( + source.chars().nth(i - 1).unwrap(), + source.chars().nth(i).unwrap(), + ) + }) + .unwrap_or(0) + } + + fn get_next_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_next_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_next_word_pos(source, index), + } + } + + fn get_prev_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_prev_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_prev_word_pos(source, index), + } + } +} + +impl Cursor { + pub fn as_str(&self) -> &str { + self.source.as_str() + } + + pub fn into_inner(self) -> String { + self.source + } + + /// Returns the string before the cursor + pub fn substring(&self) -> &str { + &self.source[..self.index] + } + + /// Returns the currently selected [`char`] + pub fn char(&self) -> Option { + self.source[self.index..].chars().next() + } + + pub fn right(&mut self) { + if self.index < self.source.len() { + loop { + self.index += 1; + if self.source.is_char_boundary(self.index) { + break; + } + } + } + } + + pub fn left(&mut self) -> bool { + if self.index > 0 { + loop { + self.index -= 1; + if self.source.is_char_boundary(self.index) { + break true; + } + } + } else { + false + } + } + + pub fn next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_next_word_pos(&self.source, self.index); + } + + pub fn prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_prev_word_pos(&self.source, self.index); + } + + /// Move cursor to the end of the current/next word (vim `e` motion). + /// + /// If cursor is in the middle of a word, moves to the end of that word. + /// If cursor is at the end of a word (or on whitespace), moves to the + /// end of the next word. + pub fn word_end(&mut self, word_chars: &str) { + let len = self.source.len(); + if self.index >= len { + return; + } + + let chars: Vec = self.source.chars().collect(); + let mut char_idx = self.source[..self.index].chars().count(); + + if char_idx >= chars.len() { + return; + } + + let current = chars[char_idx]; + + // Check if we're at a word boundary (end of current word or on whitespace) + let at_word_boundary = current.is_whitespace() || char_idx + 1 >= chars.len() || { + let next = chars[char_idx + 1]; + next.is_whitespace() || (word_chars.contains(current) != word_chars.contains(next)) + }; + + // If at word boundary, advance past it and skip whitespace to find next word + if at_word_boundary { + char_idx += 1; + while char_idx < chars.len() && chars[char_idx].is_whitespace() { + char_idx += 1; + } + } + + // If we've gone past end, go to end of string + if char_idx >= chars.len() { + self.index = len; + return; + } + + // Find end of word: advance until next char is whitespace or different word type + let in_word_chars = word_chars.contains(chars[char_idx]); + while char_idx < chars.len() { + let next_idx = char_idx + 1; + if next_idx >= chars.len() { + // At last char, move past it + char_idx = next_idx; + break; + } + let next_c = chars[next_idx]; + if next_c.is_whitespace() || (word_chars.contains(next_c) != in_word_chars) { + // Next char is start of new word/whitespace, so current char is end + char_idx = next_idx; + break; + } + char_idx += 1; + } + + // Convert char index back to byte index + self.index = chars.iter().take(char_idx).map(|c| c.len_utf8()).sum(); + } + + pub fn insert(&mut self, c: char) { + self.source.insert(self.index, c); + self.index += c.len_utf8(); + } + + pub fn remove(&mut self) -> Option { + if self.index < self.source.len() { + Some(self.source.remove(self.index)) + } else { + None + } + } + + pub fn remove_next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_next_word_pos(&self.source, self.index); + self.source.replace_range(self.index..next_index, ""); + } + + pub fn remove_prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_prev_word_pos(&self.source, self.index); + self.source.replace_range(next_index..self.index, ""); + self.index = next_index; + } + + pub fn back(&mut self) -> Option { + if self.left() { self.remove() } else { None } + } + + pub fn clear(&mut self) { + self.source.clear(); + self.index = 0; + } + + pub fn clear_to_start(&mut self) { + self.source.replace_range(..self.index, ""); + self.index = 0; + } + + pub fn clear_to_end(&mut self) { + self.source.replace_range(self.index.., ""); + self.index = self.source.len(); + } + + pub fn end(&mut self) { + self.index = self.source.len(); + } + + pub fn start(&mut self) { + self.index = 0; + } + + pub fn position(&self) -> usize { + self.index + } +} + +#[cfg(test)] +mod cursor_tests { + use super::Cursor; + use super::*; + + static EMACS_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + word_jump_mode: WordJumpMode::Emacs, + }; + + static SUBL_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "./\\()\"'-:,.;<>~!@#$%^&*|+=[]{}`~?", + word_jump_mode: WordJumpMode::Subl, + }; + + #[test] + fn right() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + let indices = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 20, 20, 20]; + for i in indices { + assert_eq!(c.index, i); + c.right(); + } + } + + #[test] + fn left() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + c.end(); + let indices = [20, 18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 3, 2, 0, 0, 0, 0]; + for i in indices { + assert_eq!(c.index, i); + c.left(); + } + } + + #[test] + fn test_emacs_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 6), (3, 6), (7, 18), (19, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_emacs_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 15), (29, 15), (15, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 3), (1, 3), (3, 9), (9, 15), (15, 21), (21, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 21), (21, 15), (15, 9), (9, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn pop() { + let mut s = String::from("öaöböcödöeöfö"); + let mut c = Cursor::from(s.clone()); + c.end(); + while !s.is_empty() { + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + assert_eq!(s.as_str(), c.substring()); + } + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + } + + #[test] + fn back() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + assert_eq!(c.back(), Some('b')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), Some('a')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), None); + assert_eq!(c.as_str(), "öcödöeöfö"); + } + + #[test] + fn insert() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + c.insert('ö'); + c.insert('g'); + c.insert('ö'); + c.insert('h'); + assert_eq!(c.substring(), "öaöbögöh"); + assert_eq!(c.as_str(), "öaöbögöhöcödöeöfö"); + } +} diff --git a/crates/turtle/src/command/client/search/duration.rs b/crates/turtle/src/command/client/search/duration.rs new file mode 100644 index 00000000..54856c87 --- /dev/null +++ b/crates/turtle/src/command/client/search/duration.rs @@ -0,0 +1,65 @@ +use core::fmt; +use std::{ops::ControlFlow, time::Duration}; + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration_into(dur: Duration, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn item(unit: &'static str, value: u64) -> ControlFlow<(&'static str, u64)> { + if value > 0 { + ControlFlow::Break((unit, value)) + } else { + ControlFlow::Continue(()) + } + } + + // impl taken and modified from + // https://github.com/tailhook/humantime/blob/master/src/duration.rs#L295-L331 + // Copyright (c) 2016 The humantime Developers + fn fmt(f: Duration) -> ControlFlow<(&'static str, u64), ()> { + let secs = f.as_secs(); + let nanos = f.subsec_nanos(); + + let years = secs / 31_557_600; // 365.25d + let year_days = secs % 31_557_600; + let months = year_days / 2_630_016; // 30.44d + let month_days = year_days % 2_630_016; + let days = month_days / 86400; + let day_secs = month_days % 86400; + let hours = day_secs / 3600; + let minutes = day_secs % 3600 / 60; + let seconds = day_secs % 60; + + let millis = nanos / 1_000_000; + let micros = nanos / 1_000; + + // a difference from our impl than the original is that + // we only care about the most-significant segment of the duration. + // If the item call returns `Break`, then the `?` will early-return. + // This allows for a very consise impl + item("y", years)?; + item("mo", months)?; + item("d", days)?; + item("h", hours)?; + item("m", minutes)?; + item("s", seconds)?; + item("ms", u64::from(millis))?; + item("us", u64::from(micros))?; + item("ns", u64::from(nanos))?; + ControlFlow::Continue(()) + } + + match fmt(dur) { + ControlFlow::Break((unit, value)) => write!(f, "{value}{unit}"), + ControlFlow::Continue(()) => write!(f, "0s"), + } +} + +#[expect(clippy::module_name_repetitions)] +pub fn format_duration(f: Duration) -> String { + struct F(Duration); + impl fmt::Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + F(f).to_string() +} diff --git a/crates/turtle/src/command/client/search/engines.rs b/crates/turtle/src/command/client/search/engines.rs new file mode 100644 index 00000000..0f92b4c7 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines.rs @@ -0,0 +1,95 @@ +use async_trait::async_trait; +use crate::atuin_client::{ + database::{Context, Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History, HistoryId}, + settings::{FilterMode, SearchMode, Settings}, +}; +use eyre::Result; + +use super::cursor::Cursor; + +#[cfg(feature = "daemon")] +pub mod daemon; +pub mod db; +pub mod skim; + +#[expect(unused)] // settings is only used if daemon feature is enabled +pub fn engine(search_mode: SearchMode, settings: &Settings) -> Box { + match search_mode { + SearchMode::Skim => Box::new(skim::Search::new()) as Box<_>, + #[cfg(feature = "daemon")] + SearchMode::DaemonFuzzy => Box::new(daemon::Search::new(settings)) as Box<_>, + #[cfg(not(feature = "daemon"))] + SearchMode::DaemonFuzzy => { + // Fall back to fuzzy mode if daemon feature is not enabled + Box::new(db::Search(SearchMode::Fuzzy)) as Box<_> + } + mode => Box::new(db::Search(mode)) as Box<_>, + } +} + +pub struct SearchState { + pub input: Cursor, + pub filter_mode: FilterMode, + pub context: Context, + pub custom_context: Option, +} + +impl SearchState { + pub(crate) fn rotate_filter_mode(&mut self, settings: &Settings, offset: isize) { + let mut i = settings + .search + .filters + .iter() + .position(|&m| m == self.filter_mode) + .unwrap_or_default(); + for _ in 0..settings.search.filters.len() { + i = (i.wrapping_add_signed(offset)) % settings.search.filters.len(); + let mode = settings.search.filters[i]; + if self.filter_mode_available(mode, settings) { + self.filter_mode = mode; + break; + } + } + } + + fn filter_mode_available(&self, mode: FilterMode, settings: &Settings) -> bool { + match mode { + FilterMode::Global | FilterMode::SessionPreload => self.custom_context.is_none(), + FilterMode::Workspace => settings.workspaces && self.context.git_root.is_some(), + _ => true, + } + } +} + +#[async_trait] +pub trait SearchEngine: Send + Sync + 'static { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result>; + + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result> { + if state.input.as_str().is_empty() { + Ok(db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + "", + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await? + .into_iter() + .collect::>()) + } else { + self.full_query(state, db).await + } + } + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec; +} diff --git a/crates/turtle/src/command/client/search/engines/daemon.rs b/crates/turtle/src/command/client/search/engines/daemon.rs new file mode 100644 index 00000000..b1299c02 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/daemon.rs @@ -0,0 +1,242 @@ +use crate::atuin_client::{ + database::{Database, OptFilters}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::{SearchMode, Settings}, +}; +use crate::atuin_daemon::client::{DaemonClientErrorKind, SearchClient, classify_error}; +use async_trait::async_trait; +use atuin_nucleo_matcher::{ + Config, Matcher, Utf32Str, + pattern::{CaseMatching, Normalization, Pattern}, +}; +use eyre::Result; +use tracing::{Level, debug, instrument, span}; +use uuid::Uuid; + +use super::{SearchEngine, SearchState}; +use crate::command::client::daemon; + +pub struct Search { + client: Option, + query_id: u64, + settings: Settings, + #[cfg(unix)] + socket_path: String, +} + +impl Search { + pub fn new(settings: &Settings) -> Self { + Search { + client: None, + query_id: 0, + settings: settings.clone(), + #[cfg(unix)] + socket_path: settings.daemon.socket_path.clone(), + } + } + + #[instrument(skip_all, level = Level::TRACE, name = "get_daemon_client")] + async fn get_client(&mut self) -> Result<&mut SearchClient> { + if self.client.is_none() { + self.connect().await?; + } + Ok(self.client.as_mut().unwrap()) + } + + async fn connect(&mut self) -> Result<()> { + #[cfg(unix)] + let client = SearchClient::new(self.socket_path.clone()).await?; + + self.client = Some(client); + Ok(()) + } + + fn should_retry(err: &eyre::Report) -> bool { + matches!( + classify_error(err), + DaemonClientErrorKind::Connect + | DaemonClientErrorKind::Unavailable + | DaemonClientErrorKind::Unimplemented + ) + } + + fn next_query_id(&mut self) -> u64 { + self.query_id += 1; + self.query_id + } + + /// Check if query contains regex pattern (r/.../) + /// Nucleo doesn't support regex, so we fall back to database search + fn contains_regex_pattern(query: &str) -> bool { + query.starts_with("r/") || query.contains(" r/") + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_db_fallback")] + async fn fallback_to_db_search( + &self, + state: &SearchState, + db: &dyn Database, + ) -> Result> { + let results = db + .search( + SearchMode::FullText, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "hydrate_from_db", fields(count = ids.len()))] + async fn hydrate_from_db(&self, db: &dyn Database, ids: &[String]) -> Result> { + let placeholders: Vec = ids.iter().map(|id| format!("'{id}'")).collect(); + let sql_query = format!( + "SELECT * FROM history WHERE id IN ({}) ORDER BY timestamp DESC", + placeholders.join(",") + ); + Ok(db.query_history(&sql_query).await?) + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "daemon_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result> { + let query = state.input.as_str().to_string(); + + // Fall back to database for regex queries (Nucleo doesn't support regex) + if Self::contains_regex_pattern(&query) { + debug!(query = %query, "[daemon-client] regex detected, falling back to db"); + return self.fallback_to_db_search(state, db).await; + } + + let query_id = self.next_query_id(); + + let span = + span!(Level::TRACE, "daemon_search.req_resp", query = %query, query_id = query_id); + + // Try to connect and search; if it fails with a retriable error, + // auto-start the daemon and retry once. + let first_attempt = async { + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await + } + .await; + + let mut stream = match first_attempt { + Ok(stream) => stream, + Err(err) if self.settings.daemon.autostart && Self::should_retry(&err) => { + debug!("daemon not available, attempting auto-start"); + self.client = None; + + daemon::ensure_daemon_running(&self.settings).await?; + + let client = self.get_client().await?; + client + .search( + query.clone(), + query_id, + state.filter_mode, + Some(state.context.clone()), + ) + .await? + } + Err(err) => return Err(err), + }; + + let mut ids = Vec::with_capacity(200); + span!(Level::TRACE, "daemon_search.resp") + .in_scope(async || { + while let Ok(Some(response)) = stream.message().await { + let span2 = span!( + Level::TRACE, + "daemon_search.resp.item", + query_id = response.query_id + ); + let _span2 = span2.enter(); + // Only process if the query_id matches (prevents stale responses) + if response.query_id == query_id { + let uuids = response + .ids + .iter() + .map(|id| { + let bytes: [u8; 16] = + id.as_slice().try_into().expect("id should be 16 bytes"); + Uuid::from_bytes(bytes).as_simple().to_string() + }) + .collect::>(); + ids.extend(uuids); + } + drop(_span2); + drop(span2); + } + }) + .await; + drop(span); + + if ids.is_empty() { + debug!(query = %query, results = 0, "[daemon-client] empty results"); + return Ok(Vec::new()); + } + + // // Hydrate from local database + let results = self.hydrate_from_db(db, &ids).await?; + + // // Reorder results to match the order from the daemon (which is ranked by relevance) + let ordered_results = span!(Level::TRACE, "reorder_results").in_scope(|| { + let mut ordered_results = Vec::with_capacity(results.len()); + for id in &ids { + if let Some(history) = results.iter().find(|h| h.id.0 == *id) { + ordered_results.push(history.clone()); + } + } + ordered_results + }); + + debug!( + query = %query, + results = results.len(), + "[daemon-client]" + ); + + Ok(ordered_results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "daemon_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { + // Use fulltext highlighting for regex queries + if Self::contains_regex_pattern(search_input) { + return super::db::get_highlight_indices_fulltext(command, search_input); + } + + let mut matcher = Matcher::new(Config::DEFAULT); + let pattern = Pattern::parse(search_input, CaseMatching::Smart, Normalization::Smart); + + let mut indices: Vec = Vec::new(); + let mut haystack_buf = Vec::new(); + + let haystack = Utf32Str::new(command, &mut haystack_buf); + pattern.indices(haystack, &mut matcher, &mut indices); + + // Convert u32 indices to usize + indices.into_iter().map(|i| i as usize).collect() + } +} diff --git a/crates/turtle/src/command/client/search/engines/db.rs b/crates/turtle/src/command/client/search/engines/db.rs new file mode 100644 index 00000000..2765faf5 --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/db.rs @@ -0,0 +1,110 @@ +use super::{SearchEngine, SearchState}; +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + database::OptFilters, + database::{QueryToken, QueryTokenizer}, + history::{AUTHOR_FILTER_ALL_USER, History}, + settings::SearchMode, +}; +use eyre::Result; +use norm::Metric; +use norm::fzf::{FzfParser, FzfV2}; +use std::ops::Range; +use tracing::{Level, instrument}; + +pub struct Search(pub SearchMode); + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "db_search", fields(mode = ?self.0, query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result> { + let results = db + .search( + self.0, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + authors: vec![AUTHOR_FILTER_ALL_USER.to_string()], + ..Default::default() + }, + ) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect()); + Ok(results) + } + + #[instrument(skip_all, level = Level::TRACE, name = "db_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { + if self.0 == SearchMode::Prefix { + return vec![]; + } else if self.0 == SearchMode::FullText { + return get_highlight_indices_fulltext(command, search_input); + } + let mut fzf = FzfV2::new(); + let mut parser = FzfParser::new(); + let query = parser.parse(search_input); + let mut ranges: Vec> = Vec::new(); + let _ = fzf.distance_and_ranges(query, command, &mut ranges); + + // convert ranges to all indices + ranges.into_iter().flatten().collect() + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "db_highlight_fulltext")] +pub fn get_highlight_indices_fulltext(command: &str, search_input: &str) -> Vec { + let mut ranges = vec![]; + let lower_command = command.to_ascii_lowercase(); + + for token in QueryTokenizer::new(search_input) { + let matchee = if token.has_uppercase() { + command + } else { + &lower_command + }; + + if token.is_inverse() { + continue; + } + + match token { + QueryToken::Or => {} + QueryToken::Regex(r) => { + if let Ok(re) = regex::Regex::new(r) { + for m in re.find_iter(command) { + ranges.push(m.range()); + } + } + } + QueryToken::MatchStart(term, _) => { + if matchee.starts_with(term) { + ranges.push(0..term.len()); + } + } + QueryToken::MatchEnd(term, _) => { + if matchee.ends_with(term) { + let l = matchee.len(); + ranges.push((l - term.len())..l); + } + } + QueryToken::Match(term, _) | QueryToken::MatchFull(term, _) => { + for (idx, m) in matchee.match_indices(term) { + ranges.push(idx..(idx + m.len())); + } + } + } + } + + let mut ret: Vec<_> = ranges.into_iter().flatten().collect(); + ret.sort_unstable(); + ret.dedup(); + ret +} diff --git a/crates/turtle/src/command/client/search/engines/skim.rs b/crates/turtle/src/command/client/search/engines/skim.rs new file mode 100644 index 00000000..96a6574d --- /dev/null +++ b/crates/turtle/src/command/client/search/engines/skim.rs @@ -0,0 +1,229 @@ +use std::path::Path; + +use async_trait::async_trait; +use crate::atuin_client::{ + database::Database, + history::{History, is_known_agent}, + settings::FilterMode, +}; +use eyre::Result; +use fuzzy_matcher::{FuzzyMatcher, skim::SkimMatcherV2}; +use itertools::Itertools; +use time::OffsetDateTime; +use tokio::task::yield_now; +use tracing::{Level, instrument, warn}; +use uuid; + +use super::{SearchEngine, SearchState}; + +pub struct Search { + all_history: Vec<(History, i32)>, + engine: SkimMatcherV2, +} + +impl Search { + pub fn new() -> Self { + Search { + all_history: vec![], + engine: SkimMatcherV2::default(), + } + } +} + +#[async_trait] +impl SearchEngine for Search { + #[instrument(skip_all, level = Level::TRACE, name = "skim_search", fields(query = %state.input.as_str()))] + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result> { + if self.all_history.is_empty() { + self.all_history = load_all_history(db).await; + } + + Ok(fuzzy_search(&self.engine, state, &self.all_history).await) + } + + #[instrument(skip_all, level = Level::TRACE, name = "skim_highlight")] + fn get_highlight_indices(&self, command: &str, search_input: &str) -> Vec { + let (_, indices) = self + .engine + .fuzzy_indices(command, search_input) + .unwrap_or_default(); + indices + } +} + +#[instrument(skip_all, level = Level::TRACE, name = "load_all_history")] +async fn load_all_history(db: &dyn Database) -> Vec<(History, i32)> { + db.all_with_count().await.unwrap() +} + +#[expect(clippy::too_many_lines)] +#[instrument(skip_all, level = Level::TRACE, name = "fuzzy_match", fields(history_count = all_history.len()))] +async fn fuzzy_search( + engine: &SkimMatcherV2, + state: &SearchState, + all_history: &[(History, i32)], +) -> Vec { + let mut set = Vec::with_capacity(200); + let mut ranks = Vec::with_capacity(200); + let query = state.input.as_str(); + let now = OffsetDateTime::now_utc(); + + for (i, (history, count)) in all_history.iter().enumerate() { + if i % 256 == 0 { + yield_now().await; + } + if is_known_agent(&history.author) { + continue; + } + let context = &state.context; + let git_root = context + .git_root + .as_ref() + .and_then(|git_root| git_root.to_str()) + .unwrap_or(&context.cwd); + match state.filter_mode { + FilterMode::Global => {} + // we aggregate host by ',' separating them + FilterMode::Host + if history + .hostname + .split(',') + .contains(&context.hostname.as_str()) => {} + // we aggregate session by concattenating them. + // sessions are 32 byte simple uuid formats + FilterMode::Session + if history + .session + .as_bytes() + .chunks(32) + .contains(&context.session.as_bytes()) => {} + // SessionPreload: include current session + global history from before session start + FilterMode::SessionPreload => { + let is_current_session = { + history + .session + .as_bytes() + .chunks(32) + .any(|chunk| chunk == context.session.as_bytes()) + }; + + if !is_current_session { + let Ok(uuid) = uuid::Uuid::parse_str(&context.session) else { + warn!("failed to parse session id '{}'", context.session); + continue; + }; + let Some(timestamp) = uuid.get_timestamp() else { + warn!( + "failed to get timestamp from uuid '{}'", + uuid.as_hyphenated() + ); + continue; + }; + let (seconds, nanos) = timestamp.to_unix(); + let Ok(session_start) = time::OffsetDateTime::from_unix_timestamp_nanos( + i128::from(seconds) * 1_000_000_000 + i128::from(nanos), + ) else { + warn!( + "failed to create OffsetDateTime from second: {seconds}, nanosecond: {nanos}" + ); + continue; + }; + + if history.timestamp >= session_start { + continue; + } + } + } + // we aggregate directory by ':' separating them + FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} + FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} + _ => continue, + } + #[expect(clippy::cast_lossless, clippy::cast_precision_loss)] + if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { + let begin = indices.first().copied().unwrap_or_default(); + + let mut duration = (now - history.timestamp).as_seconds_f64().log2(); + if !duration.is_finite() || duration <= 1.0 { + duration = 1.0; + } + // these + X.0 just make the log result a bit smoother. + // log is very spiky towards 1-4, but I want a gradual decay. + // eg: + // log2(4) = 2, log2(5) = 2.3 (16% increase) + // log2(8) = 3, log2(9) = 3.16 (5% increase) + // log2(16) = 4, log2(17) = 4.08 (2% increase) + let count = (*count as f64 + 8.0).log2(); + let begin = (begin as f64 + 16.0).log2(); + let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); + let path = (path as f64 + 8.0).log2(); + + // reduce longer durations, raise higher counts, raise matches close to the start + let score = (-score as f64) * count / path / duration / begin; + + 'insert: { + // algorithm: + // 1. find either the position that this command ranks + // 2. find the same command positioned better than our rank. + for i in 0..set.len() { + // do we out score the current position? + if ranks[i] > score { + ranks.insert(i, score); + set.insert(i, history.clone()); + let mut j = i + 1; + while j < set.len() { + // remove duplicates that have a worse score + if set[j].command == history.command { + ranks.remove(j); + set.remove(j); + + // break this while loop because there won't be any other + // duplicates. + break; + } + j += 1; + } + + // keep it limited + if ranks.len() > 200 { + ranks.pop(); + set.pop(); + } + + break 'insert; + } + // don't continue if this command has a better score already + if set[i].command == history.command { + break 'insert; + } + } + + if set.len() < 200 { + ranks.push(score); + set.push(history.clone()); + } + } + } + } + + set +} + +fn path_dist(a: &Path, b: &Path) -> usize { + let mut a: Vec<_> = a.components().collect(); + let b: Vec<_> = b.components().collect(); + + let mut dist = 0; + + // pop a until there's a common ancestor + while !b.starts_with(&a) { + dist += 1; + a.pop(); + } + + b.len() - a.len() + dist +} diff --git a/crates/turtle/src/command/client/search/history_list.rs b/crates/turtle/src/command/client/search/history_list.rs new file mode 100644 index 00000000..4c83d7eb --- /dev/null +++ b/crates/turtle/src/command/client/search/history_list.rs @@ -0,0 +1,429 @@ +use std::time::Duration; + +use super::duration::format_duration; +use super::engines::SearchEngine; +use crate::atuin_client::{ + history::History, + settings::{UiColumn, UiColumnType}, + theme::{Meaning, Theme}, +}; +use crate::atuin_common::utils::Escapable as _; +use itertools::Itertools; +use ratatui::{ + backend::FromCrossterm, + buffer::Buffer, + crossterm::style, + layout::Rect, + style::{Modifier, Style}, + widgets::{Block, StatefulWidget, Widget}, +}; +use time::OffsetDateTime; + +pub struct HistoryHighlighter<'a> { + pub engine: &'a dyn SearchEngine, + pub search_input: &'a str, +} + +impl HistoryHighlighter<'_> { + pub fn get_highlight_indices(&self, command: &str) -> Vec { + self.engine + .get_highlight_indices(command, self.search_input) + } +} + +pub struct HistoryList<'a> { + history: &'a [History], + block: Option>, + inverted: bool, + /// Apply an alternative highlighting to the selected row + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + /// Columns to display (in order, after the indicator) + columns: &'a [UiColumn], +} + +#[derive(Default)] +pub struct ListState { + offset: usize, + selected: usize, + max_entries: usize, +} + +impl ListState { + pub fn selected(&self) -> usize { + self.selected + } + + pub fn max_entries(&self) -> usize { + self.max_entries + } + + pub fn offset(&self) -> usize { + self.offset + } + + pub fn select(&mut self, index: usize) { + self.selected = index; + } +} + +impl StatefulWidget for HistoryList<'_> { + type State = ListState; + + fn render(mut self, area: Rect, buf: &mut Buffer, state: &mut Self::State) { + let list_area = self.block.take().map_or(area, |b| { + let inner_area = b.inner(area); + b.render(area, buf); + inner_area + }); + + if list_area.width < 1 || list_area.height < 1 || self.history.is_empty() { + return; + } + let list_height = list_area.height as usize; + + let (start, end) = self.get_items_bounds(state.selected, state.offset, list_height); + state.offset = start; + state.max_entries = end - start; + + let mut s = DrawState { + buf, + list_area, + x: 0, + y: 0, + state, + inverted: self.inverted, + alternate_highlight: self.alternate_highlight, + now: &self.now, + indicator: self.indicator, + theme: self.theme, + history_highlighter: self.history_highlighter, + show_numeric_shortcuts: self.show_numeric_shortcuts, + columns: self.columns, + }; + + for item in self.history.iter().skip(state.offset).take(end - start) { + s.render_row(item); + + // reset line + s.y += 1; + s.x = 0; + } + } +} + +impl<'a> HistoryList<'a> { + #[expect(clippy::too_many_arguments)] + pub fn new( + history: &'a [History], + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> Self { + Self { + history, + block: None, + inverted, + alternate_highlight, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + } + } + + pub fn block(mut self, block: Block<'a>) -> Self { + self.block = Some(block); + self + } + + fn get_items_bounds(&self, selected: usize, offset: usize, height: usize) -> (usize, usize) { + let offset = offset.min(self.history.len().saturating_sub(1)); + + let max_scroll_space = height.min(10).min(self.history.len() - selected); + if offset + height < selected + max_scroll_space { + let end = selected + max_scroll_space; + (end - height, end) + } else if selected < offset { + (selected, selected + height) + } else { + (offset, offset + height) + } + } +} + +struct DrawState<'a> { + buf: &'a mut Buffer, + list_area: Rect, + x: u16, + y: u16, + state: &'a ListState, + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], +} + +// these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. +// Yes, this is a hack, but it makes me feel happy +static SLICES: &str = " > 1 2 3 4 5 6 7 8 9 "; + +impl DrawState<'_> { + /// Render a complete row for a history item based on configured columns. + fn render_row(&mut self, h: &History) { + // Always render the indicator first (width 3) + self.index(); + + // Calculate the width for the expanding column + // Fixed columns use their configured width + 1 (trailing space) + let indicator_width: u16 = 3; + let fixed_width: u16 = self + .columns + .iter() + .filter(|c| !c.expand) + .map(|c| c.width + 1) + .sum(); + let expand_width = self + .list_area + .width + .saturating_sub(indicator_width + fixed_width); + + let style = self.theme.as_style(Meaning::Base); + // Render each configured column + for (idx, column) in self.columns.iter().enumerate() { + if idx != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + let width = if column.expand { + expand_width + } else { + column.width + }; + match column.column_type { + UiColumnType::Duration => self.duration(h, width), + UiColumnType::Time => self.time(h, width), + UiColumnType::Datetime => self.datetime(h, width), + UiColumnType::Directory => self.directory(h, width), + UiColumnType::Host => self.host(h, width), + UiColumnType::User => self.user(h, width), + UiColumnType::Exit => self.exit_code(h, width), + UiColumnType::Command => self.command(h), + } + } + } + + fn index(&mut self) { + if !self.show_numeric_shortcuts { + let i = self.y as usize + self.state.offset; + let is_selected = i == self.state.selected(); + let prompt: &str = if is_selected { self.indicator } else { " " }; + self.draw(prompt, Style::default()); + return; + } + + // these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. + // Yes, this is a hack, but it makes me feel happy + + let i = self.y as usize + self.state.offset; + let i = i.checked_sub(self.state.selected); + let i = i.unwrap_or(10).min(10) * 2; + let prompt: &str = if i == 0 { + self.indicator + } else { + &SLICES[i..i + 3] + }; + self.draw(prompt, Style::default()); + } + + fn duration(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(if h.success() { + Meaning::AlertInfo + } else { + Meaning::AlertError + }); + let duration = Duration::from_nanos(u64::try_from(h.duration).unwrap_or(0)); + let formatted = format_duration(duration); + let w = width as usize; + // Right-align duration within its column width, plus trailing space + let display = format!("{formatted:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn time(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Guidance); + + // Account for the chance that h.timestamp is "in the future" + // This would mean that "since" is negative, and the unwrap here + // would fail. + // If the timestamp would otherwise be in the future, display + // the time since as 0. + let since = (self.now)() - h.timestamp; + let time = format_duration(since.try_into().unwrap_or_default()); + + // Format as "Xs ago" right-aligned within column width + let w = width as usize; + let time_str = format!("{time} ago"); + + let display = format!("{time_str:>w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + fn command(&mut self, h: &History) { + let mut style = self.theme.as_style(Meaning::Base); + let mut row_highlighted = false; + if !self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + row_highlighted = true; + // if not applying alternative highlighting to the whole row, color the command + style = self.theme.as_style(Meaning::AlertError); + style.attributes.set(style::Attribute::Bold); + } + + let highlight_indices = self.history_highlighter.get_highlight_indices( + h.command + .escape_control() + .split_ascii_whitespace() + .join(" ") + .as_str(), + ); + + let mut pos = 0; + for section in h.command.escape_control().split_ascii_whitespace() { + if pos != 0 { + self.draw(" ", Style::from_crossterm(style)); + } + for ch in section.chars() { + if self.x > self.list_area.width { + // Avoid attempting to draw a command section beyond the width + // of the list + return; + } + let mut style = style; + if highlight_indices.contains(&pos) { + if row_highlighted { + // if the row is highlighted bold is not enough as the whole row is bold + // change the color too + style = self.theme.as_style(Meaning::AlertWarn); + } + style.attributes.set(style::Attribute::Bold); + } + let s = ch.to_string(); + self.draw(&s, Style::from_crossterm(style)); + pos += s.len(); + } + pos += 1; + } + } + + /// Render the absolute datetime column (e.g., "2025-01-22 14:35") + fn datetime(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + // Format: YYYY-MM-DD HH:MM + let formatted = h + .timestamp + .format( + &time::format_description::parse("[year]-[month]-[day] [hour]:[minute]") + .expect("valid format"), + ) + .unwrap_or_else(|_| "????-??-?? ??:??".to_string()); + let w = width as usize; + let display = format!("{formatted:w$}"); + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the directory column (working directory, truncated) + fn directory(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + let cwd = &h.cwd; + let char_count = cwd.chars().count(); + // Truncate from the left with "..." if too long, plus trailing space + // Use character count for comparison and skip for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = cwd.chars().skip(char_count - (w - 3)).collect(); + format!("...{truncated}") + } else { + format!("{cwd:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the host column (just the hostname) + fn host(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let host = h.hostname.split(':').next().unwrap_or(&h.hostname); + let char_count = host.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = host.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{host:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the user column + fn user(&mut self, h: &History, width: u16) { + let style = self.theme.as_style(Meaning::Annotation); + let w = width as usize; + // Database stores hostname as "hostname:username" + let user = h.hostname.split(':').nth(1).unwrap_or(""); + let char_count = user.chars().count(); + // Use character count for comparison and take for UTF-8 safety + let display = if char_count > w && w >= 4 { + let truncated: String = user.chars().take(w.saturating_sub(4)).collect(); + format!("{truncated}...") + } else { + format!("{user:w$}") + }; + self.draw(&display, Style::from_crossterm(style)); + } + + /// Render the exit code column + fn exit_code(&mut self, h: &History, width: u16) { + let style = if h.success() { + self.theme.as_style(Meaning::AlertInfo) + } else { + self.theme.as_style(Meaning::AlertError) + }; + let w = width as usize; + let display = format!("{:>w$}", h.exit); + self.draw(&display, Style::from_crossterm(style)); + } + + fn draw(&mut self, s: &str, mut style: Style) { + let cx = self.list_area.left() + self.x; + + let cy = if self.inverted { + self.list_area.top() + self.y + } else { + self.list_area.bottom() - self.y - 1 + }; + + if self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + style = style.add_modifier(Modifier::REVERSED); + } + + let w = (self.list_area.width - self.x) as usize; + self.x += self.buf.set_stringn(cx, cy, s, w, style).0 - cx; + } +} diff --git a/crates/turtle/src/command/client/search/inspector.rs b/crates/turtle/src/command/client/search/inspector.rs new file mode 100644 index 00000000..1ebc4383 --- /dev/null +++ b/crates/turtle/src/command/client/search/inspector.rs @@ -0,0 +1,421 @@ +use std::time::Duration; +use time::macros::format_description; + +use crate::atuin_client::{ + history::{History, HistoryStats}, + settings::{Settings, Timezone}, +}; +use ratatui::{ + Frame, + backend::FromCrossterm, + layout::Rect, + prelude::{Constraint, Direction, Layout}, + style::Style, + text::{Span, Text}, + widgets::{Bar, BarChart, BarGroup, Block, Borders, Padding, Paragraph, Row, Table}, +}; + +use super::duration::format_duration; + +use super::super::theme::{Meaning, Theme}; +use super::interactive::{Compactness, to_compactness}; + +#[expect(clippy::cast_sign_loss)] +fn u64_or_zero(num: i64) -> u64 { + if num < 0 { 0 } else { num as u64 } +} + +pub fn draw_commands( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + stats: &HistoryStats, + compact: bool, + theme: &Theme, +) { + let commands = Layout::default() + .direction(if compact { + Direction::Vertical + } else { + Direction::Horizontal + }) + .constraints(if compact { + [ + Constraint::Length(1), + Constraint::Length(1), + Constraint::Min(0), + ] + } else { + [ + Constraint::Ratio(1, 4), + Constraint::Ratio(1, 2), + Constraint::Ratio(1, 4), + ] + }) + .split(parent); + + let command = Paragraph::new(Text::from(Span::styled( + history.command.clone(), + Style::from_crossterm(theme.as_style(Meaning::Important)), + ))) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .title("Command") + .padding(Padding::horizontal(1)) + }); + + let previous = Paragraph::new( + stats + .previous + .clone() + .map_or_else(|| "[No previous command]".to_string(), |prev| prev.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .title("Previous command") + .padding(Padding::horizontal(1)) + }); + + // Add [] around blank text, as when this is shown in a list + // compacted, it makes it more obviously control text. + let next = Paragraph::new( + stats + .next + .clone() + .map_or_else(|| "[No next command]".to_string(), |next| next.command), + ) + .block(if compact { + Block::new() + .borders(Borders::NONE) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + } else { + Block::new() + .borders(Borders::ALL) + .title("Next command") + .padding(Padding::horizontal(1)) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + }); + + f.render_widget(previous, commands[0]); + f.render_widget(command, commands[1]); + f.render_widget(next, commands[2]); +} + +pub fn draw_stats_table( + f: &mut Frame<'_>, + parent: Rect, + history: &History, + tz: Timezone, + stats: &HistoryStats, + theme: &Theme, +) { + let duration = Duration::from_nanos(u64_or_zero(history.duration)); + let avg_duration = Duration::from_nanos(stats.average_duration); + let (host, user) = history.hostname.split_once(':').unwrap_or(("", "")); + + let rows = [ + Row::new(vec!["Host".to_string(), host.to_string()]), + Row::new(vec!["User".to_string(), user.to_string()]), + Row::new(vec![ + "Time".to_string(), + history.timestamp.to_offset(tz.0).to_string(), + ]), + Row::new(vec!["Duration".to_string(), format_duration(duration)]), + Row::new(vec![ + "Avg duration".to_string(), + format_duration(avg_duration), + ]), + Row::new(vec!["Exit".to_string(), history.exit.to_string()]), + Row::new(vec!["Directory".to_string(), history.cwd.clone()]), + Row::new(vec!["Session".to_string(), history.session.clone()]), + Row::new(vec!["Total runs".to_string(), stats.total.to_string()]), + ]; + + let widths = [Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]; + + let table = Table::new(rows, widths).column_spacing(1).block( + Block::default() + .title("Command stats") + .borders(Borders::ALL) + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .padding(Padding::vertical(1)), + ); + + f.render_widget(table, parent); +} + +fn num_to_day(num: &str) -> String { + match num { + "0" => "Sunday".to_string(), + "1" => "Monday".to_string(), + "2" => "Tuesday".to_string(), + "3" => "Wednesday".to_string(), + "4" => "Thursday".to_string(), + "5" => "Friday".to_string(), + "6" => "Saturday".to_string(), + _ => "Invalid day".to_string(), + } +} + +fn sort_duration_over_time(durations: &[(String, i64)]) -> Vec<(String, i64)> { + let format = format_description!("[day]-[month]-[year]"); + let output = format_description!("[month]/[year repr:last_two]"); + + let mut durations: Vec<(time::Date, i64)> = durations + .iter() + .map(|d| { + ( + time::Date::parse(d.0.as_str(), &format).expect("invalid date string from sqlite"), + d.1, + ) + }) + .collect(); + + durations.sort_by_key(|a| a.0); + + durations + .iter() + .map(|(date, duration)| { + ( + date.format(output).expect("failed to format sqlite date"), + *duration, + ) + }) + .collect() +} + +fn draw_stats_charts(f: &mut Frame<'_>, parent: Rect, stats: &HistoryStats, theme: &Theme) { + let exits: Vec = stats + .exits + .iter() + .map(|(exit, count)| { + Bar::default() + .label(exit.to_string()) + .value(u64_or_zero(*count)) + }) + .collect(); + + let exits = BarChart::default() + .block( + Block::default() + .title("Exit code distribution") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&exits)); + + let day_of_week: Vec = stats + .day_of_week + .iter() + .map(|(day, count)| { + Bar::default() + .label(num_to_day(day.as_str())) + .value(u64_or_zero(*count)) + }) + .collect(); + + let day_of_week = BarChart::default() + .block( + Block::default() + .title("Runs per day") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&day_of_week)); + + let duration_over_time = sort_duration_over_time(&stats.duration_over_time); + let duration_over_time: Vec = duration_over_time + .iter() + .map(|(date, duration)| { + let d = Duration::from_nanos(u64_or_zero(*duration)); + Bar::default() + .label(date.clone()) + .value(u64_or_zero(*duration)) + .text_value(format_duration(d)) + }) + .collect(); + + let duration_over_time = BarChart::default() + .block( + Block::default() + .title("Duration over time") + .style(Style::from_crossterm(theme.as_style(Meaning::Base))) + .borders(Borders::ALL), + ) + .bar_width(5) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&duration_over_time)); + + let layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + ]) + .split(parent); + + f.render_widget(exits, layout[0]); + f.render_widget(day_of_week, layout[1]); + f.render_widget(duration_over_time, layout[2]); +} + +pub fn draw( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + settings: &Settings, + theme: &Theme, + tz: Timezone, +) { + let compactness = to_compactness(f, settings); + + match compactness { + Compactness::Ultracompact => draw_ultracompact(f, chunk, history, stats, theme), + _ => draw_full(f, chunk, history, stats, theme, tz), + } +} + +pub fn draw_ultracompact( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, +) { + draw_commands(f, chunk, history, stats, true, theme); +} + +pub fn draw_full( + f: &mut Frame<'_>, + chunk: Rect, + history: &History, + stats: &HistoryStats, + theme: &Theme, + tz: Timezone, +) { + let vert_layout = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]) + .split(chunk); + + let stats_layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Ratio(1, 3), Constraint::Ratio(2, 3)]) + .split(vert_layout[1]); + + draw_commands(f, vert_layout[0], history, stats, false, theme); + draw_stats_table(f, stats_layout[0], history, tz, stats, theme); + draw_stats_charts(f, stats_layout[1], stats, theme); +} + +#[cfg(test)] +mod tests { + use super::draw_ultracompact; + use crate::atuin_client::{ + history::{History, HistoryId, HistoryStats}, + theme::ThemeManager, + }; + use ratatui::{backend::TestBackend, prelude::*}; + use time::OffsetDateTime; + + fn mock_history_stats() -> (History, HistoryStats) { + let history = History { + id: HistoryId::from("test1".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 3, + exit: 0, + command: "/bin/cmd".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let next = History { + id: HistoryId::from("test2".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 2, + exit: 0, + command: "/bin/cmd -os".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let prev = History { + id: HistoryId::from("test3".to_string()), + timestamp: OffsetDateTime::now_utc(), + duration: 1, + exit: 0, + command: "/bin/cmd -a".to_string(), + cwd: "/toot".to_string(), + session: "sesh1".to_string(), + hostname: "hostn".to_string(), + author: "hostn".to_string(), + intent: None, + deleted_at: None, + }; + let stats = HistoryStats { + next: Some(next.clone()), + previous: Some(prev.clone()), + total: 2, + average_duration: 3, + exits: Vec::new(), + day_of_week: Vec::new(), + duration_over_time: Vec::new(), + }; + (history, stats) + } + + #[test] + fn test_output_looks_correct_for_ultracompact() { + let backend = TestBackend::new(22, 5); + let mut terminal = Terminal::new(backend).expect("Could not create terminal"); + let chunk = Rect::new(0, 0, 22, 5); + let (history, stats) = mock_history_stats(); + let prev = stats.previous.clone().unwrap(); + let next = stats.next.clone().unwrap(); + + let mut manager = ThemeManager::new(Some(true), Some("".to_string())); + let theme = manager.load_theme("(none)", None); + let _ = terminal.draw(|f| draw_ultracompact(f, chunk, &history, &stats, &theme)); + let mut lines = [" "; 5].map(|l| Line::from(l)); + for (n, entry) in [prev, history, next].iter().enumerate() { + let mut l = lines[n].to_string(); + l.replace_range(0..entry.command.len(), &entry.command); + lines[n] = Line::from(l); + } + + terminal.backend().assert_buffer_lines(lines); + } +} diff --git a/crates/turtle/src/command/client/search/interactive.rs b/crates/turtle/src/command/client/search/interactive.rs new file mode 100644 index 00000000..a3d2cb79 --- /dev/null +++ b/crates/turtle/src/command/client/search/interactive.rs @@ -0,0 +1,3041 @@ +use std::{ + io::{IsTerminal, Write, stdout}, + time::Duration, +}; + +#[cfg(unix)] +use std::io::Read as _; + +use crate::atuin_common::{shell::Shell, utils::Escapable as _}; +use eyre::Result; +use time::OffsetDateTime; +use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; + +use super::{ + cursor::Cursor, + engines::{SearchEngine, SearchState}, + history_list::{HistoryList, ListState}, +}; +use crate::atuin_client::{ + database::{Context, Database, current_context}, + history::{History, HistoryId, HistoryStats, store::HistoryStore}, + settings::{ + CursorStyle, ExitMode, FilterMode, KeymapMode, PreviewStrategy, SearchMode, Settings, + UiColumn, + }, +}; + +use crate::command::client::search::history_list::HistoryHighlighter; +use crate::command::client::search::keybindings::KeymapSet; +use crate::command::client::theme::{Meaning, Theme}; +use crate::{VERSION, command::client::search::engines}; + +use ratatui::{ + Frame, Terminal, TerminalOptions, Viewport, + backend::{CrosstermBackend, FromCrossterm}, + crossterm::{ + cursor::SetCursorStyle, + event::{self, Event, KeyEvent, MouseEvent}, + execute, queue, terminal, + }, + layout::{Alignment, Constraint, Direction, Layout}, + prelude::*, + style::{Modifier, Style}, + text::{Line, Span, Text}, + widgets::{Block, BorderType, Borders, Clear, Padding, Paragraph, Tabs}, +}; + +#[cfg(not(target_os = "windows"))] +use ratatui::crossterm::event::{ + KeyboardEnhancementFlags, PopKeyboardEnhancementFlags, PushKeyboardEnhancementFlags, +}; + +const TAB_TITLES: [&str; 2] = ["Search", "Inspect"]; + +pub enum InputAction { + Accept(usize), + AcceptInspecting, + Copy(usize), + Delete(usize), + DeleteAllMatching(usize), + ReturnOriginal, + ReturnQuery, + Continue, + Redraw, + SwitchContext(Option), +} + +#[derive(Clone)] +pub struct InspectingState { + current: Option, + next: Option, + previous: Option, +} + +impl InspectingState { + pub fn move_to_previous(&mut self) { + let previous = self.previous.clone(); + self.reset(); + self.current = previous; + } + + pub fn move_to_next(&mut self) { + let next = self.next.clone(); + self.reset(); + self.current = next; + } + + pub fn reset(&mut self) { + self.current = None; + self.next = None; + self.previous = None; + } +} + +pub fn to_compactness(f: &Frame, settings: &Settings) -> Compactness { + if match settings.style { + crate::atuin_client::settings::Style::Auto => f.area().height < 14, + crate::atuin_client::settings::Style::Compact => true, + crate::atuin_client::settings::Style::Full => false, + } { + if settings.auto_hide_height != 0 && f.area().height <= settings.auto_hide_height { + Compactness::Ultracompact + } else { + Compactness::Compact + } + } else { + Compactness::Full + } +} + +#[expect(clippy::struct_field_names)] +#[expect(clippy::struct_excessive_bools)] +pub struct State { + history_count: i64, + results_state: ListState, + switched_search_mode: bool, + search_mode: SearchMode, + results_len: usize, + accept: bool, + keymap_mode: KeymapMode, + prefix: bool, + current_cursor: Option, + tab_index: usize, + pending_vim_key: Option, + original_input_empty: bool, + + pub inspecting_state: InspectingState, + + keymaps: KeymapSet, + search: SearchState, + engine: Box, + now: Box OffsetDateTime + Send>, +} + +#[derive(Clone, Copy)] +pub enum Compactness { + Ultracompact, + Compact, + Full, +} + +#[derive(Clone, Copy)] +struct StyleState { + compactness: Compactness, + invert: bool, + inner_width: usize, +} + +impl State { + async fn query_results( + &mut self, + db: &mut dyn Database, + smart_sort: bool, + ) -> Result> { + let results = self.engine.query(&self.search, db).await?; + + self.inspecting_state = InspectingState { + current: None, + next: None, + previous: None, + }; + self.results_state.select(0); + self.results_len = results.len(); + + if smart_sort { + Ok(crate::atuin_history::sort::sort( + self.search.input.as_str(), + results, + )) + } else { + Ok(results) + } + } + + fn handle_input(&mut self, settings: &Settings, input: &Event) -> InputAction { + match input { + Event::Key(k) => self.handle_key_input(settings, k), + Event::Mouse(m) => self.handle_mouse_input(*m, settings.invert), + Event::Paste(d) => self.handle_paste_input(d), + _ => InputAction::Continue, + } + } + + fn handle_mouse_input(&mut self, input: MouseEvent, inverted: bool) -> InputAction { + match (input.kind, inverted) { + (event::MouseEventKind::ScrollDown, false) + | (event::MouseEventKind::ScrollUp, true) => { + self.scroll_down(1); + } + (event::MouseEventKind::ScrollDown, true) + | (event::MouseEventKind::ScrollUp, false) => { + self.scroll_up(1); + } + _ => {} + } + InputAction::Continue + } + + fn handle_paste_input(&mut self, input: &str) -> InputAction { + for i in input.chars() { + self.search.input.insert(i); + } + InputAction::Continue + } + + fn cast_cursor_style(style: CursorStyle) -> SetCursorStyle { + match style { + CursorStyle::DefaultUserShape => SetCursorStyle::DefaultUserShape, + CursorStyle::BlinkingBlock => SetCursorStyle::BlinkingBlock, + CursorStyle::SteadyBlock => SetCursorStyle::SteadyBlock, + CursorStyle::BlinkingUnderScore => SetCursorStyle::BlinkingUnderScore, + CursorStyle::SteadyUnderScore => SetCursorStyle::SteadyUnderScore, + CursorStyle::BlinkingBar => SetCursorStyle::BlinkingBar, + CursorStyle::SteadyBar => SetCursorStyle::SteadyBar, + } + } + + fn set_keymap_cursor(&mut self, settings: &Settings, keymap_name: &str) { + let cursor_style = if keymap_name == "__clear__" { + None + } else { + settings.keymap_cursor.get(keymap_name).copied() + } + .or_else(|| self.current_cursor.map(|_| CursorStyle::DefaultUserShape)); + + if cursor_style != self.current_cursor + && let Some(style) = cursor_style + { + self.current_cursor = cursor_style; + let _ = execute!(stdout(), Self::cast_cursor_style(style)); + } + } + + pub fn initialize_keymap_cursor(&mut self, settings: &Settings) { + match self.keymap_mode { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => {} + } + } + + pub fn finalize_keymap_cursor(&mut self, settings: &Settings) { + match settings.keymap_mode_shell { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => self.set_keymap_cursor(settings, "__clear__"), + } + } + + fn handle_key_exit(settings: &Settings) -> InputAction { + match settings.exit_mode { + ExitMode::ReturnOriginal => InputAction::ReturnOriginal, + ExitMode::ReturnQuery => InputAction::ReturnQuery, + } + } + + /// Select the keymap for the current mode (ignoring prefix). + fn mode_keymap(&self) -> &super::keybindings::Keymap { + if self.tab_index == 1 { + &self.keymaps.inspector + } else { + match self.keymap_mode { + KeymapMode::Emacs | KeymapMode::Auto => &self.keymaps.emacs, + KeymapMode::VimNormal => &self.keymaps.vim_normal, + KeymapMode::VimInsert => &self.keymaps.vim_insert, + } + } + } + + /// Whether the current mode supports character insertion on unmatched keys. + fn is_insert_mode(&self) -> bool { + matches!( + self.keymap_mode, + KeymapMode::Emacs | KeymapMode::Auto | KeymapMode::VimInsert + ) + } + + fn handle_key_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { + use super::keybindings::Action; + use super::keybindings::EvalContext; + use super::keybindings::key::{KeyCodeValue, KeyInput, SingleKey}; + + // Skip release events + if input.kind == event::KeyEventKind::Release { + return InputAction::Continue; + } + + // Reset switched_search_mode at start of each key event + self.switched_search_mode = false; + + // Build evaluation context from current state + let ctx = EvalContext { + cursor_position: self.search.input.position(), + input_width: UnicodeWidthStr::width(self.search.input.as_str()), + input_byte_len: self.search.input.as_str().len(), + selected_index: self.results_state.selected(), + results_len: self.results_len, + original_input_empty: self.original_input_empty, + has_context: self.search.custom_context.is_some(), + }; + + // Convert KeyEvent to SingleKey + let Some(single) = SingleKey::from_event(input) else { + return InputAction::Continue; + }; + + // --- Phase 1: Resolve (take pending key first, then immutable borrows) --- + + // Take pending key before any immutable borrows of self + let pending = self.pending_vim_key.take(); + + // If in prefix mode, try prefix keymap first (single keys only) + let prefix_action = if self.prefix { + let ki = KeyInput::Single(single.clone()); + self.keymaps.prefix.resolve(&ki, &ctx) + } else { + None + }; + + // The if-let/else-if chain here is clearer than map_or_else with nested closures. + #[expect(clippy::option_if_let_else)] + let (action, new_pending) = if prefix_action.is_some() { + (prefix_action, None) + } else { + // Use mode keymap (handles both single and multi-key sequences) + let keymap = self.mode_keymap(); + + if let Some(pending_char) = pending { + // We have a pending key from a previous press (e.g., first 'g' of 'gg') + let pending_single = SingleKey { + code: KeyCodeValue::Char(pending_char), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }; + let seq = KeyInput::Sequence(vec![pending_single, single.clone()]); + let action = keymap + .resolve(&seq, &ctx) + .or_else(|| keymap.resolve(&KeyInput::Single(single.clone()), &ctx)); + (action, None) + } else if keymap.has_sequence_starting_with(&single) + && matches!(single.code, KeyCodeValue::Char(_)) + && !single.ctrl + && !single.alt + { + // This key starts a multi-key sequence; wait for next key + let KeyCodeValue::Char(c) = single.code else { + unreachable!() + }; + (Some(Action::Noop), Some(c)) + } else { + ( + keymap.resolve(&KeyInput::Single(single.clone()), &ctx), + None, + ) + } + }; + + // --- Phase 2: Apply mutations --- + self.pending_vim_key = new_pending; + + // Reset prefix (before execute, so EnterPrefixMode can re-set it) + self.prefix = false; + + if let Some(action) = action { + self.execute_action(&action, settings) + } else { + // No action matched. In insert-capable modes, insert the character. + if self.is_insert_mode() && !single.ctrl && !single.alt { + match single.code { + KeyCodeValue::Char(c) => { + self.search.input.insert(c); + } + KeyCodeValue::Space => { + self.search.input.insert(' '); + } + _ => {} + } + } + InputAction::Continue + } + } + + fn scroll_down(&mut self, scroll_len: usize) { + let i = self.results_state.selected().saturating_sub(scroll_len); + self.inspecting_state.reset(); + self.results_state.select(i); + } + + fn scroll_up(&mut self, scroll_len: usize) { + let i = self.results_state.selected() + scroll_len; + self.results_state + .select(i.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + } + + /// Execute a resolved action, performing all side effects and returning the + /// appropriate `InputAction` for the event loop. + /// + /// This is the "do it" half of the resolve+execute pipeline. The resolver + /// decides *what* to do (which `Action`), and this function carries it out. + /// + /// Invert handling: scroll actions (`SelectNext`, `ScrollPageDown`, etc.) account + /// for `settings.invert` so that keybindings are always in "visual" terms — + /// users never need to think about invert in their keybinding config. + #[expect(clippy::too_many_lines)] + pub(crate) fn execute_action( + &mut self, + action: &super::keybindings::Action, + settings: &Settings, + ) -> InputAction { + use crate::command::client::search::keybindings::Action; + + match action { + // -- Cursor movement -- + Action::CursorLeft => { + self.search.input.left(); + InputAction::Continue + } + Action::CursorRight => { + self.search.input.right(); + InputAction::Continue + } + Action::CursorWordLeft => { + self.search + .input + .prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordRight => { + self.search + .input + .next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::CursorWordEnd => { + self.search.input.word_end(&settings.word_chars); + InputAction::Continue + } + Action::CursorStart => { + self.search.input.start(); + InputAction::Continue + } + Action::CursorEnd => { + self.search.input.end(); + InputAction::Continue + } + + // -- Editing -- + Action::DeleteCharBefore => { + self.search.input.back(); + InputAction::Continue + } + Action::DeleteCharAfter => { + self.search.input.remove(); + InputAction::Continue + } + Action::DeleteWordBefore => { + self.search + .input + .remove_prev_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteWordAfter => { + self.search + .input + .remove_next_word(&settings.word_chars, settings.word_jump_mode); + InputAction::Continue + } + Action::DeleteToWordBoundary => { + // ctrl-w: remove trailing whitespace, then delete to word boundary + while matches!(self.search.input.back(), Some(c) if c.is_whitespace()) {} + while self.search.input.left() { + if self.search.input.char().unwrap().is_whitespace() { + self.search.input.right(); + break; + } + self.search.input.remove(); + } + InputAction::Continue + } + Action::ClearLine => { + self.search.input.clear(); + InputAction::Continue + } + Action::ClearToStart => { + self.search.input.clear_to_start(); + InputAction::Continue + } + Action::ClearToEnd => { + self.search.input.clear_to_end(); + InputAction::Continue + } + + // -- List navigation (invert-aware) -- + Action::SelectNext => { + if settings.invert { + self.scroll_up(1); + } else { + self.scroll_down(1); + } + InputAction::Continue + } + Action::SelectPrevious => { + if settings.invert { + self.scroll_down(1); + } else { + self.scroll_up(1); + } + InputAction::Continue + } + // -- Page/half-page scroll (invert-aware) -- + Action::ScrollHalfPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollHalfPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines) + / 2; + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageUp => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_down(scroll_len); + } else { + self.scroll_up(scroll_len); + } + InputAction::Continue + } + Action::ScrollPageDown => { + let scroll_len = self + .results_state + .max_entries() + .saturating_sub(settings.scroll_context_lines); + if settings.invert { + self.scroll_up(scroll_len); + } else { + self.scroll_down(scroll_len); + } + InputAction::Continue + } + + // -- Absolute jumps (invert-aware) -- + Action::ScrollToTop => { + // Visual top of history + if settings.invert { + self.results_state.select(0); + } else { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToBottom => { + // Visual bottom of history + if settings.invert { + let last_idx = self.results_len.saturating_sub(1); + self.results_state.select(last_idx); + } else { + self.results_state.select(0); + } + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenTop => { + // H — jump to top of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let bottom = top + visible.saturating_sub(1); + self.results_state + .select(bottom.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenMiddle => { + // M — jump to middle of visible screen + let top = self.results_state.offset(); + let visible = self.results_state.max_entries().min(self.results_len); + let middle = top + visible / 2; + self.results_state + .select(middle.min(self.results_len.saturating_sub(1))); + self.inspecting_state.reset(); + InputAction::Continue + } + Action::ScrollToScreenBottom => { + // L — jump to bottom of visible screen + let top_visible = self.results_state.offset(); + self.results_state.select(top_visible); + self.inspecting_state.reset(); + InputAction::Continue + } + + // -- Commands -- + Action::Accept => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + self.accept = true; + InputAction::Accept(self.results_state.selected()) + } + Action::AcceptNth(n) => { + self.accept = true; + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::ReturnSelection => { + if self.tab_index == 1 { + return InputAction::AcceptInspecting; + } + InputAction::Accept(self.results_state.selected()) + } + Action::ReturnSelectionNth(n) => { + InputAction::Accept(self.results_state.selected() + *n as usize) + } + Action::Copy => InputAction::Copy(self.results_state.selected()), + Action::Delete => InputAction::Delete(self.results_state.selected()), + Action::DeleteAll => InputAction::DeleteAllMatching(self.results_state.selected()), + Action::ReturnOriginal => InputAction::ReturnOriginal, + Action::ReturnQuery => InputAction::ReturnQuery, + Action::Exit => Self::handle_key_exit(settings), + Action::Redraw => InputAction::Redraw, + Action::CycleFilterMode => { + self.search.rotate_filter_mode(settings, 1); + InputAction::Continue + } + Action::CycleSearchMode => { + self.switched_search_mode = true; + self.search_mode = self.search_mode.next(settings); + self.engine = engines::engine(self.search_mode, settings); + InputAction::Continue + } + Action::SwitchContext => { + InputAction::SwitchContext(Some(self.results_state.selected())) + } + Action::ClearContext => InputAction::SwitchContext(None), + Action::ToggleTab => { + self.tab_index = (self.tab_index + 1) % TAB_TITLES.len(); + InputAction::Continue + } + + // -- Mode changes -- + Action::VimEnterNormal => { + self.set_keymap_cursor(settings, "vim_normal"); + self.keymap_mode = KeymapMode::VimNormal; + InputAction::Continue + } + Action::VimEnterInsert => { + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAfter => { + self.search.input.right(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtStart => { + self.search.input.start(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimEnterInsertAtEnd => { + self.search.input.end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimSearchInsert => { + self.search.input.clear(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::VimChangeToEnd => { + self.search.input.clear_to_end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + InputAction::Continue + } + Action::EnterPrefixMode => { + self.prefix = true; + InputAction::Continue + } + + // -- Inspector -- + Action::InspectPrevious => { + self.inspecting_state.move_to_previous(); + InputAction::Redraw + } + Action::InspectNext => { + self.inspecting_state.move_to_next(); + InputAction::Redraw + } + + // -- Special -- + Action::Noop => InputAction::Continue, + } + } + + #[expect(clippy::cast_possible_truncation)] + #[expect(clippy::bool_to_int_with_if)] + fn calc_preview_height( + settings: &Settings, + results: &[History], + selected: usize, + tab_index: usize, + compactness: Compactness, + border_size: u16, + preview_width: u16, + ) -> u16 { + if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Auto + && tab_index == 0 + && !results.is_empty() + { + let length_current_cmd = results[selected].command.len() as u16; + // calculate the number of newlines in the command + let num_newlines = results[selected] + .command + .chars() + .filter(|&c| c == '\n') + .count() as u16; + if num_newlines > 0 { + std::cmp::min( + settings.max_preview_height, + results[selected] + .command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + border_size * 2 + } + // The '- 19' takes the characters before the command (duration and time) into account + else if length_current_cmd > preview_width - 19 { + std::cmp::min( + settings.max_preview_height, + (length_current_cmd + preview_width - 1 - border_size) + / (preview_width - border_size), + ) + border_size * 2 + } else { + 1 + } + } else if settings.show_preview + && settings.preview.strategy == PreviewStrategy::Static + && tab_index == 0 + { + let longest_command = results + .iter() + .max_by(|h1, h2| h1.command.len().cmp(&h2.command.len())); + longest_command.map_or(0, |v| { + std::cmp::min( + settings.max_preview_height, + v.command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + }) + border_size * 2 + } else if settings.show_preview && settings.preview.strategy == PreviewStrategy::Fixed { + settings.max_preview_height + border_size * 2 + } else if !matches!(compactness, Compactness::Full) || tab_index == 1 { + 0 + } else { + 1 + } + } + + #[expect(clippy::bool_to_int_with_if)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::too_many_arguments)] + fn draw( + &mut self, + f: &mut Frame, + results: &[History], + stats: Option, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + popup_mode: bool, + ) { + let area = f.area(); + if popup_mode { + f.render_widget(Clear, area); + } + self.draw_inner(f, area, results, stats, inspecting, settings, theme); + } + + #[expect(clippy::too_many_arguments)] + #[expect(clippy::too_many_lines)] + #[expect(clippy::bool_to_int_with_if)] + fn draw_inner( + &mut self, + f: &mut Frame, + area: Rect, + results: &[History], + stats: Option, + inspecting: Option<&History>, + settings: &Settings, + theme: &Theme, + ) { + let compactness = to_compactness(f, settings); + let invert = settings.invert; + let border_size = match compactness { + Compactness::Full => 1, + _ => 0, + }; + let preview_width = area.width.saturating_sub(2); + let preview_height = Self::calc_preview_height( + settings, + results, + self.results_state.selected(), + self.tab_index, + compactness, + border_size, + preview_width, + ); + let show_help = + settings.show_help && (matches!(compactness, Compactness::Full) || area.height > 1); + // This is an OR, as it seems more likely for someone to wish to override + // tabs unexpectedly being missed, than unexpectedly present. + let show_tabs = settings.show_tabs && !matches!(compactness, Compactness::Ultracompact); + let chunks = Layout::default() + .direction(Direction::Vertical) + .margin(0) + .horizontal_margin(1) + .constraints::<&[Constraint]>( + if invert { + [ + Constraint::Length(1 + border_size), // input + Constraint::Min(1), // results list + Constraint::Length(preview_height), // preview + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Length(if show_help { 1 } else { 0 }), // header (sic) + ] + } else { + match compactness { + Compactness::Ultracompact => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(0), // tabs + Constraint::Min(1), // results list + Constraint::Length(0), + Constraint::Length(0), + ], + _ => [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Min(1), // results list + Constraint::Length(1 + border_size), // input + Constraint::Length(preview_height), // preview + ], + } + } + .as_ref(), + ) + .split(area); + + let input_chunk = if invert { chunks[0] } else { chunks[3] }; + let results_list_chunk = if invert { chunks[1] } else { chunks[2] }; + let preview_chunk = if invert { chunks[2] } else { chunks[4] }; + let tabs_chunk = if invert { chunks[3] } else { chunks[1] }; + let header_chunk = if invert { chunks[4] } else { chunks[0] }; + + // TODO: this should be split so that we have one interactive search container that is + // EITHER a search box or an inspector. But I'm not doing that now, way too much atm. + // also allocate less 🙈 + let titles: Vec<_> = TAB_TITLES.iter().copied().map(Line::from).collect(); + + if show_tabs { + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::NONE)) + .select(self.tab_index) + .style(Style::default()) + .highlight_style(Style::from_crossterm(theme.as_style(Meaning::Important))); + + f.render_widget(tabs, tabs_chunk); + } + + let style = StyleState { + compactness, + invert, + inner_width: input_chunk.width.into(), + }; + + let header_chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints::<&[Constraint]>( + [ + Constraint::Ratio(1, 5), + Constraint::Ratio(3, 5), + Constraint::Ratio(1, 5), + ] + .as_ref(), + ) + .split(header_chunk); + + let title = Self::build_title(theme); + f.render_widget(title, header_chunks[0]); + + let help = self.build_help(settings, theme); + f.render_widget(help, header_chunks[1]); + + let stats_tab = self.build_stats(theme); + f.render_widget(stats_tab, header_chunks[2]); + + let indicator: String = match compactness { + Compactness::Ultracompact => { + if self.switched_search_mode { + format!("S{}>", self.search_mode.as_str().chars().next().unwrap()) + } else if self.search.custom_context.is_some() { + format!( + "C{}>", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } else { + format!( + "{}> ", + self.search.filter_mode.as_str().chars().next().unwrap() + ) + } + } + _ => " > ".to_string(), + }; + + match self.tab_index { + 0 => { + let history_highlighter = HistoryHighlighter { + engine: self.engine.as_ref(), + search_input: self.search.input.as_str(), + }; + let results_list = Self::build_results_list( + style, + results, + self.keymap_mode, + &self.now, + indicator.as_str(), + theme, + history_highlighter, + settings.show_numeric_shortcuts, + &settings.ui.columns, + ); + f.render_stateful_widget(results_list, results_list_chunk, &mut self.results_state); + } + + 1 => { + if results.is_empty() { + let message = Paragraph::new("Nothing to inspect") + .block( + Block::new() + .title(Line::from(" Info ".to_string())) + .title_alignment(Alignment::Center) + .borders(Borders::ALL) + .padding(Padding::vertical(2)), + ) + .alignment(Alignment::Center); + f.render_widget(message, results_list_chunk); + } else { + let inspecting = match inspecting { + Some(inspecting) => inspecting, + None => &results[self.results_state.selected()], + }; + super::inspector::draw( + f, + results_list_chunk, + inspecting, + &stats.expect("Drawing inspector, but no stats"), + settings, + theme, + settings.timezone, + ); + } + + // HACK: I'm following up with abstracting this into the UI container, with a + // sub-widget for search + for inspector + let feedback = Paragraph::new( + "The inspector is new - please give feedback (good, or bad) at https://forum.atuin.sh", + ); + f.render_widget(feedback, input_chunk); + + return; + } + + _ => { + panic!("invalid tab index"); + } + } + + if !matches!(compactness, Compactness::Ultracompact) { + let preview_width = match compactness { + Compactness::Full => preview_width - 2, + _ => preview_width, + }; + let preview = self.build_preview( + results, + compactness, + preview_width, + preview_chunk.width.into(), + theme, + ); + #[expect(clippy::cast_possible_truncation)] + let prefix_width = settings + .ui + .columns + .iter() + .take_while(|col| !col.expand) + .map(|col| col.width + 1) + .sum::() + + " > ".len() as u16; + #[expect(clippy::cast_possible_truncation)] + let min_prefix_width = "[ SRCH: FULLTXT ] ".len() as u16; + self.draw_preview( + f, + style, + input_chunk, + compactness, + preview_chunk, + preview, + std::cmp::max(prefix_width, min_prefix_width), + ); + } + } + + #[expect(clippy::cast_possible_truncation, clippy::too_many_arguments)] + fn draw_preview( + &self, + f: &mut Frame, + style: StyleState, + input_chunk: Rect, + compactness: Compactness, + preview_chunk: Rect, + preview: Paragraph, + prefix_width: u16, + ) { + let input = self.build_input(style, prefix_width); + f.render_widget(input, input_chunk); + + f.render_widget(preview, preview_chunk); + + let extra_width = UnicodeWidthStr::width(self.search.input.substring()); + + let cursor_offset = match compactness { + Compactness::Full => 1, + _ => 0, + }; + f.set_cursor_position(( + // Put cursor past the end of the input text + input_chunk.x + extra_width as u16 + prefix_width + cursor_offset, + input_chunk.y + cursor_offset, + )); + } + + fn build_title(theme: &Theme) -> Paragraph<'_> { + let title = { + let style: Style = Style::from_crossterm(theme.as_style(Meaning::Base)); + Paragraph::new(Text::from(Span::styled( + format!("Atuin v{VERSION}"), + style.add_modifier(Modifier::BOLD), + ))) + }; + title.alignment(Alignment::Left) + } + + #[expect(clippy::unused_self)] + fn build_help(&self, settings: &Settings, theme: &Theme) -> Paragraph<'_> { + match self.tab_index { + // search + 0 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": edit"), + Span::raw(", "), + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(if settings.enter_accept { + ": run" + } else { + ": edit" + }), + Span::raw(", "), + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": inspect"), + ]))), + + 1 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": search"), + Span::raw(", "), + Span::styled("", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": delete"), + ]))), + + _ => unreachable!("invalid tab index"), + } + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Center) + } + + fn build_stats(&self, theme: &Theme) -> Paragraph<'_> { + Paragraph::new(Text::from(Span::raw(format!( + "history count: {}", + self.history_count, + )))) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))) + .alignment(Alignment::Right) + } + + #[expect(clippy::too_many_arguments)] + fn build_results_list<'a>( + style: StyleState, + results: &'a [History], + keymap_mode: KeymapMode, + now: &'a dyn Fn() -> OffsetDateTime, + indicator: &'a str, + theme: &'a Theme, + history_highlighter: HistoryHighlighter<'a>, + show_numeric_shortcuts: bool, + columns: &'a [UiColumn], + ) -> HistoryList<'a> { + let results_list = HistoryList::new( + results, + style.invert, + keymap_mode == KeymapMode::VimNormal, + now, + indicator, + theme, + history_highlighter, + show_numeric_shortcuts, + columns, + ); + + match style.compactness { + Compactness::Full => { + if style.invert { + results_list.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } else { + results_list.block( + Block::default() + .borders(Borders::TOP | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded), + ) + } + } + _ => results_list, + } + } + + fn build_input(&self, style: StyleState, prefix_width: u16) -> Paragraph<'_> { + let (pref, mode) = if self.switched_search_mode { + (" SRCH:", self.search_mode.as_str()) + } else if self.search.custom_context.is_some() { + (" CTX:", self.search.filter_mode.as_str()) + } else { + ("", self.search.filter_mode.as_str()) + }; + // 3: surrounding "[" "] " + let mode_width = usize::from(prefix_width) - pref.len() - 3; + // sanity check to ensure we don't exceed the layout limits + debug_assert!(mode_width >= mode.len(), "mode name '{mode}' is too long!"); + let input = format!("[{pref}{mode:^mode_width$}] {}", self.search.input.as_str()); + let input = Paragraph::new(input); + match style.compactness { + Compactness::Full => { + if style.invert { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT | Borders::TOP) + .border_type(BorderType::Rounded), + ) + } else { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } + } + _ => input, + } + } + + fn build_preview( + &self, + results: &[History], + compactness: Compactness, + preview_width: u16, + chunk_width: usize, + theme: &Theme, + ) -> Paragraph<'_> { + let selected = self.results_state.selected(); + let command = if results.is_empty() { + String::new() + } else { + let s = &results[selected].command; + let mut lines = Vec::new(); + for line in s.split('\n') { + let line = line.escape_control(); + let mut width = 0; + let mut start = 0; + for (idx, ch) in line.char_indices() { + let w = ch.width().unwrap_or(0); // None for control chars which should not happen + if width + w > preview_width.into() { + lines.push(line[start..idx].to_owned()); + start = idx; + width = w; + } else { + width += w; + } + } + if width != 0 { + lines.push(line[start..].to_owned()); + } + } + lines.join("\n") + }; + + match compactness { + Compactness::Full => Paragraph::new(command).block( + Block::default() + .borders(Borders::BOTTOM | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = chunk_width - 2)), + ), + _ => Paragraph::new(command) + .style(Style::from_crossterm(theme.as_style(Meaning::Annotation))), + } + } +} + +/// The writer used for terminal output - either stdout or /dev/tty +enum TerminalWriter { + Stdout(std::io::Stdout), + #[cfg(unix)] + Tty(std::fs::File), +} + +impl TerminalWriter { + fn new() -> std::io::Result { + let stdout = stdout(); + if stdout.is_terminal() { + return Ok(TerminalWriter::Stdout(stdout)); + } + + // If stdout is not a terminal (e.g., captured by command substitution), + // fall back to /dev/tty so the TUI can still render. + // This allows usage like: VAR=$(atuin search -i) + #[cfg(unix)] + { + Ok(TerminalWriter::Tty( + std::fs::File::options() + .read(true) + .write(true) + .open("/dev/tty")?, + )) + } + } +} + +impl Write for TerminalWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + TerminalWriter::Stdout(stdout) => stdout.write(buf), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + TerminalWriter::Stdout(stdout) => stdout.flush(), + #[cfg(unix)] + TerminalWriter::Tty(file) => file.flush(), + } + } +} + +/// Screen state captured from atuin pty-proxy's screen server. +#[cfg(unix)] +struct SavedScreen { + #[expect(dead_code)] + rows: u16, + #[expect(dead_code)] + cols: u16, + cursor_row: u16, + cursor_col: u16, + /// Pre-formatted ANSI bytes for each screen row, ready to write to stdout. + rows_data: Vec>, +} + +/// Connect to atuin pty-proxy's Unix socket and fetch the current screen state. +/// +/// The wire format is: +/// ```text +/// [rows: u16 BE][cols: u16 BE][cursor_row: u16 BE][cursor_col: u16 BE] +/// [row_0_len: u32 BE][row_0_bytes...] +/// [row_1_len: u32 BE][row_1_bytes...] +/// ... +/// ``` +#[cfg(unix)] +fn fetch_screen_state(socket_path: &str) -> Option { + use std::os::unix::net::UnixStream; + + let mut stream = UnixStream::connect(socket_path).ok()?; + stream.set_read_timeout(Some(Duration::from_secs(2))).ok()?; + + let mut data = Vec::new(); + stream.read_to_end(&mut data).ok()?; + + if data.len() < 8 { + return None; + } + + let rows = u16::from_be_bytes([data[0], data[1]]); + let cols = u16::from_be_bytes([data[2], data[3]]); + let cursor_row = u16::from_be_bytes([data[4], data[5]]); + let cursor_col = u16::from_be_bytes([data[6], data[7]]); + + // Parse length-prefixed rows + let mut rows_data = Vec::with_capacity(rows as usize); + let mut offset = 8; + while offset + 4 <= data.len() { + let row_len = u32::from_be_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]) as usize; + offset += 4; + if offset + row_len > data.len() { + break; + } + rows_data.push(data[offset..offset + row_len].to_vec()); + offset += row_len; + } + + Some(SavedScreen { + rows, + cols, + cursor_row, + cursor_col, + rows_data, + }) +} + +/// Restore the screen area that was covered by the popup. +/// +/// Writes the pre-formatted per-row ANSI bytes received from atuin pty-proxy +/// directly to stdout, which correctly handles wide characters, colors, and +/// all text attributes without needing a client-side vt100 parser. +#[cfg(unix)] +fn restore_popup_area(saved: &SavedScreen, popup_rect: Rect, scroll_offset: u16) { + use ratatui::crossterm::cursor::MoveTo; + + let mut stdout = stdout(); + + for dy in 0..popup_rect.height { + let target_row = popup_rect.y + dy; + let source_row = (target_row + scroll_offset) as usize; + + // Clear only the popup region. The server-side rows_formatted() skips + // default cells (spaces with default attributes) using cursor jumps, so + // any popup content at those positions would remain if not cleared + // beforehand. We write `popup_rect.width` spaces instead of + // ClearType::CurrentLine so that only the popup area is cleared, not + // the entire terminal line. + let _ = execute!( + stdout, + MoveTo(popup_rect.x, target_row), + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset), + ); + let _ = write!(stdout, "{:width$}", "", width = popup_rect.width as usize); + let _ = execute!(stdout, MoveTo(popup_rect.x, target_row)); + + if let Some(row_bytes) = saved.rows_data.get(source_row) { + let _ = stdout.write_all(row_bytes); + } + } + + let _ = execute!( + stdout, + MoveTo( + saved.cursor_col, + saved.cursor_row.saturating_sub(scroll_offset) + ) + ); + let _ = stdout.flush(); +} + +struct Stdout { + writer: TerminalWriter, + inline_mode: bool, + no_mouse: bool, +} + +impl Stdout { + pub fn new(inline_mode: bool, no_mouse: bool) -> std::io::Result { + terminal::enable_raw_mode()?; + + let mut writer = TerminalWriter::new()?; + + if !inline_mode { + execute!(writer, terminal::EnterAlternateScreen)?; + } + + if !no_mouse { + execute!(writer, event::EnableMouseCapture)?; + } + + execute!(writer, event::EnableBracketedPaste)?; + + #[cfg(not(target_os = "windows"))] + execute!( + writer, + PushKeyboardEnhancementFlags( + KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALL_KEYS_AS_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALTERNATE_KEYS + ), + )?; + + Ok(Self { + writer, + inline_mode, + no_mouse, + }) + } +} + +impl Drop for Stdout { + fn drop(&mut self) { + #[cfg(not(target_os = "windows"))] + if let Err(e) = execute!(self.writer, PopKeyboardEnhancementFlags) { + tracing::error!(?e, "Failed to pop keyboard enhancement flags"); + } + + if !self.inline_mode + && let Err(e) = execute!(self.writer, terminal::LeaveAlternateScreen) + { + tracing::error!(?e, "Failed to leave alt screen mode"); + } + + if !self.no_mouse + && let Err(e) = execute!(self.writer, event::DisableMouseCapture) + { + tracing::error!(?e, "Failed to disable mouse capture"); + } + + if let Err(e) = execute!(self.writer, event::DisableBracketedPaste) { + tracing::error!(?e, "Failed to disable bracketed paste"); + } + + if let Err(e) = terminal::disable_raw_mode() { + tracing::error!(?e, "Failed to disable raw mode"); + } + } +} + +impl Write for Stdout { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.writer.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.writer.flush() + } +} + +// this is a big blob of horrible! clean it up! +/// Compute the popup position and any scroll offset needed to make room. +/// +/// Given the cursor row, terminal dimensions, and desired popup height, +/// returns `(popup_rect, scroll_offset)` where `scroll_offset` is the number +/// of lines the caller should scroll the terminal up before rendering. +/// +/// This function performs no I/O — it is a pure computation. +#[cfg(unix)] +fn compute_popup_placement( + cursor_row: u16, + term_rows: u16, + term_cols: u16, + inline_height: u16, +) -> (Rect, u16) { + let popup_w = term_cols; + let popup_h = inline_height.min(term_rows); + let space_below = term_rows.saturating_sub(cursor_row); + + let (popup_y, scroll) = if popup_h <= space_below { + // Fits below cursor + (cursor_row, 0u16) + } else if cursor_row >= term_rows / 2 { + // Bottom half — render above cursor (overlay on existing text) + (cursor_row.saturating_sub(popup_h), 0u16) + } else { + // Top half, not enough space — scroll terminal to make room + let scroll = popup_h.saturating_sub(space_below); + let popup_y = cursor_row.saturating_sub(scroll); + (popup_y, scroll) + }; + + (Rect::new(0, popup_y, popup_w, popup_h), scroll) +} + +// for now, it works. But it'd be great if it were more easily readable, and +// modular. I'd like to add some more stats and stuff at some point +#[expect( + clippy::cast_possible_truncation, + clippy::too_many_lines, + clippy::cognitive_complexity +)] +pub async fn history( + query: &[String], + settings: &Settings, + mut db: impl Database, + history_store: &HistoryStore, + theme: &Theme, +) -> Result { + let inline_height = if settings.shell_up_key_binding { + settings + .inline_height_shell_up_key_binding + .unwrap_or(settings.inline_height) + } else { + settings.inline_height + }; + + // Use fullscreen mode if the inline height doesn't fit in the terminal, + // this will preserve the scroll position upon exit. + // Also force fullscreen when stdout isn't a terminal (e.g., command substitution + // like VAR=$(atuin search -i)). In that case, we need to use /dev/tty for the TUI and force + // fullscreen mode (inline mode won't work as it requires cursor position queries + // that don't work when stdout is captured). + let inline_height = if !stdout().is_terminal() { + 0 + } else if let Ok(size) = terminal::size() + && inline_height >= size.1 + { + 0 + } else { + inline_height + }; + + // Popup mode: if running under atuin pty-proxy and inline mode is requested, + // fetch the screen state and render as a centered overlay. + #[cfg(unix)] + let (saved_screen, popup_rect, popup_scroll_offset) = { + let socket_path = std::env::var("ATUIN_PTY_PROXY_SOCKET") + .or_else(|_| std::env::var("ATUIN_HEX_SOCKET")) + .ok(); + if let Some(ref path) = socket_path + && inline_height > 0 + { + let saved = fetch_screen_state(path); + if let Some(ref s) = saved { + let (term_cols, term_rows) = terminal::size().unwrap_or((s.cols, s.rows)); + let (popup_rect, scroll) = + compute_popup_placement(s.cursor_row, term_rows, term_cols, inline_height); + + // Scroll terminal content up to make room if needed + if scroll > 0 { + use ratatui::crossterm::cursor::MoveTo; + let mut stdout = stdout(); + let _ = execute!(stdout, MoveTo(0, term_rows - 1)); + for _ in 0..scroll { + let _ = writeln!(stdout); + } + let _ = stdout.flush(); + } + + (saved, popup_rect, scroll) + } else { + (None, Rect::default(), 0u16) + } + } else { + (None, Rect::default(), 0u16) + } + }; + + let popup_mode = saved_screen.is_some(); + + let stdout = Stdout::new(inline_height > 0, settings.no_mouse)?; + + // In popup mode, clear the popup region on the physical terminal before + // ratatui takes over. Ratatui's diff-based rendering compares against an + // initially-empty buffer, so cells that remain "empty" (spaces with default + // style) won't be written — leaving underlying terminal text visible. + // By pre-clearing with spaces, those cells are already correct on screen. + if popup_mode { + use ratatui::crossterm::cursor::MoveTo; + let mut raw_stdout = std::io::stdout(); + // Queue all commands without flushing so the terminal receives them + // as a single write — no intermediate cursor positions are visible. + let _ = queue!( + raw_stdout, + ratatui::crossterm::style::SetAttribute(ratatui::crossterm::style::Attribute::Reset) + ); + for row in popup_rect.y..popup_rect.y.saturating_add(popup_rect.height) { + let _ = queue!(raw_stdout, MoveTo(popup_rect.x, row)); + let _ = write!( + raw_stdout, + "{:width$}", + "", + width = popup_rect.width as usize + ); + } + let _ = raw_stdout.flush(); + } + + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::with_options( + backend, + TerminalOptions { + viewport: if popup_mode { + Viewport::Fixed(popup_rect) + } else if inline_height > 0 { + Viewport::Inline(inline_height) + } else { + Viewport::Fullscreen + }, + }, + )?; + + let original_query = query.join(" "); + + // Check if this is a command chaining scenario + let is_command_chaining = if settings.command_chaining { + let trimmed = original_query.trim_end(); + trimmed.ends_with("&&") || trimmed.ends_with('|') + } else { + false + }; + + // For command chaining, start with empty input to allow searching for new commands + let search_input = if is_command_chaining { + String::new() + } else { + original_query.clone() + }; + + let mut input = Cursor::from(search_input); + // Put the cursor at the end of the query by default + input.end(); + + let initial_context = current_context().await?; + + let history_count = db.history_count(false).await?; + let search_mode = if settings.shell_up_key_binding { + settings + .search_mode_shell_up_key_binding + .unwrap_or(settings.search_mode) + } else { + settings.search_mode + }; + let default_filter_mode = settings + .filter_mode_shell_up_key_binding + .filter(|_| settings.shell_up_key_binding) + .unwrap_or_else(|| settings.default_filter_mode(initial_context.git_root.is_some())); + let mut app = State { + history_count, + results_state: ListState::default(), + switched_search_mode: false, + search_mode, + tab_index: 0, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(settings), + search: SearchState { + input, + filter_mode: default_filter_mode, + context: initial_context.clone(), + custom_context: None, + }, + engine: engines::engine(search_mode, settings), + results_len: 0, + accept: false, + keymap_mode: match settings.keymap_mode { + KeymapMode::Auto => KeymapMode::Emacs, + value => value, + }, + current_cursor: None, + now: if settings.prefers_reduced_motion { + let now = OffsetDateTime::now_utc(); + Box::new(move || now) + } else { + Box::new(OffsetDateTime::now_utc) + }, + prefix: false, + pending_vim_key: None, + original_input_empty: original_query.is_empty(), + }; + + app.initialize_keymap_cursor(settings); + + let mut results = app.query_results(&mut db, settings.smart_sort).await?; + + if inline_height > 0 && !popup_mode { + terminal.clear()?; + } + + let mut stats: Option = None; + let mut inspecting: Option = None; + let accept; + let result = 'render: loop { + terminal.draw(|f| { + app.draw( + f, + &results, + stats.clone(), + inspecting.as_ref(), + settings, + theme, + popup_mode, + ); + })?; + + let initial_input = app.search.input.as_str().to_owned(); + let initial_filter_mode = app.search.filter_mode; + let initial_search_mode = app.search_mode; + let initial_custom_context = app.search.custom_context.clone(); + + let event_ready = tokio::task::spawn_blocking(|| event::poll(Duration::from_millis(250))); + + tokio::select! { + event_ready = event_ready => { + if event_ready?? { + loop { + match app.handle_input(settings, &event::read()?) { + InputAction::Continue => {}, + InputAction::Delete(index) => { + if results.is_empty() { + break; + } + app.results_len -= 1; + let selected = app.results_state.selected(); + if selected == app.results_len { + app.inspecting_state.reset(); + app.results_state.select(selected - 1); + } + + let entry = results.remove(index); + + let ids = history_store.delete_entries([entry]).await?; + history_store.incremental_build(&db, &ids).await?; + + app.tab_index = 0; + }, + InputAction::DeleteAllMatching(index) => { + if results.is_empty() { + break; + } + + let command = results[index].command.clone(); + + // Remove matching entries from the visible results + results.retain(|e| e.command != command); + + // Query the DB for ALL entries with this command and delete them + let all_matching = db.query_history( + &format!( + "select * from history where command = '{}' and deleted_at is null", + command.replace('\'', "''") + ) + ).await?; + + let ids = history_store.delete_entries(all_matching).await?; + history_store.incremental_build(&db, &ids).await?; + + app.results_len = results.len(); + app.results_state = ListState::default(); + app.inspecting_state.reset(); + app.tab_index = 0; + }, + InputAction::SwitchContext(index) => { + if let Some(index) = index && let Some(entry) = results.get(index) { + app.search.custom_context = Some(entry.id.clone()); + app.search.context = Context::from_history(entry); + app.search.filter_mode = FilterMode::Session; + app.search.input = Cursor::from(String::new()); + app.results_state = ListState::default(); + } else { + app.search.custom_context = None; + app.search.context = initial_context.clone(); + app.search.filter_mode = default_filter_mode; + } + }, + InputAction::Redraw => { + if !popup_mode { + terminal.clear()?; + } + terminal.draw(|f| { + app.draw(f, &results, stats.clone(), inspecting.as_ref(), settings, theme, popup_mode); + })?; + }, + r => { + accept = app.accept; + break 'render r; + }, + } + if !event::poll(Duration::ZERO)? { + break; + } + } + } + } + } + + if initial_input != app.search.input.as_str() + || initial_filter_mode != app.search.filter_mode + || initial_search_mode != app.search_mode + || initial_custom_context != app.search.custom_context + { + results = app.query_results(&mut db, settings.smart_sort).await?; + } + + // In custom context mode, when no filter is applied, highlight the entry which was used + // to enter the context when changing modes. This helps to find your way around. + if app.search.custom_context.is_some() + && app.search.input.as_str().is_empty() + && (initial_custom_context != app.search.custom_context + || initial_filter_mode != app.search.filter_mode) + && let Some(history_id) = app.search.custom_context.clone() + && let Some(pos) = results.iter().position(|entry| entry.id == history_id) + { + app.results_state.select(pos); + } + + let inspecting_id = app.inspecting_state.clone().current; + // If inspecting ID is not the current inspecting History, update it. + match inspecting_id { + Some(inspecting_id) => { + if inspecting.is_none() || inspecting_id != inspecting.clone().unwrap().id { + inspecting = db.load(inspecting_id.0.as_str()).await?; + } + } + _ => { + inspecting = None; + } + } + + stats = if app.tab_index == 0 { + None + } else if !results.is_empty() { + // If we have stats, then we can indicate next available IDs. This avoids passing + // around a database object, or a full stats object. + let selected = match inspecting.clone() { + Some(insp) => insp, + None => results[app.results_state.selected()].clone(), + }; + let stats = db.stats(&selected).await?; + app.inspecting_state.current = Some(selected.id); + app.inspecting_state.previous = match stats.previous.clone() { + Some(p) => Some(p.id), + _ => None, + }; + app.inspecting_state.next = match stats.next.clone() { + Some(p) => Some(p.id), + _ => None, + }; + Some(stats) + } else { + None + }; + }; + + app.finalize_keymap_cursor(settings); + + if popup_mode { + // In popup mode, restore the screen area that was covered by the popup. + // This must happen before Stdout is dropped (which disables raw mode). + #[cfg(unix)] + if let Some(ref saved) = saved_screen { + restore_popup_area(saved, popup_rect, popup_scroll_offset); + } + } else if inline_height > 0 { + terminal.clear()?; + } + + let accept = accept + && matches!( + Shell::from_env(), + Shell::Zsh | Shell::Fish | Shell::Bash | Shell::Xonsh | Shell::Nu | Shell::Powershell + ); + + let accept_prefix = "__atuin_accept__:"; + + match result { + InputAction::AcceptInspecting => { + match inspecting { + Some(result) => { + let mut command = result.command; + + if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + None => Ok(String::new()), + } + } + InputAction::Accept(index) if index < results.len() => { + let mut command = results.swap_remove(index).command; + + if is_command_chaining { + command = format!("{} {}", original_query.trim_end(), command); + } else if accept { + command = String::from(accept_prefix) + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + InputAction::ReturnOriginal => Ok(String::new()), + InputAction::Copy(index) => { + let cmd = results.swap_remove(index).command; + set_clipboard(cmd); + Ok(String::new()) + } + InputAction::ReturnQuery | InputAction::Accept(_) => { + // Either: + // * index == RETURN_QUERY, in which case we should return the input + // * out of bounds -> usually implies no selected entry so we return the input + Ok(app.search.input.into_inner()) + } + InputAction::Continue + | InputAction::Redraw + | InputAction::Delete(_) + | InputAction::DeleteAllMatching(_) + | InputAction::SwitchContext(_) => { + unreachable!("should have been handled!") + } + } +} + +// cli-clipboard only works on Windows, Mac, and Linux. + +#[cfg(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +))] +fn set_clipboard(s: String) { + let mut ctx = arboard::Clipboard::new().unwrap(); + ctx.set_text(s).unwrap(); + // Use the clipboard context to make sure it is saved + ctx.get_text().unwrap(); +} + +#[cfg(not(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +)))] +fn set_clipboard(_s: String) {} + +#[cfg(test)] +mod tests { + use crate::atuin_client::database::Context; + use crate::atuin_client::history::History; + use crate::atuin_client::settings::{ + FilterMode, KeymapMode, Preview, PreviewStrategy, SearchMode, Settings, + }; + use time::OffsetDateTime; + + use crate::command::client::search::engines::{self, SearchState}; + use crate::command::client::search::history_list::ListState; + + use super::{Compactness, InspectingState, KeymapSet, State}; + + #[test] + #[expect(clippy::too_many_lines)] + fn calc_preview_height_test() { + let settings_preview_auto = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + ..Settings::utc() + }; + + let settings_preview_auto_h2 = Settings { + preview: Preview { + strategy: PreviewStrategy::Auto, + }, + show_preview: true, + max_preview_height: 2, + ..Settings::utc() + }; + + let settings_preview_h4 = Settings { + preview: Preview { + strategy: PreviewStrategy::Static, + }, + show_preview: true, + max_preview_height: 4, + ..Settings::utc() + }; + + let settings_preview_fixed = Settings { + preview: Preview { + strategy: PreviewStrategy::Fixed, + }, + show_preview: true, + max_preview_height: 15, + ..Settings::utc() + }; + + let cmd_60: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("for i in $(seq -w 10); do echo \"item number $i - abcd\"; done") + .cwd("/") + .build() + .into(); + + let cmd_124: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo 'Aurea prima sata est aetas, quae vindice nullo, sponte sua, sine lege fidem rectumque colebat. Poena metusque aberant'") + .cwd("/") + .build() + .into(); + + let cmd_200: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("CREATE USER atuin WITH ENCRYPTED PASSWORD 'supersecretpassword'; CREATE DATABASE atuin WITH OWNER = atuin; \\c atuin; REVOKE ALL PRIVILEGES ON SCHEMA public FROM PUBLIC; echo 'All done. 200 characters'") + .cwd("/") + .build() + .into(); + + let results: Vec = vec![cmd_60, cmd_124, cmd_200]; + + // the selected command does not require a preview + let no_preview = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 2 lines + let preview_h2 = State::calc_preview_height( + &settings_preview_auto, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires 3 lines + let preview_h3 = State::calc_preview_height( + &settings_preview_auto, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the selected command requires a preview of 1 line (happens when the command is between preview_width-19 and preview_width) + let preview_one_line = State::calc_preview_height( + &settings_preview_auto, + &results, + 0_usize, + 0_usize, + Compactness::Full, + 1, + 66, + ); + // the selected command requires 3 lines, but we have a max preview height limit of 2 + let preview_limit_at_2 = State::calc_preview_height( + &settings_preview_auto_h2, + &results, + 2_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 3 lines + let preview_static_h3 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 80, + ); + // the longest command requires 10 lines, but we have a max preview height limit of 4 + let preview_static_limit_at_4 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + // the longest command requires 10 lines, but we have a max preview height of 15 and a fixed preview strategy + let settings_preview_fixed = State::calc_preview_height( + &settings_preview_fixed, + &results, + 1_usize, + 0_usize, + Compactness::Full, + 1, + 20, + ); + + assert_eq!(no_preview, 1); + // 1 * 2 is the space for the border + let border_space = 2; + assert_eq!(preview_h2, 2 + border_space); + assert_eq!(preview_h3, 3 + border_space); + assert_eq!(preview_one_line, 1 + border_space); + assert_eq!(preview_limit_at_2, 2 + border_space); + assert_eq!(preview_static_h3, 3 + border_space); + assert_eq!(preview_static_limit_at_4, 4 + border_space); + assert_eq!(settings_preview_fixed, 15 + border_space); + } + + // Test when there's no results, scrolling up or down doesn't underflow + #[test] + fn state_scroll_up_underflow() { + let settings = Settings::utc(); + let mut state = State { + history_count: 0, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 0, + accept: false, + keymap_mode: KeymapMode::Auto, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Directory, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.scroll_up(1); + state.scroll_down(1); + } + + #[test] + fn test_accept_keybindings() { + use crate::atuin_client::settings::Keys; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let mut settings = Settings::utc(); + settings.keys = Keys { + scroll_exits: true, + exit_past_line_start: false, + accept_past_line_end: true, + accept_past_line_start: false, + accept_with_backspace: false, + prefix: "a".to_string(), + }; + + let mut state = State { + history_count: 1, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 1, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Tab should always accept" + ); + + // Test left arrow with accept_past_line_start disabled (should continue) + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue when disabled" + ); + + // Test left arrow with accept_past_line_start enabled (should accept at start of line) + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Left arrow should accept at start of line when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue when disabled" + ); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Backspace should accept at start of line when enabled" + ); + + state.search.input.insert('t'); + state.search.input.insert('e'); + state.search.input.insert('s'); + state.search.input.insert('t'); + state.search.input.end(); + + let right_event = KeyEvent::new(KeyCode::Right, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &right_event); + assert!( + matches!(result, super::InputAction::Accept(_)), + "Right arrow should accept at end of line when enabled" + ); + + settings.keys.accept_past_line_start = true; + state.keymaps = KeymapSet::defaults(&settings); + let left_event = KeyEvent::new(KeyCode::Left, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &left_event); + assert!( + matches!(result, super::InputAction::Continue), + "Left arrow should continue and end of line, even when enabled" + ); + settings.keys.accept_past_line_start = false; + state.keymaps = KeymapSet::defaults(&settings); + + settings.keys.accept_with_backspace = true; + state.keymaps = KeymapSet::defaults(&settings); + let backspace_event = KeyEvent::new(KeyCode::Backspace, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &backspace_event); + assert!( + matches!(result, super::InputAction::Continue), + "Backspace should continue at end of line, even when enabled" + ); + settings.keys.accept_with_backspace = false; + state.keymaps = KeymapSet::defaults(&settings); + } + + #[test] + fn test_vim_gg_multikey_sequence() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + // Start in the middle of the list + state.results_state.select(50); + + // First 'g' should set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, Some('g')); + assert_eq!(state.results_state.selected(), 50); // Position unchanged + + // Second 'g' should jump to end (visual top in non-inverted mode) + let result = state.handle_key_input(&settings, &g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + assert_eq!(state.results_state.selected(), 99); // Jumped to last index (visual top) + } + + #[test] + fn test_vim_g_key_clears_on_other_input() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Press 'g' to set pending state + let g_event = KeyEvent::new(KeyCode::Char('g'), KeyModifiers::NONE); + state.handle_key_input(&settings, &g_event); + assert_eq!(state.pending_vim_key, Some('g')); + + // Press 'j' - should clear pending state + let j_event = KeyEvent::new(KeyCode::Char('j'), KeyModifiers::NONE); + state.handle_key_input(&settings, &j_event); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_big_g_jump_to_bottom() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // 'G' should jump to visual bottom (index 0 in non-inverted mode) + let big_g_event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &big_g_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn test_vim_ctrl_u_d_half_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+d should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_d_event = KeyEvent::new(KeyCode::Char('d'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_d_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+u should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_u_event = KeyEvent::new(KeyCode::Char('u'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_u_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + #[test] + fn test_vim_ctrl_f_b_full_page_scroll() { + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + let settings = Settings::utc(); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::VimNormal, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + state.results_state.select(50); + + // Ctrl+f should return Continue and clear pending key + // (scroll amount depends on max_entries which is 0 in tests) + state.pending_vim_key = Some('g'); + let ctrl_f_event = KeyEvent::new(KeyCode::Char('f'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_f_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + + // Ctrl+b should return Continue and clear pending key + state.pending_vim_key = Some('g'); + let ctrl_b_event = KeyEvent::new(KeyCode::Char('b'), KeyModifiers::CONTROL); + let result = state.handle_key_input(&settings, &ctrl_b_event); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.pending_vim_key, None); + } + + // ----------------------------------------------------------------------- + // Executor tests (execute_action) + // ----------------------------------------------------------------------- + + /// Helper to build a State for executor tests. + fn make_executor_state(results_len: usize, selected: usize) -> State { + let settings = Settings::utc(); + let mut state = State { + history_count: results_len as i64, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::defaults(&settings), + search: SearchState { + input: String::new().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + state.results_state.select(selected); + state + } + + #[test] + fn execute_select_next_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectNext = scroll_down = selected - 1 + assert_eq!(state.results_state.selected(), 49); + } + + #[test] + fn execute_select_next_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::SelectNext, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: SelectNext = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_select_previous_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SelectPrevious, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: SelectPrevious = scroll_up = selected + 1 + assert_eq!(state.results_state.selected(), 51); + } + + #[test] + fn execute_vim_enter_normal() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterNormal, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimNormal); + } + + #[test] + fn execute_vim_enter_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimEnterInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_accept_sets_accept_flag() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let mut settings = Settings::utc(); + settings.enter_accept = true; + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(state.accept); + } + + #[test] + fn execute_return_selection_does_not_set_accept() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnSelection, &settings); + assert!(matches!(result, super::InputAction::Accept(5))); + assert!(!state.accept); + } + + #[test] + fn execute_accept_nth() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + let settings = Settings::utc(); + let result = state.execute_action(&Action::AcceptNth(3), &settings); + assert!(matches!(result, super::InputAction::Accept(8))); + } + + #[test] + fn execute_scroll_to_top_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual top = highest index + assert_eq!(state.results_state.selected(), 99); + } + + #[test] + fn execute_scroll_to_top_with_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let mut settings = Settings::utc(); + settings.invert = true; + let result = state.execute_action(&Action::ScrollToTop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Inverted: visual top = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_scroll_to_bottom_no_invert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ScrollToBottom, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Non-inverted: visual bottom = index 0 + assert_eq!(state.results_state.selected(), 0); + } + + #[test] + fn execute_toggle_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert_eq!(state.tab_index, 0); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 1); + state.execute_action(&Action::ToggleTab, &settings); + assert_eq!(state.tab_index, 0); + } + + #[test] + fn execute_enter_prefix_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + assert!(!state.prefix); + state.execute_action(&Action::EnterPrefixMode, &settings); + assert!(state.prefix); + } + + #[test] + fn execute_exit_returns_based_on_exit_mode() { + use crate::atuin_client::settings::ExitMode; + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let mut settings = Settings::utc(); + + settings.exit_mode = ExitMode::ReturnOriginal; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + + settings.exit_mode = ExitMode::ReturnQuery; + let result = state.execute_action(&Action::Exit, &settings); + assert!(matches!(result, super::InputAction::ReturnQuery)); + } + + #[test] + fn execute_return_original() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ReturnOriginal, &settings); + assert!(matches!(result, super::InputAction::ReturnOriginal)); + } + + #[test] + fn execute_copy() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Copy, &settings); + assert!(matches!(result, super::InputAction::Copy(7))); + } + + #[test] + fn execute_delete() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Delete, &settings); + assert!(matches!(result, super::InputAction::Delete(7))); + } + + #[test] + fn execute_switch_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::SwitchContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(Some(7)))); + } + + #[test] + fn execute_clear_context() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 7); + let settings = Settings::utc(); + let result = state.execute_action(&Action::ClearContext, &settings); + assert!(matches!(result, super::InputAction::SwitchContext(None))); + } + + #[test] + fn execute_noop() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 50); + let settings = Settings::utc(); + let result = state.execute_action(&Action::Noop, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert_eq!(state.results_state.selected(), 50); + } + + #[test] + fn execute_accept_in_inspector_tab() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 5); + state.tab_index = 1; + let settings = Settings::utc(); + let result = state.execute_action(&Action::Accept, &settings); + assert!(matches!(result, super::InputAction::AcceptInspecting)); + } + + #[test] + fn execute_cycle_search_mode() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + let original_mode = state.search_mode; + let result = state.execute_action(&Action::CycleSearchMode, &settings); + assert!(matches!(result, super::InputAction::Continue)); + assert!(state.switched_search_mode); + assert_ne!(state.search_mode, original_mode); + } + + #[test] + fn execute_vim_search_insert() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + state.search.input.insert('h'); + state.search.input.insert('i'); + state.keymap_mode = KeymapMode::VimNormal; + let settings = Settings::utc(); + let result = state.execute_action(&Action::VimSearchInsert, &settings); + assert!(matches!(result, super::InputAction::Continue)); + // Should clear input and switch to insert mode + assert_eq!(state.search.input.as_str(), ""); + assert_eq!(state.keymap_mode, KeymapMode::VimInsert); + } + + #[test] + fn execute_cursor_movement() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert some text + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + // cursor is at end (position 5) + + // CursorLeft + state.execute_action(&Action::CursorLeft, &settings); + assert_eq!(state.search.input.position(), 4); + + // CursorStart + state.execute_action(&Action::CursorStart, &settings); + assert_eq!(state.search.input.position(), 0); + + // CursorEnd + state.execute_action(&Action::CursorEnd, &settings); + assert_eq!(state.search.input.position(), 5); + + // CursorRight at end does nothing + state.execute_action(&Action::CursorRight, &settings); + assert_eq!(state.search.input.position(), 5); + } + + #[test] + fn execute_editing() { + use crate::command::client::search::keybindings::Action; + + let mut state = make_executor_state(100, 0); + let settings = Settings::utc(); + + // Insert "hello" + state.search.input.insert('h'); + state.search.input.insert('e'); + state.search.input.insert('l'); + state.search.input.insert('l'); + state.search.input.insert('o'); + + // DeleteCharBefore (backspace) + state.execute_action(&Action::DeleteCharBefore, &settings); + assert_eq!(state.search.input.as_str(), "hell"); + + // ClearLine + state.execute_action(&Action::ClearLine, &settings); + assert_eq!(state.search.input.as_str(), ""); + } + + #[test] + fn keymap_config_return_query() { + use crate::atuin_client::settings::KeyBindingConfig; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + use std::collections::HashMap; + + let mut settings = Settings::utc(); + // Configure tab to return-query + settings.keymap.emacs = HashMap::from([( + "tab".to_string(), + KeyBindingConfig::Simple("return-query".to_string()), + )]); + + let mut state = State { + history_count: 100, + results_state: ListState::default(), + switched_search_mode: false, + search_mode: SearchMode::Fuzzy, + results_len: 100, + accept: false, + keymap_mode: KeymapMode::Emacs, + prefix: false, + current_cursor: None, + tab_index: 0, + pending_vim_key: None, + original_input_empty: false, + inspecting_state: InspectingState { + current: None, + next: None, + previous: None, + }, + keymaps: KeymapSet::from_settings(&settings), + search: SearchState { + input: "test query".to_string().into(), + filter_mode: FilterMode::Global, + context: Context { + session: String::new(), + cwd: String::new(), + hostname: String::new(), + host_id: String::new(), + git_root: None, + }, + custom_context: None, + }, + engine: engines::engine(SearchMode::Fuzzy, &settings), + now: Box::new(OffsetDateTime::now_utc), + }; + + let tab_event = KeyEvent::new(KeyCode::Tab, KeyModifiers::NONE); + let result = state.handle_key_input(&settings, &tab_event); + assert!( + matches!(result, super::InputAction::ReturnQuery), + "Tab configured as return-query should return InputAction::ReturnQuery" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/actions.rs b/crates/turtle/src/command/client/search/keybindings/actions.rs new file mode 100644 index 00000000..ff2ef7de --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/actions.rs @@ -0,0 +1,322 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// All possible actions that can be triggered by a keybinding. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Action { + // Cursor movement + CursorLeft, + CursorRight, + CursorWordLeft, + CursorWordRight, + CursorWordEnd, + CursorStart, + CursorEnd, + + // Editing + DeleteCharBefore, + DeleteCharAfter, + DeleteWordBefore, + DeleteWordAfter, + DeleteToWordBoundary, + ClearLine, + ClearToStart, + ClearToEnd, + + // List navigation + SelectNext, + SelectPrevious, + ScrollHalfPageUp, + ScrollHalfPageDown, + ScrollPageUp, + ScrollPageDown, + ScrollToTop, + ScrollToBottom, + ScrollToScreenTop, + ScrollToScreenMiddle, + ScrollToScreenBottom, + + // Commands — accept selection and execute immediately + Accept, + AcceptNth(u8), + // Commands — return selection to command line without executing + ReturnSelection, + ReturnSelectionNth(u8), + // Commands — other + Copy, + Delete, + DeleteAll, + ReturnOriginal, + ReturnQuery, + Exit, + Redraw, + CycleFilterMode, + CycleSearchMode, + SwitchContext, + ClearContext, + ToggleTab, + + // Mode changes + VimEnterNormal, + VimEnterInsert, + VimEnterInsertAfter, + VimEnterInsertAtStart, + VimEnterInsertAtEnd, + VimSearchInsert, + VimChangeToEnd, + EnterPrefixMode, + + // Inspector + InspectPrevious, + InspectNext, + + // Special + Noop, +} + +impl Action { + /// Convert from a kebab-case string. + pub fn from_str(s: &str) -> Result { + // Handle accept-N and return-selection-N patterns + if let Some(rest) = s.strip_prefix("accept-") + && let Ok(n) = rest.parse::() + && (1..=9).contains(&n) + { + return Ok(Action::AcceptNth(n)); + } + if let Some(rest) = s.strip_prefix("return-selection-") + && let Ok(n) = rest.parse::() + && (1..=9).contains(&n) + { + return Ok(Action::ReturnSelectionNth(n)); + } + + match s { + "cursor-left" => Ok(Action::CursorLeft), + "cursor-right" => Ok(Action::CursorRight), + "cursor-word-left" => Ok(Action::CursorWordLeft), + "cursor-word-right" => Ok(Action::CursorWordRight), + "cursor-word-end" => Ok(Action::CursorWordEnd), + "cursor-start" => Ok(Action::CursorStart), + "cursor-end" => Ok(Action::CursorEnd), + + "delete-char-before" => Ok(Action::DeleteCharBefore), + "delete-char-after" => Ok(Action::DeleteCharAfter), + "delete-word-before" => Ok(Action::DeleteWordBefore), + "delete-word-after" => Ok(Action::DeleteWordAfter), + "delete-to-word-boundary" => Ok(Action::DeleteToWordBoundary), + "clear-line" => Ok(Action::ClearLine), + "clear-to-start" => Ok(Action::ClearToStart), + "clear-to-end" => Ok(Action::ClearToEnd), + + "select-next" => Ok(Action::SelectNext), + "select-previous" => Ok(Action::SelectPrevious), + "scroll-half-page-up" => Ok(Action::ScrollHalfPageUp), + "scroll-half-page-down" => Ok(Action::ScrollHalfPageDown), + "scroll-page-up" => Ok(Action::ScrollPageUp), + "scroll-page-down" => Ok(Action::ScrollPageDown), + "scroll-to-top" => Ok(Action::ScrollToTop), + "scroll-to-bottom" => Ok(Action::ScrollToBottom), + "scroll-to-screen-top" => Ok(Action::ScrollToScreenTop), + "scroll-to-screen-middle" => Ok(Action::ScrollToScreenMiddle), + "scroll-to-screen-bottom" => Ok(Action::ScrollToScreenBottom), + + "accept" => Ok(Action::Accept), + "return-selection" => Ok(Action::ReturnSelection), + "copy" => Ok(Action::Copy), + "delete" => Ok(Action::Delete), + "delete-all" => Ok(Action::DeleteAll), + "return-original" => Ok(Action::ReturnOriginal), + "return-query" => Ok(Action::ReturnQuery), + "exit" => Ok(Action::Exit), + "redraw" => Ok(Action::Redraw), + "cycle-filter-mode" => Ok(Action::CycleFilterMode), + "cycle-search-mode" => Ok(Action::CycleSearchMode), + "switch-context" => Ok(Action::SwitchContext), + "clear-context" => Ok(Action::ClearContext), + "toggle-tab" => Ok(Action::ToggleTab), + + "vim-enter-normal" => Ok(Action::VimEnterNormal), + "vim-enter-insert" => Ok(Action::VimEnterInsert), + "vim-enter-insert-after" => Ok(Action::VimEnterInsertAfter), + "vim-enter-insert-at-start" => Ok(Action::VimEnterInsertAtStart), + "vim-enter-insert-at-end" => Ok(Action::VimEnterInsertAtEnd), + "vim-search-insert" => Ok(Action::VimSearchInsert), + "vim-change-to-end" => Ok(Action::VimChangeToEnd), + "enter-prefix-mode" => Ok(Action::EnterPrefixMode), + + "inspect-previous" => Ok(Action::InspectPrevious), + "inspect-next" => Ok(Action::InspectNext), + + "noop" => Ok(Action::Noop), + + _ => Err(format!("unknown action: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> String { + match self { + Action::CursorLeft => "cursor-left".to_string(), + Action::CursorRight => "cursor-right".to_string(), + Action::CursorWordLeft => "cursor-word-left".to_string(), + Action::CursorWordRight => "cursor-word-right".to_string(), + Action::CursorWordEnd => "cursor-word-end".to_string(), + Action::CursorStart => "cursor-start".to_string(), + Action::CursorEnd => "cursor-end".to_string(), + + Action::DeleteCharBefore => "delete-char-before".to_string(), + Action::DeleteCharAfter => "delete-char-after".to_string(), + Action::DeleteWordBefore => "delete-word-before".to_string(), + Action::DeleteWordAfter => "delete-word-after".to_string(), + Action::DeleteToWordBoundary => "delete-to-word-boundary".to_string(), + Action::ClearLine => "clear-line".to_string(), + Action::ClearToStart => "clear-to-start".to_string(), + Action::ClearToEnd => "clear-to-end".to_string(), + + Action::SelectNext => "select-next".to_string(), + Action::SelectPrevious => "select-previous".to_string(), + Action::ScrollHalfPageUp => "scroll-half-page-up".to_string(), + Action::ScrollHalfPageDown => "scroll-half-page-down".to_string(), + Action::ScrollPageUp => "scroll-page-up".to_string(), + Action::ScrollPageDown => "scroll-page-down".to_string(), + Action::ScrollToTop => "scroll-to-top".to_string(), + Action::ScrollToBottom => "scroll-to-bottom".to_string(), + Action::ScrollToScreenTop => "scroll-to-screen-top".to_string(), + Action::ScrollToScreenMiddle => "scroll-to-screen-middle".to_string(), + Action::ScrollToScreenBottom => "scroll-to-screen-bottom".to_string(), + + Action::Accept => "accept".to_string(), + Action::AcceptNth(n) => format!("accept-{n}"), + Action::ReturnSelection => "return-selection".to_string(), + Action::ReturnSelectionNth(n) => format!("return-selection-{n}"), + Action::Copy => "copy".to_string(), + Action::Delete => "delete".to_string(), + Action::DeleteAll => "delete-all".to_string(), + Action::ReturnOriginal => "return-original".to_string(), + Action::ReturnQuery => "return-query".to_string(), + Action::Exit => "exit".to_string(), + Action::Redraw => "redraw".to_string(), + Action::CycleFilterMode => "cycle-filter-mode".to_string(), + Action::CycleSearchMode => "cycle-search-mode".to_string(), + Action::SwitchContext => "switch-context".to_string(), + Action::ClearContext => "clear-context".to_string(), + Action::ToggleTab => "toggle-tab".to_string(), + + Action::VimEnterNormal => "vim-enter-normal".to_string(), + Action::VimEnterInsert => "vim-enter-insert".to_string(), + Action::VimEnterInsertAfter => "vim-enter-insert-after".to_string(), + Action::VimEnterInsertAtStart => "vim-enter-insert-at-start".to_string(), + Action::VimEnterInsertAtEnd => "vim-enter-insert-at-end".to_string(), + Action::VimSearchInsert => "vim-search-insert".to_string(), + Action::VimChangeToEnd => "vim-change-to-end".to_string(), + Action::EnterPrefixMode => "enter-prefix-mode".to_string(), + + Action::InspectPrevious => "inspect-previous".to_string(), + Action::InspectNext => "inspect-next".to_string(), + + Action::Noop => "noop".to_string(), + } + } +} + +impl fmt::Display for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl Serialize for Action { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.as_str()) + } +} + +impl<'de> Deserialize<'de> for Action { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Action::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_basic_actions() { + assert_eq!(Action::from_str("cursor-left").unwrap(), Action::CursorLeft); + assert_eq!(Action::from_str("accept").unwrap(), Action::Accept); + assert_eq!(Action::from_str("exit").unwrap(), Action::Exit); + assert_eq!(Action::from_str("noop").unwrap(), Action::Noop); + assert_eq!( + Action::from_str("vim-enter-normal").unwrap(), + Action::VimEnterNormal + ); + } + + #[test] + fn parse_accept_nth() { + assert_eq!(Action::from_str("accept-1").unwrap(), Action::AcceptNth(1)); + assert_eq!(Action::from_str("accept-9").unwrap(), Action::AcceptNth(9)); + } + + #[test] + fn parse_return_selection() { + assert_eq!( + Action::from_str("return-selection").unwrap(), + Action::ReturnSelection + ); + assert_eq!( + Action::from_str("return-selection-1").unwrap(), + Action::ReturnSelectionNth(1) + ); + assert_eq!( + Action::from_str("return-selection-9").unwrap(), + Action::ReturnSelectionNth(9) + ); + } + + #[test] + fn parse_unknown_action() { + assert!(Action::from_str("unknown-action").is_err()); + assert!(Action::from_str("accept-0").is_err()); + assert!(Action::from_str("accept-10").is_err()); + assert!(Action::from_str("return-selection-0").is_err()); + assert!(Action::from_str("return-selection-10").is_err()); + } + + #[test] + fn round_trip() { + let actions = vec![ + Action::CursorLeft, + Action::Accept, + Action::AcceptNth(5), + Action::ReturnSelection, + Action::ReturnSelectionNth(3), + Action::VimSearchInsert, + Action::ScrollToScreenMiddle, + ]; + for action in actions { + let s = action.as_str(); + let parsed = Action::from_str(&s).unwrap(); + assert_eq!(action, parsed); + } + } + + #[test] + fn serde_round_trip() { + let action = Action::CursorLeft; + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"cursor-left\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::CursorLeft); + + let action = Action::AcceptNth(3); + let json = serde_json::to_string(&action).unwrap(); + assert_eq!(json, "\"accept-3\""); + let parsed: Action = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Action::AcceptNth(3)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/conditions.rs b/crates/turtle/src/command/client/search/keybindings/conditions.rs new file mode 100644 index 00000000..055ae905 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/conditions.rs @@ -0,0 +1,801 @@ +use std::fmt; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// Atomic (leaf) conditions that can be evaluated against state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionAtom { + CursorAtStart, + CursorAtEnd, + InputEmpty, + OriginalInputEmpty, + ListAtEnd, + ListAtStart, + NoResults, + HasResults, + HasContext, +} + +/// Boolean expression tree over condition atoms. +/// +/// Supports negation, conjunction, and disjunction with standard precedence: +/// `!` binds tightest, then `&&`, then `||`. +/// +/// Examples of valid expression strings: +/// - `"cursor-at-start"` (bare atom) +/// - `"!no-results"` (negation) +/// - `"cursor-at-start && input-empty"` (conjunction) +/// - `"list-at-start || no-results"` (disjunction) +/// - `"(cursor-at-start && !input-empty) || no-results"` (grouping) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConditionExpr { + Atom(ConditionAtom), + Not(Box), + And(Box, Box), + Or(Box, Box), +} + +/// Context needed to evaluate conditions. This is a pure snapshot of state — +/// no references to mutable data. +pub struct EvalContext { + /// Current cursor position (unicode width units). + pub cursor_position: usize, + /// Width of the input string in unicode width units. + pub input_width: usize, + /// Byte length of the input string. + pub input_byte_len: usize, + /// Currently selected index in the results list. + pub selected_index: usize, + /// Total number of results. + pub results_len: usize, + /// Whether the original input (query passed to the TUI) was empty. + pub original_input_empty: bool, + /// Whether we use a search context of a command from the history. + pub has_context: bool, +} + +// --------------------------------------------------------------------------- +// ConditionAtom +// --------------------------------------------------------------------------- + +impl ConditionAtom { + /// Evaluate this atom against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionAtom::CursorAtStart => ctx.cursor_position == 0, + ConditionAtom::CursorAtEnd => ctx.cursor_position == ctx.input_width, + ConditionAtom::InputEmpty => ctx.input_byte_len == 0, + ConditionAtom::OriginalInputEmpty => ctx.original_input_empty, + ConditionAtom::ListAtEnd => { + ctx.results_len == 0 || ctx.selected_index >= ctx.results_len.saturating_sub(1) + } + ConditionAtom::ListAtStart => ctx.results_len == 0 || ctx.selected_index == 0, + ConditionAtom::NoResults => ctx.results_len == 0, + ConditionAtom::HasResults => ctx.results_len > 0, + ConditionAtom::HasContext => ctx.has_context, + } + } + + /// Parse from a kebab-case string. + pub fn from_str(s: &str) -> Result { + match s { + "cursor-at-start" => Ok(ConditionAtom::CursorAtStart), + "cursor-at-end" => Ok(ConditionAtom::CursorAtEnd), + "input-empty" => Ok(ConditionAtom::InputEmpty), + "original-input-empty" => Ok(ConditionAtom::OriginalInputEmpty), + "list-at-end" => Ok(ConditionAtom::ListAtEnd), + "list-at-start" => Ok(ConditionAtom::ListAtStart), + "no-results" => Ok(ConditionAtom::NoResults), + "has-results" => Ok(ConditionAtom::HasResults), + "has-context" => Ok(ConditionAtom::HasContext), + _ => Err(format!("unknown condition: {s}")), + } + } + + /// Convert to a kebab-case string. + pub fn as_str(&self) -> &'static str { + match self { + ConditionAtom::CursorAtStart => "cursor-at-start", + ConditionAtom::CursorAtEnd => "cursor-at-end", + ConditionAtom::InputEmpty => "input-empty", + ConditionAtom::OriginalInputEmpty => "original-input-empty", + ConditionAtom::ListAtEnd => "list-at-end", + ConditionAtom::ListAtStart => "list-at-start", + ConditionAtom::NoResults => "no-results", + ConditionAtom::HasResults => "has-results", + ConditionAtom::HasContext => "has-context", + } + } +} + +impl fmt::Display for ConditionAtom { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — evaluation +// --------------------------------------------------------------------------- + +impl ConditionExpr { + /// Evaluate this expression against the given context. + pub fn evaluate(&self, ctx: &EvalContext) -> bool { + match self { + ConditionExpr::Atom(atom) => atom.evaluate(ctx), + ConditionExpr::Not(inner) => !inner.evaluate(ctx), + ConditionExpr::And(lhs, rhs) => lhs.evaluate(ctx) && rhs.evaluate(ctx), + ConditionExpr::Or(lhs, rhs) => lhs.evaluate(ctx) || rhs.evaluate(ctx), + } + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — ergonomic builders +// --------------------------------------------------------------------------- + +impl From for ConditionExpr { + fn from(atom: ConditionAtom) -> Self { + ConditionExpr::Atom(atom) + } +} + +#[expect(dead_code)] +impl ConditionExpr { + /// Negate this expression: `!self`. + pub fn not(self) -> Self { + ConditionExpr::Not(Box::new(self)) + } + + /// Conjoin with another expression: `self && other`. + pub fn and(self, other: ConditionExpr) -> Self { + ConditionExpr::And(Box::new(self), Box::new(other)) + } + + /// Disjoin with another expression: `self || other`. + pub fn or(self, other: ConditionExpr) -> Self { + ConditionExpr::Or(Box::new(self), Box::new(other)) + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — parser +// --------------------------------------------------------------------------- + +/// Recursive descent parser for boolean condition expressions. +/// +/// Grammar (standard boolean precedence): +/// ```text +/// expr = or_expr +/// or_expr = and_expr ("||" and_expr)* +/// and_expr = unary ("&&" unary)* +/// unary = "!" unary | primary +/// primary = atom | "(" expr ")" +/// atom = [a-z][a-z0-9-]* +/// ``` +struct ExprParser<'a> { + input: &'a str, + pos: usize, +} + +impl<'a> ExprParser<'a> { + fn new(input: &'a str) -> Self { + Self { input, pos: 0 } + } + + fn skip_whitespace(&mut self) { + while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() { + self.pos += 1; + } + } + + fn starts_with(&mut self, s: &str) -> bool { + self.skip_whitespace(); + self.input[self.pos..].starts_with(s) + } + + fn consume(&mut self, s: &str) -> bool { + self.skip_whitespace(); + if self.input[self.pos..].starts_with(s) { + self.pos += s.len(); + true + } else { + false + } + } + + /// Parse a full expression, expecting to consume all input. + fn parse(mut self) -> Result { + let expr = self.parse_or()?; + self.skip_whitespace(); + if self.pos < self.input.len() { + return Err(format!( + "unexpected input at position {}: {:?}", + self.pos, + &self.input[self.pos..] + )); + } + Ok(expr) + } + + /// `or_expr` = `and_expr` ("||" `and_expr`)* + fn parse_or(&mut self) -> Result { + let mut left = self.parse_and()?; + while self.starts_with("||") { + self.consume("||"); + let right = self.parse_and()?; + left = ConditionExpr::Or(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// `and_expr` = unary ("&&" unary)* + fn parse_and(&mut self) -> Result { + let mut left = self.parse_unary()?; + while self.starts_with("&&") { + self.consume("&&"); + let right = self.parse_unary()?; + left = ConditionExpr::And(Box::new(left), Box::new(right)); + } + Ok(left) + } + + /// unary = "!" unary | primary + fn parse_unary(&mut self) -> Result { + if self.consume("!") { + let inner = self.parse_unary()?; + Ok(ConditionExpr::Not(Box::new(inner))) + } else { + self.parse_primary() + } + } + + /// primary = "(" expr ")" | atom + fn parse_primary(&mut self) -> Result { + if self.consume("(") { + let expr = self.parse_or()?; + if !self.consume(")") { + return Err(format!("expected ')' at position {}", self.pos)); + } + Ok(expr) + } else { + self.parse_atom() + } + } + + /// atom = [a-z][a-z0-9-]* + fn parse_atom(&mut self) -> Result { + self.skip_whitespace(); + let start = self.pos; + while self.pos < self.input.len() { + let b = self.input.as_bytes()[self.pos]; + if b.is_ascii_lowercase() || b.is_ascii_digit() || b == b'-' { + self.pos += 1; + } else { + break; + } + } + if self.pos == start { + return Err(format!("expected condition name at position {}", self.pos)); + } + let name = &self.input[start..self.pos]; + let atom = ConditionAtom::from_str(name)?; + Ok(ConditionExpr::Atom(atom)) + } +} + +impl ConditionExpr { + /// Parse a condition expression from a string. + pub fn parse(s: &str) -> Result { + let parser = ExprParser::new(s); + parser.parse() + } +} + +// --------------------------------------------------------------------------- +// ConditionExpr — Display +// --------------------------------------------------------------------------- + +/// Precedence levels for minimal-parentheses display. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +enum Prec { + Or = 0, + And = 1, + Not = 2, + Atom = 3, +} + +impl ConditionExpr { + fn prec(&self) -> Prec { + match self { + ConditionExpr::Or(..) => Prec::Or, + ConditionExpr::And(..) => Prec::And, + ConditionExpr::Not(..) => Prec::Not, + ConditionExpr::Atom(..) => Prec::Atom, + } + } + + fn fmt_with_prec(&self, f: &mut fmt::Formatter<'_>, parent_prec: Prec) -> fmt::Result { + let needs_parens = self.prec() < parent_prec; + if needs_parens { + write!(f, "(")?; + } + match self { + ConditionExpr::Atom(atom) => write!(f, "{atom}")?, + ConditionExpr::Not(inner) => { + write!(f, "!")?; + inner.fmt_with_prec(f, Prec::Not)?; + } + ConditionExpr::And(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::And)?; + write!(f, " && ")?; + rhs.fmt_with_prec(f, Prec::And)?; + } + ConditionExpr::Or(lhs, rhs) => { + lhs.fmt_with_prec(f, Prec::Or)?; + write!(f, " || ")?; + rhs.fmt_with_prec(f, Prec::Or)?; + } + } + if needs_parens { + write!(f, ")")?; + } + Ok(()) + } +} + +impl fmt::Display for ConditionExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt_with_prec(f, Prec::Or) + } +} + +// --------------------------------------------------------------------------- +// Serde +// --------------------------------------------------------------------------- + +impl Serialize for ConditionExpr { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for ConditionExpr { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + ConditionExpr::parse(&s).map_err(serde::de::Error::custom) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn ctx( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + ) -> EvalContext { + ctx_with_original(cursor, width, byte_len, selected, len, false) + } + + fn ctx_with_original( + cursor: usize, + width: usize, + byte_len: usize, + selected: usize, + len: usize, + original_input_empty: bool, + ) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: byte_len, + selected_index: selected, + results_len: len, + original_input_empty, + has_context: false, + } + } + + // -- Atom evaluation (carried over from Phase 0) -- + + #[test] + fn atom_cursor_at_start() { + assert!(ConditionAtom::CursorAtStart.evaluate(&ctx(0, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtStart.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn atom_cursor_at_end() { + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(5, 5, 5, 0, 10))); + assert!(!ConditionAtom::CursorAtEnd.evaluate(&ctx(3, 5, 5, 0, 10))); + assert!(ConditionAtom::CursorAtEnd.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + #[test] + fn atom_input_empty() { + assert!(ConditionAtom::InputEmpty.evaluate(&ctx(0, 0, 0, 0, 10))); + assert!(!ConditionAtom::InputEmpty.evaluate(&ctx(0, 5, 5, 0, 10))); + } + + #[test] + fn atom_original_input_empty() { + // original_input_empty = true + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, true)) + ); + // original_input_empty = false + assert!( + !ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 0, 0, 0, 10, false)) + ); + // original_input_empty is independent of current input state + assert!( + ConditionAtom::OriginalInputEmpty.evaluate(&ctx_with_original(0, 5, 5, 0, 10, true)) + ); + } + + #[test] + fn atom_list_at_end() { + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 99, 100))); + assert!(!ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtEnd.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_list_at_start() { + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 100))); + assert!(!ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 50, 100))); + assert!(ConditionAtom::ListAtStart.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_no_results_and_has_results() { + assert!(ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 0))); + assert!(!ConditionAtom::NoResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 5))); + assert!(!ConditionAtom::HasResults.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn atom_has_context() { + let mut context = ctx(0, 0, 0, 0, 0); + assert!(!ConditionAtom::HasContext.evaluate(&context)); + context.has_context = true; + assert!(ConditionAtom::HasContext.evaluate(&context)); + } + + #[test] + fn atom_parse_round_trip() { + let conditions = [ + "cursor-at-start", + "cursor-at-end", + "input-empty", + "original-input-empty", + "list-at-end", + "list-at-start", + "no-results", + "has-results", + ]; + for s in conditions { + let c = ConditionAtom::from_str(s).unwrap(); + assert_eq!(c.as_str(), s); + } + } + + #[test] + fn atom_parse_unknown() { + assert!(ConditionAtom::from_str("unknown-condition").is_err()); + } + + // -- Parser tests -- + + #[test] + fn parse_bare_atom() { + let expr = ConditionExpr::parse("cursor-at-start").unwrap(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + #[test] + fn parse_negation() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Atom(ConditionAtom::NoResults))) + ); + } + + #[test] + fn parse_double_negation() { + let expr = ConditionExpr::parse("!!no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Not(Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::NoResults + ))))) + ); + } + + #[test] + fn parse_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + ) + ); + } + + #[test] + fn parse_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::ListAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_precedence_and_binds_tighter_than_or() { + // "a || b && c" should parse as "a || (b && c)" + let expr = ConditionExpr::parse("cursor-at-start || input-empty && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + )), + ) + ); + } + + #[test] + fn parse_parens_override_precedence() { + // "(a || b) && c" + let expr = ConditionExpr::parse("(cursor-at-start || input-empty) && no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::And( + Box::new(ConditionExpr::Or( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Atom(ConditionAtom::InputEmpty)), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_complex_nested() { + // "(a && !b) || c" + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + assert_eq!( + expr, + ConditionExpr::Or( + Box::new(ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty + )))), + )), + Box::new(ConditionExpr::Atom(ConditionAtom::NoResults)), + ) + ); + } + + #[test] + fn parse_whitespace_tolerance() { + let a = ConditionExpr::parse("cursor-at-start||input-empty").unwrap(); + let b = ConditionExpr::parse("cursor-at-start || input-empty").unwrap(); + let c = ConditionExpr::parse(" cursor-at-start || input-empty ").unwrap(); + assert_eq!(a, b); + assert_eq!(b, c); + } + + #[test] + fn parse_error_unknown_atom() { + assert!(ConditionExpr::parse("unknown-thing").is_err()); + } + + #[test] + fn parse_error_trailing_input() { + assert!(ConditionExpr::parse("cursor-at-start blah").is_err()); + } + + #[test] + fn parse_error_unmatched_paren() { + assert!(ConditionExpr::parse("(cursor-at-start").is_err()); + } + + #[test] + fn parse_error_empty() { + assert!(ConditionExpr::parse("").is_err()); + } + + // -- Expression evaluation -- + + #[test] + fn eval_not() { + let expr = ConditionExpr::parse("!no-results").unwrap(); + // Has results → !no-results is true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 5))); + // No results → !no-results is false + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 0))); + } + + #[test] + fn eval_and() { + let expr = ConditionExpr::parse("cursor-at-start && input-empty").unwrap(); + // Both true + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // First true, second false (non-empty input) + assert!(!expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // First false (cursor not at start) + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + } + + #[test] + fn eval_or() { + let expr = ConditionExpr::parse("list-at-start || no-results").unwrap(); + // list at bottom (selected=0) + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 10))); + // no results + assert!(expr.evaluate(&ctx(0, 0, 0, 0, 0))); + // neither + assert!(!expr.evaluate(&ctx(0, 0, 0, 5, 10))); + } + + #[test] + fn eval_complex_nested() { + // (cursor-at-start && !input-empty) || no-results + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + + // cursor at start, input not empty → true (left branch) + assert!(expr.evaluate(&ctx(0, 5, 5, 0, 10))); + // no results → true (right branch) + assert!(expr.evaluate(&ctx(3, 5, 5, 0, 0))); + // cursor not at start, has results → false + assert!(!expr.evaluate(&ctx(3, 5, 5, 0, 10))); + // cursor at start, input empty → false (left: && fails; right: has results) + assert!(!expr.evaluate(&ctx(0, 0, 0, 0, 10))); + } + + // -- Display -- + + #[test] + fn display_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + assert_eq!(expr.to_string(), "cursor-at-start"); + } + + #[test] + fn display_not() { + let expr = ConditionExpr::Atom(ConditionAtom::NoResults).not(); + assert_eq!(expr.to_string(), "!no-results"); + } + + #[test] + fn display_and() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .and(ConditionExpr::Atom(ConditionAtom::InputEmpty)); + assert_eq!(expr.to_string(), "cursor-at-start && input-empty"); + } + + #[test] + fn display_or() { + let expr = ConditionExpr::Atom(ConditionAtom::ListAtStart) + .or(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!(expr.to_string(), "list-at-start || no-results"); + } + + #[test] + fn display_parens_when_needed() { + // (a || b) && c — the Or inside And needs parens + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart) + .or(ConditionExpr::Atom(ConditionAtom::InputEmpty)) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + assert_eq!( + expr.to_string(), + "(cursor-at-start || input-empty) && no-results" + ); + } + + #[test] + fn display_no_parens_when_not_needed() { + // a || b && c — no parens needed (and binds tighter) + let inner_and = ConditionExpr::Atom(ConditionAtom::InputEmpty) + .and(ConditionExpr::Atom(ConditionAtom::NoResults)); + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart).or(inner_and); + assert_eq!( + expr.to_string(), + "cursor-at-start || input-empty && no-results" + ); + } + + // -- Display round-trip -- + + #[test] + fn display_round_trip() { + let cases = [ + "cursor-at-start", + "!no-results", + "cursor-at-start && input-empty", + "list-at-start || no-results", + "(cursor-at-start || input-empty) && no-results", + "(cursor-at-start && !input-empty) || no-results", + ]; + for s in cases { + let expr = ConditionExpr::parse(s).unwrap(); + let displayed = expr.to_string(); + let reparsed = ConditionExpr::parse(&displayed).unwrap(); + assert_eq!(expr, reparsed, "round-trip failed for: {s}"); + } + } + + // -- Serde -- + + #[test] + fn serde_simple_atom() { + let expr = ConditionExpr::Atom(ConditionAtom::CursorAtStart); + let json = serde_json::to_string(&expr).unwrap(); + assert_eq!(json, "\"cursor-at-start\""); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, expr); + } + + #[test] + fn serde_compound_expression() { + let json = "\"cursor-at-start && !input-empty\""; + let parsed: ConditionExpr = serde_json::from_str(json).unwrap(); + let expected = ConditionExpr::And( + Box::new(ConditionExpr::Atom(ConditionAtom::CursorAtStart)), + Box::new(ConditionExpr::Not(Box::new(ConditionExpr::Atom( + ConditionAtom::InputEmpty, + )))), + ); + assert_eq!(parsed, expected); + } + + #[test] + fn serde_round_trip() { + let expr = ConditionExpr::parse("(cursor-at-start && !input-empty) || no-results").unwrap(); + let json = serde_json::to_string(&expr).unwrap(); + let parsed: ConditionExpr = serde_json::from_str(&json).unwrap(); + assert_eq!(expr, parsed); + } + + // -- From -- + + #[test] + fn from_atom_into_expr() { + let expr: ConditionExpr = ConditionAtom::CursorAtStart.into(); + assert_eq!(expr, ConditionExpr::Atom(ConditionAtom::CursorAtStart)); + } + + // -- Builder helpers -- + + #[test] + fn builder_chain() { + let expr = ConditionExpr::from(ConditionAtom::CursorAtStart) + .and(ConditionExpr::from(ConditionAtom::InputEmpty).not()) + .or(ConditionExpr::from(ConditionAtom::NoResults)); + // And binds tighter than Or, so no parens needed around the And + assert_eq!( + expr.to_string(), + "cursor-at-start && !input-empty || no-results" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/defaults.rs b/crates/turtle/src/command/client/search/keybindings/defaults.rs new file mode 100644 index 00000000..c8401e37 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/defaults.rs @@ -0,0 +1,1286 @@ +use std::collections::HashMap; + +use crate::atuin_client::settings::{KeyBindingConfig, Settings}; +use tracing::warn; + +use super::actions::Action; +use super::conditions::{ConditionAtom, ConditionExpr}; +use super::key::KeyInput; +use super::keymap::{KeyBinding, KeyRule, Keymap}; + +/// Helper to bind a scroll key with optional exit behavior. +/// +/// When `scroll_exits` is true AND the key scrolls toward index 0 (the newest +/// entry), we add a conditional rule: at `ListAtStart` → `Exit`, otherwise → +/// the scroll action. +/// +/// Whether a key scrolls toward index 0 depends on the `invert` setting: +/// - Non-inverted: "down" / "j" move toward index 0, "up" / "k" move away +/// - Inverted: "up" / "k" move toward index 0, "down" / "j" move away +/// +/// If `toward_index_zero` is false, or `scroll_exits` is false, we just bind +/// the key to the plain scroll action (no exit). +fn bind_scroll_key( + km: &mut Keymap, + key_str: &str, + action: Action, + toward_index_zero: bool, + scroll_exits: bool, +) { + let k = key(key_str); + if scroll_exits && toward_index_zero { + km.bind_conditional( + k, + vec![ + KeyRule::when(ConditionAtom::ListAtStart, Action::Exit), + KeyRule::always(action), + ], + ); + } else { + km.bind(k, action); + } +} + +/// Helper to parse a key string, panicking on invalid keys (these are all +/// compile-time-known strings). +fn key(s: &str) -> KeyInput { + KeyInput::parse(s).unwrap_or_else(|e| panic!("invalid default key {s:?}: {e}")) +} + +/// All five keymaps bundled together. +#[derive(Debug, Clone)] +pub struct KeymapSet { + pub emacs: Keymap, + pub vim_normal: Keymap, + pub vim_insert: Keymap, + pub inspector: Keymap, + pub prefix: Keymap, +} + +// --------------------------------------------------------------------------- +// Common bindings shared across search-tab keymaps +// --------------------------------------------------------------------------- + +/// Add the bindings that are common to all search-tab keymaps: +/// ctrl-c, ctrl-g, ctrl-o, and tab. +/// +/// Note: `esc`/`ctrl-[` are NOT included here because their behavior differs +/// between emacs (exit), vim-normal (exit), and vim-insert (enter normal mode). +fn add_common_bindings(km: &mut Keymap) { + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Tab: always returns selection without executing (unlike Enter which respects enter_accept) + km.bind(key("tab"), Action::ReturnSelection); +} + +/// Returns `Accept` or `ReturnSelection` based on the `enter_accept` setting. +fn accept_action(settings: &Settings) -> Action { + if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + } +} + +// --------------------------------------------------------------------------- +// Emacs keymap (also base for vim-insert) +// --------------------------------------------------------------------------- + +/// Build the default emacs keymap. This encodes the behavior from +/// `handle_key_input` common section + `handle_search_input` shared section. +/// +/// The `settings` parameter is used for: +/// - `keys.prefix` — which ctrl-key enters prefix mode +/// - `keys.scroll_exits`, `invert` — scroll-at-boundary exit behavior +/// - `keys.accept_past_line_end` — right arrow at end of line accepts +/// - `keys.exit_past_line_start` — left arrow at start of line exits +/// - `keys.accept_past_line_start` — left arrow at start accepts (overrides exit) +/// - `keys.accept_with_backspace` — backspace at start of line accepts +/// - `ctrl_n_shortcuts` — whether alt or ctrl is used for numeric shortcuts +// Keymap builder that enumerates every default binding; not worth splitting. +#[expect(clippy::too_many_lines)] +pub fn default_emacs_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + let accept = accept_action(settings); + + // esc / ctrl-[ → exit + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key: ctrl- → enter prefix mode + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Accept / navigation edge behaviors (from [keys] settings) --- + + // right: behavior at end of line + if settings.keys.accept_past_line_end { + km.bind_conditional( + key("right"), + vec![ + KeyRule::when(ConditionAtom::CursorAtEnd, Action::ReturnSelection), + KeyRule::always(Action::CursorRight), + ], + ); + } else { + km.bind(key("right"), Action::CursorRight); + } + + // left: behavior at start of line + // accept_past_line_start takes precedence over exit_past_line_start + if settings.keys.accept_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::CursorLeft), + ], + ); + } else if settings.keys.exit_past_line_start { + km.bind_conditional( + key("left"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + } else { + km.bind(key("left"), Action::CursorLeft); + } + + // down/up: scroll with optional exit at boundary. + // Non-inverted: down moves toward index 0 (can exit); up moves away (no exit). + // Inverted: up moves toward index 0 (can exit); down moves away (no exit). + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // backspace: behavior at start of line + if settings.keys.accept_with_backspace { + km.bind_conditional( + key("backspace"), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::ReturnSelection), + KeyRule::always(Action::DeleteCharBefore), + ], + ); + } else { + km.bind(key("backspace"), Action::DeleteCharBefore); + } + + // --- Accept --- + km.bind(key("enter"), accept.clone()); + km.bind(key("ctrl-m"), accept); + + // --- Copy --- + km.bind(key("ctrl-y"), Action::Copy); + + // --- Numeric shortcuts (alt-1..9 by default, ctrl-1..9 if ctrl_n_shortcuts) --- + // These return the selection without executing, regardless of enter_accept. + let num_mod = if settings.ctrl_n_shortcuts { + "ctrl" + } else { + "alt" + }; + for n in 1..=9u8 { + km.bind( + key(&format!("{num_mod}-{n}")), + Action::ReturnSelectionNth(n), + ); + } + + // --- Cursor movement --- + km.bind(key("ctrl-left"), Action::CursorWordLeft); + km.bind(key("alt-b"), Action::CursorWordLeft); + km.bind(key("ctrl-b"), Action::CursorLeft); + km.bind(key("ctrl-right"), Action::CursorWordRight); + km.bind(key("alt-f"), Action::CursorWordRight); + km.bind(key("ctrl-f"), Action::CursorRight); + km.bind(key("home"), Action::CursorStart); + // ctrl-a → CursorStart only if prefix char is NOT 'a' + // (otherwise ctrl-a is already bound to EnterPrefixMode above) + if prefix_char != 'a' { + km.bind(key("ctrl-a"), Action::CursorStart); + } + km.bind(key("ctrl-e"), Action::CursorEnd); + km.bind(key("end"), Action::CursorEnd); + + // --- Editing --- + km.bind(key("ctrl-backspace"), Action::DeleteWordBefore); + km.bind(key("ctrl-h"), Action::DeleteCharBefore); + km.bind(key("ctrl-?"), Action::DeleteCharBefore); + km.bind(key("ctrl-delete"), Action::DeleteWordAfter); + km.bind(key("delete"), Action::DeleteCharAfter); + // ctrl-d: if input empty → return original, otherwise delete char + km.bind_conditional( + key("ctrl-d"), + vec![ + KeyRule::when(ConditionAtom::InputEmpty, Action::ReturnOriginal), + KeyRule::always(Action::DeleteCharAfter), + ], + ); + km.bind(key("ctrl-w"), Action::DeleteToWordBoundary); + km.bind(key("ctrl-u"), Action::ClearLine); + + // --- Search mode --- + km.bind(key("ctrl-r"), Action::CycleFilterMode); + km.bind(key("ctrl-s"), Action::CycleSearchMode); + + // --- Scroll (no exit) --- + km.bind(key("ctrl-n"), Action::SelectNext); + km.bind(key("ctrl-j"), Action::SelectNext); + km.bind(key("ctrl-p"), Action::SelectPrevious); + km.bind(key("ctrl-k"), Action::SelectPrevious); + + // --- Redraw --- + km.bind(key("ctrl-l"), Action::Redraw); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + km +} + +// --------------------------------------------------------------------------- +// Vim Normal keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-normal keymap. +pub fn default_vim_normal_keymap(settings: &Settings) -> Keymap { + let mut km = Keymap::new(); + add_common_bindings(&mut km); + + // esc / ctrl-[ → exit (vim-normal exits, unlike vim-insert) + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + + // Prefix key + let prefix_char = settings.keys.prefix.chars().next().unwrap_or('a'); + km.bind(key(&format!("ctrl-{prefix_char}")), Action::EnterPrefixMode); + + // --- Vim navigation --- + // j/k: scroll with optional exit at boundary. + let scroll_exits = settings.keys.scroll_exits; + let invert = settings.invert; + bind_scroll_key(&mut km, "j", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "k", Action::SelectPrevious, invert, scroll_exits); + km.bind(key("h"), Action::CursorLeft); + km.bind(key("l"), Action::CursorRight); + + // --- Vim cursor movement --- + km.bind(key("0"), Action::CursorStart); + km.bind(key("$"), Action::CursorEnd); + km.bind(key("w"), Action::CursorWordRight); + km.bind(key("b"), Action::CursorWordLeft); + km.bind(key("e"), Action::CursorWordEnd); + + // --- Vim editing --- + km.bind(key("x"), Action::DeleteCharAfter); + km.bind(key("d d"), Action::ClearLine); + km.bind(key("D"), Action::ClearToEnd); + km.bind(key("C"), Action::VimChangeToEnd); + + // --- Mode switching --- + km.bind(key("?"), Action::VimSearchInsert); + km.bind(key("/"), Action::VimSearchInsert); + km.bind(key("a"), Action::VimEnterInsertAfter); + km.bind(key("A"), Action::VimEnterInsertAtEnd); + km.bind(key("i"), Action::VimEnterInsert); + km.bind(key("I"), Action::VimEnterInsertAtStart); + + // --- Numeric shortcuts (return selection without executing) --- + for n in 1..=9u8 { + km.bind(key(&n.to_string()), Action::ReturnSelectionNth(n)); + } + + // --- Half/full page scroll --- + km.bind(key("ctrl-u"), Action::ScrollHalfPageUp); + km.bind(key("ctrl-d"), Action::ScrollHalfPageDown); + km.bind(key("ctrl-b"), Action::ScrollPageUp); + km.bind(key("ctrl-f"), Action::ScrollPageDown); + + // --- Jump --- + km.bind(key("G"), Action::ScrollToBottom); + km.bind(key("g g"), Action::ScrollToTop); + km.bind(key("H"), Action::ScrollToScreenTop); + km.bind(key("M"), Action::ScrollToScreenMiddle); + km.bind(key("L"), Action::ScrollToScreenBottom); + + // --- Arrow keys (same as emacs for convenience) --- + bind_scroll_key(&mut km, "down", Action::SelectNext, !invert, scroll_exits); + bind_scroll_key(&mut km, "up", Action::SelectPrevious, invert, scroll_exits); + + // --- Page scroll --- + km.bind(key("pagedown"), Action::ScrollPageDown); + km.bind(key("pageup"), Action::ScrollPageUp); + + // --- Accept --- + let accept = accept_action(settings); + km.bind(key("enter"), accept); + + km +} + +// --------------------------------------------------------------------------- +// Vim Insert keymap +// --------------------------------------------------------------------------- + +/// Build the default vim-insert keymap. This clones the emacs keymap and +/// overlays vim-insert-specific bindings (esc → enter normal mode). +pub fn default_vim_insert_keymap(settings: &Settings) -> Keymap { + let mut km = default_emacs_keymap(settings); + + // Override esc and ctrl-[ to enter normal mode instead of exiting + km.bind(key("esc"), Action::VimEnterNormal); + km.bind(key("ctrl-["), Action::VimEnterNormal); + + km +} + +// --------------------------------------------------------------------------- +// Inspector keymap +// --------------------------------------------------------------------------- + +/// Build the default inspector keymap (tab index 1). +/// +/// The inspector shows details about the selected history item and has no +/// text input, so we build a minimal keymap with only inspector-relevant +/// bindings. We respect the user's `keymap_mode` to provide vim-style j/k +/// navigation for vim users. +pub fn default_inspector_keymap(settings: &Settings) -> Keymap { + use crate::atuin_client::settings::KeymapMode; + + let mut km = Keymap::new(); + + // Common bindings (same as search tab) + km.bind(key("ctrl-c"), Action::ReturnOriginal); + km.bind(key("ctrl-g"), Action::ReturnOriginal); + km.bind(key("esc"), Action::Exit); + km.bind(key("ctrl-["), Action::Exit); + km.bind(key("tab"), Action::ReturnSelection); + km.bind(key("ctrl-o"), Action::ToggleTab); + + // Accept behavior respects enter_accept setting + let accept = if settings.enter_accept { + Action::Accept + } else { + Action::ReturnSelection + }; + km.bind(key("enter"), accept); + + // Inspector-specific: delete history entry + km.bind(key("ctrl-d"), Action::Delete); + + // Inspector navigation + km.bind(key("up"), Action::InspectPrevious); + km.bind(key("down"), Action::InspectNext); + km.bind(key("pageup"), Action::InspectPrevious); + km.bind(key("pagedown"), Action::InspectNext); + + // For vim users, add j/k navigation + if matches!( + settings.keymap_mode, + KeymapMode::VimNormal | KeymapMode::VimInsert + ) { + km.bind(key("j"), Action::InspectNext); + km.bind(key("k"), Action::InspectPrevious); + } + + km +} + +// --------------------------------------------------------------------------- +// Prefix keymap +// --------------------------------------------------------------------------- + +/// Build the default prefix keymap (active after ctrl-a prefix). +pub fn default_prefix_keymap() -> Keymap { + let mut km = Keymap::new(); + + km.bind(key("d"), Action::Delete); + km.bind(key("D"), Action::DeleteAll); + km.bind(key("a"), Action::CursorStart); + km.bind_conditional( + key("c"), + vec![ + KeyRule::when(ConditionAtom::HasContext, Action::ClearContext), + KeyRule::always(Action::SwitchContext), + ], + ); + + km +} + +// --------------------------------------------------------------------------- +// KeymapSet construction +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Config → Keymap conversion +// --------------------------------------------------------------------------- + +/// Convert a `KeyBindingConfig` (from TOML) into a `KeyBinding`. +/// Returns `Err` if an action name or condition expression is invalid. +fn parse_binding_config(config: &KeyBindingConfig) -> Result { + match config { + KeyBindingConfig::Simple(action_str) => { + let action = Action::from_str(action_str)?; + Ok(KeyBinding::simple(action)) + } + KeyBindingConfig::Rules(rules) => { + let mut parsed_rules = Vec::with_capacity(rules.len()); + for rule_cfg in rules { + let action = Action::from_str(&rule_cfg.action)?; + let rule = match &rule_cfg.when { + None => KeyRule::always(action), + Some(cond_str) => { + let cond = ConditionExpr::parse(cond_str)?; + KeyRule::when(cond, action) + } + }; + parsed_rules.push(rule); + } + Ok(KeyBinding::conditional(parsed_rules)) + } + } +} + +/// Apply a map of key-string → binding-config overrides to a keymap. +/// Per-key override replaces the entire rule list for that key. +/// Invalid keys or action names are logged and skipped. +fn apply_config_to_keymap(keymap: &mut Keymap, overrides: &HashMap) { + for (key_str, binding_cfg) in overrides { + let key = match KeyInput::parse(key_str) { + Ok(k) => k, + Err(e) => { + warn!("invalid key in keymap config: {key_str:?}: {e}"); + continue; + } + }; + match parse_binding_config(binding_cfg) { + Ok(binding) => { + keymap.bindings.insert(key, binding); + } + Err(e) => { + warn!("invalid binding for {key_str:?} in keymap config: {e}"); + } + } + } +} + +impl KeymapSet { + /// Build the complete set of default keymaps from settings. + pub fn defaults(settings: &Settings) -> Self { + KeymapSet { + emacs: default_emacs_keymap(settings), + vim_normal: default_vim_normal_keymap(settings), + vim_insert: default_vim_insert_keymap(settings), + inspector: default_inspector_keymap(settings), + prefix: default_prefix_keymap(), + } + } + + /// Build keymaps from settings, applying any user `[keymap]` overrides. + /// + /// Precedence rules: + /// - If `[keymap]` has any entries, `[keys]` is **ignored entirely**. + /// Defaults are built with standard `[keys]` values, then `[keymap]` + /// overrides are applied per-key. + /// - If `[keymap]` is empty/absent, `[keys]` customizes the defaults + /// (current behavior for backward compatibility). + pub fn from_settings(settings: &Settings) -> Self { + use crate::atuin_client::settings::Keys; + + if settings.keymap.is_empty() { + // No [keymap] section → use [keys] to customize defaults + Self::defaults(settings) + } else { + // [keymap] present → ignore [keys], use standard defaults as base + let mut base_settings = settings.clone(); + base_settings.keys = Keys::standard_defaults(); + let mut set = Self::defaults(&base_settings); + set.apply_config(settings); + set + } + } + + /// Apply user keymap config overrides to all modes. + fn apply_config(&mut self, settings: &Settings) { + let config = &settings.keymap; + apply_config_to_keymap(&mut self.emacs, &config.emacs); + apply_config_to_keymap(&mut self.vim_normal, &config.vim_normal); + apply_config_to_keymap(&mut self.vim_insert, &config.vim_insert); + apply_config_to_keymap(&mut self.inspector, &config.inspector); + apply_config_to_keymap(&mut self.prefix, &config.prefix); + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::command::client::search::keybindings::conditions::EvalContext; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + fn default_settings() -> Settings { + Settings::utc() + } + + // -- Emacs keymap tests -- + + #[test] + fn emacs_ctrl_c_returns_original() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_esc_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_tab_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + #[test] + fn emacs_right_at_end_returns_selection() { + let km = default_emacs_keymap(&default_settings()); + // cursor at end of "hello" (width 5) + let ctx = make_ctx(5, 5, 0, 10); + assert_eq!( + km.resolve(&key("right"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn emacs_right_not_at_end_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!(km.resolve(&key("right"), &ctx), Some(Action::CursorRight)); + } + + #[test] + fn emacs_left_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_left_not_at_start_moves() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(km.resolve(&key("left"), &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn emacs_down_at_start_exits() { + let km = default_emacs_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::Exit)); + } + + #[test] + fn emacs_down_not_at_start_selects_next() { + let km = default_emacs_keymap(&default_settings()); + // selected=5 → not at start → SelectNext + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_up_selects_previous() { + let km = default_emacs_keymap(&default_settings()); + // Non-inverted: up never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn emacs_ctrl_d_empty_returns_original() { + let km = default_emacs_keymap(&default_settings()); + // input empty (byte_len = 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + #[test] + fn emacs_ctrl_d_nonempty_deletes() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(2, 5, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::DeleteCharAfter) + ); + } + + #[test] + fn emacs_ctrl_n_selects_next_no_exit_condition() { + let km = default_emacs_keymap(&default_settings()); + // at start, but ctrl-n should NOT exit (no exit condition bound) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-n"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn emacs_prefix_key_enters_prefix() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-a"), &ctx), + Some(Action::EnterPrefixMode) + ); + } + + #[test] + fn emacs_home_cursor_start() { + let km = default_emacs_keymap(&default_settings()); + let ctx = make_ctx(5, 10, 0, 10); + assert_eq!(km.resolve(&key("home"), &ctx), Some(Action::CursorStart)); + } + + // -- Vim Normal keymap tests -- + + #[test] + fn vim_normal_j_at_start_exits() { + let km = default_vim_normal_keymap(&default_settings()); + // selected=0 → ListAtStart → Exit (non-inverted: j moves toward index 0) + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::Exit)); + } + + #[test] + fn vim_normal_j_not_at_start_selects_next() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("j"), &ctx), Some(Action::SelectNext)); + } + + #[test] + fn vim_normal_k_selects_previous() { + let km = default_vim_normal_keymap(&default_settings()); + // Non-inverted: k never exits (moves away from index 0) + let ctx = make_ctx(0, 0, 5, 10); + assert_eq!(km.resolve(&key("k"), &ctx), Some(Action::SelectPrevious)); + } + + #[test] + fn vim_normal_i_enters_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("i"), &ctx), Some(Action::VimEnterInsert)); + } + + #[test] + fn vim_normal_slash_search_insert() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("/"), &ctx), Some(Action::VimSearchInsert)); + } + + #[test] + fn vim_normal_gg_scroll_to_top() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("g g"), &ctx), Some(Action::ScrollToTop)); + } + + #[test] + fn vim_normal_big_g_scroll_to_bottom() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("G"), &ctx), Some(Action::ScrollToBottom)); + } + + #[test] + fn vim_normal_numeric_returns_selection() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("3"), &ctx), + Some(Action::ReturnSelectionNth(3)) + ); + } + + #[test] + fn vim_normal_ctrl_u_half_page_up() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!( + km.resolve(&key("ctrl-u"), &ctx), + Some(Action::ScrollHalfPageUp) + ); + } + + #[test] + fn vim_normal_screen_jumps() { + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 50, 100); + assert_eq!(km.resolve(&key("H"), &ctx), Some(Action::ScrollToScreenTop)); + assert_eq!( + km.resolve(&key("M"), &ctx), + Some(Action::ScrollToScreenMiddle) + ); + assert_eq!( + km.resolve(&key("L"), &ctx), + Some(Action::ScrollToScreenBottom) + ); + } + + #[test] + fn vim_normal_enter_returns_selection() { + // enter_accept=false in test defaults → ReturnSelection + let km = default_vim_normal_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_normal_enter_accept_true_uses_accept() { + let mut settings = default_settings(); + settings.enter_accept = true; + let km = default_vim_normal_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("enter"), &ctx), Some(Action::Accept)); + } + + // -- Vim Insert keymap tests -- + + #[test] + fn vim_insert_inherits_emacs_enter() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // enter_accept=false → ReturnSelection + assert_eq!( + km.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn vim_insert_esc_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::VimEnterNormal)); + } + + #[test] + fn vim_insert_ctrl_bracket_enters_normal() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + km.resolve(&key("ctrl-["), &ctx), + Some(Action::VimEnterNormal) + ); + } + + #[test] + fn vim_insert_inherits_emacs_ctrl_d() { + let km = default_vim_insert_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + // input empty → return original + assert_eq!( + km.resolve(&key("ctrl-d"), &ctx), + Some(Action::ReturnOriginal) + ); + } + + // -- Inspector keymap tests -- + + #[test] + fn inspector_ctrl_d_deletes() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("ctrl-d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn inspector_up_inspects_previous() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("up"), &ctx), Some(Action::InspectPrevious)); + } + + #[test] + fn inspector_down_inspects_next() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("down"), &ctx), Some(Action::InspectNext)); + } + + #[test] + fn inspector_esc_exits() { + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("esc"), &ctx), Some(Action::Exit)); + } + + #[test] + fn inspector_tab_returns_selection() { + // enter_accept=false → ReturnSelection + let km = default_inspector_keymap(&default_settings()); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("tab"), &ctx), Some(Action::ReturnSelection)); + } + + // -- Prefix keymap tests -- + + #[test] + fn prefix_d_deletes() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("d"), &ctx), Some(Action::Delete)); + } + + #[test] + fn prefix_a_cursor_start() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn prefix_unknown_key_returns_none() { + let km = default_prefix_keymap(); + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(km.resolve(&key("x"), &ctx), None); + } + + // -- KeymapSet tests -- + + #[test] + fn keymap_set_defaults_builds() { + let settings = default_settings(); + let set = KeymapSet::defaults(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // Sanity check each keymap has bindings + assert!(set.emacs.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_normal.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.vim_insert.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.inspector.resolve(&key("ctrl-c"), &ctx).is_some()); + assert!(set.prefix.resolve(&key("d"), &ctx).is_some()); + } + + // -- Settings-dependent behavior -- + + #[test] + fn custom_prefix_char() { + let mut settings = default_settings(); + settings.keys.prefix = "x".to_string(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-x should be prefix mode + assert_eq!( + km.resolve(&key("ctrl-x"), &ctx), + Some(Action::EnterPrefixMode) + ); + // ctrl-a should now be CursorStart (not prefix) + assert_eq!(km.resolve(&key("ctrl-a"), &ctx), Some(Action::CursorStart)); + } + + #[test] + fn ctrl_n_shortcuts_changes_numeric_modifier() { + let mut settings = default_settings(); + settings.ctrl_n_shortcuts = true; + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-1 should work + assert_eq!( + km.resolve(&key("ctrl-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + // alt-1 should NOT be bound + assert_eq!(km.resolve(&key("alt-1"), &ctx), None); + } + + #[test] + fn default_alt_numeric_shortcuts() { + let settings = default_settings(); + let km = default_emacs_keymap(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // alt-1 should work by default + assert_eq!( + km.resolve(&key("alt-1"), &ctx), + Some(Action::ReturnSelectionNth(1)) + ); + } + + // ----------------------------------------------------------------------- + // Config parsing and merging tests + // ----------------------------------------------------------------------- + + #[test] + fn parse_simple_binding_config() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("accept".to_string()); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 1); + assert!(binding.rules[0].condition.is_none()); + assert_eq!(binding.rules[0].action, Action::Accept); + } + + #[test] + fn parse_conditional_binding_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("cursor-at-start".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "cursor-left".to_string(), + }, + ]); + let binding = super::parse_binding_config(&cfg).unwrap(); + assert_eq!(binding.rules.len(), 2); + assert!(binding.rules[0].condition.is_some()); + assert_eq!(binding.rules[0].action, Action::Exit); + assert!(binding.rules[1].condition.is_none()); + assert_eq!(binding.rules[1].action, Action::CursorLeft); + } + + #[test] + fn parse_binding_config_invalid_action() { + use crate::atuin_client::settings::KeyBindingConfig; + let cfg = KeyBindingConfig::Simple("not-a-real-action".to_string()); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn parse_binding_config_invalid_condition() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + let cfg = KeyBindingConfig::Rules(vec![KeyRuleConfig { + when: Some("not-a-real-condition".to_string()), + action: "exit".to_string(), + }]); + assert!(super::parse_binding_config(&cfg).is_err()); + } + + #[test] + fn config_override_replaces_key() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + let set = KeymapSet::defaults(&settings); + + // Default: ctrl-c → ReturnOriginal + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("ctrl-c"), &ctx), + Some(Action::ReturnOriginal) + ); + + // Override ctrl-c → Exit via config + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + } + + #[test] + fn config_override_preserves_unoverridden_keys() { + use crate::atuin_client::settings::KeyBindingConfig; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override only ctrl-c; enter should keep its default + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + + let set = KeymapSet::from_settings(&settings); + let ctx = make_ctx(0, 0, 0, 10); + + // ctrl-c overridden + assert_eq!(set.emacs.resolve(&key("ctrl-c"), &ctx), Some(Action::Exit)); + // enter still has default (enter_accept=false → ReturnSelection) + assert_eq!( + set.emacs.resolve(&key("enter"), &ctx), + Some(Action::ReturnSelection) + ); + } + + #[test] + fn config_conditional_override() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Override "up" with a custom conditional + settings.keymap.emacs = HashMap::from([( + "up".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("no-results".to_string()), + action: "exit".to_string(), + }, + KeyRuleConfig { + when: None, + action: "select-previous".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // With no results → exit + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(set.emacs.resolve(&key("up"), &ctx), Some(Action::Exit)); + + // With results → select-previous + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!( + set.emacs.resolve(&key("up"), &ctx), + Some(Action::SelectPrevious) + ); + } + + #[test] + fn from_settings_with_empty_config_equals_defaults() { + let settings = default_settings(); + let defaults = KeymapSet::defaults(&settings); + let from_settings = KeymapSet::from_settings(&settings); + + // Verify a sample of keys produce the same results + let ctx = make_ctx(0, 0, 0, 10); + let test_keys = [ + "ctrl-c", "enter", "esc", "tab", "up", "down", "left", "right", + ]; + for k in &test_keys { + assert_eq!( + defaults.emacs.resolve(&key(k), &ctx), + from_settings.emacs.resolve(&key(k), &ctx), + "mismatch for emacs key {k}" + ); + } + } + + // ----------------------------------------------------------------------- + // Phase 5: [keys] vs [keymap] backward compatibility + // ----------------------------------------------------------------------- + + #[test] + fn keymap_overrides_ignore_keys_section() { + use crate::atuin_client::settings::KeyBindingConfig; + + // Set up: [keys] disables scroll_exits, but [keymap] is present + let mut settings = default_settings(); + settings.keys.scroll_exits = false; + + // Without [keymap], scroll_exits=false means no exit condition on down + let set_legacy = KeymapSet::defaults(&settings); + // At list-at-start (selected=0), down should still be SelectNext (no exit) + let ctx_at_boundary = make_ctx(0, 0, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::SelectNext), + "legacy: down at boundary should be SelectNext with scroll_exits=false" + ); + + // With [keymap] present (even just one override), [keys] is ignored + // so the standard defaults (scroll_exits=true) apply + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Not at boundary (selected=5): should SelectNext normally + let ctx_not_at_boundary = make_ctx(0, 0, 5, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_not_at_boundary), + Some(Action::SelectNext), + "keymap: down not at boundary should SelectNext" + ); + // At list-at-start (selected=0): should Exit (standard scroll_exits=true) + assert_eq!( + set_keymap.emacs.resolve(&key("down"), &ctx_at_boundary), + Some(Action::Exit), + "keymap: down at boundary should Exit (standard defaults restored)" + ); + } + + #[test] + fn keymap_present_resets_to_standard_keys_defaults() { + use crate::atuin_client::settings::KeyBindingConfig; + + let mut settings = default_settings(); + // Disable all [keys] behaviors + settings.keys.exit_past_line_start = false; + settings.keys.accept_past_line_end = false; + + // Without [keymap], left should be plain CursorLeft + let set_legacy = KeymapSet::defaults(&settings); + let ctx_at_start = make_ctx(0, 5, 0, 10); + assert_eq!( + set_legacy.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::CursorLeft), + "legacy: left should be plain CursorLeft without exit_past_line_start" + ); + + // Add a [keymap] entry (for a different key) + settings.keymap.emacs = HashMap::from([( + "ctrl-c".to_string(), + KeyBindingConfig::Simple("exit".to_string()), + )]); + let set_keymap = KeymapSet::from_settings(&settings); + + // Now left should use standard defaults (exit_past_line_start=true) + // At cursor start → Exit + assert_eq!( + set_keymap.emacs.resolve(&key("left"), &ctx_at_start), + Some(Action::Exit), + "keymap: left at cursor start should exit (standard defaults)" + ); + + // Right at cursor end should return selection (standard defaults: accept_past_line_end=true, enter_accept=false) + let ctx_at_end = make_ctx(5, 5, 0, 10); + assert_eq!( + set_keymap.emacs.resolve(&key("right"), &ctx_at_end), + Some(Action::ReturnSelection), + "keymap: right at cursor end should return selection (standard defaults)" + ); + } + + #[test] + fn keys_has_non_default_values_detection() { + use crate::atuin_client::settings::Keys; + + let standard = Keys::standard_defaults(); + assert!(!standard.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.scroll_exits = false; + assert!(modified.has_non_default_values()); + + let mut modified = Keys::standard_defaults(); + modified.prefix = "x".to_string(); + assert!(modified.has_non_default_values()); + } + + #[test] + fn original_input_empty_condition_in_config() { + use crate::atuin_client::settings::{KeyBindingConfig, KeyRuleConfig}; + use std::collections::HashMap; + + let mut settings = default_settings(); + // Configure esc to: if original-input-empty -> return-query, else return-original + settings.keymap.emacs = HashMap::from([( + "esc".to_string(), + KeyBindingConfig::Rules(vec![ + KeyRuleConfig { + when: Some("original-input-empty".to_string()), + action: "return-query".to_string(), + }, + KeyRuleConfig { + when: None, + action: "return-original".to_string(), + }, + ]), + )]); + + let set = KeymapSet::from_settings(&settings); + + // When original input was empty, should return-query + let ctx_original_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: true, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_empty), + Some(Action::ReturnQuery), + "esc with original_input_empty=true should return-query" + ); + + // When original input was not empty, should return-original + let ctx_original_not_empty = EvalContext { + cursor_position: 0, + input_width: 5, + input_byte_len: 5, + selected_index: 0, + results_len: 10, + original_input_empty: false, + has_context: false, + }; + assert_eq!( + set.emacs.resolve(&key("esc"), &ctx_original_not_empty), + Some(Action::ReturnOriginal), + "esc with original_input_empty=false should return-original" + ); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/key.rs b/crates/turtle/src/command/client/search/keybindings/key.rs new file mode 100644 index 00000000..c2eb31c6 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/key.rs @@ -0,0 +1,629 @@ +use std::fmt; + +use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers, MediaKeyCode}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A single key press with modifiers (e.g. `ctrl-c`, `alt-f`, `enter`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[expect(clippy::struct_excessive_bools)] +pub struct SingleKey { + pub code: KeyCodeValue, + pub ctrl: bool, + pub alt: bool, + pub shift: bool, + pub super_key: bool, +} + +/// The key code portion of a key press. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KeyCodeValue { + Char(char), + Enter, + Esc, + Tab, + Backspace, + Delete, + Insert, + Up, + Down, + Left, + Right, + Home, + End, + PageUp, + PageDown, + Space, + F(u8), + Media(MediaKeyCode), +} + +/// A key input that may be a single key or a multi-key sequence (e.g. `g g`). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum KeyInput { + Single(SingleKey), + Sequence(Vec), +} + +impl SingleKey { + /// Convert a crossterm `KeyEvent` into a `SingleKey`. + pub fn from_event(event: &KeyEvent) -> Option { + let ctrl = event.modifiers.contains(KeyModifiers::CONTROL); + let alt = event.modifiers.contains(KeyModifiers::ALT); + let shift = event.modifiers.contains(KeyModifiers::SHIFT); + let super_key = event.modifiers.contains(KeyModifiers::SUPER); + + let code = match event.code { + KeyCode::Char(' ') => KeyCodeValue::Space, + KeyCode::Char(c) => { + // If shift is the only modifier and it's an uppercase letter, + // we store the uppercase char directly and clear the shift flag + // since the case already encodes it. + if shift && !ctrl && !alt && !super_key && c.is_ascii_uppercase() { + return Some(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } + KeyCode::Enter => KeyCodeValue::Enter, + KeyCode::Esc => KeyCodeValue::Esc, + KeyCode::Tab => KeyCodeValue::Tab, + // BackTab is sent by many terminals for Shift+Tab + KeyCode::BackTab => { + return Some(SingleKey { + code: KeyCodeValue::Tab, + ctrl, + alt, + shift: true, + super_key, + }); + } + KeyCode::Backspace => KeyCodeValue::Backspace, + KeyCode::Delete => KeyCodeValue::Delete, + KeyCode::Insert => KeyCodeValue::Insert, + KeyCode::Up => KeyCodeValue::Up, + KeyCode::Down => KeyCodeValue::Down, + KeyCode::Left => KeyCodeValue::Left, + KeyCode::Right => KeyCodeValue::Right, + KeyCode::Home => KeyCodeValue::Home, + KeyCode::End => KeyCodeValue::End, + KeyCode::PageUp => KeyCodeValue::PageUp, + KeyCode::PageDown => KeyCodeValue::PageDown, + KeyCode::F(n) => KeyCodeValue::F(n), + KeyCode::Media(m) => KeyCodeValue::Media(m), + _ => return None, + }; + + Some(SingleKey { + code, + ctrl, + alt, + shift: if matches!(code, KeyCodeValue::Char(_)) { + false + } else { + shift + }, + super_key, + }) + } + + /// Parse a key string like `"ctrl-c"`, `"alt-f"`, `"enter"`, `"G"`. + pub fn parse(s: &str) -> Result { + let s = s.trim(); + let parts: Vec<&str> = s.split('-').collect(); + + let mut ctrl = false; + let mut alt = false; + let mut shift = false; + let mut super_key = false; + + // All parts except the last are modifiers + for &part in &parts[..parts.len() - 1] { + match part.to_lowercase().as_str() { + "ctrl" => ctrl = true, + "alt" => alt = true, + "shift" => shift = true, + "super" | "cmd" | "win" => super_key = true, + _ => return Err(format!("unknown modifier: {part}")), + } + } + + let key_part = parts[parts.len() - 1]; + let code = match key_part.to_lowercase().as_str() { + "enter" | "return" => KeyCodeValue::Enter, + "esc" | "escape" => KeyCodeValue::Esc, + "tab" => KeyCodeValue::Tab, + "backspace" => KeyCodeValue::Backspace, + "delete" | "del" => KeyCodeValue::Delete, + "insert" | "ins" => KeyCodeValue::Insert, + "up" => KeyCodeValue::Up, + "down" => KeyCodeValue::Down, + "left" => KeyCodeValue::Left, + "right" => KeyCodeValue::Right, + "home" => KeyCodeValue::Home, + "end" => KeyCodeValue::End, + "pageup" => KeyCodeValue::PageUp, + "pagedown" => KeyCodeValue::PageDown, + "space" => KeyCodeValue::Space, + s if s.starts_with('f') && s.len() > 1 => { + // Parse function keys like "f1", "f12" + if let Ok(n) = s[1..].parse::() { + if (1..=24).contains(&n) { + KeyCodeValue::F(n) + } else { + return Err(format!("function key out of range: {key_part}")); + } + } else { + return Err(format!("unknown key: {key_part}")); + } + } + "[" => KeyCodeValue::Char('['), + "]" => KeyCodeValue::Char(']'), + "?" => KeyCodeValue::Char('?'), + "/" => KeyCodeValue::Char('/'), + "$" => KeyCodeValue::Char('$'), + // Media keys (no dashes - the parser splits on dash for modifiers) + "play" => KeyCodeValue::Media(MediaKeyCode::Play), + "pause" => KeyCodeValue::Media(MediaKeyCode::Pause), + "playpause" => KeyCodeValue::Media(MediaKeyCode::PlayPause), + "stop" => KeyCodeValue::Media(MediaKeyCode::Stop), + "fastforward" => KeyCodeValue::Media(MediaKeyCode::FastForward), + "rewind" => KeyCodeValue::Media(MediaKeyCode::Rewind), + "tracknext" => KeyCodeValue::Media(MediaKeyCode::TrackNext), + "trackprevious" => KeyCodeValue::Media(MediaKeyCode::TrackPrevious), + "record" => KeyCodeValue::Media(MediaKeyCode::Record), + "lowervolume" => KeyCodeValue::Media(MediaKeyCode::LowerVolume), + "raisevolume" => KeyCodeValue::Media(MediaKeyCode::RaiseVolume), + "mutevolume" | "mute" => KeyCodeValue::Media(MediaKeyCode::MuteVolume), + _ => { + let chars: Vec = key_part.chars().collect(); + if chars.len() == 1 { + let c = chars[0]; + // An uppercase letter implies shift (unless shift already specified) + if c.is_ascii_uppercase() && !ctrl && !alt && !super_key { + return Ok(SingleKey { + code: KeyCodeValue::Char(c), + ctrl: false, + alt: false, + shift: false, + super_key: false, + }); + } + KeyCodeValue::Char(c) + } else { + return Err(format!("unknown key: {key_part}")); + } + } + }; + + Ok(SingleKey { + code, + ctrl, + alt, + shift, + super_key, + }) + } +} + +impl fmt::Display for SingleKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.super_key { + write!(f, "super-")?; + } + if self.ctrl { + write!(f, "ctrl-")?; + } + if self.alt { + write!(f, "alt-")?; + } + if self.shift { + write!(f, "shift-")?; + } + match &self.code { + KeyCodeValue::Char(c) => write!(f, "{c}"), + KeyCodeValue::Enter => write!(f, "enter"), + KeyCodeValue::Esc => write!(f, "esc"), + KeyCodeValue::Tab => write!(f, "tab"), + KeyCodeValue::Backspace => write!(f, "backspace"), + KeyCodeValue::Delete => write!(f, "delete"), + KeyCodeValue::Insert => write!(f, "insert"), + KeyCodeValue::Up => write!(f, "up"), + KeyCodeValue::Down => write!(f, "down"), + KeyCodeValue::Left => write!(f, "left"), + KeyCodeValue::Right => write!(f, "right"), + KeyCodeValue::Home => write!(f, "home"), + KeyCodeValue::End => write!(f, "end"), + KeyCodeValue::PageUp => write!(f, "pageup"), + KeyCodeValue::PageDown => write!(f, "pagedown"), + KeyCodeValue::Space => write!(f, "space"), + KeyCodeValue::F(n) => write!(f, "f{n}"), + KeyCodeValue::Media(m) => match m { + MediaKeyCode::Play => write!(f, "play"), + MediaKeyCode::Pause => write!(f, "media-pause"), + MediaKeyCode::PlayPause => write!(f, "playpause"), + MediaKeyCode::Stop => write!(f, "stop"), + MediaKeyCode::FastForward => write!(f, "fastforward"), + MediaKeyCode::Rewind => write!(f, "rewind"), + MediaKeyCode::TrackNext => write!(f, "tracknext"), + MediaKeyCode::TrackPrevious => write!(f, "trackprevious"), + MediaKeyCode::Record => write!(f, "record"), + MediaKeyCode::LowerVolume => write!(f, "lowervolume"), + MediaKeyCode::RaiseVolume => write!(f, "raisevolume"), + MediaKeyCode::MuteVolume => write!(f, "mutevolume"), + MediaKeyCode::Reverse => write!(f, "reverse"), + }, + } + } +} + +impl KeyInput { + /// Parse a key input string. Supports multi-key sequences separated by spaces + /// (e.g. `"g g"`). + pub fn parse(s: &str) -> Result { + let s = s.trim(); + // Check for space-separated multi-key sequences + // But don't split "space" or modifier combos like "ctrl-a" + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.len() > 1 { + let keys: Result, String> = + parts.iter().map(|p| SingleKey::parse(p)).collect(); + Ok(KeyInput::Sequence(keys?)) + } else { + Ok(KeyInput::Single(SingleKey::parse(s)?)) + } + } +} + +impl fmt::Display for KeyInput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + KeyInput::Single(k) => write!(f, "{k}"), + KeyInput::Sequence(keys) => { + for (i, k) in keys.iter().enumerate() { + if i > 0 { + write!(f, " ")?; + } + write!(f, "{k}")?; + } + Ok(()) + } + } + } +} + +impl Serialize for KeyInput { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for KeyInput { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + KeyInput::parse(&s).map_err(serde::de::Error::custom) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + #[test] + fn parse_simple_keys() { + let k = SingleKey::parse("a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("enter").unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + + let k = SingleKey::parse("esc").unwrap(); + assert_eq!(k.code, KeyCodeValue::Esc); + + let k = SingleKey::parse("tab").unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + + let k = SingleKey::parse("space").unwrap(); + assert_eq!(k.code, KeyCodeValue::Space); + } + + #[test] + fn parse_modifiers() { + let k = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let k = SingleKey::parse("alt-f").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('f')); + assert!(k.alt); + assert!(!k.ctrl); + + let k = SingleKey::parse("ctrl-alt-x").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('x')); + assert!(k.ctrl && k.alt); + } + + #[test] + fn parse_uppercase_implies_no_shift_flag() { + let k = SingleKey::parse("G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(!k.shift); + assert!(!k.ctrl); + } + + #[test] + fn parse_special_chars() { + let k = SingleKey::parse("ctrl-[").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('[')); + assert!(k.ctrl); + + let k = SingleKey::parse("?").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('?')); + + let k = SingleKey::parse("/").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('/')); + } + + #[test] + fn parse_multi_key_sequence() { + let ki = KeyInput::parse("g g").unwrap(); + match ki { + KeyInput::Sequence(keys) => { + assert_eq!(keys.len(), 2); + assert_eq!(keys[0].code, KeyCodeValue::Char('g')); + assert_eq!(keys[1].code, KeyCodeValue::Char('g')); + } + _ => panic!("expected sequence"), + } + } + + #[test] + fn display_round_trip() { + let cases = ["ctrl-c", "alt-f", "enter", "G", "tab", "pageup"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + + let ki = KeyInput::parse("g g").unwrap(); + assert_eq!(ki.to_string(), "g g"); + } + + #[test] + fn from_event_basic() { + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.ctrl); + assert!(!k.alt); + + let event = KeyEvent::new(KeyCode::Enter, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Enter); + } + + #[test] + fn from_event_uppercase() { + // Crossterm sends uppercase chars with SHIFT modifier + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + // shift flag should be cleared since the case encodes it + assert!(!k.shift); + } + + #[test] + fn from_event_matches_parsed() { + // Verify that from_event and parse produce the same SingleKey + let event = KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("ctrl-c").unwrap(); + assert_eq!(from_event, parsed); + + let event = KeyEvent::new(KeyCode::Char('G'), KeyModifiers::SHIFT); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("G").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn parse_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + + // "cmd" is an alias for "super" + let k2 = SingleKey::parse("cmd-a").unwrap(); + assert_eq!(k, k2); + + // "win" is an alias for "super" + let k3 = SingleKey::parse("win-a").unwrap(); + assert_eq!(k, k3); + } + + #[test] + fn parse_super_with_other_modifiers() { + let k = SingleKey::parse("super-ctrl-c").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('c')); + assert!(k.super_key && k.ctrl); + assert!(!k.alt && !k.shift); + } + + #[test] + fn display_super_modifier() { + let k = SingleKey::parse("super-a").unwrap(); + assert_eq!(k.to_string(), "super-a"); + + let k = SingleKey::parse("super-ctrl-x").unwrap(); + assert_eq!(k.to_string(), "super-ctrl-x"); + } + + #[test] + fn display_round_trip_super() { + let k = KeyInput::parse("super-a").unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for super-a"); + } + + #[test] + fn from_event_super() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('a')); + assert!(k.super_key); + assert!(!k.ctrl && !k.alt && !k.shift); + } + + #[test] + fn from_event_super_matches_parsed() { + let event = KeyEvent::new(KeyCode::Char('a'), KeyModifiers::SUPER); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("super-a").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn super_uppercase_preserves_super() { + // super-G should keep the super flag (unlike bare "G" which clears shift) + let k = SingleKey::parse("super-G").unwrap(); + assert_eq!(k.code, KeyCodeValue::Char('G')); + assert!(k.super_key); + } + + #[test] + fn parse_errors() { + assert!(SingleKey::parse("ctrl-alt-shift-xxx").is_err()); + assert!(SingleKey::parse("foobar-a").is_err()); + } + + #[test] + fn parse_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("F12").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + + let k = SingleKey::parse("ctrl-f5").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(5)); + assert!(k.ctrl); + + // F24 is valid (some keyboards have extended function keys) + let k = SingleKey::parse("f24").unwrap(); + assert_eq!(k.code, KeyCodeValue::F(24)); + + // F0 and F25+ are invalid + assert!(SingleKey::parse("f0").is_err()); + assert!(SingleKey::parse("f25").is_err()); + } + + #[test] + fn from_event_function_keys() { + let event = KeyEvent::new(KeyCode::F(1), KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(1)); + + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::F(12)); + assert!(k.ctrl); + } + + #[test] + fn display_function_keys() { + let k = SingleKey::parse("f1").unwrap(); + assert_eq!(k.to_string(), "f1"); + + let k = SingleKey::parse("ctrl-f12").unwrap(); + assert_eq!(k.to_string(), "ctrl-f12"); + } + + #[test] + fn function_key_round_trip() { + let cases = ["f1", "f12", "ctrl-f5", "alt-f10"]; + for s in cases { + let k = KeyInput::parse(s).unwrap(); + let display = k.to_string(); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2, "round-trip failed for {s}"); + } + } + + #[test] + fn from_event_function_key_matches_parsed() { + let event = KeyEvent::new(KeyCode::F(12), KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("f12").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_becomes_shift_tab() { + // Many terminals send BackTab for Shift+Tab + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(!k.ctrl && !k.alt); + } + + #[test] + fn from_event_backtab_matches_parsed_shift_tab() { + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::NONE); + let from_event = SingleKey::from_event(&event).unwrap(); + let parsed = SingleKey::parse("shift-tab").unwrap(); + assert_eq!(from_event, parsed); + } + + #[test] + fn from_event_backtab_with_ctrl() { + // BackTab with ctrl modifier + let event = KeyEvent::new(KeyCode::BackTab, KeyModifiers::CONTROL); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Tab); + assert!(k.shift); + assert!(k.ctrl); + } + + #[test] + fn parse_insert_key() { + let k = SingleKey::parse("insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(!k.ctrl && !k.alt && !k.shift); + + let k = SingleKey::parse("ins").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + + let k = SingleKey::parse("ctrl-insert").unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + assert!(k.ctrl); + } + + #[test] + fn from_event_insert_key() { + let event = KeyEvent::new(KeyCode::Insert, KeyModifiers::NONE); + let k = SingleKey::from_event(&event).unwrap(); + assert_eq!(k.code, KeyCodeValue::Insert); + } + + #[test] + fn insert_key_round_trip() { + let k = KeyInput::parse("insert").unwrap(); + let display = k.to_string(); + assert_eq!(display, "insert"); + let k2 = KeyInput::parse(&display).unwrap(); + assert_eq!(k, k2); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/keymap.rs b/crates/turtle/src/command/client/search/keybindings/keymap.rs new file mode 100644 index 00000000..0d362863 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/keymap.rs @@ -0,0 +1,233 @@ +use std::collections::HashMap; + +use super::actions::Action; +use super::conditions::{ConditionExpr, EvalContext}; +use super::key::{KeyInput, SingleKey}; + +/// A single rule within a keybinding: an optional condition and an action. +/// If the condition is `None`, the rule always matches. +#[derive(Debug, Clone)] +pub struct KeyRule { + pub condition: Option, + pub action: Action, +} + +/// A keybinding is an ordered list of rules. The first rule whose condition +/// matches (or has no condition) wins. +#[derive(Debug, Clone)] +pub struct KeyBinding { + pub rules: Vec, +} + +/// A keymap is a collection of keybindings indexed by key input. +#[derive(Debug, Clone)] +pub struct Keymap { + pub bindings: HashMap, +} + +impl KeyRule { + /// Create an unconditional rule. + pub fn always(action: Action) -> Self { + KeyRule { + condition: None, + action, + } + } + + /// Create a conditional rule. Accepts any type convertible to `ConditionExpr`, + /// including bare `ConditionAtom` values. + pub fn when(condition: impl Into, action: Action) -> Self { + KeyRule { + condition: Some(condition.into()), + action, + } + } +} + +impl KeyBinding { + /// Create a simple (unconditional) binding. + pub fn simple(action: Action) -> Self { + KeyBinding { + rules: vec![KeyRule::always(action)], + } + } + + /// Create a conditional binding from a list of rules. + pub fn conditional(rules: Vec) -> Self { + KeyBinding { rules } + } +} + +impl Keymap { + /// Create an empty keymap. + pub fn new() -> Self { + Keymap { + bindings: HashMap::new(), + } + } + + /// Bind a key input to a simple (unconditional) action. + pub fn bind(&mut self, key: KeyInput, action: Action) { + self.bindings.insert(key, KeyBinding::simple(action)); + } + + /// Bind a key input to a conditional set of rules. + pub fn bind_conditional(&mut self, key: KeyInput, rules: Vec) { + self.bindings.insert(key, KeyBinding::conditional(rules)); + } + + /// Resolve a key input to an action given the current evaluation context. + /// Returns `None` if the key has no binding or no rule's condition matches. + pub fn resolve(&self, key: &KeyInput, ctx: &EvalContext) -> Option { + let binding = self.bindings.get(key)?; + for rule in &binding.rules { + match &rule.condition { + None => return Some(rule.action.clone()), + Some(cond) if cond.evaluate(ctx) => return Some(rule.action.clone()), + Some(_) => {} + } + } + None + } + + /// Check if any binding starts with the given single key as the first key + /// of a multi-key sequence. Used to detect pending multi-key sequences. + pub fn has_sequence_starting_with(&self, prefix: &SingleKey) -> bool { + self.bindings.keys().any(|ki| match ki { + KeyInput::Sequence(keys) => keys.first() == Some(prefix), + KeyInput::Single(_) => false, + }) + } + + /// Merge another keymap into this one. Keys from `other` override keys in `self`. + #[expect(dead_code)] + pub fn merge(&mut self, other: &Keymap) { + for (key, binding) in &other.bindings { + self.bindings.insert(key.clone(), binding.clone()); + } + } +} + +impl Default for Keymap { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::super::conditions::ConditionAtom; + use super::*; + + fn make_ctx(cursor: usize, width: usize, selected: usize, len: usize) -> EvalContext { + EvalContext { + cursor_position: cursor, + input_width: width, + input_byte_len: width, + selected_index: selected, + results_len: len, + original_input_empty: false, + has_context: false, + } + } + + #[test] + fn simple_binding_resolves() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + keymap.bind(key.clone(), Action::ReturnOriginal); + + let ctx = make_ctx(0, 0, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::ReturnOriginal)); + } + + #[test] + fn conditional_first_match_wins() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + keymap.bind_conditional( + key.clone(), + vec![ + KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit), + KeyRule::always(Action::CursorLeft), + ], + ); + + // Cursor at start → Exit + let ctx = make_ctx(0, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::Exit)); + + // Cursor not at start → CursorLeft + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), Some(Action::CursorLeft)); + } + + #[test] + fn no_match_returns_none() { + let keymap = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn conditional_no_condition_matches_returns_none() { + let mut keymap = Keymap::new(); + let key = KeyInput::parse("left").unwrap(); + // Only one rule with a condition that won't match + keymap.bind_conditional( + key.clone(), + vec![KeyRule::when(ConditionAtom::CursorAtStart, Action::Exit)], + ); + + // Cursor not at start → no match + let ctx = make_ctx(3, 5, 0, 10); + assert_eq!(keymap.resolve(&key, &ctx), None); + } + + #[test] + fn has_sequence_starting_with() { + let mut keymap = Keymap::new(); + let seq = KeyInput::parse("g g").unwrap(); + keymap.bind(seq, Action::ScrollToTop); + + let g = SingleKey::parse("g").unwrap(); + assert!(keymap.has_sequence_starting_with(&g)); + + let h = SingleKey::parse("h").unwrap(); + assert!(!keymap.has_sequence_starting_with(&h)); + } + + #[test] + fn merge_overrides() { + let mut base = Keymap::new(); + let key = KeyInput::parse("ctrl-c").unwrap(); + base.bind(key.clone(), Action::ReturnOriginal); + + let mut overlay = Keymap::new(); + overlay.bind(key.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key, &ctx), Some(Action::Exit)); + } + + #[test] + fn merge_preserves_unoverridden() { + let mut base = Keymap::new(); + let key1 = KeyInput::parse("ctrl-c").unwrap(); + let key2 = KeyInput::parse("ctrl-d").unwrap(); + base.bind(key1.clone(), Action::ReturnOriginal); + base.bind(key2.clone(), Action::DeleteCharAfter); + + let mut overlay = Keymap::new(); + overlay.bind(key1.clone(), Action::Exit); + + base.merge(&overlay); + + let ctx = make_ctx(0, 0, 0, 0); + assert_eq!(base.resolve(&key1, &ctx), Some(Action::Exit)); + assert_eq!(base.resolve(&key2, &ctx), Some(Action::DeleteCharAfter)); + } +} diff --git a/crates/turtle/src/command/client/search/keybindings/mod.rs b/crates/turtle/src/command/client/search/keybindings/mod.rs new file mode 100644 index 00000000..3b6eb2b2 --- /dev/null +++ b/crates/turtle/src/command/client/search/keybindings/mod.rs @@ -0,0 +1,14 @@ +pub mod actions; +pub mod conditions; +pub mod defaults; +pub mod key; +pub mod keymap; + +pub use actions::Action; +#[expect(unused_imports)] +pub use conditions::{ConditionAtom, ConditionExpr, EvalContext}; +pub use defaults::KeymapSet; +#[expect(unused_imports)] +pub use key::{KeyCodeValue, KeyInput, SingleKey}; +#[expect(unused_imports)] +pub use keymap::{KeyBinding, KeyRule, Keymap}; diff --git a/crates/turtle/src/command/client/server.rs b/crates/turtle/src/command/client/server.rs new file mode 100644 index 00000000..7de27551 --- /dev/null +++ b/crates/turtle/src/command/client/server.rs @@ -0,0 +1,61 @@ +use std::net::SocketAddr; + +use crate::atuin_server::{Settings, launch, launch_metrics_server}; +use crate::atuin_server_database::DbType; +use crate::atuin_server_postgres::Postgres; +use crate::atuin_server_sqlite::Sqlite; + +use clap::Subcommand; +use eyre::{Context, Result, eyre}; + +#[derive(Subcommand, Clone, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Start the server + Start { + /// The host address to bind + #[clap(long)] + host: Option, + + /// The port to bind + #[clap(long, short)] + port: Option, + }, + + /// Print server example configuration + DefaultConfig, +} + +impl Cmd { + #[expect(clippy::too_many_lines)] + pub async fn run(self) -> Result<()> { + match self { + Cmd::Start { host, port } => { + let settings = Settings::new().wrap_err("could not load server settings")?; + let host = host.as_ref().unwrap_or(&settings.host).clone(); + let port = port.unwrap_or(settings.port); + let addr = SocketAddr::new(host.parse()?, port); + + if settings.metrics.enable { + tokio::spawn(launch_metrics_server( + settings.metrics.host.clone(), + settings.metrics.port, + )); + } + + match settings.db_settings.db_type() { + DbType::Postgres => launch::(settings, addr).await, + DbType::Sqlite => launch::(settings, addr).await, + DbType::Unknown => { + Err(eyre!("db_uri must start with postgres:// or sqlite://")) + } + } + } + Cmd::DefaultConfig => { + // TODO(@bpeetz): Add this back <2026-06-11> + println!("TODO"); + Ok(()) + } + } + } +} diff --git a/crates/turtle/src/command/client/setup.rs b/crates/turtle/src/command/client/setup.rs new file mode 100644 index 00000000..b32ceb97 --- /dev/null +++ b/crates/turtle/src/command/client/setup.rs @@ -0,0 +1,81 @@ +use crate::atuin_client::settings::Settings; + +use colored::Colorize; +use eyre::Result; +use std::io::{self, Write}; +use toml_edit::{DocumentMut, value}; + +pub async fn run(_settings: &Settings) -> Result<()> { + let enable_ai = prompt( + "Atuin AI", + "This will enable command generation and other AI features via the question mark key", + Some( + "By default, Atuin AI only has access to the name and version of your operating system and shell - your shell history is not sent to the AI.", + ), + )?; + + let enable_daemon = prompt( + "Atuin Daemon", + "This will enable improved search and history sync using a persistent background process", + None, + )?; + + let config_file = Settings::get_config_path()?; + let config_str = tokio::fs::read_to_string(&config_file).await?; + let mut doc = config_str.parse::()?; + + let mut changed = false; + if enable_ai { + changed = true; + if !doc.contains_key("ai") { + doc["ai"] = toml_edit::table(); + } + doc["ai"]["enabled"] = value(true); + } + + if enable_daemon { + changed = true; + if !doc.contains_key("daemon") { + doc["daemon"] = toml_edit::table(); + } + doc["daemon"]["enabled"] = value(true); + doc["daemon"]["autostart"] = value(true); + doc["search_mode"] = value("daemon-fuzzy"); + } + + if changed { + tokio::fs::write(config_file, doc.to_string()).await?; + + println!( + "{check} Settings updated successfully", + check = "✓".bold().bright_green() + ); + } else { + println!( + "{check} No settings changed", + check = "✓".bold().bright_green() + ); + } + + Ok(()) +} + +pub fn prompt(feature: &str, description: &str, note: Option<&str>) -> Result { + println!( + "> Enable {feature}?", + feature = feature.bold().bright_blue() + ); + if let Some(note) = note { + println!(" {description}"); + print!(" {note} {q} ", q = "[Y/n]".bold()); + } else { + print!(" {description} {q} ", q = "[Y/n]".bold()); + } + + io::stdout().flush().ok(); + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let answer = input.trim().to_lowercase(); + Ok(answer.is_empty() || answer == "y" || answer == "yes") +} diff --git a/crates/turtle/src/command/client/stats.rs b/crates/turtle/src/command/client/stats.rs new file mode 100644 index 00000000..fc10e949 --- /dev/null +++ b/crates/turtle/src/command/client/stats.rs @@ -0,0 +1,85 @@ +use clap::Parser; +use eyre::Result; +use interim::parse_date_string; +use time::{Duration, OffsetDateTime, Time}; + +use crate::atuin_client::{ + database::{Database, current_context}, + settings::Settings, + theme::Theme, +}; + +use crate::atuin_history::stats::{compute, pretty_print}; + +fn parse_ngram_size(s: &str) -> Result { + let value = s + .parse::() + .map_err(|_| format!("'{s}' is not a valid window size"))?; + + if value == 0 { + return Err("ngram window size must be at least 1".to_string()); + } + + Ok(value) +} + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub struct Cmd { + /// Compute statistics for the specified period, leave blank for statistics since the beginning. See [this](https://docs.atuin.sh/reference/stats/) for more details. + period: Vec, + + /// How many top commands to list + #[arg(long, short, default_value = "10")] + count: usize, + + /// The number of consecutive commands to consider + #[arg(long, short, default_value = "1", value_parser = parse_ngram_size)] + ngram_size: usize, +} + +impl Cmd { + pub async fn run(&self, db: &impl Database, settings: &Settings, theme: &Theme) -> Result<()> { + let context = current_context().await?; + let words = if self.period.is_empty() { + String::from("all") + } else { + self.period.join(" ") + }; + + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let last_night = now.replace_time(Time::MIDNIGHT); + + let history = if words.as_str() == "all" { + db.list(&[], &context, None, false, false).await? + } else if words.trim() == "today" { + let start = last_night; + let end = start + Duration::days(1); + db.range(start, end).await? + } else if words.trim() == "month" { + let end = last_night; + let start = end - Duration::days(31); + db.range(start, end).await? + } else if words.trim() == "week" { + let end = last_night; + let start = end - Duration::days(7); + db.range(start, end).await? + } else if words.trim() == "year" { + let end = last_night; + let start = end - Duration::days(365); + db.range(start, end).await? + } else { + let start = parse_date_string(&words, now, settings.dialect.into())?; + let end = start + Duration::days(1); + db.range(start, end).await? + }; + + let stats = compute(settings, &history, self.count, self.ngram_size); + + if let Some(stats) = stats { + pretty_print(stats, self.ngram_size, theme); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store.rs b/crates/turtle/src/command/client/store.rs new file mode 100644 index 00000000..dfa3b66c --- /dev/null +++ b/crates/turtle/src/command/client/store.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; +use itertools::Itertools; +use time::{OffsetDateTime, UtcOffset}; + +#[cfg(feature = "sync")] +mod push; + +#[cfg(feature = "sync")] +mod pull; + +mod purge; +mod rebuild; +mod rekey; +mod verify; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Print the current status of the record store + Status, + + /// Rebuild a store (eg atuin store rebuild history) + Rebuild(rebuild::Rebuild), + + /// Re-encrypt the store with a new key (potential for data loss!) + Rekey(rekey::Rekey), + + /// Delete all records in the store that cannot be decrypted with the current key + Purge(purge::Purge), + + /// Verify that all records in the store can be decrypted with the current key + Verify(verify::Verify), + + /// Push all records to the remote sync server (one way sync) + #[cfg(feature = "sync")] + Push(push::Push), + + /// Pull records from the remote sync server (one way sync) + #[cfg(feature = "sync")] + Pull(pull::Pull), +} + +impl Cmd { + pub async fn run( + &self, + settings: &Settings, + database: &dyn Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Status => self.status(store).await, + Self::Rebuild(rebuild) => rebuild.run(settings, store, database).await, + Self::Rekey(rekey) => rekey.run(settings, store).await, + Self::Verify(verify) => verify.run(settings, store).await, + Self::Purge(purge) => purge.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Push(push) => push.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Pull(pull) => pull.run(settings, store, database).await, + } + } + + pub async fn status(&self, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + + let status = store.status().await?; + + // TODO: should probs build some data structure and then pretty-print it or smth + for (host, st) in status.hosts.iter().sorted_by_key(|(h, _)| *h) { + let host_string = if host == &host_id { + format!("host: {} <- CURRENT HOST", host.0.as_hyphenated()) + } else { + format!("host: {}", host.0.as_hyphenated()) + }; + + println!("{host_string}"); + + for (tag, idx) in st.iter().sorted_by_key(|(tag, _)| *tag) { + println!("\tstore: {tag}"); + + let first = store.first(*host, tag).await?; + let last = store.last(*host, tag).await?; + + println!("\t\tidx: {idx}"); + + if let Some(first) = first { + println!("\t\tfirst: {}", first.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + + if let Some(last) = last { + println!("\t\tlast: {}", last.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))? + .to_offset(offset); + println!("\t\t\tcreated: {time}"); + } + } + + println!(); + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/pull.rs b/crates/turtle/src/command/client/store/pull.rs new file mode 100644 index 00000000..c9c9c379 --- /dev/null +++ b/crates/turtle/src/command/client/store/pull.rs @@ -0,0 +1,94 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + database::Database, + encryption::load_key, + record::store::Store, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Pull { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option, + + /// Force push records + /// This will first wipe the local store, and then download all records from the remote + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to download at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Pull { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + db: &dyn Database, + ) -> Result<()> { + if self.force { + println!("Forcing local overwrite!"); + println!("Clearing local store"); + + store.delete_all().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they a download op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: local was already wiped above, mismatch is the user's call. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Upload { .. } => false, + + // pull, so yes plz to downloads! + Operation::Download { tag, .. } => { + if self.force { + return true; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (_, downloaded) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Downloaded {} records", downloaded.len()); + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/purge.rs b/crates/turtle/src/command/client/store/purge.rs new file mode 100644 index 00000000..f7996c4b --- /dev/null +++ b/crates/turtle/src/command/client/store/purge.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Purge {} + +impl Purge { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Purging local records that cannot be decrypted"); + + let key = load_key(settings)?; + + match store.purge(&key.into()).await { + Ok(()) => println!("Local store purge completed OK"), + Err(e) => println!("Failed to purge local store: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/push.rs b/crates/turtle/src/command/client/store/push.rs new file mode 100644 index 00000000..724dfbef --- /dev/null +++ b/crates/turtle/src/command/client/store/push.rs @@ -0,0 +1,112 @@ +use crate::atuin_common::record::HostId; +use clap::Args; +use eyre::Result; +use uuid::Uuid; + +use crate::atuin_client::{ + api_client::Client, + encryption::load_key, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Push { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option, + + /// The host to push, in the form of a UUID host ID. Defaults to the current host. + #[arg(long)] + pub host: Option, + + /// Force push records + /// This will override both host and tag, to be all hosts and all tags. First clear the remote store, then upload all of the + /// local store + #[arg(long, default_value = "false")] + pub force: bool, + + /// Page Size + /// How many records to upload at once. Defaults to 100 + #[arg(long, default_value = "100")] + pub page: u64, +} + +impl Push { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().await?; + + if self.force { + println!("Forcing remote store overwrite!"); + println!("Clearing remote store"); + + let client = Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout * 10, // we may be deleting a lot of data... so up the + // timeout + ) + .expect("failed to create client"); + + client.delete_store().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they an upload op? + // b) are they for the host/tag we are pushing here? + let client = sync::build_client(settings).await?; + let (diff, remote_index) = sync::diff(&client, &store).await?; + + // Skip on --force: that path intentionally replaces remote with local. + if !self.force { + let key: [u8; 32] = load_key(settings)?.into(); + sync::check_encryption_key(&client, &remote_index, &key) + .await + .map_err(crate::print_error::format_sync_error)?; + } + + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Download { .. } => false, + + // push, so yes plz to uploads! + Operation::Upload { host, tag, .. } => { + if self.force { + return true; + } + + if let Some(h) = self.host { + if HostId(h) != *host { + return false; + } + } else if *host != host_id { + return false; + } + + if let Some(t) = self.tag.clone() + && t != *tag + { + return false; + } + + true + } + }) + .collect(); + + let (uploaded, _) = sync::sync_remote(&client, operations, &store, self.page).await?; + + println!("Uploaded {uploaded} records"); + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rebuild.rs b/crates/turtle/src/command/client/store/rebuild.rs new file mode 100644 index 00000000..80e201c2 --- /dev/null +++ b/crates/turtle/src/command/client/store/rebuild.rs @@ -0,0 +1,58 @@ +use clap::Args; +use eyre::{Result, bail}; + +#[cfg(feature = "daemon")] +use crate::command::client::daemon as daemon_cmd; + +use crate::atuin_client::{ + database::Database, encryption, history::store::HistoryStore, + record::sqlite_store::SqliteStore, settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rebuild { + pub tag: String, +} + +impl Rebuild { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + // keep it as a string and not an enum atm + // would be super cool to build this dynamically in the future + // eg register handles for rebuilding various tags without having to make this part of the + // binary big + match self.tag.as_str() { + "history" => { + self.rebuild_history(settings, store.clone(), database) + .await?; + } + + tag => bail!("unknown tag: {tag}"), + } + + Ok(()) + } + + async fn rebuild_history( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store, host_id, encryption_key); + + history_store.build(database).await?; + + #[cfg(feature = "daemon")] + daemon_cmd::emit_event(settings, crate::atuin_daemon::DaemonEvent::HistoryRebuilt).await; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/rekey.rs b/crates/turtle/src/command/client/store/rekey.rs new file mode 100644 index 00000000..e63be447 --- /dev/null +++ b/crates/turtle/src/command/client/store/rekey.rs @@ -0,0 +1,41 @@ +use clap::Args; +use eyre::Result; +use tokio::{fs::File, io::AsyncWriteExt}; + +use crate::atuin_client::{ + encryption::{decode_key, generate_encoded_key, load_key}, + record::sqlite_store::SqliteStore, + record::store::Store, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rekey { + /// The new key to use for encryption. Omit for a randomly-generated key + key: Option, +} + +impl Rekey { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let key = if let Some(key) = self.key.clone() { + println!("Re-encrypting store with specified key"); + + key + } else { + println!("Re-encrypting store with freshly-generated key"); + let (_, encoded) = generate_encoded_key()?; + encoded + }; + + let current_key: [u8; 32] = load_key(settings)?.into(); + let new_key: [u8; 32] = decode_key(key.clone())?.into(); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Store rewritten. Saving new key"); + let mut file = File::create(settings.key_path.clone()).await?; + file.write_all(key.as_bytes()).await?; + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/store/verify.rs b/crates/turtle/src/command/client/store/verify.rs new file mode 100644 index 00000000..5aa1dc70 --- /dev/null +++ b/crates/turtle/src/command/client/store/verify.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use crate::atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Verify {} + +impl Verify { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Verifying local store can be decrypted with the current key"); + + let key = load_key(settings)?; + + match store.verify(&key.into()).await { + Ok(()) => println!("Local store encryption verified OK"), + Err(e) => println!("Failed to verify local store encryption: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/client/sync.rs b/crates/turtle/src/command/client/sync.rs new file mode 100644 index 00000000..a4839b5f --- /dev/null +++ b/crates/turtle/src/command/client/sync.rs @@ -0,0 +1,120 @@ +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use crate::atuin_client::{ + database::Database, + encryption, + history::store::HistoryStore, + record::{sqlite_store::SqliteStore, store::Store, sync}, + settings::Settings, +}; + +mod status; + +use crate::command::client::account; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Sync with the configured server + Sync { + /// Force re-download everything + #[arg(long, short)] + force: bool, + }, + + /// Login to the configured server + Login(account::login::Cmd), + + /// Log out + Logout, + + /// Register with the configured server + Register(account::register::Cmd), + + /// Print the encryption key for transfer to another machine + Key {}, + + /// Display the sync status + Status, +} + +impl Cmd { + pub async fn run( + self, + settings: Settings, + db: &impl Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Sync { force } => run(&settings, force, db, store).await, + Self::Login(l) => l.run(&settings, &store).await, + Self::Logout => account::logout::run().await, + Self::Register(r) => r.run(&settings).await, + Self::Status => status::run(&settings).await, + Self::Key {} => { + use crate::atuin_client::encryption::{encode_key, load_key}; + let key = load_key(&settings).wrap_err("could not load encryption key")?; + + let encode = encode_key(&key).wrap_err("could not encode encryption key")?; + println!("{encode}"); + + Ok(()) + } + } + } +} + +async fn run( + settings: &Settings, + force: bool, + db: &impl Database, + store: SqliteStore, +) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + + let history_length = db.history_count(true).await?; + let store_history_length = store.len_tag("history").await?; + + #[expect(clippy::cast_sign_loss)] + if history_length as u64 > store_history_length { + println!("{history_length} in history index, but {store_history_length} in history store"); + println!("Running automatic history store init..."); + + // Internally we use the global filter mode, so this context is ignored. + // don't recurse or loop here. + history_store.init_store(db).await?; + + println!("Re-running sync due to new records locally"); + + // we'll want to run sync once more, as there will now be stuff to upload + let (uploaded, downloaded) = sync::sync(settings, &store, &encryption_key) + .await + .map_err(crate::print_error::format_sync_error)?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + } + + println!( + "Sync complete! {} items in history database, force: {}", + db.history_count(true).await?, + force + ); + + Ok(()) +} diff --git a/crates/turtle/src/command/client/sync/status.rs b/crates/turtle/src/command/client/sync/status.rs new file mode 100644 index 00000000..00088b59 --- /dev/null +++ b/crates/turtle/src/command/client/sync/status.rs @@ -0,0 +1,37 @@ +use crate::{SHA, VERSION}; +use crate::atuin_client::{api_client, settings::Settings}; +use colored::Colorize; +use eyre::{Result, bail}; + +pub async fn run(settings: &Settings) -> Result<()> { + if !settings.logged_in().await? { + bail!("You are not logged in to a sync server - cannot show sync status"); + } + + let client = api_client::Client::new( + &settings.sync_address, + settings.sync_auth_token().await?, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + let me = client.me().await?; + let last_sync = Settings::last_sync().await?; + + println!("Atuin v{VERSION} - Build rev {SHA}\n"); + + println!("{}", "[Local]".green()); + + if settings.auto_sync { + println!("Sync frequency: {}", settings.sync_frequency); + println!("Last sync: {}", last_sync.to_offset(settings.timezone.0)); + } + + if settings.auto_sync { + println!("{}", "[Remote]".green()); + println!("Address: {}", settings.sync_address); + println!("Username: {}", me.username); + } + + Ok(()) +} diff --git a/crates/turtle/src/command/client/wrapped.rs b/crates/turtle/src/command/client/wrapped.rs new file mode 100644 index 00000000..694157c2 --- /dev/null +++ b/crates/turtle/src/command/client/wrapped.rs @@ -0,0 +1,326 @@ +use crossterm::style::{ResetColor, SetAttribute}; +use eyre::Result; +use std::collections::{HashMap, HashSet}; +use time::{Date, Duration, Month, OffsetDateTime, Time}; + +use crate::atuin_client::{database::Database, settings::Settings, theme::Theme}; + +use crate::atuin_history::stats::{Stats, compute}; + +#[derive(Debug)] +struct WrappedStats { + nav_commands: usize, + pkg_commands: usize, + error_rate: f64, + first_half_commands: Vec<(String, usize)>, + second_half_commands: Vec<(String, usize)>, + git_percentage: f64, + busiest_hour: Option<(String, usize)>, +} + +impl WrappedStats { + #[expect(clippy::too_many_lines, clippy::cast_precision_loss)] + fn new( + settings: &Settings, + stats: &Stats, + history: &[crate::atuin_client::history::History], + ) -> Self { + let nav_commands = stats + .top + .iter() + .filter(|(cmd, _)| { + let cmd = &cmd[0]; + cmd == "cd" || cmd == "ls" || cmd == "pwd" || cmd == "pushd" || cmd == "popd" + }) + .map(|(_, count)| count) + .sum(); + + let pkg_managers = [ + "cargo", + "npm", + "pnpm", + "yarn", + "pip", + "pip3", + "pipenv", + "poetry", + "pipx", + "uv", + "brew", + "apt", + "apt-get", + "apk", + "pacman", + "yay", + "paru", + "yum", + "dnf", + "dnf5", + "rpm", + "rpm-ostree", + "zypper", + "pkg", + "chocolatey", + "choco", + "scoop", + "winget", + "gem", + "bundle", + "shards", + "composer", + "gradle", + "maven", + "mvn", + "go get", + "nuget", + "dotnet", + "mix", + "hex", + "rebar3", + "nix", + "nix-env", + "cabal", + "opam", + ]; + + let pkg_commands = history + .iter() + .filter(|h| { + let cmd = h.command.clone(); + pkg_managers.iter().any(|pm| cmd.starts_with(pm)) + }) + .count(); + + // Error analysis + let mut command_errors: HashMap = HashMap::new(); // (total_uses, errors) + let midyear = history[0].timestamp + Duration::days(182); // Split year in half + + let mut first_half_commands: HashMap = HashMap::new(); + let mut second_half_commands: HashMap = HashMap::new(); + let mut hours: HashMap = HashMap::new(); + + for entry in history { + let cmd = entry + .command + .split_whitespace() + .next() + .unwrap_or("") + .to_string(); + let (total, errors) = command_errors.entry(cmd.clone()).or_insert((0, 0)); + *total += 1; + if entry.exit != 0 { + *errors += 1; + } + + // Track command evolution + if entry.timestamp < midyear { + *first_half_commands.entry(cmd.clone()).or_default() += 1; + } else { + *second_half_commands.entry(cmd).or_default() += 1; + } + + // Track hourly distribution + let local_time = entry + .timestamp + .to_offset(time::UtcOffset::current_local_offset().unwrap_or(settings.timezone.0)); + let hour = format!("{:02}:00", local_time.time().hour()); + *hours.entry(hour).or_default() += 1; + } + + let total_errors: usize = command_errors.values().map(|(_, errors)| errors).sum(); + let total_commands: usize = command_errors.values().map(|(total, _)| total).sum(); + let error_rate = total_errors as f64 / total_commands as f64; + + // Process command evolution data + let mut first_half: Vec<_> = first_half_commands.into_iter().collect(); + let mut second_half: Vec<_> = second_half_commands.into_iter().collect(); + first_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + second_half.sort_by_key(|(_, count)| std::cmp::Reverse(*count)); + first_half.truncate(5); + second_half.truncate(5); + + // Calculate git percentage + let git_commands: usize = stats + .top + .iter() + .filter(|(cmd, _)| cmd[0].starts_with("git")) + .map(|(_, count)| count) + .sum(); + let git_percentage = git_commands as f64 / stats.total_commands as f64; + + // Find busiest hour + let busiest_hour = hours.into_iter().max_by_key(|(_, count)| *count); + + Self { + nav_commands, + pkg_commands, + error_rate, + first_half_commands: first_half, + second_half_commands: second_half, + git_percentage, + busiest_hour, + } + } +} + +pub fn print_wrapped_header(year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + println!("{bold}╭────────────────────────────────────╮{reset}"); + println!("{bold}│ ATUIN WRAPPED {year} │{reset}"); + println!("{bold}│ Your Year in Shell History │{reset}"); + println!("{bold}╰────────────────────────────────────╯{reset}"); + println!(); +} + +#[expect(clippy::cast_precision_loss)] +fn print_fun_facts(wrapped_stats: &WrappedStats, stats: &Stats, year: i32) { + let reset = ResetColor; + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + if wrapped_stats.git_percentage > 0.05 { + println!( + "{bold}🌟 You're a Git Power User!{reset} {bold}{:.1}%{reset} of your commands were Git operations\n", + wrapped_stats.git_percentage * 100.0 + ); + } + // Navigation patterns + let nav_percentage = wrapped_stats.nav_commands as f64 / stats.total_commands as f64 * 100.0; + if nav_percentage > 0.05 { + println!( + "{bold}🚀 You're a Navigator!{reset} {bold}{nav_percentage:.1}%{reset} of your time was spent navigating directories\n", + ); + } + + // Command vocabulary + println!( + "{bold}📚 Command Vocabulary{reset}: You know {bold}{}{reset} unique commands\n", + stats.unique_commands + ); + + // Package management + println!( + "{bold}📦 Package Management{reset}: You ran {bold}{}{reset} package-related commands\n", + wrapped_stats.pkg_commands + ); + + // Error patterns + let error_percentage = wrapped_stats.error_rate * 100.0; + println!( + "{bold}🚨 Error Analysis{reset}: Your commands failed {bold}{error_percentage:.1}%{reset} of the time\n", + ); + + // Command evolution + println!("🔍 Command Evolution:"); + + // print stats for each half and compare + println!(" {bold}Top Commands{reset} in the first half of {year}:"); + for (cmd, count) in wrapped_stats.first_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + println!(" {bold}Top Commands{reset} in the second half of {year}:"); + for (cmd, count) in wrapped_stats.second_half_commands.iter().take(3) { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + + // Find new favorite commands (in top 5 of second half but not in first half) + let first_half_set: HashSet<_> = wrapped_stats + .first_half_commands + .iter() + .map(|(cmd, _)| cmd) + .collect(); + let new_favorites: Vec<_> = wrapped_stats + .second_half_commands + .iter() + .filter(|(cmd, _)| !first_half_set.contains(cmd)) + .take(2) + .collect(); + + if !new_favorites.is_empty() { + println!(" {bold}New favorites{reset} in the second half:"); + for (cmd, count) in new_favorites { + println!(" {bold}{cmd}{reset} ({count} times)"); + } + } + + // Time patterns + if let Some((hour, count)) = &wrapped_stats.busiest_hour { + println!("\n🕘 Most Productive Hour: {bold}{hour}{reset} ({count} commands)"); + + // Night owl or early bird + let hour_num = hour + .split(':') + .next() + .unwrap_or("0") + .parse::() + .unwrap_or(0); + if hour_num >= 22 || hour_num <= 4 { + println!(" You're quite the night owl! 🦉"); + } else if (5..=7).contains(&hour_num) { + println!(" Early bird gets the worm! 🐦"); + } + } + + println!(); +} + +pub async fn run( + year: Option, + db: &impl Database, + settings: &Settings, + theme: &Theme, +) -> Result<()> { + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let month = now.month(); + + // If we're in December, then wrapped is for the current year. If not, it's for the previous year + let year = year.unwrap_or_else(|| { + if month == Month::December { + now.year() + } else { + now.year() - 1 + } + }); + + let start = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::January, 1).unwrap(), + Time::MIDNIGHT, + now.offset(), + ); + let end = OffsetDateTime::new_in_offset( + Date::from_calendar_date(year, Month::December, 31).unwrap(), + Time::MIDNIGHT + Duration::days(1) - Duration::nanoseconds(1), + now.offset(), + ); + + let history = db.range(start, end).await?; + if history.is_empty() { + println!( + "Your history for {year} is empty!\nMaybe 'atuin import' could help you import your previous history 🪄" + ); + return Ok(()); + } + + // Compute overall stats using existing functionality + let stats = compute(settings, &history, 10, 1).expect("Failed to compute stats"); + let wrapped_stats = WrappedStats::new(settings, &stats, &history); + + // Print wrapped format + print_wrapped_header(year); + + println!("🎉 In {year}, you typed {} commands!", stats.total_commands); + println!( + " That's ~{} commands every day\n", + stats.total_commands / 365 + ); + + println!("Your Top Commands:"); + crate::atuin_history::stats::pretty_print(stats.clone(), 1, theme); + println!(); + + print_fun_facts(&wrapped_stats, &stats, year); + + Ok(()) +} diff --git a/crates/turtle/src/command/contributors.rs b/crates/turtle/src/command/contributors.rs new file mode 100644 index 00000000..452fd335 --- /dev/null +++ b/crates/turtle/src/command/contributors.rs @@ -0,0 +1,5 @@ +static CONTRIBUTORS: &str = include_str!("CONTRIBUTORS"); + +pub fn run() { + println!("\n{CONTRIBUTORS}"); +} diff --git a/crates/turtle/src/command/external.rs b/crates/turtle/src/command/external.rs new file mode 100644 index 00000000..e1f0cddd --- /dev/null +++ b/crates/turtle/src/command/external.rs @@ -0,0 +1,102 @@ +use std::fmt::Write as _; +use std::process::Command; +use std::{io, process}; + +#[cfg(feature = "client")] +use crate::atuin_client::plugin::{OfficialPluginRegistry, PluginContext}; +use clap::CommandFactory; +use clap::builder::{StyledStr, Styles}; +use eyre::Result; + +use crate::Atuin; + +pub fn run(args: &[String]) -> Result<()> { + let subcommand = &args[0]; + let bin = format!("atuin-{subcommand}"); + let mut cmd = Command::new(&bin); + cmd.args(&args[1..]); + + #[cfg(feature = "client")] + let context = PluginContext::new(subcommand); + + let spawn_result = match cmd.spawn() { + Ok(child) => Ok(child), + Err(e) => match e.kind() { + io::ErrorKind::NotFound => { + let output = render_not_found(subcommand, &bin); + Err(output) + } + _ => Err(e.to_string().into()), + }, + }; + + match spawn_result { + Ok(mut child) => { + let status = child.wait()?; + if status.success() { + Ok(()) + } else { + #[cfg(feature = "client")] + drop(context); + + process::exit(status.code().unwrap_or(1)); + } + } + Err(e) => { + eprintln!("{}", e.ansi()); + + #[cfg(feature = "client")] + drop(context); + + process::exit(1); + } + } +} + +fn render_not_found(subcommand: &str, bin: &str) -> StyledStr { + let mut output = StyledStr::new(); + let styles = Styles::styled(); + + let error = styles.get_error(); + let invalid = styles.get_invalid(); + let literal = styles.get_literal(); + + #[cfg(feature = "client")] + { + let registry = OfficialPluginRegistry::new(); + + // Check if this is an official plugin + if let Some(install_message) = registry.get_install_message(subcommand) { + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "'{invalid}{subcommand}{invalid:#}' is an official atuin plugin, but it's not installed" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{install_message}"); + return output; + } + } + + let mut atuin_cmd = Atuin::command(); + let usage = atuin_cmd.render_usage(); + + let _ = write!(output, "{error}error:{error:#} "); + let _ = write!( + output, + "unrecognized subcommand '{invalid}{subcommand}{invalid:#}' " + ); + let _ = write!( + output, + "and no executable named '{invalid}{bin}{invalid:#}' found in your PATH" + ); + let _ = write!(output, "\n\n"); + let _ = write!(output, "{usage}"); + let _ = write!(output, "\n\n"); + let _ = write!( + output, + "For more information, try '{literal}--help{literal:#}'." + ); + + output +} diff --git a/crates/turtle/src/command/gen_completions.rs b/crates/turtle/src/command/gen_completions.rs new file mode 100644 index 00000000..10d4f689 --- /dev/null +++ b/crates/turtle/src/command/gen_completions.rs @@ -0,0 +1,84 @@ +use clap::{CommandFactory, Parser, ValueEnum}; +use clap_complete::{Generator, Shell, generate, generate_to}; +use clap_complete_nushell::Nushell; +use eyre::Result; + +// clap put nushell completions into a separate package due to the maintainers +// being a little less committed to support them. +// This means we have to do a tiny bit of legwork to combine these completions +// into one command. +#[derive(Debug, Clone, ValueEnum)] +#[value(rename_all = "lower")] +pub enum GenShell { + Bash, + Elvish, + Fish, + Nushell, + PowerShell, + Zsh, +} + +impl Generator for GenShell { + fn file_name(&self, name: &str) -> String { + match self { + // clap_complete + Self::Bash => Shell::Bash.file_name(name), + Self::Elvish => Shell::Elvish.file_name(name), + Self::Fish => Shell::Fish.file_name(name), + Self::PowerShell => Shell::PowerShell.file_name(name), + Self::Zsh => Shell::Zsh.file_name(name), + + // clap_complete_nushell + Self::Nushell => Nushell.file_name(name), + } + } + + fn generate(&self, cmd: &clap::Command, buf: &mut dyn std::io::prelude::Write) { + match self { + // clap_complete + Self::Bash => Shell::Bash.generate(cmd, buf), + Self::Elvish => Shell::Elvish.generate(cmd, buf), + Self::Fish => Shell::Fish.generate(cmd, buf), + Self::PowerShell => Shell::PowerShell.generate(cmd, buf), + Self::Zsh => Shell::Zsh.generate(cmd, buf), + + // clap_complete_nushell + Self::Nushell => Nushell.generate(cmd, buf), + } + } +} + +#[derive(Debug, Parser)] +pub struct Cmd { + /// Set the shell for generating completions + #[arg(long, short)] + shell: GenShell, + + /// Set the output directory + #[arg(long, short)] + out_dir: Option, +} + +impl Cmd { + pub fn run(self) -> Result<()> { + let Cmd { shell, out_dir } = self; + + let mut cli = crate::Atuin::command(); + + match out_dir { + Some(out_dir) => { + generate_to(shell, &mut cli, env!("CARGO_PKG_NAME"), &out_dir)?; + } + None => { + generate( + shell, + &mut cli, + env!("CARGO_PKG_NAME"), + &mut std::io::stdout(), + ); + } + } + + Ok(()) + } +} diff --git a/crates/turtle/src/command/mod.rs b/crates/turtle/src/command/mod.rs new file mode 100644 index 00000000..e58bfe72 --- /dev/null +++ b/crates/turtle/src/command/mod.rs @@ -0,0 +1,156 @@ +use clap::Subcommand; +use eyre::Result; + +#[cfg(not(windows))] +use rustix::{fs::Mode, process::umask}; + +#[cfg(feature = "client")] +mod client; + +mod contributors; + +mod gen_completions; + +mod external; + +#[derive(Subcommand)] +#[command(infer_subcommands = true)] +#[expect(clippy::large_enum_variant)] +pub enum AtuinCmd { + #[cfg(feature = "client")] + #[command(flatten)] + Client(client::Cmd), + + /// PTY proxy for atuin + #[cfg(feature = "pty-proxy")] + #[command(alias = "hex")] + PtyProxy(crate::atuin_pty_proxy::PtyProxy), + + /// Generate a UUID + Uuid, + + Contributors, + + /// Generate shell completions + GenCompletions(gen_completions::Cmd), + + #[command(external_subcommand)] + External(Vec), +} + +impl AtuinCmd { + pub fn run(self) -> Result<()> { + #[cfg(not(windows))] + { + // set umask before we potentially open/create files + // or in other words, 077. Do not allow any access to any other user + let mode = Mode::RWXG | Mode::RWXO; + umask(mode); + } + + match self { + #[cfg(feature = "client")] + Self::Client(client) => client.run(), + + #[cfg(feature = "pty-proxy")] + Self::PtyProxy(proxy) => { + run_pty_proxy(proxy); + Ok(()) + } + + Self::Contributors => { + contributors::run(); + Ok(()) + } + Self::Uuid => { + println!("{}", crate::atuin_common::utils::uuid_v7().as_simple()); + Ok(()) + } + Self::GenCompletions(gen_completions) => gen_completions.run(), + Self::External(args) => external::run(&args), + } + } +} + +#[cfg(all(feature = "pty-proxy", unix))] +fn run_pty_proxy(proxy: crate::atuin_pty_proxy::PtyProxy) { + #[cfg(feature = "daemon")] + proxy.run(semantic_command_capture_sink()); + + #[cfg(not(feature = "daemon"))] + proxy.run(None); +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +fn semantic_command_capture_sink() -> Option { + use std::sync::mpsc; + use std::time::Duration; + + if is_truthy_env("ATUIN_TERMINAL") { + return None; + } + + let settings = crate::atuin_client::settings::Settings::new().ok()?; + let (tx, rx) = mpsc::sync_channel::(128); + + std::thread::spawn(move || { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + while let Ok(first) = rx.recv() { + let mut batch = vec![first]; + + while batch.len() < 64 { + match rx.recv_timeout(Duration::from_millis(25)) { + Ok(capture) => batch.push(capture), + Err(mpsc::RecvTimeoutError::Timeout | mpsc::RecvTimeoutError::Disconnected) => { + break; + } + } + } + + runtime.block_on(send_semantic_command_captures(&settings, batch)); + } + }); + + Some(Box::new(move |capture| { + let _ = tx.try_send(capture); + })) +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +#[inline] +fn is_truthy_env(name: &str) -> bool { + std::env::var(name) + .ok() + .as_ref() + .is_some_and(|value| !value.trim().is_empty() && value.trim() != "false") +} + +#[cfg(all(feature = "daemon", feature = "pty-proxy", unix))] +async fn send_semantic_command_captures( + settings: &crate::atuin_client::settings::Settings, + batch: Vec, +) { + let captures = batch + .into_iter() + .map(|capture| crate::atuin_daemon::semantic::CommandCapture { + prompt: capture.prompt, + command: capture.command, + output: capture.output, + exit_code: capture.exit_code, + history_id: capture.history_id, + session_id: capture.session_id, + output_truncated: capture.output_truncated, + output_observed_bytes: capture.output_observed_bytes, + }) + .collect(); + + if let Ok(mut client) = crate::atuin_daemon::SemanticClient::from_settings(settings).await { + let _ = client.record_commands(captures).await; + } +} diff --git a/crates/turtle/src/main.rs b/crates/turtle/src/main.rs new file mode 100644 index 00000000..e5b80ee8 --- /dev/null +++ b/crates/turtle/src/main.rs @@ -0,0 +1,73 @@ +#![warn(clippy::pedantic, clippy::nursery)] +#![allow(clippy::use_self, clippy::missing_const_for_fn)] // not 100% reliable +// #![deny(unsafe_code)] +#![forbid(unsafe_code)] + +use clap::Parser; +use clap::builder::Styles; +use clap::builder::styling::{AnsiColor, Effects}; +use eyre::Result; + +use command::AtuinCmd; + +mod command; + +mod atuin_client; +mod atuin_common; +mod atuin_daemon; +mod atuin_history; +mod atuin_pty_proxy; +mod atuin_server; +mod atuin_server_database; +mod atuin_server_postgres; +mod atuin_server_sqlite; + +#[cfg(feature = "sync")] +mod print_error; +#[cfg(feature = "sync")] +mod sync; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); +const SHA: &str = env!("GIT_HASH"); + +const LONG_VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), " (", env!("GIT_HASH"), ")"); + +static HELP_TEMPLATE: &str = "\ +{before-help}{name} {version} +{author} +{about} + +{usage-heading} + {usage} + +{all-args}{after-help}"; + +const STYLES: Styles = Styles::styled() + .header(AnsiColor::Yellow.on_default().effects(Effects::BOLD)) + .usage(AnsiColor::Green.on_default().effects(Effects::BOLD)) + .literal(AnsiColor::Green.on_default().effects(Effects::BOLD)) + .placeholder(AnsiColor::Green.on_default()); + +/// Magical shell history +#[derive(Parser)] +#[command( + author = "Ellie Huxtable ", + version = VERSION, + long_version = LONG_VERSION, + help_template(HELP_TEMPLATE), + styles = STYLES, +)] +struct Atuin { + #[command(subcommand)] + atuin: AtuinCmd, +} + +impl Atuin { + fn run(self) -> Result<()> { + self.atuin.run() + } +} + +fn main() -> Result<()> { + Atuin::parse().run() +} diff --git a/crates/turtle/src/print_error.rs b/crates/turtle/src/print_error.rs new file mode 100644 index 00000000..4d4724bc --- /dev/null +++ b/crates/turtle/src/print_error.rs @@ -0,0 +1,123 @@ +use std::io::IsTerminal; + +use crate::atuin_client::record::sync::SyncError; +use colored::Colorize; +use crossterm::terminal; + +/// Print a prominent error to stderr. Colored and box-bordered when stderr is +/// a TTY, plain "Error: ..." header otherwise. The description is word-wrapped +/// to the terminal width (capped at 100 columns) so the message stays readable. +pub fn print_error(title: &str, description: &str) { + let is_tty = std::io::stderr().is_terminal(); + let width = if is_tty { + terminal::size().map_or(80, |(w, _)| w as usize) + } else { + 80 + } + .min(100); + + eprintln!(); + if is_tty { + let bar = "━".repeat(width).red().bold().to_string(); + eprintln!("{bar}"); + eprintln!(" {} {}", "✗".red().bold(), title.red().bold()); + eprintln!("{bar}"); + } else { + eprintln!("Error: {title}"); + eprintln!("{}", "-".repeat(width)); + } + eprintln!(); + + for line in wrap_text(description, width.saturating_sub(2)) { + eprintln!(" {line}"); + } + eprintln!(); +} + +/// Convert a `SyncError` into an `eyre::Report`, exiting on `WrongKey` after +/// painting the prominent banner. +pub fn format_sync_error(e: SyncError) -> eyre::Report { + if matches!(e, SyncError::WrongKey) { + print_error( + "Wrong encryption key", + "Your local encryption key cannot decrypt the data on the server. \ + This usually means another machine wrote records with a different key.\n\n\ + To fix this, find the correct key by running `atuin key` on a machine that \ + already syncs successfully, then run `atuin store rekey ` here.", + ); + std::process::exit(1); + } + e.into() +} + +fn wrap_text(text: &str, width: usize) -> Vec { + let mut out = Vec::new(); + for paragraph in text.split('\n') { + let mut line = String::new(); + let mut line_len = 0; + for word in paragraph.split_whitespace() { + let word_len = word.chars().count(); + if !line.is_empty() && line_len + 1 + word_len > width { + out.push(std::mem::take(&mut line)); + line_len = 0; + } + if !line.is_empty() { + line.push(' '); + line_len += 1; + } + line.push_str(word); + line_len += word_len; + } + // Push every paragraph's final line (even empty) so `\n\n` in the + // input becomes a blank line in the output. + out.push(line); + } + while out.first().is_some_and(String::is_empty) { + out.remove(0); + } + while out.last().is_some_and(String::is_empty) { + out.pop(); + } + out +} + +#[cfg(test)] +mod tests { + use super::wrap_text; + + #[test] + fn wraps_long_text() { + let lines = wrap_text("the quick brown fox jumps over the lazy dog", 20); + for line in &lines { + assert!(line.chars().count() <= 20, "line too long: {line:?}"); + } + assert_eq!( + lines.join(" "), + "the quick brown fox jumps over the lazy dog" + ); + } + + #[test] + fn preserves_explicit_newlines() { + let lines = wrap_text("first line\nsecond line", 80); + assert_eq!(lines, vec!["first line", "second line"]); + } + + #[test] + fn handles_word_longer_than_width() { + let lines = wrap_text("short superlongword more", 5); + assert_eq!(lines, vec!["short", "superlongword", "more"]); + } + + #[test] + fn preserves_blank_lines_between_paragraphs() { + let lines = wrap_text("first paragraph\n\nsecond paragraph", 80); + assert_eq!(lines, vec!["first paragraph", "", "second paragraph"]); + } + + #[test] + fn trims_leading_and_trailing_blank_lines() { + let lines = wrap_text("\n\nbody\n\n", 80); + assert_eq!(lines, vec!["body"]); + } +} diff --git a/crates/turtle/src/shell/.gitattributes b/crates/turtle/src/shell/.gitattributes new file mode 100644 index 00000000..fae8897c --- /dev/null +++ b/crates/turtle/src/shell/.gitattributes @@ -0,0 +1 @@ +* eol=lf diff --git a/crates/turtle/src/shell/atuin.bash b/crates/turtle/src/shell/atuin.bash new file mode 100644 index 00000000..8b540bd7 --- /dev/null +++ b/crates/turtle/src/shell/atuin.bash @@ -0,0 +1,725 @@ +# Include guard +if [[ ${__atuin_initialized-} == true ]]; then + false +elif [[ $- != *i* ]]; then + # Enable only in interactive shells + false +elif ((BASH_VERSINFO[0] < 3 || BASH_VERSINFO[0] == 3 && BASH_VERSINFO[1] < 1)); then + # Require bash >= 3.1 + [[ -t 2 ]] && printf 'atuin: requires bash >= 3.1 for the integration.\n' >&2 + false +else # (include guard) beginning of main content +#------------------------------------------------------------------------------ +__atuin_initialized=true + +if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then + ATUIN_SESSION=$(atuin uuid) + export ATUIN_SESSION + export ATUIN_SHLVL=$SHLVL +fi +ATUIN_STTY=$(stty -g) +ATUIN_HISTORY_ID="" + +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" && "$ATUIN_HISTORY_ID" != "__bash_preexec_failure__" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'\001\033]133;A;cl=line\a\002' +__atuin_osc133_prompt_end=$'\001\033]133;B\a\002' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PS1-}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PS1="${__atuin_osc133_prompt_start}${__atuin_prompt}${__atuin_osc133_prompt_end}" + else + PS1="$__atuin_prompt" + fi +} + +export ATUIN_PREEXEC_BACKEND=$SHLVL:none +__atuin_update_preexec_backend() { + if [[ ${BLE_ATTACHED-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:blesh-${BLE_VERSION-} + elif [[ ${bash_preexec_imported-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:bash-preexec + elif [[ ${__bp_imported-} ]]; then + ATUIN_PREEXEC_BACKEND="$SHLVL:bash-preexec (old)" + else + ATUIN_PREEXEC_BACKEND=$SHLVL:unknown + fi +} + +__atuin_preexec() { + # Workaround for old versions of bash-preexec + if [[ ! ${BLE_ATTACHED-} ]]; then + # In older versions of bash-preexec, the preexec hook may be called + # even for the commands run by keybindings. There is no general and + # robust way to detect the command for keybindings, but at least we + # want to exclude Atuin's keybindings. When the preexec hook is called + # for a keybinding, the preexec hook for the user command will not + # fire, so we instead set a fake ATUIN_HISTORY_ID here to notify + # __atuin_precmd of this failure. + if [[ $BASH_COMMAND != "$1" ]]; then + case $BASH_COMMAND in + '__atuin_history'* | '__atuin_widget_run'* | '__atuin_bash42_dispatch'*) + ATUIN_HISTORY_ID=__bash_preexec_failure__ + return 0 ;; + esac + fi + fi + + # Note: We update ATUIN_PREEXEC_BACKEND on every preexec because blesh's + # attaching state can dynamically change. + __atuin_update_preexec_backend + + local id + id=$(atuin history start -- "$1" 2>/dev/null) + export ATUIN_HISTORY_ID=$id + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_executed + __atuin_preexec_time=${EPOCHREALTIME-} +} + +__atuin_precmd() { + local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} + + __atuin_osc133_wrap_prompt + + [[ ! $ATUIN_HISTORY_ID ]] && return + + # If the previous preexec hook failed, we manually call __atuin_preexec + local __atuin_skip_osc133="" + if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then + # This is the command extraction code taken from bash-preexec + local previous_command + previous_command=$( + export LC_ALL=C HISTTIMEFORMAT='' + builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' + ) + __atuin_skip_osc133=1 + __atuin_preexec "$previous_command" + fi + + local duration="" + # shellcheck disable=SC2154,SC2309 + if [[ ${BLE_ATTACHED-} && ${_ble_exec_time_ata-} ]]; then + # With ble.sh, we utilize the shell variable `_ble_exec_time_ata` + # recorded by ble.sh. It is more accurate than the measurements by + # Atuin, which includes the spawn cost of Atuin. ble.sh uses the + # special shell variable `EPOCHREALTIME` in bash >= 5.0 with the + # microsecond resolution, or the builtin `time` in bash < 5.0 with the + # millisecond resolution. + duration=${_ble_exec_time_ata}000 + elif ((BASH_VERSINFO[0] >= 5)); then + # We calculate the high-resolution duration based on EPOCHREALTIME + # (bash >= 5.0) recorded by precmd/preexec, though it might not be as + # accurate as `_ble_exec_time_ata` provided by ble.sh because it + # includes the extra time of the precmd/preexec handling. Since Bash + # does not offer floating-point arithmetic, we remove the non-digit + # characters and perform the integral arithmetic. The fraction part of + # EPOCHREALTIME is fixed to have 6 digits in Bash. We remove all the + # non-digit characters because the decimal point is not necessarily a + # period depending on the locale. + duration=$((${__atuin_precmd_time//[!0-9]} - ${__atuin_preexec_time//[!0-9]})) + if ((duration >= 0)); then + duration=${duration}000 + else + duration="" # clear the result on overflow + fi + fi + + [[ -n ${__atuin_skip_osc133:-} ]] || __atuin_osc133_command_finished "$EXIT" + (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +__atuin_set_ret_value() { + return ${1:+"$1"} +} + +#------------------------------------------------------------------------------ +# section: __atuin_accept_line +# +# The function "__atuin_accept_line" is kept for backward compatibility of the +# direct use of __atuin_history in keybindings by users. + +# The shell function `__atuin_evaluate_prompt` evaluates prompt sequences in +# $PS1. We switch the implementation of the shell function +# `__atuin_evaluate_prompt` based on the Bash version because the expansion +# ${PS1@P} is only available in bash >= 4.4. +if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 4)); then + __atuin_evaluate_prompt() { + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + __atuin_prompt=${PS1@P} + + # Note: Strip the control characters ^A (\001) and ^B (\002), which + # Bash internally uses to enclose the escape sequences. They are + # produced by '\[' and '\]', respectively, in $PS1 and used to tell + # Bash that the strings inbetween do not contribute to the prompt + # width. After the prompt width calculation, Bash strips those control + # characters before outputting it to the terminal. We here strip these + # characters following Bash's behavior. + __atuin_prompt=${__atuin_prompt//[$'\001\002']} + + # Count the number of newlines contained in $__atuin_prompt + __atuin_prompt_offset=${__atuin_prompt//[!$'\n']} + __atuin_prompt_offset=${#__atuin_prompt_offset} + } +else + __atuin_evaluate_prompt() { + __atuin_prompt='$ ' + __atuin_prompt_offset=0 + } +fi + +# The shell function `__atuin_clear_prompt N` outputs terminal control +# sequences to clear the contents of the current and N previous lines. After +# clearing, the cursor is placed at the beginning of the N-th previous line. +__atuin_clear_prompt_cache=() +__atuin_clear_prompt() { + local offset=$1 + if [[ ! ${__atuin_clear_prompt_cache[offset]+set} ]]; then + if [[ ! ${__atuin_clear_prompt_cache[0]+set} ]]; then + __atuin_clear_prompt_cache[0]=$'\r'$(tput el 2>/dev/null || tput ce 2>/dev/null) + fi + if ((offset > 0)); then + __atuin_clear_prompt_cache[offset]=${__atuin_clear_prompt_cache[0]}$( + tput cuu "$offset" 2>/dev/null || tput UP "$offset" 2>/dev/null + tput dl "$offset" 2>/dev/null || tput DL "$offset" 2>/dev/null + tput il "$offset" 2>/dev/null || tput AL "$offset" 2>/dev/null + ) + fi + fi + printf '%s' "${__atuin_clear_prompt_cache[offset]}" +} + +__atuin_accept_line() { + local __atuin_command=$1 + + # Reprint the prompt, accounting for multiple lines + local __atuin_prompt __atuin_prompt_offset + __atuin_evaluate_prompt + __atuin_clear_prompt "$__atuin_prompt_offset" + printf '%s\n' "$__atuin_prompt$__atuin_command" + + # Add it to the bash history + history -s "$__atuin_command" + + # Assuming bash-preexec + # Invoke every function in the preexec array + local __atuin_preexec_function + local __atuin_preexec_function_ret_value + local __atuin_preexec_ret_value=0 + for __atuin_preexec_function in "${preexec_functions[@]:-}"; do + if type -t "$__atuin_preexec_function" 1>/dev/null; then + __atuin_set_ret_value "${__bp_last_ret_value:-}" + "$__atuin_preexec_function" "$__atuin_command" + __atuin_preexec_function_ret_value=$? + if [[ $__atuin_preexec_function_ret_value != 0 ]]; then + __atuin_preexec_ret_value=$__atuin_preexec_function_ret_value + fi + fi + done + + # If extdebug is turned on and any preexec function returns non-zero + # exit status, we do not run the user command. + if ! { shopt -q extdebug && ((__atuin_preexec_ret_value)); }; then + # Note: When a child Bash session is started by enter_accept, if the + # environment variable READLINE_POINT is present, bash-preexec in the + # child session does not fire preexec at all because it considers we + # are inside Atuin's keybinding of the current session. To avoid + # propagating the environment variable to the child session, we remove + # the export attribute of READLINE_LINE and READLINE_POINT. + export -n READLINE_LINE READLINE_POINT + + # Juggle the terminal settings so that the command can be interacted + # with + local __atuin_stty_backup + __atuin_stty_backup=$(stty -g) + stty "$ATUIN_STTY" + + # Execute the command. Note: We need to record $? and $_ after the + # user command within the same call of "eval" because $_ is otherwise + # overwritten by the last argument of "eval". + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_command"$'\n__bp_last_ret_value=$? __bp_last_argument_prev_command=$_' + + stty "$__atuin_stty_backup" + fi + + # Execute preprompt commands + local __atuin_prompt_command + for __atuin_prompt_command in "${PROMPT_COMMAND[@]}"; do + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_prompt_command" + done + # Bash will redraw only the line with the prompt after we finish, + # so to work for a multiline prompt we need to print it ourselves, + # then go to the beginning of the last line. + __atuin_evaluate_prompt + printf '%s' "$__atuin_prompt" + __atuin_clear_prompt 0 +} + +#------------------------------------------------------------------------------ + +# Check if tmux popup is available (tmux >= 3.2) +__atuin_tmux_popup_check() { + [[ -n "${TMUX-}" ]] || return 1 + [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 + + # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme + local tmux_version + tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... + [[ -z "$tmux_version" ]] && return 1 + + local m1 m2 + m1=${tmux_version%%.*} + m2=${tmux_version#*.} + m2=${m2%%.*} + [[ "$m1" =~ ^[0-9]+$ ]] || return 1 + [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 + (( m1 > 3 || (m1 == 3 && m2 >= 2) )) +} + +# Use global variable to fix scope issues with traps +__atuin_popup_tmpdir="" +__atuin_tmux_popup_cleanup() { + [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" + __atuin_popup_tmpdir="" +} + +__atuin_search_cmd() { + local -a search_args=("$@") + + if __atuin_tmux_popup_check; then + __atuin_popup_tmpdir=$(mktemp -d) || return 1 + local result_file="$__atuin_popup_tmpdir/result" + + trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM + + local escaped_query escaped_args + escaped_query=$(printf '%s' "$READLINE_LINE" | sed "s/'/'\\\\''/g") + escaped_args="" + for arg in "${search_args[@]}"; do + escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" + done + + # In the popup, atuin goes to terminal, stderr goes to file + local cdir popup_width popup_height + cdir=$(pwd) + popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways + popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" + + if [[ -f "$result_file" ]]; then + cat "$result_file" + fi + + __atuin_tmux_popup_cleanup + trap - EXIT HUP INT TERM + else + ATUIN_SHELL=bash ATUIN_LOG=error ATUIN_QUERY=$READLINE_LINE atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- + fi +} + +__atuin_history() { + # Default action of the up key: When this function is called with the first + # argument `--shell-up-key-binding`, we perform Atuin's history search only + # when the up key is supposed to cause the history movement in the original + # binding. We do this only for ble.sh because the up key always invokes + # the history movement in the plain Bash. + if [[ ${BLE_ATTACHED-} && ${1-} == --shell-up-key-binding ]]; then + # When the current cursor position is not in the first line, the up key + # should move the cursor to the previous line. While the selection is + # performed, the up key should not start the history search. + # shellcheck disable=SC2154 # Note: these variables are set by ble.sh + if [[ ${_ble_edit_str::_ble_edit_ind} == *$'\n'* || $_ble_edit_mark_active ]]; then + ble/widget/@nomarked backward-line + local status=$? + READLINE_LINE=$_ble_edit_str + READLINE_POINT=$_ble_edit_ind + READLINE_MARK=$_ble_edit_mark + return "$status" + fi + fi + + # READLINE_LINE and READLINE_POINT are only supported by bash >= 4.0 or + # ble.sh. When it is not supported, we clear them to suppress strange + # behaviors. + [[ ${BLE_ATTACHED-} ]] || ((BASH_VERSINFO[0] >= 4)) || + READLINE_LINE="" READLINE_POINT=0 + + local __atuin_output + if ! __atuin_output=$(__atuin_search_cmd "$@"); then + [[ $__atuin_output ]] && printf '%s\n' "$__atuin_output" >&2 + return 1 + fi + + # We do nothing when the search is canceled. + [[ $__atuin_output ]] || return 0 + + if [[ $__atuin_output == __atuin_accept__:* ]]; then + __atuin_output=${__atuin_output#__atuin_accept__:} + + if [[ ${BLE_ATTACHED-} ]]; then + ble-edit/content/reset-and-check-dirty "$__atuin_output" + ble/widget/accept-line + READLINE_LINE="" + elif [[ ${__atuin_macro_chain_keymap-} ]]; then + READLINE_LINE=$__atuin_output + bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_accept_line" + else + __atuin_accept_line "$__atuin_output" + READLINE_LINE="" + fi + + READLINE_POINT=${#READLINE_LINE} + else + READLINE_LINE=$__atuin_output + READLINE_POINT=${#READLINE_LINE} + if [[ ! ${BLE_ATTACHED-} ]] && ((BASH_VERSINFO[0] < 4)) && [[ ${__atuin_macro_chain_keymap-} ]]; then + bind -m "$__atuin_macro_chain_keymap" '"'"$__atuin_macro_chain"'": '"$__atuin_macro_insert_line" + fi + fi +} + +__atuin_initialize_blesh() { + # shellcheck disable=SC2154 + [[ ${BLE_VERSION-} ]] && ((_ble_version >= 400)) || return 0 + + ble-import contrib/integration/bash-preexec + + # Define and register an autosuggestion source for ble.sh's auto-complete. + # If you'd like to overwrite this, define the same name of shell function + # after the $(atuin init bash) line in your .bashrc. If you do not need + # the auto-complete source by Atuin, please add the following code to + # remove the entry after the $(atuin init bash) line in your .bashrc: + # + # ble/util/import/eval-after-load core-complete ' + # ble/array#remove _ble_complete_auto_source atuin-history' + # + function ble/complete/auto-complete/source:atuin-history { + local suggestion + suggestion=$(ATUIN_QUERY="$_ble_edit_str" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) + [[ $suggestion == "$_ble_edit_str"?* ]] || return 1 + ble/complete/auto-complete/enter h 0 "${suggestion:${#_ble_edit_str}}" '' "$suggestion" + } + ble/util/import/eval-after-load core-complete ' + ble/array#unshift _ble_complete_auto_source atuin-history' + + # @env BLE_SESSION_ID: `atuin doctor` references the environment variable + # BLE_SESSION_ID. We explicitly export the variable because it was not + # exported in older versions of ble.sh. + [[ ${BLE_SESSION_ID-} ]] && export BLE_SESSION_ID +} +__atuin_initialize_blesh +BLE_ONLOAD+=(__atuin_initialize_blesh) +precmd_functions+=(__atuin_precmd) +preexec_functions+=(__atuin_preexec) + +#------------------------------------------------------------------------------ +# section: atuin-bind + +__atuin_widget=() + +__atuin_widget_save() { + local data=$1 + for REPLY in "${!__atuin_widget[@]}"; do + if [[ ${__atuin_widget[REPLY]} == "$data" ]]; then + return 0 + fi + done + # shellcheck disable=SC2154 + REPLY=${#__atuin_widget[*]} + __atuin_widget[REPLY]=$data +} + +__atuin_widget_run() { + local data=${__atuin_widget[$1]} + local keymap=${data%%:*} widget=${data#*:} + local __atuin_macro_chain_keymap=$keymap + bind -m "$keymap" '"'"$__atuin_macro_chain"'": ""' + builtin eval -- "$widget" +} + +# To realize the enter_accept feature in a robust way, we need to call the +# readline bindable function `accept-line'. However, there is no way to call +# `accept-line' from the shell script. To call the bindable function +# `accept-line', we may utilize string macros of readline. When we bind KEYSEQ +# to a WIDGET that wants to conditionally call `accept-line' at the end, we +# perform two-step dispatching: +# +# 1. [KEYSEQ -> IKEYSEQ1 IKEYSEQ2]---We first translate KEYSEQ to two +# intermediate key sequences IKEYSEQ1 and IKEYSEQ2 using string macros. For +# example, when we bind `__atuin_history` to \C-r, this step can be set up by +# `bind '"\C-r": "IKEYSEQ1IKEYSEQ2"'`. +# +# 2. [IKEYSEQ1 -> WIDGET]---Then, IKEYSEQ1 is bound to the WIDGET, and the +# binding of IKEYSEQ2 is dynamically determined by WIDGET. For example, when +# we bind `__atuin_history` to \C-r, this step can be set up by `bind -x +# '"IKEYSEQ1": WIDGET'`. +# +# 3. [IKEYSEQ2 -> accept-line] or [IKEYSEQ2 -> ""]---To request the execution +# of `accept-line', WIDGET can change the binding of IKEYSEQ2 by running +# `bind '"IKEYSEQ2": accept-line''. Otherwise, WIDGET can change the binding +# of IKEYSEQ2 to no-op by running `bind '"IKEYSEQ2": ""'`. +# +# For the choice of the intermediate key sequences, we want to choose key +# sequences that are unlikely to conflict with others. In addition, we want to +# avoid a key sequence containing \e because keymap "vi-insert" stops +# processing key sequences containing \e in older versions of Bash. We have +# used \e[0;A (a variant of the [up] key with modifier ) in Atuin 3.10.0 +# for intermediate key sequences, but this contains \e and caused a problem. +# Instead, we use \C-x\C-_A\a, which starts with \C-x\C-_ (an unlikely +# two-byte combination) and A (represents the initial letter of Atuin), +# followed by the payload and the terminator \a (BEL, \C-g). + +__atuin_macro_chain='\C-x\C-_A0\a' +for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" "\"$__atuin_macro_chain\": \"\"" +done +unset -v __atuin_keymap + +if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 3)); then + # In Bash >= 4.3 + + __atuin_macro_accept_line=accept-line + + __atuin_bind_impl() { + local keymap=$1 keyseq=$2 command=$3 + + # Note: In Bash <= 5.0, the table for `bind -x` from the keyseq to the + # command is shared by all the keymaps (emacs, vi-insert, and + # vi-command), so one cannot safely bind different command strings to + # the same keyseq in different keymaps. Therefore, the command string + # and the keyseq need to be globally in one-to-one correspondence in + # all the keymaps. + local REPLY + __atuin_widget_save "$keymap:$command" + local widget=$REPLY + local ikeyseq1='\C-x\C-_A'$((1 + widget))'\a' + local ikeyseq2=$__atuin_macro_chain + + if ((BASH_VERSINFO[0] == 5 && BASH_VERSINFO[1] == 1)); then + # Workaround for Bash 5.1: Bash 5.1 has a bug that overwriting an + # existing "bind -x" keybinding breaks other existing "bind -x" + # keybindings [1,2]. To work around the problem, we explicitly + # unbind an existing keybinding before overwriting it. + # + # [1] https://lists.gnu.org/archive/html/bug-bash/2021-04/msg00135.html + # [2] https://github.com/atuinsh/atuin/issues/962#issuecomment-3451132291 + bind -m "$keymap" -r "$keyseq" + fi + + bind -m "$keymap" "\"$keyseq\": \"$ikeyseq1$ikeyseq2\"" + bind -m "$keymap" -x "\"$ikeyseq1\": __atuin_widget_run $widget" + } + + __atuin_bind_blesh_onload() { + # In ble.sh, we need to enable unrecognized CSI sequences like \e[0;0A, + # which are discarded by ble.sh by default. Note: In Bash <= 4.2, we + # do not need to unset "decode_error_cseq_discard" because \e[0;A is + # used only for the macro chaining (which is unused by ble.sh) in Bash + # <= 4.2. + bleopt decode_error_cseq_discard= + } + if [[ ${BLE_VERSION-} ]]; then + __atuin_bind_blesh_onload + fi + BLE_ONLOAD+=(__atuin_bind_blesh_onload) +else + # In Bash <= 4.2, "bind -x" cannot bind a shell command to a keyseq having + # more than two bytes, so we need to work with only two-byte sequences. + # + # However, the number of available combinations of two-byte sequences is + # limited. To minimize the number of key sequences used by Atuin, instead + # of specifying a widget by its own intermediate sequence, we specify a + # widget by a fixed-length sequence of multiple two-byte sequences. More + # specifically, instead of IKEYSEQ1, we use IKS1 IKS2 IKS3 [IKS4 IKS5] + # IKSX, where IKS1..IKS5 just stores its information to a global variable, + # and IKSX collects all the information and determine and call the actual + # widget based on the stored information. Each of IKn (n=1..5) is one of + # the two reserved sequences, $__atuin_bash42_code0 and + # $__atuin_bash42_code1. IKSX is fixed to be $__atuin_bash42_code2. + # + # For the choices of the special key sequences, we consider \C-xQ, \C-xR, + # and \C-xS. In the emacs editing mode of Bash, \C-x is used as a prefix + # key, i.e., it is used for the beginning key of the keybindings with + # multiple keys, so \C-x is unlikely to be used for a single-key binding by + # the user. Also, \C-x is not used in the vi editing mode by default. The + # combinations \C-xQ..\C-xS are also unlikely be used because we need to + # switch the modifier keys from Control to Shift to input these sequences, + # and these are not easy to input. + __atuin_bash42_code0='\C-xQ' + __atuin_bash42_code1='\C-xR' + __atuin_bash42_code2='\C-xS' + + __atuin_bash42_encode() { + REPLY= + local n=$1 min_width=${2-} + while + if ((n % 2 == 0)); then + REPLY=$__atuin_bash42_code0$REPLY + else + REPLY=$__atuin_bash42_code1$REPLY + fi + (((n /= 2) || ${#REPLY} / ${#__atuin_bash42_code0} < min_width)) + do :; done + } + + __atuin_bash42_bind() { + local __atuin_keymap + for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code0"'": __atuin_bash42_dispatch_selector+=0' + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code1"'": __atuin_bash42_dispatch_selector+=1' + bind -m "$__atuin_keymap" -x '"'"$__atuin_bash42_code2"'": __atuin_bash42_dispatch' + done + } + __atuin_bash42_bind + # In Bash <= 4.2, there is no way to read users' "bind -x" settings, so we + # need to explicitly perform "bind -x" when ble.sh is loaded. + BLE_ONLOAD+=(__atuin_bash42_bind) + + if ((BASH_VERSINFO[0] >= 4)); then + __atuin_macro_accept_line=accept-line + else + # Note: We rewrite the command line and invoke `accept-line'. In + # bash <= 3.2, there is no way to rewrite the command line from the + # shell script, so we rewrite it using a macro and + # `shell-expand-line'. + # + # Note: Concerning the key sequences to invoke bindable functions + # such as "\C-x\C-_A1\a", another option is to use + # "\exbegginning-of-line\r", etc. to make it consistent with bash + # >= 5.3. However, an older Bash configuration can still conflict + # on [M-x]. The conflict is more likely than \C-x\C-_A1\a. + for __atuin_keymap in emacs vi-insert vi-command; do + bind -m "$__atuin_keymap" '"\C-x\C-_A1\a": beginning-of-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A2\a": kill-line' + # shellcheck disable=SC2016 + bind -m "$__atuin_keymap" '"\C-x\C-_A3\a": "$READLINE_LINE"' + bind -m "$__atuin_keymap" '"\C-x\C-_A4\a": shell-expand-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A5\a": accept-line' + bind -m "$__atuin_keymap" '"\C-x\C-_A6\a": end-of-line' + done + unset -v __atuin_keymap + + bind -m vi-command '"\C-x\C-_A7\a": vi-insertion-mode' + bind -m vi-insert '"\C-x\C-_A7\a": vi-movement-mode' + + # "\C-x\C-_A10\a": Replace the command line with READLINE_LINE. When we are + # in the vi-command keymap, we go to vi-insert, input + # "$READLINE_LINE", and come back to vi-command. + bind -m emacs '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' + bind -m vi-insert '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A3\a\C-x\C-_A4\a"' + bind -m vi-command '"\C-x\C-_A10\a": "\C-x\C-_A1\a\C-x\C-_A2\a\C-x\C-_A7\a\C-x\C-_A3\a\C-x\C-_A7\a\C-x\C-_A4\a"' + + __atuin_macro_accept_line='"\C-x\C-_A10\a\C-x\C-_A5\a"' + __atuin_macro_insert_line='"\C-x\C-_A10\a\C-x\C-_A6\a"' + fi + + __atuin_bash42_dispatch_selector= + + __atuin_bash42_dispatch() { + local s=$__atuin_bash42_dispatch_selector + __atuin_bash42_dispatch_selector= + __atuin_widget_run "$((2#0$s))" + } + + __atuin_bind_impl() { + local keymap=$1 keyseq=$2 command=$3 + + __atuin_widget_save "$keymap:$command" + __atuin_bash42_encode "$REPLY" + local macro=$REPLY$__atuin_bash42_code2$__atuin_macro_chain + + bind -m "$keymap" "\"$keyseq\": \"$macro\"" + } +fi + +atuin-bind() { + local keymap= + local OPTIND=1 OPTARG="" OPTERR=0 flag + while getopts ':m:' flag "$@"; do + case $flag in + m) keymap=$OPTARG ;; + *) + printf '%s\n' "atuin-bind: unrecognized option '-$flag'" >&2 + return 2 + ;; + esac + done + shift "$((OPTIND - 1))" + + if (($# != 2)); then + printf '%s\n' 'usage: atuin-bind [-m keymap] keyseq widget' >&2 + return 2 + fi + + local keyseq=$1 + [[ $keymap ]] || keymap=$(bind -v | awk '$2 == "keymap" { print $3 }') + case $keymap in + emacs-meta) keymap=emacs keyseq='\e'$keyseq ;; + emacs-ctlx) keymap=emacs keyseq='\C-x'$keyseq ;; + emacs*) keymap=emacs ;; + vi-insert) ;; + vi*) keymap=vi-command ;; + *) + printf '%s\n' "atuin-bind: unknown keymap '$keymap'" >&2 + return 2 ;; + esac + + local command=$2 widget=${2%%[[:blank:]]*} + case $widget in + atuin-search) command=${2/#"$widget"/__atuin_history} ;; + atuin-search-emacs) command=${2/#"$widget"/__atuin_history --keymap-mode=emacs} ;; + atuin-search-viins) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-insert} ;; + atuin-search-vicmd) command=${2/#"$widget"/__atuin_history --keymap-mode=vim-normal} ;; + atuin-up-search) command=${2/#"$widget"/__atuin_history --shell-up-key-binding} ;; + atuin-up-search-emacs) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=emacs} ;; + atuin-up-search-viins) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-insert} ;; + atuin-up-search-vicmd) command=${2/#"$widget"/__atuin_history --shell-up-key-binding --keymap-mode=vim-normal} ;; + esac + + __atuin_bind_impl "$keymap" "$keyseq" "$command" +} + +#------------------------------------------------------------------------------ + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_ctrl_r == true ]]; then + # Note: We do not overwrite [C-r] in the vi-command keymap because we do + # not want to overwrite "redo", which is already bound to [C-r] in the + # vi_nmap keymap in ble.sh. + atuin-bind -m emacs '\C-r' atuin-search-emacs + atuin-bind -m vi-insert '\C-r' atuin-search-viins + atuin-bind -m vi-command '/' atuin-search-emacs +fi + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_up_arrow == true ]]; then + atuin-bind -m emacs '\e[A' atuin-up-search-emacs + atuin-bind -m emacs '\eOA' atuin-up-search-emacs + atuin-bind -m vi-insert '\e[A' atuin-up-search-viins + atuin-bind -m vi-insert '\eOA' atuin-up-search-viins + atuin-bind -m vi-command '\e[A' atuin-up-search-vicmd + atuin-bind -m vi-command '\eOA' atuin-up-search-vicmd + atuin-bind -m vi-command 'k' atuin-up-search-vicmd +fi + +#------------------------------------------------------------------------------ +fi # (include guard) end of main content diff --git a/crates/turtle/src/shell/atuin.fish b/crates/turtle/src/shell/atuin.fish new file mode 100644 index 00000000..15b33451 --- /dev/null +++ b/crates/turtle/src/shell/atuin.fish @@ -0,0 +1,178 @@ +if not set -q ATUIN_SESSION; or test "$ATUIN_SHLVL" != "$SHLVL" + set -gx ATUIN_SESSION (atuin uuid) + set -gx ATUIN_SHLVL $SHLVL +end +set --erase ATUIN_HISTORY_ID + +function _atuin_osc133_command_executed + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;C\a' +end + +function _atuin_osc133_command_finished --argument-names exit_code + set -q ATUIN_PTY_PROXY_ACTIVE; or return + test -n "$ATUIN_HISTORY_ID"; or return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$exit_code" "$ATUIN_HISTORY_ID" "$ATUIN_SESSION" +end + +function _atuin_preexec --on-event fish_preexec + if not test -n "$fish_private_mode" + set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]" 2>/dev/null) + _atuin_osc133_command_executed + end +end + +function _atuin_postexec --on-event fish_postexec + set -l s $status + + if test -n "$ATUIN_HISTORY_ID" + _atuin_osc133_command_finished $s + ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & + disown + end + + set --erase ATUIN_HISTORY_ID +end + +# Check if tmux popup is available (tmux >= 3.2) +function _atuin_tmux_popup_check + if not test -n "$TMUX" + echo 0 + return + end + + if test "$ATUIN_TMUX_POPUP" = false + echo 0 + return + end + + set -l tmux_version (tmux -V 2>/dev/null | string match -r '\d+\.\d+') + if not test -n "$tmux_version" + echo 0 + return + end + + set -l parts (string split '.' $tmux_version) + set -l m1 $parts[1] + set -l m2 0 + if test (count $parts) -ge 2 + set m2 $parts[2] + end + + if not string match -rq '^[0-9]+$' -- "$m1" + echo 0 + return + end + + if not string match -rq '^[0-9]+$' -- "$m2" + set m2 0 + end + + if test "$m1" -gt 3 2>/dev/null; or begin + test "$m1" -eq 3 2>/dev/null; and test "$m2" -ge 2 2>/dev/null + end + echo 1 + else + echo 0 + end +end + +function _atuin_search + set -l keymap_mode + switch $fish_key_bindings + case fish_vi_key_bindings fish_hybrid_key_bindings + switch $fish_bind_mode + case default + set keymap_mode vim-normal + case insert + set keymap_mode vim-insert + end + case '*' + set keymap_mode emacs + end + + set -l use_tmux_popup (_atuin_tmux_popup_check) + + set -l ATUIN_H + set -l ATUIN_STATUS 0 + if test "$use_tmux_popup" -eq 1 + set -l tmpdir (mktemp -d) + if not test -d "$tmpdir" + # if mktemp got errors + set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) + set ATUIN_STATUS $pipestatus[1] + else + set -l result_file "$tmpdir/result" + + set -l query (commandline -b | string replace -a "'" "'\\''") + set -l escaped_args "" + for arg in $argv + set escaped_args "$escaped_args '"(string replace -a "'" "'\\''" -- $arg)"'" + end + + # In the popup, atuin goes to terminal, stderr goes to file + set -l cdir (pwd) + # Keep default value anyways + set -l popup_width (test -n "$ATUIN_TMUX_POPUP_WIDTH" && echo "$ATUIN_TMUX_POPUP_WIDTH" || echo "80%") + set -l popup_height (test -n "$ATUIN_TMUX_POPUP_HEIGHT" && echo "$ATUIN_TMUX_POPUP_HEIGHT" || echo "60%") + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY='$query' atuin search --keymap-mode=$keymap_mode$escaped_args -i 2>'$result_file'" + set ATUIN_STATUS $status + + if test -f "$result_file" + set ATUIN_H (cat "$result_file" | string collect) + end + + command rm -rf "$tmpdir" + end + else + # In fish 3.4 and above we can use `"$(some command)"` to keep multiple lines separate; + # but to support fish 3.3 we need to use `(some command | string collect)`. + # https://fishshell.com/docs/current/relnotes.html#id24 (fish 3.4 "Notable improvements and fixes") + set ATUIN_H (ATUIN_SHELL=fish ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 3>&- | string collect) + set ATUIN_STATUS $pipestatus[1] + end + + if test "$ATUIN_STATUS" -ne 0 + test -n "$ATUIN_H"; and printf '%s\n' "$ATUIN_H" >&2 + commandline -f repaint + return "$ATUIN_STATUS" + end + + set ATUIN_H (string trim -- $ATUIN_H | string collect) # trim whitespace + + if test -n "$ATUIN_H" + if string match --quiet '__atuin_accept__:*' "$ATUIN_H" + set -l ATUIN_HIST (string replace "__atuin_accept__:" "" -- "$ATUIN_H" | string collect) + commandline -r "$ATUIN_HIST" + commandline -f repaint + commandline -f execute + return + else + commandline -r "$ATUIN_H" + end + end + + commandline -f repaint +end + +function _atuin_bind_up + # Fallback to fish's builtin up-or-search if we're in search or paging mode + if commandline --search-mode; or commandline --paging-mode + up-or-search + return + end + + # Only invoke atuin if we're on the top line of the command + set -l lineno (commandline --line) + + switch $lineno + case 1 + _atuin_search --shell-up-key-binding + case '*' + up-or-search + end +end diff --git a/crates/turtle/src/shell/atuin.nu b/crates/turtle/src/shell/atuin.nu new file mode 100644 index 00000000..d37457e4 --- /dev/null +++ b/crates/turtle/src/shell/atuin.nu @@ -0,0 +1,121 @@ +# Source this in your ~/.config/nushell/config.nu +# minimum supported version = 0.93.0 +module compat { + export def --wrapped "random uuid -v 7" [...rest] { atuin uuid } +} +use (if not ( + (version).major > 0 or + (version).minor >= 103 +) { "compat" }) * + +if 'ATUIN_SESSION' not-in $env or ('ATUIN_SHLVL' not-in $env) or ($env.ATUIN_SHLVL != ($env.SHLVL? | default "")) { + $env.ATUIN_SESSION = (random uuid -v 7 | str replace -a "-" "") + $env.ATUIN_SHLVL = ($env.SHLVL? | default "") +} +hide-env -i ATUIN_HISTORY_ID + +def _atuin_osc133_command_executed [] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;C(char bel)" +} + +def _atuin_osc133_command_finished [exit_code: int] { + if 'ATUIN_PTY_PROXY_ACTIVE' not-in $env { + return + } + if 'ATUIN_HISTORY_ID' not-in $env or ($env.ATUIN_HISTORY_ID | is-empty) { + return + } + + print -n $"(char esc)]133;D;($exit_code);history_id=($env.ATUIN_HISTORY_ID);session_id=($env.ATUIN_SESSION)(char bel)" +} + +# Magic token to make sure we don't record commands run by keybindings +let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" + +let _atuin_pre_execution = {|| + if ($nu | get history-enabled?) == false { + return + } + let cmd = (commandline) + if ($cmd | is-empty) { + return + } + if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { + $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd | complete | get stdout | str trim) + _atuin_osc133_command_executed + } +} + +let _atuin_pre_prompt = {|| + let last_exit = $env.LAST_EXIT_CODE + if 'ATUIN_HISTORY_ID' not-in $env { + return + } + _atuin_osc133_command_finished $last_exit + with-env { ATUIN_LOG: error } { + if (version).minor >= 104 or (version).major > 0 { + job spawn { + ^atuin history end $'--exit=($env.LAST_EXIT_CODE)' -- $env.ATUIN_HISTORY_ID | complete + } | ignore + } else { + do { atuin history end $'--exit=($last_exit)' -- $env.ATUIN_HISTORY_ID } | complete + } + + } + hide-env ATUIN_HISTORY_ID +} + +def _atuin_search_cmd [...flags: string] { + if (version).minor >= 106 or (version).major > 0 { + [ + $ATUIN_KEYBINDING_TOKEN, + ([ + `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline), ATUIN_SHELL: nu } {`, + ([ + 'let output = (run-external atuin search', + ($flags | append [--interactive] | each {|e| $'"($e)"'}), + 'e>| str trim)', + ] | flatten | str join ' '), + 'if ($output | str starts-with "__atuin_accept__:") {', + 'commandline edit --accept ($output | str replace "__atuin_accept__:" "")', + '} else {', + 'commandline edit $output', + '}', + `}`, + ] | flatten | str join "\n"), + ] + } else { + [ + $ATUIN_KEYBINDING_TOKEN, + ([ + `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline) } {`, + 'commandline edit', + '(run-external atuin search', + ($flags | append [--interactive] | each {|e| $'"($e)"'}), + ' e>| str trim)', + `}`, + ] | flatten | str join ' '), + ] + } | str join "\n" +} + +$env.config = ($env | default {} config).config +$env.config = ($env.config | default {} hooks) +$env.config = ( + $env.config | upsert hooks ( + $env.config.hooks + | upsert pre_execution ( + $env.config.hooks | get pre_execution? | default [] | append $_atuin_pre_execution) + | upsert pre_prompt ( + $env.config.hooks | get pre_prompt? | default [] | append $_atuin_pre_prompt) + ) +) + +$env.config = ($env.config | default [] keybindings) diff --git a/crates/turtle/src/shell/atuin.ps1 b/crates/turtle/src/shell/atuin.ps1 new file mode 100644 index 00000000..431ee2c3 --- /dev/null +++ b/crates/turtle/src/shell/atuin.ps1 @@ -0,0 +1,240 @@ +# Atuin PowerShell module +# +# This should support PowerShell 5.1 (which is shipped with Windows) and later versions, on Windows and Linux. +# +# Usage: atuin init powershell | Out-String | Invoke-Expression +# +# Settings: +# - $env:ATUIN_POWERSHELL_PROMPT_OFFSET - Number of lines to offset the prompt position after exiting search. +# This is useful when using a multi-line prompt: e.g. set this to -1 when using a 2-line prompt. +# It is initialized from the current prompt line count if not set when the first Atuin search is performed. + +if (Get-Module Atuin -ErrorAction Ignore) { + if ($PSVersionTable.PSVersion.Major -ge 7) { + Write-Warning "The Atuin module is already loaded, replacing it." + Remove-Module Atuin + } else { + Write-Warning "The Atuin module is already loaded, skipping." + return + } +} + +if (!(Get-Command atuin -ErrorAction Ignore)) { + Write-Error "The 'atuin' executable needs to be available in the PATH." + return +} + +if (!(Get-Module PSReadLine -ErrorAction Ignore)) { + Write-Error "Atuin requires the PSReadLine module to be installed." + return +} + +New-Module -Name Atuin -ScriptBlock { + if (-not $env:ATUIN_SESSION -or $env:ATUIN_PID -ne $PID) { + $env:ATUIN_SESSION = atuin uuid + $env:ATUIN_PID = $PID + } + + $script:atuinHistoryId = $null + $script:previousPSConsoleHostReadLine = $Function:PSConsoleHostReadLine + + # The ReadLine overloads changed with breaking changes over time, make sure the one we expect is available. + $script:hasExpectedReadLineOverload = ([Microsoft.PowerShell.PSConsoleReadLine]::ReadLine).OverloadDefinitions.Contains("static string ReadLine(runspace runspace, System.Management.Automation.EngineIntrinsics engineIntrinsics, System.Threading.CancellationToken cancellationToken, System.Nullable[bool] lastRunStatus)") + + function Get-CommandLine { + $commandLine = "" + [Microsoft.PowerShell.PSConsoleReadLine]::GetBufferState([ref]$commandLine, [ref]$null) + return $commandLine + } + + function Set-CommandLine { + param([string]$Text) + + $commandLine = Get-CommandLine + [Microsoft.PowerShell.PSConsoleReadLine]::Replace(0, $commandLine.Length, $Text) + } + + # This function name is called by PSReadLine to read the next command line to execute. + # We replace it with a custom implementation which adds Atuin support. + function PSConsoleHostReadLine { + ## 1. Collect the exit code of the previous command. + + # This needs to be done as the first thing because any script run will flush $?. + $lastRunStatus = $? + + # Exit statuses are maintained separately for native and PowerShell commands, this needs to be taken into account. + $lastNativeExitCode = $global:LASTEXITCODE + $exitCode = if ($lastRunStatus) { 0 } elseif ($lastNativeExitCode) { $lastNativeExitCode } else { 1 } + + ## 2. Report the status of the previous command to Atuin (atuin history end). + + if ($script:atuinHistoryId) { + try { + # The duration is not recorded in old PowerShell versions, let Atuin handle it. $null arguments are ignored. + $duration = (Get-History -Count 1).Duration.Ticks * 100 + $durationArg = if ($duration) { "--duration=$duration" } else { $null } + + # Fire and forget the atuin history end command to avoid blocking the shell during a potential sync. + $process = New-Object System.Diagnostics.Process + $process.StartInfo.FileName = "atuin" + $process.StartInfo.Arguments = "history end --exit=$exitCode $durationArg -- $script:atuinHistoryId" + $process.StartInfo.UseShellExecute = $false + $process.StartInfo.CreateNoWindow = $true + $process.StartInfo.RedirectStandardInput = $true + $process.StartInfo.RedirectStandardOutput = $true + $process.StartInfo.RedirectStandardError = $true + $process.Start() | Out-Null + $process.StandardInput.Close() + $process.BeginOutputReadLine() + $process.BeginErrorReadLine() + } + catch { + # Ignore errors to avoid breaking the shell. + # An error would occur if the user removes atuin from the PATH, for instance. + } + finally { + $script:atuinHistoryId = $null + } + } + + ## 3. Read the next command line to execute. + + # PSConsoleHostReadLine implementation from PSReadLine, adjusted to support old versions. + Microsoft.PowerShell.Core\Set-StrictMode -Off + + $line = if ($script:hasExpectedReadLineOverload) { + # When the overload we expect is available, we can pass $lastRunStatus to it. + [Microsoft.PowerShell.PSConsoleReadLine]::ReadLine($Host.Runspace, $ExecutionContext, [System.Threading.CancellationToken]::None, $lastRunStatus) + } else { + # Either PSReadLine is older than v2.2.0-beta3, or maybe newer than we expect, so use the function from PSReadLine as-is. + & $script:previousPSConsoleHostReadLine + } + + ## 4. Report the next command line to Atuin (atuin history start). + + # PowerShell doesn't handle double quotes in native command line arguments the same way depending on its version, + # and the value of $PSNativeCommandArgumentPassing - see the about_Parsing help page which explains the breaking changes. + # This makes it unreliable, so we go through an environment variable, which should always be consistent across versions. + try { + $env:ATUIN_COMMAND_LINE = $line + $script:atuinHistoryId = atuin history start --command-from-env + } + catch { + # Ignore errors to avoid breaking the shell, see above. + } + finally { + $env:ATUIN_COMMAND_LINE = $null + } + + $global:LASTEXITCODE = $lastNativeExitCode + return $line + } + + function Invoke-AtuinSearch { + param([string]$ExtraArgs = "") + + $previousOutputEncoding = [System.Console]::OutputEncoding + $resultFile = New-TemporaryFile + $suggestion = "" + $errorOutput = "" + + try { + [System.Console]::OutputEncoding = [System.Text.Encoding]::UTF8 + + # Start-Process does some crazy stuff, just use the Process class directly to have more control. + $process = New-Object System.Diagnostics.Process + $process.StartInfo.FileName = "atuin" + $process.StartInfo.Arguments = "search -i --result-file ""$($resultFile.FullName)"" $ExtraArgs" + $process.StartInfo.UseShellExecute = $false + $process.StartInfo.RedirectStandardError = $true + $process.StartInfo.StandardErrorEncoding = [System.Text.Encoding]::UTF8 + $process.StartInfo.EnvironmentVariables["ATUIN_SHELL"] = "powershell" + $process.StartInfo.EnvironmentVariables["ATUIN_QUERY"] = Get-CommandLine + # PowerShell's Set-Location (cd) doesn't update the process-level working directory, set it explicitly + $process.StartInfo.WorkingDirectory = (Get-Location -PSProvider FileSystem).ProviderPath + + try { + $process.Start() | Out-Null + + # A single stream is redirected, so we can read it synchronously, but we have to start reading it + # before waiting for the process to exit, otherwise the buffer could fill up and cause a deadlock. + $errorOutput = $process.StandardError.ReadToEnd().Trim() + $process.WaitForExit() + + $suggestion = (Get-Content -LiteralPath $resultFile.FullName -Raw -Encoding UTF8 | Out-String).Trim() + } + catch { + $errorOutput = $_ + } + + if ($errorOutput) { + Write-Host -ForegroundColor Red "Atuin error:" + Write-Host -ForegroundColor DarkRed $errorOutput + } + + # If no shell prompt offset is set, initialize it from the current prompt line count. + if ($null -eq $env:ATUIN_POWERSHELL_PROMPT_OFFSET) { + try { + $promptLines = (& $Function:prompt | Out-String | Measure-Object -Line).Lines + $env:ATUIN_POWERSHELL_PROMPT_OFFSET = -1 * ($promptLines - 1) + } + catch { + $env:ATUIN_POWERSHELL_PROMPT_OFFSET = 0 + } + } + + # PSReadLine maintains its own cursor position, which will no longer be valid if Atuin scrolls the display in inline mode. + # Fortunately, InvokePrompt can receive a new Y position and reset the internal state. + $y = $Host.UI.RawUI.CursorPosition.Y + [int]$env:ATUIN_POWERSHELL_PROMPT_OFFSET + $y = [System.Math]::Max([System.Math]::Min($y, [System.Console]::BufferHeight - 1), 0) + [Microsoft.PowerShell.PSConsoleReadLine]::InvokePrompt($null, $y) + + if ($suggestion -eq "") { + # The previous input was already rendered by InvokePrompt + return + } + + $acceptPrefix = "__atuin_accept__:" + + if ( $suggestion.StartsWith($acceptPrefix)) { + Set-CommandLine $suggestion.Substring($acceptPrefix.Length) + [Microsoft.PowerShell.PSConsoleReadLine]::AcceptLine() + } else { + Set-CommandLine $suggestion + } + } + finally { + [System.Console]::OutputEncoding = $previousOutputEncoding + $resultFile.Delete() + } + } + + function Enable-AtuinSearchKeys { + param([bool]$CtrlR = $true, [bool]$UpArrow = $true) + + if ($CtrlR) { + Set-PSReadLineKeyHandler -Chord "Ctrl+r" -BriefDescription "Runs Atuin search" -ScriptBlock { + Invoke-AtuinSearch + } + } + + if ($UpArrow) { + Set-PSReadLineKeyHandler -Chord "UpArrow" -BriefDescription "Runs Atuin search" -ScriptBlock { + $line = Get-CommandLine + + if (!$line.Contains("`n")) { + Invoke-AtuinSearch -ExtraArgs "--shell-up-key-binding" + } else { + [Microsoft.PowerShell.PSConsoleReadLine]::PreviousLine() + } + } + } + } + + $ExecutionContext.SessionState.Module.OnRemove += { + $env:ATUIN_SESSION = $null + $Function:PSConsoleHostReadLine = $script:previousPSConsoleHostReadLine + } + + Export-ModuleMember -Function @("Enable-AtuinSearchKeys", "PSConsoleHostReadLine") +} | Import-Module -Global diff --git a/crates/turtle/src/shell/atuin.xsh b/crates/turtle/src/shell/atuin.xsh new file mode 100644 index 00000000..a0283402 --- /dev/null +++ b/crates/turtle/src/shell/atuin.xsh @@ -0,0 +1,86 @@ +import os +import subprocess + +from prompt_toolkit.application.current import get_app +from prompt_toolkit.filters import Condition +from prompt_toolkit.keys import Keys + + +if "ATUIN_SESSION" not in ${...} or ${...}.get("ATUIN_SHLVL", "") != ${...}.get("SHLVL", ""): + $ATUIN_SESSION=$(atuin uuid).rstrip('\n') + $ATUIN_SHLVL = ${...}.get("SHLVL", "") + +@events.on_precommand +def _atuin_precommand(cmd: str): + cmd = cmd.rstrip("\n") + try: + $ATUIN_HISTORY_ID = $(atuin history start -- @(cmd) 2>@(os.devnull)).rstrip("\n") + except: + $ATUIN_HISTORY_ID = "" + + +@events.on_postcommand +def _atuin_postcommand(cmd: str, rtn: int, out, ts): + if "ATUIN_HISTORY_ID" not in ${...}: + return + + duration = ts[1] - ts[0] + # Duration is float representing seconds, but atuin expects integer of nanoseconds + nanos = round(duration * 10 ** 9) + with ${...}.swap(ATUIN_LOG="error"): + # This causes the entire .xonshrc to be re-executed, which is incredibly slow + # This happens when using a subshell and using output redirection at the same time + # For more details, see https://github.com/xonsh/xonsh/issues/5224 + # (atuin history end --exit @(rtn) -- $ATUIN_HISTORY_ID &) > /dev/null 2>&1 + atuin history end --exit @(rtn) --duration @(nanos) -- $ATUIN_HISTORY_ID > @(os.devnull) 2>&1 + del $ATUIN_HISTORY_ID + + +def _search(event, extra_args: list[str]): + buffer = event.current_buffer + cmd = ["atuin", "search", "--interactive", *extra_args] + # We need to explicitly pass in xonsh env, in case user has set XDG_HOME or something else that matters + env = ${...}.detype() + env["ATUIN_SHELL"] = "xonsh" + env["ATUIN_QUERY"] = buffer.text + + p = subprocess.run(cmd, stderr=subprocess.PIPE, encoding="utf-8", env=env) + result = p.stderr.rstrip("\n") + # redraw prompt - necessary if atuin is configured to run inline, rather than fullscreen + event.cli.renderer.erase() + + if not result: + return + + buffer.reset() + if result.startswith("__atuin_accept__:"): + buffer.insert_text(result[17:]) + buffer.validate_and_handle() + else: + buffer.insert_text(result) + + +@events.on_ptk_create +def _custom_keybindings(bindings, **kw): + if _ATUIN_BIND_CTRL_R: + @bindings.add(Keys.ControlR) + def r_search(event): + _search(event, extra_args=[]) + + if _ATUIN_BIND_UP_ARROW: + @Condition + def should_search(): + buffer = get_app().current_buffer + # disable keybind when there is an active completion, so + # that up arrow can be used to navigate completion menu + if buffer.complete_state is not None: + return False + # similarly, disable when buffer text contains multiple lines + if '\n' in buffer.text: + return False + + return True + + @bindings.add(Keys.Up, filter=should_search) + def up_search(event): + _search(event, extra_args=["--shell-up-key-binding"]) diff --git a/crates/turtle/src/shell/atuin.zsh b/crates/turtle/src/shell/atuin.zsh new file mode 100644 index 00000000..7a7375aa --- /dev/null +++ b/crates/turtle/src/shell/atuin.zsh @@ -0,0 +1,221 @@ +# shellcheck disable=SC2034,SC2153,SC2086,SC2155 + +# Above line is because shellcheck doesn't support zsh, per +# https://github.com/koalaman/shellcheck/wiki/SC1071, and the ignore: param in +# ludeeus/action-shellcheck only supports _directories_, not _files_. So +# instead, we manually add any error the shellcheck step finds in the file to +# the above line ... + +# Source this in your ~/.zshrc +autoload -U add-zsh-hook + +zmodload zsh/datetime 2>/dev/null + +# If zsh-autosuggestions is installed, configure it to use Atuin's search. If +# you'd like to override this, then add your config after the $(atuin init zsh) +# in your .zshrc +_zsh_autosuggest_strategy_atuin() { + # silence errors, since we don't want to spam the terminal prompt while typing. + suggestion=$(ATUIN_QUERY="$1" atuin search --cmd-only --limit 1 --search-mode prefix 2>/dev/null) +} + +if [ -n "${ZSH_AUTOSUGGEST_STRATEGY:-}" ]; then + ZSH_AUTOSUGGEST_STRATEGY=("atuin" "${ZSH_AUTOSUGGEST_STRATEGY[@]}") +else + ZSH_AUTOSUGGEST_STRATEGY=("atuin") +fi + +if [[ -z "${ATUIN_SESSION:-}" || "${ATUIN_SHLVL:-}" != "$SHLVL" ]]; then + export ATUIN_SESSION=$(atuin uuid) + export ATUIN_SHLVL=$SHLVL +fi +ATUIN_HISTORY_ID="" + +__atuin_osc133_command_executed() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;C\a' +} + +__atuin_osc133_command_finished() { + [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]] || return + [[ -n "${ATUIN_HISTORY_ID:-}" ]] || return + + printf '\033]133;D;%s;history_id=%s;session_id=%s\a' "$1" "$ATUIN_HISTORY_ID" "${ATUIN_SESSION:-}" +} + +__atuin_osc133_prompt_start=$'%{\033]133;A;cl=line\a%}' +__atuin_osc133_prompt_end=$'%{\033]133;B\a%}' + +__atuin_osc133_wrap_prompt() { + local __atuin_prompt="${PROMPT-}" + local __atuin_rprompt="${RPROMPT-}" + + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_start/}" + __atuin_prompt="${__atuin_prompt//$__atuin_osc133_prompt_end/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_start/}" + __atuin_rprompt="${__atuin_rprompt//$__atuin_osc133_prompt_end/}" + + if [[ -n "${ATUIN_PTY_PROXY_ACTIVE:-}" ]]; then + PROMPT="${__atuin_osc133_prompt_start}${__atuin_prompt}" + RPROMPT="${__atuin_rprompt}${__atuin_osc133_prompt_end}" + else + PROMPT="$__atuin_prompt" + RPROMPT="$__atuin_rprompt" + fi +} + +_atuin_preexec() { + local id + id=$(atuin history start -- "$1" 2>/dev/null) + export ATUIN_HISTORY_ID="$id" + __atuin_osc133_command_executed + __atuin_preexec_time=${EPOCHREALTIME-} +} + +_atuin_precmd() { + local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} + + __atuin_osc133_wrap_prompt + + [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return + + local duration="" + if [[ -n $__atuin_preexec_time && -n $__atuin_precmd_time ]]; then + printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) + fi + + __atuin_osc133_command_finished "$EXIT" + (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +# Check if tmux popup is available (tmux >= 3.2) +__atuin_tmux_popup_check() { + [[ -n "${TMUX-}" ]] || return 1 + [[ "${ATUIN_TMUX_POPUP:-true}" != "false" ]] || return 1 + + # https://github.com/tmux/tmux/wiki/FAQ#how-often-is-tmux-released-what-is-the-version-number-scheme + local tmux_version + tmux_version=$(tmux -V 2>/dev/null | sed -n 's/^[^0-9]*\([0-9][0-9]*\.[0-9][0-9]*\).*/\1/p') # Could have used grep... + [[ -z "$tmux_version" ]] && return 1 + + local m1 m2 + m1=${tmux_version%%.*} + m2=${tmux_version#*.} + m2=${m2%%.*} + [[ "$m1" =~ ^[0-9]+$ ]] || return 1 + [[ "$m2" =~ ^[0-9]+$ ]] || m2=0 + (( m1 > 3 || (m1 == 3 && m2 >= 2) )) +} + +# Use global variable to fix scope issues with traps +__atuin_popup_tmpdir="" +__atuin_tmux_popup_cleanup() { + [[ -n "$__atuin_popup_tmpdir" && -d "$__atuin_popup_tmpdir" ]] && command rm -rf "$__atuin_popup_tmpdir" + __atuin_popup_tmpdir="" +} + +__atuin_search_cmd() { + local -a search_args=("$@") + + if __atuin_tmux_popup_check; then + __atuin_popup_tmpdir=$(mktemp -d) || return 1 + local result_file="$__atuin_popup_tmpdir/result" + + trap '__atuin_tmux_popup_cleanup' EXIT HUP INT TERM + + local escaped_query escaped_args + escaped_query=$(printf '%s' "$BUFFER" | sed "s/'/'\\\\''/g") + escaped_args="" + for arg in "${search_args[@]}"; do + escaped_args+=" '$(printf '%s' "$arg" | sed "s/'/'\\\\''/g")'" + done + + # In the popup, atuin goes to terminal, stderr goes to file + local cdir popup_width popup_height + cdir=$(pwd) + popup_width="${ATUIN_TMUX_POPUP_WIDTH:-80%}" # Keep default value anyways + popup_height="${ATUIN_TMUX_POPUP_HEIGHT:-60%}" + tmux display-popup -d "$cdir" -w "$popup_width" -h "$popup_height" -E -E -- \ + sh -c "PATH='$PATH' ATUIN_SESSION='$ATUIN_SESSION' ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY='$escaped_query' atuin search $escaped_args -i 2>'$result_file'" + + if [[ -f "$result_file" ]]; then + cat "$result_file" + fi + + __atuin_tmux_popup_cleanup + trap - EXIT HUP INT TERM + else + ATUIN_SHELL=zsh ATUIN_LOG=error ATUIN_QUERY=$BUFFER atuin search "${search_args[@]}" -i 3>&1 1>&2 2>&3 3>&- + fi +} + +_atuin_search() { + emulate -L zsh + zle -I + + # swap stderr and stdout, so that the tui stuff works + # TODO: not this + local output __atuin_status + # shellcheck disable=SC2048 + output=$(__atuin_search_cmd $*) + __atuin_status=$? + + zle reset-prompt + # re-enable bracketed paste + # shellcheck disable=SC2154 + echo -n ${zle_bracketed_paste[1]} >/dev/tty + + if (( __atuin_status != 0 )); then + [[ -n $output ]] && print -r -- "$output" >/dev/tty + return $__atuin_status + fi + + if [[ -n $output ]]; then + RBUFFER="" + LBUFFER=$output + + if [[ $LBUFFER == __atuin_accept__:* ]] + then + LBUFFER=${LBUFFER#__atuin_accept__:} + zle accept-line + fi + fi +} +_atuin_search_vicmd() { + _atuin_search --keymap-mode=vim-normal +} +_atuin_search_viins() { + _atuin_search --keymap-mode=vim-insert +} + +_atuin_up_search() { + # Only trigger if the buffer is a single line + if [[ ! $BUFFER == *$'\n'* ]]; then + _atuin_search --shell-up-key-binding "$@" + else + zle up-line + fi +} +_atuin_up_search_vicmd() { + _atuin_up_search --keymap-mode=vim-normal +} +_atuin_up_search_viins() { + _atuin_up_search --keymap-mode=vim-insert +} + +add-zsh-hook preexec _atuin_preexec +add-zsh-hook precmd _atuin_precmd + +zle -N atuin-search _atuin_search +zle -N atuin-search-vicmd _atuin_search_vicmd +zle -N atuin-search-viins _atuin_search_viins +zle -N atuin-up-search _atuin_up_search +zle -N atuin-up-search-vicmd _atuin_up_search_vicmd +zle -N atuin-up-search-viins _atuin_up_search_viins + +# These are compatibility widget names for "atuin <= 17.2.1" users. +zle -N _atuin_search_widget _atuin_search +zle -N _atuin_up_search_widget _atuin_up_search diff --git a/crates/turtle/src/sync.rs b/crates/turtle/src/sync.rs new file mode 100644 index 00000000..56aef615 --- /dev/null +++ b/crates/turtle/src/sync.rs @@ -0,0 +1,34 @@ +use eyre::{Context, Result}; + +use crate::atuin_client::{ + database::Database, history::store::HistoryStore, record::sqlite_store::SqliteStore, + settings::Settings, +}; +use crate::atuin_common::record::RecordId; + +// This is the only crate that ties together all other crates. +// Therefore, it's the only crate where functions tying together all stores can live + +/// Rebuild all stores after a sync +/// Note: for history, this only does an _incremental_ sync. Hence the need to specify downloaded +/// records. +pub async fn build( + settings: &Settings, + store: &SqliteStore, + db: &dyn Database, + downloaded: Option<&[RecordId]>, +) -> Result<()> { + let encryption_key: [u8; 32] = crate::atuin_client::encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().await?; + + let downloaded = downloaded.unwrap_or(&[]); + + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + history_store.incremental_build(db, downloaded).await?; + + Ok(()) +} -- cgit v1.3.1