diff options
Diffstat (limited to 'crates')
155 files changed, 19783 insertions, 0 deletions
diff --git a/crates/atuin-client/Cargo.toml b/crates/atuin-client/Cargo.toml new file mode 100644 index 00000000..c8ca74ae --- /dev/null +++ b/crates/atuin-client/Cargo.toml @@ -0,0 +1,73 @@ +[package] +name = "atuin-client" +edition = "2021" +description = "client library for atuin" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = ["sync"] +sync = ["urlencoding", "reqwest", "sha2", "hex"] +check-update = [] + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } + +log = { workspace = true } +base64 = { workspace = true } +time = { workspace = true, features = ["macros", "formatting"] } +clap = { workspace = true } +eyre = { workspace = true } +directories = { workspace = true } +uuid = { workspace = true } +whoami = { workspace = true } +interim = { workspace = true } +config = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +parse_duration = "2.1.1" +async-trait = { workspace = true } +itertools = { workspace = true } +rand = { workspace = true } +shellexpand = "3" +sqlx = { workspace = true, features = ["sqlite", "regexp"] } +minspan = "0.1.1" +regex = "1.10.4" +serde_regex = "1.1.0" +fs-err = { workspace = true } +sql-builder = "3" +memchr = "2.5" +rmp = { version = "0.8.11" } +typed-builder = { workspace = true } +tokio = { workspace = true } +semver = { workspace = true } +thiserror = { workspace = true } +futures = "0.3" +crypto_secretbox = "0.1.1" +generic-array = { version = "0.14", features = ["serde"] } +serde_with = "3.5.1" + +# encryption +rusty_paseto = { version = "0.6.0", default-features = false } +rusty_paserk = { version = "0.3.0", default-features = false, features = [ + "v4", + "serde", +] } + +# sync +urlencoding = { version = "2.1.0", optional = true } +reqwest = { workspace = true, optional = true } +hex = { version = "0.4", optional = true } +sha2 = { version = "0.10", optional = true } +indicatif = "0.17.7" + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } +pretty_assertions = { workspace = true } diff --git a/crates/atuin-client/config.toml b/crates/atuin-client/config.toml new file mode 100644 index 00000000..415fd441 --- /dev/null +++ b/crates/atuin-client/config.toml @@ -0,0 +1,210 @@ +## where to store your database, default is your system data directory +## linux/mac: ~/.local/share/atuin/history.db +## windows: %USERPROFILE%/.local/share/atuin/history.db +# db_path = "~/.history.db" + +## where to store your encryption key, default is your system data directory +## linux/mac: ~/.local/share/atuin/key +## windows: %USERPROFILE%/.local/share/atuin/key +# key_path = "~/.key" + +## where to store your auth session token, default is your system data directory +## linux/mac: ~/.local/share/atuin/session +## windows: %USERPROFILE%/.local/share/atuin/session +# session_path = "~/.session" + +## date format used, either "us" or "uk" +# dialect = "us" + +## default timezone to use when displaying time +## either "l", "local" to use the system's current local timezone, or an offset +## from UTC in the format of "<+|->H[H][:M[M][:S[S]]]" +## for example: "+9", "-05", "+03:30", "-01:23:45", etc. +# timezone = "local" + +## enable or disable automatic sync +# auto_sync = true + +## enable or disable automatic update checks +# update_check = true + +## address of the sync server +# sync_address = "https://api.atuin.sh" + +## how often to sync history. note that this is only triggered when a command +## is ran, so sync intervals may well be longer +## set it to 0 to sync after every command +# sync_frequency = "10m" + +## which search mode to use +## possible values: prefix, fulltext, fuzzy, skim +# search_mode = "fuzzy" + +## which filter mode to use +## possible values: global, host, session, directory +# filter_mode = "global" + +## With workspace filtering enabled, Atuin will filter for commands executed +## in any directory within a git repository tree (default: false) +# workspaces = false + +## which filter mode to use when atuin is invoked from a shell up-key binding +## the accepted values are identical to those of "filter_mode" +## leave unspecified to use same mode set in "filter_mode" +# filter_mode_shell_up_key_binding = "global" + +## which search mode to use when atuin is invoked from a shell up-key binding +## the accepted values are identical to those of "search_mode" +## leave unspecified to use same mode set in "search_mode" +# search_mode_shell_up_key_binding = "fuzzy" + +## which style to use +## possible values: auto, full, compact +# style = "auto" + +## the maximum number of lines the interface should take up +## set it to 0 to always go full screen +# inline_height = 0 + +## Invert the UI - put the search bar at the top , Default to `false` +# invert = false + +## enable or disable showing a preview of the selected command +## useful when the command is longer than the terminal width and is cut off +# show_preview = false + +## enable or disable automatic preview. It shows a preview, if the command is +## longer than the width of the terminal. It respects max_preview_height. +## (default: true) +# show_preview_auto = true + +## what to do when the escape key is pressed when searching +## possible values: return-original, return-query +# exit_mode = "return-original" + +## possible values: emacs, subl +# word_jump_mode = "emacs" + +## characters that count as a part of a word +# word_chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +## number of context lines to show when scrolling by pages +# scroll_context_lines = 1 + +## use ctrl instead of alt as the shortcut modifier key for numerical UI shortcuts +## alt-0 .. alt-9 +# ctrl_n_shortcuts = false + +## default history list format - can also be specified with the --format arg +# history_format = "{time}\t{command}\t{duration}" + +## prevent commands matching any of these regexes from being written to history. +## Note that these regular expressions are unanchored, i.e. if they don't start +## with ^ or end with $, they'll match anywhere in the command. +## For details on the supported regular expression syntax, see +## https://docs.rs/regex/latest/regex/#syntax +# history_filter = [ +# "^secret-cmd", +# "^innocuous-cmd .*--secret=.+", +# ] + +## prevent commands run with cwd matching any of these regexes from being written +## to history. Note that these regular expressions are unanchored, i.e. if they don't +## start with ^ or end with $, they'll match anywhere in CWD. +## For details on the supported regular expression syntax, see +## https://docs.rs/regex/latest/regex/#syntax +# cwd_filter = [ +# "^/very/secret/area", +# ] + +## Configure the maximum height of the preview to show. +## Useful when you have long scripts in your history that you want to distinguish +## by more than the first few lines. +# max_preview_height = 4 + +## Configure whether or not to show the help row, which includes the current Atuin +## version (and whether an update is available), a keymap hint, and the total +## amount of commands in your history. +# show_help = true + +## Configure whether or not to show tabs for search and inspect +# show_tabs = true + +## Defaults to true. This matches history against a set of default regex, and will not save it if we get a match. Defaults include +## 1. AWS key id +## 2. Github pat (old and new) +## 3. Slack oauth tokens (bot, user) +## 4. Slack webhooks +## 5. Stripe live/test keys +# secrets_filter = true + +## Defaults to true. If enabled, upon hitting enter Atuin will immediately execute the command. Press tab to return to the shell and edit. +# This applies for new installs. Old installs will keep the old behaviour unless configured otherwise. +enter_accept = true + +## Defaults to "emacs". This specifies the keymap on the startup of `atuin +## search`. If this is set to "auto", the startup keymap mode in the Atuin +## search is automatically selected based on the shell's keymap where the +## keybinding is defined. If this is set to "emacs", "vim-insert", or +## "vim-normal", the startup keymap mode in the Atuin search is forced to be +## the specified one. +# keymap_mode = "auto" + +## Cursor style in each keymap mode. If specified, the cursor style is changed +## in entering the cursor shape. Available values are "default" and +## "{blink,steady}-{block,underline,bar}". +# keymap_cursor = { emacs = "blink-block", vim_insert = "blink-block", vim_normal = "steady-block" } + +# network_connect_timeout = 5 +# network_timeout = 5 + +## Timeout (in seconds) for acquiring a local database connection (sqlite) +# local_timeout = 5 + +## Set this to true and Atuin will minimize motion in the UI - timers will not update live, etc. +## Alternatively, set env NO_MOTION=true +# prefers_reduced_motion = false + +[stats] +## Set commands where we should consider the subcommand for statistics. Eg, kubectl get vs just kubectl +# common_subcommands = [ +# "apt", +# "cargo", +# "composer", +# "dnf", +# "docker", +# "git", +# "go", +# "ip", +# "kubectl", +# "nix", +# "nmcli", +# "npm", +# "pecl", +# "pnpm", +# "podman", +# "port", +# "systemctl", +# "tmux", +# "yarn", +# ] + +## Set commands that should be totally stripped and ignored from stats +# common_prefix = ["sudo"] + +## Set commands that will be completely ignored from stats +# ignored_commands = [ +# "cd", +# "ls", +# "vi" +# ] + +[keys] +# Defaults to true. If disabled, using the up/down key won't exit the TUI when scrolled past the first/last entry. +# scroll_exits = false + +[sync] +# Enable sync v2 by default +# This ensures that sync v2 is enabled for new installs only +# In a later release it will become the default across the board +records = true diff --git a/crates/atuin-client/migrations/20210422143411_create_history.sql b/crates/atuin-client/migrations/20210422143411_create_history.sql new file mode 100644 index 00000000..1f3f8686 --- /dev/null +++ b/crates/atuin-client/migrations/20210422143411_create_history.sql @@ -0,0 +1,16 @@ +-- Add migration script here +create table if not exists history ( + id text primary key, + timestamp integer not null, + duration integer not null, + exit integer not null, + command text not null, + cwd text not null, + session text not null, + hostname text not null, + + unique(timestamp, cwd, command) +); + +create index if not exists idx_history_timestamp on history(timestamp); +create index if not exists idx_history_command on history(command); diff --git a/crates/atuin-client/migrations/20220505083406_create-events.sql b/crates/atuin-client/migrations/20220505083406_create-events.sql new file mode 100644 index 00000000..f6cafeba --- /dev/null +++ b/crates/atuin-client/migrations/20220505083406_create-events.sql @@ -0,0 +1,11 @@ +create table if not exists events ( + id text primary key, + timestamp integer not null, + hostname text not null, + event_type text not null, + + history_id text not null +); + +-- Ensure there is only ever one of each event type per history item +create unique index history_event_idx ON events(event_type, history_id); diff --git a/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql b/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql new file mode 100644 index 00000000..b5770e62 --- /dev/null +++ b/crates/atuin-client/migrations/20220806155627_interactive_search_index.sql @@ -0,0 +1,6 @@ +-- Interactive search filters by command then by the max(timestamp) for that +-- command. Create an index that covers those +create index if not exists idx_history_command_timestamp on history( + command, + timestamp +); diff --git a/crates/atuin-client/migrations/20230315220114_drop-events.sql b/crates/atuin-client/migrations/20230315220114_drop-events.sql new file mode 100644 index 00000000..fe3cae17 --- /dev/null +++ b/crates/atuin-client/migrations/20230315220114_drop-events.sql @@ -0,0 +1,2 @@ +-- Add migration script here +drop table events; diff --git a/crates/atuin-client/migrations/20230319185725_deleted_at.sql b/crates/atuin-client/migrations/20230319185725_deleted_at.sql new file mode 100644 index 00000000..6c422abc --- /dev/null +++ b/crates/atuin-client/migrations/20230319185725_deleted_at.sql @@ -0,0 +1,2 @@ +-- Add migration script here +alter table history add column deleted_at integer; diff --git a/crates/atuin-client/record-migrations/20230531212437_create-records.sql b/crates/atuin-client/record-migrations/20230531212437_create-records.sql new file mode 100644 index 00000000..4f4b304a --- /dev/null +++ b/crates/atuin-client/record-migrations/20230531212437_create-records.sql @@ -0,0 +1,16 @@ +-- Add migration script here +create table if not exists records ( + id text primary key, + parent text unique, -- null if this is the first one + host text not null, + + timestamp integer not null, + tag text not null, + version text not null, + data blob not null, + cek blob not null +); + +create index host_idx on records (host); +create index tag_idx on records (tag); +create index host_tag_idx on records (host, tag); diff --git a/crates/atuin-client/record-migrations/20231127090831_create-store.sql b/crates/atuin-client/record-migrations/20231127090831_create-store.sql new file mode 100644 index 00000000..53d78860 --- /dev/null +++ b/crates/atuin-client/record-migrations/20231127090831_create-store.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table if not exists store ( + id text primary key, -- globally unique ID + + idx integer, -- incrementing integer ID unique per (host, tag) + host text not null, -- references the host row + tag text not null, + + timestamp integer not null, + version text not null, + data blob not null, + cek blob not null +); + +create unique index record_uniq ON store(host, tag, idx); diff --git a/crates/atuin-client/src/api_client.rs b/crates/atuin-client/src/api_client.rs new file mode 100644 index 00000000..f31a796e --- /dev/null +++ b/crates/atuin-client/src/api_client.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use eyre::{bail, Result}; +use reqwest::{ + header::{HeaderMap, AUTHORIZATION, USER_AGENT}, + Response, StatusCode, Url, +}; + +use atuin_common::{ + api::{ + AddHistoryRequest, ChangePasswordRequest, CountResponse, DeleteHistoryRequest, + ErrorResponse, LoginRequest, LoginResponse, MeResponse, RegisterResponse, StatusResponse, + SyncHistoryResponse, + }, + record::RecordStatus, +}; +use atuin_common::{ + api::{ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION, ATUIN_VERSION}, + record::{EncryptedData, HostId, Record, RecordIdx}, +}; + +use semver::Version; +use time::format_description::well_known::Rfc3339; +use time::OffsetDateTime; + +use crate::{history::History, sync::hash_str, utils::get_host_user}; + +static APP_USER_AGENT: &str = concat!("atuin/", env!("CARGO_PKG_VERSION"),); + +pub struct Client<'a> { + sync_addr: &'a str, + client: reqwest::Client, +} + +pub async fn register( + address: &str, + username: &str, + email: &str, + password: &str, +) -> Result<RegisterResponse> { + let mut map = HashMap::new(); + map.insert("username", username); + map.insert("email", email); + map.insert("password", password); + + let url = format!("{address}/user/{username}"); + let resp = reqwest::get(url).await?; + + if resp.status().is_success() { + bail!("username already in use"); + } + + let url = format!("{address}/register"); + let client = reqwest::Client::new(); + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .header(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION) + .json(&map) + .send() + .await?; + + if !ensure_version(&resp)? { + bail!("could not register user due to version mismatch"); + } + + if !resp.status().is_success() { + let error = resp.json::<ErrorResponse>().await?; + bail!("failed to register user: {}", error.reason); + } + + let session = resp.json::<RegisterResponse>().await?; + Ok(session) +} + +pub async fn login(address: &str, req: LoginRequest) -> Result<LoginResponse> { + let url = format!("{address}/login"); + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .header(USER_AGENT, APP_USER_AGENT) + .json(&req) + .send() + .await?; + + if !ensure_version(&resp)? { + bail!("could not login due to version mismatch"); + } + + if resp.status() != reqwest::StatusCode::OK { + let error = resp.json::<ErrorResponse>().await?; + bail!("invalid login details: {}", error.reason); + } + + let session = resp.json::<LoginResponse>().await?; + Ok(session) +} + +#[cfg(feature = "check-update")] +pub async fn latest_version() -> Result<Version> { + use atuin_common::api::IndexResponse; + + let url = "https://api.atuin.sh"; + let client = reqwest::Client::new(); + + let resp = client + .get(url) + .header(USER_AGENT, APP_USER_AGENT) + .send() + .await?; + + if resp.status() != reqwest::StatusCode::OK { + let error = resp.json::<ErrorResponse>().await?; + bail!("failed to check latest version: {}", error.reason); + } + + let index = resp.json::<IndexResponse>().await?; + let version = Version::parse(index.version.as_str())?; + + Ok(version) +} + +pub fn ensure_version(response: &Response) -> Result<bool> { + let version = response.headers().get(ATUIN_HEADER_VERSION); + + let version = if let Some(version) = version { + match version.to_str() { + Ok(v) => Version::parse(v), + Err(e) => bail!("failed to parse server version: {:?}", e), + } + } else { + // if there is no version header, then the newest this server can possibly be is 17.1.0 + Version::parse("17.1.0") + }?; + + // If the client is newer than the server + if version.major < ATUIN_VERSION.major { + println!("Atuin version mismatch! In order to successfully sync, the server needs to run a newer version of Atuin"); + println!("Client: {}", ATUIN_CARGO_VERSION); + println!("Server: {}", version); + + return Ok(false); + } + + Ok(true) +} + +async fn handle_resp_error(resp: Response) -> Result<Response> { + let status = resp.status(); + + if status == StatusCode::SERVICE_UNAVAILABLE { + bail!( + "Service unavailable: check https://status.atuin.sh (or get in touch with your host)" + ); + } + + if !status.is_success() { + if let Ok(error) = resp.json::<ErrorResponse>().await { + let reason = error.reason; + + if status.is_client_error() { + bail!("Could not fetch history, client error {status}: {reason}.") + } + + bail!("There was an error with the atuin sync service, server error {status}: {reason}.\nIf the problem persists, contact the host") + } + + bail!("There was an error with the atuin sync service: Status {status:?}.\nIf the problem persists, contact the host") + } + + Ok(resp) +} + +impl<'a> Client<'a> { + pub fn new( + sync_addr: &'a str, + session_token: &str, + connect_timeout: u64, + timeout: u64, + ) -> Result<Self> { + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, format!("Token {session_token}").parse()?); + + // used for semver server check + headers.insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse()?); + + Ok(Client { + sync_addr, + client: reqwest::Client::builder() + .user_agent(APP_USER_AGENT) + .default_headers(headers) + .connect_timeout(Duration::new(connect_timeout, 0)) + .timeout(Duration::new(timeout, 0)) + .build()?, + }) + } + + pub async fn count(&self) -> Result<i64> { + let url = format!("{}/sync/count", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + if resp.status() != StatusCode::OK { + bail!("failed to get count (are you logged in?)"); + } + + let count = resp.json::<CountResponse>().await?; + + Ok(count.count) + } + + pub async fn status(&self) -> Result<StatusResponse> { + let url = format!("{}/sync/status", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync due to version mismatch"); + } + + let status = resp.json::<StatusResponse>().await?; + + Ok(status) + } + + pub async fn me(&self) -> Result<MeResponse> { + let url = format!("{}/api/v0/me", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let status = resp.json::<MeResponse>().await?; + + Ok(status) + } + + pub async fn get_history( + &self, + sync_ts: OffsetDateTime, + history_ts: OffsetDateTime, + host: Option<String>, + ) -> Result<SyncHistoryResponse> { + let host = host.unwrap_or_else(|| hash_str(&get_host_user())); + + let url = format!( + "{}/sync/history?sync_ts={}&history_ts={}&host={}", + self.sync_addr, + urlencoding::encode(sync_ts.format(&Rfc3339)?.as_str()), + urlencoding::encode(history_ts.format(&Rfc3339)?.as_str()), + host, + ); + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let history = resp.json::<SyncHistoryResponse>().await?; + Ok(history) + } + + pub async fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let url = format!("{}/history", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.post(url).json(history).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_history(&self, h: History) -> Result<()> { + let url = format!("{}/history", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .delete(url) + .json(&DeleteHistoryRequest { + client_id: h.id.to_string(), + }) + .send() + .await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn delete_store(&self) -> Result<()> { + let url = format!("{}/api/v0/store", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn post_records(&self, records: &[Record<EncryptedData>]) -> Result<()> { + let url = format!("{}/api/v0/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + debug!("uploading {} records to {url}", records.len()); + + let resp = self.client.post(url).json(records).send().await?; + handle_resp_error(resp).await?; + + Ok(()) + } + + pub async fn next_records( + &self, + host: HostId, + tag: String, + start: RecordIdx, + count: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + debug!( + "fetching record/s from host {}/{}/{}", + host.0.to_string(), + tag, + start + ); + + let url = format!( + "{}/api/v0/record/next?host={}&tag={}&count={}&start={}", + self.sync_addr, host.0, tag, count, start + ); + + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + let records = resp.json::<Vec<Record<EncryptedData>>>().await?; + + Ok(records) + } + + pub async fn record_status(&self) -> Result<RecordStatus> { + let url = format!("{}/api/v0/record", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.get(url).send().await?; + let resp = handle_resp_error(resp).await?; + + if !ensure_version(&resp)? { + bail!("could not sync records due to version mismatch"); + } + + let index = resp.json().await?; + + debug!("got remote index {:?}", index); + + Ok(index) + } + + pub async fn delete(&self) -> Result<()> { + let url = format!("{}/account", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self.client.delete(url).send().await?; + + if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } + + pub async fn change_password( + &self, + current_password: String, + new_password: String, + ) -> Result<()> { + let url = format!("{}/account/password", self.sync_addr); + let url = Url::parse(url.as_str())?; + + let resp = self + .client + .patch(url) + .json(&ChangePasswordRequest { + current_password, + new_password, + }) + .send() + .await?; + + dbg!(&resp); + + if resp.status() == 401 { + bail!("current password is incorrect") + } else if resp.status() == 403 { + bail!("invalid login details"); + } else if resp.status() == 200 { + Ok(()) + } else { + bail!("Unknown error"); + } + } +} diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs new file mode 100644 index 00000000..7faa3802 --- /dev/null +++ b/crates/atuin-client/src/database.rs @@ -0,0 +1,1128 @@ +use std::{ + borrow::Cow, + env, + path::{Path, PathBuf}, + str::FromStr, + time::Duration, +}; + +use async_trait::async_trait; +use atuin_common::utils; +use fs_err as fs; +use itertools::Itertools; +use rand::{distributions::Alphanumeric, Rng}; +use sql_builder::{bind::Bind, esc, quote, SqlBuilder, SqlName}; +use sqlx::{ + sqlite::{ + SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow, + SqliteSynchronous, + }, + Result, Row, +}; +use time::OffsetDateTime; + +use crate::{ + history::{HistoryId, HistoryStats}, + utils::get_host_user, +}; + +use super::{ + history::History, + ordering, + settings::{FilterMode, SearchMode, Settings}, +}; + +pub struct Context { + pub session: String, + pub cwd: String, + pub hostname: String, + pub host_id: String, + pub git_root: Option<PathBuf>, +} + +#[derive(Default, Clone)] +pub struct OptFilters { + pub exit: Option<i64>, + pub exclude_exit: Option<i64>, + pub cwd: Option<String>, + pub exclude_cwd: Option<String>, + pub before: Option<String>, + pub after: Option<String>, + pub limit: Option<i64>, + pub offset: Option<i64>, + pub reverse: bool, +} + +pub fn current_context() -> Context { + let Ok(session) = env::var("ATUIN_SESSION") else { + eprintln!("ERROR: Failed to find $ATUIN_SESSION in the environment. Check that you have correctly set up your shell."); + std::process::exit(1); + }; + let hostname = get_host_user(); + let cwd = utils::get_current_dir(); + let host_id = Settings::host_id().expect("failed to load host ID"); + let git_root = utils::in_git_repo(cwd.as_str()); + + Context { + session, + hostname, + cwd, + git_root, + host_id: host_id.0.as_simple().to_string(), + } +} + +#[async_trait] +pub trait Database: Send + Sync + 'static { + async fn save(&self, h: &History) -> Result<()>; + async fn save_bulk(&self, h: &[History]) -> Result<()>; + + async fn load(&self, id: &str) -> Result<Option<History>>; + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>>; + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>>; + + async fn update(&self, h: &History) -> Result<()>; + async fn history_count(&self, include_deleted: bool) -> Result<i64>; + + async fn last(&self) -> Result<Option<History>>; + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>>; + + async fn delete(&self, h: History) -> Result<()>; + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()>; + async fn deleted(&self) -> Result<Vec<History>>; + + // Yes I know, it's a lot. + // Could maybe break it down to a searchparams struct or smth but that feels a little... pointless. + // Been debating maybe a DSL for search? eg "before:time limit:1 the query" + #[allow(clippy::too_many_arguments)] + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>>; + + async fn query_history(&self, query: &str) -> Result<Vec<History>>; + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>>; + + async fn stats(&self, h: &History) -> Result<HistoryStats>; +} + +// Intended for use on a developer machine and not a sync server. +// TODO: implement IntoIterator +pub struct Sqlite { + pub pool: SqlitePool, +} + +impl Sqlite { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .optimize_on_close(true, None) + .synchronous(SqliteSynchronous::Normal) + .with_regexp() + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw(tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, h: &History) -> Result<()> { + sqlx::query( + "insert or ignore into history(id, timestamp, duration, exit, command, cwd, session, hostname, deleted_at) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + async fn delete_row_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + id: HistoryId, + ) -> Result<()> { + sqlx::query("delete from history where id = ?1") + .bind(id.0.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_history(row: SqliteRow) -> History { + let deleted_at: Option<i64> = row.get("deleted_at"); + + History::from_db() + .id(row.get("id")) + .timestamp( + OffsetDateTime::from_unix_timestamp_nanos(row.get::<i64, _>("timestamp") as i128) + .unwrap(), + ) + .duration(row.get("duration")) + .exit(row.get("exit")) + .command(row.get("command")) + .cwd(row.get("cwd")) + .session(row.get("session")) + .hostname(row.get("hostname")) + .deleted_at( + deleted_at.and_then(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128).ok()), + ) + .build() + .into() + } +} + +#[async_trait] +impl Database for Sqlite { + async fn save(&self, h: &History) -> Result<()> { + debug!("saving history to sqlite"); + let mut tx = self.pool.begin().await?; + Self::save_raw(&mut tx, h).await?; + tx.commit().await?; + + Ok(()) + } + + async fn save_bulk(&self, h: &[History]) -> Result<()> { + debug!("saving history to sqlite"); + + let mut tx = self.pool.begin().await?; + + for i in h { + Self::save_raw(&mut tx, i).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn load(&self, id: &str) -> Result<Option<History>> { + debug!("loading history item {}", id); + + let res = sqlx::query("select * from history where id = ?1") + .bind(id) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn update(&self, h: &History) -> Result<()> { + debug!("updating sqlite history"); + + sqlx::query( + "update history + set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8, deleted_at = ?9 + where id = ?1", + ) + .bind(h.id.0.as_str()) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(h.duration) + .bind(h.exit) + .bind(h.command.as_str()) + .bind(h.cwd.as_str()) + .bind(h.session.as_str()) + .bind(h.hostname.as_str()) + .bind(h.deleted_at.map(|t|t.unix_timestamp_nanos() as i64)) + .execute(&self.pool) + .await?; + + Ok(()) + } + + // make a unique list, that only shows the *newest* version of things + async fn list( + &self, + filters: &[FilterMode], + context: &Context, + max: Option<usize>, + unique: bool, + include_deleted: bool, + ) -> Result<Vec<History>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + query.field("*").order_desc("timestamp"); + if !include_deleted { + query.and_where_is_null("deleted_at"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + for filter in filters { + match filter { + FilterMode::Global => &mut query, + FilterMode::Host => query.and_where_eq("hostname", quote(&context.hostname)), + FilterMode::Session => query.and_where_eq("session", quote(&context.session)), + FilterMode::Directory => query.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => query.and_where_like_left("cwd", &git_root), + }; + } + + if unique { + query.group_by("command").having("max(timestamp)"); + } + + if let Some(max) = max { + query.limit(max); + } + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn range(&self, from: OffsetDateTime, to: OffsetDateTime) -> Result<Vec<History>> { + debug!("listing history from {:?} to {:?}", from, to); + + let res = sqlx::query( + "select * from history where timestamp >= ?1 and timestamp <= ?2 order by timestamp asc", + ) + .bind(from.unix_timestamp_nanos() as i64) + .bind(to.unix_timestamp_nanos() as i64) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn last(&self) -> Result<Option<History>> { + let res = sqlx::query( + "select * from history where duration >= 0 order by timestamp desc limit 1", + ) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + Ok(res) + } + + async fn before(&self, timestamp: OffsetDateTime, count: i64) -> Result<Vec<History>> { + let res = sqlx::query( + "select * from history where timestamp < ?1 order by timestamp desc limit ?2", + ) + .bind(timestamp.unix_timestamp_nanos() as i64) + .bind(count) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn deleted(&self) -> Result<Vec<History>> { + let res = sqlx::query("select * from history where deleted_at is not null") + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn history_count(&self, include_deleted: bool) -> Result<i64> { + let query = if include_deleted { + "select count(1) from history" + } else { + "select count(1) from history where deleted_at is null" + }; + + let res: (i64,) = sqlx::query_as(query).fetch_one(&self.pool).await?; + Ok(res.0) + } + + async fn search( + &self, + search_mode: SearchMode, + filter: FilterMode, + context: &Context, + query: &str, + filter_options: OptFilters, + ) -> Result<Vec<History>> { + let mut sql = SqlBuilder::select_from("history"); + + sql.group_by("command").having("max(timestamp)"); + + if let Some(limit) = filter_options.limit { + sql.limit(limit); + } + + if let Some(offset) = filter_options.offset { + sql.offset(offset); + } + + if filter_options.reverse { + sql.order_asc("timestamp"); + } else { + sql.order_desc("timestamp"); + } + + let git_root = if let Some(git_root) = context.git_root.clone() { + git_root.to_str().unwrap_or("/").to_string() + } else { + context.cwd.clone() + }; + + match filter { + FilterMode::Global => &mut sql, + FilterMode::Host => { + sql.and_where_eq("lower(hostname)", quote(context.hostname.to_lowercase())) + } + FilterMode::Session => sql.and_where_eq("session", quote(&context.session)), + FilterMode::Directory => sql.and_where_eq("cwd", quote(&context.cwd)), + FilterMode::Workspace => sql.and_where_like_left("cwd", git_root), + }; + + let orig_query = query; + + let mut regexes = Vec::new(); + match search_mode { + SearchMode::Prefix => sql.and_where_like_left("command", query.replace('*', "%")), + _ => { + let mut is_or = false; + let mut regex = None; + for part in query.split_inclusive(' ') { + let query_part: Cow<str> = match (&mut regex, part.starts_with("r/")) { + (None, false) => { + if part.trim_end().is_empty() { + continue; + } + Cow::Owned(part.trim_end().replace('*', "%")) // allow wildcard char + } + (None, true) => { + if part[2..].trim_end().ends_with('/') { + let end_pos = part.trim_end().len() - 1; + regexes.push(String::from(&part[2..end_pos])); + } else { + regex = Some(String::from(&part[2..])); + } + continue; + } + (Some(r), _) => { + if part.trim_end().ends_with('/') { + let end_pos = part.trim_end().len() - 1; + r.push_str(&part.trim_end()[..end_pos]); + regexes.push(regex.take().unwrap()); + } else { + r.push_str(part); + } + continue; + } + }; + + // TODO smart case mode could be made configurable like in fzf + let (is_glob, glob) = if query_part.contains(char::is_uppercase) { + (true, "*") + } else { + (false, "%") + }; + + let (is_inverse, query_part) = match query_part.strip_prefix('!') { + Some(stripped) => (true, Cow::Borrowed(stripped)), + None => (false, query_part), + }; + + #[allow(clippy::if_same_then_else)] + let param = if query_part == "|" { + if !is_or { + is_or = true; + continue; + } else { + format!("{glob}|{glob}") + } + } else if let Some(term) = query_part.strip_prefix('^') { + format!("{term}{glob}") + } else if let Some(term) = query_part.strip_suffix('$') { + format!("{glob}{term}") + } else if let Some(term) = query_part.strip_prefix('\'') { + format!("{glob}{term}{glob}") + } else if is_inverse { + format!("{glob}{query_part}{glob}") + } else if search_mode == SearchMode::FullText { + format!("{glob}{query_part}{glob}") + } else { + query_part.split("").join(glob) + }; + + sql.fuzzy_condition("command", param, is_inverse, is_glob, is_or); + is_or = false; + } + if let Some(r) = regex { + regexes.push(r); + } + + &mut sql + } + }; + + for regex in regexes { + sql.and_where("command regexp ?".bind(®ex)); + } + + filter_options + .exit + .map(|exit| sql.and_where_eq("exit", exit)); + + filter_options + .exclude_exit + .map(|exclude_exit| sql.and_where_ne("exit", exclude_exit)); + + filter_options + .cwd + .map(|cwd| sql.and_where_eq("cwd", quote(cwd))); + + filter_options + .exclude_cwd + .map(|exclude_cwd| sql.and_where_ne("cwd", quote(exclude_cwd))); + + filter_options.before.map(|before| { + interim::parse_date_string( + before.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|before| { + sql.and_where_lt("timestamp", quote(before.unix_timestamp_nanos() as i64)) + }) + }); + + filter_options.after.map(|after| { + interim::parse_date_string( + after.as_str(), + OffsetDateTime::now_utc(), + interim::Dialect::Uk, + ) + .map(|after| sql.and_where_gt("timestamp", quote(after.unix_timestamp_nanos() as i64))) + }); + + sql.and_where_is_null("deleted_at"); + + let query = sql.sql().expect("bug in search query. please report"); + + let res = sqlx::query(&query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(ordering::reorder_fuzzy(search_mode, orig_query, res)) + } + + async fn query_history(&self, query: &str) -> Result<Vec<History>> { + let res = sqlx::query(query) + .map(Self::query_history) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn all_with_count(&self) -> Result<Vec<(History, i32)>> { + debug!("listing history"); + + let mut query = SqlBuilder::select_from(SqlName::new("history").alias("h").baquoted()); + + query + .fields(&[ + "id", + "max(timestamp) as timestamp", + "max(duration) as duration", + "exit", + "command", + "deleted_at", + "group_concat(cwd, ':') as cwd", + "group_concat(session) as session", + "group_concat(hostname, ',') as hostname", + "count(*) as count", + ]) + .group_by("command") + .group_by("exit") + .and_where("deleted_at is null") + .order_desc("timestamp"); + + let query = query.sql().expect("bug in list query. please report"); + + let res = sqlx::query(&query) + .map(|row: SqliteRow| { + let count: i32 = row.get("count"); + (Self::query_history(row), count) + }) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + // deleted_at doesn't mean the actual time that the user deleted it, + // but the time that the system marks it as deleted + async fn delete(&self, mut h: History) -> Result<()> { + let now = OffsetDateTime::now_utc(); + h.command = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); // overwrite with random string + h.deleted_at = Some(now); // delete it + + self.update(&h).await?; // save it + + Ok(()) + } + + async fn delete_rows(&self, ids: &[HistoryId]) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for id in ids { + Self::delete_row_raw(&mut tx, id.clone()).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn stats(&self, h: &History) -> Result<HistoryStats> { + // We select the previous in the session by time + let mut prev = SqlBuilder::select_from("history"); + prev.field("*") + .and_where("timestamp < ?1") + .and_where("session = ?2") + .order_by("timestamp", true) + .limit(1); + + let mut next = SqlBuilder::select_from("history"); + next.field("*") + .and_where("timestamp > ?1") + .and_where("session = ?2") + .order_by("timestamp", false) + .limit(1); + + let mut total = SqlBuilder::select_from("history"); + total.field("count(1)").and_where("command = ?1"); + + let mut average = SqlBuilder::select_from("history"); + average.field("avg(duration)").and_where("command = ?1"); + + let mut exits = SqlBuilder::select_from("history"); + exits + .fields(&["exit", "count(1) as count"]) + .and_where("command = ?1") + .group_by("exit"); + + // rewrite the following with sqlbuilder + let mut day_of_week = SqlBuilder::select_from("history"); + day_of_week + .fields(&[ + "strftime('%w', ROUND(timestamp / 1000000000), 'unixepoch') AS day_of_week", + "count(1) as count", + ]) + .and_where("command = ?1") + .group_by("day_of_week"); + + // Intentionally format the string with 01 hardcoded. We want the average runtime for the + // _entire month_, but will later parse it as a datetime for sorting + // Sqlite has no datetime so we cannot do it there, and otherwise sorting will just be a + // string sort, which won't be correct. + let mut duration_over_time = SqlBuilder::select_from("history"); + duration_over_time + .fields(&[ + "strftime('01-%m-%Y', ROUND(timestamp / 1000000000), 'unixepoch') AS month_year", + "avg(duration) as duration", + ]) + .and_where("command = ?1") + .group_by("month_year") + .having("duration > 0"); + + let prev = prev.sql().expect("issue in stats previous query"); + let next = next.sql().expect("issue in stats next query"); + let total = total.sql().expect("issue in stats average query"); + let average = average.sql().expect("issue in stats previous query"); + let exits = exits.sql().expect("issue in stats exits query"); + let day_of_week = day_of_week.sql().expect("issue in stats day of week query"); + let duration_over_time = duration_over_time + .sql() + .expect("issue in stats duration over time query"); + + let prev = sqlx::query(&prev) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let next = sqlx::query(&next) + .bind(h.timestamp.unix_timestamp_nanos() as i64) + .bind(&h.session) + .map(Self::query_history) + .fetch_optional(&self.pool) + .await?; + + let total: (i64,) = sqlx::query_as(&total) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let average: (f64,) = sqlx::query_as(&average) + .bind(&h.command) + .fetch_one(&self.pool) + .await?; + + let exits: Vec<(i64, i64)> = sqlx::query_as(&exits) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let day_of_week: Vec<(String, i64)> = sqlx::query_as(&day_of_week) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time: Vec<(String, f64)> = sqlx::query_as(&duration_over_time) + .bind(&h.command) + .fetch_all(&self.pool) + .await?; + + let duration_over_time = duration_over_time + .iter() + .map(|f| (f.0.clone(), f.1.round() as i64)) + .collect(); + + Ok(HistoryStats { + next, + previous: prev, + total: total.0 as u64, + average_duration: average.0 as u64, + exits, + day_of_week, + duration_over_time, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::time::{Duration, Instant}; + + async fn assert_search_eq<'a>( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected: usize, + ) -> Result<Vec<History>> { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let results = db + .search( + mode, + filter_mode, + &context, + query, + OptFilters { + ..Default::default() + }, + ) + .await?; + + assert_eq!( + results.len(), + expected, + "query \"{}\", commands: {:?}", + query, + results.iter().map(|a| &a.command).collect::<Vec<&String>>() + ); + Ok(results) + } + + async fn assert_search_commands( + db: &impl Database, + mode: SearchMode, + filter_mode: FilterMode, + query: &str, + expected_commands: Vec<&str>, + ) { + let results = assert_search_eq(db, mode, filter_mode, query, expected_commands.len()) + .await + .unwrap(); + let commands: Vec<&str> = results.iter().map(|a| a.command.as_str()).collect(); + assert_eq!(commands, expected_commands); + } + + async fn new_history_item(db: &mut impl Database, cmd: &str) -> Result<()> { + let mut captured: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(cmd) + .cwd("/home/ellie") + .build() + .into(); + + captured.exit = 0; + captured.duration = 1; + captured.session = "beep boop".to_string(); + captured.hostname = "booop".to_string(); + + db.save(&captured).await + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_prefix() { + let mut db = Sqlite::new("sqlite::memory:", 0.1).await.unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Prefix, FilterMode::Global, "ls ", 0) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fulltext() { + let mut db = Sqlite::new("sqlite::memory:", 0.1).await.unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "ls ho", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "hm", 0) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r/^ls ", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / ie$", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/ls / !ie", + 0, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "meow r/ls/", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "r//hom/", 1) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home//", + 1, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r//home///", + 0, + ) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::FullText, FilterMode::Global, "/home.*e", 0) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::FullText, + FilterMode::Global, + "r/home.*e", + 1, + ) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", 0.1).await.unwrap(); + new_history_item(&mut db, "ls /home/ellie").await.unwrap(); + new_history_item(&mut db, "ls /home/frank").await.unwrap(); + new_history_item(&mut db, "cd /home/Ellie").await.unwrap(); + new_history_item(&mut db, "/home/ellie/.bin/rustup") + .await + .unwrap(); + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls /", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "l/h/", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/hmoe/", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie/home", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "lsellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, " ", 4) + .await + .unwrap(); + + // single term operators + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "'ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ellie$", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!^ls", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "!ellie$", 2) + .await + .unwrap(); + + // multiple terms + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "ls !ellie", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "^ls !e$", 1) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "home !^ls", 2) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup", + 2, + ) + .await + .unwrap(); + assert_search_eq( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "'frank | 'rustup 'ls", + 1, + ) + .await + .unwrap(); + + // case matching + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "Ellie", 1) + .await + .unwrap(); + + // regex + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/^ls ", 2) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "r/[Ee]llie", 3) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "/h/e r/^ls ", 1) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_reordered_fuzzy() { + let mut db = Sqlite::new("sqlite::memory:", 0.1).await.unwrap(); + // test ordering of results: we should choose the first, even though it happened longer ago. + + new_history_item(&mut db, "curl").await.unwrap(); + new_history_item(&mut db, "corburl").await.unwrap(); + + // if fuzzy reordering is on, it should come back in a more sensible order + assert_search_commands( + &db, + SearchMode::Fuzzy, + FilterMode::Global, + "curl", + vec!["curl", "corburl"], + ) + .await; + + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "xxxx", 0) + .await + .unwrap(); + assert_search_eq(&db, SearchMode::Fuzzy, FilterMode::Global, "", 2) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_search_bench_dupes() { + let context = Context { + hostname: "test:host".to_string(), + session: "beepboopiamasession".to_string(), + cwd: "/home/ellie".to_string(), + host_id: "test-host".to_string(), + git_root: None, + }; + + let mut db = Sqlite::new("sqlite::memory:", 0.1).await.unwrap(); + for _i in 1..10000 { + new_history_item(&mut db, "i am a duplicated command") + .await + .unwrap(); + } + let start = Instant::now(); + let _results = db + .search( + SearchMode::Fuzzy, + FilterMode::Global, + &context, + "", + OptFilters { + ..Default::default() + }, + ) + .await + .unwrap(); + let duration = start.elapsed(); + + assert!(duration < Duration::from_secs(15)); + } +} + +trait SqlBuilderExt { + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self; +} + +impl SqlBuilderExt for SqlBuilder { + /// adapted from the sql-builder *like functions + fn fuzzy_condition<S: ToString, T: ToString>( + &mut self, + field: S, + mask: T, + inverse: bool, + glob: bool, + is_or: bool, + ) -> &mut Self { + let mut cond = field.to_string(); + if inverse { + cond.push_str(" NOT"); + } + if glob { + cond.push_str(" GLOB '"); + } else { + cond.push_str(" LIKE '"); + } + cond.push_str(&esc(mask.to_string())); + cond.push('\''); + if is_or { + self.or_where(cond) + } else { + self.and_where(cond) + } + } +} diff --git a/crates/atuin-client/src/encryption.rs b/crates/atuin-client/src/encryption.rs new file mode 100644 index 00000000..50aacc24 --- /dev/null +++ b/crates/atuin-client/src/encryption.rs @@ -0,0 +1,430 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::{io::prelude::*, path::PathBuf}; + +use base64::prelude::{Engine, BASE64_STANDARD}; +pub use crypto_secretbox::Key; +use crypto_secretbox::{ + aead::{Nonce, OsRng}, + AeadCore, AeadInPlace, KeyInit, XSalsa20Poly1305, +}; +use eyre::{bail, ensure, eyre, Context, Result}; +use fs_err as fs; +use rmp::{decode::Bytes, Marker}; +use serde::{Deserialize, Serialize}; +use time::{format_description::well_known::Rfc3339, macros::format_description, OffsetDateTime}; + +use crate::{history::History, settings::Settings}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec<u8>, + pub nonce: Nonce<XSalsa20Poly1305>, +} + +pub fn generate_encoded_key() -> Result<(Key, String)> { + let key = XSalsa20Poly1305::generate_key(&mut OsRng); + let encoded = encode_key(&key)?; + + Ok((key, encoded)) +} + +pub fn new_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + let path = PathBuf::from(path); + + if path.exists() { + bail!("key already exists! cannot overwrite"); + } + + let (key, encoded) = generate_encoded_key()?; + + let mut file = fs::File::create(path)?; + file.write_all(encoded.as_bytes())?; + + Ok(key) +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result<Key> { + let path = settings.key_path.as_str(); + + let key = if PathBuf::from(path).exists() { + let key = fs_err::read_to_string(path)?; + decode_key(key)? + } else { + new_key(settings)? + }; + + Ok(key) +} + +pub fn encode_key(key: &Key) -> Result<String> { + let mut buf = vec![]; + rmp::encode::write_array_len(&mut buf, key.len() as u32) + .wrap_err("could not encode key to message pack")?; + for b in key { + rmp::encode::write_uint(&mut buf, *b as u64) + .wrap_err("could not encode key to message pack")?; + } + let buf = BASE64_STANDARD.encode(buf); + + Ok(buf) +} + +pub fn decode_key(key: String) -> Result<Key> { + use rmp::decode; + + let buf = BASE64_STANDARD + .decode(key.trim_end()) + .wrap_err("encryption key is not a valid base64 encoding")?; + + // old code wrote the key as a fixed length array of 32 bytes + // new code writes the key with a length prefix + match <[u8; 32]>::try_from(&*buf) { + Ok(key) => Ok(key.into()), + Err(_) => { + let mut bytes = rmp::decode::Bytes::new(&buf); + + match Marker::from_u8(buf[0]) { + Marker::Bin8 => { + let len = decode::read_bin_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + let key = <[u8; 32]>::try_from(bytes.remaining_slice()) + .context("could not decode encryption key")?; + Ok(key.into()) + } + Marker::Array16 => { + let len = decode::read_array_len(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + ensure!(len == 32, "encryption key is not the correct size"); + + let mut key = Key::default(); + for i in &mut key { + *i = rmp::decode::read_int(&mut bytes).map_err(|err| eyre!("{err:?}"))?; + } + Ok(key) + } + _ => bail!("could not decode encryption key"), + } + } + } +} + +pub fn encrypt(history: &History, key: &Key) -> Result<EncryptedHistory> { + // serialize with msgpack + let mut buf = encode(history)?; + + let nonce = XSalsa20Poly1305::generate_nonce(&mut OsRng); + XSalsa20Poly1305::new(key) + .encrypt_in_place(&nonce, &[], &mut buf) + .map_err(|_| eyre!("could not encrypt"))?; + + Ok(EncryptedHistory { + ciphertext: buf, + nonce, + }) +} + +pub fn decrypt(mut encrypted_history: EncryptedHistory, key: &Key) -> Result<History> { + XSalsa20Poly1305::new(key) + .decrypt_in_place( + &encrypted_history.nonce, + &[], + &mut encrypted_history.ciphertext, + ) + .map_err(|_| eyre!("could not encrypt"))?; + let plaintext = encrypted_history.ciphertext; + + let history = decode(&plaintext)?; + + Ok(history) +} + +fn format_rfc3339(ts: OffsetDateTime) -> Result<String> { + // horrible hack. chrono AutoSI limits to 0, 3, 6, or 9 decimal places for nanoseconds. + // time does not have this functionality. + static PARTIAL_RFC3339_0: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second]Z"); + static PARTIAL_RFC3339_3: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"); + static PARTIAL_RFC3339_6: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:6]Z"); + static PARTIAL_RFC3339_9: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:9]Z"); + + let fmt = match ts.nanosecond() { + 0 => PARTIAL_RFC3339_0, + ns if ns % 1_000_000 == 0 => PARTIAL_RFC3339_3, + ns if ns % 1_000 == 0 => PARTIAL_RFC3339_6, + _ => PARTIAL_RFC3339_9, + }; + + Ok(ts.format(fmt)?) +} + +fn encode(h: &History) -> Result<Vec<u8>> { + use rmp::encode; + + let mut output = vec![]; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &h.id.0)?; + encode::write_str(&mut output, &(format_rfc3339(h.timestamp)?))?; + encode::write_sint(&mut output, h.duration)?; + encode::write_sint(&mut output, h.exit)?; + encode::write_str(&mut output, &h.command)?; + encode::write_str(&mut output, &h.cwd)?; + encode::write_str(&mut output, &h.session)?; + encode::write_str(&mut output, &h.hostname)?; + match h.deleted_at { + Some(d) => encode::write_str(&mut output, &format_rfc3339(d)?)?, + None => encode::write_nil(&mut output)?, + } + + Ok(output) +} + +fn decode(bytes: &[u8]) -> Result<History> { + use rmp::decode::{self, DecodeStringError}; + + let mut bytes = Bytes::new(bytes); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + if nfields < 8 { + bail!("malformed decrypted history") + } + if nfields > 9 { + bail!("cannot decrypt history from a newer version of atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (timestamp, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut deleted_at = None; + let mut bytes = bytes; + if nfields > 8 { + bytes = match decode::read_str_from_slice(bytes) { + Ok((d, b)) => { + deleted_at = Some(d); + b + } + // we accept null here + Err(DecodeStringError::TypeMismatch(Marker::Null)) => { + // consume the null marker + let mut c = Bytes::new(bytes); + decode::read_nil(&mut c).map_err(error_report)?; + c.remaining_slice() + } + Err(err) => return Err(error_report(err)), + }; + } + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::parse(timestamp, &Rfc3339)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + deleted_at: deleted_at + .map(|t| OffsetDateTime::parse(t, &Rfc3339)) + .transpose()?, + }) +} + +fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") +} + +#[cfg(test)] +mod test { + use crypto_secretbox::{aead::OsRng, KeyInit, XSalsa20Poly1305}; + use pretty_assertions::assert_eq; + use time::{macros::datetime, OffsetDateTime}; + + use crate::history::History; + + use super::{decode, decrypt, encode, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = XSalsa20Poly1305::generate_key(&mut OsRng); + let key2 = XSalsa20Poly1305::generate_key(&mut OsRng); + + let history = History::from_db() + .id("1".into()) + .timestamp(OffsetDateTime::now_utc()) + .command("ls".into()) + .cwd("/home/ellie".into()) + .exit(0) + .duration(1) + .session("beep boop".into()) + .hostname("booop".into()) + .deleted_at(None) + .build() + .into(); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(e1, &key1) { + Err(e) => panic!("failed to decrypt, got {}", e), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + let _ = decrypt(e2, &key1).expect_err("expected an error decrypting with invalid key"); + } + + #[test] + fn test_decode() { + let bytes = [ + 0x99, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, 192, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + + let b = encode(&h).unwrap(); + assert_eq!(&bytes, &*b); + } + + #[test] + fn test_decode_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: Some(datetime!(2023-05-28 18:35:40.633872 +00:00)), + }; + + let b = encode(&history).unwrap(); + let h = decode(&b).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn test_decode_old() { + let bytes = [ + 0x98, 0xD9, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, 53, 51, 56, + 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 187, 50, 48, 50, 51, 45, + 48, 53, 45, 50, 56, 84, 49, 56, 58, 51, 53, 58, 52, 48, 46, 54, 51, 51, 56, 55, 50, 90, + 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, 97, 116, 117, 115, 217, 42, + 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, + 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, 47, 99, 111, 100, 101, 47, 97, + 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, 51, 48, 54, 102, 50, 55, 52, 52, + 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, 102, 57, 52, 53, 55, 187, 102, + 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, 99, 111, 110, 114, 97, 100, 46, + 108, 117, 100, 103, 97, 116, 101, + ]; + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: None, + }; + + let h = decode(&bytes).unwrap(); + assert_eq!(history, h); + } + + #[test] + fn key_encodings() { + use super::{decode_key, encode_key, Key}; + + // a history of our key encodings. + // v11.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v12.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v13.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.0 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // v14.0.1 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== + // c7d89c1 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/805) + // b53ca35 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/974) + // v15.0.0 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== + // b8b57c8 xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q== (https://github.com/ellie/atuin/pull/1057) + // 8c94d79 3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q== (https://github.com/ellie/atuin/pull/1089) + + let key = Key::from([ + 27, 91, 42, 91, 210, 107, 9, 216, 170, 190, 242, 62, 6, 84, 69, 148, 148, 53, 251, 117, + 226, 167, 173, 52, 82, 34, 138, 110, 169, 124, 92, 229, + ]); + + assert_eq!( + encode_key(&key).unwrap(), + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==" + ); + + // key encodings we have to support + let valid_encodings = [ + "xCAbWypb0msJ2Kq+8j4GVEWUlDX7deKnrTRSIopuqXxc5Q==", + "3AAgG1sqW8zSawnM2MyqzL7M8j4GVEXMlMyUNcz7dczizKfMrTRSIsyKbsypfFzM5Q==", + ]; + + for k in valid_encodings { + assert_eq!(decode_key(k.to_owned()).expect(k), key); + } + } +} diff --git a/crates/atuin-client/src/history.rs b/crates/atuin-client/src/history.rs new file mode 100644 index 00000000..1b590e88 --- /dev/null +++ b/crates/atuin-client/src/history.rs @@ -0,0 +1,517 @@ +use core::fmt::Formatter; +use rmp::decode::ValueReadError; +use rmp::{decode::Bytes, Marker}; +use std::env; +use std::fmt::Display; + +use atuin_common::record::DecryptedData; +use atuin_common::utils::uuid_v7; + +use eyre::{bail, eyre, Result}; +use regex::RegexSet; + +use crate::utils::get_host_user; +use crate::{secrets::SECRET_PATTERNS, settings::Settings}; +use time::OffsetDateTime; + +mod builder; +pub mod store; + +const HISTORY_VERSION: &str = "v0"; +pub const HISTORY_TAG: &str = "history"; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct HistoryId(pub String); + +impl Display for HistoryId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<String> for HistoryId { + fn from(s: String) -> Self { + Self(s) + } +} + +/// Client-side history entry. +/// +/// Client stores data unencrypted, and only encrypts it before sending to the server. +/// +/// To create a new history entry, use one of the builders: +/// - [`History::import()`] to import an entry from the shell history file +/// - [`History::capture()`] to capture an entry via hook +/// - [`History::from_db()`] to create an instance from the database entry +// +// ## Implementation Notes +// +// New fields must should be added to `encryption::{encode, decode}` in a backwards +// compatible way. (eg sensible defaults and updating the nfields parameter) +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct History { + /// A client-generated ID, used to identify the entry when syncing. + /// + /// Stored as `client_id` in the database. + pub id: HistoryId, + /// When the command was run. + pub timestamp: OffsetDateTime, + /// How long the command took to run. + pub duration: i64, + /// The exit code of the command. + pub exit: i64, + /// The command that was run. + pub command: String, + /// The current working directory when the command was run. + pub cwd: String, + /// The session ID, associated with a terminal session. + pub session: String, + /// The hostname of the machine the command was run on. + pub hostname: String, + /// Timestamp, which is set when the entry is deleted, allowing a soft delete. + pub deleted_at: Option<OffsetDateTime>, +} + +#[derive(Debug, Clone, PartialEq, Eq, sqlx::FromRow)] +pub struct HistoryStats { + /// The command that was ran after this one in the session + pub next: Option<History>, + /// + /// The command that was ran before this one in the session + pub previous: Option<History>, + + /// How many times has this command been ran? + pub total: u64, + + pub average_duration: u64, + + pub exits: Vec<(i64, i64)>, + + pub day_of_week: Vec<(String, i64)>, + + pub duration_over_time: Vec<(String, i64)>, +} + +impl History { + #[allow(clippy::too_many_arguments)] + fn new( + timestamp: OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: Option<String>, + hostname: Option<String>, + deleted_at: Option<OffsetDateTime>, + ) -> Self { + let session = session + .or_else(|| env::var("ATUIN_SESSION").ok()) + .unwrap_or_else(|| uuid_v7().as_simple().to_string()); + let hostname = hostname.unwrap_or_else(get_host_user); + + Self { + id: uuid_v7().as_simple().to_string().into(), + timestamp, + command, + cwd, + exit, + duration, + session, + hostname, + deleted_at, + } + } + + pub fn serialize(&self) -> Result<DecryptedData> { + // This is pretty much the same as what we used for the old history, with one difference - + // it uses integers for timestamps rather than a string format. + + use rmp::encode; + + let mut output = vec![]; + + // write the version + encode::write_u16(&mut output, 0)?; + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 9)?; + + encode::write_str(&mut output, &self.id.0)?; + encode::write_u64(&mut output, self.timestamp.unix_timestamp_nanos() as u64)?; + encode::write_sint(&mut output, self.duration)?; + encode::write_sint(&mut output, self.exit)?; + encode::write_str(&mut output, &self.command)?; + encode::write_str(&mut output, &self.cwd)?; + encode::write_str(&mut output, &self.session)?; + encode::write_str(&mut output, &self.hostname)?; + + match self.deleted_at { + Some(d) => encode::write_u64(&mut output, d.unix_timestamp_nanos() as u64)?, + None => encode::write_nil(&mut output)?, + } + + Ok(DecryptedData(output)) + } + + fn deserialize_v0(bytes: &[u8]) -> Result<History> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(bytes); + + let version = decode::read_u16(&mut bytes).map_err(error_report)?; + + if version != 0 { + bail!("expected decoding v0 record, found v{version}"); + } + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + + if nfields != 9 { + bail!("cannot decrypt history from a different version of Atuin"); + } + + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + let mut bytes = Bytes::new(bytes); + let timestamp = decode::read_u64(&mut bytes).map_err(error_report)?; + let duration = decode::read_int(&mut bytes).map_err(error_report)?; + let exit = decode::read_int(&mut bytes).map_err(error_report)?; + + let bytes = bytes.remaining_slice(); + let (command, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (cwd, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (session, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (hostname, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + // if we have more fields, try and get the deleted_at + let mut bytes = Bytes::new(bytes); + + let (deleted_at, bytes) = match decode::read_u64(&mut bytes) { + Ok(unix) => (Some(unix), bytes.remaining_slice()), + // we accept null here + Err(ValueReadError::TypeMismatch(Marker::Null)) => (None, bytes.remaining_slice()), + Err(err) => return Err(error_report(err)), + }; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded history. malformed") + } + + Ok(History { + id: id.to_owned().into(), + timestamp: OffsetDateTime::from_unix_timestamp_nanos(timestamp as i128)?, + duration, + exit, + command: command.to_owned(), + cwd: cwd.to_owned(), + session: session.to_owned(), + hostname: hostname.to_owned(), + deleted_at: deleted_at + .map(|t| OffsetDateTime::from_unix_timestamp_nanos(t as i128)) + .transpose()?, + }) + } + + pub fn deserialize(bytes: &[u8], version: &str) -> Result<History> { + match version { + HISTORY_VERSION => Self::deserialize_v0(bytes), + + _ => bail!("unknown version {version:?}"), + } + } + + /// Builder for a history entry that is imported from shell history. + /// + /// The only two required fields are `timestamp` and `command`. + /// + /// ## Examples + /// ``` + /// use atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + /// + /// If shell history contains more information, it can be added to the builder: + /// ``` + /// use atuin_client::history::History; + /// + /// let history: History = History::import() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .exit(0) + /// .duration(100) + /// .build() + /// .into(); + /// ``` + /// + /// Unknown command or command without timestamp cannot be imported, which + /// is forced at compile time: + /// + /// ```compile_fail + /// use atuin_client::history::History; + /// + /// // this will not compile because timestamp is missing + /// let history: History = History::import() + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn import() -> builder::HistoryImportedBuilder { + builder::HistoryImported::builder() + } + + /// Builder for a history entry that is captured via hook. + /// + /// This builder is used only at the `start` step of the hook, + /// so it doesn't have any fields which are known only after + /// the command is finished, such as `exit` or `duration`. + /// + /// ## Examples + /// ```rust + /// use atuin_client::history::History; + /// + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .cwd("/home/user") + /// .build() + /// .into(); + /// ``` + /// + /// Command without any required info cannot be captured, which is forced at compile time: + /// + /// ```compile_fail + /// use atuin_client::history::History; + /// + /// // this will not compile because `cwd` is missing + /// let history: History = History::capture() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la") + /// .build() + /// .into(); + /// ``` + pub fn capture() -> builder::HistoryCapturedBuilder { + builder::HistoryCaptured::builder() + } + + /// Builder for a history entry that is imported from the database. + /// + /// All fields are required, as they are all present in the database. + /// + /// ```compile_fail + /// use atuin_client::history::History; + /// + /// // this will not compile because `id` field is missing + /// let history: History = History::from_db() + /// .timestamp(time::OffsetDateTime::now_utc()) + /// .command("ls -la".to_string()) + /// .cwd("/home/user".to_string()) + /// .exit(0) + /// .duration(100) + /// .session("somesession".to_string()) + /// .hostname("localhost".to_string()) + /// .deleted_at(None) + /// .build() + /// .into(); + /// ``` + pub fn from_db() -> builder::HistoryFromDbBuilder { + builder::HistoryFromDb::builder() + } + + pub fn success(&self) -> bool { + self.exit == 0 || self.duration == -1 + } + + pub fn should_save(&self, settings: &Settings) -> bool { + let secret_regex = SECRET_PATTERNS.iter().map(|f| f.1); + let secret_regex = RegexSet::new(secret_regex).expect("Failed to build secrets regex"); + + !(self.command.starts_with(' ') + || settings.history_filter.is_match(&self.command) + || settings.cwd_filter.is_match(&self.cwd) + || (secret_regex.is_match(&self.command)) && settings.secrets_filter) + } +} + +#[cfg(test)] +mod tests { + use regex::RegexSet; + use time::macros::datetime; + + use crate::{history::HISTORY_VERSION, settings::Settings}; + + use super::History; + + // Test that we don't save history where necessary + #[test] + fn privacy_test() { + let settings = Settings { + cwd_filter: RegexSet::new(["^/supasecret"]).unwrap(), + history_filter: RegexSet::new(["^psql"]).unwrap(), + ..Settings::utc() + }; + + let normal_command: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo foo") + .cwd("/") + .build() + .into(); + + let with_space: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command(" echo bar") + .cwd("/") + .build() + .into(); + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + let secret_dir: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo ohno") + .cwd("/supasecret") + .build() + .into(); + + let with_psql: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("psql") + .cwd("/supasecret") + .build() + .into(); + + assert!(normal_command.should_save(&settings)); + assert!(!with_space.should_save(&settings)); + assert!(!stripe_key.should_save(&settings)); + assert!(!secret_dir.should_save(&settings)); + assert!(!with_psql.should_save(&settings)); + } + + #[test] + fn disable_secrets() { + let settings = Settings { + secrets_filter: false, + ..Settings::utc() + }; + + let stripe_key: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("curl foo.com/bar?key=sk_test_1234567890abcdefghijklmnop") + .cwd("/") + .build() + .into(); + + assert!(stripe_key.should_save(&settings)); + } + + #[test] + fn test_serialize_deserialize() { + let bytes = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: None, + }; + + let serialized = history.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + assert_eq!(history, deserialized); + + // test the snapshot too + let deserialized = + History::deserialize(&bytes, HISTORY_VERSION).expect("failed to deserialize history"); + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_deleted() { + let history = History { + id: "66d16cbee7cd47538e5c5b8b44e9006e".to_owned().into(), + timestamp: datetime!(2023-05-28 18:35:40.633872 +00:00), + duration: 49206000, + exit: 0, + command: "git status".to_owned(), + cwd: "/Users/conrad.ludgate/Documents/code/atuin".to_owned(), + session: "b97d9a306f274473a203d2eba41f9457".to_owned(), + hostname: "fvfg936c0kpf:conrad.ludgate".to_owned(), + deleted_at: Some(datetime!(2023-11-19 20:18 +00:00)), + }; + + let serialized = history.serialize().expect("failed to serialize history"); + + let deserialized = History::deserialize(&serialized.0, HISTORY_VERSION) + .expect("failed to deserialize history"); + + assert_eq!(history, deserialized); + } + + #[test] + fn test_serialize_deserialize_version() { + // v0 + let bytes_v0 = [ + 205, 0, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + // some other version + let bytes_v1 = [ + 205, 1, 0, 153, 217, 32, 54, 54, 100, 49, 54, 99, 98, 101, 101, 55, 99, 100, 52, 55, + 53, 51, 56, 101, 53, 99, 53, 98, 56, 98, 52, 52, 101, 57, 48, 48, 54, 101, 207, 23, 99, + 98, 117, 24, 210, 246, 128, 206, 2, 238, 210, 240, 0, 170, 103, 105, 116, 32, 115, 116, + 97, 116, 117, 115, 217, 42, 47, 85, 115, 101, 114, 115, 47, 99, 111, 110, 114, 97, 100, + 46, 108, 117, 100, 103, 97, 116, 101, 47, 68, 111, 99, 117, 109, 101, 110, 116, 115, + 47, 99, 111, 100, 101, 47, 97, 116, 117, 105, 110, 217, 32, 98, 57, 55, 100, 57, 97, + 51, 48, 54, 102, 50, 55, 52, 52, 55, 51, 97, 50, 48, 51, 100, 50, 101, 98, 97, 52, 49, + 102, 57, 52, 53, 55, 187, 102, 118, 102, 103, 57, 51, 54, 99, 48, 107, 112, 102, 58, + 99, 111, 110, 114, 97, 100, 46, 108, 117, 100, 103, 97, 116, 101, 192, + ]; + + let deserialized = History::deserialize(&bytes_v0, HISTORY_VERSION); + assert!(deserialized.is_ok()); + + let deserialized = History::deserialize(&bytes_v1, HISTORY_VERSION); + assert!(deserialized.is_err()); + } +} diff --git a/crates/atuin-client/src/history/builder.rs b/crates/atuin-client/src/history/builder.rs new file mode 100644 index 00000000..4e69cf66 --- /dev/null +++ b/crates/atuin-client/src/history/builder.rs @@ -0,0 +1,99 @@ +use typed_builder::TypedBuilder; + +use super::History; + +/// Builder for a history entry that is imported from shell history. +/// +/// The only two required fields are `timestamp` and `command`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryImported { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(default = "unknown".into(), setter(into))] + cwd: String, + #[builder(default = -1)] + exit: i64, + #[builder(default = -1)] + duration: i64, + #[builder(default, setter(strip_option, into))] + session: Option<String>, + #[builder(default, setter(strip_option, into))] + hostname: Option<String>, +} + +impl From<HistoryImported> for History { + fn from(imported: HistoryImported) -> Self { + History::new( + imported.timestamp, + imported.command, + imported.cwd, + imported.exit, + imported.duration, + imported.session, + imported.hostname, + None, + ) + } +} + +/// Builder for a history entry that is captured via hook. +/// +/// This builder is used only at the `start` step of the hook, +/// so it doesn't have any fields which are known only after +/// the command is finished, such as `exit` or `duration`. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryCaptured { + timestamp: time::OffsetDateTime, + #[builder(setter(into))] + command: String, + #[builder(setter(into))] + cwd: String, +} + +impl From<HistoryCaptured> for History { + fn from(captured: HistoryCaptured) -> Self { + History::new( + captured.timestamp, + captured.command, + captured.cwd, + -1, + -1, + None, + None, + None, + ) + } +} + +/// Builder for a history entry that is loaded from the database. +/// +/// All fields are required, as they are all present in the database. +#[derive(Debug, Clone, TypedBuilder)] +pub struct HistoryFromDb { + id: String, + timestamp: time::OffsetDateTime, + command: String, + cwd: String, + exit: i64, + duration: i64, + session: String, + hostname: String, + deleted_at: Option<time::OffsetDateTime>, +} + +impl From<HistoryFromDb> for History { + fn from(from_db: HistoryFromDb) -> Self { + History { + id: from_db.id.into(), + timestamp: from_db.timestamp, + exit: from_db.exit, + command: from_db.command, + cwd: from_db.cwd, + duration: from_db.duration, + session: from_db.session, + hostname: from_db.hostname, + deleted_at: from_db.deleted_at, + } + } +} diff --git a/crates/atuin-client/src/history/store.rs b/crates/atuin-client/src/history/store.rs new file mode 100644 index 00000000..fe2b7b92 --- /dev/null +++ b/crates/atuin-client/src/history/store.rs @@ -0,0 +1,410 @@ +use std::{collections::HashSet, fmt::Write, time::Duration}; + +use eyre::{bail, eyre, Result}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; +use rmp::decode::Bytes; + +use crate::{ + database::{current_context, Database}, + record::{encryption::PASETO_V4, sqlite_store::SqliteStore, store::Store}, +}; +use atuin_common::record::{DecryptedData, Host, HostId, Record, RecordId, RecordIdx}; + +use super::{History, HistoryId, HISTORY_TAG, HISTORY_VERSION}; + +#[derive(Debug)] +pub struct HistoryStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum HistoryRecord { + Create(History), // Create a history record + Delete(HistoryId), // Delete a history record, identified by ID +} + +impl HistoryRecord { + /// Serialize a history record, returning DecryptedData + /// The record will be of a certain type + /// We map those like so: + /// + /// HistoryRecord::Create -> 0 + /// HistoryRecord::Delete-> 1 + /// + /// This numeric identifier is then written as the first byte to the buffer. For history, we + /// append the serialized history right afterwards, to avoid having to handle serialization + /// twice. + /// + /// Deletion simply refers to the history by ID + pub fn serialize(&self) -> Result<DecryptedData> { + // probably don't actually need to use rmp here, but if we ever need to extend it, it's a + // nice wrapper around raw byte stuff + use rmp::encode; + + let mut output = vec![]; + + match self { + HistoryRecord::Create(history) => { + // 0 -> a history create + encode::write_u8(&mut output, 0)?; + + let bytes = history.serialize()?; + + encode::write_bin(&mut output, &bytes.0)?; + } + HistoryRecord::Delete(id) => { + // 1 -> a history delete + encode::write_u8(&mut output, 1)?; + encode::write_str(&mut output, id.0.as_str())?; + } + }; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(bytes: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + let mut bytes = Bytes::new(&bytes.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // 0 -> HistoryRecord::Create + 0 => { + // not super useful to us atm, but perhaps in the future + // written by write_bin above + let _ = decode::read_bin_len(&mut bytes).map_err(error_report)?; + + let record = History::deserialize(bytes.remaining_slice(), version)?; + + Ok(HistoryRecord::Create(record)) + } + + // 1 -> HistoryRecord::Delete + 1 => { + let bytes = bytes.remaining_slice(); + let (id, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!( + "trailing bytes decoding HistoryRecord::Delete - malformed? got {bytes:?}" + ); + } + + Ok(HistoryRecord::Delete(id.to_string().into())) + } + + n => { + bail!("unknown HistoryRecord type {n}") + } + } + } +} + +impl HistoryStore { + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> Self { + HistoryStore { + store, + host_id, + encryption_key, + } + } + + async fn push_record(&self, record: HistoryRecord) -> Result<(RecordId, RecordIdx)> { + let bytes = record.serialize()?; + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + let id = record.id; + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + Ok((id, idx)) + } + + async fn push_batch(&self, records: impl Iterator<Item = HistoryRecord>) -> Result<()> { + let mut ret = Vec::new(); + + let idx = self + .store + .last(self.host_id, HISTORY_TAG) + .await? + .map_or(0, |p| p.idx + 1); + + // Could probably _also_ do this as an iterator, but let's see how this is for now. + // optimizing for minimal sqlite transactions, this code can be optimised later + for (n, record) in records.enumerate() { + let bytes = record.serialize()?; + + let record = Record::builder() + .host(Host::new(self.host_id)) + .version(HISTORY_VERSION.to_string()) + .tag(HISTORY_TAG.to_string()) + .idx(idx + n as u64) + .data(bytes) + .build(); + + let record = record.encrypt::<PASETO_V4>(&self.encryption_key); + + ret.push(record); + } + + self.store.push_batch(ret.iter()).await?; + + Ok(()) + } + + pub async fn delete(&self, id: HistoryId) -> Result<(RecordId, RecordIdx)> { + let record = HistoryRecord::Delete(id); + + self.push_record(record).await + } + + pub async fn push(&self, history: History) -> Result<(RecordId, RecordIdx)> { + // TODO(ellie): move the history store to its own file + // it's tiny rn so fine as is + let record = HistoryRecord::Create(history); + + self.push_record(record).await + } + + pub async fn history(&self) -> Result<Vec<HistoryRecord>> { + // Atm this loads all history into memory + // Not ideal as that is potentially quite a lot, although history will be small. + let records = self.store.all_tagged(HISTORY_TAG).await?; + let mut ret = Vec::with_capacity(records.len()); + + for record in records.into_iter() { + let hist = match record.version.as_str() { + HISTORY_VERSION => { + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + + HistoryRecord::deserialize(&decrypted.data, HISTORY_VERSION) + } + version => bail!("unknown history version {version:?}"), + }?; + + ret.push(hist); + } + + Ok(ret) + } + + pub async fn build(&self, database: &dyn Database) -> Result<()> { + // I'd like to change how we rebuild and not couple this with the database, but need to + // consider the structure more deeply. This will be easy to change. + + // TODO(ellie): page or iterate this + let history = self.history().await?; + + // In theory we could flatten this here + // The current issue is that the database may have history in it already, from the old sync + // This didn't actually delete old history + // If we're sure we have a DB only maintained by the new store, we can flatten + // create/delete before we even get to sqlite + let mut creates = Vec::new(); + let mut deletes = Vec::new(); + + for i in history { + match i { + HistoryRecord::Create(h) => { + creates.push(h); + } + HistoryRecord::Delete(id) => { + deletes.push(id); + } + } + } + + database.save_bulk(&creates).await?; + database.delete_rows(&deletes).await?; + + Ok(()) + } + + pub async fn incremental_build(&self, database: &dyn Database, ids: &[RecordId]) -> Result<()> { + for id in ids { + let record = self.store.get(*id).await; + + let record = if let Ok(record) = record { + record + } else { + continue; + }; + + if record.tag != HISTORY_TAG { + continue; + } + + let decrypted = record.decrypt::<PASETO_V4>(&self.encryption_key)?; + let record = HistoryRecord::deserialize(&decrypted.data, HISTORY_VERSION)?; + + match record { + HistoryRecord::Create(h) => { + // TODO: benchmark CPU time/memory tradeoff of batch commit vs one at a time + database.save(&h).await?; + } + HistoryRecord::Delete(id) => { + database.delete_rows(&[id]).await?; + } + } + } + + Ok(()) + } + + /// Get a list of history IDs that exist in the store + /// Note: This currently involves loading all history into memory. This is not going to be a + /// large amount in absolute terms, but do not all it in a hot loop. + pub async fn history_ids(&self) -> Result<HashSet<HistoryId>> { + let history = self.history().await?; + + let ret = HashSet::from_iter(history.iter().map(|h| match h { + HistoryRecord::Create(h) => h.id.clone(), + HistoryRecord::Delete(id) => id.clone(), + })); + + Ok(ret) + } + + pub async fn init_store(&self, db: &impl Database) -> Result<()> { + let pb = ProgressBar::new_spinner(); + pb.set_style( + ProgressStyle::with_template("{spinner:.blue} {msg}") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| { + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() + }) + .progress_chars("#>-"), + ); + pb.enable_steady_tick(Duration::from_millis(500)); + + pb.set_message("Fetching history from old database"); + + let context = current_context(); + let history = db.list(&[], &context, None, false, true).await?; + + pb.set_message("Fetching history already in store"); + let store_ids = self.history_ids().await?; + + pb.set_message("Converting old history to new store"); + let mut records = Vec::new(); + + for i in history { + debug!("loaded {}", i.id); + + if store_ids.contains(&i.id) { + debug!("skipping {} - already exists", i.id); + continue; + } + + if i.deleted_at.is_some() { + records.push(HistoryRecord::Delete(i.id)); + } else { + records.push(HistoryRecord::Create(i)); + } + } + + pb.set_message("Writing to db"); + + if !records.is_empty() { + self.push_batch(records.into_iter()).await?; + } + + pb.finish_with_message("Import complete"); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use atuin_common::record::DecryptedData; + use time::macros::datetime; + + use crate::history::{store::HistoryRecord, HISTORY_VERSION}; + + use super::History; + + #[test] + fn test_serialize_deserialize_create() { + let bytes = [ + 204, 0, 196, 141, 205, 0, 0, 153, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, + 55, 53, 55, 99, 100, 50, 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, + 56, 49, 207, 23, 166, 251, 212, 181, 82, 0, 0, 100, 0, 162, 108, 115, 217, 41, 47, 85, + 115, 101, 114, 115, 47, 101, 108, 108, 105, 101, 47, 115, 114, 99, 47, 103, 105, 116, + 104, 117, 98, 46, 99, 111, 109, 47, 97, 116, 117, 105, 110, 115, 104, 47, 97, 116, 117, + 105, 110, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 97, 100, 56, 57, 55, 53, 57, 55, + 56, 53, 50, 53, 50, 55, 97, 51, 49, 99, 57, 57, 56, 48, 53, 57, 170, 98, 111, 111, 112, + 58, 101, 108, 108, 105, 101, 192, + ]; + + let history = History { + id: "018cd4fe81757cd2aee65cd7861f9c81".to_owned().into(), + timestamp: datetime!(2024-01-04 00:00:00.000000 +00:00), + duration: 100, + exit: 0, + command: "ls".to_owned(), + cwd: "/Users/ellie/src/github.com/atuinsh/atuin".to_owned(), + session: "018cd4fead897597852527a31c998059".to_owned(), + hostname: "boop:ellie".to_owned(), + deleted_at: None, + }; + + let record = HistoryRecord::Create(history); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + // check the snapshot too + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } + + #[test] + fn test_serialize_deserialize_delete() { + let bytes = [ + 204, 1, 217, 32, 48, 49, 56, 99, 100, 52, 102, 101, 56, 49, 55, 53, 55, 99, 100, 50, + 97, 101, 101, 54, 53, 99, 100, 55, 56, 54, 49, 102, 57, 99, 56, 49, + ]; + let record = HistoryRecord::Delete("018cd4fe81757cd2aee65cd7861f9c81".to_string().into()); + + let serialized = record.serialize().expect("failed to serialize history"); + assert_eq!(serialized.0, bytes); + + let deserialized = HistoryRecord::deserialize(&serialized, HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + + let deserialized = + HistoryRecord::deserialize(&DecryptedData(Vec::from(bytes)), HISTORY_VERSION) + .expect("failed to deserialize HistoryRecord"); + assert_eq!(deserialized, record); + } +} diff --git a/crates/atuin-client/src/import/bash.rs b/crates/atuin-client/src/import/bash.rs new file mode 100644 index 00000000..ade1f751 --- /dev/null +++ b/crates/atuin-client/src/import/bash.rs @@ -0,0 +1,218 @@ +use std::{path::PathBuf, str}; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{eyre, Result}; +use itertools::Itertools; +use time::{Duration, OffsetDateTime}; + +use super::{get_histpath, unix_byte_lines, Importer, Loader}; +use crate::history::History; +use crate::import::read_to_end; + +#[derive(Debug)] +pub struct Bash { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".bash_history")) +} + +#[async_trait] +impl Importer for Bash { + const NAME: &'static str = "bash"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + let count = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| matches!(line, LineType::Command(_))) + .count(); + Ok(count) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let lines = unix_byte_lines(&self.bytes) + .map(LineType::from) + .filter(|line| !matches!(line, LineType::NotUtf8)) // invalid utf8 are ignored + .collect_vec(); + + let (commands_before_first_timestamp, first_timestamp) = lines + .iter() + .enumerate() + .find_map(|(i, line)| match line { + LineType::Timestamp(t) => Some((i, *t)), + _ => None, + }) + // if no known timestamps, use now as base + .unwrap_or((lines.len(), OffsetDateTime::now_utc())); + + // if no timestamp is recorded, then use this increment to set an arbitrary timestamp + // to preserve ordering + // this increment is deliberately very small to prevent particularly fast fingers + // causing ordering issues; it also helps in handling the "here document" syntax, + // where several lines are recorded in succession without individual timestamps + let timestamp_increment = Duration::milliseconds(1); + + // make sure there is a minimum amount of time before the first known timestamp + // to fit all commands, given the default increment + let mut next_timestamp = + first_timestamp - timestamp_increment * commands_before_first_timestamp as i32; + + for line in lines.into_iter() { + match line { + LineType::NotUtf8 => unreachable!(), // already filtered + LineType::Empty => {} // do nothing + LineType::Timestamp(t) => { + if t < next_timestamp { + warn!("Time reversal detected in Bash history! Commands may be ordered incorrectly."); + } + next_timestamp = t; + } + LineType::Command(c) => { + let imported = History::import().timestamp(next_timestamp).command(c); + + h.push(imported.build().into()).await?; + next_timestamp += timestamp_increment; + } + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum LineType<'a> { + NotUtf8, + /// Can happen when using the "here document" syntax. + Empty, + /// A timestamp line start with a '#', followed immediately by an integer + /// that represents seconds since UNIX epoch. + Timestamp(OffsetDateTime), + /// Anything else. + Command(&'a str), +} +impl<'a> From<&'a [u8]> for LineType<'a> { + fn from(bytes: &'a [u8]) -> Self { + let Ok(line) = str::from_utf8(bytes) else { + return LineType::NotUtf8; + }; + if line.is_empty() { + return LineType::Empty; + } + let parsed = match try_parse_line_as_timestamp(line) { + Some(time) => LineType::Timestamp(time), + None => LineType::Command(line), + }; + parsed + } +} + +fn try_parse_line_as_timestamp(line: &str) -> Option<OffsetDateTime> { + let seconds = line.strip_prefix('#')?.parse().ok()?; + OffsetDateTime::from_unix_timestamp(seconds).ok() +} + +#[cfg(test)] +mod test { + use std::cmp::Ordering; + + use itertools::{assert_equal, Itertools}; + + use crate::import::{tests::TestLoader, Importer}; + + use super::Bash; + + #[tokio::test] + async fn parse_no_timestamps() { + let bytes = r"cargo install atuin +cargo update +cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + #[tokio::test] + async fn parse_with_timestamps() { + let bytes = b"#1672918999 +git reset +#1672919006 +git clean -dxf +#1672919020 +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert_equal( + loader.buf.iter().map(|h| h.timestamp.unix_timestamp()), + [1672918999, 1672919006, 1672919020], + ) + } + + #[tokio::test] + async fn parse_with_partial_timestamps() { + let bytes = b"git reset +#1672919006 +git clean -dxf +cd ../ +" + .to_vec(); + + let mut bash = Bash { bytes }; + assert_eq!(bash.entries().await.unwrap(), 3); + + let mut loader = TestLoader::default(); + bash.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["git reset", "git clean -dxf", "cd ../"], + ); + assert!(is_strictly_sorted(loader.buf.iter().map(|h| h.timestamp))) + } + + fn is_strictly_sorted<T>(iter: impl IntoIterator<Item = T>) -> bool + where + T: Clone + PartialOrd, + { + iter.into_iter() + .tuple_windows() + .all(|(a, b)| matches!(a.partial_cmp(&b), Some(Ordering::Less))) + } +} diff --git a/crates/atuin-client/src/import/fish.rs b/crates/atuin-client/src/import/fish.rs new file mode 100644 index 00000000..714b2d01 --- /dev/null +++ b/crates/atuin-client/src/import/fish.rs @@ -0,0 +1,179 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{eyre, Result}; +use time::OffsetDateTime; + +use super::{unix_byte_lines, Importer, Loader}; +use crate::history::History; +use crate::import::read_to_end; + +#[derive(Debug)] +pub struct Fish { + bytes: Vec<u8>, +} + +/// see https://fishshell.com/docs/current/interactive.html#searchable-command-history +fn default_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let data = std::env::var("XDG_DATA_HOME").map_or_else( + |_| base.home_dir().join(".local").join("share"), + PathBuf::from, + ); + + // fish supports multiple history sessions + // If `fish_history` var is missing, or set to `default`, use `fish` as the session + let session = std::env::var("fish_history").unwrap_or_else(|_| String::from("fish")); + let session = if session == "default" { + String::from("fish") + } else { + session + }; + + let mut histpath = data.join("fish"); + histpath.push(format!("{session}_history")); + + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Fish { + const NAME: &'static str = "fish"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(default_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut time: Option<OffsetDateTime> = None; + let mut cmd: Option<String> = None; + + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + if let Some(c) = s.strip_prefix("- cmd: ") { + // first, we must deal with the prev cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + // using raw strings to avoid needing escaping. + // replaces double backslashes with single backslashes + let c = c.replace(r"\\", r"\"); + // replaces escaped newlines + let c = c.replace(r"\n", "\n"); + // TODO: any other escape characters? + + cmd = Some(c); + } else if let Some(t) = s.strip_prefix(" when: ") { + // if t is not an int, just ignore this line + if let Ok(t) = t.parse::<i64>() { + time = Some(OffsetDateTime::from_unix_timestamp(t)?); + } + } else { + // ... ignore paths lines + } + } + + // we might have a trailing cmd + if let Some(cmd) = cmd.take() { + let time = time.unwrap_or(now); + let entry = History::import().timestamp(time).command(cmd); + + loader.push(entry.build().into()).await?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use crate::import::{tests::TestLoader, Importer}; + + use super::Fish; + + #[tokio::test] + async fn parse_complex() { + // complicated input with varying contents and escaped strings. + let bytes = r#"- cmd: history --help + when: 1639162832 +- cmd: cat ~/.bash_history + when: 1639162851 + paths: + - ~/.bash_history +- cmd: ls ~/.local/share/fish/fish_history + when: 1639162890 + paths: + - ~/.local/share/fish/fish_history +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162893 + paths: + - ~/.local/share/fish/fish_history +ERROR +- CORRUPTED: ENTRY + CONTINUE: + - AS + - NORMAL +- cmd: echo "foo" \\\n'bar' baz + when: 1639162933 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639162939 + paths: + - ~/.local/share/fish/fish_history +- cmd: echo "\\"" \\\\ "\\\\" + when: 1639163063 +- cmd: cat ~/.local/share/fish/fish_history + when: 1639163066 + paths: + - ~/.local/share/fish/fish_history +"# + .as_bytes() + .to_owned(); + + let fish = Fish { bytes }; + + let mut loader = TestLoader::default(); + fish.load(&mut loader).await.unwrap(); + let mut history = loader.buf.into_iter(); + + // simple wrapper for fish history entry + macro_rules! fishtory { + ($timestamp:expr, $command:expr) => { + let h = history.next().expect("missing entry in history"); + assert_eq!(h.command.as_str(), $command); + assert_eq!(h.timestamp.unix_timestamp(), $timestamp); + }; + } + + fishtory!(1639162832, "history --help"); + fishtory!(1639162851, "cat ~/.bash_history"); + fishtory!(1639162890, "ls ~/.local/share/fish/fish_history"); + fishtory!(1639162893, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639162933, "echo \"foo\" \\\n'bar' baz"); + fishtory!(1639162939, "cat ~/.local/share/fish/fish_history"); + fishtory!(1639163063, r#"echo "\"" \\ "\\""#); + fishtory!(1639163066, "cat ~/.local/share/fish/fish_history"); + } +} diff --git a/crates/atuin-client/src/import/mod.rs b/crates/atuin-client/src/import/mod.rs new file mode 100644 index 00000000..c9d8c798 --- /dev/null +++ b/crates/atuin-client/src/import/mod.rs @@ -0,0 +1,111 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; + +use async_trait::async_trait; +use eyre::{bail, Result}; +use memchr::Memchr; + +use crate::history::History; + +pub mod bash; +pub mod fish; +pub mod nu; +pub mod nu_histdb; +pub mod resh; +pub mod xonsh; +pub mod xonsh_sqlite; +pub mod zsh; +pub mod zsh_histdb; + +#[async_trait] +pub trait Importer: Sized { + const NAME: &'static str; + async fn new() -> Result<Self>; + async fn entries(&mut self) -> Result<usize>; + async fn load(self, loader: &mut impl Loader) -> Result<()>; +} + +#[async_trait] +pub trait Loader: Sync + Send { + async fn push(&mut self, hist: History) -> eyre::Result<()>; +} + +fn unix_byte_lines(input: &[u8]) -> impl Iterator<Item = &[u8]> { + UnixByteLines { + iter: memchr::memchr_iter(b'\n', input), + bytes: input, + i: 0, + } +} + +struct UnixByteLines<'a> { + iter: Memchr<'a>, + bytes: &'a [u8], + i: usize, +} + +impl<'a> Iterator for UnixByteLines<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option<Self::Item> { + let j = self.iter.next()?; + let out = &self.bytes[self.i..j]; + self.i = j + 1; + Some(out) + } + + fn count(self) -> usize + where + Self: Sized, + { + self.iter.count() + } +} + +fn count_lines(input: &[u8]) -> usize { + unix_byte_lines(input).count() +} + +fn get_histpath<D>(def: D) -> Result<PathBuf> +where + D: FnOnce() -> Result<PathBuf>, +{ + if let Ok(p) = std::env::var("HISTFILE") { + is_file(PathBuf::from(p)) + } else { + is_file(def()?) + } +} + +fn read_to_end(path: PathBuf) -> Result<Vec<u8>> { + let mut bytes = Vec::new(); + let mut f = File::open(path)?; + f.read_to_end(&mut bytes)?; + Ok(bytes) +} +fn is_file(p: PathBuf) -> Result<PathBuf> { + if p.is_file() { + Ok(p) + } else { + bail!("Could not find history file {:?}. Try setting $HISTFILE", p) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Default)] + pub struct TestLoader { + pub buf: Vec<History>, + } + + #[async_trait] + impl Loader for TestLoader { + async fn push(&mut self, hist: History) -> Result<()> { + self.buf.push(hist); + Ok(()) + } + } +} diff --git a/crates/atuin-client/src/import/nu.rs b/crates/atuin-client/src/import/nu.rs new file mode 100644 index 00000000..a45d83c5 --- /dev/null +++ b/crates/atuin-client/src/import/nu.rs @@ -0,0 +1,67 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{eyre, Result}; +use time::OffsetDateTime; + +use super::{unix_byte_lines, Importer, Loader}; +use crate::history::History; +use crate::import::read_to_end; + +#[derive(Debug)] +pub struct Nu { + bytes: Vec<u8>, +} + +fn get_histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histpath = config_dir.join("history.txt"); + if histpath.exists() { + Ok(histpath) + } else { + Err(eyre!("Could not find history file.")) + } +} + +#[async_trait] +impl Importer for Nu { + const NAME: &'static str = "nu"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath()?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + + let cmd: String = s.replace("<\\n>", "\n"); + + let offset = time::Duration::nanoseconds(counter); + counter += 1; + + let entry = History::import().timestamp(now - offset).command(cmd); + + h.push(entry.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/atuin-client/src/import/nu_histdb.rs b/crates/atuin-client/src/import/nu_histdb.rs new file mode 100644 index 00000000..f0e8e95c --- /dev/null +++ b/crates/atuin-client/src/import/nu_histdb.rs @@ -0,0 +1,113 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{eyre, Result}; +use sqlx::{sqlite::SqlitePool, Pool}; +use time::{Duration, OffsetDateTime}; + +use super::Importer; +use crate::history::History; +use crate::import::Loader; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub command_line: Vec<u8>, + pub start_timestamp: i64, + pub session_id: i64, + pub hostname: Vec<u8>, + pub cwd: Vec<u8>, + pub duration_ms: i64, + pub exit_status: i64, + pub more_info: Vec<u8>, +} + +impl From<HistDbEntry> for History { + fn from(histdb_item: HistDbEntry) -> Self { + let ts_secs = histdb_item.start_timestamp / 1000; + let ts_ns = (histdb_item.start_timestamp % 1000) * 1_000_000; + let imported = History::import() + .timestamp( + OffsetDateTime::from_unix_timestamp(ts_secs).unwrap() + + Duration::nanoseconds(ts_ns), + ) + .command(String::from_utf8(histdb_item.command_line).unwrap()) + .cwd(String::from_utf8(histdb_item.cwd).unwrap()) + .exit(histdb_item.exit_status) + .duration(histdb_item.duration_ms) + .session(format!("{:x}", histdb_item.session_id)) + .hostname(String::from_utf8(histdb_item.hostname).unwrap()); + + imported.build().into() + } +} + +#[derive(Debug)] +pub struct NuHistDb { + histdb: Vec<HistDbEntry>, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + id, command_line, start_timestamp, session_id, hostname, cwd, duration_ms, exit_status, + more_info + FROM history + ORDER BY start_timestamp + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl NuHistDb { + pub fn histpath() -> Result<PathBuf> { + let base = BaseDirs::new().ok_or_else(|| eyre!("could not determine data directory"))?; + let config_dir = base.config_dir().join("nushell"); + + let histdb_path = config_dir.join("history.sqlite3"); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!("Could not find history file.")) + } + } +} + +#[async_trait] +impl Importer for NuHistDb { + // Not sure how this is used + const NAME: &'static str = "nu_histdb"; + + /// Creates a new NuHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = NuHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for i in self.histdb { + h.push(i.into()).await?; + } + Ok(()) + } +} diff --git a/crates/atuin-client/src/import/resh.rs b/crates/atuin-client/src/import/resh.rs new file mode 100644 index 00000000..396d11fd --- /dev/null +++ b/crates/atuin-client/src/import/resh.rs @@ -0,0 +1,140 @@ +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{eyre, Result}; +use serde::Deserialize; + +use atuin_common::utils::uuid_v7; +use time::OffsetDateTime; + +use super::{get_histpath, unix_byte_lines, Importer, Loader}; +use crate::history::History; +use crate::import::read_to_end; + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ReshEntry { + pub cmd_line: String, + pub exit_code: i64, + pub shell: String, + pub uname: String, + pub session_id: String, + pub home: String, + pub lang: String, + pub lc_all: String, + pub login: String, + pub pwd: String, + pub pwd_after: String, + pub shell_env: String, + pub term: String, + pub real_pwd: String, + pub real_pwd_after: String, + pub pid: i64, + pub session_pid: i64, + pub host: String, + pub hosttype: String, + pub ostype: String, + pub machtype: String, + pub shlvl: i64, + pub timezone_before: String, + pub timezone_after: String, + pub realtime_before: f64, + pub realtime_after: f64, + pub realtime_before_local: f64, + pub realtime_after_local: f64, + pub realtime_duration: f64, + pub realtime_since_session_start: f64, + pub realtime_since_boot: f64, + pub git_dir: String, + pub git_real_dir: String, + pub git_origin_remote: String, + pub git_dir_after: String, + pub git_real_dir_after: String, + pub git_origin_remote_after: String, + pub machine_id: String, + pub os_release_id: String, + pub os_release_version_id: String, + pub os_release_id_like: String, + pub os_release_name: String, + pub os_release_pretty_name: String, + pub resh_uuid: String, + pub resh_version: String, + pub resh_revision: String, + pub parts_merged: bool, + pub recalled: bool, + pub recall_last_cmd_line: String, + pub cols: String, + pub lines: String, +} + +#[derive(Debug)] +pub struct Resh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + Ok(home_dir.join(".resh_history.json")) +} + +#[async_trait] +impl Importer for Resh { + const NAME: &'static str = "resh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + for b in unix_byte_lines(&self.bytes) { + let s = match std::str::from_utf8(b) { + Ok(s) => s, + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let entry = match serde_json::from_str::<ReshEntry>(s) { + Ok(e) => e, + Err(_) => continue, // skip invalid json :shrug: + }; + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let timestamp = { + let secs = entry.realtime_before.floor() as i64; + let nanosecs = (entry.realtime_before.fract() * 1_000_000_000_f64).round() as i64; + OffsetDateTime::from_unix_timestamp(secs)? + time::Duration::nanoseconds(nanosecs) + }; + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let duration = { + let secs = entry.realtime_after.floor() as i64; + let nanosecs = (entry.realtime_after.fract() * 1_000_000_000_f64).round() as i64; + let base = OffsetDateTime::from_unix_timestamp(secs)? + + time::Duration::nanoseconds(nanosecs); + let difference = base - timestamp; + difference.whole_nanoseconds() as i64 + }; + + let imported = History::import() + .command(entry.cmd_line) + .timestamp(timestamp) + .duration(duration) + .exit(entry.exit_code) + .cwd(entry.pwd) + .hostname(entry.host) + // CHECK: should we add uuid here? It's not set in the other importers + .session(uuid_v7().as_simple().to_string()); + + h.push(imported.build().into()).await?; + } + + Ok(()) + } +} diff --git a/crates/atuin-client/src/import/xonsh.rs b/crates/atuin-client/src/import/xonsh.rs new file mode 100644 index 00000000..19ce4cf6 --- /dev/null +++ b/crates/atuin-client/src/import/xonsh.rs @@ -0,0 +1,233 @@ +use std::env; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{eyre, Result}; +use serde::Deserialize; +use time::OffsetDateTime; +use uuid::timestamp::{context::NoContext, Timestamp}; +use uuid::Uuid; + +use super::{get_histpath, Importer, Loader}; +use crate::history::History; +use crate::utils::get_host_user; + +// Note: both HistoryFile and HistoryData have other keys present in the JSON, we don't +// care about them so we leave them unspecified so as to avoid deserializing unnecessarily. +#[derive(Debug, Deserialize)] +struct HistoryFile { + data: HistoryData, +} + +#[derive(Debug, Deserialize)] +struct HistoryData { + sessionid: String, + cmds: Vec<HistoryCmd>, +} + +#[derive(Debug, Deserialize)] +struct HistoryCmd { + cwd: String, + inp: String, + rtn: Option<i64>, + ts: (f64, f64), +} + +#[derive(Debug)] +pub struct Xonsh { + // history is stored as a bunch of json files, one per session + sessions: Vec<HistoryData>, + hostname: String, +} + +fn xonsh_hist_dir(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("history_json"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_dir = base.data_dir().join("xonsh/history_json"); + if hist_dir.exists() || cfg!(test) { + Ok(hist_dir) + } else { + Err(eyre!("Could not find xonsh history files")) + } +} + +fn load_sessions(hist_dir: &Path) -> Result<Vec<HistoryData>> { + let mut sessions = vec![]; + for entry in fs::read_dir(hist_dir)? { + let p = entry?.path(); + let ext = p.extension().and_then(|e| e.to_str()); + if p.is_file() && ext == Some("json") { + if let Some(data) = load_session(&p)? { + sessions.push(data); + } + } + } + Ok(sessions) +} + +fn load_session(path: &Path) -> Result<Option<HistoryData>> { + let file = File::open(path)?; + // empty files are not valid json, so we can't deserialize them + if file.metadata()?.len() == 0 { + return Ok(None); + } + + let mut hist_file: HistoryFile = serde_json::from_reader(file)?; + + // if there are commands in this session, replace the existing UUIDv4 + // with a UUIDv7 generated from the timestamp of the first command + if let Some(cmd) = hist_file.data.cmds.first() { + let seconds = cmd.ts.0.trunc() as u64; + let nanos = (cmd.ts.0.fract() * 1_000_000_000_f64) as u32; + let ts = Timestamp::from_unix(NoContext, seconds, nanos); + hist_file.data.sessionid = Uuid::new_v7(ts).to_string(); + } + Ok(Some(hist_file.data)) +} + +#[async_trait] +impl Importer for Xonsh { + const NAME: &'static str = "xonsh"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let hist_dir = get_histpath(|| xonsh_hist_dir(xonsh_data_dir))?; + let sessions = load_sessions(&hist_dir)?; + let hostname = get_host_user(); + Ok(Xonsh { sessions, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let total = self.sessions.iter().map(|s| s.cmds.len()).sum(); + Ok(total) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + for session in self.sessions { + for cmd in session.cmds { + let (start, end) = cmd.ts; + let ts_nanos = (start * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos)?; + + let duration = (end - start) * 1_000_000_000_f64; + + match cmd.rtn { + Some(exit) => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + None => { + let entry = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(cmd.inp.trim()) + .cwd(cmd.cwd) + .session(session.sessionid.clone()) + .hostname(self.hostname.clone()); + loader.push(entry.build().into()).await?; + } + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_hist_dir_xonsh() { + let hist_dir = xonsh_hist_dir(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + hist_dir, + PathBuf::from("/home/user/xonsh_data/history_json") + ); + } + + #[tokio::test] + async fn test_import() { + let dir = PathBuf::from("tests/data/xonsh"); + let sessions = load_sessions(&dir).unwrap(); + let hostname = "box:user".to_string(); + let xonsh = Xonsh { sessions, hostname }; + + let mut loader = TestLoader::default(); + xonsh.load(&mut loader).await.unwrap(); + // order in buf will depend on filenames, so sort by timestamp for consistency + loader.buf.sort_by_key(|h| h.timestamp); + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 04:17:59.478272256 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4651069) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 04:18:01.70632832 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(21288633) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:31.142515968 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(1) + .duration(10269403) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:41:32.271584 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin/atuin-client".to_string()) + .exit(0) + .duration(4259347) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/atuin-client/src/import/xonsh_sqlite.rs b/crates/atuin-client/src/import/xonsh_sqlite.rs new file mode 100644 index 00000000..2817dc63 --- /dev/null +++ b/crates/atuin-client/src/import/xonsh_sqlite.rs @@ -0,0 +1,217 @@ +use std::env; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::BaseDirs; +use eyre::{eyre, Result}; +use futures::TryStreamExt; +use sqlx::{sqlite::SqlitePool, FromRow, Row}; +use time::OffsetDateTime; +use uuid::timestamp::{context::NoContext, Timestamp}; +use uuid::Uuid; + +use super::{get_histpath, Importer, Loader}; +use crate::history::History; +use crate::utils::get_host_user; + +#[derive(Debug, FromRow)] +struct HistDbEntry { + inp: String, + rtn: Option<i64>, + tsb: f64, + tse: f64, + cwd: String, + session_start: f64, +} + +impl HistDbEntry { + fn into_hist_with_hostname(self, hostname: String) -> History { + let ts_nanos = (self.tsb * 1_000_000_000_f64) as i128; + let timestamp = OffsetDateTime::from_unix_timestamp_nanos(ts_nanos).unwrap(); + + let session_ts_seconds = self.session_start.trunc() as u64; + let session_ts_nanos = (self.session_start.fract() * 1_000_000_000_f64) as u32; + let session_ts = Timestamp::from_unix(NoContext, session_ts_seconds, session_ts_nanos); + let session_id = Uuid::new_v7(session_ts).to_string(); + let duration = (self.tse - self.tsb) * 1_000_000_000_f64; + + if let Some(exit) = self.rtn { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .exit(exit) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } else { + let imported = History::import() + .timestamp(timestamp) + .duration(duration.trunc() as i64) + .command(self.inp) + .cwd(self.cwd) + .session(session_id) + .hostname(hostname); + imported.build().into() + } + } +} + +fn xonsh_db_path(xonsh_data_dir: Option<String>) -> Result<PathBuf> { + // if running within xonsh, this will be available + if let Some(d) = xonsh_data_dir { + let mut path = PathBuf::from(d); + path.push("xonsh-history.sqlite"); + return Ok(path); + } + + // otherwise, fall back to default + let base = BaseDirs::new().ok_or_else(|| eyre!("Could not determine home directory"))?; + + let hist_file = base.data_dir().join("xonsh/xonsh-history.sqlite"); + if hist_file.exists() || cfg!(test) { + Ok(hist_file) + } else { + Err(eyre!( + "Could not find xonsh history db at: {}", + hist_file.to_string_lossy() + )) + } +} + +#[derive(Debug)] +pub struct XonshSqlite { + pool: SqlitePool, + hostname: String, +} + +#[async_trait] +impl Importer for XonshSqlite { + const NAME: &'static str = "xonsh_sqlite"; + + async fn new() -> Result<Self> { + // wrap xonsh-specific path resolver in general one so that it respects $HISTPATH + let xonsh_data_dir = env::var("XONSH_DATA_DIR").ok(); + let db_path = get_histpath(|| xonsh_db_path(xonsh_data_dir))?; + let connection_str = db_path.to_str().ok_or_else(|| { + eyre!( + "Invalid path for SQLite database: {}", + db_path.to_string_lossy() + ) + })?; + + let pool = SqlitePool::connect(connection_str).await?; + let hostname = get_host_user(); + Ok(XonshSqlite { pool, hostname }) + } + + async fn entries(&mut self) -> Result<usize> { + let query = "SELECT COUNT(*) FROM xonsh_history"; + let row = sqlx::query(query).fetch_one(&self.pool).await?; + let count: u32 = row.get(0); + Ok(count as usize) + } + + async fn load(self, loader: &mut impl Loader) -> Result<()> { + let query = r#" + SELECT inp, rtn, tsb, tse, cwd, + MIN(tsb) OVER (PARTITION BY sessionid) AS session_start + FROM xonsh_history + ORDER BY rowid + "#; + + let mut entries = sqlx::query_as::<_, HistDbEntry>(query).fetch(&self.pool); + + let mut count = 0; + while let Some(entry) = entries.try_next().await? { + let hist = entry.into_hist_with_hostname(self.hostname.clone()); + loader.push(hist).await?; + count += 1; + } + + println!("Loaded: {count}"); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use super::*; + + use crate::history::History; + use crate::import::tests::TestLoader; + + #[test] + fn test_db_path_xonsh() { + let db_path = xonsh_db_path(Some("/home/user/xonsh_data".to_string())).unwrap(); + assert_eq!( + db_path, + PathBuf::from("/home/user/xonsh_data/xonsh-history.sqlite") + ); + } + + #[tokio::test] + async fn test_import() { + let connection_str = "tests/data/xonsh-history.sqlite"; + let xonsh_sqlite = XonshSqlite { + pool: SqlitePool::connect(connection_str).await.unwrap(), + hostname: "box:user".to_string(), + }; + + let mut loader = TestLoader::default(); + xonsh_sqlite.load(&mut loader).await.unwrap(); + + for (actual, expected) in loader.buf.iter().zip(expected_hist_entries().iter()) { + assert_eq!(actual.timestamp, expected.timestamp); + assert_eq!(actual.command, expected.command); + assert_eq!(actual.cwd, expected.cwd); + assert_eq!(actual.exit, expected.exit); + assert_eq!(actual.duration, expected.duration); + assert_eq!(actual.hostname, expected.hostname); + } + } + + fn expected_hist_entries() -> [History; 4] { + [ + History::import() + .timestamp(datetime!(2024-02-6 17:56:21.130956288 +00:00:00)) + .command("echo hello world!".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(2628564) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:28.190406144 +00:00:00)) + .command("ls -l".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(9371519) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:46.989020928 +00:00:00)) + .command("false".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(1) + .duration(17337560) + .hostname("box:user".to_string()) + .build() + .into(), + History::import() + .timestamp(datetime!(2024-02-06 17:56:48.218384128 +00:00:00)) + .command("exit".to_string()) + .cwd("/home/user/Documents/code/atuin".to_string()) + .exit(0) + .duration(4599094) + .hostname("box:user".to_string()) + .build() + .into(), + ] + } +} diff --git a/crates/atuin-client/src/import/zsh.rs b/crates/atuin-client/src/import/zsh.rs new file mode 100644 index 00000000..5bc8fc16 --- /dev/null +++ b/crates/atuin-client/src/import/zsh.rs @@ -0,0 +1,229 @@ +// import old shell history! +// automatically hoover up all that we can find + +use std::borrow::Cow; +use std::path::PathBuf; + +use async_trait::async_trait; +use directories::UserDirs; +use eyre::{eyre, Result}; +use time::OffsetDateTime; + +use super::{get_histpath, unix_byte_lines, Importer, Loader}; +use crate::history::History; +use crate::import::read_to_end; + +#[derive(Debug)] +pub struct Zsh { + bytes: Vec<u8>, +} + +fn default_histpath() -> Result<PathBuf> { + // oh-my-zsh sets HISTFILE=~/.zhistory + // zsh has no default value for this var, but uses ~/.zhistory. + // we could maybe be smarter about this in the future :) + let user_dirs = UserDirs::new().ok_or_else(|| eyre!("could not find user directories"))?; + let home_dir = user_dirs.home_dir(); + + let mut candidates = [".zhistory", ".zsh_history"].iter(); + loop { + match candidates.next() { + Some(candidate) => { + let histpath = home_dir.join(candidate); + if histpath.exists() { + break Ok(histpath); + } + } + None => { + break Err(eyre!( + "Could not find history file. Try setting and exporting $HISTFILE" + )) + } + } + } +} + +#[async_trait] +impl Importer for Zsh { + const NAME: &'static str = "zsh"; + + async fn new() -> Result<Self> { + let bytes = read_to_end(get_histpath(default_histpath)?)?; + Ok(Self { bytes }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(super::count_lines(&self.bytes)) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let now = OffsetDateTime::now_utc(); + let mut line = String::new(); + + let mut counter = 0; + for b in unix_byte_lines(&self.bytes) { + let s = match unmetafy(b) { + Some(s) => s, + _ => continue, // we can skip past things like invalid utf8 + }; + + if let Some(s) = s.strip_suffix('\\') { + line.push_str(s); + line.push_str("\\\n"); + } else { + line.push_str(&s); + let command = std::mem::take(&mut line); + + if let Some(command) = command.strip_prefix(": ") { + counter += 1; + h.push(parse_extended(command, counter)).await?; + } else { + let offset = time::Duration::seconds(counter); + counter += 1; + + let imported = History::import() + // preserve ordering + .timestamp(now - offset) + .command(command.trim_end().to_string()); + + h.push(imported.build().into()).await?; + } + } + } + + Ok(()) + } +} + +fn parse_extended(line: &str, counter: i64) -> History { + let (time, duration) = line.split_once(':').unwrap(); + let (duration, command) = duration.split_once(';').unwrap(); + + let time = time + .parse::<i64>() + .ok() + .and_then(|t| OffsetDateTime::from_unix_timestamp(t).ok()) + .unwrap_or_else(OffsetDateTime::now_utc) + + time::Duration::milliseconds(counter); + + // use nanos, because why the hell not? we won't display them. + let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); + + let imported = History::import() + .timestamp(time) + .command(command.trim_end().to_string()) + .duration(duration); + + imported.build().into() +} + +fn unmetafy(line: &[u8]) -> Option<Cow<str>> { + if line.contains(&0x83) { + let mut s = Vec::with_capacity(line.len()); + let mut is_meta = false; + for ch in line { + if *ch == 0x83 { + is_meta = true; + } else if is_meta { + is_meta = false; + s.push(*ch ^ 32); + } else { + s.push(*ch) + } + } + String::from_utf8(s).ok().map(Cow::Owned) + } else { + std::str::from_utf8(line).ok().map(Cow::Borrowed) + } +} + +#[cfg(test)] +mod test { + use itertools::assert_equal; + + use crate::import::tests::TestLoader; + + use super::*; + + #[test] + fn test_parse_extended_simple() { + let parsed = parse_extended("1613322469:0;cargo install atuin", 0); + + assert_eq!(parsed.command, "cargo install atuin"); + assert_eq!(parsed.duration, 0); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install atuin;cargo update", 0); + + assert_eq!(parsed.command, "cargo install atuin;cargo update"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); + + assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + + let parsed = parse_extended("1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!( + parsed.timestamp, + OffsetDateTime::from_unix_timestamp(1_613_322_469).unwrap() + ); + } + + #[tokio::test] + async fn test_parse_file() { + let bytes = r": 1613322469:0;cargo install atuin +: 1613322469:10;cargo install atuin; \ +cargo update +: 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷ +" + .as_bytes() + .to_owned(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 4); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + [ + "cargo install atuin", + "cargo install atuin; \\\ncargo update", + "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", + ], + ); + } + + #[tokio::test] + async fn test_parse_metafied() { + let bytes = + b"echo \xe4\xbd\x83\x80\xe5\xa5\xbd\nls ~/\xe9\x83\xbf\xb3\xe4\xb9\x83\xb0\n".to_vec(); + + let mut zsh = Zsh { bytes }; + assert_eq!(zsh.entries().await.unwrap(), 2); + + let mut loader = TestLoader::default(); + zsh.load(&mut loader).await.unwrap(); + + assert_equal( + loader.buf.iter().map(|h| h.command.as_str()), + ["echo 你好", "ls ~/音乐"], + ); + } +} diff --git a/crates/atuin-client/src/import/zsh_histdb.rs b/crates/atuin-client/src/import/zsh_histdb.rs new file mode 100644 index 00000000..eb72baa3 --- /dev/null +++ b/crates/atuin-client/src/import/zsh_histdb.rs @@ -0,0 +1,247 @@ +// import old shell history from zsh-histdb! +// automatically hoover up all that we can find + +// As far as i can tell there are no version numbers in the histdb sqlite DB, so we're going based +// on the schema from 2022-05-01 +// +// I have run into some histories that will not import b/c of non UTF-8 characters. +// + +// +// An Example sqlite query for hsitdb data: +// +//id|session|command_id|place_id|exit_status|start_time|duration|id|argv|id|host|dir +// +// +// select +// history.id, +// history.start_time, +// places.host, +// places.dir, +// commands.argv +// from history +// left join commands on history.command_id = commands.id +// left join places on history.place_id = places.id ; +// +// CREATE TABLE history (id integer primary key autoincrement, +// session int, +// command_id int references commands (id), +// place_id int references places (id), +// exit_status int, +// start_time int, +// duration int); +// + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use atuin_common::utils::uuid_v7; +use directories::UserDirs; +use eyre::{eyre, Result}; +use sqlx::{sqlite::SqlitePool, Pool}; +use time::PrimitiveDateTime; + +use super::Importer; +use crate::history::History; +use crate::import::Loader; +use crate::utils::{get_hostname, get_username}; + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntryCount { + pub count: usize, +} + +#[derive(sqlx::FromRow, Debug)] +pub struct HistDbEntry { + pub id: i64, + pub start_time: PrimitiveDateTime, + pub host: Vec<u8>, + pub dir: Vec<u8>, + pub argv: Vec<u8>, + pub duration: i64, + pub exit_status: i64, + pub session: i64, +} + +#[derive(Debug)] +pub struct ZshHistDb { + histdb: Vec<HistDbEntry>, + username: String, +} + +/// Read db at given file, return vector of entries. +async fn hist_from_db(dbpath: PathBuf) -> Result<Vec<HistDbEntry>> { + let pool = SqlitePool::connect(dbpath.to_str().unwrap()).await?; + hist_from_db_conn(pool).await +} + +async fn hist_from_db_conn(pool: Pool<sqlx::Sqlite>) -> Result<Vec<HistDbEntry>> { + let query = r#" + SELECT + history.id, history.start_time, history.duration, places.host, places.dir, + commands.argv, history.exit_status, history.session + FROM history + LEFT JOIN commands ON history.command_id = commands.id + LEFT JOIN places ON history.place_id = places.id + ORDER BY history.start_time + "#; + let histdb_vec: Vec<HistDbEntry> = sqlx::query_as::<_, HistDbEntry>(query) + .fetch_all(&pool) + .await?; + Ok(histdb_vec) +} + +impl ZshHistDb { + pub fn histpath_candidate() -> PathBuf { + // By default histdb database is `${HOME}/.histdb/zsh-history.db` + // This can be modified by ${HISTDB_FILE} + // + // if [[ -z ${HISTDB_FILE} ]]; then + // typeset -g HISTDB_FILE="${HOME}/.histdb/zsh-history.db" + let user_dirs = UserDirs::new().unwrap(); // should catch error here? + let home_dir = user_dirs.home_dir(); + std::env::var("HISTDB_FILE") + .as_ref() + .map(|x| Path::new(x).to_path_buf()) + .unwrap_or_else(|_err| home_dir.join(".histdb/zsh-history.db")) + } + pub fn histpath() -> Result<PathBuf> { + let histdb_path = ZshHistDb::histpath_candidate(); + if histdb_path.exists() { + Ok(histdb_path) + } else { + Err(eyre!( + "Could not find history file. Try setting $HISTDB_FILE" + )) + } + } +} + +#[async_trait] +impl Importer for ZshHistDb { + // Not sure how this is used + const NAME: &'static str = "zsh_histdb"; + + /// Creates a new ZshHistDb and populates the history based on the pre-populated data + /// structure. + async fn new() -> Result<Self> { + let dbpath = ZshHistDb::histpath()?; + let histdb_entry_vec = hist_from_db(dbpath).await?; + Ok(Self { + histdb: histdb_entry_vec, + username: get_username(), + }) + } + + async fn entries(&mut self) -> Result<usize> { + Ok(self.histdb.len()) + } + + async fn load(self, h: &mut impl Loader) -> Result<()> { + let mut session_map = HashMap::new(); + for entry in self.histdb { + let command = match std::str::from_utf8(&entry.argv) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let cwd = match std::str::from_utf8(&entry.dir) { + Ok(s) => s.trim_end(), + Err(_) => continue, // we can skip past things like invalid utf8 + }; + let hostname = format!( + "{}:{}", + String::from_utf8(entry.host).unwrap_or_else(|_e| get_hostname()), + self.username + ); + let session = session_map.entry(entry.session).or_insert_with(uuid_v7); + + let imported = History::import() + .timestamp(entry.start_time.assume_utc()) + .command(command) + .cwd(cwd) + .duration(entry.duration * 1_000_000_000) + .exit(entry.exit_status) + .session(session.as_simple().to_string()) + .hostname(hostname) + .build(); + h.push(imported.into()).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + use std::env; + #[tokio::test(flavor = "multi_thread")] + async fn test_env_vars() { + let test_env_db = "nonstd-zsh-history.db"; + let key = "HISTDB_FILE"; + env::set_var(key, test_env_db); + + // test the env got set + assert_eq!(env::var(key).unwrap(), test_env_db.to_string()); + + // test histdb returns the proper db from previous step + let histdb_path = ZshHistDb::histpath_candidate(); + assert_eq!(histdb_path.to_str().unwrap(), test_env_db); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_import() { + let pool: SqlitePool = SqlitePoolOptions::new() + .min_connections(2) + .connect(":memory:") + .await + .unwrap(); + + // sql dump directly from a test database. + let db_sql = r#" + PRAGMA foreign_keys=OFF; + BEGIN TRANSACTION; + CREATE TABLE commands (id integer primary key autoincrement, argv text, unique(argv) on conflict ignore); + INSERT INTO commands VALUES(1,'pwd'); + INSERT INTO commands VALUES(2,'curl google.com'); + INSERT INTO commands VALUES(3,'bash'); + CREATE TABLE places (id integer primary key autoincrement, host text, dir text, unique(host, dir) on conflict ignore); + INSERT INTO places VALUES(1,'mbp16.local','/home/noyez'); + CREATE TABLE history (id integer primary key autoincrement, + session int, + command_id int references commands (id), + place_id int references places (id), + exit_status int, + start_time int, + duration int); + INSERT INTO history VALUES(1,0,1,1,0,1651497918,1); + INSERT INTO history VALUES(2,0,2,1,0,1651497923,1); + INSERT INTO history VALUES(3,0,3,1,NULL,1651497930,NULL); + DELETE FROM sqlite_sequence; + INSERT INTO sqlite_sequence VALUES('commands',3); + INSERT INTO sqlite_sequence VALUES('places',3); + INSERT INTO sqlite_sequence VALUES('history',3); + CREATE INDEX hist_time on history(start_time); + CREATE INDEX place_dir on places(dir); + CREATE INDEX place_host on places(host); + CREATE INDEX history_command_place on history(command_id, place_id); + COMMIT; "#; + + sqlx::query(db_sql).execute(&pool).await.unwrap(); + + // test histdb iterator + let histdb_vec = hist_from_db_conn(pool).await.unwrap(); + let histdb = ZshHistDb { + histdb: histdb_vec, + username: get_username(), + }; + + println!("h: {:#?}", histdb.histdb); + println!("counter: {:?}", histdb.histdb.len()); + for i in histdb.histdb { + println!("{i:?}"); + } + } +} diff --git a/crates/atuin-client/src/kv.rs b/crates/atuin-client/src/kv.rs new file mode 100644 index 00000000..fb26cadc --- /dev/null +++ b/crates/atuin-client/src/kv.rs @@ -0,0 +1,265 @@ +use std::collections::BTreeMap; + +use atuin_common::record::{DecryptedData, Host, HostId}; +use eyre::{bail, ensure, eyre, Result}; +use serde::Deserialize; + +use crate::record::encryption::PASETO_V4; +use crate::record::store::Store; + +const KV_VERSION: &str = "v0"; +const KV_TAG: &str = "kv"; +const KV_VAL_MAX_LEN: usize = 100 * 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct KvRecord { + pub namespace: String, + pub key: String, + pub value: String, +} + +impl KvRecord { + pub fn serialize(&self) -> Result<DecryptedData> { + use rmp::encode; + + let mut output = vec![]; + + // INFO: ensure this is updated when adding new fields + encode::write_array_len(&mut output, 3)?; + + encode::write_str(&mut output, &self.namespace)?; + encode::write_str(&mut output, &self.key)?; + encode::write_str(&mut output, &self.value)?; + + Ok(DecryptedData(output)) + } + + pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + KV_VERSION => { + let mut bytes = decode::Bytes::new(&data.0); + + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!(nfields == 3, "too many entries in v0 kv record"); + + let bytes = bytes.remaining_slice(); + + let (namespace, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded kvrecord. malformed") + } + + Ok(KvRecord { + namespace: namespace.to_owned(), + key: key.to_owned(), + value: value.to_owned(), + }) + } + _ => { + bail!("unknown version {version:?}") + } + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct KvStore; + +impl Default for KvStore { + fn default() -> Self { + Self::new() + } +} + +impl KvStore { + // will want to init the actual kv store when that is done + pub fn new() -> KvStore { + KvStore {} + } + + pub async fn set( + &self, + store: &(impl Store + Send + Sync), + encryption_key: &[u8; 32], + host_id: HostId, + namespace: &str, + key: &str, + value: &str, + ) -> Result<()> { + if value.len() > KV_VAL_MAX_LEN { + return Err(eyre!( + "kv value too large: max len {} bytes", + KV_VAL_MAX_LEN + )); + } + + let record = KvRecord { + namespace: namespace.to_string(), + key: key.to_string(), + value: value.to_string(), + }; + + let bytes = record.serialize()?; + + let idx = store + .last(host_id, KV_TAG) + .await? + .map_or(0, |entry| entry.idx + 1); + + let record = atuin_common::record::Record::builder() + .host(Host::new(host_id)) + .version(KV_VERSION.to_string()) + .tag(KV_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + store + .push(&record.encrypt::<PASETO_V4>(encryption_key)) + .await?; + + Ok(()) + } + + // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as + // well. + pub async fn get( + &self, + store: &impl Store, + encryption_key: &[u8; 32], + namespace: &str, + key: &str, + ) -> Result<Option<KvRecord>> { + // TODO: don't rebuild every time... + let map = self.build_kv(store, encryption_key).await?; + + let res = map.get(namespace); + + if let Some(ns) = res { + let value = ns.get(key); + + Ok(value.cloned()) + } else { + Ok(None) + } + } + + // Build a kv map out of the linked list kv store + // Map is Namespace -> Key -> Value + // TODO(ellie): "cache" this into a real kv structure, which we can + // use as a write-through cache to avoid constant rebuilds. + pub async fn build_kv( + &self, + store: &impl Store, + encryption_key: &[u8; 32], + ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> { + let mut map = BTreeMap::new(); + + // TODO: maybe don't load the entire tag into memory to build the kv + // we can be smart about it and only load values since the last build + // or, iterate/paginate + let tagged = store.all_tagged(KV_TAG).await?; + + // iterate through all tags and play each KV record at a time + // this is "last write wins" + // probably good enough for now, but revisit in future + for record in tagged { + let decrypted = match record.version.as_str() { + KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?, + version => bail!("unknown version {version:?}"), + }; + + let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?; + + let ns = map + .entry(kv.namespace.clone()) + .or_insert_with(BTreeMap::new); + + ns.insert(kv.key.clone(), kv); + } + + Ok(map) + } +} + +#[cfg(test)] +mod tests { + use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; + use rand::rngs::OsRng; + + use crate::record::sqlite_store::{test_sqlite_store_timeout, SqliteStore}; + + use super::{KvRecord, KvStore, KV_VERSION}; + + #[test] + fn encode_decode() { + let kv = KvRecord { + namespace: "foo".to_owned(), + key: "bar".to_owned(), + value: "baz".to_owned(), + }; + let snapshot = [ + 0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z', + ]; + + let encoded = kv.serialize().unwrap(); + let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, kv); + } + + #[tokio::test] + async fn build_kv() { + let mut store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let kv = KvStore::new(); + let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); + let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); + + kv.set(&mut store, &key, host_id, "test-kv", "foo", "bar") + .await + .unwrap(); + + kv.set(&mut store, &key, host_id, "test-kv", "1", "2") + .await + .unwrap(); + + let map = kv.build_kv(&store, &key).await.unwrap(); + + assert_eq!( + *map.get("test-kv") + .expect("map namespace not set") + .get("foo") + .expect("map key not set"), + KvRecord { + namespace: String::from("test-kv"), + key: String::from("foo"), + value: String::from("bar") + } + ); + + assert_eq!( + *map.get("test-kv") + .expect("map namespace not set") + .get("1") + .expect("map key not set"), + KvRecord { + namespace: String::from("test-kv"), + key: String::from("1"), + value: String::from("2") + } + ); + } +} diff --git a/crates/atuin-client/src/lib.rs b/crates/atuin-client/src/lib.rs new file mode 100644 index 00000000..66258af3 --- /dev/null +++ b/crates/atuin-client/src/lib.rs @@ -0,0 +1,21 @@ +#![forbid(unsafe_code)] + +#[macro_use] +extern crate log; + +#[cfg(feature = "sync")] +pub mod api_client; +#[cfg(feature = "sync")] +pub mod sync; + +pub mod database; +pub mod encryption; +pub mod history; +pub mod import; +pub mod kv; +pub mod ordering; +pub mod record; +pub mod secrets; +pub mod settings; + +mod utils; diff --git a/crates/atuin-client/src/ordering.rs b/crates/atuin-client/src/ordering.rs new file mode 100644 index 00000000..4e5ec84c --- /dev/null +++ b/crates/atuin-client/src/ordering.rs @@ -0,0 +1,32 @@ +use minspan::minspan; + +use super::{history::History, settings::SearchMode}; + +pub fn reorder_fuzzy(mode: SearchMode, query: &str, res: Vec<History>) -> Vec<History> { + match mode { + SearchMode::Fuzzy => reorder(query, |x| &x.command, res), + _ => res, + } +} + +fn reorder<F, A>(query: &str, f: F, res: Vec<A>) -> Vec<A> +where + F: Fn(&A) -> &String, + A: Clone, +{ + let mut r = res.clone(); + let qvec = &query.chars().collect(); + r.sort_by_cached_key(|h| { + // TODO for fzf search we should sum up scores for each matched term + let (from, to) = match minspan::span(qvec, &(f(h).chars().collect())) { + Some(x) => x, + // this is a little unfortunate: when we are asked to match a query that is found nowhere, + // we don't want to return a None, as the comparison behaviour would put the worst matches + // at the front. therefore, we'll return a set of indices that are one larger than the longest + // possible legitimate match. This is meaningless except as a comparison. + None => (0, res.len()), + }; + 1 + to - from + }); + r +} diff --git a/crates/atuin-client/src/record/encryption.rs b/crates/atuin-client/src/record/encryption.rs new file mode 100644 index 00000000..3ad3be66 --- /dev/null +++ b/crates/atuin-client/src/record/encryption.rs @@ -0,0 +1,373 @@ +use atuin_common::record::{ + AdditionalData, DecryptedData, EncryptedData, Encryption, HostId, RecordId, RecordIdx, +}; +use base64::{engine::general_purpose, Engine}; +use eyre::{ensure, Context, Result}; +use rusty_paserk::{Key, KeyId, Local, PieWrappedKey}; +use rusty_paseto::core::{ + ImplicitAssertion, Key as DataKey, Local as LocalPurpose, Paseto, PasetoNonce, Payload, V4, +}; +use serde::{Deserialize, Serialize}; + +/// Use PASETO V4 Local encryption using the additional data as an implicit assertion. +#[allow(non_camel_case_types)] +pub struct PASETO_V4; + +/* +Why do we use a random content-encryption key? +Originally I was planning on using a derived key for encryption based on additional data. +This would be a lot more secure than using the master key directly. + +However, there's an established norm of using a random key. This scheme might be otherwise known as +- client-side encryption +- envelope encryption +- key wrapping + +A HSM (Hardware Security Module) provider, eg: AWS, Azure, GCP, or even a physical device like a YubiKey +will have some keys that they keep to themselves. These keys never leave their physical hardware. +If they never leave the hardware, then encrypting large amounts of data means giving them the data and waiting. +This is not a practical solution. Instead, generate a unique key for your data, encrypt that using your HSM +and then store that with your data. + +See + - <https://docs.aws.amazon.com/wellarchitected/latest/financial-services-industry-lens/use-envelope-encryption-with-customer-master-keys.html> + - <https://cloud.google.com/kms/docs/envelope-encryption> + - <https://learn.microsoft.com/en-us/azure/storage/blobs/client-side-encryption?tabs=dotnet#encryption-and-decryption-via-the-envelope-technique> + - <https://www.yubico.com/gb/product/yubihsm-2-fips/> + - <https://cheatsheetseries.owasp.org/cheatsheets/Cryptographic_Storage_Cheat_Sheet.html#encrypting-stored-keys> + +Why would we care? In the past we have received some requests for company solutions. If in future we can configure a +KMS service with little effort, then that would solve a lot of issues for their security team. + +Even for personal use, if a user is not comfortable with sharing keys between hosts, +GCP HSM costs $1/month and $0.03 per 10,000 key operations. Assuming an active user runs +1000 atuin records a day, that would only cost them $1 and 10 cent a month. + +Additionally, key rotations are much simpler using this scheme. Rotating a key is as simple as re-encrypting the CEK, and not the message contents. +This makes it very fast to rotate a key in bulk. + +For future reference, with asymmetric encryption, you can encrypt the CEK without the HSM's involvement, but decrypting +will need the HSM. This allows the encryption path to still be extremely fast (no network calls) but downloads/decryption +that happens in the background can make the network calls to the HSM +*/ + +impl Encryption for PASETO_V4 { + fn re_encrypt( + mut data: EncryptedData, + _ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let cek = Self::decrypt_cek(data.content_encryption_key, old_key)?; + data.content_encryption_key = Self::encrypt_cek(cek, new_key); + Ok(data) + } + + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData { + // generate a random key for this entry + // aka content-encryption-key (CEK) + let random_key = Key::<V4, Local>::new_os_random(); + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // build the payload and encrypt the token + let payload = serde_json::to_string(&AtuinPayload { + data: general_purpose::URL_SAFE_NO_PAD.encode(data.0), + }) + .expect("json encoding can't fail"); + let nonce = DataKey::<32>::try_new_random().expect("could not source from random"); + let nonce = PasetoNonce::<V4, LocalPurpose>::from(&nonce); + + let token = Paseto::<V4, LocalPurpose>::builder() + .set_payload(Payload::from(payload.as_str())) + .set_implicit_assertion(ImplicitAssertion::from(assertions.as_str())) + .try_encrypt(&random_key.into(), &nonce) + .expect("error encrypting atuin data"); + + EncryptedData { + data: token, + content_encryption_key: Self::encrypt_cek(random_key, key), + } + } + + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData> { + let token = data.data; + let cek = Self::decrypt_cek(data.content_encryption_key, key)?; + + // encode the implicit assertions + let assertions = Assertions::from(ad).encode(); + + // decrypt the payload with the footer and implicit assertions + let payload = Paseto::<V4, LocalPurpose>::try_decrypt( + &token, + &cek.into(), + None, + ImplicitAssertion::from(&*assertions), + ) + .context("could not decrypt entry")?; + + let payload: AtuinPayload = serde_json::from_str(&payload)?; + let data = general_purpose::URL_SAFE_NO_PAD.decode(payload.data)?; + Ok(DecryptedData(data)) + } +} + +impl PASETO_V4 { + fn decrypt_cek(wrapped_cek: String, key: &[u8; 32]) -> Result<Key<V4, Local>> { + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // let wrapping_key = PasetoSymmetricKey::from(Key::from(key)); + + let AtuinFooter { kid, wpk } = serde_json::from_str(&wrapped_cek) + .context("wrapped cek did not contain the correct contents")?; + + // check that the wrapping key matches the required key to decrypt. + // In future, we could support multiple keys and use this key to + // look up the key rather than only allow one key. + // For now though we will only support the one key and key rotation will + // have to be a hard reset + let current_kid = wrapping_key.to_id(); + + ensure!( + current_kid == kid, + "attempting to decrypt with incorrect key. currently using {current_kid}, expecting {kid}" + ); + + // decrypt the random key + Ok(wpk.unwrap_key(&wrapping_key)?) + } + + fn encrypt_cek(cek: Key<V4, Local>, key: &[u8; 32]) -> String { + // aka key-encryption-key (KEK) + let wrapping_key = Key::<V4, Local>::from_bytes(*key); + + // wrap the random key so we can decrypt it later + let wrapped_cek = AtuinFooter { + wpk: cek.wrap_pie(&wrapping_key), + kid: wrapping_key.to_id(), + }; + serde_json::to_string(&wrapped_cek).expect("could not serialize wrapped cek") + } +} + +#[derive(Serialize, Deserialize)] +struct AtuinPayload { + data: String, +} + +#[derive(Serialize, Deserialize)] +/// Well-known footer claims for decrypting. This is not encrypted but is stored in the record. +/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/04-Claims.md#optional-footer-claims> +struct AtuinFooter { + /// Wrapped key + wpk: PieWrappedKey<V4, Local>, + /// ID of the key which was used to wrap + kid: KeyId<V4, Local>, +} + +/// Used in the implicit assertions. This is not encrypted and not stored in the data blob. +// This cannot be changed, otherwise it breaks the authenticated encryption. +#[derive(Debug, Copy, Clone, Serialize)] +struct Assertions<'a> { + id: &'a RecordId, + idx: &'a RecordIdx, + version: &'a str, + tag: &'a str, + host: &'a HostId, +} + +impl<'a> From<AdditionalData<'a>> for Assertions<'a> { + fn from(ad: AdditionalData<'a>) -> Self { + Self { + id: ad.id, + version: ad.version, + tag: ad.tag, + host: ad.host, + idx: ad.idx, + } + } +} + +impl Assertions<'_> { + fn encode(&self) -> String { + serde_json::to_string(self).expect("could not serialize implicit assertions") + } +} + +#[cfg(test)] +mod tests { + use atuin_common::{ + record::{Host, Record}, + utils::uuid_v7, + }; + + use super::*; + + #[test] + fn round_trip() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let decrypted = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap(); + assert_eq!(decrypted, data); + } + + #[test] + fn same_entry_different_output() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data.clone(), ad, &key.to_bytes()); + let encrypted2 = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + assert_ne!( + encrypted.data, encrypted2.data, + "re-encrypting the same contents should have different output due to key randomization" + ); + } + + #[test] + fn cannot_decrypt_different_key() { + let key = Key::<V4, Local>::new_os_random(); + let fake_key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + let _ = PASETO_V4::decrypt(encrypted, ad, &fake_key.to_bytes()).unwrap_err(); + } + + #[test] + fn cannot_decrypt_different_id() { + let key = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted = PASETO_V4::encrypt(data, ad, &key.to_bytes()); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + ..ad + }; + let _ = PASETO_V4::decrypt(encrypted, ad, &key.to_bytes()).unwrap_err(); + } + + #[test] + fn re_encrypt_round_trip() { + let key1 = Key::<V4, Local>::new_os_random(); + let key2 = Key::<V4, Local>::new_os_random(); + + let ad = AdditionalData { + id: &RecordId(uuid_v7()), + version: "v0", + tag: "kv", + host: &HostId(uuid_v7()), + idx: &0, + }; + + let data = DecryptedData(vec![1, 2, 3, 4]); + + let encrypted1 = PASETO_V4::encrypt(data.clone(), ad, &key1.to_bytes()); + let encrypted2 = + PASETO_V4::re_encrypt(encrypted1.clone(), ad, &key1.to_bytes(), &key2.to_bytes()) + .unwrap(); + + // we only re-encrypt the content keys + assert_eq!(encrypted1.data, encrypted2.data); + assert_ne!( + encrypted1.content_encryption_key, + encrypted2.content_encryption_key + ); + + let decrypted = PASETO_V4::decrypt(encrypted2, ad, &key2.to_bytes()).unwrap(); + + assert_eq!(decrypted, data); + } + + #[test] + fn full_record_round_trip() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + assert!(!encrypted.data.data.is_empty()); + assert!(!encrypted.data.content_encryption_key.is_empty()); + + let decrypted = encrypted.decrypt::<PASETO_V4>(&key).unwrap(); + + assert_eq!(decrypted.data.0, [1, 2, 3, 4]); + } + + #[test] + fn full_record_round_trip_fail() { + let key = [0x55; 32]; + let record = Record::builder() + .id(RecordId(uuid_v7())) + .version("v0".to_owned()) + .tag("kv".to_owned()) + .host(Host::new(HostId(uuid_v7()))) + .timestamp(1687244806000000) + .data(DecryptedData(vec![1, 2, 3, 4])) + .idx(0) + .build(); + + let encrypted = record.encrypt::<PASETO_V4>(&key); + + let mut enc1 = encrypted.clone(); + enc1.host = Host::new(HostId(uuid_v7())); + let _ = enc1 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the host should result in auth failure"); + + let mut enc2 = encrypted; + enc2.id = RecordId(uuid_v7()); + let _ = enc2 + .decrypt::<PASETO_V4>(&key) + .expect_err("tampering with the id should result in auth failure"); + } +} diff --git a/crates/atuin-client/src/record/mod.rs b/crates/atuin-client/src/record/mod.rs new file mode 100644 index 00000000..c40fd395 --- /dev/null +++ b/crates/atuin-client/src/record/mod.rs @@ -0,0 +1,6 @@ +pub mod encryption; +pub mod sqlite_store; +pub mod store; + +#[cfg(feature = "sync")] +pub mod sync; diff --git a/crates/atuin-client/src/record/sqlite_store.rs b/crates/atuin-client/src/record/sqlite_store.rs new file mode 100644 index 00000000..31de311b --- /dev/null +++ b/crates/atuin-client/src/record/sqlite_store.rs @@ -0,0 +1,641 @@ +// Here we are using sqlite as a pretty dumb store, and will not be running any complex queries. +// Multiple stores of multiple types are all stored in one chonky table (for now), and we just index +// by tag/host + +use std::str::FromStr; +use std::{path::Path, time::Duration}; + +use async_trait::async_trait; +use eyre::{eyre, Result}; +use fs_err as fs; + +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteRow}, + Row, +}; + +use atuin_common::record::{ + EncryptedData, Host, HostId, Record, RecordId, RecordIdx, RecordStatus, +}; +use uuid::Uuid; + +use super::encryption::PASETO_V4; +use super::store::Store; + +#[derive(Debug, Clone)] +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(path: impl AsRef<Path>, timeout: f64) -> Result<Self> { + let path = path.as_ref(); + + debug!("opening sqlite database at {:?}", path); + + let create = !path.exists(); + if create { + if let Some(dir) = path.parent() { + fs::create_dir_all(dir)?; + } + } + + let opts = SqliteConnectOptions::from_str(path.as_os_str().to_str().unwrap())? + .journal_mode(SqliteJournalMode::Wal) + .foreign_keys(true) + .create_if_missing(true); + + let pool = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)) + .connect_with(opts) + .await?; + + Self::setup_db(&pool).await?; + + Ok(Self { pool }) + } + + async fn setup_db(pool: &SqlitePool) -> Result<()> { + debug!("running sqlite database setup"); + + sqlx::migrate!("./record-migrations").run(pool).await?; + + Ok(()) + } + + async fn save_raw( + tx: &mut sqlx::Transaction<'_, sqlx::Sqlite>, + r: &Record<EncryptedData>, + ) -> Result<()> { + // In sqlite, we are "limited" to i64. But that is still fine, until 2262. + sqlx::query( + "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) + values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; + + Ok(()) + } + + fn query_row(row: SqliteRow) -> Record<EncryptedData> { + let idx: i64 = row.get("idx"); + let timestamp: i64 = row.get("timestamp"); + + // tbh at this point things are pretty fucked so just panic + let id = Uuid::from_str(row.get("id")).expect("invalid id UUID format in sqlite DB"); + let host = Uuid::from_str(row.get("host")).expect("invalid host UUID format in sqlite DB"); + + Record { + id: RecordId(id), + idx: idx as u64, + host: Host::new(HostId(host)), + timestamp: timestamp as u64, + tag: row.get("tag"), + version: row.get("version"), + data: EncryptedData { + data: row.get("data"), + content_encryption_key: row.get("cek"), + }, + } + } + + async fn load_all(&self) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store ") + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } +} + +#[async_trait] +impl Store for SqliteStore { + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + for record in records { + Self::save_raw(&mut tx, record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>> { + let res = sqlx::query("select * from store where store.id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .map(Self::query_row) + .fetch_one(&self.pool) + .await?; + + Ok(res) + } + + async fn delete(&self, id: RecordId) -> Result<()> { + sqlx::query("delete from store where id = ?1") + .bind(id.0.as_hyphenated().to_string()) + .execute(&self.pool) + .await?; + + Ok(()) + } + + async fn delete_all(&self) -> Result<()> { + sqlx::query("delete from store").execute(&self.pool).await?; + + Ok(()) + } + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where host=?1 and tag=?2 order by idx desc limit 1") + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(record) => Ok(Some(record)), + } + } + + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>> { + self.idx(host, tag, 0).await + } + + async fn len_all(&self) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len_tag(&self, tag: &str) -> Result<u64> { + let res: Result<(i64,), sqlx::Error> = + sqlx::query_as("select count(*) from store where tag=?1") + .bind(tag) + .fetch_one(&self.pool) + .await; + match res { + Err(e) => Err(eyre!("failed to fetch local store len: {}", e)), + Ok(v) => Ok(v.0 as u64), + } + } + + async fn len(&self, host: HostId, tag: &str) -> Result<u64> { + let last = self.last(host, tag).await?; + + if let Some(last) = last { + return Ok(last.idx + 1); + } + + return Ok(0); + } + + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>> { + let res = + sqlx::query("select * from store where idx >= ?1 and host = ?2 and tag = ?3 limit ?4") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .bind(limit as i64) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where idx = ?1 and host = ?2 and tag = ?3") + .bind(idx as i64) + .bind(host.0.as_hyphenated().to_string()) + .bind(tag) + .map(Self::query_row) + .fetch_one(&self.pool) + .await; + + match res { + Err(sqlx::Error::RowNotFound) => Ok(None), + Err(e) => Err(eyre!("an error occurred: {}", e)), + Ok(v) => Ok(Some(v)), + } + } + + async fn status(&self) -> Result<RecordStatus> { + let mut status = RecordStatus::new(); + + let res: Result<Vec<(String, String, i64)>, sqlx::Error> = + sqlx::query_as("select host, tag, max(idx) from store group by host, tag") + .fetch_all(&self.pool) + .await; + + let res = match res { + Err(e) => return Err(eyre!("failed to fetch local store status: {}", e)), + Ok(v) => v, + }; + + for i in res { + let host = HostId( + Uuid::from_str(i.0.as_str()).expect("failed to parse uuid for local store status"), + ); + + status.set_raw(host, i.1, i.2 as u64); + } + + Ok(status) + } + + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>> { + let res = sqlx::query("select * from store where tag = ?1 order by timestamp asc") + .bind(tag) + .map(Self::query_row) + .fetch_all(&self.pool) + .await?; + + Ok(res) + } + + /// Reencrypt every single item in this store with a new key + /// Be careful - this may mess with sync. + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()> { + // Load all the records + // In memory like some of the other code here + // This will never be called in a hot loop, and only under the following circumstances + // 1. The user has logged into a new account, with a new key. They are unlikely to have a + // lot of data + // 2. The user has encountered some sort of issue, and runs a maintenance command that + // invokes this + let all = self.load_all().await?; + + let re_encrypted = all + .into_iter() + .map(|record| record.re_encrypt::<PASETO_V4>(old_key, new_key)) + .collect::<Result<Vec<_>>>()?; + + // next up, we delete all the old data and reinsert the new stuff + // do it in one transaction, so if anything fails we rollback OK + + let mut tx = self.pool.begin().await?; + + let res = sqlx::query("delete from store").execute(&mut *tx).await?; + + let rows = res.rows_affected(); + debug!("deleted {rows} rows"); + + // don't call push_batch, as it will start its own transaction + // call the underlying save_raw + + for record in re_encrypted { + Self::save_raw(&mut tx, &record).await?; + } + + tx.commit().await?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn verify(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + all.into_iter() + .map(|record| record.decrypt::<PASETO_V4>(key)) + .collect::<Result<Vec<_>>>()?; + + Ok(()) + } + + /// Verify that every record in this store can be decrypted with the current key + /// Someday maybe also check each tag/record can be deserialized, but not for now. + async fn purge(&self, key: &[u8; 32]) -> Result<()> { + let all = self.load_all().await?; + + for record in all.iter() { + match record.clone().decrypt::<PASETO_V4>(key) { + Ok(_) => continue, + Err(_) => { + println!( + "Failed to decrypt {}, deleting", + record.id.0.as_hyphenated() + ); + + self.delete(record.id).await?; + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +pub(crate) fn test_sqlite_store_timeout() -> f64 { + std::env::var("ATUIN_TEST_SQLITE_STORE_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(0.1) +} + +#[cfg(test)] +mod tests { + use atuin_common::{ + record::{DecryptedData, EncryptedData, Host, HostId, Record}, + utils::uuid_v7, + }; + + use crate::{ + encryption::generate_encoded_key, + record::{encryption::PASETO_V4, store::Store}, + }; + + use super::{test_sqlite_store_timeout, SqliteStore}; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(Host::new(HostId(atuin_common::utils::uuid_v7()))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: "1234".into(), + content_encryption_key: "1234".into(), + }) + .idx(0) + .build() + } + + #[tokio::test] + async fn create_db() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()).await; + + assert!( + db.is_ok(), + "db could not be created, {:?}", + db.err().unwrap() + ); + } + + #[tokio::test] + async fn push_record() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + + db.push(&record).await.expect("failed to insert record"); + } + + #[tokio::test] + async fn get_record() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let new_record = db.get(record.id).await.expect("failed to fetch record"); + + assert_eq!(record, new_record, "records are not equal"); + } + + #[tokio::test] + async fn last() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let last = db + .last(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + last.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn first() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let first = db + .first(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!( + first.unwrap().id, + record.id, + "expected to get back the same record that was inserted" + ); + } + + #[tokio::test] + async fn len() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len(record.host.id, record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_tag() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let record = test_record(); + db.push(&record).await.unwrap(); + + let len = db + .len_tag(record.tag.as_str()) + .await + .expect("failed to get store len"); + + assert_eq!(len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn len_different_tags() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + // these have different tags, so the len should be the same + // we model multiple stores within one database + // new store = new tag = independent length + let first = test_record(); + let second = test_record(); + + db.push(&first).await.unwrap(); + db.push(&second).await.unwrap(); + + let first_len = db.len(first.host.id, first.tag.as_str()).await.unwrap(); + let second_len = db.len(second.host.id, second.tag.as_str()).await.unwrap(); + + assert_eq!(first_len, 1, "expected length of 1 after insert"); + assert_eq!(second_len, 1, "expected length of 1 after insert"); + } + + #[tokio::test] + async fn append_a_bunch() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + let mut tail = test_record(); + db.push(&tail).await.expect("failed to push record"); + + for _ in 1..100 { + tail = tail.append(vec![1, 2, 3, 4]).encrypt::<PASETO_V4>(&[0; 32]); + db.push(&tail).await.unwrap(); + } + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + + assert_eq!( + db.len_tag(tail.tag.as_str()).await.unwrap(), + 100, + "failed to insert 100 records" + ); + } + + #[tokio::test] + async fn append_a_big_bunch() { + let db = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + + let mut records: Vec<Record<EncryptedData>> = Vec::with_capacity(10000); + + let mut tail = test_record(); + records.push(tail.clone()); + + for _ in 1..10000 { + tail = tail.append(vec![1, 2, 3]).encrypt::<PASETO_V4>(&[0; 32]); + records.push(tail.clone()); + } + + db.push_batch(records.iter()).await.unwrap(); + + assert_eq!( + db.len(tail.host.id, tail.tag.as_str()).await.unwrap(), + 10000, + "failed to insert 10k records" + ); + } + + #[tokio::test] + async fn re_encrypt() { + let store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let (key, _) = generate_encoded_key().unwrap(); + let data = vec![0u8, 1u8, 2u8, 3u8]; + let host_id = HostId(uuid_v7()); + + for i in 0..10 { + let record = Record::builder() + .host(Host::new(host_id)) + .version(String::from("test")) + .tag(String::from("test")) + .idx(i) + .data(DecryptedData(data.clone())) + .build(); + + let record = record.encrypt::<PASETO_V4>(&key.into()); + store + .push(&record) + .await + .expect("failed to push encrypted record"); + } + + // first, check that we can decrypt the data with the current key + let all = store.all_tagged("test").await.unwrap(); + + assert_eq!(all.len(), 10, "failed to fetch all records"); + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + // reencrypt the store, then check if + // 1) it cannot be decrypted with the old key + // 2) it can be decrypted with the new key + + let (new_key, _) = generate_encoded_key().unwrap(); + store + .re_encrypt(&key.into(), &new_key.into()) + .await + .expect("failed to re-encrypt store"); + + let all = store.all_tagged("test").await.unwrap(); + + for record in all.iter() { + let decrypted = record.clone().decrypt::<PASETO_V4>(&key.into()); + assert!( + decrypted.is_err(), + "did not get error decrypting with old key after re-encrypt" + ) + } + + for record in all { + let decrypted = record.decrypt::<PASETO_V4>(&new_key.into()).unwrap(); + assert_eq!(decrypted.data.0, data); + } + + assert_eq!(store.len(host_id, "test").await.unwrap(), 10); + } +} diff --git a/crates/atuin-client/src/record/store.rs b/crates/atuin-client/src/record/store.rs new file mode 100644 index 00000000..49ca4968 --- /dev/null +++ b/crates/atuin-client/src/record/store.rs @@ -0,0 +1,60 @@ +use async_trait::async_trait; +use eyre::Result; + +use atuin_common::record::{EncryptedData, HostId, Record, RecordId, RecordIdx, RecordStatus}; + +/// A record store stores records +/// In more detail - we tend to need to process this into _another_ format to actually query it. +/// As is, the record store is intended as the source of truth for arbitrary data, which could +/// be shell history, kvs, etc. +#[async_trait] +pub trait Store { + // Push a record + async fn push(&self, record: &Record<EncryptedData>) -> Result<()> { + self.push_batch(std::iter::once(record)).await + } + + // Push a batch of records, all in one transaction + async fn push_batch( + &self, + records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync, + ) -> Result<()>; + + async fn get(&self, id: RecordId) -> Result<Record<EncryptedData>>; + + async fn delete(&self, id: RecordId) -> Result<()>; + async fn delete_all(&self) -> Result<()>; + + async fn len_all(&self) -> Result<u64>; + async fn len(&self, host: HostId, tag: &str) -> Result<u64>; + async fn len_tag(&self, tag: &str) -> Result<u64>; + + async fn last(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + async fn first(&self, host: HostId, tag: &str) -> Result<Option<Record<EncryptedData>>>; + + async fn re_encrypt(&self, old_key: &[u8; 32], new_key: &[u8; 32]) -> Result<()>; + async fn verify(&self, key: &[u8; 32]) -> Result<()>; + async fn purge(&self, key: &[u8; 32]) -> Result<()>; + + /// Get the next `limit` records, after and including the given index + async fn next( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + limit: u64, + ) -> Result<Vec<Record<EncryptedData>>>; + + /// Get the first record for a given host and tag + async fn idx( + &self, + host: HostId, + tag: &str, + idx: RecordIdx, + ) -> Result<Option<Record<EncryptedData>>>; + + async fn status(&self) -> Result<RecordStatus>; + + /// Get all records for a given tag + async fn all_tagged(&self, tag: &str) -> Result<Vec<Record<EncryptedData>>>; +} diff --git a/crates/atuin-client/src/record/sync.rs b/crates/atuin-client/src/record/sync.rs new file mode 100644 index 00000000..234c6442 --- /dev/null +++ b/crates/atuin-client/src/record/sync.rs @@ -0,0 +1,607 @@ +// do a sync :O +use std::{cmp::Ordering, fmt::Write}; + +use eyre::Result; +use thiserror::Error; + +use super::store::Store; +use crate::{api_client::Client, settings::Settings}; + +use atuin_common::record::{Diff, HostId, RecordId, RecordIdx, RecordStatus}; +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; + +#[derive(Error, Debug)] +pub enum SyncError { + #[error("the local store is ahead of the remote, but for another host. has remote lost data?")] + LocalAheadOtherHost, + + #[error("an issue with the local database occurred: {msg:?}")] + LocalStoreError { msg: String }, + + #[error("something has gone wrong with the sync logic: {msg:?}")] + SyncLogicError { msg: String }, + + #[error("operational error: {msg:?}")] + OperationalError { msg: String }, + + #[error("a request to the sync server failed: {msg:?}")] + RemoteRequestError { msg: String }, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum Operation { + // Either upload or download until the states matches the below + Upload { + local: RecordIdx, + remote: Option<RecordIdx>, + host: HostId, + tag: String, + }, + Download { + local: Option<RecordIdx>, + remote: RecordIdx, + host: HostId, + tag: String, + }, + Noop { + host: HostId, + tag: String, + }, +} + +pub async fn diff( + settings: &Settings, + store: &impl Store, +) -> Result<(Vec<Diff>, RecordStatus), SyncError> { + let client = Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + ) + .map_err(|e| SyncError::OperationalError { msg: e.to_string() })?; + + let local_index = store + .status() + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + let remote_index = client + .record_status() + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + let diff = local_index.diff(&remote_index); + + Ok((diff, remote_index)) +} + +// Take a diff, along with a local store, and resolve it into a set of operations. +// With the store as context, we can determine if a tail exists locally or not and therefore if it needs uploading or download. +// In theory this could be done as a part of the diffing stage, but it's easier to reason +// about and test this way +pub async fn operations( + diffs: Vec<Diff>, + _store: &impl Store, +) -> Result<Vec<Operation>, SyncError> { + let mut operations = Vec::with_capacity(diffs.len()); + + for diff in diffs { + let op = match (diff.local, diff.remote) { + // We both have it! Could be either. Compare. + (Some(local), Some(remote)) => match local.cmp(&remote) { + Ordering::Equal => Operation::Noop { + host: diff.host, + tag: diff.tag, + }, + Ordering::Greater => Operation::Upload { + local, + remote: Some(remote), + host: diff.host, + tag: diff.tag, + }, + Ordering::Less => Operation::Download { + local: Some(local), + remote, + host: diff.host, + tag: diff.tag, + }, + }, + + // Remote has it, we don't. Gotta be download + (None, Some(remote)) => Operation::Download { + local: None, + remote, + host: diff.host, + tag: diff.tag, + }, + + // We have it, remote doesn't. Gotta be upload. + (Some(local), None) => Operation::Upload { + local, + remote: None, + host: diff.host, + tag: diff.tag, + }, + + // something is pretty fucked. + (None, None) => { + return Err(SyncError::SyncLogicError { + msg: String::from( + "diff has nothing for local or remote - (host, tag) does not exist", + ), + }) + } + }; + + operations.push(op); + } + + // sort them - purely so we have a stable testing order, and can rely on + // same input = same output + // We can sort by ID so long as we continue to use UUIDv7 or something + // with the same properties + + operations.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + Ok(operations) +} + +async fn sync_upload( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: RecordIdx, + remote: Option<RecordIdx>, +) -> Result<i64, SyncError> { + let remote = remote.unwrap_or(0); + let expected = local - remote; + let upload_page_size = 100; + let mut progress = 0; + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + println!( + "Uploading {} records to {}/{}", + expected, + host.0.as_simple(), + tag + ); + + // preload with the first entry if remote does not know of this store + loop { + let page = store + .next(host, tag.as_str(), remote + progress, upload_page_size) + .await + .map_err(|e| { + error!("failed to read upload page: {e:?}"); + + SyncError::LocalStoreError { msg: e.to_string() } + })?; + + client.post_records(&page).await.map_err(|e| { + error!("failed to post records: {e:?}"); + + SyncError::RemoteRequestError { msg: e.to_string() } + })?; + + pb.set_position(progress); + progress += page.len() as u64; + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Uploaded records"); + + Ok(progress as i64) +} + +async fn sync_download( + store: &impl Store, + client: &Client<'_>, + host: HostId, + tag: String, + local: Option<RecordIdx>, + remote: RecordIdx, +) -> Result<Vec<RecordId>, SyncError> { + let local = local.unwrap_or(0); + let expected = remote - local; + let download_page_size = 100; + let mut progress = 0; + let mut ret = Vec::new(); + + println!( + "Downloading {} records from {}/{}", + expected, + host.0.as_simple(), + tag + ); + + let pb = ProgressBar::new(expected); + pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {human_pos}/{human_len} ({eta})") + .unwrap() + .with_key("eta", |state: &ProgressState, w: &mut dyn Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()) + .progress_chars("#>-")); + + // preload with the first entry if remote does not know of this store + loop { + let page = client + .next_records(host, tag.clone(), local + progress, download_page_size) + .await + .map_err(|e| SyncError::RemoteRequestError { msg: e.to_string() })?; + + store + .push_batch(page.iter()) + .await + .map_err(|e| SyncError::LocalStoreError { msg: e.to_string() })?; + + ret.extend(page.iter().map(|f| f.id)); + + pb.set_position(progress); + progress += page.len() as u64; + + if progress >= expected { + break; + } + } + + pb.finish_with_message("Downloaded records"); + + Ok(ret) +} + +pub async fn sync_remote( + operations: Vec<Operation>, + local_store: &impl Store, + settings: &Settings, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let client = Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + ) + .expect("failed to create client"); + + let mut uploaded = 0; + let mut downloaded = Vec::new(); + + // this can totally run in parallel, but lets get it working first + for i in operations { + match i { + Operation::Upload { + host, + tag, + local, + remote, + } => uploaded += sync_upload(local_store, &client, host, tag, local, remote).await?, + + Operation::Download { + host, + tag, + local, + remote, + } => { + let mut d = sync_download(local_store, &client, host, tag, local, remote).await?; + downloaded.append(&mut d) + } + + Operation::Noop { .. } => continue, + } + } + + Ok((uploaded, downloaded)) +} + +pub async fn sync( + settings: &Settings, + store: &impl Store, +) -> Result<(i64, Vec<RecordId>), SyncError> { + let (diff, _) = diff(settings, store).await?; + let operations = operations(diff, store).await?; + let (uploaded, downloaded) = sync_remote(operations, store, settings).await?; + + Ok((uploaded, downloaded)) +} + +#[cfg(test)] +mod tests { + use atuin_common::record::{Diff, EncryptedData, HostId, Record}; + use pretty_assertions::assert_eq; + + use crate::record::{ + encryption::PASETO_V4, + sqlite_store::{test_sqlite_store_timeout, SqliteStore}, + store::Store, + sync::{self, Operation}, + }; + + fn test_record() -> Record<EncryptedData> { + Record::builder() + .host(atuin_common::record::Host::new(HostId( + atuin_common::utils::uuid_v7(), + ))) + .version("v1".into()) + .tag(atuin_common::utils::uuid_v7().simple().to_string()) + .data(EncryptedData { + data: String::new(), + content_encryption_key: String::new(), + }) + .idx(0) + .build() + } + + // Take a list of local records, and a list of remote records. + // Return the local database, and a diff of local/remote, ready to build + // ops + async fn build_test_diff( + local_records: Vec<Record<EncryptedData>>, + remote_records: Vec<Record<EncryptedData>>, + ) -> (SqliteStore, Vec<Diff>) { + let local_store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .expect("failed to open in memory sqlite"); + let remote_store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .expect("failed to open in memory sqlite"); // "remote" + + for i in local_records { + local_store.push(&i).await.unwrap(); + } + + for i in remote_records { + remote_store.push(&i).await.unwrap(); + } + + let local_index = local_store.status().await.unwrap(); + let remote_index = remote_store.status().await.unwrap(); + + let diff = local_index.diff(&remote_index); + + (local_store, diff) + } + + #[tokio::test] + async fn test_basic_diff() { + // a diff where local is ahead of remote. nothing else. + + let record = test_record(); + let (store, diff) = build_test_diff(vec![record.clone()], vec![]).await; + + assert_eq!(diff.len(), 1); + + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 1); + + assert_eq!( + operations[0], + Operation::Upload { + host: record.host.id, + tag: record.tag, + local: record.idx, + remote: None, + } + ); + } + + #[tokio::test] + async fn build_two_way_diff() { + // a diff where local is ahead of remote for one, and remote for + // another. One upload, one download + + let shared_record = test_record(); + let remote_ahead = test_record(); + + let local_ahead = shared_record + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + assert_eq!(local_ahead.idx, 1); + + let local = vec![shared_record.clone(), local_ahead.clone()]; // local knows about the already synced, and something newer in the same store + let remote = vec![shared_record.clone(), remote_ahead.clone()]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 2); + + assert_eq!( + operations, + vec![ + // Or in otherwords, local is ahead by one + Operation::Upload { + host: local_ahead.host.id, + tag: local_ahead.tag, + local: 1, + remote: Some(0), + }, + // Or in other words, remote knows of a record in an entirely new store (tag) + Operation::Download { + host: remote_ahead.host.id, + tag: remote_ahead.tag, + local: None, + remote: 0, + }, + ] + ); + } + + #[tokio::test] + async fn build_complex_diff() { + // One shared, ahead but known only by remote + // One known only by local + // One known only by remote + + let shared_record = test_record(); + let local_only = test_record(); + + let local_only_20 = test_record(); + let local_only_21 = local_only_20 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_22 = local_only_21 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let local_only_23 = local_only_22 + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let remote_only = test_record(); + + let remote_only_20 = test_record(); + let remote_only_21 = remote_only_20 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_22 = remote_only_21 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_23 = remote_only_22 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + let remote_only_24 = remote_only_23 + .append(vec![2, 3, 2]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let second_shared = test_record(); + let second_shared_remote_ahead = second_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let second_shared_remote_ahead2 = second_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let third_shared = test_record(); + let third_shared_local_ahead = third_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let third_shared_local_ahead2 = third_shared_local_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let fourth_shared = test_record(); + let fourth_shared_remote_ahead = fourth_shared + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + let fourth_shared_remote_ahead2 = fourth_shared_remote_ahead + .append(vec![1, 2, 3]) + .encrypt::<PASETO_V4>(&[0; 32]); + + let local = vec![ + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + // single store, only local has it + local_only.clone(), + // bigger store, also only known by local + local_only_20.clone(), + local_only_21.clone(), + local_only_22.clone(), + local_only_23.clone(), + // another shared store, but local is ahead on this one + third_shared_local_ahead.clone(), + third_shared_local_ahead2.clone(), + ]; + + let remote = vec![ + remote_only.clone(), + remote_only_20.clone(), + remote_only_21.clone(), + remote_only_22.clone(), + remote_only_23.clone(), + remote_only_24.clone(), + shared_record.clone(), + second_shared.clone(), + third_shared.clone(), + second_shared_remote_ahead.clone(), + second_shared_remote_ahead2.clone(), + fourth_shared.clone(), + fourth_shared_remote_ahead.clone(), + fourth_shared_remote_ahead2.clone(), + ]; // remote knows about the already-synced, and one new record in a new store + + let (store, diff) = build_test_diff(local, remote).await; + let operations = sync::operations(diff, &store).await.unwrap(); + + assert_eq!(operations.len(), 7); + + let mut result_ops = vec![ + // We started with a shared record, but the remote knows of two newer records in the + // same store + Operation::Download { + local: Some(0), + remote: 2, + host: second_shared_remote_ahead.host.id, + tag: second_shared_remote_ahead.tag, + }, + // We have a shared record, local knows of the first two but not the last + Operation::Download { + local: Some(1), + remote: 2, + host: fourth_shared_remote_ahead2.host.id, + tag: fourth_shared_remote_ahead2.tag, + }, + // Remote knows of a store with a single record that local does not have + Operation::Download { + local: None, + remote: 0, + host: remote_only.host.id, + tag: remote_only.tag, + }, + // Remote knows of a store with a bunch of records that local does not have + Operation::Download { + local: None, + remote: 4, + host: remote_only_20.host.id, + tag: remote_only_20.tag, + }, + // Local knows of a record in a store that remote does not have + Operation::Upload { + local: 0, + remote: None, + host: local_only.host.id, + tag: local_only.tag, + }, + // Local knows of 4 records in a store that remote does not have + Operation::Upload { + local: 3, + remote: None, + host: local_only_20.host.id, + tag: local_only_20.tag, + }, + // Local knows of 2 more records in a shared store that remote only has one of + Operation::Upload { + local: 2, + remote: Some(0), + host: third_shared.host.id, + tag: third_shared.tag, + }, + ]; + + result_ops.sort_by_key(|op| match op { + Operation::Noop { host, tag } => (0, *host, tag.clone()), + + Operation::Upload { host, tag, .. } => (1, *host, tag.clone()), + + Operation::Download { host, tag, .. } => (2, *host, tag.clone()), + }); + + assert_eq!(result_ops, operations); + } +} diff --git a/crates/atuin-client/src/secrets.rs b/crates/atuin-client/src/secrets.rs new file mode 100644 index 00000000..21f015cd --- /dev/null +++ b/crates/atuin-client/src/secrets.rs @@ -0,0 +1,59 @@ +// This file will probably trigger a lot of scanners. Sorry. + +// A list of (name, regex, test), where test should match against regex +pub static SECRET_PATTERNS: &[(&str, &str, &str)] = &[ + ( + "AWS Access Key ID", + "AKIA[0-9A-Z]{16}", + "AKIAIOSFODNN7EXAMPLE", + ), + ( + "Atuin login", + r"atuin\s+login", + "atuin login -u mycoolusername -p mycoolpassword -k \"lots of random words\"", + ), + ( + "GitHub PAT (old)", + "ghp_[a-zA-Z0-9]{36}", + "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", // legit, I expired it + ), + ( + "GitHub PAT (new)", + "github_pat_[a-zA-Z0-9]{22}_[a-zA-Z0-9]{59}", + "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", // also legit, also expired + ), + ( + "Slack OAuth v2 bot", + "xoxb-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + "xoxb-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy", + ), + ( + "Slack OAuth v2 user token", + "xoxp-[0-9]{11}-[0-9]{11}-[0-9a-zA-Z]{24}", + "xoxp-17653672481-19874698323-pdFZKVeTuE8sk7oOcBrzbqgy", + ), + ( + "Slack webhook", + "T[a-zA-Z0-9_]{8}/B[a-zA-Z0-9_]{8}/[a-zA-Z0-9_]{24}", + "https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX", + ), + ("Stripe test key", "sk_test_[0-9a-zA-Z]{24}", "sk_test_1234567890abcdefghijklmnop"), + ("Stripe live key", "sk_live_[0-9a-zA-Z]{24}", "sk_live_1234567890abcdefghijklmnop"), +]; + +#[cfg(test)] +mod tests { + use regex::Regex; + + use crate::secrets::SECRET_PATTERNS; + + #[test] + fn test_secrets() { + for (name, regex, test) in SECRET_PATTERNS { + let re = + Regex::new(regex).unwrap_or_else(|_| panic!("Failed to compile regex for {name}")); + + assert!(re.is_match(test), "{name} test failed!"); + } + } +} diff --git a/crates/atuin-client/src/settings.rs b/crates/atuin-client/src/settings.rs new file mode 100644 index 00000000..daf8fe34 --- /dev/null +++ b/crates/atuin-client/src/settings.rs @@ -0,0 +1,784 @@ +use std::{ + collections::HashMap, + convert::TryFrom, + fmt, + io::prelude::*, + path::{Path, PathBuf}, + str::FromStr, +}; + +use atuin_common::record::HostId; +use clap::ValueEnum; +use config::{ + builder::DefaultState, Config, ConfigBuilder, Environment, File as ConfigFile, FileFormat, +}; +use eyre::{bail, eyre, Context, Error, Result}; +use fs_err::{create_dir_all, File}; +use parse_duration::parse; +use regex::RegexSet; +use semver::Version; +use serde::Deserialize; +use serde_with::DeserializeFromStr; +use time::{ + format_description::{well_known::Rfc3339, FormatItem}, + macros::format_description, + OffsetDateTime, UtcOffset, +}; +use uuid::Uuid; + +pub const HISTORY_PAGE_SIZE: i64 = 100; +pub const LAST_SYNC_FILENAME: &str = "last_sync_time"; +pub const LAST_VERSION_CHECK_FILENAME: &str = "last_version_check_time"; +pub const LATEST_VERSION_FILENAME: &str = "latest_version"; +pub const HOST_ID_FILENAME: &str = "host_id"; +static EXAMPLE_CONFIG: &str = include_str!("../config.toml"); + +mod dotfiles; + +#[derive(Clone, Debug, Deserialize, Copy, ValueEnum, PartialEq)] +pub enum SearchMode { + #[serde(rename = "prefix")] + Prefix, + + #[serde(rename = "fulltext")] + #[clap(aliases = &["fulltext"])] + FullText, + + #[serde(rename = "fuzzy")] + Fuzzy, + + #[serde(rename = "skim")] + Skim, +} + +impl SearchMode { + pub fn as_str(&self) -> &'static str { + match self { + SearchMode::Prefix => "PREFIX", + SearchMode::FullText => "FULLTXT", + SearchMode::Fuzzy => "FUZZY", + SearchMode::Skim => "SKIM", + } + } + pub fn next(&self, settings: &Settings) -> Self { + match self { + SearchMode::Prefix => SearchMode::FullText, + // if the user is using skim, we go to skim + SearchMode::FullText if settings.search_mode == SearchMode::Skim => SearchMode::Skim, + // otherwise fuzzy. + SearchMode::FullText => SearchMode::Fuzzy, + SearchMode::Fuzzy | SearchMode::Skim => SearchMode::Prefix, + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum)] +pub enum FilterMode { + #[serde(rename = "global")] + Global = 0, + + #[serde(rename = "host")] + Host = 1, + + #[serde(rename = "session")] + Session = 2, + + #[serde(rename = "directory")] + Directory = 3, + + #[serde(rename = "workspace")] + Workspace = 4, +} + +impl FilterMode { + pub fn as_str(&self) -> &'static str { + match self { + FilterMode::Global => "GLOBAL", + FilterMode::Host => "HOST", + FilterMode::Session => "SESSION", + FilterMode::Directory => "DIRECTORY", + FilterMode::Workspace => "WORKSPACE", + } + } +} + +#[derive(Clone, Debug, Deserialize, Copy)] +pub enum ExitMode { + #[serde(rename = "return-original")] + ReturnOriginal, + + #[serde(rename = "return-query")] + ReturnQuery, +} + +// FIXME: Can use upstream Dialect enum if https://github.com/stevedonovan/chrono-english/pull/16 is merged +// FIXME: Above PR was merged, but dependency was changed to interim (fork of chrono-english) in the ... interim +#[derive(Clone, Debug, Deserialize, Copy)] +pub enum Dialect { + #[serde(rename = "us")] + Us, + + #[serde(rename = "uk")] + Uk, +} + +impl From<Dialect> for interim::Dialect { + fn from(d: Dialect) -> interim::Dialect { + match d { + Dialect::Uk => interim::Dialect::Uk, + Dialect::Us => interim::Dialect::Us, + } + } +} + +/// Type wrapper around `time::UtcOffset` to support a wider variety of timezone formats. +/// +/// Note that the parsing of this struct needs to be done before starting any +/// multithreaded runtime, otherwise it will fail on most Unix systems. +/// +/// See: https://github.com/atuinsh/atuin/pull/1517#discussion_r1447516426 +#[derive(Clone, Copy, Debug, Eq, PartialEq, DeserializeFromStr)] +pub struct Timezone(pub UtcOffset); +impl fmt::Display for Timezone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +/// format: <+|-><hour>[:<minute>[:<second>]] +static OFFSET_FMT: &[FormatItem<'_>] = + format_description!("[offset_hour sign:mandatory padding:none][optional [:[offset_minute padding:none][optional [:[offset_second padding:none]]]]]"); +impl FromStr for Timezone { + type Err = Error; + + fn from_str(s: &str) -> Result<Self> { + // local timezone + if matches!(s.to_lowercase().as_str(), "l" | "local") { + // There have been some timezone issues, related to errors fetching it on some + // platforms + // Rather than fail to start, fallback to UTC. The user should still be able to specify + // their timezone manually in the config file. + let offset = UtcOffset::current_local_offset().unwrap_or(UtcOffset::UTC); + return Ok(Self(offset)); + } + + if matches!(s.to_lowercase().as_str(), "0" | "utc") { + let offset = UtcOffset::UTC; + return Ok(Self(offset)); + } + + // offset from UTC + if let Ok(offset) = UtcOffset::parse(s, OFFSET_FMT) { + return Ok(Self(offset)); + } + + // IDEA: Currently named timezones are not supported, because the well-known crate + // for this is `chrono_tz`, which is not really interoperable with the datetime crate + // that we currently use - `time`. If ever we migrate to using `chrono`, this would + // be a good feature to add. + + bail!(r#""{s}" is not a valid timezone spec"#) + } +} + +#[derive(Clone, Debug, Deserialize, Copy)] +pub enum Style { + #[serde(rename = "auto")] + Auto, + + #[serde(rename = "full")] + Full, + + #[serde(rename = "compact")] + Compact, +} + +#[derive(Clone, Debug, Deserialize, Copy)] +pub enum WordJumpMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "subl")] + Subl, +} + +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum)] +pub enum KeymapMode { + #[serde(rename = "emacs")] + Emacs, + + #[serde(rename = "vim-normal")] + VimNormal, + + #[serde(rename = "vim-insert")] + VimInsert, + + #[serde(rename = "auto")] + Auto, +} + +impl KeymapMode { + pub fn as_str(&self) -> &'static str { + match self { + KeymapMode::Emacs => "EMACS", + KeymapMode::VimNormal => "VIMNORMAL", + KeymapMode::VimInsert => "VIMINSERT", + KeymapMode::Auto => "AUTO", + } + } +} + +// We want to translate the config to crossterm::cursor::SetCursorStyle, but +// the original type does not implement trait serde::Deserialize unfortunately. +// It seems impossible to implement Deserialize for external types when it is +// used in HashMap (https://stackoverflow.com/questions/67142663). We instead +// define an adapter type. +#[derive(Clone, Debug, Deserialize, Copy, PartialEq, Eq, ValueEnum)] +pub enum CursorStyle { + #[serde(rename = "default")] + DefaultUserShape, + + #[serde(rename = "blink-block")] + BlinkingBlock, + + #[serde(rename = "steady-block")] + SteadyBlock, + + #[serde(rename = "blink-underline")] + BlinkingUnderScore, + + #[serde(rename = "steady-underline")] + SteadyUnderScore, + + #[serde(rename = "blink-bar")] + BlinkingBar, + + #[serde(rename = "steady-bar")] + SteadyBar, +} + +impl CursorStyle { + pub fn as_str(&self) -> &'static str { + match self { + CursorStyle::DefaultUserShape => "DEFAULT", + CursorStyle::BlinkingBlock => "BLINKBLOCK", + CursorStyle::SteadyBlock => "STEADYBLOCK", + CursorStyle::BlinkingUnderScore => "BLINKUNDERLINE", + CursorStyle::SteadyUnderScore => "STEADYUNDERLINE", + CursorStyle::BlinkingBar => "BLINKBAR", + CursorStyle::SteadyBar => "STEADYBAR", + } + } +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Stats { + #[serde(default = "Stats::common_prefix_default")] + pub common_prefix: Vec<String>, // sudo, etc. commands we want to strip off + #[serde(default = "Stats::common_subcommands_default")] + pub common_subcommands: Vec<String>, // kubectl, commands we should consider subcommands for + #[serde(default = "Stats::ignored_commands_default")] + pub ignored_commands: Vec<String>, // cd, ls, etc. commands we want to completely hide from stats +} + +impl Stats { + fn common_prefix_default() -> Vec<String> { + vec!["sudo", "doas"].into_iter().map(String::from).collect() + } + + fn common_subcommands_default() -> Vec<String> { + vec![ + "apt", + "cargo", + "composer", + "dnf", + "docker", + "git", + "go", + "ip", + "kubectl", + "nix", + "nmcli", + "npm", + "pecl", + "pnpm", + "podman", + "port", + "systemctl", + "tmux", + "yarn", + ] + .into_iter() + .map(String::from) + .collect() + } + + fn ignored_commands_default() -> Vec<String> { + vec![] + } +} + +impl Default for Stats { + fn default() -> Self { + Self { + common_prefix: Self::common_prefix_default(), + common_subcommands: Self::common_subcommands_default(), + ignored_commands: Self::ignored_commands_default(), + } + } +} + +#[derive(Clone, Debug, Deserialize, Default)] +pub struct Sync { + pub records: bool, +} + +#[derive(Clone, Debug, Deserialize, Default)] +pub struct Keys { + pub scroll_exits: bool, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Settings { + pub dialect: Dialect, + pub timezone: Timezone, + pub style: Style, + pub auto_sync: bool, + pub update_check: bool, + pub sync_address: String, + pub sync_frequency: String, + pub db_path: String, + pub record_store_path: String, + pub key_path: String, + pub session_path: String, + pub search_mode: SearchMode, + pub filter_mode: FilterMode, + pub filter_mode_shell_up_key_binding: Option<FilterMode>, + pub search_mode_shell_up_key_binding: Option<SearchMode>, + pub shell_up_key_binding: bool, + pub inline_height: u16, + pub invert: bool, + pub show_preview: bool, + pub show_preview_auto: bool, + pub max_preview_height: u16, + pub show_help: bool, + pub show_tabs: bool, + pub exit_mode: ExitMode, + pub keymap_mode: KeymapMode, + pub keymap_mode_shell: KeymapMode, + pub keymap_cursor: HashMap<String, CursorStyle>, + pub word_jump_mode: WordJumpMode, + pub word_chars: String, + pub scroll_context_lines: usize, + pub history_format: String, + pub prefers_reduced_motion: bool, + pub store_failed: bool, + + #[serde(with = "serde_regex", default = "RegexSet::empty")] + pub history_filter: RegexSet, + + #[serde(with = "serde_regex", default = "RegexSet::empty")] + pub cwd_filter: RegexSet, + + pub secrets_filter: bool, + pub workspaces: bool, + pub ctrl_n_shortcuts: bool, + + pub network_connect_timeout: u64, + pub network_timeout: u64, + pub local_timeout: f64, + pub enter_accept: bool, + pub smart_sort: bool, + + #[serde(default)] + pub stats: Stats, + + #[serde(default)] + pub sync: Sync, + + #[serde(default)] + pub keys: Keys, + + #[serde(default)] + pub dotfiles: dotfiles::Settings, + + // This is automatically loaded when settings is created. Do not set in + // config! Keep secrets and settings apart. + #[serde(skip)] + pub session_token: String, +} + +impl Settings { + pub fn utc() -> Self { + Self::builder() + .expect("Could not build default") + .set_override("timezone", "0") + .expect("failed to override timezone with UTC") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } + + fn save_to_data_dir(filename: &str, value: &str) -> Result<()> { + let data_dir = atuin_common::utils::data_dir(); + let data_dir = data_dir.as_path(); + + let path = data_dir.join(filename); + + fs_err::write(path, value)?; + + Ok(()) + } + + fn read_from_data_dir(filename: &str) -> Option<String> { + let data_dir = atuin_common::utils::data_dir(); + let data_dir = data_dir.as_path(); + + let path = data_dir.join(filename); + + if !path.exists() { + return None; + } + + let value = fs_err::read_to_string(path); + + value.ok() + } + + fn save_current_time(filename: &str) -> Result<()> { + Settings::save_to_data_dir( + filename, + OffsetDateTime::now_utc().format(&Rfc3339)?.as_str(), + )?; + + Ok(()) + } + + fn load_time_from_file(filename: &str) -> Result<OffsetDateTime> { + let value = Settings::read_from_data_dir(filename); + + match value { + Some(v) => Ok(OffsetDateTime::parse(v.as_str(), &Rfc3339)?), + None => Ok(OffsetDateTime::UNIX_EPOCH), + } + } + + pub fn save_sync_time() -> Result<()> { + Settings::save_current_time(LAST_SYNC_FILENAME) + } + + pub fn save_version_check_time() -> Result<()> { + Settings::save_current_time(LAST_VERSION_CHECK_FILENAME) + } + + pub fn last_sync() -> Result<OffsetDateTime> { + Settings::load_time_from_file(LAST_SYNC_FILENAME) + } + + pub fn last_version_check() -> Result<OffsetDateTime> { + Settings::load_time_from_file(LAST_VERSION_CHECK_FILENAME) + } + + pub fn host_id() -> Option<HostId> { + let id = Settings::read_from_data_dir(HOST_ID_FILENAME); + + if let Some(id) = id { + let parsed = + Uuid::from_str(id.as_str()).expect("failed to parse host ID from local directory"); + return Some(HostId(parsed)); + } + + let uuid = atuin_common::utils::uuid_v7(); + + Settings::save_to_data_dir(HOST_ID_FILENAME, uuid.as_simple().to_string().as_ref()) + .expect("Could not write host ID to data dir"); + + Some(HostId(uuid)) + } + + pub fn should_sync(&self) -> Result<bool> { + if !self.auto_sync || !PathBuf::from(self.session_path.as_str()).exists() { + return Ok(false); + } + + match parse(self.sync_frequency.as_str()) { + Ok(d) => { + let d = time::Duration::try_from(d).unwrap(); + Ok(OffsetDateTime::now_utc() - Settings::last_sync()? >= d) + } + Err(e) => Err(eyre!("failed to check sync: {}", e)), + } + } + + #[cfg(feature = "check-update")] + fn needs_update_check(&self) -> Result<bool> { + let last_check = Settings::last_version_check()?; + let diff = OffsetDateTime::now_utc() - last_check; + + // Check a max of once per hour + Ok(diff.whole_hours() >= 1) + } + + #[cfg(feature = "check-update")] + async fn latest_version(&self) -> Result<Version> { + // Default to the current version, and if that doesn't parse, a version so high it's unlikely to ever + // suggest upgrading. + let current = + Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or(Version::new(100000, 0, 0)); + + if !self.needs_update_check()? { + // Worst case, we don't want Atuin to fail to start because something funky is going on with + // version checking. + let version = tokio::task::spawn_blocking(|| { + Settings::read_from_data_dir(LATEST_VERSION_FILENAME) + }) + .await + .expect("file task panicked"); + + let version = match version { + Some(v) => Version::parse(&v).unwrap_or(current), + None => current, + }; + + return Ok(version); + } + + #[cfg(feature = "sync")] + let latest = crate::api_client::latest_version().await.unwrap_or(current); + + #[cfg(not(feature = "sync"))] + let latest = current; + + let latest_encoded = latest.to_string(); + tokio::task::spawn_blocking(move || { + Settings::save_version_check_time()?; + Settings::save_to_data_dir(LATEST_VERSION_FILENAME, &latest_encoded)?; + Ok::<(), eyre::Report>(()) + }) + .await + .expect("file task panicked")?; + + Ok(latest) + } + + // Return Some(latest version) if an update is needed. Otherwise, none. + #[cfg(feature = "check-update")] + pub async fn needs_update(&self) -> Option<Version> { + if !self.update_check { + return None; + } + + let current = + Version::parse(env!("CARGO_PKG_VERSION")).unwrap_or(Version::new(100000, 0, 0)); + + let latest = self.latest_version().await; + + if latest.is_err() { + return None; + } + + let latest = latest.unwrap(); + + if latest > current { + return Some(latest); + } + + None + } + + #[cfg(not(feature = "check-update"))] + pub async fn needs_update(&self) -> Option<Version> { + None + } + + pub fn builder() -> Result<ConfigBuilder<DefaultState>> { + let data_dir = atuin_common::utils::data_dir(); + let db_path = data_dir.join("history.db"); + let record_store_path = data_dir.join("records.db"); + + let key_path = data_dir.join("key"); + let session_path = data_dir.join("session"); + + Ok(Config::builder() + .set_default("history_format", "{time}\t{command}\t{duration}")? + .set_default("db_path", db_path.to_str())? + .set_default("record_store_path", record_store_path.to_str())? + .set_default("key_path", key_path.to_str())? + .set_default("session_path", session_path.to_str())? + .set_default("dialect", "us")? + .set_default("timezone", "local")? + .set_default("auto_sync", true)? + .set_default("update_check", cfg!(feature = "check-update"))? + .set_default("sync_address", "https://api.atuin.sh")? + .set_default("sync_frequency", "10m")? + .set_default("search_mode", "fuzzy")? + .set_default("filter_mode", "global")? + .set_default("style", "auto")? + .set_default("inline_height", 0)? + .set_default("show_preview", false)? + .set_default("show_preview_auto", true)? + .set_default("max_preview_height", 4)? + .set_default("show_help", true)? + .set_default("show_tabs", true)? + .set_default("invert", false)? + .set_default("exit_mode", "return-original")? + .set_default("word_jump_mode", "emacs")? + .set_default( + "word_chars", + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + )? + .set_default("scroll_context_lines", 1)? + .set_default("shell_up_key_binding", false)? + .set_default("session_token", "")? + .set_default("workspaces", false)? + .set_default("ctrl_n_shortcuts", false)? + .set_default("secrets_filter", true)? + .set_default("network_connect_timeout", 5)? + .set_default("network_timeout", 30)? + .set_default("local_timeout", 2.0)? + // enter_accept defaults to false here, but true in the default config file. The dissonance is + // intentional! + // Existing users will get the default "False", so we don't mess with any potential + // muscle memory. + // New users will get the new default, that is more similar to what they are used to. + .set_default("enter_accept", false)? + .set_default("sync.records", false)? + .set_default("keys.scroll_exits", true)? + .set_default("keymap_mode", "emacs")? + .set_default("keymap_mode_shell", "auto")? + .set_default("keymap_cursor", HashMap::<String, String>::new())? + .set_default("smart_sort", false)? + .set_default("store_failed", true)? + .set_default( + "prefers_reduced_motion", + std::env::var("NO_MOTION") + .ok() + .map(|_| config::Value::new(None, config::ValueKind::Boolean(true))) + .unwrap_or_else(|| config::Value::new(None, config::ValueKind::Boolean(false))), + )? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + )) + } + + pub fn new() -> Result<Self> { + let config_dir = atuin_common::utils::config_dir(); + let data_dir = atuin_common::utils::data_dir(); + + create_dir_all(&config_dir) + .wrap_err_with(|| format!("could not create dir {config_dir:?}"))?; + + create_dir_all(&data_dir).wrap_err_with(|| format!("could not create dir {data_dir:?}"))?; + + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + config_file.push(config_dir); + config_file + }; + + config_file.push("config.toml"); + + let mut config_builder = Self::builder()?; + + config_builder = if config_file.exists() { + config_builder.add_source(ConfigFile::new( + config_file.to_str().unwrap(), + FileFormat::Toml, + )) + } else { + let mut file = File::create(config_file).wrap_err("could not create config file")?; + file.write_all(EXAMPLE_CONFIG.as_bytes()) + .wrap_err("could not write default config file")?; + + config_builder + }; + + let config = config_builder.build()?; + let mut settings: Settings = config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e))?; + + // all paths should be expanded + let db_path = settings.db_path; + let db_path = shellexpand::full(&db_path)?; + settings.db_path = db_path.to_string(); + + let key_path = settings.key_path; + let key_path = shellexpand::full(&key_path)?; + settings.key_path = key_path.to_string(); + + let session_path = settings.session_path; + let session_path = shellexpand::full(&session_path)?; + settings.session_path = session_path.to_string(); + + // Finally, set the auth token + if Path::new(session_path.to_string().as_str()).exists() { + let token = fs_err::read_to_string(session_path.to_string())?; + settings.session_token = token.trim().to_string(); + } else { + settings.session_token = String::from("not logged in"); + } + + Ok(settings) + } + + pub fn example_config() -> &'static str { + EXAMPLE_CONFIG + } +} + +impl Default for Settings { + fn default() -> Self { + // if this panics something is very wrong, as the default config + // does not build or deserialize into the settings struct + Self::builder() + .expect("Could not build default") + .build() + .expect("Could not build config") + .try_deserialize() + .expect("Could not deserialize config") + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use eyre::Result; + + use super::Timezone; + + #[test] + fn can_parse_offset_timezone_spec() -> Result<()> { + assert_eq!(Timezone::from_str("+02")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-04")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+05:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-09:30")?.0.as_hms(), (-9, -30, 0)); + + // single digit hours are allowed + assert_eq!(Timezone::from_str("+2")?.0.as_hms(), (2, 0, 0)); + assert_eq!(Timezone::from_str("-4")?.0.as_hms(), (-4, 0, 0)); + assert_eq!(Timezone::from_str("+5:30")?.0.as_hms(), (5, 30, 0)); + assert_eq!(Timezone::from_str("-9:30")?.0.as_hms(), (-9, -30, 0)); + + // fully qualified form + assert_eq!(Timezone::from_str("+09:30:00")?.0.as_hms(), (9, 30, 0)); + assert_eq!(Timezone::from_str("-09:30:00")?.0.as_hms(), (-9, -30, 0)); + + // these offsets don't really exist but are supported anyway + assert_eq!(Timezone::from_str("+0:5")?.0.as_hms(), (0, 5, 0)); + assert_eq!(Timezone::from_str("-0:5")?.0.as_hms(), (0, -5, 0)); + assert_eq!(Timezone::from_str("+01:23:45")?.0.as_hms(), (1, 23, 45)); + assert_eq!(Timezone::from_str("-01:23:45")?.0.as_hms(), (-1, -23, -45)); + + // require a leading sign for clarity + assert!(Timezone::from_str("5").is_err()); + assert!(Timezone::from_str("10:30").is_err()); + + Ok(()) + } +} diff --git a/crates/atuin-client/src/settings/dotfiles.rs b/crates/atuin-client/src/settings/dotfiles.rs new file mode 100644 index 00000000..dd852781 --- /dev/null +++ b/crates/atuin-client/src/settings/dotfiles.rs @@ -0,0 +1,6 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize, Clone, Default)] +pub struct Settings { + pub enabled: bool, +} diff --git a/crates/atuin-client/src/sync.rs b/crates/atuin-client/src/sync.rs new file mode 100644 index 00000000..1f0d3dd8 --- /dev/null +++ b/crates/atuin-client/src/sync.rs @@ -0,0 +1,210 @@ +use std::collections::HashSet; +use std::convert::TryInto; +use std::iter::FromIterator; + +use eyre::Result; + +use atuin_common::api::AddHistoryRequest; +use crypto_secretbox::Key; +use time::OffsetDateTime; + +use crate::{ + api_client, + database::Database, + encryption::{decrypt, encrypt, load_key}, + settings::Settings, +}; + +pub fn hash_str(string: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(string.as_bytes()); + hex::encode(hasher.finalize()) +} + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +async fn sync_download( + key: &Key, + force: bool, + client: &api_client::Client<'_>, + db: &(impl Database + Send), +) -> Result<(i64, i64)> { + debug!("starting sync download"); + + let remote_status = client.status().await?; + let remote_count = remote_status.count; + + // useful to ensure we don't even save something that hasn't yet been synced + deleted + let remote_deleted = + HashSet::<&str>::from_iter(remote_status.deleted.iter().map(String::as_str)); + + let initial_local = db.history_count(true).await?; + let mut local_count = initial_local; + + let mut last_sync = if force { + OffsetDateTime::UNIX_EPOCH + } else { + Settings::last_sync()? + }; + + let mut last_timestamp = OffsetDateTime::UNIX_EPOCH; + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client + .get_history(last_sync, last_timestamp, host.clone()) + .await?; + + let history: Vec<_> = page + .history + .iter() + // TODO: handle deletion earlier in this chain + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(h, key).expect("failed to decrypt history! check your key")) + .map(|mut h| { + if remote_deleted.contains(h.id.0.as_str()) { + h.deleted_at = Some(time::OffsetDateTime::now_utc()); + h.command = String::from(""); + } + + h + }) + .collect(); + + db.save_bulk(&history).await?; + + local_count = db.history_count(true).await?; + + if history.len() < remote_status.page_size.try_into().unwrap() { + break; + } + + let page_last = history + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = OffsetDateTime::UNIX_EPOCH; + last_sync -= time::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + for i in remote_status.deleted { + // we will update the stored history to have this data + // pretty much everything can be nullified + if let Some(h) = db.load(i.as_str()).await? { + db.delete(h).await?; + } else { + info!( + "could not delete history with id {}, not found locally", + i.as_str() + ); + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +async fn sync_upload( + key: &Key, + _force: bool, + client: &api_client::Client<'_>, + db: &(impl Database + Send), +) -> Result<()> { + debug!("starting sync upload"); + + let remote_status = client.status().await?; + let remote_deleted: HashSet<String> = HashSet::from_iter(remote_status.deleted.clone()); + + let initial_remote_count = client.count().await?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count(true).await?; + + debug!("remote has {}, we have {}", remote_count, local_count); + + // first just try the most recent set + let mut cursor = OffsetDateTime::now_utc(); + + while local_count > remote_count { + let last = db.before(cursor, remote_status.page_size).await?; + let mut buffer = Vec::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id.to_string(), + timestamp: i.timestamp, + data, + hostname: hash_str(&i.hostname), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer).await?; + cursor = buffer.last().unwrap().timestamp; + remote_count = client.count().await?; + + debug!("upload cursor: {:?}", cursor); + } + + let deleted = db.deleted().await?; + + for i in deleted { + if remote_deleted.contains(&i.id.to_string()) { + continue; + } + + info!("deleting {} on remote", i.id); + client.delete_history(i).await?; + } + + Ok(()) +} + +pub async fn sync(settings: &Settings, force: bool, db: &(impl Database + Send)) -> Result<()> { + let client = api_client::Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + Settings::save_sync_time()?; + + let key = load_key(settings)?; // encryption key + + sync_upload(&key, force, &client, db).await?; + + let download = sync_download(&key, force, &client, db).await?; + + debug!("sync downloaded {}", download.0); + + Ok(()) +} diff --git a/crates/atuin-client/src/utils.rs b/crates/atuin-client/src/utils.rs new file mode 100644 index 00000000..a7c6eab0 --- /dev/null +++ b/crates/atuin-client/src/utils.rs @@ -0,0 +1,14 @@ +pub(crate) fn get_hostname() -> String { + std::env::var("ATUIN_HOST_NAME").unwrap_or_else(|_| { + whoami::fallible::hostname().unwrap_or_else(|_| "unknown-host".to_string()) + }) +} + +pub(crate) fn get_username() -> String { + std::env::var("ATUIN_HOST_USER").unwrap_or_else(|_| whoami::username()) +} + +/// Returns a pair of the hostname and username, separated by a colon. +pub(crate) fn get_host_user() -> String { + format!("{}:{}", get_hostname(), get_username()) +} diff --git a/crates/atuin-client/tests/data/xonsh-history.sqlite b/crates/atuin-client/tests/data/xonsh-history.sqlite Binary files differnew file mode 100644 index 00000000..744fcf86 --- /dev/null +++ b/crates/atuin-client/tests/data/xonsh-history.sqlite diff --git a/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json b/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json new file mode 100644 index 00000000..339a09f1 --- /dev/null +++ b/crates/atuin-client/tests/data/xonsh/xonsh-82eafbf5-9f43-489a-80d2-61c7dc6ef542.json @@ -0,0 +1,12 @@ +{"locs": [ 69, 3371, 3451, 3978], + "index": {"offsets":{"__total__":0,"cmds":[{"__total__":10,"cwd":18,"inp":78,"rtn":96,"ts":[106,125,105]},{"__total__":149,"cwd":157,"inp":217,"rtn":234,"ts":[244,263,243]},9],"env":{"ATUIN_SESSION":314,"BASH_COMPLETIONS":370,"COLORTERM":433,"DBUS_SESSION_BUS_ADDRESS":474,"DESKTOP_SESSION":529,"DISPLAY":550,"GDMSESSION":570,"GIO_LAUNCHED_DESKTOP_FILE":609,"GIO_LAUNCHED_DESKTOP_FILE_PID":704,"GJS_DEBUG_OUTPUT":734,"GJS_DEBUG_TOPICS":764,"GNOME_DESKTOP_SESSION_ID":811,"GNOME_SETUP_DISPLAY":856,"GNOME_SHELL_SESSION_MODE":890,"GTK_MODULES":915,"HOME":942,"IM_CONFIG_PHASE":976,"INVOCATION_ID":998,"JOURNAL_STREAM":1052,"LANG":1071,"LOGNAME":1097,"MANAGERPID":1118,"MOZ_ENABLE_WAYLAND":1148,"PATH":1161,"PWD":1736,"PYENV_DIR":1802,"PYENV_HOOK_PATH":1874,"PYENV_ROOT":2048,"PYENV_SHELL":2086,"PYENV_VERSION":2111,"QT_ACCESSIBILITY":2141,"QT_IM_MODULE":2162,"SESSION_MANAGER":2189,"SHELL":2279,"SHLVL":2303,"SSH_AGENT_LAUNCHER":2330,"SSH_AUTH_SOCK":2364,"SSL_CERT_DIR":2415,"SSL_CERT_FILE":2458,"SYSTEMD_EXEC_PID":2525,"TERM":2541,"TERM_PROGRAM":2575,"TERM_PROGRAM_VERSION":2610,"THREAD_SUBPROCS":2657,"USER":2670,"USERNAME":2689,"WAYLAND_DISPLAY":2715,"WEZTERM_CONFIG_DIR":2750,"WEZTERM_CONFIG_FILE":2806,"WEZTERM_EXECUTABLE":2874,"WEZTERM_EXECUTABLE_DIR":2927,"WEZTERM_PANE":2957,"WEZTERM_UNIX_SOCKET":2986,"XAUTHORITY":3047,"XDG_CONFIG_DIRS":3116,"XDG_CURRENT_DESKTOP":3176,"XDG_DATA_DIRS":3209,"XDG_MENU_PREFIX":3316,"XDG_RUNTIME_DIR":3345,"XDG_SESSION_CLASS":3387,"XDG_SESSION_DESKTOP":3418,"XDG_SESSION_TYPE":3448,"XMODIFIERS":3473,"XONSHRC":3496,"XONSHRC_DIR":3594,"XONSH_CAPTURE_ALWAYS":3674,"XONSH_CONFIG_DIR":3698,"XONSH_DATA_DIR":3747,"XONSH_INTERACTIVE":3805,"XONSH_LOGIN":3825,"XONSH_VERSION":3847,"__total__":296},"locked":3869,"sessionid":3889,"ts":[3936,3956,3935]},"sizes":{"__total__":3978,"cmds":[{"__total__":137,"cwd":51,"inp":9,"rtn":1,"ts":[17,18,40]},{"__total__":136,"cwd":51,"inp":8,"rtn":1,"ts":[17,18,40]},278],"env":{"ATUIN_SESSION":34,"BASH_COMPLETIONS":48,"COLORTERM":11,"DBUS_SESSION_BUS_ADDRESS":34,"DESKTOP_SESSION":8,"DISPLAY":4,"GDMSESSION":8,"GIO_LAUNCHED_DESKTOP_FILE":60,"GIO_LAUNCHED_DESKTOP_FILE_PID":8,"GJS_DEBUG_OUTPUT":8,"GJS_DEBUG_TOPICS":17,"GNOME_DESKTOP_SESSION_ID":20,"GNOME_SETUP_DISPLAY":4,"GNOME_SHELL_SESSION_MODE":8,"GTK_MODULES":17,"HOME":13,"IM_CONFIG_PHASE":3,"INVOCATION_ID":34,"JOURNAL_STREAM":9,"LANG":13,"LOGNAME":5,"MANAGERPID":6,"MOZ_ENABLE_WAYLAND":3,"PATH":566,"PWD":51,"PYENV_DIR":51,"PYENV_HOOK_PATH":158,"PYENV_ROOT":21,"PYENV_SHELL":6,"PYENV_VERSION":8,"QT_ACCESSIBILITY":3,"QT_IM_MODULE":6,"SESSION_MANAGER":79,"SHELL":13,"SHLVL":3,"SSH_AGENT_LAUNCHER":15,"SSH_AUTH_SOCK":33,"SSL_CERT_DIR":24,"SSL_CERT_FILE":45,"SYSTEMD_EXEC_PID":6,"TERM":16,"TERM_PROGRAM":9,"TERM_PROGRAM_VERSION":26,"THREAD_SUBPROCS":3,"USER":5,"USERNAME":5,"WAYLAND_DISPLAY":11,"WEZTERM_CONFIG_DIR":31,"WEZTERM_CONFIG_FILE":44,"WEZTERM_EXECUTABLE":25,"WEZTERM_EXECUTABLE_DIR":12,"WEZTERM_PANE":4,"WEZTERM_UNIX_SOCKET":45,"XAUTHORITY":48,"XDG_CONFIG_DIRS":35,"XDG_CURRENT_DESKTOP":14,"XDG_DATA_DIRS":86,"XDG_MENU_PREFIX":8,"XDG_RUNTIME_DIR":19,"XDG_SESSION_CLASS":6,"XDG_SESSION_DESKTOP":8,"XDG_SESSION_TYPE":9,"XMODIFIERS":10,"XONSHRC":81,"XONSHRC_DIR":54,"XONSH_CAPTURE_ALWAYS":2,"XONSH_CONFIG_DIR":29,"XONSH_DATA_DIR":35,"XONSH_INTERACTIVE":3,"XONSH_LOGIN":3,"XONSH_VERSION":8,"__total__":3561},"locked":5,"sessionid":38,"ts":[18,18,41]}}, + "data": {"cmds": [{"cwd": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "inp": "false\n", "rtn": 1, "ts": [1707241291.142516, 1707241291.1527853] +} +, {"cwd": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "inp": "exit\n", "rtn": 0, "ts": [1707241292.271584, 1707241292.2758434] +} +] +, "env": {"ATUIN_SESSION": "018d7f82ad167dc4888ca0bf294d2bfd", "BASH_COMPLETIONS": "\/usr\/share\/bash-completion\/bash_completion", "COLORTERM": "truecolor", "DBUS_SESSION_BUS_ADDRESS": "unix:path=\/run\/user\/1000\/bus", "DESKTOP_SESSION": "ubuntu", "DISPLAY": ":0", "GDMSESSION": "ubuntu", "GIO_LAUNCHED_DESKTOP_FILE": "\/usr\/share\/applications\/org.wezfurlong.wezterm.desktop", "GIO_LAUNCHED_DESKTOP_FILE_PID": "196859", "GJS_DEBUG_OUTPUT": "stderr", "GJS_DEBUG_TOPICS": "JS ERROR;JS LOG", "GNOME_DESKTOP_SESSION_ID": "this-is-deprecated", "GNOME_SETUP_DISPLAY": ":1", "GNOME_SHELL_SESSION_MODE": "ubuntu", "GTK_MODULES": "gail:atk-bridge", "HOME": "\/home\/user", "IM_CONFIG_PHASE": "1", "INVOCATION_ID": "4f121e7ad56c41a6b84aa3cbe1ad61fa", "JOURNAL_STREAM": "8:37187", "LANG": "en_US.UTF-8", "LOGNAME": "user", "MANAGERPID": "2118", "MOZ_ENABLE_WAYLAND": "1", "PATH": "\/home\/user\/.pyenv\/versions\/3.12.0\/bin:\/home\/user\/.pyenv\/libexec:\/home\/user\/.pyenv\/plugins\/python-build\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-update\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-doctor\/bin:\/home\/user\/.cargo\/bin:\/home\/user\/.pyenv\/shims:\/home\/user\/.pyenv\/bin:\/home\/user\/bin:\/home\/user\/bin:\/usr\/local\/sbin:\/usr\/local\/bin:\/usr\/sbin:\/usr\/bin:\/sbin:\/bin:\/usr\/games:\/usr\/local\/games:\/snap\/bin:\/snap\/bin:\/home\/user\/.local\/share\/JetBrains\/Toolbox\/scripts", "PWD": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "PYENV_DIR": "\/home\/user\/Documents\/code\/atuin\/atuin-client", "PYENV_HOOK_PATH": "\/home\/user\/.pyenv\/pyenv.d:\/usr\/local\/etc\/pyenv.d:\/etc\/pyenv.d:\/usr\/lib\/pyenv\/hooks:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/etc\/pyenv.d", "PYENV_ROOT": "\/home\/user\/.pyenv", "PYENV_SHELL": "bash", "PYENV_VERSION": "3.12.0", "QT_ACCESSIBILITY": "1", "QT_IM_MODULE": "ibus", "SESSION_MANAGER": "local\/box:@\/tmp\/.ICE-unix\/2452,unix\/box:\/tmp\/.ICE-unix\/2452", "SHELL": "\/bin\/bash", "SHLVL": "1", "SSH_AGENT_LAUNCHER": "gnome-keyring", "SSH_AUTH_SOCK": "\/run\/user\/1000\/keyring\/ssh", "SSL_CERT_DIR": "\/usr\/lib\/ssl\/certs", "SSL_CERT_FILE": "\/usr\/lib\/ssl\/certs\/ca-certificates.crt", "SYSTEMD_EXEC_PID": "2470", "TERM": "xterm-256color", "TERM_PROGRAM": "WezTerm", "TERM_PROGRAM_VERSION": "20240127-113634-bbcac864", "THREAD_SUBPROCS": "1", "USER": "user", "USERNAME": "user", "WAYLAND_DISPLAY": "wayland-0", "WEZTERM_CONFIG_DIR": "\/home\/user\/.config\/wezterm", "WEZTERM_CONFIG_FILE": "\/home\/user\/.config\/wezterm\/wezterm.lua", "WEZTERM_EXECUTABLE": "\/usr\/bin\/wezterm-gui", "WEZTERM_EXECUTABLE_DIR": "\/usr\/bin", "WEZTERM_PANE": "41", "WEZTERM_UNIX_SOCKET": "\/run\/user\/1000\/wezterm\/gui-sock-196859", "XAUTHORITY": "\/run\/user\/1000\/.mutter-Xwaylandauth.T986H2", "XDG_CONFIG_DIRS": "\/etc\/xdg\/xdg-ubuntu:\/etc\/xdg", "XDG_CURRENT_DESKTOP": "ubuntu:GNOME", "XDG_DATA_DIRS": "\/usr\/share\/ubuntu:\/usr\/local\/share\/:\/usr\/share\/:\/var\/lib\/snapd\/desktop", "XDG_MENU_PREFIX": "gnome-", "XDG_RUNTIME_DIR": "\/run\/user\/1000", "XDG_SESSION_CLASS": "user", "XDG_SESSION_DESKTOP": "ubuntu", "XDG_SESSION_TYPE": "wayland", "XMODIFIERS": "@im=ibus", "XONSHRC": "\/etc\/xonsh\/xonshrc:\/home\/user\/.config\/xonsh\/rc.xsh:\/home\/user\/.xonshrc", "XONSHRC_DIR": "\/etc\/xonsh\/rc.d:\/home\/user\/.config\/xonsh\/rc.d", "XONSH_CAPTURE_ALWAYS": "", "XONSH_CONFIG_DIR": "\/home\/user\/.config\/xonsh", "XONSH_DATA_DIR": "\/home\/user\/.local\/share\/xonsh", "XONSH_INTERACTIVE": "1", "XONSH_LOGIN": "1", "XONSH_VERSION": "0.14.2"} +, "locked": false, "sessionid": "82eafbf5-9f43-489a-80d2-61c7dc6ef542", "ts": [1707241286.9361255, 1707241292.3081477] +} + +} diff --git a/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json b/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json new file mode 100644 index 00000000..72694f04 --- /dev/null +++ b/crates/atuin-client/tests/data/xonsh/xonsh-de16af90-9148-4461-8df3-5b5659c6420d.json @@ -0,0 +1,12 @@ +{"locs": [ 69, 3372, 3452, 3936], + "index": {"offsets":{"__total__":0,"cmds":[{"__total__":10,"cwd":18,"inp":64,"rtn":94,"ts":[104,124,103]},{"__total__":148,"cwd":156,"inp":202,"rtn":220,"ts":[230,250,229]},9],"env":{"ATUIN_SESSION":300,"BASH_COMPLETIONS":356,"COLORTERM":419,"DBUS_SESSION_BUS_ADDRESS":460,"DESKTOP_SESSION":515,"DISPLAY":536,"GDMSESSION":556,"GIO_LAUNCHED_DESKTOP_FILE":595,"GIO_LAUNCHED_DESKTOP_FILE_PID":690,"GJS_DEBUG_OUTPUT":720,"GJS_DEBUG_TOPICS":750,"GNOME_DESKTOP_SESSION_ID":797,"GNOME_SETUP_DISPLAY":842,"GNOME_SHELL_SESSION_MODE":876,"GTK_MODULES":901,"HOME":928,"IM_CONFIG_PHASE":962,"INVOCATION_ID":984,"JOURNAL_STREAM":1038,"LANG":1057,"LOGNAME":1083,"MANAGERPID":1104,"MOZ_ENABLE_WAYLAND":1134,"PATH":1147,"PWD":1722,"PYENV_DIR":1774,"PYENV_HOOK_PATH":1832,"PYENV_ROOT":2006,"PYENV_SHELL":2044,"PYENV_VERSION":2069,"QT_ACCESSIBILITY":2099,"QT_IM_MODULE":2120,"SESSION_MANAGER":2147,"SHELL":2237,"SHLVL":2261,"SSH_AGENT_LAUNCHER":2288,"SSH_AUTH_SOCK":2322,"SSL_CERT_DIR":2373,"SSL_CERT_FILE":2416,"SYSTEMD_EXEC_PID":2483,"TERM":2499,"TERM_PROGRAM":2533,"TERM_PROGRAM_VERSION":2568,"THREAD_SUBPROCS":2615,"USER":2628,"USERNAME":2647,"WAYLAND_DISPLAY":2673,"WEZTERM_CONFIG_DIR":2708,"WEZTERM_CONFIG_FILE":2764,"WEZTERM_EXECUTABLE":2832,"WEZTERM_EXECUTABLE_DIR":2885,"WEZTERM_PANE":2915,"WEZTERM_UNIX_SOCKET":2944,"XAUTHORITY":3005,"XDG_CONFIG_DIRS":3074,"XDG_CURRENT_DESKTOP":3134,"XDG_DATA_DIRS":3167,"XDG_MENU_PREFIX":3274,"XDG_RUNTIME_DIR":3303,"XDG_SESSION_CLASS":3345,"XDG_SESSION_DESKTOP":3376,"XDG_SESSION_TYPE":3406,"XMODIFIERS":3431,"XONSHRC":3454,"XONSHRC_DIR":3552,"XONSH_CAPTURE_ALWAYS":3632,"XONSH_CONFIG_DIR":3656,"XONSH_DATA_DIR":3705,"XONSH_INTERACTIVE":3763,"XONSH_LOGIN":3783,"XONSH_VERSION":3805,"__total__":282},"locked":3827,"sessionid":3847,"ts":[3894,3914,3893]},"sizes":{"__total__":3936,"cmds":[{"__total__":136,"cwd":37,"inp":21,"rtn":1,"ts":[18,18,41]},{"__total__":123,"cwd":37,"inp":9,"rtn":1,"ts":[18,17,40]},264],"env":{"ATUIN_SESSION":34,"BASH_COMPLETIONS":48,"COLORTERM":11,"DBUS_SESSION_BUS_ADDRESS":34,"DESKTOP_SESSION":8,"DISPLAY":4,"GDMSESSION":8,"GIO_LAUNCHED_DESKTOP_FILE":60,"GIO_LAUNCHED_DESKTOP_FILE_PID":8,"GJS_DEBUG_OUTPUT":8,"GJS_DEBUG_TOPICS":17,"GNOME_DESKTOP_SESSION_ID":20,"GNOME_SETUP_DISPLAY":4,"GNOME_SHELL_SESSION_MODE":8,"GTK_MODULES":17,"HOME":13,"IM_CONFIG_PHASE":3,"INVOCATION_ID":34,"JOURNAL_STREAM":9,"LANG":13,"LOGNAME":5,"MANAGERPID":6,"MOZ_ENABLE_WAYLAND":3,"PATH":566,"PWD":37,"PYENV_DIR":37,"PYENV_HOOK_PATH":158,"PYENV_ROOT":21,"PYENV_SHELL":6,"PYENV_VERSION":8,"QT_ACCESSIBILITY":3,"QT_IM_MODULE":6,"SESSION_MANAGER":79,"SHELL":13,"SHLVL":3,"SSH_AGENT_LAUNCHER":15,"SSH_AUTH_SOCK":33,"SSL_CERT_DIR":24,"SSL_CERT_FILE":45,"SYSTEMD_EXEC_PID":6,"TERM":16,"TERM_PROGRAM":9,"TERM_PROGRAM_VERSION":26,"THREAD_SUBPROCS":3,"USER":5,"USERNAME":5,"WAYLAND_DISPLAY":11,"WEZTERM_CONFIG_DIR":31,"WEZTERM_CONFIG_FILE":44,"WEZTERM_EXECUTABLE":25,"WEZTERM_EXECUTABLE_DIR":12,"WEZTERM_PANE":4,"WEZTERM_UNIX_SOCKET":45,"XAUTHORITY":48,"XDG_CONFIG_DIRS":35,"XDG_CURRENT_DESKTOP":14,"XDG_DATA_DIRS":86,"XDG_MENU_PREFIX":8,"XDG_RUNTIME_DIR":19,"XDG_SESSION_CLASS":6,"XDG_SESSION_DESKTOP":8,"XDG_SESSION_TYPE":9,"XMODIFIERS":10,"XONSHRC":81,"XONSHRC_DIR":54,"XONSH_CAPTURE_ALWAYS":2,"XONSH_CONFIG_DIR":29,"XONSH_DATA_DIR":35,"XONSH_INTERACTIVE":3,"XONSH_LOGIN":3,"XONSH_VERSION":8,"__total__":3533},"locked":5,"sessionid":38,"ts":[18,18,41]}}, + "data": {"cmds": [{"cwd": "\/home\/user\/Documents\/code\/atuin", "inp": "echo hello world!\n", "rtn": 0, "ts": [1707193079.4782722, 1707193079.4829233] +} +, {"cwd": "\/home\/user\/Documents\/code\/atuin", "inp": "ls -l\n", "rtn": 0, "ts": [1707193081.7063284, 1707193081.727617] +} +] +, "env": {"ATUIN_SESSION": "018d7ca2e953742e9826012f30115040", "BASH_COMPLETIONS": "\/usr\/share\/bash-completion\/bash_completion", "COLORTERM": "truecolor", "DBUS_SESSION_BUS_ADDRESS": "unix:path=\/run\/user\/1000\/bus", "DESKTOP_SESSION": "ubuntu", "DISPLAY": ":0", "GDMSESSION": "ubuntu", "GIO_LAUNCHED_DESKTOP_FILE": "\/usr\/share\/applications\/org.wezfurlong.wezterm.desktop", "GIO_LAUNCHED_DESKTOP_FILE_PID": "196859", "GJS_DEBUG_OUTPUT": "stderr", "GJS_DEBUG_TOPICS": "JS ERROR;JS LOG", "GNOME_DESKTOP_SESSION_ID": "this-is-deprecated", "GNOME_SETUP_DISPLAY": ":1", "GNOME_SHELL_SESSION_MODE": "ubuntu", "GTK_MODULES": "gail:atk-bridge", "HOME": "\/home\/user", "IM_CONFIG_PHASE": "1", "INVOCATION_ID": "4f121e7ad56c41a6b84aa3cbe1ad61fa", "JOURNAL_STREAM": "8:37187", "LANG": "en_US.UTF-8", "LOGNAME": "user", "MANAGERPID": "2118", "MOZ_ENABLE_WAYLAND": "1", "PATH": "\/home\/user\/.pyenv\/versions\/3.12.0\/bin:\/home\/user\/.pyenv\/libexec:\/home\/user\/.pyenv\/plugins\/python-build\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-update\/bin:\/home\/user\/.pyenv\/plugins\/pyenv-doctor\/bin:\/home\/user\/.cargo\/bin:\/home\/user\/.pyenv\/shims:\/home\/user\/.pyenv\/bin:\/home\/user\/bin:\/home\/user\/bin:\/usr\/local\/sbin:\/usr\/local\/bin:\/usr\/sbin:\/usr\/bin:\/sbin:\/bin:\/usr\/games:\/usr\/local\/games:\/snap\/bin:\/snap\/bin:\/home\/user\/.local\/share\/JetBrains\/Toolbox\/scripts", "PWD": "\/home\/user\/Documents\/code\/atuin", "PYENV_DIR": "\/home\/user\/Documents\/code\/atuin", "PYENV_HOOK_PATH": "\/home\/user\/.pyenv\/pyenv.d:\/usr\/local\/etc\/pyenv.d:\/etc\/pyenv.d:\/usr\/lib\/pyenv\/hooks:\/home\/user\/.pyenv\/plugins\/pyenv-virtualenv\/etc\/pyenv.d", "PYENV_ROOT": "\/home\/user\/.pyenv", "PYENV_SHELL": "bash", "PYENV_VERSION": "3.12.0", "QT_ACCESSIBILITY": "1", "QT_IM_MODULE": "ibus", "SESSION_MANAGER": "local\/box:@\/tmp\/.ICE-unix\/2452,unix\/box:\/tmp\/.ICE-unix\/2452", "SHELL": "\/bin\/bash", "SHLVL": "1", "SSH_AGENT_LAUNCHER": "gnome-keyring", "SSH_AUTH_SOCK": "\/run\/user\/1000\/keyring\/ssh", "SSL_CERT_DIR": "\/usr\/lib\/ssl\/certs", "SSL_CERT_FILE": "\/usr\/lib\/ssl\/certs\/ca-certificates.crt", "SYSTEMD_EXEC_PID": "2470", "TERM": "xterm-256color", "TERM_PROGRAM": "WezTerm", "TERM_PROGRAM_VERSION": "20240127-113634-bbcac864", "THREAD_SUBPROCS": "1", "USER": "user", "USERNAME": "user", "WAYLAND_DISPLAY": "wayland-0", "WEZTERM_CONFIG_DIR": "\/home\/user\/.config\/wezterm", "WEZTERM_CONFIG_FILE": "\/home\/user\/.config\/wezterm\/wezterm.lua", "WEZTERM_EXECUTABLE": "\/usr\/bin\/wezterm-gui", "WEZTERM_EXECUTABLE_DIR": "\/usr\/bin", "WEZTERM_PANE": "38", "WEZTERM_UNIX_SOCKET": "\/run\/user\/1000\/wezterm\/gui-sock-196859", "XAUTHORITY": "\/run\/user\/1000\/.mutter-Xwaylandauth.T986H2", "XDG_CONFIG_DIRS": "\/etc\/xdg\/xdg-ubuntu:\/etc\/xdg", "XDG_CURRENT_DESKTOP": "ubuntu:GNOME", "XDG_DATA_DIRS": "\/usr\/share\/ubuntu:\/usr\/local\/share\/:\/usr\/share\/:\/var\/lib\/snapd\/desktop", "XDG_MENU_PREFIX": "gnome-", "XDG_RUNTIME_DIR": "\/run\/user\/1000", "XDG_SESSION_CLASS": "user", "XDG_SESSION_DESKTOP": "ubuntu", "XDG_SESSION_TYPE": "wayland", "XMODIFIERS": "@im=ibus", "XONSHRC": "\/etc\/xonsh\/xonshrc:\/home\/user\/.config\/xonsh\/rc.xsh:\/home\/user\/.xonshrc", "XONSHRC_DIR": "\/etc\/xonsh\/rc.d:\/home\/user\/.config\/xonsh\/rc.d", "XONSH_CAPTURE_ALWAYS": "", "XONSH_CONFIG_DIR": "\/home\/user\/.config\/xonsh", "XONSH_DATA_DIR": "\/home\/user\/.local\/share\/xonsh", "XONSH_INTERACTIVE": "1", "XONSH_LOGIN": "1", "XONSH_VERSION": "0.14.2"} +, "locked": false, "sessionid": "de16af90-9148-4461-8df3-5b5659c6420d", "ts": [1707193067.8615997, 1707193089.2513068] +} + +} diff --git a/crates/atuin-common/Cargo.toml b/crates/atuin-common/Cargo.toml new file mode 100644 index 00000000..85e41ef6 --- /dev/null +++ b/crates/atuin-common/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "atuin-common" +edition = "2021" +description = "common library for atuin" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +time = { workspace = true } +serde = { workspace = true } +uuid = { workspace = true } +rand = { workspace = true } +typed-builder = { workspace = true } +eyre = { workspace = true } +sqlx = { workspace = true } +semver = { workspace = true } +thiserror = { workspace = true } +sysinfo = "0.30.7" + +lazy_static = "1.4.0" + +[dev-dependencies] +pretty_assertions = { workspace = true } diff --git a/crates/atuin-common/src/api.rs b/crates/atuin-common/src/api.rs new file mode 100644 index 00000000..99b57cec --- /dev/null +++ b/crates/atuin-common/src/api.rs @@ -0,0 +1,122 @@ +use lazy_static::lazy_static; +use semver::Version; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use time::OffsetDateTime; + +// the usage of X- has been deprecated for quite along time, it turns out +pub static ATUIN_HEADER_VERSION: &str = "Atuin-Version"; +pub static ATUIN_CARGO_VERSION: &str = env!("CARGO_PKG_VERSION"); + +lazy_static! { + pub static ref ATUIN_VERSION: Version = + Version::parse(ATUIN_CARGO_VERSION).expect("failed to parse self semver"); +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct UserResponse { + pub username: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterRequest { + pub email: String, + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterResponse { + pub session: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteUserResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordRequest { + pub current_password: String, + pub new_password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChangePasswordResponse {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginRequest { + pub username: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginResponse { + pub session: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AddHistoryRequest { + pub id: String, + #[serde(with = "time::serde::rfc3339")] + pub timestamp: OffsetDateTime, + pub data: String, + pub hostname: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountResponse { + pub count: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryRequest { + #[serde(with = "time::serde::rfc3339")] + pub sync_ts: OffsetDateTime, + #[serde(with = "time::serde::rfc3339")] + pub history_ts: OffsetDateTime, + pub host: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SyncHistoryResponse { + pub history: Vec<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponse<'a> { + pub reason: Cow<'a, str>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexResponse { + pub homage: String, + pub version: String, + pub total_history: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusResponse { + pub count: i64, + pub username: String, + pub deleted: Vec<String>, + + // These could/should also go on the index of the server + // However, we do not request the server index as a part of normal sync + // I'd rather slightly increase the size of this response, than add an extra HTTP request + pub page_size: i64, // max page size supported by the server + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteHistoryRequest { + pub client_id: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageResponse { + pub message: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MeResponse { + pub username: String, +} diff --git a/crates/atuin-common/src/calendar.rs b/crates/atuin-common/src/calendar.rs new file mode 100644 index 00000000..d3b1d921 --- /dev/null +++ b/crates/atuin-common/src/calendar.rs @@ -0,0 +1,16 @@ +// Calendar data +use serde::{Serialize, Deserialize}; + +pub enum TimePeriod { + YEAR, + MONTH, + DAY, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/atuin-common/src/lib.rs b/crates/atuin-common/src/lib.rs new file mode 100644 index 00000000..2d848f6f --- /dev/null +++ b/crates/atuin-common/src/lib.rs @@ -0,0 +1,58 @@ +#![forbid(unsafe_code)] + +/// Defines a new UUID type wrapper +macro_rules! new_uuid { + ($name:ident) => { + #[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + PartialOrd, + Ord, + serde::Serialize, + serde::Deserialize, + )] + #[serde(transparent)] + pub struct $name(pub Uuid); + + impl<DB: sqlx::Database> sqlx::Type<DB> for $name + where + Uuid: sqlx::Type<DB>, + { + fn type_info() -> <DB as sqlx::Database>::TypeInfo { + Uuid::type_info() + } + } + + impl<'r, DB: sqlx::Database> sqlx::Decode<'r, DB> for $name + where + Uuid: sqlx::Decode<'r, DB>, + { + fn decode( + value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef, + ) -> std::result::Result<Self, sqlx::error::BoxDynError> { + Uuid::decode(value).map(Self) + } + } + + impl<'q, DB: sqlx::Database> sqlx::Encode<'q, DB> for $name + where + Uuid: sqlx::Encode<'q, DB>, + { + fn encode_by_ref( + &self, + buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer, + ) -> sqlx::encode::IsNull { + self.0.encode_by_ref(buf) + } + } + }; +} + +pub mod api; +pub mod record; +pub mod shell; +pub mod utils; diff --git a/crates/atuin-common/src/record.rs b/crates/atuin-common/src/record.rs new file mode 100644 index 00000000..e6ce2647 --- /dev/null +++ b/crates/atuin-common/src/record.rs @@ -0,0 +1,426 @@ +use std::collections::HashMap; + +use eyre::Result; +use serde::{Deserialize, Serialize}; +use typed_builder::TypedBuilder; +use uuid::Uuid; + +#[derive(Clone, Debug, PartialEq)] +pub struct DecryptedData(pub Vec<u8>); + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct EncryptedData { + pub data: String, + pub content_encryption_key: String, +} + +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq)] +pub struct Diff { + pub host: HostId, + pub tag: String, + pub local: Option<RecordIdx>, + pub remote: Option<RecordIdx>, +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub struct Host { + pub id: HostId, + pub name: String, +} + +impl Host { + pub fn new(id: HostId) -> Self { + Host { + id, + name: String::new(), + } + } +} + +new_uuid!(RecordId); +new_uuid!(HostId); + +pub type RecordIdx = u64; + +/// A single record stored inside of our local database +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +pub struct Record<Data> { + /// a unique ID + #[builder(default = RecordId(crate::utils::uuid_v7()))] + pub id: RecordId, + + /// The integer record ID. This is only unique per (host, tag). + pub idx: RecordIdx, + + /// The unique ID of the host. + // TODO(ellie): Optimize the storage here. We use a bunch of IDs, and currently store + // as strings. I would rather avoid normalization, so store as UUID binary instead of + // encoding to a string and wasting much more storage. + pub host: Host, + + /// The creation time in nanoseconds since unix epoch + #[builder(default = time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)] + pub timestamp: u64, + + /// The version the data in the entry conforms to + // However we want to track versions for this tag, eg v2 + pub version: String, + + /// The type of data we are storing here. Eg, "history" + pub tag: String, + + /// Some data. This can be anything you wish to store. Use the tag field to know how to handle it. + pub data: Data, +} + +/// Extra data from the record that should be encoded in the data +#[derive(Debug, Copy, Clone)] +pub struct AdditionalData<'a> { + pub id: &'a RecordId, + pub idx: &'a u64, + pub version: &'a str, + pub tag: &'a str, + pub host: &'a HostId, +} + +impl<Data> Record<Data> { + pub fn append(&self, data: Vec<u8>) -> Record<DecryptedData> { + Record::builder() + .host(self.host.clone()) + .version(self.version.clone()) + .idx(self.idx + 1) + .tag(self.tag.clone()) + .data(DecryptedData(data)) + .build() + } +} + +/// An index representing the current state of the record stores +/// This can be both remote, or local, and compared in either direction +#[derive(Debug, Serialize, Deserialize)] +pub struct RecordStatus { + // A map of host -> tag -> max(idx) + pub hosts: HashMap<HostId, HashMap<String, RecordIdx>>, +} + +impl Default for RecordStatus { + fn default() -> Self { + Self::new() + } +} + +impl Extend<(HostId, String, RecordIdx)> for RecordStatus { + fn extend<T: IntoIterator<Item = (HostId, String, RecordIdx)>>(&mut self, iter: T) { + for (host, tag, tail_idx) in iter { + self.set_raw(host, tag, tail_idx); + } + } +} + +impl RecordStatus { + pub fn new() -> RecordStatus { + RecordStatus { + hosts: HashMap::new(), + } + } + + /// Insert a new tail record into the store + pub fn set(&mut self, tail: Record<DecryptedData>) { + self.set_raw(tail.host.id, tail.tag, tail.idx) + } + + pub fn set_raw(&mut self, host: HostId, tag: String, tail_id: RecordIdx) { + self.hosts.entry(host).or_default().insert(tag, tail_id); + } + + pub fn get(&self, host: HostId, tag: String) -> Option<RecordIdx> { + self.hosts.get(&host).and_then(|v| v.get(&tag)).cloned() + } + + /// Diff this index with another, likely remote index. + /// The two diffs can then be reconciled, and the optimal change set calculated + /// Returns a tuple, with (host, tag, Option(OTHER)) + /// OTHER is set to the value of the idx on the other machine. If it is greater than our index, + /// then we need to do some downloading. If it is smaller, then we need to do some uploading + /// Note that we cannot upload if we are not the owner of the record store - hosts can only + /// write to their own store. + pub fn diff(&self, other: &Self) -> Vec<Diff> { + let mut ret = Vec::new(); + + // First, we check if other has everything that self has + for (host, tag_map) in self.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match other.get(*host, tag.clone()) { + // The other store is all up to date! No diff. + Some(t) if t.eq(idx) => continue, + + // The other store does exist, and it is either ahead or behind us. A diff regardless + Some(t) => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: Some(t), + }), + + // The other store does not exist :O + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + local: Some(*idx), + remote: None, + }), + }; + } + } + + // At this point, there is a single case we have not yet considered. + // If the other store knows of a tag that we are not yet aware of, then the diff will be missed + + // account for that! + for (host, tag_map) in other.hosts.iter() { + for (tag, idx) in tag_map.iter() { + match self.get(*host, tag.clone()) { + // If we have this host/tag combo, the comparison and diff will have already happened above + Some(_) => continue, + + None => ret.push(Diff { + host: *host, + tag: tag.clone(), + remote: Some(*idx), + local: None, + }), + }; + } + } + + // Stability is a nice property to have + ret.sort(); + ret + } +} + +pub trait Encryption { + fn re_encrypt( + data: EncryptedData, + ad: AdditionalData, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<EncryptedData> { + let data = Self::decrypt(data, ad, old_key)?; + Ok(Self::encrypt(data, ad, new_key)) + } + fn encrypt(data: DecryptedData, ad: AdditionalData, key: &[u8; 32]) -> EncryptedData; + fn decrypt(data: EncryptedData, ad: AdditionalData, key: &[u8; 32]) -> Result<DecryptedData>; +} + +impl Record<DecryptedData> { + pub fn encrypt<E: Encryption>(self, key: &[u8; 32]) -> Record<EncryptedData> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Record { + data: E::encrypt(self.data, ad, key), + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + } + } +} + +impl Record<EncryptedData> { + pub fn decrypt<E: Encryption>(self, key: &[u8; 32]) -> Result<Record<DecryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::decrypt(self.data, ad, key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } + + pub fn re_encrypt<E: Encryption>( + self, + old_key: &[u8; 32], + new_key: &[u8; 32], + ) -> Result<Record<EncryptedData>> { + let ad = AdditionalData { + id: &self.id, + version: &self.version, + tag: &self.tag, + host: &self.host.id, + idx: &self.idx, + }; + Ok(Record { + data: E::re_encrypt(self.data, ad, old_key, new_key)?, + id: self.id, + host: self.host, + idx: self.idx, + timestamp: self.timestamp, + version: self.version, + tag: self.tag, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::record::{Host, HostId}; + + use super::{DecryptedData, Diff, Record, RecordStatus}; + use pretty_assertions::assert_eq; + + fn test_record() -> Record<DecryptedData> { + Record::builder() + .host(Host::new(HostId(crate::utils::uuid_v7()))) + .version("v1".into()) + .tag(crate::utils::uuid_v7().simple().to_string()) + .data(DecryptedData(vec![0, 1, 2, 3])) + .idx(0) + .build() + } + + #[test] + fn record_index() { + let mut index = RecordStatus::new(); + let record = test_record(); + + index.set(record.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + record.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_overwrite() { + let mut index = RecordStatus::new(); + let record = test_record(); + let child = record.append(vec![1, 2, 3]); + + index.set(record.clone()); + index.set(child.clone()); + + let tail = index.get(record.host.id, record.tag); + + assert_eq!( + child.idx, + tail.expect("tail not in store"), + "tail in store did not match" + ); + } + + #[test] + fn record_index_no_diff() { + // Here, they both have the same version and should have no diff + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + + index1.set(record1.clone()); + index2.set(record1); + + let diff = index1.diff(&index2); + + assert_eq!(0, diff.len(), "expected empty diff"); + } + + #[test] + fn record_index_single_diff() { + // Here, they both have the same stores, but one is ahead by a single record + + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let record1 = test_record(); + let record2 = record1.append(vec![1, 2, 3]); + + index1.set(record1); + index2.set(record2.clone()); + + let diff = index1.diff(&index2); + + assert_eq!(1, diff.len(), "expected single diff"); + assert_eq!( + diff[0], + Diff { + host: record2.host.id, + tag: record2.tag, + remote: Some(1), + local: Some(0) + } + ); + } + + #[test] + fn record_index_multi_diff() { + // A much more complex case, with a bunch more checks + let mut index1 = RecordStatus::new(); + let mut index2 = RecordStatus::new(); + + let store1record1 = test_record(); + let store1record2 = store1record1.append(vec![1, 2, 3]); + + let store2record1 = test_record(); + let store2record2 = store2record1.append(vec![1, 2, 3]); + + let store3record1 = test_record(); + + let store4record1 = test_record(); + + // index1 only knows about the first two entries of the first two stores + index1.set(store1record1); + index1.set(store2record1); + + // index2 is fully up to date with the first two stores, and knows of a third + index2.set(store1record2); + index2.set(store2record2); + index2.set(store3record1); + + // index1 knows of a 4th store + index1.set(store4record1); + + let diff1 = index1.diff(&index2); + let diff2 = index2.diff(&index1); + + // both diffs the same length + assert_eq!(4, diff1.len()); + assert_eq!(4, diff2.len()); + + dbg!(&diff1, &diff2); + + // both diffs should be ALMOST the same. They will agree on which hosts and tags + // require updating, but the "other" value will not be the same. + let smol_diff_1: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + let smol_diff_2: Vec<(HostId, String)> = + diff1.iter().map(|v| (v.host, v.tag.clone())).collect(); + + assert_eq!(smol_diff_1, smol_diff_2); + + // diffing with yourself = no diff + assert_eq!(index1.diff(&index1).len(), 0); + assert_eq!(index2.diff(&index2).len(), 0); + } +} diff --git a/crates/atuin-common/src/shell.rs b/crates/atuin-common/src/shell.rs new file mode 100644 index 00000000..42e32f72 --- /dev/null +++ b/crates/atuin-common/src/shell.rs @@ -0,0 +1,147 @@ +use std::{ffi::OsStr, path::Path, process::Command}; + +use serde::Serialize; +use sysinfo::{get_current_pid, Process, System}; +use thiserror::Error; + +pub enum Shell { + Sh, + Bash, + Fish, + Zsh, + Xonsh, + Nu, + + Unknown, +} + +impl std::fmt::Display for Shell { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let shell = match self { + Shell::Bash => "bash", + Shell::Fish => "fish", + Shell::Zsh => "zsh", + Shell::Nu => "nu", + Shell::Xonsh => "xonsh", + Shell::Sh => "sh", + + Shell::Unknown => "unknown", + }; + + write!(f, "{}", shell) + } +} + +#[derive(Debug, Error, Serialize)] +pub enum ShellError { + #[error("shell not supported")] + NotSupported, + + #[error("failed to execute shell command: {0}")] + ExecError(String), +} + +impl Shell { + pub fn current() -> Shell { + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + Shell::from_string(shell.to_string()) + } + + /// Best-effort attempt to determine the default shell + /// This implementation will be different across different platforms + /// Caller should ensure to handle Shell::Unknown correctly + pub fn default_shell() -> Result<Shell, ShellError> { + let sys = System::name().unwrap_or("".to_string()).to_lowercase(); + + // TODO: Support Linux + // I'm pretty sure we can use /etc/passwd there, though there will probably be some issues + if sys.contains("darwin") { + // This works in my testing so far + let path = Shell::Sh.run_interactive([ + "dscl localhost -read \"/Local/Default/Users/$USER\" shell | awk '{print $2}'", + ])?; + + let path = Path::new(path.trim()); + + let shell = path.file_name(); + + if shell.is_none() { + return Err(ShellError::NotSupported); + } + + Ok(Shell::from_string( + shell.unwrap().to_string_lossy().to_string(), + )) + } else { + Err(ShellError::NotSupported) + } + } + + pub fn from_string(name: String) -> Shell { + match name.as_str() { + "bash" => Shell::Bash, + "fish" => Shell::Fish, + "zsh" => Shell::Zsh, + "xonsh" => Shell::Xonsh, + "nu" => Shell::Nu, + "sh" => Shell::Sh, + + _ => Shell::Unknown, + } + } + + /// Returns true if the shell is posix-like + /// Note that while fish is not posix compliant, it behaves well enough for our current + /// featureset that this does not matter. + pub fn is_posixish(&self) -> bool { + matches!(self, Shell::Bash | Shell::Fish | Shell::Zsh) + } + + pub fn run_interactive<I, S>(&self, args: I) -> Result<String, ShellError> + where + I: IntoIterator<Item = S>, + S: AsRef<OsStr>, + { + let shell = self.to_string(); + + let output = Command::new(shell) + .arg("-ic") + .args(args) + .output() + .map_err(|e| ShellError::ExecError(e.to_string()))?; + + Ok(String::from_utf8(output.stdout).unwrap()) + } +} + +pub fn shell_name(parent: Option<&Process>) -> String { + let sys = System::new_all(); + + let parent = if let Some(parent) = parent { + parent + } else { + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + sys.process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist") + }; + + let shell = parent.name().trim().to_lowercase(); + let shell = shell.strip_prefix('-').unwrap_or(&shell); + + shell.to_string() +} diff --git a/crates/atuin-common/src/utils.rs b/crates/atuin-common/src/utils.rs new file mode 100644 index 00000000..7c533663 --- /dev/null +++ b/crates/atuin-common/src/utils.rs @@ -0,0 +1,265 @@ +use std::borrow::Cow; +use std::env; +use std::path::PathBuf; + +use rand::RngCore; +use uuid::Uuid; + +pub fn random_bytes<const N: usize>() -> [u8; N] { + let mut ret = [0u8; N]; + + rand::thread_rng().fill_bytes(&mut ret); + + ret +} + +pub fn uuid_v7() -> Uuid { + Uuid::now_v7() +} + +pub fn uuid_v4() -> String { + Uuid::new_v4().as_simple().to_string() +} + +pub fn has_git_dir(path: &str) -> bool { + let mut gitdir = PathBuf::from(path); + gitdir.push(".git"); + + gitdir.exists() +} + +// detect if any parent dir has a git repo in it +// I really don't want to bring in libgit for something simple like this +// If we start to do anything more advanced, then perhaps +pub fn in_git_repo(path: &str) -> Option<PathBuf> { + let mut gitdir = PathBuf::from(path); + + while gitdir.parent().is_some() && !has_git_dir(gitdir.to_str().unwrap()) { + gitdir.pop(); + } + + // No parent? then we hit root, finding no git + if gitdir.parent().is_some() { + return Some(gitdir); + } + + None +} + +// TODO: more reliable, more tested +// I don't want to use ProjectDirs, it puts config in awkward places on +// mac. Data too. Seems to be more intended for GUI apps. + +#[cfg(not(target_os = "windows"))] +pub fn home_dir() -> PathBuf { + let home = std::env::var("HOME").expect("$HOME not found"); + PathBuf::from(home) +} + +#[cfg(target_os = "windows")] +pub fn home_dir() -> PathBuf { + let home = std::env::var("USERPROFILE").expect("%userprofile% not found"); + PathBuf::from(home) +} + +pub fn config_dir() -> PathBuf { + let config_dir = + std::env::var("XDG_CONFIG_HOME").map_or_else(|_| home_dir().join(".config"), PathBuf::from); + config_dir.join("atuin") +} + +pub fn data_dir() -> PathBuf { + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin") +} + +pub fn dotfiles_cache_dir() -> PathBuf { + // In most cases, this will be ~/.local/share/atuin/dotfiles/cache + let data_dir = std::env::var("XDG_DATA_HOME") + .map_or_else(|_| home_dir().join(".local").join("share"), PathBuf::from); + + data_dir.join("atuin").join("dotfiles").join("cache") +} + +pub fn get_current_dir() -> String { + // Prefer PWD environment variable over cwd if available to better support symbolic links + match env::var("PWD") { + Ok(v) => v, + Err(_) => match env::current_dir() { + Ok(dir) => dir.display().to_string(), + Err(_) => String::from(""), + }, + } +} + +pub fn is_zsh() -> bool { + // only set on zsh + env::var("ATUIN_SHELL_ZSH").is_ok() +} + +pub fn is_fish() -> bool { + // only set on fish + env::var("ATUIN_SHELL_FISH").is_ok() +} + +pub fn is_bash() -> bool { + // only set on bash + env::var("ATUIN_SHELL_BASH").is_ok() +} + +pub fn is_xonsh() -> bool { + // only set on xonsh + env::var("ATUIN_SHELL_XONSH").is_ok() +} + +/// Extension trait for anything that can behave like a string to make it easy to escape control +/// characters. +/// +/// Intended to help prevent control characters being printed and interpreted by the terminal when +/// printing history as well as to ensure the commands that appear in the interactive search +/// reflect the actual command run rather than just the printable characters. +pub trait Escapable: AsRef<str> { + fn escape_control(&self) -> Cow<str> { + if !self.as_ref().contains(|c: char| c.is_ascii_control()) { + self.as_ref().into() + } else { + let mut remaining = self.as_ref(); + // Not a perfect way to reserve space but should reduce the allocations + let mut buf = String::with_capacity(remaining.as_bytes().len()); + while let Some(i) = remaining.find(|c: char| c.is_ascii_control()) { + // safe to index with `..i`, `i` and `i+1..` as part[i] is a single byte ascii char + buf.push_str(&remaining[..i]); + buf.push('^'); + buf.push(match remaining.as_bytes()[i] { + 0x7F => '?', + code => char::from_u32(u32::from(code) + 64).unwrap(), + }); + remaining = &remaining[i + 1..]; + } + buf.push_str(remaining); + buf.into() + } + } +} + +impl<T: AsRef<str>> Escapable for T {} + +#[cfg(test)] +mod tests { + use time::Month; + + use super::*; + use std::env; + + use std::collections::HashSet; + + #[cfg(not(windows))] + #[test] + fn test_dirs() { + // these tests need to be run sequentially to prevent race condition + test_config_dir_xdg(); + test_config_dir(); + test_data_dir_xdg(); + test_data_dir(); + } + + fn test_config_dir_xdg() { + env::remove_var("HOME"); + env::set_var("XDG_CONFIG_HOME", "/home/user/custom_config"); + assert_eq!( + config_dir(), + PathBuf::from("/home/user/custom_config/atuin") + ); + env::remove_var("XDG_CONFIG_HOME"); + } + + fn test_config_dir() { + env::set_var("HOME", "/home/user"); + env::remove_var("XDG_CONFIG_HOME"); + + assert_eq!(config_dir(), PathBuf::from("/home/user/.config/atuin")); + + env::remove_var("HOME"); + } + + fn test_data_dir_xdg() { + env::remove_var("HOME"); + env::set_var("XDG_DATA_HOME", "/home/user/custom_data"); + assert_eq!(data_dir(), PathBuf::from("/home/user/custom_data/atuin")); + env::remove_var("XDG_DATA_HOME"); + } + + fn test_data_dir() { + env::set_var("HOME", "/home/user"); + env::remove_var("XDG_DATA_HOME"); + assert_eq!(data_dir(), PathBuf::from("/home/user/.local/share/atuin")); + env::remove_var("HOME"); + } + + #[test] + fn days_from_month() { + assert_eq!(time::util::days_in_year_month(2023, Month::January), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::February), 28); + assert_eq!(time::util::days_in_year_month(2023, Month::March), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::April), 30); + assert_eq!(time::util::days_in_year_month(2023, Month::May), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::June), 30); + assert_eq!(time::util::days_in_year_month(2023, Month::July), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::August), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::September), 30); + assert_eq!(time::util::days_in_year_month(2023, Month::October), 31); + assert_eq!(time::util::days_in_year_month(2023, Month::November), 30); + assert_eq!(time::util::days_in_year_month(2023, Month::December), 31); + + // leap years + assert_eq!(time::util::days_in_year_month(2024, Month::February), 29); + } + + #[test] + fn uuid_is_unique() { + let how_many: usize = 1000000; + + // for peace of mind + let mut uuids: HashSet<Uuid> = HashSet::with_capacity(how_many); + + // there will be many in the same millisecond + for _ in 0..how_many { + let uuid = uuid_v7(); + uuids.insert(uuid); + } + + assert_eq!(uuids.len(), how_many); + } + + #[test] + fn escape_control_characters() { + use super::Escapable; + // CSI colour sequence + assert_eq!("\x1b[31mfoo".escape_control(), "^[[31mfoo"); + + // Tabs count as control chars + assert_eq!("foo\tbar".escape_control(), "foo^Ibar"); + + // space is in control char range but should be excluded + assert_eq!("two words".escape_control(), "two words"); + + // unicode multi-byte characters + let s = "🐢\x1b[32m🦀"; + assert_eq!(s.escape_control(), s.replace("\x1b", "^[")); + } + + #[test] + fn escape_no_control_characters() { + use super::Escapable as _; + assert!(matches!( + "no control characters".escape_control(), + Cow::Borrowed(_) + )); + assert!(matches!( + "with \x1b[31mcontrol\x1b[0m characters".escape_control(), + Cow::Owned(_) + )); + } +} diff --git a/crates/atuin-dotfiles/Cargo.toml b/crates/atuin-dotfiles/Cargo.toml new file mode 100644 index 00000000..1bd16223 --- /dev/null +++ b/crates/atuin-dotfiles/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "atuin-dotfiles" +description = "The dotfiles crate for Atuin" +edition = "2021" +version = "0.2.0" # intentionally not the same as the rest + +authors.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +readme.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } +atuin-client = { path = "../atuin-client", version = "18.2.0" } + +eyre = { workspace = true } +tokio = { workspace = true } +rmp = { version = "0.8.11" } +rand = { workspace = true } +serde = { workspace = true } +crypto_secretbox = "0.1.1" diff --git a/crates/atuin-dotfiles/src/lib.rs b/crates/atuin-dotfiles/src/lib.rs new file mode 100644 index 00000000..74daf8ef --- /dev/null +++ b/crates/atuin-dotfiles/src/lib.rs @@ -0,0 +1,2 @@ +pub mod shell; +pub mod store; diff --git a/crates/atuin-dotfiles/src/shell.rs b/crates/atuin-dotfiles/src/shell.rs new file mode 100644 index 00000000..7912bc34 --- /dev/null +++ b/crates/atuin-dotfiles/src/shell.rs @@ -0,0 +1,100 @@ +use eyre::Result; +use serde::Serialize; + +use atuin_common::shell::{Shell, ShellError}; + +use crate::store::AliasStore; + +pub mod bash; +pub mod fish; +pub mod xonsh; +pub mod zsh; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct Alias { + pub name: String, + pub value: String, +} + +pub fn parse_alias(line: &str) -> Alias { + let mut parts = line.split('='); + + let name = parts.next().unwrap().to_string(); + let remaining = parts.collect::<Vec<&str>>().join("=").to_string(); + + Alias { + name, + value: remaining, + } +} + +pub fn existing_aliases(shell: Option<Shell>) -> Result<Vec<Alias>, ShellError> { + let shell = if let Some(shell) = shell { + shell + } else { + Shell::current() + }; + + // this only supports posix-y shells atm + if !shell.is_posixish() { + return Err(ShellError::NotSupported); + } + + // This will return a list of aliases, each on its own line + // They will be in the form foo=bar + let aliases = shell.run_interactive(["alias"])?; + let aliases: Vec<Alias> = aliases.lines().map(parse_alias).collect(); + + Ok(aliases) +} + +/// Import aliases from the current shell +/// This will not import aliases already in the store +/// Returns aliases that were set +pub async fn import_aliases(store: AliasStore) -> Result<Vec<Alias>> { + let shell_aliases = existing_aliases(None)?; + let store_aliases = store.aliases().await?; + + let mut res = Vec::new(); + + for alias in shell_aliases { + // O(n), but n is small, and imports infrequent + // can always make a map + if store_aliases.contains(&alias) { + continue; + } + + res.push(alias.clone()); + store.set(&alias.name, &alias.value).await?; + } + + Ok(res) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_parse_simple_alias() { + let alias = super::parse_alias("foo=bar"); + assert_eq!(alias.name, "foo"); + assert_eq!(alias.value, "bar"); + } + + #[test] + fn test_parse_quoted_alias() { + let alias = super::parse_alias("emacs='TERM=xterm-24bits emacs -nw'"); + assert_eq!(alias.name, "emacs"); + assert_eq!(alias.value, "'TERM=xterm-24bits emacs -nw'"); + + let git_alias = super::parse_alias("gwip='git add -A; git rm $(git ls-files --deleted) 2> /dev/null; git commit --no-verify --no-gpg-sign --message \"--wip-- [skip ci]\"'"); + assert_eq!(git_alias.name, "gwip"); + assert_eq!(git_alias.value, "'git add -A; git rm $(git ls-files --deleted) 2> /dev/null; git commit --no-verify --no-gpg-sign --message \"--wip-- [skip ci]\"'"); + } + + #[test] + fn test_parse_quoted_alias_equals() { + let alias = super::parse_alias("emacs='TERM=xterm-24bits emacs -nw --foo=bar'"); + assert_eq!(alias.name, "emacs"); + assert_eq!(alias.value, "'TERM=xterm-24bits emacs -nw --foo=bar'"); + } +} diff --git a/crates/atuin-dotfiles/src/shell/bash.rs b/crates/atuin-dotfiles/src/shell/bash.rs new file mode 100644 index 00000000..5bdd7dce --- /dev/null +++ b/crates/atuin-dotfiles/src/shell/bash.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::store::AliasStore; + +async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { + match tokio::fs::read_to_string(path).await { + Ok(aliases) => aliases, + Err(r) => { + // we failed to read the file for some reason, but the file does exist + // fallback to generating new aliases on the fly + + store.posix().await.unwrap_or_else(|e| { + format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) + }) + } + } +} + +/// Return bash dotfile config +/// +/// Do not return an error. We should not prevent the shell from starting. +/// +/// In the worst case, Atuin should not function but the shell should start correctly. +/// +/// While currently this only returns aliases, it will be extended to also return other synced dotfiles +pub async fn config(store: &AliasStore) -> String { + // First try to read the cached config + let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.bash"); + + if aliases.exists() { + return cached_aliases(aliases, store).await; + } + + if let Err(e) = store.build().await { + return format!("echo 'Atuin: failed to generate aliases: {}'", e); + } + + cached_aliases(aliases, store).await +} diff --git a/crates/atuin-dotfiles/src/shell/fish.rs b/crates/atuin-dotfiles/src/shell/fish.rs new file mode 100644 index 00000000..bf4e1a3b --- /dev/null +++ b/crates/atuin-dotfiles/src/shell/fish.rs @@ -0,0 +1,40 @@ +// Configuration for fish +use std::path::PathBuf; + +use crate::store::AliasStore; + +async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { + match tokio::fs::read_to_string(path).await { + Ok(aliases) => aliases, + Err(r) => { + // we failed to read the file for some reason, but the file does exist + // fallback to generating new aliases on the fly + + store.posix().await.unwrap_or_else(|e| { + format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) + }) + } + } +} + +/// Return fish dotfile config +/// +/// Do not return an error. We should not prevent the shell from starting. +/// +/// In the worst case, Atuin should not function but the shell should start correctly. +/// +/// While currently this only returns aliases, it will be extended to also return other synced dotfiles +pub async fn config(store: &AliasStore) -> String { + // First try to read the cached config + let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.fish"); + + if aliases.exists() { + return cached_aliases(aliases, store).await; + } + + if let Err(e) = store.build().await { + return format!("echo 'Atuin: failed to generate aliases: {}'", e); + } + + cached_aliases(aliases, store).await +} diff --git a/crates/atuin-dotfiles/src/shell/xonsh.rs b/crates/atuin-dotfiles/src/shell/xonsh.rs new file mode 100644 index 00000000..383df4ec --- /dev/null +++ b/crates/atuin-dotfiles/src/shell/xonsh.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::store::AliasStore; + +async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { + match tokio::fs::read_to_string(path).await { + Ok(aliases) => aliases, + Err(r) => { + // we failed to read the file for some reason, but the file does exist + // fallback to generating new aliases on the fly + + store.xonsh().await.unwrap_or_else(|e| { + format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) + }) + } + } +} + +/// Return xonsh dotfile config +/// +/// Do not return an error. We should not prevent the shell from starting. +/// +/// In the worst case, Atuin should not function but the shell should start correctly. +/// +/// While currently this only returns aliases, it will be extended to also return other synced dotfiles +pub async fn config(store: &AliasStore) -> String { + // First try to read the cached config + let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.xsh"); + + if aliases.exists() { + return cached_aliases(aliases, store).await; + } + + if let Err(e) = store.build().await { + return format!("echo 'Atuin: failed to generate aliases: {}'", e); + } + + cached_aliases(aliases, store).await +} diff --git a/crates/atuin-dotfiles/src/shell/zsh.rs b/crates/atuin-dotfiles/src/shell/zsh.rs new file mode 100644 index 00000000..d863b261 --- /dev/null +++ b/crates/atuin-dotfiles/src/shell/zsh.rs @@ -0,0 +1,39 @@ +use std::path::PathBuf; + +use crate::store::AliasStore; + +async fn cached_aliases(path: PathBuf, store: &AliasStore) -> String { + match tokio::fs::read_to_string(path).await { + Ok(aliases) => aliases, + Err(r) => { + // we failed to read the file for some reason, but the file does exist + // fallback to generating new aliases on the fly + + store.posix().await.unwrap_or_else(|e| { + format!("echo 'Atuin: failed to read and generate aliases: \n{r}\n{e}'",) + }) + } + } +} + +/// Return zsh dotfile config +/// +/// Do not return an error. We should not prevent the shell from starting. +/// +/// In the worst case, Atuin should not function but the shell should start correctly. +/// +/// While currently this only returns aliases, it will be extended to also return other synced dotfiles +pub async fn config(store: &AliasStore) -> String { + // First try to read the cached config + let aliases = atuin_common::utils::dotfiles_cache_dir().join("aliases.zsh"); + + if aliases.exists() { + return cached_aliases(aliases, store).await; + } + + if let Err(e) = store.build().await { + return format!("echo 'Atuin: failed to generate aliases: {}'", e); + } + + cached_aliases(aliases, store).await +} diff --git a/crates/atuin-dotfiles/src/store.rs b/crates/atuin-dotfiles/src/store.rs new file mode 100644 index 00000000..425a5e1e --- /dev/null +++ b/crates/atuin-dotfiles/src/store.rs @@ -0,0 +1,364 @@ +use std::collections::BTreeMap; + +use atuin_client::record::sqlite_store::SqliteStore; +// Sync aliases +// This will be noticeable similar to the kv store, though I expect the two shall diverge +// While we will support a range of shell config, I'd rather have a larger number of small records +// + stores, rather than one mega config store. +use atuin_common::record::{DecryptedData, Host, HostId}; +use eyre::{bail, ensure, eyre, Result}; + +use atuin_client::record::encryption::PASETO_V4; +use atuin_client::record::store::Store; + +use crate::shell::Alias; + +const CONFIG_SHELL_ALIAS_VERSION: &str = "v0"; +const CONFIG_SHELL_ALIAS_TAG: &str = "config-shell-alias"; +const CONFIG_SHELL_ALIAS_FIELD_MAX_LEN: usize = 20000; // 20kb max total len, way more than should be needed. + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AliasRecord { + Create(Alias), // create a full record + Delete(String), // delete by name +} + +impl AliasRecord { + pub fn serialize(&self) -> Result<DecryptedData> { + use rmp::encode; + + let mut output = vec![]; + + match self { + AliasRecord::Create(alias) => { + encode::write_u8(&mut output, 0)?; // create + encode::write_array_len(&mut output, 2)?; // 2 fields + + encode::write_str(&mut output, alias.name.as_str())?; + encode::write_str(&mut output, alias.value.as_str())?; + } + AliasRecord::Delete(name) => { + encode::write_u8(&mut output, 1)?; // delete + encode::write_array_len(&mut output, 1)?; // 1 field + + encode::write_str(&mut output, name.as_str())?; + } + } + + Ok(DecryptedData(output)) + } + + pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> { + use rmp::decode; + + fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report { + eyre!("{err:?}") + } + + match version { + CONFIG_SHELL_ALIAS_VERSION => { + let mut bytes = decode::Bytes::new(&data.0); + + let record_type = decode::read_u8(&mut bytes).map_err(error_report)?; + + match record_type { + // create + 0 => { + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!( + nfields == 2, + "too many entries in v0 shell alias create record" + ); + + let bytes = bytes.remaining_slice(); + + let (key, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + let (value, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded shell alias record. malformed") + } + + Ok(AliasRecord::Create(Alias { + name: key.to_owned(), + value: value.to_owned(), + })) + } + + // delete + 1 => { + let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?; + ensure!( + nfields == 1, + "too many entries in v0 shell alias delete record" + ); + + let bytes = bytes.remaining_slice(); + + let (key, bytes) = + decode::read_str_from_slice(bytes).map_err(error_report)?; + + if !bytes.is_empty() { + bail!("trailing bytes in encoded shell alias record. malformed") + } + + Ok(AliasRecord::Delete(key.to_owned())) + } + + n => { + bail!("unknown AliasRecord type {n}") + } + } + } + _ => { + bail!("unknown version {version:?}") + } + } + } +} + +#[derive(Debug, Clone)] +pub struct AliasStore { + pub store: SqliteStore, + pub host_id: HostId, + pub encryption_key: [u8; 32], +} + +impl AliasStore { + // will want to init the actual kv store when that is done + pub fn new(store: SqliteStore, host_id: HostId, encryption_key: [u8; 32]) -> AliasStore { + AliasStore { + store, + host_id, + encryption_key, + } + } + + pub async fn posix(&self) -> Result<String> { + let aliases = self.aliases().await?; + + let mut config = String::new(); + + for alias in aliases { + config.push_str(&format!("alias {}='{}'\n", alias.name, alias.value)); + } + + Ok(config) + } + + pub async fn xonsh(&self) -> Result<String> { + let aliases = self.aliases().await?; + + let mut config = String::new(); + + for alias in aliases { + config.push_str(&format!("aliases['{}'] ='{}'\n", alias.name, alias.value)); + } + + Ok(config) + } + + pub async fn build(&self) -> Result<()> { + let dir = atuin_common::utils::dotfiles_cache_dir(); + tokio::fs::create_dir_all(dir.clone()).await?; + + // Build for all supported shells + let posix = self.posix().await?; + let xonsh = self.xonsh().await?; + + // All the same contents, maybe optimize in the future or perhaps there will be quirks + // per-shell + // I'd prefer separation atm + let zsh = dir.join("aliases.zsh"); + let bash = dir.join("aliases.bash"); + let fish = dir.join("aliases.fish"); + let xsh = dir.join("aliases.xsh"); + + tokio::fs::write(zsh, &posix).await?; + tokio::fs::write(bash, &posix).await?; + tokio::fs::write(fish, &posix).await?; + tokio::fs::write(xsh, &xonsh).await?; + + Ok(()) + } + + pub async fn set(&self, name: &str, value: &str) -> Result<()> { + if name.len() + value.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN { + return Err(eyre!( + "alias record too large: max len {} bytes", + CONFIG_SHELL_ALIAS_FIELD_MAX_LEN + )); + } + + let record = AliasRecord::Create(Alias { + name: name.to_string(), + value: value.to_string(), + }); + + let bytes = record.serialize()?; + + let idx = self + .store + .last(self.host_id, CONFIG_SHELL_ALIAS_TAG) + .await? + .map_or(0, |entry| entry.idx + 1); + + let record = atuin_common::record::Record::builder() + .host(Host::new(self.host_id)) + .version(CONFIG_SHELL_ALIAS_VERSION.to_string()) + .tag(CONFIG_SHELL_ALIAS_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + // set mutates shell config, so build again + self.build().await?; + + Ok(()) + } + + pub async fn delete(&self, name: &str) -> Result<()> { + if name.len() > CONFIG_SHELL_ALIAS_FIELD_MAX_LEN { + return Err(eyre!( + "alias record too large: max len {} bytes", + CONFIG_SHELL_ALIAS_FIELD_MAX_LEN + )); + } + + let record = AliasRecord::Delete(name.to_string()); + + let bytes = record.serialize()?; + + let idx = self + .store + .last(self.host_id, CONFIG_SHELL_ALIAS_TAG) + .await? + .map_or(0, |entry| entry.idx + 1); + + let record = atuin_common::record::Record::builder() + .host(Host::new(self.host_id)) + .version(CONFIG_SHELL_ALIAS_VERSION.to_string()) + .tag(CONFIG_SHELL_ALIAS_TAG.to_string()) + .idx(idx) + .data(bytes) + .build(); + + self.store + .push(&record.encrypt::<PASETO_V4>(&self.encryption_key)) + .await?; + + // delete mutates shell config, so build again + self.build().await?; + + Ok(()) + } + + pub async fn aliases(&self) -> Result<Vec<Alias>> { + let mut build = BTreeMap::new(); + + // this is sorted, oldest to newest + let tagged = self.store.all_tagged(CONFIG_SHELL_ALIAS_TAG).await?; + + for record in tagged { + let version = record.version.clone(); + + let decrypted = match version.as_str() { + CONFIG_SHELL_ALIAS_VERSION => record.decrypt::<PASETO_V4>(&self.encryption_key)?, + version => bail!("unknown version {version:?}"), + }; + + let ar = AliasRecord::deserialize(&decrypted.data, version.as_str())?; + + match ar { + AliasRecord::Create(a) => { + build.insert(a.name.clone(), a); + } + AliasRecord::Delete(d) => { + build.remove(&d); + } + } + } + + Ok(build.into_values().collect()) + } +} + +#[cfg(test)] +pub(crate) fn test_sqlite_store_timeout() -> f64 { + std::env::var("ATUIN_TEST_SQLITE_STORE_TIMEOUT") + .ok() + .and_then(|x| x.parse().ok()) + .unwrap_or(0.1) +} + +#[cfg(test)] +mod tests { + use rand::rngs::OsRng; + + use atuin_client::record::sqlite_store::SqliteStore; + + use crate::shell::Alias; + + use super::{test_sqlite_store_timeout, AliasRecord, AliasStore, CONFIG_SHELL_ALIAS_VERSION}; + use crypto_secretbox::{KeyInit, XSalsa20Poly1305}; + + #[test] + fn encode_decode() { + let record = Alias { + name: "k".to_owned(), + value: "kubectl".to_owned(), + }; + let record = AliasRecord::Create(record); + + let snapshot = [204, 0, 146, 161, 107, 167, 107, 117, 98, 101, 99, 116, 108]; + + let encoded = record.serialize().unwrap(); + let decoded = AliasRecord::deserialize(&encoded, CONFIG_SHELL_ALIAS_VERSION).unwrap(); + + assert_eq!(encoded.0, &snapshot); + assert_eq!(decoded, record); + } + + #[tokio::test] + async fn build_aliases() { + let store = SqliteStore::new(":memory:", test_sqlite_store_timeout()) + .await + .unwrap(); + let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into(); + let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7()); + + let alias = AliasStore::new(store, host_id, key); + + alias.set("k", "kubectl").await.unwrap(); + + alias.set("gp", "git push").await.unwrap(); + + let mut aliases = alias.aliases().await.unwrap(); + + aliases.sort_by_key(|a| a.name.clone()); + + assert_eq!(aliases.len(), 2); + + assert_eq!( + aliases[0], + Alias { + name: String::from("gp"), + value: String::from("git push") + } + ); + + assert_eq!( + aliases[1], + Alias { + name: String::from("k"), + value: String::from("kubectl") + } + ); + } +} diff --git a/crates/atuin-server-database/Cargo.toml b/crates/atuin-server-database/Cargo.toml new file mode 100644 index 00000000..ffd29b8d --- /dev/null +++ b/crates/atuin-server-database/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "atuin-server-database" +edition = "2021" +description = "server database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } + +tracing = "0.1" +time = { workspace = true } +eyre = { workspace = true } +uuid = { workspace = true } +serde = { workspace = true } +async-trait = { workspace = true } diff --git a/crates/atuin-server-database/src/calendar.rs b/crates/atuin-server-database/src/calendar.rs new file mode 100644 index 00000000..2229667b --- /dev/null +++ b/crates/atuin-server-database/src/calendar.rs @@ -0,0 +1,18 @@ +// Calendar data + +use serde::{Deserialize, Serialize}; +use time::Month; + +pub enum TimePeriod { + Year, + Month { year: i32 }, + Day { year: i32, month: Month }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TimePeriodInfo { + pub count: u64, + + // TODO: Use this for merkle tree magic + pub hash: String, +} diff --git a/crates/atuin-server-database/src/lib.rs b/crates/atuin-server-database/src/lib.rs new file mode 100644 index 00000000..d2c16b3d --- /dev/null +++ b/crates/atuin-server-database/src/lib.rs @@ -0,0 +1,173 @@ +#![forbid(unsafe_code)] + +pub mod calendar; +pub mod models; + +use std::{ + collections::HashMap, + fmt::{Debug, Display}, + ops::Range, +}; + +use self::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::{History, NewHistory, NewSession, NewUser, Session, User}, +}; +use async_trait::async_trait; +use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use serde::{de::DeserializeOwned, Serialize}; +use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset}; +use tracing::instrument; + +#[derive(Debug)] +pub enum DbError { + NotFound, + Other(eyre::Report), +} + +impl Display for DbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl<T: std::error::Error + Into<time::error::Error>> From<T> for DbError { + fn from(value: T) -> Self { + DbError::Other(value.into().into()) + } +} + +impl std::error::Error for DbError {} + +pub type DbResult<T> = Result<T, DbError>; + +#[async_trait] +pub trait Database: Sized + Clone + Send + Sync + 'static { + type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static; + async fn new(settings: &Self::Settings) -> DbResult<Self>; + + async fn get_session(&self, token: &str) -> DbResult<Session>; + async fn get_session_user(&self, token: &str) -> DbResult<User>; + async fn add_session(&self, session: &NewSession) -> DbResult<()>; + + async fn get_user(&self, username: &str) -> DbResult<User>; + async fn get_user_session(&self, u: &User) -> DbResult<Session>; + async fn add_user(&self, user: &NewUser) -> DbResult<i64>; + async fn update_user_password(&self, u: &User) -> DbResult<()>; + + async fn total_history(&self) -> DbResult<i64>; + async fn count_history(&self, user: &User) -> DbResult<i64>; + async fn count_history_cached(&self, user: &User) -> DbResult<i64>; + + async fn delete_user(&self, u: &User) -> DbResult<()>; + async fn delete_history(&self, user: &User, id: String) -> DbResult<()>; + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>>; + async fn delete_store(&self, user: &User) -> DbResult<()>; + + async fn add_records(&self, user: &User, record: &[Record<EncryptedData>]) -> DbResult<()>; + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>>; + + // Return the tail record ID for each store, so (HostID, Tag, TailRecordID) + async fn status(&self, user: &User) -> DbResult<RecordStatus>; + + async fn count_history_range(&self, user: &User, range: Range<OffsetDateTime>) + -> DbResult<i64>; + + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>>; + + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()>; + + async fn oldest_history(&self, user: &User) -> DbResult<History>; + + #[instrument(skip_all)] + async fn calendar( + &self, + user: &User, + period: TimePeriod, + tz: UtcOffset, + ) -> DbResult<HashMap<u64, TimePeriodInfo>> { + let mut ret = HashMap::new(); + let iter: Box<dyn Iterator<Item = DbResult<(u64, Range<Date>)>> + Send> = match period { + TimePeriod::Year => { + // First we need to work out how far back to calculate. Get the + // oldest history item + let oldest = self + .oldest_history(user) + .await? + .timestamp + .to_offset(tz) + .year(); + let current_year = OffsetDateTime::now_utc().to_offset(tz).year(); + + // All the years we need to get data for + // The upper bound is exclusive, so include current +1 + let years = oldest..current_year + 1; + + Box::new(years.map(|year| { + let start = Date::from_calendar_date(year, time::Month::January, 1)?; + let end = Date::from_calendar_date(year + 1, time::Month::January, 1)?; + + Ok((year as u64, start..end)) + })) + } + + TimePeriod::Month { year } => { + let months = + std::iter::successors(Some(Month::January), |m| Some(m.next())).take(12); + + Box::new(months.map(move |month| { + let start = Date::from_calendar_date(year, month, 1)?; + let days = time::util::days_in_year_month(year, month); + let end = start + Duration::days(days as i64); + + Ok((month as u64, start..end)) + })) + } + + TimePeriod::Day { year, month } => { + let days = 1..time::util::days_in_year_month(year, month); + Box::new(days.map(move |day| { + let start = Date::from_calendar_date(year, month, day)?; + let end = start + .next_day() + .ok_or_else(|| DbError::Other(eyre::eyre!("no next day?")))?; + + Ok((day as u64, start..end)) + })) + } + }; + + for x in iter { + let (index, range) = x?; + + let start = range.start.with_time(Time::MIDNIGHT).assume_offset(tz); + let end = range.end.with_time(Time::MIDNIGHT).assume_offset(tz); + + let count = self.count_history_range(user, start..end).await?; + + ret.insert( + index, + TimePeriodInfo { + count: count as u64, + hash: "".to_string(), + }, + ); + } + + Ok(ret) + } +} diff --git a/crates/atuin-server-database/src/models.rs b/crates/atuin-server-database/src/models.rs new file mode 100644 index 00000000..b71a9bc9 --- /dev/null +++ b/crates/atuin-server-database/src/models.rs @@ -0,0 +1,52 @@ +use time::OffsetDateTime; + +pub struct History { + pub id: i64, + pub client_id: String, // a client generated ID + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, + + pub created_at: OffsetDateTime, +} + +pub struct NewHistory { + pub client_id: String, + pub user_id: i64, + pub hostname: String, + pub timestamp: OffsetDateTime, + + /// All the data we have about this command, encrypted. + /// + /// Currently this is an encrypted msgpack object, but this may change in the future. + pub data: String, +} + +pub struct User { + pub id: i64, + pub username: String, + pub email: String, + pub password: String, +} + +pub struct Session { + pub id: i64, + pub user_id: i64, + pub token: String, +} + +pub struct NewUser { + pub username: String, + pub email: String, + pub password: String, +} + +pub struct NewSession { + pub user_id: i64, + pub token: String, +} diff --git a/crates/atuin-server-postgres/Cargo.toml b/crates/atuin-server-postgres/Cargo.toml new file mode 100644 index 00000000..647d934a --- /dev/null +++ b/crates/atuin-server-postgres/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "atuin-server-postgres" +edition = "2021" +description = "server postgres database library for atuin" + +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } +atuin-server-database = { path = "../atuin-server-database", version = "18.2.0" } + +eyre = { workspace = true } +tracing = "0.1" +time = { workspace = true } +serde = { workspace = true } +sqlx = { workspace = true } +async-trait = { workspace = true } +uuid = { workspace = true } +futures-util = "0.3" diff --git a/crates/atuin-server-postgres/build.rs b/crates/atuin-server-postgres/build.rs new file mode 100644 index 00000000..d5068697 --- /dev/null +++ b/crates/atuin-server-postgres/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql b/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql new file mode 100644 index 00000000..2c2d17b0 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20210425153745_create_history.sql @@ -0,0 +1,11 @@ +create table history ( + id bigserial primary key, + client_id text not null unique, -- the client-generated ID + user_id bigserial not null, -- allow multiple users + hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) + timestamp timestamp not null, -- one of the few non-encrypted metadatas + + data varchar(8192) not null, -- store the actual history data, encrypted. I don't wanna know! + + created_at timestamp not null default current_timestamp +); diff --git a/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql b/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql new file mode 100644 index 00000000..a25dcced --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20210425153757_create_users.sql @@ -0,0 +1,10 @@ +create table users ( + id bigserial primary key, -- also store our own ID + username varchar(32) not null unique, -- being able to contact users is useful + email varchar(128) not null unique, -- being able to contact users is useful + password varchar(128) not null unique +); + +-- the prior index is case sensitive :( +CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email)); +CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username)); diff --git a/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql b/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql new file mode 100644 index 00000000..c2fb6559 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20210425153800_create_sessions.sql @@ -0,0 +1,6 @@ +-- Add migration script here +create table sessions ( + id bigserial primary key, + user_id bigserial, + token varchar(128) unique not null +); diff --git a/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql b/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql new file mode 100644 index 00000000..dd1afa88 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220419082412_add_count_trigger.sql @@ -0,0 +1,51 @@ +-- Prior to this, the count endpoint was super naive and just ran COUNT(1). +-- This is slow asf. Now that we have an amount of actual traffic, +-- stop doing that! +-- This basically maintains a count, so we can read ONE row, instead of ALL the +-- rows. Much better. +-- Future optimisation could use some sort of cache so we don't even need to hit +-- postgres at all. + +create table total_history_count_user( + id bigserial primary key, + user_id bigserial, + total integer -- try and avoid using keywords - hence total, not count +); + +create or replace function user_history_count() +returns trigger as +$func$ +begin + if (TG_OP='INSERT') then + update total_history_count_user set total = total + 1 where user_id = new.user_id; + + if not found then + insert into total_history_count_user(user_id, total) + values ( + new.user_id, + (select count(1) from history where user_id = new.user_id) + ); + end if; + + elsif (TG_OP='DELETE') then + update total_history_count_user set total = total - 1 where user_id = new.user_id; + + if not found then + insert into total_history_count_user(user_id, total) + values ( + new.user_id, + (select count(1) from history where user_id = new.user_id) + ); + end if; + end if; + + return NEW; -- this is actually ignored for an after trigger, but oh well +end; +$func$ +language plpgsql volatile -- pldfplplpflh +cost 100; -- default value + +create trigger tg_user_history_count + after insert or delete on history + for each row + execute procedure user_history_count(); diff --git a/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql b/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql new file mode 100644 index 00000000..6198f300 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220421073605_fix_count_trigger_delete.sql @@ -0,0 +1,35 @@ +-- the old version of this function used NEW in the delete part when it should +-- use OLD + +create or replace function user_history_count() +returns trigger as +$func$ +begin + if (TG_OP='INSERT') then + update total_history_count_user set total = total + 1 where user_id = new.user_id; + + if not found then + insert into total_history_count_user(user_id, total) + values ( + new.user_id, + (select count(1) from history where user_id = new.user_id) + ); + end if; + + elsif (TG_OP='DELETE') then + update total_history_count_user set total = total - 1 where user_id = old.user_id; + + if not found then + insert into total_history_count_user(user_id, total) + values ( + old.user_id, + (select count(1) from history where user_id = old.user_id) + ); + end if; + end if; + + return NEW; -- this is actually ignored for an after trigger, but oh well +end; +$func$ +language plpgsql volatile -- pldfplplpflh +cost 100; -- default value diff --git a/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql b/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql new file mode 100644 index 00000000..0ac43433 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220421174016_larger-commands.sql @@ -0,0 +1,3 @@ +-- Make it 4x larger. Most commands are less than this, but as it's base64 +-- SOME are more than 8192. Should be enough for now. +ALTER TABLE history ALTER COLUMN data TYPE varchar(32768); diff --git a/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql b/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql new file mode 100644 index 00000000..a9138194 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220426172813_user-created-at.sql @@ -0,0 +1 @@ +alter table users add column created_at timestamp not null default now(); diff --git a/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql b/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql new file mode 100644 index 00000000..57e16ec7 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220505082442_create-events.sql @@ -0,0 +1,14 @@ +create type event_type as enum ('create', 'delete'); + +create table events ( + id bigserial primary key, + client_id text not null unique, -- the client-generated ID + user_id bigserial not null, -- allow multiple users + hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever) + timestamp timestamp not null, -- one of the few non-encrypted metadatas + + event_type event_type, + data text not null, -- store the actual history data, encrypted. I don't wanna know! + + created_at timestamp not null default current_timestamp +); diff --git a/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql b/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql new file mode 100644 index 00000000..b1c23016 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20220610074049_history-length.sql @@ -0,0 +1,2 @@ +-- Add migration script here +alter table history alter column data type text; diff --git a/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql b/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql new file mode 100644 index 00000000..fe3cae17 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20230315220537_drop-events.sql @@ -0,0 +1,2 @@ +-- Add migration script here +drop table events; diff --git a/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql b/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql new file mode 100644 index 00000000..9a9e6263 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20230315224203_create-deleted.sql @@ -0,0 +1,5 @@ +-- Add migration script here +alter table history add column if not exists deleted_at timestamp; + +-- queries will all be selecting the ids of history for a user, that has been deleted +create index if not exists history_deleted_index on history(client_id, user_id, deleted_at); diff --git a/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql b/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql new file mode 100644 index 00000000..3d0bba52 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20230515221038_trigger-delete-only.sql @@ -0,0 +1,30 @@ +-- We do not need to run the trigger on deletes, as the only time we are deleting history is when the user +-- has already been deleted +-- This actually slows down deleting all the history a good bit! + +create or replace function user_history_count() +returns trigger as +$func$ +begin + if (TG_OP='INSERT') then + update total_history_count_user set total = total + 1 where user_id = new.user_id; + + if not found then + insert into total_history_count_user(user_id, total) + values ( + new.user_id, + (select count(1) from history where user_id = new.user_id) + ); + end if; + end if; + + return NEW; -- this is actually ignored for an after trigger, but oh well +end; +$func$ +language plpgsql volatile -- pldfplplpflh +cost 100; -- default value + +create or replace trigger tg_user_history_count + after insert on history + for each row + execute procedure user_history_count(); diff --git a/crates/atuin-server-postgres/migrations/20230623070418_records.sql b/crates/atuin-server-postgres/migrations/20230623070418_records.sql new file mode 100644 index 00000000..22437595 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20230623070418_records.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table records ( + id uuid primary key, -- remember to use uuidv7 for happy indices <3 + client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key + host uuid not null, -- a unique identifier for the host + parent uuid default null, -- the ID of the parent record, bearing in mind this is a linked list + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); diff --git a/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql b/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql new file mode 100644 index 00000000..ffb57966 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20231202170508_create-store.sql @@ -0,0 +1,15 @@ +-- Add migration script here +create table store ( + id uuid primary key, -- remember to use uuidv7 for happy indices <3 + client_id uuid not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically + host uuid not null, -- a unique identifier for the host + idx bigint not null, -- the index of the record in this store, identified by (host, tag) + timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision + version text not null, + tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host + data text not null, -- store the actual history data, encrypted. I don't wanna know! + cek text not null, + + user_id bigint not null, -- allow multiple users + created_at timestamp not null default current_timestamp +); diff --git a/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql b/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql new file mode 100644 index 00000000..56d67145 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20231203124112_create-store-idx.sql @@ -0,0 +1,2 @@ +-- Add migration script here +create unique index record_uniq ON store(user_id, host, tag, idx); diff --git a/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql b/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql new file mode 100644 index 00000000..ad2af5a1 --- /dev/null +++ b/crates/atuin-server-postgres/migrations/20240108124837_drop-some-defaults.sql @@ -0,0 +1,4 @@ +-- Add migration script here +alter table history alter column user_id drop default; +alter table sessions alter column user_id drop default; +alter table total_history_count_user alter column user_id drop default; diff --git a/crates/atuin-server-postgres/src/lib.rs b/crates/atuin-server-postgres/src/lib.rs new file mode 100644 index 00000000..6dc56fe4 --- /dev/null +++ b/crates/atuin-server-postgres/src/lib.rs @@ -0,0 +1,538 @@ +use std::ops::Range; + +use async_trait::async_trait; +use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; +use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User}; +use atuin_server_database::{Database, DbError, DbResult}; +use futures_util::TryStreamExt; +use serde::{Deserialize, Serialize}; +use sqlx::postgres::PgPoolOptions; +use sqlx::Row; + +use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset}; +use tracing::instrument; +use uuid::Uuid; +use wrappers::{DbHistory, DbRecord, DbSession, DbUser}; + +mod wrappers; + +const MIN_PG_VERSION: u32 = 14; + +#[derive(Clone)] +pub struct Postgres { + pool: sqlx::Pool<sqlx::postgres::Postgres>, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PostgresSettings { + pub db_uri: String, +} + +fn fix_error(error: sqlx::Error) -> DbError { + match error { + sqlx::Error::RowNotFound => DbError::NotFound, + error => DbError::Other(error.into()), + } +} + +#[async_trait] +impl Database for Postgres { + type Settings = PostgresSettings; + async fn new(settings: &PostgresSettings) -> DbResult<Self> { + let pool = PgPoolOptions::new() + .max_connections(100) + .connect(settings.db_uri.as_str()) + .await + .map_err(fix_error)?; + + // Call server_version_num to get the DB server's major version number + // The call returns None for servers older than 8.x. + let pg_major_version: u32 = pool + .acquire() + .await + .map_err(fix_error)? + .server_version_num() + .ok_or(DbError::Other(eyre::Report::msg( + "could not get PostgreSQL version", + )))? + / 10000; + + if pg_major_version < MIN_PG_VERSION { + return Err(DbError::Other(eyre::Report::msg(format!( + "unsupported PostgreSQL version {}, minimum required is {}", + pg_major_version, MIN_PG_VERSION + )))); + } + + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|error| DbError::Other(error.into()))?; + + Ok(Self { pool }) + } + + #[instrument(skip_all)] + async fn get_session(&self, token: &str) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where token = $1") + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn get_user(&self, username: &str) -> DbResult<User> { + sqlx::query_as("select id, username, email, password from users where username = $1") + .bind(username) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn get_session_user(&self, token: &str) -> DbResult<User> { + sqlx::query_as( + "select users.id, users.username, users.email, users.password from users + inner join sessions + on users.id = sessions.user_id + and sessions.token = $1", + ) + .bind(token) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbUser(user)| user) + } + + #[instrument(skip_all)] + async fn count_history(&self, user: &User) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn total_history(&self) -> DbResult<i64> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res: (i64,) = sqlx::query_as("select sum(total) from total_history_count_user") + .fetch_optional(&self.pool) + .await + .map_err(fix_error)? + .unwrap_or((0,)); + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn count_history_cached(&self, user: &User) -> DbResult<i64> { + let res: (i32,) = sqlx::query_as( + "select total from total_history_count_user + where user_id = $1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0 as i64) + } + + async fn delete_store(&self, user: &User) -> DbResult<()> { + sqlx::query( + "delete from store + where user_id = $1", + ) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + async fn delete_history(&self, user: &User, id: String) -> DbResult<()> { + sqlx::query( + "update history + set deleted_at = $3 + where user_id = $1 + and client_id = $2 + and deleted_at is null", // don't just keep setting it + ) + .bind(user.id) + .bind(id) + .bind(OffsetDateTime::now_utc()) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> { + // The cache is new, and the user might not yet have a cache value. + // They will have one as soon as they post up some new history, but handle that + // edge case. + + let res = sqlx::query( + "select client_id from history + where user_id = $1 + and deleted_at is not null", + ) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let res = res + .iter() + .map(|row| row.get::<String, _>("client_id")) + .collect(); + + Ok(res) + } + + #[instrument(skip_all)] + async fn count_history_range( + &self, + user: &User, + range: Range<OffsetDateTime>, + ) -> DbResult<i64> { + let res: (i64,) = sqlx::query_as( + "select count(1) from history + where user_id = $1 + and timestamp >= $2::date + and timestamp < $3::date", + ) + .bind(user.id) + .bind(into_utc(range.start)) + .bind(into_utc(range.end)) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn list_history( + &self, + user: &User, + created_after: OffsetDateTime, + since: OffsetDateTime, + host: &str, + page_size: i64, + ) -> DbResult<Vec<History>> { + let res = sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + and hostname != $2 + and created_at >= $3 + and timestamp >= $4 + order by timestamp asc + limit $5", + ) + .bind(user.id) + .bind(host) + .bind(into_utc(created_after)) + .bind(into_utc(since)) + .bind(page_size) + .fetch(&self.pool) + .map_ok(|DbHistory(h)| h) + .try_collect() + .await + .map_err(fix_error)?; + + Ok(res) + } + + #[instrument(skip_all)] + async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in history { + let client_id: &str = &i.client_id; + let hostname: &str = &i.hostname; + let data: &str = &i.data; + + sqlx::query( + "insert into history + (client_id, user_id, hostname, timestamp, data) + values ($1, $2, $3, $4, $5) + on conflict do nothing + ", + ) + .bind(client_id) + .bind(i.user_id) + .bind(hostname) + .bind(i.timestamp) + .bind(data) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn delete_user(&self, u: &User) -> DbResult<()> { + sqlx::query("delete from sessions where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from users where id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from history where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + sqlx::query("delete from total_history_count_user where user_id = $1") + .bind(u.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn update_user_password(&self, user: &User) -> DbResult<()> { + sqlx::query( + "update users + set password = $1 + where id = $2", + ) + .bind(&user.password) + .bind(user.id) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn add_user(&self, user: &NewUser) -> DbResult<i64> { + let email: &str = &user.email; + let username: &str = &user.username; + let password: &str = &user.password; + + let res: (i64,) = sqlx::query_as( + "insert into users + (username, email, password) + values($1, $2, $3) + returning id", + ) + .bind(username) + .bind(email) + .bind(password) + .fetch_one(&self.pool) + .await + .map_err(fix_error)?; + + Ok(res.0) + } + + #[instrument(skip_all)] + async fn add_session(&self, session: &NewSession) -> DbResult<()> { + let token: &str = &session.token; + + sqlx::query( + "insert into sessions + (user_id, token) + values($1, $2)", + ) + .bind(session.user_id) + .bind(token) + .execute(&self.pool) + .await + .map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn get_user_session(&self, u: &User) -> DbResult<Session> { + sqlx::query_as("select id, user_id, token from sessions where user_id = $1") + .bind(u.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbSession(session)| session) + } + + #[instrument(skip_all)] + async fn oldest_history(&self, user: &User) -> DbResult<History> { + sqlx::query_as( + "select id, client_id, user_id, hostname, timestamp, data, created_at from history + where user_id = $1 + order by timestamp asc + limit 1", + ) + .bind(user.id) + .fetch_one(&self.pool) + .await + .map_err(fix_error) + .map(|DbHistory(h)| h) + } + + #[instrument(skip_all)] + async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> { + let mut tx = self.pool.begin().await.map_err(fix_error)?; + + for i in records { + let id = atuin_common::utils::uuid_v7(); + + sqlx::query( + "insert into store + (id, client_id, host, idx, timestamp, version, tag, data, cek, user_id) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + on conflict do nothing + ", + ) + .bind(id) + .bind(i.id) + .bind(i.host.id) + .bind(i.idx as i64) + .bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time + .bind(&i.version) + .bind(&i.tag) + .bind(&i.data.data) + .bind(&i.data.content_encryption_key) + .bind(user.id) + .execute(&mut *tx) + .await + .map_err(fix_error)?; + } + + tx.commit().await.map_err(fix_error)?; + + Ok(()) + } + + #[instrument(skip_all)] + async fn next_records( + &self, + user: &User, + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, + ) -> DbResult<Vec<Record<EncryptedData>>> { + tracing::debug!("{:?} - {:?} - {:?}", host, tag, start); + let start = start.unwrap_or(0); + + let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as( + "select client_id, host, idx, timestamp, version, tag, data, cek from store + where user_id = $1 + and tag = $2 + and host = $3 + and idx >= $4 + order by idx asc + limit $5", + ) + .bind(user.id) + .bind(tag.clone()) + .bind(host) + .bind(start as i64) + .bind(count as i64) + .fetch_all(&self.pool) + .await + .map_err(fix_error); + + let ret = match records { + Ok(records) => { + let records: Vec<Record<EncryptedData>> = records + .into_iter() + .map(|f| { + let record: Record<EncryptedData> = f.into(); + record + }) + .collect(); + + records + } + Err(DbError::NotFound) => { + tracing::debug!("no records found in store: {:?}/{}", host, tag); + return Ok(vec![]); + } + Err(e) => return Err(e), + }; + + Ok(ret) + } + + async fn status(&self, user: &User) -> DbResult<RecordStatus> { + const STATUS_SQL: &str = + "select host, tag, max(idx) from store where user_id = $1 group by host, tag"; + + let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL) + .bind(user.id) + .fetch_all(&self.pool) + .await + .map_err(fix_error)?; + + let mut status = RecordStatus::new(); + + for i in res { + status.set_raw(HostId(i.0), i.1, i.2 as u64); + } + + Ok(status) + } +} + +fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime { + let x = x.to_offset(UtcOffset::UTC); + PrimitiveDateTime::new(x.date(), x.time()) +} + +#[cfg(test)] +mod tests { + use time::macros::datetime; + + use crate::into_utc; + + #[test] + fn utc() { + let dt = datetime!(2023-09-26 15:11:02 +05:30); + assert_eq!(into_utc(dt), datetime!(2023-09-26 09:41:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 -07:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 22:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + + let dt = datetime!(2023-09-26 15:11:02 +00:00); + assert_eq!(into_utc(dt), datetime!(2023-09-26 15:11:02)); + assert_eq!(into_utc(dt).assume_utc(), dt); + } +} diff --git a/crates/atuin-server-postgres/src/wrappers.rs b/crates/atuin-server-postgres/src/wrappers.rs new file mode 100644 index 00000000..3ccf9c19 --- /dev/null +++ b/crates/atuin-server-postgres/src/wrappers.rs @@ -0,0 +1,77 @@ +use ::sqlx::{FromRow, Result}; +use atuin_common::record::{EncryptedData, Host, Record}; +use atuin_server_database::models::{History, Session, User}; +use sqlx::{postgres::PgRow, Row}; +use time::PrimitiveDateTime; + +pub struct DbUser(pub User); +pub struct DbSession(pub Session); +pub struct DbHistory(pub History); +pub struct DbRecord(pub Record<EncryptedData>); + +impl<'a> FromRow<'a, PgRow> for DbUser { + fn from_row(row: &'a PgRow) -> Result<Self> { + Ok(Self(User { + id: row.try_get("id")?, + username: row.try_get("username")?, + email: row.try_get("email")?, + password: row.try_get("password")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbSession { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + Ok(Self(Session { + id: row.try_get("id")?, + user_id: row.try_get("user_id")?, + token: row.try_get("token")?, + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbHistory { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + Ok(Self(History { + id: row.try_get("id")?, + client_id: row.try_get("client_id")?, + user_id: row.try_get("user_id")?, + hostname: row.try_get("hostname")?, + timestamp: row + .try_get::<PrimitiveDateTime, _>("timestamp")? + .assume_utc(), + data: row.try_get("data")?, + created_at: row + .try_get::<PrimitiveDateTime, _>("created_at")? + .assume_utc(), + })) + } +} + +impl<'a> ::sqlx::FromRow<'a, PgRow> for DbRecord { + fn from_row(row: &'a PgRow) -> ::sqlx::Result<Self> { + let timestamp: i64 = row.try_get("timestamp")?; + let idx: i64 = row.try_get("idx")?; + + let data = EncryptedData { + data: row.try_get("data")?, + content_encryption_key: row.try_get("cek")?, + }; + + Ok(Self(Record { + id: row.try_get("client_id")?, + host: Host::new(row.try_get("host")?), + idx: idx as u64, + timestamp: timestamp as u64, + version: row.try_get("version")?, + tag: row.try_get("tag")?, + data, + })) + } +} + +impl From<DbRecord> for Record<EncryptedData> { + fn from(other: DbRecord) -> Record<EncryptedData> { + Record { ..other.0 } + } +} diff --git a/crates/atuin-server/Cargo.toml b/crates/atuin-server/Cargo.toml new file mode 100644 index 00000000..a6b8a9f6 --- /dev/null +++ b/crates/atuin-server/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "atuin-server" +edition = "2021" +description = "server library for atuin" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[dependencies] +atuin-common = { path = "../atuin-common", version = "18.2.0" } +atuin-server-database = { path = "../atuin-server-database", version = "18.2.0" } + +tracing = "0.1" +time = { workspace = true } +eyre = { workspace = true } +uuid = { workspace = true } +config = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +base64 = { workspace = true } +rand = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +axum = "0.7.4" +axum-server = { version = "0.6.0", features = ["tls-rustls"] } +fs-err = { workspace = true } +tower = "0.4" +tower-http = { version = "0.5.1", features = ["trace"] } +reqwest = { workspace = true } +rustls = "0.21" +rustls-pemfile = "2.1" +argon2 = "0.5.3" +semver = { workspace = true } +metrics-exporter-prometheus = "0.12.1" +metrics = "0.21.1" diff --git a/crates/atuin-server/server.toml b/crates/atuin-server/server.toml new file mode 100644 index 00000000..946769c9 --- /dev/null +++ b/crates/atuin-server/server.toml @@ -0,0 +1,34 @@ +## host to bind, can also be passed via CLI args +# host = "127.0.0.1" + +## port to bind, can also be passed via CLI args +# port = 8888 + +## whether to allow anyone to register an account +# open_registration = false + +## URI for postgres (using development creds here) +# db_uri="postgres://username:password@localhost/atuin" + +## Maximum size for one history entry +# max_history_length = 8192 + +## Maximum size for one record entry +## 1024 * 1024 * 1024 +# max_record_size = 1073741824 + +## Webhook to be called when user registers on the servers +# register_webhook_username = "" + +## Default page size for requests +# page_size = 1100 + +# [metrics] +# enable = false +# host = 127.0.0.1 +# port = 9001 + +# [tls] +# enable = false +# cert_path = "" +# pkey_path = "" diff --git a/crates/atuin-server/src/handlers/history.rs b/crates/atuin-server/src/handlers/history.rs new file mode 100644 index 00000000..05bbe740 --- /dev/null +++ b/crates/atuin-server/src/handlers/history.rs @@ -0,0 +1,237 @@ +use std::{collections::HashMap, convert::TryFrom}; + +use axum::{ + extract::{Path, Query, State}, + http::{HeaderMap, StatusCode}, + Json, +}; +use metrics::counter; +use time::{Month, UtcOffset}; +use tracing::{debug, error, instrument}; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::{ + router::{AppState, UserAuth}, + utils::client_version_min, +}; +use atuin_server_database::{ + calendar::{TimePeriod, TimePeriodInfo}, + models::NewHistory, + Database, +}; + +use atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn count<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<CountResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => Ok(Json(CountResponse { count })), + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => Ok(Json(CountResponse { count })), + Err(_) => Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)), + }, + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn list<DB: Database>( + req: Query<SyncHistoryRequest>, + UserAuth(user): UserAuth, + headers: HeaderMap, + state: State<AppState<DB>>, +) -> Result<Json<SyncHistoryResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let agent = headers + .get("user-agent") + .map_or("", |v| v.to_str().unwrap_or("")); + + let variable_page_size = client_version_min(agent, ">=15.0.0").unwrap_or(false); + + let page_size = if variable_page_size { + state.settings.page_size + } else { + 100 + }; + + if req.sync_ts.unix_timestamp_nanos() < 0 || req.history_ts.unix_timestamp_nanos() < 0 { + error!("client asked for history from < epoch 0"); + counter!("atuin_history_epoch_before_zero", 1); + + return Err( + ErrorResponse::reply("asked for history from before epoch 0") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + let history = db + .list_history(&user, req.sync_ts, req.history_ts, &req.host, page_size) + .await; + + if let Err(e) = history { + error!("failed to load history: {}", e); + return Err(ErrorResponse::reply("failed to load history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + let history: Vec<String> = history + .unwrap() + .iter() + .map(|i| i.data.to_string()) + .collect(); + + debug!( + "loaded {} items of history for user {}", + history.len(), + user.id + ); + + counter!("atuin_history_returned", history.len() as u64); + + Ok(Json(SyncHistoryResponse { history })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(req): Json<DeleteHistoryRequest>, +) -> Result<Json<MessageResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + // user_id is the ID of the history, as set by the user (the server has its own ID) + let deleted = db.delete_history(&user, req.client_id).await; + + if let Err(e) = deleted { + error!("failed to delete history: {}", e); + return Err(ErrorResponse::reply("failed to delete history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + Ok(Json(MessageResponse { + message: String::from("deleted OK"), + })) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn add<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(req): Json<Vec<AddHistoryRequest>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + debug!("request to add {} history items", req.len()); + counter!("atuin_history_uploaded", req.len() as u64); + + let mut history: Vec<NewHistory> = req + .into_iter() + .map(|h| NewHistory { + client_id: h.id, + user_id: user.id, + hostname: h.hostname, + timestamp: h.timestamp, + data: h.data, + }) + .collect(); + + history.retain(|h| { + // keep if within limit, or limit is 0 (unlimited) + let keep = h.data.len() <= settings.max_history_length || settings.max_history_length == 0; + + // Don't return an error here. We want to insert as much of the + // history list as we can, so log the error and continue going. + if !keep { + counter!("atuin_history_too_long", 1); + + tracing::warn!( + "history too long, got length {}, max {}", + h.data.len(), + settings.max_history_length + ); + } + + keep + }); + + if let Err(e) = database.add_history(&history).await { + error!("failed to add history: {}", e); + + return Err(ErrorResponse::reply("failed to add history") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[derive(serde::Deserialize, Debug)] +pub struct CalendarQuery { + #[serde(default = "serde_calendar::zero")] + year: i32, + #[serde(default = "serde_calendar::one")] + month: u8, + + #[serde(default = "serde_calendar::utc")] + tz: UtcOffset, +} + +mod serde_calendar { + use time::UtcOffset; + + pub fn zero() -> i32 { + 0 + } + + pub fn one() -> u8 { + 1 + } + + pub fn utc() -> UtcOffset { + UtcOffset::UTC + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn calendar<DB: Database>( + Path(focus): Path<String>, + Query(params): Query<CalendarQuery>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<HashMap<u64, TimePeriodInfo>>, ErrorResponseStatus<'static>> { + let focus = focus.as_str(); + + let year = params.year; + let month = Month::try_from(params.month).map_err(|e| ErrorResponseStatus { + error: ErrorResponse { + reason: e.to_string().into(), + }, + status: StatusCode::BAD_REQUEST, + })?; + + let period = match focus { + "year" => TimePeriod::Year, + "month" => TimePeriod::Month { year }, + "day" => TimePeriod::Day { year, month }, + _ => { + return Err(ErrorResponse::reply("invalid focus: use year/month/day") + .with_status(StatusCode::BAD_REQUEST)) + } + }; + + let db = &state.0.database; + let focus = db.calendar(&user, period, params.tz).await.map_err(|_| { + ErrorResponse::reply("failed to query calendar") + .with_status(StatusCode::INTERNAL_SERVER_ERROR) + })?; + + Ok(Json(focus)) +} diff --git a/crates/atuin-server/src/handlers/mod.rs b/crates/atuin-server/src/handlers/mod.rs new file mode 100644 index 00000000..50f82336 --- /dev/null +++ b/crates/atuin-server/src/handlers/mod.rs @@ -0,0 +1,58 @@ +use atuin_common::api::{ErrorResponse, IndexResponse}; +use atuin_server_database::Database; +use axum::{extract::State, http, response::IntoResponse, Json}; + +use crate::router::AppState; + +pub mod history; +pub mod record; +pub mod status; +pub mod user; +pub mod v0; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +pub async fn index<DB: Database>(state: State<AppState<DB>>) -> Json<IndexResponse> { + let homage = r#""Through the fathomless deeps of space swims the star turtle Great A'Tuin, bearing on its back the four giant elephants who carry on their shoulders the mass of the Discworld." -- Sir Terry Pratchett"#; + + // Error with a -1 response + // It's super unlikely this will happen + let count = state.database.total_history().await.unwrap_or(-1); + + Json(IndexResponse { + homage: homage.to_string(), + version: VERSION.to_string(), + total_history: count, + }) +} + +impl<'a> IntoResponse for ErrorResponseStatus<'a> { + fn into_response(self) -> axum::response::Response { + (self.status, Json(self.error)).into_response() + } +} + +pub struct ErrorResponseStatus<'a> { + pub error: ErrorResponse<'a>, + pub status: http::StatusCode, +} + +pub trait RespExt<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a>; + fn reply(reason: &'a str) -> Self; +} + +impl<'a> RespExt<'a> for ErrorResponse<'a> { + fn with_status(self, status: http::StatusCode) -> ErrorResponseStatus<'a> { + ErrorResponseStatus { + error: self, + status, + } + } + + fn reply(reason: &'a str) -> ErrorResponse { + Self { + reason: reason.into(), + } + } +} diff --git a/crates/atuin-server/src/handlers/record.rs b/crates/atuin-server/src/handlers/record.rs new file mode 100644 index 00000000..bf454949 --- /dev/null +++ b/crates/atuin-server/src/handlers/record.rs @@ -0,0 +1,45 @@ +use axum::{http::StatusCode, response::IntoResponse, Json}; +use serde_json::json; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::UserAuth; +use atuin_server_database::Database; + +use atuin_common::record::{EncryptedData, Record}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post<DB: Database>( + UserAuth(user): UserAuth, +) -> Result<(), ErrorResponseStatus<'static>> { + // anyone who has actually used the old record store (a very small number) will see this error + // upon trying to sync. + // 1. The status endpoint will say that the server has nothing + // 2. The client will try to upload local records + // 3. Sync will fail with this error + + // If the client has no local records, they will see the empty index and do nothing. For the + // vast majority of users, this is the case. + return Err( + ErrorResponse::reply("record store deprecated; please upgrade") + .with_status(StatusCode::BAD_REQUEST), + ); +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index<DB: Database>(UserAuth(user): UserAuth) -> axum::response::Response { + let ret = json!({ + "hosts": {} + }); + + ret.to_string().into_response() +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next( + UserAuth(user): UserAuth, +) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { + let records = Vec::new(); + + Ok(Json(records)) +} diff --git a/crates/atuin-server/src/handlers/status.rs b/crates/atuin-server/src/handlers/status.rs new file mode 100644 index 00000000..3c22232c --- /dev/null +++ b/crates/atuin-server/src/handlers/status.rs @@ -0,0 +1,43 @@ +use axum::{extract::State, http::StatusCode, Json}; +use tracing::instrument; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::Database; + +use atuin_common::api::*; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn status<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<StatusResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let deleted = db.deleted_history(&user).await.unwrap_or(vec![]); + + let count = match db.count_history_cached(&user).await { + // By default read out the cached value + Ok(count) => count, + + // If that fails, fallback on a full COUNT. Cache is built on a POST + // only + Err(_) => match db.count_history(&user).await { + Ok(count) => count, + Err(_) => { + return Err(ErrorResponse::reply("failed to query history count") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)) + } + }, + }; + + Ok(Json(StatusResponse { + count, + deleted, + username: user.username, + version: VERSION.to_string(), + page_size: state.settings.page_size, + })) +} diff --git a/crates/atuin-server/src/handlers/user.rs b/crates/atuin-server/src/handlers/user.rs new file mode 100644 index 00000000..e5651fe2 --- /dev/null +++ b/crates/atuin-server/src/handlers/user.rs @@ -0,0 +1,258 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::time::Duration; + +use argon2::{ + password_hash::SaltString, Algorithm, Argon2, Params, PasswordHash, PasswordHasher, + PasswordVerifier, Version, +}; +use axum::{ + extract::{Path, State}, + http::StatusCode, + Json, +}; +use metrics::counter; +use rand::rngs::OsRng; +use tracing::{debug, error, info, instrument}; +use uuid::Uuid; + +use super::{ErrorResponse, ErrorResponseStatus, RespExt}; +use crate::router::{AppState, UserAuth}; +use atuin_server_database::{ + models::{NewSession, NewUser}, + Database, DbError, +}; + +use reqwest::header::CONTENT_TYPE; + +use atuin_common::api::*; + +pub fn verify_str(hash: &str, password: &str) -> bool { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let Ok(hash) = PasswordHash::new(hash) else { + return false; + }; + arg2.verify_password(password.as_bytes(), &hash).is_ok() +} + +// Try to send a Discord webhook once - if it fails, we don't retry. "At most once", and best effort. +// Don't return the status because if this fails, we don't really care. +async fn send_register_hook(url: &str, username: String, registered: String) { + let hook = HashMap::from([ + ("username", username), + ("content", format!("{registered} has just signed up!")), + ]); + + let client = reqwest::Client::new(); + + let resp = client + .post(url) + .timeout(Duration::new(5, 0)) + .header(CONTENT_TYPE, "application/json") + .json(&hook) + .send() + .await; + + match resp { + Ok(_) => info!("register webhook sent ok!"), + Err(e) => error!("failed to send register webhook: {}", e), + } +} + +#[instrument(skip_all, fields(user.username = username.as_str()))] +pub async fn get<DB: Database>( + Path(username): Path<String>, + state: State<AppState<DB>>, +) -> Result<Json<UserResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(username.as_ref()).await { + Ok(user) => user, + Err(DbError::NotFound) => { + debug!("user not found: {}", username); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error: {}", err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(UserResponse { + username: user.username, + })) +} + +#[instrument(skip_all)] +pub async fn register<DB: Database>( + state: State<AppState<DB>>, + Json(register): Json<RegisterRequest>, +) -> Result<Json<RegisterResponse>, ErrorResponseStatus<'static>> { + if !state.settings.open_registration { + return Err( + ErrorResponse::reply("this server is not open for registrations") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + for c in register.username.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' => {} + _ => { + return Err(ErrorResponse::reply( + "Only alphanumeric and hyphens (-) are allowed in usernames", + ) + .with_status(StatusCode::BAD_REQUEST)) + } + } + } + + let hashed = hash_secret(®ister.password); + + let new_user = NewUser { + email: register.email.clone(), + username: register.username.clone(), + password: hashed, + }; + + let db = &state.0.database; + let user_id = match db.add_user(&new_user).await { + Ok(id) => id, + Err(e) => { + error!("failed to add user: {}", e); + return Err( + ErrorResponse::reply("failed to add user").with_status(StatusCode::BAD_REQUEST) + ); + } + }; + + let token = Uuid::new_v4().as_simple().to_string(); + + let new_session = NewSession { + user_id, + token: (&token).into(), + }; + + if let Some(url) = &state.settings.register_webhook_url { + // Could probs be run on another thread, but it's ok atm + send_register_hook( + url, + state.settings.register_webhook_username.clone(), + register.username, + ) + .await; + } + + counter!("atuin_users_registered", 1); + + match db.add_session(&new_session).await { + Ok(_) => Ok(Json(RegisterResponse { session: token })), + Err(e) => { + error!("failed to add session: {}", e); + Err(ErrorResponse::reply("failed to register user") + .with_status(StatusCode::BAD_REQUEST)) + } + } +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<DeleteUserResponse>, ErrorResponseStatus<'static>> { + debug!("request to delete user {}", user.id); + + let db = &state.0.database; + if let Err(e) = db.delete_user(&user).await { + error!("failed to delete user: {}", e); + + return Err(ErrorResponse::reply("failed to delete user") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + counter!("atuin_users_deleted", 1); + + Ok(Json(DeleteUserResponse {})) +} + +#[instrument(skip_all, fields(user.id = user.id, change_password))] +pub async fn change_password<DB: Database>( + UserAuth(mut user): UserAuth, + state: State<AppState<DB>>, + Json(change_password): Json<ChangePasswordRequest>, +) -> Result<Json<ChangePasswordResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + + let verified = verify_str( + user.password.as_str(), + change_password.current_password.borrow(), + ); + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + let hashed = hash_secret(&change_password.new_password); + user.password = hashed; + + if let Err(e) = db.update_user_password(&user).await { + error!("failed to change user password: {}", e); + + return Err(ErrorResponse::reply("failed to change user password") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + Ok(Json(ChangePasswordResponse {})) +} + +#[instrument(skip_all, fields(user.username = login.username.as_str()))] +pub async fn login<DB: Database>( + state: State<AppState<DB>>, + login: Json<LoginRequest>, +) -> Result<Json<LoginResponse>, ErrorResponseStatus<'static>> { + let db = &state.0.database; + let user = match db.get_user(login.username.borrow()).await { + Ok(u) => u, + Err(DbError::NotFound) => { + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(e)) => { + error!("failed to get user {}: {}", login.username.clone(), e); + + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let session = match db.get_user_session(&user).await { + Ok(u) => u, + Err(DbError::NotFound) => { + debug!("user session not found for user id={}", user.id); + return Err(ErrorResponse::reply("user not found").with_status(StatusCode::NOT_FOUND)); + } + Err(DbError::Other(err)) => { + error!("database error for user {}: {}", login.username, err); + return Err(ErrorResponse::reply("database error") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + let verified = verify_str(user.password.as_str(), login.password.borrow()); + + if !verified { + return Err( + ErrorResponse::reply("password is not correct").with_status(StatusCode::UNAUTHORIZED) + ); + } + + Ok(Json(LoginResponse { + session: session.token, + })) +} + +fn hash_secret(password: &str) -> String { + let arg2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, Params::default()); + let salt = SaltString::generate(&mut OsRng); + let hash = arg2.hash_password(password.as_bytes(), &salt).unwrap(); + hash.to_string() +} diff --git a/crates/atuin-server/src/handlers/v0/me.rs b/crates/atuin-server/src/handlers/v0/me.rs new file mode 100644 index 00000000..7960b479 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/me.rs @@ -0,0 +1,16 @@ +use axum::Json; +use tracing::instrument; + +use crate::handlers::ErrorResponseStatus; +use crate::router::UserAuth; + +use atuin_common::api::*; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn get( + UserAuth(user): UserAuth, +) -> Result<Json<MeResponse>, ErrorResponseStatus<'static>> { + Ok(Json(MeResponse { + username: user.username, + })) +} diff --git a/crates/atuin-server/src/handlers/v0/mod.rs b/crates/atuin-server/src/handlers/v0/mod.rs new file mode 100644 index 00000000..d6f880f2 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod me; +pub(crate) mod record; +pub(crate) mod store; diff --git a/crates/atuin-server/src/handlers/v0/record.rs b/crates/atuin-server/src/handlers/v0/record.rs new file mode 100644 index 00000000..321c34c2 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/record.rs @@ -0,0 +1,112 @@ +use axum::{extract::Query, extract::State, http::StatusCode, Json}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use atuin_server_database::Database; + +use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus}; + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn post<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, + Json(records): Json<Vec<Record<EncryptedData>>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { database, settings }) = state; + + tracing::debug!( + count = records.len(), + user = user.username, + "request to add records" + ); + + counter!("atuin_record_uploaded", records.len() as u64); + + let keep = records + .iter() + .all(|r| r.data.data.len() <= settings.max_record_size || settings.max_record_size == 0); + + if !keep { + counter!("atuin_record_too_large", 1); + + return Err( + ErrorResponse::reply("could not add records; record too large") + .with_status(StatusCode::BAD_REQUEST), + ); + } + + if let Err(e) = database.add_records(&user, &records).await { + error!("failed to add record: {}", e); + + return Err(ErrorResponse::reply("failed to add record") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + }; + + Ok(()) +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn index<DB: Database>( + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<RecordStatus>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + let record_index = match database.status(&user).await { + Ok(index) => index, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + Ok(Json(record_index)) +} + +#[derive(Deserialize)] +pub struct NextParams { + host: HostId, + tag: String, + start: Option<RecordIdx>, + count: u64, +} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn next<DB: Database>( + params: Query<NextParams>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<Json<Vec<Record<EncryptedData>>>, ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + let params = params.0; + + let records = match database + .next_records(&user, params.host, params.tag, params.start, params.count) + .await + { + Ok(records) => records, + Err(e) => { + error!("failed to get record index: {}", e); + + return Err(ErrorResponse::reply("failed to calculate record index") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + }; + + counter!("atuin_record_downloaded", records.len() as u64); + + Ok(Json(records)) +} diff --git a/crates/atuin-server/src/handlers/v0/store.rs b/crates/atuin-server/src/handlers/v0/store.rs new file mode 100644 index 00000000..941f2487 --- /dev/null +++ b/crates/atuin-server/src/handlers/v0/store.rs @@ -0,0 +1,37 @@ +use axum::{extract::Query, extract::State, http::StatusCode}; +use metrics::counter; +use serde::Deserialize; +use tracing::{error, instrument}; + +use crate::{ + handlers::{ErrorResponse, ErrorResponseStatus, RespExt}, + router::{AppState, UserAuth}, +}; +use atuin_server_database::Database; + +#[derive(Deserialize)] +pub struct DeleteParams {} + +#[instrument(skip_all, fields(user.id = user.id))] +pub async fn delete<DB: Database>( + _params: Query<DeleteParams>, + UserAuth(user): UserAuth, + state: State<AppState<DB>>, +) -> Result<(), ErrorResponseStatus<'static>> { + let State(AppState { + database, + settings: _, + }) = state; + + if let Err(e) = database.delete_store(&user).await { + counter!("atuin_store_delete_failed", 1); + error!("failed to delete store {e:?}"); + + return Err(ErrorResponse::reply("failed to delete store") + .with_status(StatusCode::INTERNAL_SERVER_ERROR)); + } + + counter!("atuin_store_deleted", 1); + + Ok(()) +} diff --git a/crates/atuin-server/src/lib.rs b/crates/atuin-server/src/lib.rs new file mode 100644 index 00000000..a0c104dc --- /dev/null +++ b/crates/atuin-server/src/lib.rs @@ -0,0 +1,144 @@ +#![forbid(unsafe_code)] + +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use atuin_server_database::Database; +use axum::{serve, Router}; +use axum_server::Handle; +use eyre::{Context, Result}; + +mod handlers; +mod metrics; +mod router; +mod utils; + +use rustls::ServerConfig; +pub use settings::example_config; +pub use settings::Settings; + +pub mod settings; + +use tokio::net::TcpListener; +use tokio::signal; + +#[cfg(target_family = "unix")] +async fn shutdown_signal() { + let mut term = signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to register signal handler"); + let mut interrupt = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("failed to register signal handler"); + + tokio::select! { + _ = term.recv() => {}, + _ = interrupt.recv() => {}, + }; + eprintln!("Shutting down gracefully..."); +} + +#[cfg(target_family = "windows")] +async fn shutdown_signal() { + signal::windows::ctrl_c() + .expect("failed to register signal handler") + .recv() + .await; + eprintln!("Shutting down gracefully..."); +} + +pub async fn launch<Db: Database>( + settings: Settings<Db::Settings>, + addr: SocketAddr, +) -> Result<()> { + if settings.tls.enable { + launch_with_tls::<Db>(settings, addr, shutdown_signal()).await + } else { + launch_with_tcp_listener::<Db>( + settings, + TcpListener::bind(addr) + .await + .context("could not connect to socket")?, + shutdown_signal(), + ) + .await + } +} + +pub async fn launch_with_tcp_listener<Db: Database>( + settings: Settings<Db::Settings>, + listener: TcpListener, + shutdown: impl Future<Output = ()> + Send + 'static, +) -> Result<()> { + let r = make_router::<Db>(settings).await?; + + serve(listener, r.into_make_service()) + .with_graceful_shutdown(shutdown) + .await?; + + Ok(()) +} + +async fn launch_with_tls<Db: Database>( + settings: Settings<Db::Settings>, + addr: SocketAddr, + shutdown: impl Future<Output = ()>, +) -> Result<()> { + let certificates = settings.tls.certificates()?; + let pkey = settings.tls.private_key()?; + + let server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(certificates, pkey)?; + + let server_config = Arc::new(server_config); + let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(server_config); + + let r = make_router::<Db>(settings).await?; + + let handle = Handle::new(); + + let server = axum_server::bind_rustls(addr, rustls_config) + .handle(handle.clone()) + .serve(r.into_make_service()); + + tokio::select! { + _ = server => {} + _ = shutdown => { + handle.graceful_shutdown(None); + } + } + + Ok(()) +} + +// The separate listener means it's much easier to ensure metrics are not accidentally exposed to +// the public. +pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> { + let listener = TcpListener::bind((host, port)) + .await + .context("failed to bind metrics tcp")?; + + let recorder_handle = metrics::setup_metrics_recorder(); + + let router = Router::new().route( + "/metrics", + axum::routing::get(move || std::future::ready(recorder_handle.render())), + ); + + serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +async fn make_router<Db: Database>( + settings: Settings<<Db as Database>::Settings>, +) -> Result<Router, eyre::Error> { + let db = Db::new(&settings.db_settings) + .await + .wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?; + let r = router::router(db, settings); + Ok(r) +} diff --git a/crates/atuin-server/src/metrics.rs b/crates/atuin-server/src/metrics.rs new file mode 100644 index 00000000..0a7ac6bd --- /dev/null +++ b/crates/atuin-server/src/metrics.rs @@ -0,0 +1,56 @@ +use std::time::Instant; + +use axum::{ + extract::{MatchedPath, Request}, + middleware::Next, + response::IntoResponse, +}; +use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; + +pub fn setup_metrics_recorder() -> PrometheusHandle { + const EXPONENTIAL_SECONDS: &[f64] = &[ + 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, + ]; + + PrometheusBuilder::new() + .set_buckets_for_metric( + Matcher::Full("http_requests_duration_seconds".to_string()), + EXPONENTIAL_SECONDS, + ) + .unwrap() + .install_recorder() + .unwrap() +} + +/// Middleware to record some common HTTP metrics +/// Generic over B to allow for arbitrary body types (eg Vec<u8>, Streams, a deserialized thing, etc) +/// Someday tower-http might provide a metrics middleware: https://github.com/tower-rs/tower-http/issues/57 +pub async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { + let start = Instant::now(); + + let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() { + matched_path.as_str().to_owned() + } else { + req.uri().path().to_owned() + }; + + let method = req.method().clone(); + + // Run the rest of the request handling first, so we can measure it and get response + // codes. + let response = next.run(req).await; + + let latency = start.elapsed().as_secs_f64(); + let status = response.status().as_u16().to_string(); + + let labels = [ + ("method", method.to_string()), + ("path", path), + ("status", status), + ]; + + metrics::increment_counter!("http_requests_total", &labels); + metrics::histogram!("http_requests_duration_seconds", latency, &labels); + + response +} diff --git a/crates/atuin-server/src/router.rs b/crates/atuin-server/src/router.rs new file mode 100644 index 00000000..96dff2bd --- /dev/null +++ b/crates/atuin-server/src/router.rs @@ -0,0 +1,149 @@ +use async_trait::async_trait; +use atuin_common::api::{ErrorResponse, ATUIN_CARGO_VERSION, ATUIN_HEADER_VERSION}; +use axum::{ + extract::{FromRequestParts, Request}, + http::{self, request::Parts}, + middleware::Next, + response::{IntoResponse, Response}, + routing::{delete, get, patch, post}, + Router, +}; +use eyre::Result; +use tower::ServiceBuilder; +use tower_http::trace::TraceLayer; + +use super::handlers; +use crate::{ + handlers::{ErrorResponseStatus, RespExt}, + metrics, + settings::Settings, +}; +use atuin_server_database::{models::User, Database, DbError}; + +pub struct UserAuth(pub User); + +#[async_trait] +impl<DB: Send + Sync> FromRequestParts<AppState<DB>> for UserAuth +where + DB: Database, +{ + type Rejection = ErrorResponseStatus<'static>; + + async fn from_request_parts( + req: &mut Parts, + state: &AppState<DB>, + ) -> Result<Self, Self::Rejection> { + let auth_header = req + .headers + .get(http::header::AUTHORIZATION) + .ok_or_else(|| { + ErrorResponse::reply("missing authorization header") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let auth_header = auth_header.to_str().map_err(|_| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + let (typ, token) = auth_header.split_once(' ').ok_or_else(|| { + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST) + })?; + + if typ != "Token" { + return Err( + ErrorResponse::reply("invalid authorization header encoding") + .with_status(http::StatusCode::BAD_REQUEST), + ); + } + + let user = state + .database + .get_session_user(token) + .await + .map_err(|e| match e { + DbError::NotFound => ErrorResponse::reply("session not found") + .with_status(http::StatusCode::FORBIDDEN), + DbError::Other(e) => { + tracing::error!(error = ?e, "could not query user session"); + ErrorResponse::reply("could not query user session") + .with_status(http::StatusCode::INTERNAL_SERVER_ERROR) + } + })?; + + Ok(UserAuth(user)) + } +} + +async fn teapot() -> impl IntoResponse { + // This used to return 418: 🫖 + // Much as it was fun, it wasn't as useful or informative as it should be + (http::StatusCode::NOT_FOUND, "404 not found") +} + +async fn clacks_overhead(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + + let gnu_terry_value = "GNU Terry Pratchett, Kris Nova"; + let gnu_terry_header = "X-Clacks-Overhead"; + + response + .headers_mut() + .insert(gnu_terry_header, gnu_terry_value.parse().unwrap()); + response +} + +/// Ensure that we only try and sync with clients on the same major version +async fn semver(request: Request, next: Next) -> Response { + let mut response = next.run(request).await; + response + .headers_mut() + .insert(ATUIN_HEADER_VERSION, ATUIN_CARGO_VERSION.parse().unwrap()); + + response +} + +#[derive(Clone)] +pub struct AppState<DB: Database> { + pub database: DB, + pub settings: Settings<DB::Settings>, +} + +pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router { + let routes = Router::new() + .route("/", get(handlers::index)) + .route("/sync/count", get(handlers::history::count)) + .route("/sync/history", get(handlers::history::list)) + .route("/sync/calendar/:focus", get(handlers::history::calendar)) + .route("/sync/status", get(handlers::status::status)) + .route("/history", post(handlers::history::add)) + .route("/history", delete(handlers::history::delete)) + .route("/user/:username", get(handlers::user::get)) + .route("/account", delete(handlers::user::delete)) + .route("/account/password", patch(handlers::user::change_password)) + .route("/register", post(handlers::user::register)) + .route("/login", post(handlers::user::login)) + .route("/record", post(handlers::record::post::<DB>)) + .route("/record", get(handlers::record::index::<DB>)) + .route("/record/next", get(handlers::record::next)) + .route("/api/v0/me", get(handlers::v0::me::get)) + .route("/api/v0/record", post(handlers::v0::record::post)) + .route("/api/v0/record", get(handlers::v0::record::index)) + .route("/api/v0/record/next", get(handlers::v0::record::next)) + .route("/api/v0/store", delete(handlers::v0::store::delete)); + + let path = settings.path.as_str(); + if path.is_empty() { + routes + } else { + Router::new().nest(path, routes) + } + .fallback(teapot) + .with_state(AppState { database, settings }) + .layer( + ServiceBuilder::new() + .layer(axum::middleware::from_fn(clacks_overhead)) + .layer(TraceLayer::new_for_http()) + .layer(axum::middleware::from_fn(metrics::track_metrics)) + .layer(axum::middleware::from_fn(semver)), + ) +} diff --git a/crates/atuin-server/src/settings.rs b/crates/atuin-server/src/settings.rs new file mode 100644 index 00000000..2d00df36 --- /dev/null +++ b/crates/atuin-server/src/settings.rs @@ -0,0 +1,151 @@ +use std::{io::prelude::*, path::PathBuf}; + +use config::{Config, Environment, File as ConfigFile, FileFormat}; +use eyre::{bail, eyre, Context, Result}; +use fs_err::{create_dir_all, File}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +static EXAMPLE_CONFIG: &str = include_str!("../server.toml"); + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Metrics { + pub enable: bool, + pub host: String, + pub port: u16, +} + +impl Default for Metrics { + fn default() -> Self { + Self { + enable: false, + host: String::from("127.0.0.1"), + port: 9001, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Settings<DbSettings> { + pub host: String, + pub port: u16, + pub path: String, + pub open_registration: bool, + pub max_history_length: usize, + pub max_record_size: usize, + pub page_size: i64, + pub register_webhook_url: Option<String>, + pub register_webhook_username: String, + pub metrics: Metrics, + pub tls: Tls, + + #[serde(flatten)] + pub db_settings: DbSettings, +} + +impl<DbSettings: DeserializeOwned> Settings<DbSettings> { + pub fn new() -> Result<Self> { + let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") { + PathBuf::from(p) + } else { + let mut config_file = PathBuf::new(); + let config_dir = atuin_common::utils::config_dir(); + config_file.push(config_dir); + config_file + }; + + config_file.push("server.toml"); + + // create the config file if it does not exist + let mut config_builder = Config::builder() + .set_default("host", "127.0.0.1")? + .set_default("port", 8888)? + .set_default("open_registration", false)? + .set_default("max_history_length", 8192)? + .set_default("max_record_size", 1024 * 1024 * 1024)? // pretty chonky + .set_default("path", "")? + .set_default("register_webhook_username", "")? + .set_default("page_size", 1100)? + .set_default("metrics.enable", false)? + .set_default("metrics.host", "127.0.0.1")? + .set_default("metrics.port", 9001)? + .set_default("tls.enable", false)? + .set_default("tls.cert_path", "")? + .set_default("tls.pkey_path", "")? + .add_source( + Environment::with_prefix("atuin") + .prefix_separator("_") + .separator("__"), + ); + + config_builder = if config_file.exists() { + config_builder.add_source(ConfigFile::new( + config_file.to_str().unwrap(), + FileFormat::Toml, + )) + } else { + create_dir_all(config_file.parent().unwrap())?; + let mut file = File::create(config_file)?; + file.write_all(EXAMPLE_CONFIG.as_bytes())?; + + config_builder + }; + + let config = config_builder.build()?; + + config + .try_deserialize() + .map_err(|e| eyre!("failed to deserialize: {}", e)) + } +} + +pub fn example_config() -> &'static str { + EXAMPLE_CONFIG +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct Tls { + pub enable: bool, + + pub cert_path: PathBuf, + pub pkey_path: PathBuf, +} + +impl Tls { + pub fn certificates(&self) -> Result<Vec<rustls::Certificate>> { + let cert_file = std::fs::File::open(&self.cert_path) + .with_context(|| format!("tls.cert_path {:?} is missing", self.cert_path))?; + let mut reader = std::io::BufReader::new(cert_file); + let certs: Vec<_> = rustls_pemfile::certs(&mut reader) + .map(|c| c.map(|c| rustls::Certificate(c.to_vec()))) + .collect::<Result<Vec<_>, _>>() + .with_context(|| format!("tls.cert_path {:?} is invalid", self.cert_path))?; + + if certs.is_empty() { + bail!( + "tls.cert_path {:?} must have at least one certificate", + self.cert_path + ); + } + + Ok(certs) + } + + pub fn private_key(&self) -> Result<rustls::PrivateKey> { + let pkey_file = std::fs::File::open(&self.pkey_path) + .with_context(|| format!("tls.pkey_path {:?} is missing", self.pkey_path))?; + let mut reader = std::io::BufReader::new(pkey_file); + let keys = rustls_pemfile::pkcs8_private_keys(&mut reader) + .map(|c| c.map(|c| rustls::PrivateKey(c.secret_pkcs8_der().to_vec()))) + .collect::<Result<Vec<_>, _>>() + .with_context(|| format!("tls.pkey_path {:?} is not PKCS8-encoded", self.pkey_path))?; + + if keys.is_empty() { + bail!( + "tls.pkey_path {:?} must have at least one private key", + self.pkey_path + ); + } + + Ok(keys[0].clone()) + } +} diff --git a/crates/atuin-server/src/utils.rs b/crates/atuin-server/src/utils.rs new file mode 100644 index 00000000..12e9ac1b --- /dev/null +++ b/crates/atuin-server/src/utils.rs @@ -0,0 +1,15 @@ +use eyre::Result; +use semver::{Version, VersionReq}; + +pub fn client_version_min(user_agent: &str, req: &str) -> Result<bool> { + if user_agent.is_empty() { + return Ok(false); + } + + let version = user_agent.replace("atuin/", ""); + + let req = VersionReq::parse(req)?; + let version = Version::parse(version.as_str())?; + + Ok(req.matches(&version)) +} diff --git a/crates/atuin/Cargo.toml b/crates/atuin/Cargo.toml new file mode 100644 index 00000000..9c112d73 --- /dev/null +++ b/crates/atuin/Cargo.toml @@ -0,0 +1,95 @@ +[package] +name = "atuin" +edition = "2021" +description = "atuin - magical shell history" +readme = "./README.md" + +rust-version = { workspace = true } +version = { workspace = true } +authors = { workspace = true } +license = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } + +[package.metadata.binstall] +pkg-url = "{ repo }/releases/download/v{ version }/{ name }-v{ version }-{ target }.tar.gz" +bin-dir = "{ name }-v{ version }-{ target }/{ bin }{ binary-ext }" +pkg-fmt = "tgz" + +[package.metadata.deb] +maintainer = "Ellie Huxtable <ellie@elliehuxtable.com>" +copyright = "2021, Ellie Huxtable <ellie@elliehuxtable.com>" +license-file = ["LICENSE"] +depends = "$auto" +section = "utility" + +[package.metadata.rpm] +package = "atuin" + +[package.metadata.rpm.cargo] +buildflags = ["--release"] + +[package.metadata.rpm.targets] +atuin = { path = "/usr/bin/atuin" } + +[features] +default = ["client", "sync", "server", "clipboard", "check-update"] +client = ["atuin-client"] +sync = ["atuin-client/sync"] +server = ["atuin-server", "atuin-server-postgres", "tracing-subscriber"] +clipboard = ["cli-clipboard"] +check-update = ["atuin-client/check-update"] + +[dependencies] +atuin-server-postgres = { path = "../atuin-server-postgres", version = "18.2.0", optional = true } +atuin-server = { path = "../atuin-server", version = "18.2.0", optional = true } +atuin-client = { path = "../atuin-client", version = "18.2.0", optional = true, default-features = false } +atuin-common = { path = "../atuin-common", version = "18.2.0" } +atuin-dotfiles = { path = "../atuin-dotfiles", version = "0.2.0" } + +log = { workspace = true } +env_logger = "0.11.2" +time = { workspace = true } +eyre = { workspace = true } +directories = { workspace = true } +indicatif = "0.17.5" +serde = { workspace = true } +serde_json = { workspace = true } +crossterm = { version = "0.27", features = ["use-dev-tty"] } +unicode-width = "0.1" +itertools = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +interim = { workspace = true } +base64 = { workspace = true } +clap = { workspace = true } +clap_complete = "4.5.1" +clap_complete_nushell = "4.5.1" +fs-err = { workspace = true } +whoami = { workspace = true } +rpassword = "7.0" +semver = { workspace = true } +rustix = { workspace = true } +runtime-format = "0.1.3" +tiny-bip39 = "1" +futures-util = "0.3" +fuzzy-matcher = "0.3.7" +colored = "2.0.4" +ratatui = "0.25" +tracing = "0.1" +uuid = { workspace = true } +unicode-segmentation = "1.11.0" +serde_yaml = "0.9.32" +sysinfo = "0.30.7" + +[target.'cfg(any(target_os = "windows", target_os = "macos", target_os = "linux"))'.dependencies] +cli-clipboard = { version = "0.4.0", optional = true } + +[dependencies.tracing-subscriber] +version = "0.3" +default-features = false +features = ["ansi", "fmt", "registry", "env-filter"] +optional = true + +[dev-dependencies] +tracing-tree = "0.3" diff --git a/crates/atuin/LICENSE b/crates/atuin/LICENSE new file mode 100644 index 00000000..7dfc9b58 --- /dev/null +++ b/crates/atuin/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Ellie Huxtable + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/atuin/README.md b/crates/atuin/README.md new file mode 120000 index 00000000..32d46ee8 --- /dev/null +++ b/crates/atuin/README.md @@ -0,0 +1 @@ +../README.md
\ No newline at end of file diff --git a/crates/atuin/build.rs b/crates/atuin/build.rs new file mode 100644 index 00000000..f24cf1bf --- /dev/null +++ b/crates/atuin/build.rs @@ -0,0 +1,11 @@ +use std::process::Command; +fn main() { + let output = Command::new("git").args(["rev-parse", "HEAD"]).output(); + + let sha = match output { + Ok(sha) => String::from_utf8(sha.stdout).unwrap(), + Err(_) => String::from("NO_GIT"), + }; + + println!("cargo:rustc-env=GIT_HASH={}", sha); +} diff --git a/crates/atuin/src/command/CONTRIBUTORS b/crates/atuin/src/command/CONTRIBUTORS new file mode 120000 index 00000000..1ca4115a --- /dev/null +++ b/crates/atuin/src/command/CONTRIBUTORS @@ -0,0 +1 @@ +../../../../CONTRIBUTORS
\ No newline at end of file diff --git a/crates/atuin/src/command/client.rs b/crates/atuin/src/command/client.rs new file mode 100644 index 00000000..23040695 --- /dev/null +++ b/crates/atuin/src/command/client.rs @@ -0,0 +1,144 @@ +use std::path::PathBuf; + +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use atuin_client::{database::Sqlite, record::sqlite_store::SqliteStore, settings::Settings}; +use env_logger::Builder; + +#[cfg(feature = "sync")] +mod sync; + +#[cfg(feature = "sync")] +mod account; + +mod default_config; +mod doctor; +mod dotfiles; +mod history; +mod import; +mod info; +mod init; +mod kv; +mod search; +mod stats; +mod store; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Manipulate shell history + #[command(subcommand)] + History(history::Cmd), + + /// Import shell history from file + #[command(subcommand)] + Import(import::Cmd), + + /// Calculate statistics for your history + Stats(stats::Cmd), + + /// Interactive history search + Search(search::Cmd), + + #[cfg(feature = "sync")] + #[command(flatten)] + Sync(sync::Cmd), + + /// Manage your sync account + #[cfg(feature = "sync")] + Account(account::Cmd), + + /// Get or set small key-value pairs + #[command(subcommand)] + Kv(kv::Cmd), + + /// Manage the atuin data store + #[command(subcommand)] + Store(store::Cmd), + + /// Manage your dotfiles with Atuin + #[command(subcommand)] + Dotfiles(dotfiles::Cmd), + + /// Print Atuin's shell init script + #[command()] + Init(init::Cmd), + + /// Information about dotfiles locations and ENV vars + #[command()] + Info, + + /// Run the doctor to check for common issues + #[command()] + Doctor, + + /// Print example configuration + #[command()] + DefaultConfig, +} + +impl Cmd { + pub fn run(self) -> Result<()> { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let settings = Settings::new().wrap_err("could not load client settings")?; + let res = runtime.block_on(self.run_inner(settings)); + + runtime.shutdown_timeout(std::time::Duration::from_millis(50)); + + res + } + + async fn run_inner(self, mut settings: Settings) -> Result<()> { + Builder::new() + .filter_level(log::LevelFilter::Off) + .filter_module("sqlx_sqlite::regexp", log::LevelFilter::Off) + .parse_env("ATUIN_LOG") + .init(); + + tracing::trace!(command = ?self, "client command"); + + let db_path = PathBuf::from(settings.db_path.as_str()); + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + + let db = Sqlite::new(db_path, settings.local_timeout).await?; + let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + match self { + Self::History(history) => history.run(&settings, &db, sqlite_store).await, + Self::Import(import) => import.run(&db).await, + Self::Stats(stats) => stats.run(&db, &settings).await, + Self::Search(search) => search.run(db, &mut settings, sqlite_store).await, + + #[cfg(feature = "sync")] + Self::Sync(sync) => sync.run(settings, &db, sqlite_store).await, + + #[cfg(feature = "sync")] + Self::Account(account) => account.run(settings, sqlite_store).await, + + Self::Kv(kv) => kv.run(&settings, &sqlite_store).await, + + Self::Store(store) => store.run(&settings, &db, sqlite_store).await, + + Self::Dotfiles(dotfiles) => dotfiles.run(&settings, sqlite_store).await, + + Self::Init(init) => init.run(&settings).await, + + Self::Info => { + info::run(&settings); + Ok(()) + } + + Self::Doctor => doctor::run(&settings), + + Self::DefaultConfig => { + default_config::run(); + Ok(()) + } + } + } +} diff --git a/crates/atuin/src/command/client/account.rs b/crates/atuin/src/command/client/account.rs new file mode 100644 index 00000000..e31e6208 --- /dev/null +++ b/crates/atuin/src/command/client/account.rs @@ -0,0 +1,47 @@ +use clap::{Args, Subcommand}; +use eyre::Result; + +use atuin_client::record::sqlite_store::SqliteStore; +use atuin_client::settings::Settings; + +pub mod change_password; +pub mod delete; +pub mod login; +pub mod logout; +pub mod register; + +#[derive(Args, Debug)] +pub struct Cmd { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +pub enum Commands { + /// Login to the configured server + Login(login::Cmd), + + // Register a new account + Register(register::Cmd), + + /// Log out + Logout, + + /// Delete your account, and all synced data + Delete, + + /// Change your password + ChangePassword(change_password::Cmd), +} + +impl Cmd { + pub async fn run(self, settings: Settings, store: SqliteStore) -> Result<()> { + match self.command { + Commands::Login(l) => l.run(&settings, &store).await, + Commands::Register(r) => r.run(&settings).await, + Commands::Logout => logout::run(&settings), + Commands::Delete => delete::run(&settings).await, + Commands::ChangePassword(c) => c.run(&settings).await, + } + } +} diff --git a/crates/atuin/src/command/client/account/change_password.rs b/crates/atuin/src/command/client/account/change_password.rs new file mode 100644 index 00000000..3b5ad6f5 --- /dev/null +++ b/crates/atuin/src/command/client/account/change_password.rs @@ -0,0 +1,57 @@ +use clap::Parser; +use eyre::{bail, Result}; + +use atuin_client::{api_client, settings::Settings}; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub current_password: Option<String>, + + #[clap(long, short)] + pub new_password: Option<String>, +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + run(settings, &self.current_password, &self.new_password).await + } +} + +pub async fn run( + settings: &Settings, + current_password: &Option<String>, + new_password: &Option<String>, +) -> Result<()> { + let client = api_client::Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + let current_password = current_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the current password: ").expect("Failed to read from input") + }); + + if current_password.is_empty() { + bail!("please provide the current password"); + } + + let new_password = new_password.clone().unwrap_or_else(|| { + prompt_password("Please enter the new password: ").expect("Failed to read from input") + }); + + if new_password.is_empty() { + bail!("please provide a new password"); + } + + client + .change_password(current_password, new_password) + .await?; + + println!("Account password successfully changed!"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/account/delete.rs b/crates/atuin/src/command/client/account/delete.rs new file mode 100644 index 00000000..3591c6f3 --- /dev/null +++ b/crates/atuin/src/command/client/account/delete.rs @@ -0,0 +1,30 @@ +use atuin_client::{api_client, settings::Settings}; +use eyre::{bail, Result}; +use std::fs::remove_file; +use std::path::PathBuf; + +pub async fn run(settings: &Settings) -> Result<()> { + let session_path = settings.session_path.as_str(); + + if !PathBuf::from(session_path).exists() { + bail!("You are not logged in"); + } + + let client = api_client::Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + client.delete().await?; + + // Fixes stale session+key when account is deleted via CLI. + if PathBuf::from(session_path).exists() { + remove_file(PathBuf::from(session_path))?; + } + + println!("Your account is deleted"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/account/login.rs b/crates/atuin/src/command/client/account/login.rs new file mode 100644 index 00000000..9cd53399 --- /dev/null +++ b/crates/atuin/src/command/client/account/login.rs @@ -0,0 +1,177 @@ +use std::{io, path::PathBuf}; + +use clap::Parser; +use eyre::{bail, Context, Result}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use atuin_client::{ + api_client, + encryption::{decode_key, encode_key, load_key, new_key, Key}, + record::sqlite_store::SqliteStore, + record::store::Store, + settings::Settings, +}; +use atuin_common::api::LoginRequest; +use rpassword::prompt_password; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + /// The encryption key for your account + #[clap(long, short)] + pub key: Option<String>, +} + +fn get_input() -> Result<String> { + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + Ok(input.trim_end_matches(&['\r', '\n'][..]).to_string()) +} + +impl Cmd { + pub async fn run(&self, settings: &Settings, store: &SqliteStore) -> Result<()> { + let session_path = settings.session_path.as_str(); + + if PathBuf::from(session_path).exists() { + println!( + "You are already logged in! Please run 'atuin logout' if you wish to login again" + ); + + return Ok(()); + } + + let username = or_user_input(&self.username, "username"); + let password = self.password.clone().unwrap_or_else(read_user_password); + + let key_path = settings.key_path.as_str(); + let key_path = PathBuf::from(key_path); + + let key = or_user_input(&self.key, "encryption key [blank to use existing key file]"); + + // if provided, the key may be EITHER base64, or a bip mnemonic + // try to normalize on base64 + let key = if key.is_empty() { + key + } else { + // try parse the key as a mnemonic... + match bip39::Mnemonic::from_phrase(&key, bip39::Language::English) { + Ok(mnemonic) => encode_key(Key::from_slice(mnemonic.entropy()))?, + Err(err) => { + if let Some(err) = err.downcast_ref::<bip39::ErrorKind>() { + match err { + // assume they copied in the base64 key + bip39::ErrorKind::InvalidWord => key, + bip39::ErrorKind::InvalidChecksum => { + bail!("key mnemonic was not valid") + } + bip39::ErrorKind::InvalidKeysize(_) + | bip39::ErrorKind::InvalidWordLength(_) + | bip39::ErrorKind::InvalidEntropyLength(_, _) => { + bail!("key was not the correct length") + } + } + } else { + // unknown error. assume they copied the base64 key + key + } + } + } + }; + + // I've simplified this a little, but it could really do with a refactor + // Annoyingly, it's also very important to get it correct + if key.is_empty() { + if key_path.exists() { + let bytes = fs_err::read_to_string(key_path) + .context("existing key file couldn't be read")?; + if decode_key(bytes).is_err() { + bail!("the key in existing key file was invalid"); + } + } else { + println!("No key file exists, creating a new"); + let _key = new_key(settings)?; + } + } else if !key_path.exists() { + if decode_key(key.clone()).is_err() { + bail!("the specified key was invalid"); + } + + let mut file = File::create(key_path).await?; + file.write_all(key.as_bytes()).await?; + } else { + // we now know that the user has logged in specifying a key, AND that the key path + // exists + + // 1. check if the saved key and the provided key match. if so, nothing to do. + // 2. if not, re-encrypt the local history and overwrite the key + let current_key: [u8; 32] = load_key(settings)?.into(); + + let encoded = key.clone(); // gonna want to save it in a bit + let new_key: [u8; 32] = decode_key(key) + .context("could not decode provided key - is not valid base64")? + .into(); + + if new_key != current_key { + println!("\nRe-encrypting local store with new key"); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Writing new key"); + let mut file = File::create(key_path).await?; + file.write_all(encoded.as_bytes()).await?; + } + } + + let session = api_client::login( + settings.sync_address.as_str(), + LoginRequest { username, password }, + ) + .await?; + + let session_path = settings.session_path.as_str(); + let mut file = File::create(session_path).await?; + file.write_all(session.session.as_bytes()).await?; + + println!("Logged in!"); + + Ok(()) + } +} + +pub(super) fn or_user_input(value: &'_ Option<String>, name: &'static str) -> String { + value.clone().unwrap_or_else(|| read_user_input(name)) +} + +pub(super) fn read_user_password() -> String { + let password = prompt_password("Please enter password: "); + password.expect("Failed to read from input") +} + +fn read_user_input(name: &'static str) -> String { + eprint!("Please enter {name}: "); + get_input().expect("Failed to read from input") +} + +#[cfg(test)] +mod tests { + use atuin_client::encryption::Key; + + #[test] + fn mnemonic_round_trip() { + let key = Key::from([ + 3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 9, 3, 2, 3, 8, 4, 6, 2, 6, 4, 3, 3, 8, 3, 2, + 7, 9, 5, + ]); + let phrase = bip39::Mnemonic::from_entropy(&key, bip39::Language::English) + .unwrap() + .into_phrase(); + let mnemonic = bip39::Mnemonic::from_phrase(&phrase, bip39::Language::English).unwrap(); + assert_eq!(mnemonic.entropy(), key.as_slice()); + assert_eq!(phrase, "adapt amused able anxiety mother adapt beef gaze amount else seat alcohol cage lottery avoid scare alcohol cactus school avoid coral adjust catch pink"); + } +} diff --git a/crates/atuin/src/command/client/account/logout.rs b/crates/atuin/src/command/client/account/logout.rs new file mode 100644 index 00000000..90b49d6d --- /dev/null +++ b/crates/atuin/src/command/client/account/logout.rs @@ -0,0 +1,19 @@ +use std::path::PathBuf; + +use eyre::{Context, Result}; +use fs_err::remove_file; + +use atuin_client::settings::Settings; + +pub fn run(settings: &Settings) -> Result<()> { + let session_path = settings.session_path.as_str(); + + if PathBuf::from(session_path).exists() { + remove_file(session_path).context("Failed to remove session file")?; + println!("You have logged out!"); + } else { + println!("You are not logged in"); + } + + Ok(()) +} diff --git a/crates/atuin/src/command/client/account/register.rs b/crates/atuin/src/command/client/account/register.rs new file mode 100644 index 00000000..96b7d7d6 --- /dev/null +++ b/crates/atuin/src/command/client/account/register.rs @@ -0,0 +1,55 @@ +use clap::Parser; +use eyre::{bail, Result}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use atuin_client::{api_client, settings::Settings}; + +#[derive(Parser, Debug)] +pub struct Cmd { + #[clap(long, short)] + pub username: Option<String>, + + #[clap(long, short)] + pub password: Option<String>, + + #[clap(long, short)] + pub email: Option<String>, +} + +impl Cmd { + pub async fn run(self, settings: &Settings) -> Result<()> { + run(settings, &self.username, &self.email, &self.password).await + } +} + +pub async fn run( + settings: &Settings, + username: &Option<String>, + email: &Option<String>, + password: &Option<String>, +) -> Result<()> { + use super::login::or_user_input; + println!("Registering for an Atuin Sync account"); + + let username = or_user_input(username, "username"); + let email = or_user_input(email, "email"); + + let password = password + .clone() + .unwrap_or_else(super::login::read_user_password); + + if password.is_empty() { + bail!("please provide a password"); + } + + let session = + api_client::register(settings.sync_address.as_str(), &username, &email, &password).await?; + + let path = settings.session_path.as_str(); + let mut file = File::create(path).await?; + file.write_all(session.session.as_bytes()).await?; + + let _key = atuin_client::encryption::load_key(settings)?; + + Ok(()) +} diff --git a/crates/atuin/src/command/client/default_config.rs b/crates/atuin/src/command/client/default_config.rs new file mode 100644 index 00000000..f51e45c2 --- /dev/null +++ b/crates/atuin/src/command/client/default_config.rs @@ -0,0 +1,5 @@ +use atuin_client::settings::Settings; + +pub fn run() { + println!("{}", Settings::example_config()); +} diff --git a/crates/atuin/src/command/client/doctor.rs b/crates/atuin/src/command/client/doctor.rs new file mode 100644 index 00000000..48659ed1 --- /dev/null +++ b/crates/atuin/src/command/client/doctor.rs @@ -0,0 +1,346 @@ +use std::process::Command; +use std::{env, path::PathBuf, str::FromStr}; + +use atuin_client::settings::Settings; +use atuin_common::shell::{shell_name, Shell}; +use colored::Colorize; +use eyre::Result; +use serde::{Deserialize, Serialize}; + +use sysinfo::{get_current_pid, Disks, System}; + +#[derive(Debug, Serialize, Deserialize)] +struct ShellInfo { + pub name: String, + + // best-effort, not supported on all OSes + pub default: String, + + // Detect some shell plugins that the user has installed. + // I'm just going to start with preexec/blesh + pub plugins: Vec<String>, + + // The preexec framework used in the current session, if Atuin is loaded. + pub preexec: Option<String>, +} + +impl ShellInfo { + // HACK ALERT! + // Many of the shell vars we need to detect are not exported :( + // So, we're going to run a interactive session and directly check the + // variable. There's a chance this won't work, so it should not be fatal. + // + // Every shell we support handles `shell -ic 'command'` + fn shellvar_exists(shell: &str, var: &str) -> bool { + let cmd = Command::new(shell) + .args([ + "-ic", + format!("[ -z ${var} ] || echo ATUIN_DOCTOR_ENV_FOUND").as_str(), + ]) + .output() + .map_or(String::new(), |v| { + let out = v.stdout; + String::from_utf8(out).unwrap_or_default() + }); + + cmd.contains("ATUIN_DOCTOR_ENV_FOUND") + } + + fn detect_preexec_framework(shell: &str) -> Option<String> { + if env::var("ATUIN_SESSION").ok().is_none() { + None + } else if shell.starts_with("bash") || shell == "sh" { + env::var("ATUIN_PREEXEC_BACKEND") + .ok() + .filter(|value| !value.is_empty()) + .and_then(|atuin_preexec_backend| { + atuin_preexec_backend.rfind(':').and_then(|pos_colon| { + u32::from_str(&atuin_preexec_backend[..pos_colon]) + .ok() + .is_some_and(|preexec_shlvl| { + env::var("SHLVL") + .ok() + .and_then(|shlvl| u32::from_str(&shlvl).ok()) + .is_some_and(|shlvl| shlvl == preexec_shlvl) + }) + .then(|| atuin_preexec_backend[pos_colon + 1..].to_string()) + }) + }) + } else { + Some("built-in".to_string()) + } + } + + fn validate_plugin_blesh( + _shell: &str, + shell_process: &sysinfo::Process, + ble_session_id: &str, + ) -> Option<String> { + ble_session_id + .split('/') + .nth(1) + .and_then(|field| u32::from_str(field).ok()) + .filter(|&blesh_pid| blesh_pid == shell_process.pid().as_u32()) + .map(|_| "blesh".to_string()) + } + + pub fn plugins(shell: &str, shell_process: &sysinfo::Process) -> Vec<String> { + // consider a different detection approach if there are plugins + // that don't set shell vars + + enum PluginShellType { + Any, + Bash, + + // Note: these are currently unused + #[allow(dead_code)] + Zsh, + #[allow(dead_code)] + Fish, + #[allow(dead_code)] + Nushell, + #[allow(dead_code)] + Xonsh, + } + + enum PluginProbeType { + EnvironmentVariable(&'static str), + InteractiveShellVariable(&'static str), + } + + type PluginValidator = fn(&str, &sysinfo::Process, &str) -> Option<String>; + + let plugin_list: [( + &str, + PluginShellType, + PluginProbeType, + Option<PluginValidator>, + ); 3] = [ + ( + "atuin", + PluginShellType::Any, + PluginProbeType::EnvironmentVariable("ATUIN_SESSION"), + None, + ), + ( + "blesh", + PluginShellType::Bash, + PluginProbeType::EnvironmentVariable("BLE_SESSION_ID"), + Some(Self::validate_plugin_blesh), + ), + ( + "bash-preexec", + PluginShellType::Bash, + PluginProbeType::InteractiveShellVariable("bash_preexec_imported"), + None, + ), + ]; + + plugin_list + .into_iter() + .filter(|(_, shell_type, _, _)| match shell_type { + PluginShellType::Any => true, + PluginShellType::Bash => shell.starts_with("bash") || shell == "sh", + PluginShellType::Zsh => shell.starts_with("zsh"), + PluginShellType::Fish => shell.starts_with("fish"), + PluginShellType::Nushell => shell.starts_with("nu"), + PluginShellType::Xonsh => shell.starts_with("xonsh"), + }) + .filter_map(|(plugin, _, probe_type, validator)| -> Option<String> { + match probe_type { + PluginProbeType::EnvironmentVariable(env) => { + env::var(env).ok().filter(|value| !value.is_empty()) + } + PluginProbeType::InteractiveShellVariable(shellvar) => { + ShellInfo::shellvar_exists(shell, shellvar).then_some(String::default()) + } + } + .and_then(|value| { + validator.map_or_else( + || Some(plugin.to_string()), + |validator| validator(shell, shell_process, &value), + ) + }) + }) + .collect() + } + + pub fn new() -> Self { + // TODO: rework to use atuin_common::Shell + + let sys = System::new_all(); + + let process = sys + .process(get_current_pid().expect("Failed to get current PID")) + .expect("Process with current pid does not exist"); + + let parent = sys + .process(process.parent().expect("Atuin running with no parent!")) + .expect("Process with parent pid does not exist"); + + let name = shell_name(Some(parent)); + + let plugins = ShellInfo::plugins(name.as_str(), parent); + + let default = Shell::default_shell().unwrap_or(Shell::Unknown).to_string(); + + let preexec = Self::detect_preexec_framework(name.as_str()); + + Self { + name, + default, + plugins, + preexec, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct DiskInfo { + pub name: String, + pub filesystem: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SystemInfo { + pub os: String, + + pub arch: String, + + pub version: String, + pub disks: Vec<DiskInfo>, +} + +impl SystemInfo { + pub fn new() -> Self { + let disks = Disks::new_with_refreshed_list(); + let disks = disks + .list() + .iter() + .map(|d| DiskInfo { + name: d.name().to_os_string().into_string().unwrap(), + filesystem: d.file_system().to_os_string().into_string().unwrap(), + }) + .collect(); + + Self { + os: System::name().unwrap_or_else(|| "unknown".to_string()), + arch: System::cpu_arch().unwrap_or_else(|| "unknown".to_string()), + version: System::os_version().unwrap_or_else(|| "unknown".to_string()), + disks, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct SyncInfo { + /// Whether the main Atuin sync server is in use + /// I'm just calling it Atuin Cloud for lack of a better name atm + pub cloud: bool, + pub records: bool, + pub auto_sync: bool, + + pub last_sync: String, +} + +impl SyncInfo { + pub fn new(settings: &Settings) -> Self { + Self { + cloud: settings.sync_address == "https://api.atuin.sh", + auto_sync: settings.auto_sync, + records: settings.sync.records, + last_sync: Settings::last_sync().map_or("no last sync".to_string(), |v| v.to_string()), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct AtuinInfo { + pub version: String, + + /// Whether the main Atuin sync server is in use + /// I'm just calling it Atuin Cloud for lack of a better name atm + pub sync: Option<SyncInfo>, +} + +impl AtuinInfo { + pub fn new(settings: &Settings) -> Self { + let session_path = settings.session_path.as_str(); + let logged_in = PathBuf::from(session_path).exists(); + + let sync = if logged_in { + Some(SyncInfo::new(settings)) + } else { + None + }; + + Self { + version: crate::VERSION.to_string(), + sync, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct DoctorDump { + pub atuin: AtuinInfo, + pub shell: ShellInfo, + pub system: SystemInfo, +} + +impl DoctorDump { + pub fn new(settings: &Settings) -> Self { + Self { + atuin: AtuinInfo::new(settings), + shell: ShellInfo::new(), + system: SystemInfo::new(), + } + } +} + +fn checks(info: &DoctorDump) { + println!(); // spacing + // + let zfs_error = "[Filesystem] ZFS is known to have some issues with SQLite. Atuin uses SQLite heavily. If you are having poor performance, there are some workarounds here: https://github.com/atuinsh/atuin/issues/952".bold().red(); + let bash_plugin_error = "[Shell] If you are using Bash, Atuin requires that either bash-preexec or ble.sh be installed. An older ble.sh may not be detected. so ignore this if you have it set up! Read more here: https://docs.atuin.sh/guide/installation/#bash".bold().red(); + let blesh_loading_order_error = "[Shell] Atuin seems to be loaded before ble.sh is sourced. In .bashrc, make sure to initialize Atuin after sourcing ble.sh.".bold().red(); + + // ZFS: https://github.com/atuinsh/atuin/issues/952 + if info.system.disks.iter().any(|d| d.filesystem == "zfs") { + println!("{zfs_error}"); + } + + // Shell + if info.shell.name == "bash" { + if !info + .shell + .plugins + .iter() + .any(|p| p == "blesh" || p == "bash-preexec") + { + println!("{bash_plugin_error}"); + } + + if info.shell.plugins.iter().any(|plugin| plugin == "atuin") + && info.shell.plugins.iter().any(|plugin| plugin == "blesh") + && info.shell.preexec.as_ref().is_some_and(|val| val == "none") + { + println!("{blesh_loading_order_error}"); + } + } +} + +pub fn run(settings: &Settings) -> Result<()> { + println!("{}", "Atuin Doctor".bold()); + println!("Checking for diagnostics"); + let dump = DoctorDump::new(settings); + + checks(&dump); + + let dump = serde_yaml::to_string(&dump)?; + + println!("\nPlease include the output below with any bug reports or issues\n"); + println!("{dump}"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/dotfiles.rs b/crates/atuin/src/command/client/dotfiles.rs new file mode 100644 index 00000000..291c794d --- /dev/null +++ b/crates/atuin/src/command/client/dotfiles.rs @@ -0,0 +1,22 @@ +use clap::Subcommand; +use eyre::Result; + +use atuin_client::{record::sqlite_store::SqliteStore, settings::Settings}; + +mod alias; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Manage shell aliases with Atuin + #[command(subcommand)] + Alias(alias::Cmd), +} + +impl Cmd { + pub async fn run(self, settings: &Settings, store: SqliteStore) -> Result<()> { + match self { + Self::Alias(cmd) => cmd.run(settings, store).await, + } + } +} diff --git a/crates/atuin/src/command/client/dotfiles/alias.rs b/crates/atuin/src/command/client/dotfiles/alias.rs new file mode 100644 index 00000000..6456a8b0 --- /dev/null +++ b/crates/atuin/src/command/client/dotfiles/alias.rs @@ -0,0 +1,95 @@ +use clap::Subcommand; +use eyre::{Context, Result}; + +use atuin_client::{encryption, record::sqlite_store::SqliteStore, settings::Settings}; + +use atuin_dotfiles::{shell::Alias, store::AliasStore}; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Set an alias + Set { name: String, value: String }, + + /// Delete an alias + Delete { name: String }, + + /// List all aliases + List, + + /// Import aliases set in the current shell + Import, +} + +impl Cmd { + async fn set(&self, store: AliasStore, name: String, value: String) -> Result<()> { + let aliases = store.aliases().await?; + let found: Vec<Alias> = aliases.into_iter().filter(|a| a.name == name).collect(); + + if found.is_empty() { + println!("Aliasing '{name}={value}'."); + } else { + println!( + "Overwriting alias '{name}={}' with '{name}={value}'.", + found[0].value + ); + } + + store.set(&name, &value).await?; + + Ok(()) + } + + async fn list(&self, store: AliasStore) -> Result<()> { + let aliases = store.aliases().await?; + + for i in aliases { + println!("{}={}", i.name, i.value); + } + + Ok(()) + } + + async fn delete(&self, store: AliasStore, name: String) -> Result<()> { + let mut aliases = store.aliases().await?.into_iter(); + if let Some(alias) = aliases.find(|alias| alias.name == name) { + println!("Deleting '{name}={}'.", alias.value); + store.delete(&name).await?; + } else { + eprintln!("Cannot delete '{name}': Alias not set."); + }; + Ok(()) + } + + async fn import(&self, store: AliasStore) -> Result<()> { + let aliases = atuin_dotfiles::shell::import_aliases(store).await?; + + for i in aliases { + println!("Importing {}={}", i.name, i.value); + } + + Ok(()) + } + + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + if !settings.dotfiles.enabled { + eprintln!("Dotfiles are not enabled. Add\n\n[dotfiles]\nenabled = true\n\nto your configuration file to enable them.\n"); + eprintln!("The default configuration file is located at ~/.config/atuin/config.toml."); + return Ok(()); + } + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().expect("failed to get host_id"); + + let alias_store = AliasStore::new(store, host_id, encryption_key); + + match self { + Self::Set { name, value } => self.set(alias_store, name.clone(), value.clone()).await, + Self::Delete { name } => self.delete(alias_store, name.clone()).await, + Self::List => self.list(alias_store).await, + Self::Import => self.import(alias_store).await, + } + } +} diff --git a/crates/atuin/src/command/client/history.rs b/crates/atuin/src/command/client/history.rs new file mode 100644 index 00000000..e6774816 --- /dev/null +++ b/crates/atuin/src/command/client/history.rs @@ -0,0 +1,556 @@ +use std::{ + fmt::{self, Display}, + io::{self, IsTerminal, Write}, + time::Duration, +}; + +use atuin_common::utils::{self, Escapable as _}; +use clap::Subcommand; +use eyre::{Context, Result}; +use runtime_format::{FormatKey, FormatKeyError, ParseSegment, ParsedFmt}; + +use atuin_client::{ + database::{current_context, Database}, + encryption, + history::{store::HistoryStore, History}, + record::sqlite_store::SqliteStore, + settings::{ + FilterMode::{Directory, Global, Session}, + Settings, Timezone, + }, +}; + +#[cfg(feature = "sync")] +use atuin_client::{record, sync}; + +use log::{debug, warn}; +use time::{macros::format_description, OffsetDateTime}; + +use super::search::format_duration_into; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Begins a new command in the history + Start { + command: Vec<String>, + }, + + /// Finishes a new command in the history (adds time, exit code) + End { + id: String, + #[arg(long, short)] + exit: i64, + #[arg(long, short)] + duration: Option<u64>, + }, + + /// List all items in history + List { + #[arg(long, short)] + cwd: bool, + + #[arg(long, short)] + session: bool, + + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Terminate the output with a null, for better multiline support + #[arg(long)] + print0: bool, + + #[arg(long, short, default_value = "true")] + // accept no value + #[arg(num_args(0..=1), default_missing_value("true"))] + // accept a value + #[arg(action = clap::ArgAction::Set)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {exit} and {time}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + /// Get the last command ran + Last { + #[arg(long)] + human: bool, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host} and {time}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + }, + + InitStore, + + /// Delete history entries matching the configured exclusion filters + Prune { + /// List matching history lines without performing the actual deletion. + #[arg(short = 'n', long)] + dry_run: bool, + }, +} + +#[derive(Clone, Copy, Debug)] +pub enum ListMode { + Human, + CmdOnly, + Regular, +} + +impl ListMode { + pub const fn from_flags(human: bool, cmd_only: bool) -> Self { + if human { + ListMode::Human + } else if cmd_only { + ListMode::CmdOnly + } else { + ListMode::Regular + } + } +} + +#[allow(clippy::cast_sign_loss)] +pub fn print_list( + h: &[History], + list_mode: ListMode, + format: Option<&str>, + print0: bool, + reverse: bool, + tz: Timezone, +) { + let w = std::io::stdout(); + let mut w = w.lock(); + + let fmt_str = match list_mode { + ListMode::Human => format + .unwrap_or("{time} · {duration}\t{command}") + .replace("\\t", "\t"), + ListMode::Regular => format + .unwrap_or("{time}\t{command}\t{duration}") + .replace("\\t", "\t"), + // not used + ListMode::CmdOnly => String::new(), + }; + + let parsed_fmt = match list_mode { + ListMode::Human | ListMode::Regular => parse_fmt(&fmt_str), + ListMode::CmdOnly => std::iter::once(ParseSegment::Key("command")).collect(), + }; + + let iterator = if reverse { + Box::new(h.iter().rev()) as Box<dyn Iterator<Item = &History>> + } else { + Box::new(h.iter()) as Box<dyn Iterator<Item = &History>> + }; + + let entry_terminator = if print0 { "\0" } else { "\n" }; + let flush_each_line = print0; + + for history in iterator { + let fh = FmtHistory { + history, + cmd_format: CmdFormat::for_output(&w), + tz: &tz, + }; + let args = parsed_fmt.with_args(&fh); + let write = write!(w, "{args}{entry_terminator}"); + if let Err(err) = args.status() { + eprintln!("ERROR: history output failed with: {err}"); + std::process::exit(1); + } + check_for_write_errors(write); + if flush_each_line { + check_for_write_errors(w.flush()); + } + } + + if !flush_each_line { + check_for_write_errors(w.flush()); + } +} + +fn check_for_write_errors(write: Result<(), io::Error>) { + if let Err(err) = write { + // Ignore broken pipe (issue #626) + if err.kind() != io::ErrorKind::BrokenPipe { + eprintln!("ERROR: History output failed with the following error: {err}"); + std::process::exit(1); + } + } +} + +/// Type wrapper around `History` with formatting settings. +#[derive(Clone, Copy, Debug)] +struct FmtHistory<'a> { + history: &'a History, + cmd_format: CmdFormat, + tz: &'a Timezone, +} + +#[derive(Clone, Copy, Debug)] +enum CmdFormat { + Literal, + Escaped, +} +impl CmdFormat { + fn for_output<O: IsTerminal>(out: &O) -> Self { + if out.is_terminal() { + Self::Escaped + } else { + Self::Literal + } + } +} + +static TIME_FMT: &[time::format_description::FormatItem<'static>] = + format_description!("[year]-[month]-[day] [hour repr:24]:[minute]:[second]"); + +/// defines how to format the history +impl FormatKey for FmtHistory<'_> { + #[allow(clippy::cast_sign_loss)] + fn fmt(&self, key: &str, f: &mut fmt::Formatter<'_>) -> Result<(), FormatKeyError> { + match key { + "command" => match self.cmd_format { + CmdFormat::Literal => f.write_str(self.history.command.trim()), + CmdFormat::Escaped => f.write_str(&self.history.command.trim().escape_control()), + }?, + "directory" => f.write_str(self.history.cwd.trim())?, + "exit" => f.write_str(&self.history.exit.to_string())?, + "duration" => { + let dur = Duration::from_nanos(std::cmp::max(self.history.duration, 0) as u64); + format_duration_into(dur, f)?; + } + "time" => { + self.history + .timestamp + .to_offset(self.tz.0) + .format(TIME_FMT) + .map_err(|_| fmt::Error)? + .fmt(f)?; + } + "relativetime" => { + let since = OffsetDateTime::now_utc() - self.history.timestamp; + let d = Duration::try_from(since).unwrap_or_default(); + format_duration_into(d, f)?; + } + "host" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or(&self.history.hostname, |(host, _)| host), + )?, + "user" => f.write_str( + self.history + .hostname + .split_once(':') + .map_or("", |(_, user)| user), + )?, + _ => return Err(FormatKeyError::UnknownKey), + } + Ok(()) + } +} + +fn parse_fmt(format: &str) -> ParsedFmt { + match ParsedFmt::new(format) { + Ok(fmt) => fmt, + Err(err) => { + eprintln!("ERROR: History formatting failed with the following error: {err}"); + println!("If your formatting string contains curly braces (eg: {{var}}) you need to escape them this way: {{{{var}}."); + std::process::exit(1) + } + } +} + +impl Cmd { + #[allow(clippy::too_many_lines, clippy::cast_possible_truncation)] + async fn handle_start( + db: &impl Database, + settings: &Settings, + command: &[String], + ) -> Result<()> { + let command = command.join(" "); + + // It's better for atuin to silently fail here and attempt to + // store whatever is ran, than to throw an error to the terminal + let cwd = utils::get_current_dir(); + + let h: History = History::capture() + .timestamp(OffsetDateTime::now_utc()) + .command(command) + .cwd(cwd) + .build() + .into(); + + if !h.should_save(settings) { + return Ok(()); + } + + // print the ID + // we use this as the key for calling end + println!("{}", h.id); + db.save(&h).await?; + + Ok(()) + } + + #[allow(unused_variables)] + async fn handle_end( + db: &impl Database, + store: SqliteStore, + history_store: HistoryStore, + settings: &Settings, + id: &str, + exit: i64, + duration: Option<u64>, + ) -> Result<()> { + if id.trim() == "" { + return Ok(()); + } + + let Some(mut h) = db.load(id).await? else { + warn!("history entry is missing"); + return Ok(()); + }; + + if h.duration > 0 { + debug!("cannot end history - already has duration"); + + // returning OK as this can occur if someone Ctrl-c a prompt + return Ok(()); + } + + if !settings.store_failed && h.exit != 0 { + debug!("history has non-zero exit code, and store_failed is false"); + + // the history has already been inserted half complete. remove it + db.delete(h).await?; + + return Ok(()); + } + + h.exit = exit; + h.duration = match duration { + Some(value) => i64::try_from(value).context("command took over 292 years")?, + None => i64::try_from((OffsetDateTime::now_utc() - h.timestamp).whole_nanoseconds()) + .context("command took over 292 years")?, + }; + + db.update(&h).await?; + history_store.push(h).await?; + + if settings.should_sync()? { + #[cfg(feature = "sync")] + { + if settings.sync.records { + let (_, downloaded) = record::sync::sync(settings, &store).await?; + Settings::save_sync_time()?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + } else { + debug!("running periodic background sync"); + sync::sync(settings, false, db).await?; + } + } + #[cfg(not(feature = "sync"))] + debug!("not compiled with sync support"); + } else { + debug!("sync disabled! not syncing"); + } + + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + #[allow(clippy::fn_params_excessive_bools)] + async fn handle_list( + db: &impl Database, + settings: &Settings, + context: atuin_client::database::Context, + session: bool, + cwd: bool, + mode: ListMode, + format: Option<String>, + include_deleted: bool, + print0: bool, + reverse: bool, + tz: Timezone, + ) -> Result<()> { + let filters = match (session, cwd) { + (true, true) => [Session, Directory], + (true, false) => [Session, Global], + (false, true) => [Global, Directory], + (false, false) => [settings.filter_mode, Global], + }; + + let history = db + .list(&filters, &context, None, false, include_deleted) + .await?; + + print_list( + &history, + mode, + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + print0, + reverse, + tz, + ); + + Ok(()) + } + + async fn handle_prune( + db: &impl Database, + settings: &Settings, + store: SqliteStore, + context: atuin_client::database::Context, + dry_run: bool, + ) -> Result<()> { + // Grab all executed commands and filter them using History::should_save. + // We could iterate or paginate here if memory usage becomes an issue. + let matches: Vec<History> = db + .list(&[Global], &context, None, false, false) + .await? + .into_iter() + .filter(|h| !h.should_save(settings)) + .collect(); + + match matches.len() { + 0 => { + println!("No entries to prune."); + return Ok(()); + } + 1 => println!("Found 1 entry to prune."), + n => println!("Found {n} entries to prune."), + } + + if dry_run { + print_list( + &matches, + ListMode::Human, + Some(settings.history_format.as_str()), + false, + false, + settings.timezone, + ); + } else { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().expect("failed to get host_id"); + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + for entry in matches { + eprintln!("deleting {}", entry.id); + if settings.sync.records { + let (id, _) = history_store.delete(entry.id.clone()).await?; + history_store.incremental_build(db, &[id]).await?; + } else { + db.delete(entry.clone()).await?; + } + } + } + Ok(()) + } + + pub async fn run( + self, + settings: &Settings, + db: &impl Database, + store: SqliteStore, + ) -> Result<()> { + let context = current_context(); + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + match self { + Self::Start { command } => Self::handle_start(db, settings, &command).await, + Self::End { id, exit, duration } => { + Self::handle_end(db, store, history_store, settings, &id, exit, duration).await + } + Self::List { + session, + cwd, + human, + cmd_only, + print0, + reverse, + timezone, + format, + } => { + let mode = ListMode::from_flags(human, cmd_only); + let tz = timezone.unwrap_or(settings.timezone); + Self::handle_list( + db, settings, context, session, cwd, mode, format, false, print0, reverse, tz, + ) + .await + } + + Self::Last { + human, + cmd_only, + timezone, + format, + } => { + let last = db.last().await?; + let last = last.as_ref().map(std::slice::from_ref).unwrap_or_default(); + let tz = timezone.unwrap_or(settings.timezone); + print_list( + last, + ListMode::from_flags(human, cmd_only), + match format { + None => Some(settings.history_format.as_str()), + _ => format.as_deref(), + }, + false, + true, + tz, + ); + + Ok(()) + } + + Self::InitStore => history_store.init_store(db).await, + + Self::Prune { dry_run } => { + Self::handle_prune(db, settings, store, context, dry_run).await + } + } + } +} diff --git a/crates/atuin/src/command/client/import.rs b/crates/atuin/src/command/client/import.rs new file mode 100644 index 00000000..35595b9b --- /dev/null +++ b/crates/atuin/src/command/client/import.rs @@ -0,0 +1,168 @@ +use std::env; + +use async_trait::async_trait; +use clap::Parser; +use eyre::Result; +use indicatif::ProgressBar; + +use atuin_client::{ + database::Database, + history::History, + import::{ + bash::Bash, fish::Fish, nu::Nu, nu_histdb::NuHistDb, resh::Resh, xonsh::Xonsh, + xonsh_sqlite::XonshSqlite, zsh::Zsh, zsh_histdb::ZshHistDb, Importer, Loader, + }, +}; + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Import history for the current shell + Auto, + + /// Import history from the zsh history file + Zsh, + /// Import history from the zsh history file + ZshHistDb, + /// Import history from the bash history file + Bash, + /// Import history from the resh history file + Resh, + /// Import history from the fish history file + Fish, + /// Import history from the nu history file + Nu, + /// Import history from the nu history file + NuHistDb, + /// Import history from xonsh json files + Xonsh, + /// Import history from xonsh sqlite db + XonshSqlite, +} + +const BATCH_SIZE: usize = 100; + +impl Cmd { + pub async fn run<DB: Database>(&self, db: &DB) -> Result<()> { + println!(" Atuin "); + println!("======================"); + println!(" \u{1f30d} "); + println!(" \u{1f418}\u{1f418}\u{1f418}\u{1f418} "); + println!(" \u{1f422} "); + println!("======================"); + println!("Importing history..."); + + match self { + Self::Auto => { + if cfg!(windows) { + println!("This feature does not work on windows. Please run atuin import <SHELL>. To view a list of shells, run atuin import."); + return Ok(()); + } + + // $XONSH_HISTORY_BACKEND isn't always set, but $XONSH_HISTORY_FILE is + let xonsh_histfile = + env::var("XONSH_HISTORY_FILE").unwrap_or_else(|_| String::new()); + let shell = env::var("SHELL").unwrap_or_else(|_| String::from("NO_SHELL")); + + if xonsh_histfile.to_lowercase().ends_with(".json") { + println!("Detected Xonsh",); + import::<Xonsh, DB>(db).await + } else if xonsh_histfile.to_lowercase().ends_with(".sqlite") { + println!("Detected Xonsh (SQLite backend)"); + import::<XonshSqlite, DB>(db).await + } else if shell.ends_with("/zsh") { + if ZshHistDb::histpath().is_ok() { + println!( + "Detected Zsh-HistDb, using :{}", + ZshHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<ZshHistDb, DB>(db).await + } else { + println!("Detected ZSH"); + import::<Zsh, DB>(db).await + } + } else if shell.ends_with("/fish") { + println!("Detected Fish"); + import::<Fish, DB>(db).await + } else if shell.ends_with("/bash") { + println!("Detected Bash"); + import::<Bash, DB>(db).await + } else if shell.ends_with("/nu") { + if NuHistDb::histpath().is_ok() { + println!( + "Detected Nu-HistDb, using :{}", + NuHistDb::histpath().unwrap().to_str().unwrap() + ); + import::<NuHistDb, DB>(db).await + } else { + println!("Detected Nushell"); + import::<Nu, DB>(db).await + } + } else { + println!("cannot import {shell} history"); + Ok(()) + } + } + + Self::Zsh => import::<Zsh, DB>(db).await, + Self::ZshHistDb => import::<ZshHistDb, DB>(db).await, + Self::Bash => import::<Bash, DB>(db).await, + Self::Resh => import::<Resh, DB>(db).await, + Self::Fish => import::<Fish, DB>(db).await, + Self::Nu => import::<Nu, DB>(db).await, + Self::NuHistDb => import::<NuHistDb, DB>(db).await, + Self::Xonsh => import::<Xonsh, DB>(db).await, + Self::XonshSqlite => import::<XonshSqlite, DB>(db).await, + } + } +} + +pub struct HistoryImporter<'db, DB: Database> { + pb: ProgressBar, + buf: Vec<History>, + db: &'db DB, +} + +impl<'db, DB: Database> HistoryImporter<'db, DB> { + fn new(db: &'db DB, len: usize) -> Self { + Self { + pb: ProgressBar::new(len as u64), + buf: Vec::with_capacity(BATCH_SIZE), + db, + } + } + + async fn flush(self) -> Result<()> { + if !self.buf.is_empty() { + self.db.save_bulk(&self.buf).await?; + } + self.pb.finish(); + Ok(()) + } +} + +#[async_trait] +impl<'db, DB: Database> Loader for HistoryImporter<'db, DB> { + async fn push(&mut self, hist: History) -> Result<()> { + self.pb.inc(1); + self.buf.push(hist); + if self.buf.len() == self.buf.capacity() { + self.db.save_bulk(&self.buf).await?; + self.buf.clear(); + } + Ok(()) + } +} + +async fn import<I: Importer + Send, DB: Database>(db: &DB) -> Result<()> { + println!("Importing history from {}", I::NAME); + + let mut importer = I::new().await?; + let len = importer.entries().await.unwrap(); + let mut loader = HistoryImporter::new(db, len); + importer.load(&mut loader).await?; + loader.flush().await?; + + println!("Import complete!"); + Ok(()) +} diff --git a/crates/atuin/src/command/client/info.rs b/crates/atuin/src/command/client/info.rs new file mode 100644 index 00000000..60ba1fe6 --- /dev/null +++ b/crates/atuin/src/command/client/info.rs @@ -0,0 +1,31 @@ +use atuin_client::settings::Settings;
+
+use crate::VERSION;
+
+pub fn run(settings: &Settings) {
+ let config = atuin_common::utils::config_dir();
+ let mut config_file = config.clone();
+ config_file.push("config.toml");
+ let mut sever_config = config;
+ sever_config.push("server.toml");
+
+ let config_paths = format!(
+ "Config files:\nclient config: {:?}\nserver config: {:?}\nclient db path: {:?}\nkey path: {:?}\nsession path: {:?}",
+ config_file.to_string_lossy(),
+ sever_config.to_string_lossy(),
+ settings.db_path,
+ settings.key_path,
+ settings.session_path
+ );
+
+ let env_vars = format!(
+ "Env Vars:\nATUIN_CONFIG_DIR = {:?}",
+ std::env::var("ATUIN_CONFIG_DIR").unwrap_or_else(|_| "None".into())
+ );
+
+ let general_info = format!("Version info:\nversion: {VERSION}");
+
+ let print_out = format!("{config_paths}\n\n{env_vars}\n\n{general_info}");
+
+ println!("{print_out}");
+}
diff --git a/crates/atuin/src/command/client/init.rs b/crates/atuin/src/command/client/init.rs new file mode 100644 index 00000000..bfda75be --- /dev/null +++ b/crates/atuin/src/command/client/init.rs @@ -0,0 +1,145 @@ +use std::path::PathBuf; + +use atuin_client::{encryption, record::sqlite_store::SqliteStore, settings::Settings}; +use atuin_dotfiles::store::AliasStore; +use clap::{Parser, ValueEnum}; +use eyre::{Result, WrapErr}; + +mod bash; +mod fish; +mod xonsh; +mod zsh; + +#[derive(Parser, Debug)] +pub struct Cmd { + shell: Shell, + + /// Disable the binding of CTRL-R to atuin + #[clap(long)] + disable_ctrl_r: bool, + + /// Disable the binding of the Up Arrow key to atuin + #[clap(long)] + disable_up_arrow: bool, +} + +#[derive(Clone, Copy, ValueEnum, Debug)] +pub enum Shell { + /// Zsh setup + Zsh, + /// Bash setup + Bash, + /// Fish setup + Fish, + /// Nu setup + Nu, + /// Xonsh setup + Xonsh, +} + +impl Cmd { + fn init_nu(&self) { + let full = include_str!("../../shell/atuin.nu"); + println!("{full}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: control + keycode: char_r + mode: [emacs, vi_normal, vi_insert] + event: { send: executehostcommand cmd: (_atuin_search_cmd) } + } + ) +)"; + const BIND_UP_ARROW: &str = r" +$env.config = ( + $env.config | upsert keybindings ( + $env.config.keybindings + | append { + name: atuin + modifier: none + keycode: up + mode: [emacs, vi_normal, vi_insert] + event: { + until: [ + {send: menuup} + {send: executehostcommand cmd: (_atuin_search_cmd '--shell-up-key-binding') } + ] + } + } + ) +) +"; + if !self.disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !self.disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } + } + + fn static_init(&self) { + match self.shell { + Shell::Zsh => { + zsh::init_static(self.disable_up_arrow, self.disable_ctrl_r); + } + Shell::Bash => { + bash::init_static(self.disable_up_arrow, self.disable_ctrl_r); + } + Shell::Fish => { + fish::init_static(self.disable_up_arrow, self.disable_ctrl_r); + } + Shell::Nu => { + self.init_nu(); + } + Shell::Xonsh => { + xonsh::init_static(self.disable_up_arrow, self.disable_ctrl_r); + } + }; + } + + async fn dotfiles_init(&self, settings: &Settings) -> Result<()> { + let record_store_path = PathBuf::from(settings.record_store_path.as_str()); + let sqlite_store = SqliteStore::new(record_store_path, settings.local_timeout).await?; + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + let host_id = Settings::host_id().expect("failed to get host_id"); + + let alias_store = AliasStore::new(sqlite_store, host_id, encryption_key); + + match self.shell { + Shell::Zsh => { + zsh::init(alias_store, self.disable_up_arrow, self.disable_ctrl_r).await?; + } + Shell::Bash => { + bash::init(alias_store, self.disable_up_arrow, self.disable_ctrl_r).await?; + } + Shell::Fish => { + fish::init(alias_store, self.disable_up_arrow, self.disable_ctrl_r).await?; + } + Shell::Nu => self.init_nu(), + Shell::Xonsh => { + xonsh::init(alias_store, self.disable_up_arrow, self.disable_ctrl_r).await?; + } + } + + Ok(()) + } + + pub async fn run(self, settings: &Settings) -> Result<()> { + if settings.dotfiles.enabled { + self.dotfiles_init(settings).await?; + } else { + self.static_init(); + } + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/init/bash.rs b/crates/atuin/src/command/client/init/bash.rs new file mode 100644 index 00000000..6e7f14e7 --- /dev/null +++ b/crates/atuin/src/command/client/init/bash.rs @@ -0,0 +1,26 @@ +use atuin_dotfiles::store::AliasStore; +use eyre::Result; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool) { + let base = include_str!("../../../shell/atuin.bash"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + + println!("__atuin_bind_ctrl_r={bind_ctrl_r}"); + println!("__atuin_bind_up_arrow={bind_up_arrow}"); + println!("{base}"); +} + +pub async fn init(store: AliasStore, disable_up_arrow: bool, disable_ctrl_r: bool) -> Result<()> { + init_static(disable_up_arrow, disable_ctrl_r); + + let aliases = atuin_dotfiles::shell::bash::config(&store).await; + + println!("{aliases}"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/init/fish.rs b/crates/atuin/src/command/client/init/fish.rs new file mode 100644 index 00000000..4ec74952 --- /dev/null +++ b/crates/atuin/src/command/client/init/fish.rs @@ -0,0 +1,45 @@ +use atuin_dotfiles::store::AliasStore; +use eyre::Result; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool) { + let base = include_str!("../../../shell/atuin.fish"); + + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"bind \cr _atuin_search"; + const BIND_UP_ARROW: &str = r"bind -k up _atuin_bind_up +bind \eOA _atuin_bind_up +bind \e\[A _atuin_bind_up"; + const BIND_CTRL_R_INS: &str = r"bind -M insert \cr _atuin_search"; + const BIND_UP_ARROW_INS: &str = r"bind -M insert -k up _atuin_bind_up +bind -M insert \eOA _atuin_bind_up +bind -M insert \e\[A _atuin_bind_up"; + + if !disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + + println!("if bind -M insert > /dev/null 2>&1"); + if !disable_ctrl_r { + println!("{BIND_CTRL_R_INS}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW_INS}"); + } + println!("end"); + } +} + +pub async fn init(store: AliasStore, disable_up_arrow: bool, disable_ctrl_r: bool) -> Result<()> { + init_static(disable_up_arrow, disable_ctrl_r); + + let aliases = atuin_dotfiles::shell::fish::config(&store).await; + + println!("{aliases}"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/init/xonsh.rs b/crates/atuin/src/command/client/init/xonsh.rs new file mode 100644 index 00000000..cfe64f7e --- /dev/null +++ b/crates/atuin/src/command/client/init/xonsh.rs @@ -0,0 +1,31 @@ +use atuin_dotfiles::store::AliasStore; +use eyre::Result; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool) { + let base = include_str!("../../../shell/atuin.xsh"); + + let (bind_ctrl_r, bind_up_arrow) = if std::env::var("ATUIN_NOBIND").is_ok() { + (false, false) + } else { + (!disable_ctrl_r, !disable_up_arrow) + }; + println!( + "_ATUIN_BIND_CTRL_R={}", + if bind_ctrl_r { "True" } else { "False" } + ); + println!( + "_ATUIN_BIND_UP_ARROW={}", + if bind_up_arrow { "True" } else { "False" } + ); + println!("{base}"); +} + +pub async fn init(store: AliasStore, disable_up_arrow: bool, disable_ctrl_r: bool) -> Result<()> { + init_static(disable_up_arrow, disable_ctrl_r); + + let aliases = atuin_dotfiles::shell::xonsh::config(&store).await; + + println!("{aliases}"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/init/zsh.rs b/crates/atuin/src/command/client/init/zsh.rs new file mode 100644 index 00000000..2341e203 --- /dev/null +++ b/crates/atuin/src/command/client/init/zsh.rs @@ -0,0 +1,39 @@ +use atuin_dotfiles::store::AliasStore; +use eyre::Result; + +pub fn init_static(disable_up_arrow: bool, disable_ctrl_r: bool) { + let base = include_str!("../../../shell/atuin.zsh"); + + println!("{base}"); + + if std::env::var("ATUIN_NOBIND").is_err() { + const BIND_CTRL_R: &str = r"bindkey -M emacs '^r' atuin-search +bindkey -M viins '^r' atuin-search-viins +bindkey -M vicmd '/' atuin-search"; + + const BIND_UP_ARROW: &str = r"bindkey -M emacs '^[[A' atuin-up-search +bindkey -M vicmd '^[[A' atuin-up-search-vicmd +bindkey -M viins '^[[A' atuin-up-search-viins +bindkey -M emacs '^[OA' atuin-up-search +bindkey -M vicmd '^[OA' atuin-up-search-vicmd +bindkey -M viins '^[OA' atuin-up-search-viins +bindkey -M vicmd 'k' atuin-up-search-vicmd"; + + if !disable_ctrl_r { + println!("{BIND_CTRL_R}"); + } + if !disable_up_arrow { + println!("{BIND_UP_ARROW}"); + } + } +} + +pub async fn init(store: AliasStore, disable_up_arrow: bool, disable_ctrl_r: bool) -> Result<()> { + init_static(disable_up_arrow, disable_ctrl_r); + + let aliases = atuin_dotfiles::shell::zsh::config(&store).await; + + println!("{aliases}"); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/kv.rs b/crates/atuin/src/command/client/kv.rs new file mode 100644 index 00000000..b97f31b7 --- /dev/null +++ b/crates/atuin/src/command/client/kv.rs @@ -0,0 +1,96 @@ +use clap::Subcommand; +use eyre::{Context, Result}; + +use atuin_client::{encryption, kv::KvStore, record::store::Store, settings::Settings}; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + // atuin kv set foo bar bar + Set { + #[arg(long, short)] + key: String, + + #[arg(long, short, default_value = "default")] + namespace: String, + + value: String, + }, + + // atuin kv get foo => bar baz + Get { + key: String, + + #[arg(long, short, default_value = "default")] + namespace: String, + }, + + List { + #[arg(long, short, default_value = "default")] + namespace: String, + + #[arg(long, short)] + all_namespaces: bool, + }, +} + +impl Cmd { + pub async fn run(&self, settings: &Settings, store: &(impl Store + Send + Sync)) -> Result<()> { + let kv_store = KvStore::new(); + + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + + match self { + Self::Set { + key, + value, + namespace, + } => { + kv_store + .set(store, &encryption_key, host_id, namespace, key, value) + .await + } + + Self::Get { key, namespace } => { + let val = kv_store.get(store, &encryption_key, namespace, key).await?; + + if let Some(kv) = val { + println!("{}", kv.value); + } + + Ok(()) + } + + Self::List { + namespace, + all_namespaces, + } => { + // TODO: don't rebuild this every time lol + let map = kv_store.build_kv(store, &encryption_key).await?; + + // slower, but sorting is probably useful + if *all_namespaces { + for (ns, kv) in &map { + for k in kv.keys() { + println!("{ns}.{k}"); + } + } + } else { + let ns = map.get(namespace); + + if let Some(ns) = ns { + for k in ns.keys() { + println!("{k}"); + } + } + } + + Ok(()) + } + } + } +} diff --git a/crates/atuin/src/command/client/search.rs b/crates/atuin/src/command/client/search.rs new file mode 100644 index 00000000..f645d26b --- /dev/null +++ b/crates/atuin/src/command/client/search.rs @@ -0,0 +1,307 @@ +use std::io::{stderr, IsTerminal as _}; + +use atuin_common::utils::{self, Escapable as _}; +use clap::Parser; +use eyre::Result; + +use atuin_client::{ + database::Database, + database::{current_context, OptFilters}, + encryption, + history::{store::HistoryStore, History}, + record::sqlite_store::SqliteStore, + settings::{FilterMode, KeymapMode, SearchMode, Settings, Timezone}, +}; + +use super::history::ListMode; + +mod cursor; +mod duration; +mod engines; +mod history_list; +mod inspector; +mod interactive; +mod sort; + +pub use duration::format_duration_into; + +#[allow(clippy::struct_excessive_bools, clippy::struct_field_names)] +#[derive(Parser, Debug)] +pub struct Cmd { + /// Filter search result by directory + #[arg(long, short)] + cwd: Option<String>, + + /// Exclude directory from results + #[arg(long = "exclude-cwd")] + exclude_cwd: Option<String>, + + /// Filter search result by exit code + #[arg(long, short)] + exit: Option<i64>, + + /// Exclude results with this exit code + #[arg(long = "exclude-exit")] + exclude_exit: Option<i64>, + + /// Only include results added before this date + #[arg(long, short)] + before: Option<String>, + + /// Only include results after this date + #[arg(long)] + after: Option<String>, + + /// How many entries to return at most + #[arg(long)] + limit: Option<i64>, + + /// Offset from the start of the results + #[arg(long)] + offset: Option<i64>, + + /// Open interactive search UI + #[arg(long, short)] + interactive: bool, + + /// Allow overriding filter mode over config + #[arg(long = "filter-mode")] + filter_mode: Option<FilterMode>, + + /// Allow overriding search mode over config + #[arg(long = "search-mode")] + search_mode: Option<SearchMode>, + + /// Marker argument used to inform atuin that it was invoked from a shell up-key binding (hidden from help to avoid confusion) + #[arg(long = "shell-up-key-binding", hide = true)] + shell_up_key_binding: bool, + + /// Notify the keymap at the shell's side + #[arg(long = "keymap-mode", default_value = "auto")] + keymap_mode: KeymapMode, + + /// Use human-readable formatting for time + #[arg(long)] + human: bool, + + query: Option<Vec<String>>, + + /// Show only the text of the command + #[arg(long)] + cmd_only: bool, + + /// Delete anything matching this query. Will not print out the match + #[arg(long)] + delete: bool, + + /// Delete EVERYTHING! + #[arg(long)] + delete_it_all: bool, + + /// Reverse the order of results, oldest first + #[arg(long, short)] + reverse: bool, + + /// Display the command time in another timezone other than the configured default. + /// + /// This option takes one of the following kinds of values: + /// - the special value "local" (or "l") which refers to the system time zone + /// - an offset from UTC (e.g. "+9", "-2:30") + #[arg(long, visible_alias = "tz")] + timezone: Option<Timezone>, + + /// Available variables: {command}, {directory}, {duration}, {user}, {host}, {time}, {exit} and + /// {relativetime}. + /// Example: --format "{time} - [{duration}] - {directory}$\t{command}" + #[arg(long, short)] + format: Option<String>, + + /// Set the maximum number of lines Atuin's interface should take up. + #[arg(long = "inline-height")] + inline_height: Option<u16>, +} + +impl Cmd { + // clippy: please write this instead + // clippy: now it has too many lines + // me: I'll do it later OKAY + #[allow(clippy::too_many_lines)] + pub async fn run( + self, + db: impl Database, + settings: &mut Settings, + store: SqliteStore, + ) -> Result<()> { + let query = self.query.map_or_else( + || { + std::env::var("ATUIN_QUERY").map_or_else( + |_| vec![], + |query| { + query + .split(' ') + .map(std::string::ToString::to_string) + .collect() + }, + ) + }, + |query| query, + ); + + if (self.delete_it_all || self.delete) && self.limit.is_some() { + // Because of how deletion is implemented, it will always delete all matches + // and disregard the limit option. It is also not clear what deletion with a + // limit would even mean. Deleting the LIMIT most recent entries that match + // the search query would make sense, but that wouldn't match what's displayed + // when running the equivalent search, but deleting those entries that are + // displayed with the search would leave any duplicates of those lines which may + // or may not have been intended to be deleted. + println!("\"--limit\" is not compatible with deletion."); + return Ok(()); + } + + if self.delete && query.is_empty() { + println!("Please specify a query to match the items you wish to delete. If you wish to delete all history, pass --delete-it-all"); + return Ok(()); + } + + if self.delete_it_all && !query.is_empty() { + println!( + "--delete-it-all will delete ALL of your history! It does not require a query." + ); + return Ok(()); + } + + if self.search_mode.is_some() { + settings.search_mode = self.search_mode.unwrap(); + } + if self.filter_mode.is_some() { + settings.filter_mode = self.filter_mode.unwrap(); + } + if self.inline_height.is_some() { + settings.inline_height = self.inline_height.unwrap(); + } + + settings.shell_up_key_binding = self.shell_up_key_binding; + + // `keymap_mode` specified in config.toml overrides the `--keymap-mode` + // option specified in the keybindings. + settings.keymap_mode = match settings.keymap_mode { + KeymapMode::Auto => self.keymap_mode, + value => value, + }; + settings.keymap_mode_shell = self.keymap_mode; + + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + if self.interactive { + let item = interactive::history(&query, settings, db, &history_store).await?; + if stderr().is_terminal() { + eprintln!("{}", item.escape_control()); + } else { + eprintln!("{item}"); + } + } else { + let opt_filter = OptFilters { + exit: self.exit, + exclude_exit: self.exclude_exit, + cwd: self.cwd, + exclude_cwd: self.exclude_cwd, + before: self.before, + after: self.after, + limit: self.limit, + offset: self.offset, + reverse: self.reverse, + }; + + let mut entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + + if entries.is_empty() { + std::process::exit(1) + } + + // if we aren't deleting, print it all + if self.delete || self.delete_it_all { + // delete it + // it only took me _years_ to add this + // sorry + while !entries.is_empty() { + for entry in &entries { + eprintln!("deleting {}", entry.id); + + if settings.sync.records { + let (id, _) = history_store.delete(entry.id.clone()).await?; + history_store.incremental_build(&db, &[id]).await?; + } else { + db.delete(entry.clone()).await?; + } + } + + entries = + run_non_interactive(settings, opt_filter.clone(), &query, &db).await?; + } + } else { + let format = match self.format { + None => Some(settings.history_format.as_str()), + _ => self.format.as_deref(), + }; + let tz = self.timezone.unwrap_or(settings.timezone); + + super::history::print_list( + &entries, + ListMode::from_flags(self.human, self.cmd_only), + format, + false, + true, + tz, + ); + } + }; + Ok(()) + } +} + +// This is supposed to more-or-less mirror the command line version, so ofc +// it is going to have a lot of args +#[allow(clippy::too_many_arguments, clippy::cast_possible_truncation)] +async fn run_non_interactive( + settings: &Settings, + filter_options: OptFilters, + query: &[String], + db: &impl Database, +) -> Result<Vec<History>> { + let dir = if filter_options.cwd.as_deref() == Some(".") { + Some(utils::get_current_dir()) + } else { + filter_options.cwd + }; + + let context = current_context(); + + let opt_filter = OptFilters { + cwd: dir.clone(), + ..filter_options + }; + + let dir = dir.unwrap_or_else(|| "/".to_string()); + let filter_mode = if settings.workspaces && utils::has_git_dir(dir.as_str()) { + FilterMode::Workspace + } else { + settings.filter_mode + }; + + let results = db + .search( + settings.search_mode, + filter_mode, + &context, + query.join(" ").as_str(), + opt_filter, + ) + .await?; + + Ok(results) +} diff --git a/crates/atuin/src/command/client/search/cursor.rs b/crates/atuin/src/command/client/search/cursor.rs new file mode 100644 index 00000000..2bce4f37 --- /dev/null +++ b/crates/atuin/src/command/client/search/cursor.rs @@ -0,0 +1,333 @@ +use atuin_client::settings::WordJumpMode; + +pub struct Cursor { + source: String, + index: usize, +} + +impl From<String> for Cursor { + fn from(source: String) -> Self { + Self { source, index: 0 } + } +} + +pub struct WordJumper<'a> { + word_chars: &'a str, + word_jump_mode: WordJumpMode, +} + +impl WordJumper<'_> { + fn is_word_boundary(&self, c: char, next_c: char) -> bool { + (c.is_whitespace() && !next_c.is_whitespace()) + || (!c.is_whitespace() && next_c.is_whitespace()) + || (self.word_chars.contains(c) && !self.word_chars.contains(next_c)) + || (!self.word_chars.contains(c) && self.word_chars.contains(next_c)) + } + + fn emacs_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index + 1..source.len().saturating_sub(1)) + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()); + (index + 1..source.len().saturating_sub(1)) + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(source.len()) + } + + fn emacs_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| self.word_chars.contains(source.chars().nth(i).unwrap())) + .unwrap_or(0); + (1..index) + .rev() + .find(|&i| !self.word_chars.contains(source.chars().nth(i).unwrap())) + .map_or(0, |i| i + 1) + } + + fn subl_get_next_word_pos(&self, source: &str, index: usize) -> usize { + let index = (index..source.len().saturating_sub(1)).find(|&i| { + self.is_word_boundary( + source.chars().nth(i).unwrap(), + source.chars().nth(i + 1).unwrap(), + ) + }); + if index.is_none() { + return source.len(); + } + (index.unwrap() + 1..source.len()) + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()) + .unwrap_or(source.len()) + } + + fn subl_get_prev_word_pos(&self, source: &str, index: usize) -> usize { + let index = (1..index) + .rev() + .find(|&i| !source.chars().nth(i).unwrap().is_whitespace()); + if index.is_none() { + return 0; + } + (1..index.unwrap()) + .rev() + .find(|&i| { + self.is_word_boundary( + source.chars().nth(i - 1).unwrap(), + source.chars().nth(i).unwrap(), + ) + }) + .unwrap_or(0) + } + + fn get_next_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_next_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_next_word_pos(source, index), + } + } + + fn get_prev_word_pos(&self, source: &str, index: usize) -> usize { + match self.word_jump_mode { + WordJumpMode::Emacs => self.emacs_get_prev_word_pos(source, index), + WordJumpMode::Subl => self.subl_get_prev_word_pos(source, index), + } + } +} + +impl Cursor { + pub fn as_str(&self) -> &str { + self.source.as_str() + } + + pub fn into_inner(self) -> String { + self.source + } + + /// Returns the string before the cursor + pub fn substring(&self) -> &str { + &self.source[..self.index] + } + + /// Returns the currently selected [`char`] + pub fn char(&self) -> Option<char> { + self.source[self.index..].chars().next() + } + + pub fn right(&mut self) { + if self.index < self.source.len() { + loop { + self.index += 1; + if self.source.is_char_boundary(self.index) { + break; + } + } + } + } + + pub fn left(&mut self) -> bool { + if self.index > 0 { + loop { + self.index -= 1; + if self.source.is_char_boundary(self.index) { + break true; + } + } + } else { + false + } + } + + pub fn next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_next_word_pos(&self.source, self.index); + } + + pub fn prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + self.index = word_jumper.get_prev_word_pos(&self.source, self.index); + } + + pub fn insert(&mut self, c: char) { + self.source.insert(self.index, c); + self.index += c.len_utf8(); + } + + pub fn remove(&mut self) -> Option<char> { + if self.index < self.source.len() { + Some(self.source.remove(self.index)) + } else { + None + } + } + + pub fn remove_next_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_next_word_pos(&self.source, self.index); + self.source.replace_range(self.index..next_index, ""); + } + + pub fn remove_prev_word(&mut self, word_chars: &str, word_jump_mode: WordJumpMode) { + let word_jumper = WordJumper { + word_chars, + word_jump_mode, + }; + let next_index = word_jumper.get_prev_word_pos(&self.source, self.index); + self.source.replace_range(next_index..self.index, ""); + self.index = next_index; + } + + pub fn back(&mut self) -> Option<char> { + if self.left() { + self.remove() + } else { + None + } + } + + pub fn clear(&mut self) { + self.source.clear(); + self.index = 0; + } + + pub fn end(&mut self) { + self.index = self.source.len(); + } + + pub fn start(&mut self) { + self.index = 0; + } +} + +#[cfg(test)] +mod cursor_tests { + use super::Cursor; + use super::*; + + static EMACS_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", + word_jump_mode: WordJumpMode::Emacs, + }; + + static SUBL_WORD_JUMPER: WordJumper = WordJumper { + word_chars: "./\\()\"'-:,.;<>~!@#$%^&*|+=[]{}`~?", + word_jump_mode: WordJumpMode::Subl, + }; + + #[test] + fn right() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + let indices = [0, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18, 20, 20, 20, 20]; + for i in indices { + assert_eq!(c.index, i); + c.right(); + } + } + + #[test] + fn left() { + // ö is 2 bytes + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + c.end(); + let indices = [20, 18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 3, 2, 0, 0, 0, 0]; + for i in indices { + assert_eq!(c.index, i); + c.left(); + } + } + + #[test] + fn test_emacs_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 6), (3, 6), (7, 18), (19, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_emacs_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 15), (29, 15), (15, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(EMACS_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_next_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(0, 3), (1, 3), (3, 9), (9, 15), (15, 21), (21, 30)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_next_word_pos("", 0), 0); + } + + #[test] + fn test_subl_get_prev_word_pos() { + let s = String::from(" aaa ((()))bbb ((())) "); + let indices = [(30, 21), (21, 15), (15, 9), (9, 3), (3, 0)]; + for (i_src, i_dest) in indices { + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos(&s, i_src), i_dest); + } + assert_eq!(SUBL_WORD_JUMPER.get_prev_word_pos("", 0), 0); + } + + #[test] + fn pop() { + let mut s = String::from("öaöböcödöeöfö"); + let mut c = Cursor::from(s.clone()); + c.end(); + while !s.is_empty() { + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + assert_eq!(s.as_str(), c.substring()); + } + let c1 = s.pop(); + let c2 = c.back(); + assert_eq!(c1, c2); + } + + #[test] + fn back() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + assert_eq!(c.back(), Some('b')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), Some('a')); + assert_eq!(c.back(), Some('ö')); + assert_eq!(c.back(), None); + assert_eq!(c.as_str(), "öcödöeöfö"); + } + + #[test] + fn insert() { + let mut c = Cursor::from(String::from("öaöböcödöeöfö")); + // move to ^ + for _ in 0..4 { + c.right(); + } + assert_eq!(c.substring(), "öaöb"); + c.insert('ö'); + c.insert('g'); + c.insert('ö'); + c.insert('h'); + assert_eq!(c.substring(), "öaöbögöh"); + assert_eq!(c.as_str(), "öaöbögöhöcödöeöfö"); + } +} diff --git a/crates/atuin/src/command/client/search/duration.rs b/crates/atuin/src/command/client/search/duration.rs new file mode 100644 index 00000000..dfa9426b --- /dev/null +++ b/crates/atuin/src/command/client/search/duration.rs @@ -0,0 +1,65 @@ +use core::fmt; +use std::{ops::ControlFlow, time::Duration}; + +#[allow(clippy::module_name_repetitions)] +pub fn format_duration_into(dur: Duration, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn item(unit: &'static str, value: u64) -> ControlFlow<(&'static str, u64)> { + if value > 0 { + ControlFlow::Break((unit, value)) + } else { + ControlFlow::Continue(()) + } + } + + // impl taken and modified from + // https://github.com/tailhook/humantime/blob/master/src/duration.rs#L295-L331 + // Copyright (c) 2016 The humantime Developers + fn fmt(f: Duration) -> ControlFlow<(&'static str, u64), ()> { + let secs = f.as_secs(); + let nanos = f.subsec_nanos(); + + let years = secs / 31_557_600; // 365.25d + let year_days = secs % 31_557_600; + let months = year_days / 2_630_016; // 30.44d + let month_days = year_days % 2_630_016; + let days = month_days / 86400; + let day_secs = month_days % 86400; + let hours = day_secs / 3600; + let minutes = day_secs % 3600 / 60; + let seconds = day_secs % 60; + + let millis = nanos / 1_000_000; + let micros = nanos / 1_000; + + // a difference from our impl than the original is that + // we only care about the most-significant segment of the duration. + // If the item call returns `Break`, then the `?` will early-return. + // This allows for a very consise impl + item("y", years)?; + item("mo", months)?; + item("d", days)?; + item("h", hours)?; + item("m", minutes)?; + item("s", seconds)?; + item("ms", u64::from(millis))?; + item("us", u64::from(micros))?; + item("ns", u64::from(nanos))?; + ControlFlow::Continue(()) + } + + match fmt(dur) { + ControlFlow::Break((unit, value)) => write!(f, "{value}{unit}"), + ControlFlow::Continue(()) => write!(f, "0s"), + } +} + +#[allow(clippy::module_name_repetitions)] +pub fn format_duration(f: Duration) -> String { + struct F(Duration); + impl fmt::Display for F { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + format_duration_into(self.0, f) + } + } + F(f).to_string() +} diff --git a/crates/atuin/src/command/client/search/engines.rs b/crates/atuin/src/command/client/search/engines.rs new file mode 100644 index 00000000..105ce147 --- /dev/null +++ b/crates/atuin/src/command/client/search/engines.rs @@ -0,0 +1,46 @@ +use async_trait::async_trait; +use atuin_client::{ + database::{Context, Database}, + history::History, + settings::{FilterMode, SearchMode}, +}; +use eyre::Result; + +use super::cursor::Cursor; + +pub mod db; +pub mod skim; + +pub fn engine(search_mode: SearchMode) -> Box<dyn SearchEngine> { + match search_mode { + SearchMode::Skim => Box::new(skim::Search::new()) as Box<_>, + mode => Box::new(db::Search(mode)) as Box<_>, + } +} + +pub struct SearchState { + pub input: Cursor, + pub filter_mode: FilterMode, + pub context: Context, +} + +#[async_trait] +pub trait SearchEngine: Send + Sync + 'static { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>>; + + async fn query(&mut self, state: &SearchState, db: &mut dyn Database) -> Result<Vec<History>> { + if state.input.as_str().is_empty() { + Ok(db + .list(&[state.filter_mode], &state.context, Some(200), true, false) + .await? + .into_iter() + .collect::<Vec<_>>()) + } else { + self.full_query(state, db).await + } + } +} diff --git a/crates/atuin/src/command/client/search/engines/db.rs b/crates/atuin/src/command/client/search/engines/db.rs new file mode 100644 index 00000000..e638f9d9 --- /dev/null +++ b/crates/atuin/src/command/client/search/engines/db.rs @@ -0,0 +1,33 @@ +use async_trait::async_trait; +use atuin_client::{ + database::Database, database::OptFilters, history::History, settings::SearchMode, +}; +use eyre::Result; + +use super::{SearchEngine, SearchState}; + +pub struct Search(pub SearchMode); + +#[async_trait] +impl SearchEngine for Search { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + Ok(db + .search( + self.0, + state.filter_mode, + &state.context, + state.input.as_str(), + OptFilters { + limit: Some(200), + ..Default::default() + }, + ) + .await + // ignore errors as it may be caused by incomplete regex + .map_or(Vec::new(), |r| r.into_iter().collect())) + } +} diff --git a/crates/atuin/src/command/client/search/engines/skim.rs b/crates/atuin/src/command/client/search/engines/skim.rs new file mode 100644 index 00000000..d2baa63b --- /dev/null +++ b/crates/atuin/src/command/client/search/engines/skim.rs @@ -0,0 +1,166 @@ +use std::path::Path; + +use async_trait::async_trait; +use atuin_client::{database::Database, history::History, settings::FilterMode}; +use eyre::Result; +use fuzzy_matcher::{skim::SkimMatcherV2, FuzzyMatcher}; +use itertools::Itertools; +use time::OffsetDateTime; +use tokio::task::yield_now; + +use super::{SearchEngine, SearchState}; + +pub struct Search { + all_history: Vec<(History, i32)>, + engine: SkimMatcherV2, +} + +impl Search { + pub fn new() -> Self { + Search { + all_history: vec![], + engine: SkimMatcherV2::default(), + } + } +} + +#[async_trait] +impl SearchEngine for Search { + async fn full_query( + &mut self, + state: &SearchState, + db: &mut dyn Database, + ) -> Result<Vec<History>> { + if self.all_history.is_empty() { + self.all_history = db.all_with_count().await.unwrap(); + } + + Ok(fuzzy_search(&self.engine, state, &self.all_history).await) + } +} + +async fn fuzzy_search( + engine: &SkimMatcherV2, + state: &SearchState, + all_history: &[(History, i32)], +) -> Vec<History> { + let mut set = Vec::with_capacity(200); + let mut ranks = Vec::with_capacity(200); + let query = state.input.as_str(); + let now = OffsetDateTime::now_utc(); + + for (i, (history, count)) in all_history.iter().enumerate() { + if i % 256 == 0 { + yield_now().await; + } + let context = &state.context; + let git_root = context + .git_root + .as_ref() + .and_then(|git_root| git_root.to_str()) + .unwrap_or(&context.cwd); + match state.filter_mode { + FilterMode::Global => {} + // we aggregate host by ',' separating them + FilterMode::Host + if history + .hostname + .split(',') + .contains(&context.hostname.as_str()) => {} + // we aggregate session by concattenating them. + // sessions are 32 byte simple uuid formats + FilterMode::Session + if history + .session + .as_bytes() + .chunks(32) + .contains(&context.session.as_bytes()) => {} + // we aggregate directory by ':' separating them + FilterMode::Directory if history.cwd.split(':').contains(&context.cwd.as_str()) => {} + FilterMode::Workspace if history.cwd.split(':').contains(&git_root) => {} + _ => continue, + } + #[allow(clippy::cast_lossless, clippy::cast_precision_loss)] + if let Some((score, indices)) = engine.fuzzy_indices(&history.command, query) { + let begin = indices.first().copied().unwrap_or_default(); + + let mut duration = (now - history.timestamp).as_seconds_f64().log2(); + if !duration.is_finite() || duration <= 1.0 { + duration = 1.0; + } + // these + X.0 just make the log result a bit smoother. + // log is very spiky towards 1-4, but I want a gradual decay. + // eg: + // log2(4) = 2, log2(5) = 2.3 (16% increase) + // log2(8) = 3, log2(9) = 3.16 (5% increase) + // log2(16) = 4, log2(17) = 4.08 (2% increase) + let count = (*count as f64 + 8.0).log2(); + let begin = (begin as f64 + 16.0).log2(); + let path = path_dist(history.cwd.as_ref(), state.context.cwd.as_ref()); + let path = (path as f64 + 8.0).log2(); + + // reduce longer durations, raise higher counts, raise matches close to the start + let score = (-score as f64) * count / path / duration / begin; + + 'insert: { + // algorithm: + // 1. find either the position that this command ranks + // 2. find the same command positioned better than our rank. + for i in 0..set.len() { + // do we out score the current position? + if ranks[i] > score { + ranks.insert(i, score); + set.insert(i, history.clone()); + let mut j = i + 1; + while j < set.len() { + // remove duplicates that have a worse score + if set[j].command == history.command { + ranks.remove(j); + set.remove(j); + + // break this while loop because there won't be any other + // duplicates. + break; + } + j += 1; + } + + // keep it limited + if ranks.len() > 200 { + ranks.pop(); + set.pop(); + } + + break 'insert; + } + // don't continue if this command has a better score already + if set[i].command == history.command { + break 'insert; + } + } + + if set.len() < 200 { + ranks.push(score); + set.push(history.clone()); + } + } + } + } + + set +} + +fn path_dist(a: &Path, b: &Path) -> usize { + let mut a: Vec<_> = a.components().collect(); + let b: Vec<_> = b.components().collect(); + + let mut dist = 0; + + // pop a until there's a common anscestor + while !b.starts_with(&a) { + dist += 1; + a.pop(); + } + + b.len() - a.len() + dist +} diff --git a/crates/atuin/src/command/client/search/history_list.rs b/crates/atuin/src/command/client/search/history_list.rs new file mode 100644 index 00000000..e27d0ce2 --- /dev/null +++ b/crates/atuin/src/command/client/search/history_list.rs @@ -0,0 +1,221 @@ +use std::time::Duration; + +use atuin_client::history::History; +use atuin_common::utils::Escapable as _; +use ratatui::{ + buffer::Buffer, + layout::Rect, + style::{Color, Modifier, Style}, + widgets::{Block, StatefulWidget, Widget}, +}; +use time::OffsetDateTime; + +use super::duration::format_duration; + +pub struct HistoryList<'a> { + history: &'a [History], + block: Option<Block<'a>>, + inverted: bool, + /// Apply an alternative highlighting to the selected row + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, +} + +#[derive(Default)] +pub struct ListState { + offset: usize, + selected: usize, + max_entries: usize, +} + +impl ListState { + pub fn selected(&self) -> usize { + self.selected + } + + pub fn max_entries(&self) -> usize { + self.max_entries + } + + pub fn select(&mut self, index: usize) { + self.selected = index; + } +} + +impl<'a> StatefulWidget for HistoryList<'a> { + type State = ListState; + + fn render(mut self, area: Rect, buf: &mut Buffer, state: &mut Self::State) { + let list_area = self.block.take().map_or(area, |b| { + let inner_area = b.inner(area); + b.render(area, buf); + inner_area + }); + + if list_area.width < 1 || list_area.height < 1 || self.history.is_empty() { + return; + } + let list_height = list_area.height as usize; + + let (start, end) = self.get_items_bounds(state.selected, state.offset, list_height); + state.offset = start; + state.max_entries = end - start; + + let mut s = DrawState { + buf, + list_area, + x: 0, + y: 0, + state, + inverted: self.inverted, + alternate_highlight: self.alternate_highlight, + now: &self.now, + }; + + for item in self.history.iter().skip(state.offset).take(end - start) { + s.index(); + s.duration(item); + s.time(item); + s.command(item); + + // reset line + s.y += 1; + s.x = 0; + } + } +} + +impl<'a> HistoryList<'a> { + pub fn new( + history: &'a [History], + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, + ) -> Self { + Self { + history, + block: None, + inverted, + alternate_highlight, + now, + } + } + + pub fn block(mut self, block: Block<'a>) -> Self { + self.block = Some(block); + self + } + + fn get_items_bounds(&self, selected: usize, offset: usize, height: usize) -> (usize, usize) { + let offset = offset.min(self.history.len().saturating_sub(1)); + + let max_scroll_space = height.min(10).min(self.history.len() - selected); + if offset + height < selected + max_scroll_space { + let end = selected + max_scroll_space; + (end - height, end) + } else if selected < offset { + (selected, selected + height) + } else { + (offset, offset + height) + } + } +} + +struct DrawState<'a> { + buf: &'a mut Buffer, + list_area: Rect, + x: u16, + y: u16, + state: &'a ListState, + inverted: bool, + alternate_highlight: bool, + now: &'a dyn Fn() -> OffsetDateTime, +} + +// longest line prefix I could come up with +#[allow(clippy::cast_possible_truncation)] // we know that this is <65536 length +pub const PREFIX_LENGTH: u16 = " > 123ms 59s ago".len() as u16; +static SPACES: &str = " "; +static _ASSERT: () = assert!(SPACES.len() == PREFIX_LENGTH as usize); + +impl DrawState<'_> { + fn index(&mut self) { + // these encode the slices of `" > "`, `" {n} "`, or `" "` in a compact form. + // Yes, this is a hack, but it makes me feel happy + static SLICES: &str = " > 1 2 3 4 5 6 7 8 9 "; + + let i = self.y as usize + self.state.offset; + let i = i.checked_sub(self.state.selected); + let i = i.unwrap_or(10).min(10) * 2; + self.draw(&SLICES[i..i + 3], Style::default()); + } + + fn duration(&mut self, h: &History) { + let status = Style::default().fg(if h.success() { + Color::Green + } else { + Color::Red + }); + let duration = Duration::from_nanos(u64::try_from(h.duration).unwrap_or(0)); + self.draw(&format_duration(duration), status); + } + + #[allow(clippy::cast_possible_truncation)] // we know that time.len() will be <6 + fn time(&mut self, h: &History) { + let style = Style::default().fg(Color::Blue); + + // Account for the chance that h.timestamp is "in the future" + // This would mean that "since" is negative, and the unwrap here + // would fail. + // If the timestamp would otherwise be in the future, display + // the time since as 0. + let since = (self.now)() - h.timestamp; + let time = format_duration(since.try_into().unwrap_or_default()); + + // pad the time a little bit before we write. this aligns things nicely + // skip padding if for some reason it is already too long to align nicely + let padding = + usize::from(PREFIX_LENGTH).saturating_sub(usize::from(self.x) + 4 + time.len()); + self.draw(&SPACES[..padding], Style::default()); + + self.draw(&time, style); + self.draw(" ago", style); + } + + fn command(&mut self, h: &History) { + let mut style = Style::default(); + if !self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + // if not applying alternative highlighting to the whole row, color the command + style = style.fg(Color::Red).add_modifier(Modifier::BOLD); + } + + for section in h.command.escape_control().split_ascii_whitespace() { + self.draw(" ", style); + if self.x > self.list_area.width { + // Avoid attempting to draw a command section beyond the width + // of the list + return; + } + self.draw(section, style); + } + } + + fn draw(&mut self, s: &str, mut style: Style) { + let cx = self.list_area.left() + self.x; + + let cy = if self.inverted { + self.list_area.top() + self.y + } else { + self.list_area.bottom() - self.y - 1 + }; + + if self.alternate_highlight && (self.y as usize + self.state.offset == self.state.selected) + { + style = style.add_modifier(Modifier::REVERSED); + } + + let w = (self.list_area.width - self.x) as usize; + self.x += self.buf.set_stringn(cx, cy, s, w, style).0 - cx; + } +} diff --git a/crates/atuin/src/command/client/search/inspector.rs b/crates/atuin/src/command/client/search/inspector.rs new file mode 100644 index 00000000..060b4df6 --- /dev/null +++ b/crates/atuin/src/command/client/search/inspector.rs @@ -0,0 +1,259 @@ +use std::time::Duration; +use time::macros::format_description; + +use atuin_client::{ + history::{History, HistoryStats}, + settings::Settings, +}; +use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use ratatui::{ + layout::Rect, + prelude::{Constraint, Direction, Layout}, + style::Style, + widgets::{Bar, BarChart, BarGroup, Block, Borders, Padding, Paragraph, Row, Table}, + Frame, +}; + +use super::duration::format_duration; + +use super::interactive::{InputAction, State}; + +#[allow(clippy::cast_sign_loss)] +fn u64_or_zero(num: i64) -> u64 { + if num < 0 { + 0 + } else { + num as u64 + } +} + +pub fn draw_commands(f: &mut Frame<'_>, parent: Rect, history: &History, stats: &HistoryStats) { + let commands = Layout::default() + .direction(Direction::Horizontal) + .constraints([ + Constraint::Ratio(1, 4), + Constraint::Ratio(1, 2), + Constraint::Ratio(1, 4), + ]) + .split(parent); + + let command = Paragraph::new(history.command.clone()).block( + Block::new() + .borders(Borders::ALL) + .title("Command") + .padding(Padding::horizontal(1)), + ); + + let previous = Paragraph::new( + stats + .previous + .clone() + .map_or("No previous command".to_string(), |prev| prev.command), + ) + .block( + Block::new() + .borders(Borders::ALL) + .title("Previous command") + .padding(Padding::horizontal(1)), + ); + + let next = Paragraph::new( + stats + .next + .clone() + .map_or("No next command".to_string(), |next| next.command), + ) + .block( + Block::new() + .borders(Borders::ALL) + .title("Next command") + .padding(Padding::horizontal(1)), + ); + + f.render_widget(previous, commands[0]); + f.render_widget(command, commands[1]); + f.render_widget(next, commands[2]); +} + +pub fn draw_stats_table(f: &mut Frame<'_>, parent: Rect, history: &History, stats: &HistoryStats) { + let duration = Duration::from_nanos(u64_or_zero(history.duration)); + let avg_duration = Duration::from_nanos(stats.average_duration); + + let rows = [ + Row::new(vec!["Time".to_string(), history.timestamp.to_string()]), + Row::new(vec!["Duration".to_string(), format_duration(duration)]), + Row::new(vec![ + "Avg duration".to_string(), + format_duration(avg_duration), + ]), + Row::new(vec!["Exit".to_string(), history.exit.to_string()]), + Row::new(vec!["Directory".to_string(), history.cwd.to_string()]), + Row::new(vec!["Session".to_string(), history.session.to_string()]), + Row::new(vec!["Total runs".to_string(), stats.total.to_string()]), + ]; + + let widths = [Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]; + + let table = Table::new(rows, widths).column_spacing(1).block( + Block::default() + .title("Command stats") + .borders(Borders::ALL) + .padding(Padding::vertical(1)), + ); + + f.render_widget(table, parent); +} + +fn num_to_day(num: &str) -> String { + match num { + "0" => "Sunday".to_string(), + "1" => "Monday".to_string(), + "2" => "Tuesday".to_string(), + "3" => "Wednesday".to_string(), + "4" => "Thursday".to_string(), + "5" => "Friday".to_string(), + "6" => "Saturday".to_string(), + _ => "Invalid day".to_string(), + } +} + +fn sort_duration_over_time(durations: &[(String, i64)]) -> Vec<(String, i64)> { + let format = format_description!("[day]-[month]-[year]"); + let output = format_description!("[month]/[year repr:last_two]"); + + let mut durations: Vec<(time::Date, i64)> = durations + .iter() + .map(|d| { + ( + time::Date::parse(d.0.as_str(), &format).expect("invalid date string from sqlite"), + d.1, + ) + }) + .collect(); + + durations.sort_by(|a, b| a.0.cmp(&b.0)); + + durations + .iter() + .map(|(date, duration)| { + ( + date.format(output).expect("failed to format sqlite date"), + *duration, + ) + }) + .collect() +} + +fn draw_stats_charts(f: &mut Frame<'_>, parent: Rect, stats: &HistoryStats) { + let exits: Vec<Bar> = stats + .exits + .iter() + .map(|(exit, count)| { + Bar::default() + .label(exit.to_string().into()) + .value(u64_or_zero(*count)) + }) + .collect(); + + let exits = BarChart::default() + .block( + Block::default() + .title("Exit code distribution") + .borders(Borders::ALL), + ) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&exits)); + + let day_of_week: Vec<Bar> = stats + .day_of_week + .iter() + .map(|(day, count)| { + Bar::default() + .label(num_to_day(day.as_str()).into()) + .value(u64_or_zero(*count)) + }) + .collect(); + + let day_of_week = BarChart::default() + .block(Block::default().title("Runs per day").borders(Borders::ALL)) + .bar_width(3) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&day_of_week)); + + let duration_over_time = sort_duration_over_time(&stats.duration_over_time); + let duration_over_time: Vec<Bar> = duration_over_time + .iter() + .map(|(date, duration)| { + let d = Duration::from_nanos(u64_or_zero(*duration)); + Bar::default() + .label(date.clone().into()) + .value(u64_or_zero(*duration)) + .text_value(format_duration(d)) + }) + .collect(); + + let duration_over_time = BarChart::default() + .block( + Block::default() + .title("Duration over time") + .borders(Borders::ALL), + ) + .bar_width(5) + .bar_gap(1) + .bar_style(Style::default()) + .value_style(Style::default()) + .label_style(Style::default()) + .data(BarGroup::default().bars(&duration_over_time)); + + let layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + Constraint::Ratio(1, 3), + ]) + .split(parent); + + f.render_widget(exits, layout[0]); + f.render_widget(day_of_week, layout[1]); + f.render_widget(duration_over_time, layout[2]); +} + +pub fn draw(f: &mut Frame<'_>, chunk: Rect, history: &History, stats: &HistoryStats) { + let vert_layout = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Ratio(1, 5), Constraint::Ratio(4, 5)]) + .split(chunk); + + let stats_layout = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Ratio(1, 3), Constraint::Ratio(2, 3)]) + .split(vert_layout[1]); + + draw_commands(f, vert_layout[0], history, stats); + draw_stats_table(f, stats_layout[0], history, stats); + draw_stats_charts(f, stats_layout[1], stats); +} + +// I'm going to break this out more, but just starting to move things around before changing +// structure and making it nicer. +pub fn input( + _state: &mut State, + _settings: &Settings, + selected: usize, + input: &KeyEvent, +) -> InputAction { + let ctrl = input.modifiers.contains(KeyModifiers::CONTROL); + + match input.code { + KeyCode::Char('d') if ctrl => InputAction::Delete(selected), + _ => InputAction::Continue, + } +} diff --git a/crates/atuin/src/command/client/search/interactive.rs b/crates/atuin/src/command/client/search/interactive.rs new file mode 100644 index 00000000..7a3a834b --- /dev/null +++ b/crates/atuin/src/command/client/search/interactive.rs @@ -0,0 +1,1309 @@ +use std::{ + io::{stdout, Write}, + time::Duration, +}; + +use atuin_common::utils::{self, Escapable as _}; +use crossterm::{ + cursor::SetCursorStyle, + event::{ + self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode, KeyEvent, KeyModifiers, + KeyboardEnhancementFlags, MouseEvent, PopKeyboardEnhancementFlags, + PushKeyboardEnhancementFlags, + }, + execute, terminal, +}; +use eyre::Result; +use futures_util::FutureExt; +use semver::Version; +use time::OffsetDateTime; +use unicode_width::UnicodeWidthStr; + +use atuin_client::{ + database::{current_context, Database}, + history::{store::HistoryStore, History, HistoryStats}, + settings::{CursorStyle, ExitMode, FilterMode, KeymapMode, SearchMode, Settings}, +}; + +use super::{ + cursor::Cursor, + engines::{SearchEngine, SearchState}, + history_list::{HistoryList, ListState, PREFIX_LENGTH}, + sort, +}; + +use crate::{command::client::search::engines, VERSION}; + +use ratatui::{ + backend::CrosstermBackend, + layout::{Alignment, Constraint, Direction, Layout}, + prelude::*, + style::{Color, Modifier, Style}, + text::{Line, Span, Text}, + widgets::{block::Title, Block, BorderType, Borders, Padding, Paragraph, Tabs}, + Frame, Terminal, TerminalOptions, Viewport, +}; + +const TAB_TITLES: [&str; 2] = ["Search", "Inspect"]; + +pub enum InputAction { + Accept(usize), + Copy(usize), + Delete(usize), + ReturnOriginal, + ReturnQuery, + Continue, + Redraw, +} + +#[allow(clippy::struct_field_names)] +pub struct State { + history_count: i64, + update_needed: Option<Version>, + results_state: ListState, + switched_search_mode: bool, + search_mode: SearchMode, + results_len: usize, + accept: bool, + keymap_mode: KeymapMode, + prefix: bool, + current_cursor: Option<CursorStyle>, + tab_index: usize, + + search: SearchState, + engine: Box<dyn SearchEngine>, + now: Box<dyn Fn() -> OffsetDateTime + Send>, +} + +#[derive(Clone, Copy)] +struct StyleState { + compact: bool, + invert: bool, + inner_width: usize, +} + +impl State { + async fn query_results( + &mut self, + db: &mut dyn Database, + smart_sort: bool, + ) -> Result<Vec<History>> { + let results = self.engine.query(&self.search, db).await?; + + self.results_state.select(0); + self.results_len = results.len(); + + if smart_sort { + Ok(sort::sort(self.search.input.as_str(), results)) + } else { + Ok(results) + } + } + + fn handle_input<W>( + &mut self, + settings: &Settings, + input: &Event, + w: &mut W, + ) -> Result<InputAction> + where + W: Write, + { + execute!(w, EnableMouseCapture)?; + let r = match input { + Event::Key(k) => self.handle_key_input(settings, k), + Event::Mouse(m) => self.handle_mouse_input(*m), + Event::Paste(d) => self.handle_paste_input(d), + _ => InputAction::Continue, + }; + execute!(w, DisableMouseCapture)?; + Ok(r) + } + + fn handle_mouse_input(&mut self, input: MouseEvent) -> InputAction { + match input.kind { + event::MouseEventKind::ScrollDown => { + self.scroll_down(1); + } + event::MouseEventKind::ScrollUp => { + self.scroll_up(1); + } + _ => {} + } + InputAction::Continue + } + + fn handle_paste_input(&mut self, input: &str) -> InputAction { + for i in input.chars() { + self.search.input.insert(i); + } + InputAction::Continue + } + + fn cast_cursor_style(style: CursorStyle) -> SetCursorStyle { + match style { + CursorStyle::DefaultUserShape => SetCursorStyle::DefaultUserShape, + CursorStyle::BlinkingBlock => SetCursorStyle::BlinkingBlock, + CursorStyle::SteadyBlock => SetCursorStyle::SteadyBlock, + CursorStyle::BlinkingUnderScore => SetCursorStyle::BlinkingUnderScore, + CursorStyle::SteadyUnderScore => SetCursorStyle::SteadyUnderScore, + CursorStyle::BlinkingBar => SetCursorStyle::BlinkingBar, + CursorStyle::SteadyBar => SetCursorStyle::SteadyBar, + } + } + + fn set_keymap_cursor(&mut self, settings: &Settings, keymap_name: &str) { + let cursor_style = if keymap_name == "__clear__" { + None + } else { + settings.keymap_cursor.get(keymap_name).copied() + } + .or_else(|| self.current_cursor.map(|_| CursorStyle::DefaultUserShape)); + + if cursor_style != self.current_cursor { + if let Some(style) = cursor_style { + self.current_cursor = cursor_style; + let _ = execute!(stdout(), Self::cast_cursor_style(style)); + } + } + } + + pub fn initialize_keymap_cursor(&mut self, settings: &Settings) { + match self.keymap_mode { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => {} + } + } + + pub fn finalize_keymap_cursor(&mut self, settings: &Settings) { + match settings.keymap_mode_shell { + KeymapMode::Emacs => self.set_keymap_cursor(settings, "emacs"), + KeymapMode::VimNormal => self.set_keymap_cursor(settings, "vim_normal"), + KeymapMode::VimInsert => self.set_keymap_cursor(settings, "vim_insert"), + KeymapMode::Auto => self.set_keymap_cursor(settings, "__clear__"), + } + } + + fn handle_key_exit(settings: &Settings) -> InputAction { + match settings.exit_mode { + ExitMode::ReturnOriginal => InputAction::ReturnOriginal, + ExitMode::ReturnQuery => InputAction::ReturnQuery, + } + } + + fn handle_key_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { + if input.kind == event::KeyEventKind::Release { + return InputAction::Continue; + } + + let ctrl = input.modifiers.contains(KeyModifiers::CONTROL); + let esc_allow_exit = !(self.tab_index == 0 && self.keymap_mode == KeymapMode::VimInsert); + + // support ctrl-a prefix, like screen or tmux + if ctrl && input.code == KeyCode::Char('a') { + self.prefix = true; + return InputAction::Continue; + } + + // core input handling, common for all tabs + let common: Option<InputAction> = match input.code { + KeyCode::Char('c' | 'g') if ctrl => Some(InputAction::ReturnOriginal), + KeyCode::Esc if esc_allow_exit => Some(Self::handle_key_exit(settings)), + KeyCode::Char('[') if ctrl && esc_allow_exit => Some(Self::handle_key_exit(settings)), + KeyCode::Tab => Some(InputAction::Accept(self.results_state.selected())), + KeyCode::Char('o') if ctrl => { + self.tab_index = (self.tab_index + 1) % TAB_TITLES.len(); + + Some(InputAction::Continue) + } + + _ => None, + }; + + if let Some(ret) = common { + self.prefix = false; + + return ret; + } + + // handle tab-specific input + let action = match self.tab_index { + 0 => self.handle_search_input(settings, input), + + 1 => super::inspector::input(self, settings, self.results_state.selected(), input), + + _ => panic!("invalid tab index on input"), + }; + + self.prefix = false; + + action + } + + fn handle_search_scroll_one_line( + &mut self, + settings: &Settings, + enable_exit: bool, + is_down: bool, + ) -> InputAction { + if is_down { + if settings.keys.scroll_exits && enable_exit && self.results_state.selected() == 0 { + return Self::handle_key_exit(settings); + } + self.scroll_down(1); + } else { + self.scroll_up(1); + } + InputAction::Continue + } + + fn handle_search_up(&mut self, settings: &Settings, enable_exit: bool) -> InputAction { + self.handle_search_scroll_one_line(settings, enable_exit, settings.invert) + } + + fn handle_search_down(&mut self, settings: &Settings, enable_exit: bool) -> InputAction { + self.handle_search_scroll_one_line(settings, enable_exit, !settings.invert) + } + + fn handle_search_accept(&mut self, settings: &Settings) -> InputAction { + if settings.enter_accept { + self.accept = true; + } + InputAction::Accept(self.results_state.selected()) + } + + #[allow(clippy::too_many_lines)] + #[allow(clippy::cognitive_complexity)] + fn handle_search_input(&mut self, settings: &Settings, input: &KeyEvent) -> InputAction { + let ctrl = input.modifiers.contains(KeyModifiers::CONTROL); + let alt = input.modifiers.contains(KeyModifiers::ALT); + + // Use Ctrl-n instead of Alt-n? + let modfr = if settings.ctrl_n_shortcuts { ctrl } else { alt }; + + // reset the state, will be set to true later if user really did change it + self.switched_search_mode = false; + + // first up handle prefix mappings. these take precedence over all others + // eg, if a user types ctrl-a d, delete the history + if self.prefix { + // It'll be expanded. + #[allow(clippy::single_match)] + match input.code { + KeyCode::Char('d') => { + return InputAction::Delete(self.results_state.selected()); + } + KeyCode::Char('a') => { + self.search.input.start(); + return InputAction::Continue; + } + _ => {} + } + } + + // handle keymap specific keybindings. + match self.keymap_mode { + KeymapMode::VimNormal => match input.code { + KeyCode::Char('/') if !ctrl => { + self.search.input.clear(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char('?') if !ctrl => { + self.search.input.clear(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char('j') if !ctrl => { + return self.handle_search_down(settings, true); + } + KeyCode::Char('k') if !ctrl => { + return self.handle_search_up(settings, true); + } + KeyCode::Char('h') if !ctrl => { + self.search.input.left(); + return InputAction::Continue; + } + KeyCode::Char('l') if !ctrl => { + self.search.input.right(); + return InputAction::Continue; + } + KeyCode::Char('a') if !ctrl => { + self.search.input.right(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char('A') if !ctrl => { + self.search.input.end(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char('i') if !ctrl => { + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char('I') if !ctrl => { + self.search.input.start(); + self.set_keymap_cursor(settings, "vim_insert"); + self.keymap_mode = KeymapMode::VimInsert; + return InputAction::Continue; + } + KeyCode::Char(_) if !ctrl => { + return InputAction::Continue; + } + _ => {} + }, + KeymapMode::VimInsert => { + if input.code == KeyCode::Esc || (ctrl && input.code == KeyCode::Char('[')) { + self.set_keymap_cursor(settings, "vim_normal"); + self.keymap_mode = KeymapMode::VimNormal; + return InputAction::Continue; + } + } + _ => {} + } + + match input.code { + KeyCode::Enter => return self.handle_search_accept(settings), + KeyCode::Char('m') if ctrl => return self.handle_search_accept(settings), + KeyCode::Char('y') if ctrl => { + return InputAction::Copy(self.results_state.selected()); + } + KeyCode::Char(c @ '1'..='9') if modfr => { + return c.to_digit(10).map_or(InputAction::Continue, |c| { + InputAction::Accept(self.results_state.selected() + c as usize) + }) + } + KeyCode::Left if ctrl => self + .search + .input + .prev_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Char('b') if alt => self + .search + .input + .prev_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Left => { + self.search.input.left(); + } + KeyCode::Char('b') if ctrl => { + self.search.input.left(); + } + KeyCode::Right if ctrl => self + .search + .input + .next_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Char('f') if alt => self + .search + .input + .next_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Right => self.search.input.right(), + KeyCode::Char('f') if ctrl => self.search.input.right(), + KeyCode::Home => self.search.input.start(), + KeyCode::Char('e') if ctrl => self.search.input.end(), + KeyCode::End => self.search.input.end(), + KeyCode::Backspace if ctrl => self + .search + .input + .remove_prev_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Backspace => { + self.search.input.back(); + } + KeyCode::Char('h' | '?') if ctrl => { + // Depending on the terminal, [Backspace] can be transmitted as + // \x08 or \x7F. Also, [Ctrl+Backspace] can be transmitted as + // \x08 or \x7F or \x1F. On the other hand, [Ctrl+h] and + // [Ctrl+?] are also transmitted as \x08 or \x7F by the + // terminals. + // + // The crossterm library translates \x08 and \x7F to C-h and + // Backspace, respectively. With the extended keyboard + // protocol enabled, crossterm can faithfully translate + // [Ctrl+h] and [Ctrl+?] to C-h and C-?. There is no perfect + // solution, but we treat C-h and C-? the same as backspace to + // suppress quirks as much as possible. + self.search.input.back(); + } + KeyCode::Delete if ctrl => self + .search + .input + .remove_next_word(&settings.word_chars, settings.word_jump_mode), + KeyCode::Delete => { + self.search.input.remove(); + } + KeyCode::Char('d') if ctrl => { + if self.search.input.as_str().is_empty() { + return InputAction::ReturnOriginal; + } + self.search.input.remove(); + } + KeyCode::Char('w') if ctrl => { + // remove the first batch of whitespace + while matches!(self.search.input.back(), Some(c) if c.is_whitespace()) {} + while self.search.input.left() { + if self.search.input.char().unwrap().is_whitespace() { + self.search.input.right(); // found whitespace, go back right + break; + } + self.search.input.remove(); + } + } + KeyCode::Char('u') if ctrl => self.search.input.clear(), + KeyCode::Char('r') if ctrl => { + let filter_modes = if settings.workspaces && self.search.context.git_root.is_some() + { + vec![ + FilterMode::Global, + FilterMode::Host, + FilterMode::Session, + FilterMode::Directory, + FilterMode::Workspace, + ] + } else { + vec![ + FilterMode::Global, + FilterMode::Host, + FilterMode::Session, + FilterMode::Directory, + ] + }; + + let i = self.search.filter_mode as usize; + let i = (i + 1) % filter_modes.len(); + self.search.filter_mode = filter_modes[i]; + } + KeyCode::Char('s') if ctrl => { + self.switched_search_mode = true; + self.search_mode = self.search_mode.next(settings); + self.engine = engines::engine(self.search_mode); + } + KeyCode::Down => { + return self.handle_search_down(settings, true); + } + KeyCode::Up => { + return self.handle_search_up(settings, true); + } + KeyCode::Char('n' | 'j') if ctrl => { + return self.handle_search_down(settings, false); + } + KeyCode::Char('p' | 'k') if ctrl => { + return self.handle_search_up(settings, false); + } + KeyCode::Char('l') if ctrl => { + return InputAction::Redraw; + } + KeyCode::Char(c) => { + self.search.input.insert(c); + } + KeyCode::PageDown if !settings.invert => { + let scroll_len = self.results_state.max_entries() - settings.scroll_context_lines; + self.scroll_down(scroll_len); + } + KeyCode::PageDown if settings.invert => { + let scroll_len = self.results_state.max_entries() - settings.scroll_context_lines; + self.scroll_up(scroll_len); + } + KeyCode::PageUp if !settings.invert => { + let scroll_len = self.results_state.max_entries() - settings.scroll_context_lines; + self.scroll_up(scroll_len); + } + KeyCode::PageUp if settings.invert => { + let scroll_len = self.results_state.max_entries() - settings.scroll_context_lines; + self.scroll_down(scroll_len); + } + _ => {} + }; + + InputAction::Continue + } + + fn scroll_down(&mut self, scroll_len: usize) { + let i = self.results_state.selected().saturating_sub(scroll_len); + self.results_state.select(i); + } + + fn scroll_up(&mut self, scroll_len: usize) { + let i = self.results_state.selected() + scroll_len; + self.results_state.select(i.min(self.results_len - 1)); + } + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::bool_to_int_with_if)] + fn calc_preview_height( + settings: &Settings, + results: &[History], + selected: usize, + tab_index: usize, + compact: bool, + border_size: u16, + preview_width: u16, + ) -> u16 { + if settings.show_preview_auto && tab_index == 0 && !results.is_empty() { + let length_current_cmd = results[selected].command.len() as u16; + // The '- 19' takes the characters before the command (duration and time) into account + if length_current_cmd > preview_width - 19 { + std::cmp::min( + settings.max_preview_height, + (length_current_cmd + preview_width - 1 - border_size) + / (preview_width - border_size), + ) + border_size * 2 + } else { + 1 + } + } else if settings.show_preview && !settings.show_preview_auto && tab_index == 0 { + let longest_command = results + .iter() + .max_by(|h1, h2| h1.command.len().cmp(&h2.command.len())); + longest_command.map_or(0, |v| { + std::cmp::min( + settings.max_preview_height, + v.command + .split('\n') + .map(|line| { + (line.len() as u16 + preview_width - 1 - border_size) + / (preview_width - border_size) + }) + .sum(), + ) + }) + border_size * 2 + } else if compact || tab_index == 1 { + 0 + } else { + 1 + } + } + + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::bool_to_int_with_if)] + #[allow(clippy::too_many_lines)] + fn draw( + &mut self, + f: &mut Frame, + results: &[History], + stats: Option<HistoryStats>, + settings: &Settings, + ) { + let compact = match settings.style { + atuin_client::settings::Style::Auto => f.size().height < 14, + atuin_client::settings::Style::Compact => true, + atuin_client::settings::Style::Full => false, + }; + let invert = settings.invert; + let border_size = if compact { 0 } else { 1 }; + let preview_width = f.size().width - 2; + let preview_height = Self::calc_preview_height( + settings, + results, + self.results_state.selected(), + self.tab_index, + compact, + border_size, + preview_width, + ); + let show_help = settings.show_help && (!compact || f.size().height > 1); + let show_tabs = settings.show_tabs; + let chunks = Layout::default() + .direction(Direction::Vertical) + .margin(0) + .horizontal_margin(1) + .constraints( + if invert { + [ + Constraint::Length(1 + border_size), // input + Constraint::Min(1), // results list + Constraint::Length(preview_height), // preview + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Length(if show_help { 1 } else { 0 }), // header (sic) + ] + } else { + [ + Constraint::Length(if show_help { 1 } else { 0 }), // header + Constraint::Length(if show_tabs { 1 } else { 0 }), // tabs + Constraint::Min(1), // results list + Constraint::Length(1 + border_size), // input + Constraint::Length(preview_height), // preview + ] + } + .as_ref(), + ) + .split(f.size()); + + let input_chunk = if invert { chunks[0] } else { chunks[3] }; + let results_list_chunk = if invert { chunks[1] } else { chunks[2] }; + let preview_chunk = if invert { chunks[2] } else { chunks[4] }; + let tabs_chunk = if invert { chunks[3] } else { chunks[1] }; + let header_chunk = if invert { chunks[4] } else { chunks[0] }; + + // TODO: this should be split so that we have one interactive search container that is + // EITHER a search box or an inspector. But I'm not doing that now, way too much atm. + // also allocate less 🙈 + let titles = TAB_TITLES.iter().copied().map(Line::from).collect(); + + let tabs = Tabs::new(titles) + .block(Block::default().borders(Borders::NONE)) + .select(self.tab_index) + .style(Style::default()) + .highlight_style(Style::default().bold().white().on_black()); + + f.render_widget(tabs, tabs_chunk); + + let style = StyleState { + compact, + invert, + inner_width: input_chunk.width.into(), + }; + + let header_chunks = Layout::default() + .direction(Direction::Horizontal) + .constraints( + [ + Constraint::Ratio(1, 5), + Constraint::Ratio(3, 5), + Constraint::Ratio(1, 5), + ] + .as_ref(), + ) + .split(header_chunk); + + let title = self.build_title(); + f.render_widget(title, header_chunks[0]); + + let help = self.build_help(); + f.render_widget(help, header_chunks[1]); + + let stats_tab = self.build_stats(); + f.render_widget(stats_tab, header_chunks[2]); + + match self.tab_index { + 0 => { + let results_list = + Self::build_results_list(style, results, self.keymap_mode, &self.now); + f.render_stateful_widget(results_list, results_list_chunk, &mut self.results_state); + } + + 1 => { + if results.is_empty() { + let message = Paragraph::new("Nothing to inspect") + .block( + Block::new() + .title( + Title::from(" Info ".to_string()).alignment(Alignment::Center), + ) + .borders(Borders::ALL) + .padding(Padding::vertical(2)), + ) + .alignment(Alignment::Center); + f.render_widget(message, results_list_chunk); + } else { + super::inspector::draw( + f, + results_list_chunk, + &results[self.results_state.selected()], + &stats.expect("Drawing inspector, but no stats"), + ); + } + + // HACK: I'm following up with abstracting this into the UI container, with a + // sub-widget for search + for inspector + let feedback = Paragraph::new("The inspector is new - please give feedback (good, or bad) at https://forum.atuin.sh"); + f.render_widget(feedback, input_chunk); + + return; + } + + _ => { + panic!("invalid tab index"); + } + } + + let input = self.build_input(style); + f.render_widget(input, input_chunk); + + let preview_width = if compact { + preview_width + } else { + preview_width - 2 + }; + let preview = + self.build_preview(results, compact, preview_width, preview_chunk.width.into()); + f.render_widget(preview, preview_chunk); + + let extra_width = UnicodeWidthStr::width(self.search.input.substring()); + + let cursor_offset = if compact { 0 } else { 1 }; + f.set_cursor( + // Put cursor past the end of the input text + input_chunk.x + extra_width as u16 + PREFIX_LENGTH + 1 + cursor_offset, + input_chunk.y + cursor_offset, + ); + } + + fn build_title(&mut self) -> Paragraph { + let title = if self.update_needed.is_some() { + Paragraph::new(Text::from(Span::styled( + format!("Atuin v{VERSION} - UPGRADE"), + Style::default().add_modifier(Modifier::BOLD).fg(Color::Red), + ))) + } else { + Paragraph::new(Text::from(Span::styled( + format!("Atuin v{VERSION}"), + Style::default().add_modifier(Modifier::BOLD), + ))) + }; + title.alignment(Alignment::Left) + } + + #[allow(clippy::unused_self)] + fn build_help(&self) -> Paragraph { + match self.tab_index { + // search + 0 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<tab>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": edit"), + Span::raw(", "), + Span::styled("<enter>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": run"), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": inspect"), + ]))), + + 1 => Paragraph::new(Text::from(Line::from(vec![ + Span::styled("<esc>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": exit"), + Span::raw(", "), + Span::styled("<ctrl-o>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": search"), + Span::raw(", "), + Span::styled("<ctrl-d>", Style::default().add_modifier(Modifier::BOLD)), + Span::raw(": delete"), + ]))), + + _ => unreachable!("invalid tab index"), + } + .style(Style::default().fg(Color::DarkGray)) + .alignment(Alignment::Center) + } + + fn build_stats(&mut self) -> Paragraph { + let stats = Paragraph::new(Text::from(Span::raw(format!( + "history count: {}", + self.history_count, + )))) + .style(Style::default().fg(Color::DarkGray)) + .alignment(Alignment::Right); + stats + } + + fn build_results_list<'a>( + style: StyleState, + results: &'a [History], + keymap_mode: KeymapMode, + now: &'a dyn Fn() -> OffsetDateTime, + ) -> HistoryList<'a> { + let results_list = HistoryList::new( + results, + style.invert, + keymap_mode == KeymapMode::VimNormal, + now, + ); + + if style.compact { + results_list + } else if style.invert { + results_list.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } else { + results_list.block( + Block::default() + .borders(Borders::TOP | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded), + ) + } + } + + fn build_input(&mut self, style: StyleState) -> Paragraph { + /// Max width of the UI box showing current mode + const MAX_WIDTH: usize = 14; + let (pref, mode) = if self.switched_search_mode { + (" SRCH:", self.search_mode.as_str()) + } else { + ("", self.search.filter_mode.as_str()) + }; + let mode_width = MAX_WIDTH - pref.len(); + // sanity check to ensure we don't exceed the layout limits + debug_assert!(mode_width >= mode.len(), "mode name '{mode}' is too long!"); + let input = format!("[{pref}{mode:^mode_width$}] {}", self.search.input.as_str(),); + let input = Paragraph::new(input); + if style.compact { + input + } else if style.invert { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT | Borders::TOP) + .border_type(BorderType::Rounded), + ) + } else { + input.block( + Block::default() + .borders(Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = style.inner_width - 2)), + ) + } + } + + fn build_preview( + &mut self, + results: &[History], + compact: bool, + preview_width: u16, + chunk_width: usize, + ) -> Paragraph { + let selected = self.results_state.selected(); + let command = if results.is_empty() { + String::new() + } else { + use itertools::Itertools as _; + let s = &results[selected].command; + s.split('\n') + .flat_map(|line| { + line.char_indices() + .step_by(preview_width.into()) + .map(|(i, _)| i) + .chain(Some(line.len())) + .tuple_windows() + .map(|(a, b)| (&line[a..b]).escape_control().to_string()) + }) + .join("\n") + }; + let preview = if compact { + Paragraph::new(command).style(Style::default().fg(Color::DarkGray)) + } else { + Paragraph::new(command).block( + Block::default() + .borders(Borders::BOTTOM | Borders::LEFT | Borders::RIGHT) + .border_type(BorderType::Rounded) + .title(format!("{:─>width$}", "", width = chunk_width - 2)), + ) + }; + preview + } +} + +struct Stdout { + stdout: std::io::Stdout, + inline_mode: bool, +} + +impl Stdout { + pub fn new(inline_mode: bool) -> std::io::Result<Self> { + terminal::enable_raw_mode()?; + let mut stdout = stdout(); + + if !inline_mode { + execute!(stdout, terminal::EnterAlternateScreen)?; + } + + execute!( + stdout, + event::EnableMouseCapture, + event::EnableBracketedPaste, + )?; + + #[cfg(not(target_os = "windows"))] + execute!( + stdout, + PushKeyboardEnhancementFlags( + KeyboardEnhancementFlags::DISAMBIGUATE_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALL_KEYS_AS_ESCAPE_CODES + | KeyboardEnhancementFlags::REPORT_ALTERNATE_KEYS + ), + )?; + + Ok(Self { + stdout, + inline_mode, + }) + } +} + +impl Drop for Stdout { + fn drop(&mut self) { + #[cfg(not(target_os = "windows"))] + execute!(self.stdout, PopKeyboardEnhancementFlags).unwrap(); + + if !self.inline_mode { + execute!(self.stdout, terminal::LeaveAlternateScreen).unwrap(); + } + execute!( + self.stdout, + event::DisableMouseCapture, + event::DisableBracketedPaste, + ) + .unwrap(); + + terminal::disable_raw_mode().unwrap(); + } +} + +impl Write for Stdout { + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { + self.stdout.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.stdout.flush() + } +} + +// this is a big blob of horrible! clean it up! +// for now, it works. But it'd be great if it were more easily readable, and +// modular. I'd like to add some more stats and stuff at some point +#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)] +pub async fn history( + query: &[String], + settings: &Settings, + mut db: impl Database, + history_store: &HistoryStore, +) -> Result<String> { + let stdout = Stdout::new(settings.inline_height > 0)?; + let backend = CrosstermBackend::new(stdout); + let mut terminal = Terminal::with_options( + backend, + TerminalOptions { + viewport: if settings.inline_height > 0 { + Viewport::Inline(settings.inline_height) + } else { + Viewport::Fullscreen + }, + }, + )?; + + let mut input = Cursor::from(query.join(" ")); + // Put the cursor at the end of the query by default + input.end(); + + let settings2 = settings.clone(); + let update_needed = tokio::spawn(async move { settings2.needs_update().await }).fuse(); + tokio::pin!(update_needed); + + let context = current_context(); + + let history_count = db.history_count(false).await?; + let search_mode = if settings.shell_up_key_binding { + settings + .search_mode_shell_up_key_binding + .unwrap_or(settings.search_mode) + } else { + settings.search_mode + }; + let mut app = State { + history_count, + results_state: ListState::default(), + update_needed: None, + switched_search_mode: false, + search_mode, + tab_index: 0, + search: SearchState { + input, + filter_mode: if settings.workspaces && context.git_root.is_some() { + FilterMode::Workspace + } else if settings.shell_up_key_binding { + settings + .filter_mode_shell_up_key_binding + .unwrap_or(settings.filter_mode) + } else { + settings.filter_mode + }, + context, + }, + engine: engines::engine(search_mode), + results_len: 0, + accept: false, + keymap_mode: match settings.keymap_mode { + KeymapMode::Auto => KeymapMode::Emacs, + value => value, + }, + current_cursor: None, + now: if settings.prefers_reduced_motion { + let now = OffsetDateTime::now_utc(); + Box::new(move || now) + } else { + Box::new(OffsetDateTime::now_utc) + }, + prefix: false, + }; + + app.initialize_keymap_cursor(settings); + + let mut results = app.query_results(&mut db, settings.smart_sort).await?; + + let mut stats: Option<HistoryStats> = None; + let accept; + let result = 'render: loop { + terminal.draw(|f| app.draw(f, &results, stats.clone(), settings))?; + + let initial_input = app.search.input.as_str().to_owned(); + let initial_filter_mode = app.search.filter_mode; + let initial_search_mode = app.search_mode; + + let event_ready = tokio::task::spawn_blocking(|| event::poll(Duration::from_millis(250))); + + tokio::select! { + event_ready = event_ready => { + if event_ready?? { + loop { + match app.handle_input(settings, &event::read()?, &mut std::io::stdout())? { + InputAction::Continue => {}, + InputAction::Delete(index) => { + app.results_len -= 1; + let selected = app.results_state.selected(); + if selected == app.results_len { + app.results_state.select(selected - 1); + } + + let entry = results.remove(index); + + if settings.sync.records { + let (id, _) = history_store.delete(entry.id).await?; + history_store.incremental_build(&db, &[id]).await?; + } else { + db.delete(entry.clone()).await?; + } + + app.tab_index = 0; + }, + InputAction::Redraw => { + terminal.clear()?; + terminal.draw(|f| app.draw(f, &results, stats.clone(), settings))?; + }, + r => { + accept = app.accept; + break 'render r; + }, + } + if !event::poll(Duration::ZERO)? { + break; + } + } + } + } + update_needed = &mut update_needed => { + app.update_needed = update_needed?; + } + } + + if initial_input != app.search.input.as_str() + || initial_filter_mode != app.search.filter_mode + || initial_search_mode != app.search_mode + { + results = app.query_results(&mut db, settings.smart_sort).await?; + } + + stats = if app.tab_index == 0 { + None + } else if !results.is_empty() { + let selected = results[app.results_state.selected()].clone(); + Some(db.stats(&selected).await?) + } else { + None + }; + }; + + app.finalize_keymap_cursor(settings); + + if settings.inline_height > 0 { + terminal.clear()?; + } + + match result { + InputAction::Accept(index) if index < results.len() => { + let mut command = results.swap_remove(index).command; + if accept + && (utils::is_zsh() || utils::is_fish() || utils::is_bash() || utils::is_xonsh()) + { + command = String::from("__atuin_accept__:") + &command; + } + + // index is in bounds so we return that entry + Ok(command) + } + InputAction::ReturnOriginal => Ok(String::new()), + InputAction::Copy(index) => { + let cmd = results.swap_remove(index).command; + set_clipboard(cmd); + Ok(String::new()) + } + InputAction::ReturnQuery | InputAction::Accept(_) => { + // Either: + // * index == RETURN_QUERY, in which case we should return the input + // * out of bounds -> usually implies no selected entry so we return the input + Ok(app.search.input.into_inner()) + } + InputAction::Continue | InputAction::Redraw | InputAction::Delete(_) => { + unreachable!("should have been handled!") + } + } +} + +// cli-clipboard only works on Windows, Mac, and Linux. + +#[cfg(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +))] +fn set_clipboard(s: String) { + cli_clipboard::set_contents(s).unwrap(); +} + +#[cfg(not(all( + feature = "clipboard", + any(target_os = "windows", target_os = "macos", target_os = "linux") +)))] +fn set_clipboard(_s: String) {} + +#[cfg(test)] +mod tests { + use atuin_client::history::History; + use atuin_client::settings::Settings; + + use super::State; + + #[test] + fn calc_preview_height_test() { + let settings_preview_auto = Settings { + show_preview_auto: true, + ..Settings::utc() + }; + + let settings_preview_auto_h2 = Settings { + show_preview_auto: true, + max_preview_height: 2, + ..Settings::utc() + }; + + let settings_preview_h4 = Settings { + show_preview_auto: false, + show_preview: true, + max_preview_height: 4, + ..Settings::utc() + }; + + let cmd_60: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("for i in $(seq -w 10); do echo \"item number $i - abcd\"; done") + .cwd("/") + .build() + .into(); + + let cmd_124: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("echo 'Aurea prima sata est aetas, quae vindice nullo, sponte sua, sine lege fidem rectumque colebat. Poena metusque aberant'") + .cwd("/") + .build() + .into(); + + let cmd_200: History = History::capture() + .timestamp(time::OffsetDateTime::now_utc()) + .command("CREATE USER atuin WITH ENCRYPTED PASSWORD 'supersecretpassword'; CREATE DATABASE atuin WITH OWNER = atuin; \\c atuin; REVOKE ALL PRIVILEGES ON SCHEMA public FROM PUBLIC; echo 'All done. 200 characters'") + .cwd("/") + .build() + .into(); + + let results: Vec<History> = vec![cmd_60, cmd_124, cmd_200]; + + // the selected command does not require a preview + let no_preview = State::calc_preview_height( + &settings_preview_auto, + &results, + 0 as usize, + 0 as usize, + false, + 1, + 80, + ); + // the selected command requires 2 lines + let preview_h2 = State::calc_preview_height( + &settings_preview_auto, + &results, + 1 as usize, + 0 as usize, + false, + 1, + 80, + ); + // the selected command requires 3 lines + let preview_h3 = State::calc_preview_height( + &settings_preview_auto, + &results, + 2 as usize, + 0 as usize, + false, + 1, + 80, + ); + // the selected command requires a preview of 1 line (happens when the command is between preview_width-19 and preview_width) + let preview_one_line = State::calc_preview_height( + &settings_preview_auto, + &results, + 0 as usize, + 0 as usize, + false, + 1, + 66, + ); + // the selected command requires 3 lines, but we have a max preview height limit of 2 + let preview_limit_at_2 = State::calc_preview_height( + &settings_preview_auto_h2, + &results, + 2 as usize, + 0 as usize, + false, + 1, + 80, + ); + // the longest command requires 3 lines + let preview_static_h3 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1 as usize, + 0 as usize, + false, + 1, + 80, + ); + // the longest command requires 10 lines, but we have a max preview height limit of 4 + let preview_static_limit_at_4 = State::calc_preview_height( + &settings_preview_h4, + &results, + 1 as usize, + 0 as usize, + false, + 1, + 20, + ); + + assert_eq!(no_preview, 1); + // 1*2 is the space for the border + assert_eq!(preview_h2, 2 + 1 * 2); + assert_eq!(preview_h3, 3 + 1 * 2); + assert_eq!(preview_one_line, 1 + 1 * 2); + assert_eq!(preview_limit_at_2, 2 + 1 * 2); + assert_eq!(preview_static_h3, 3 + 1 * 2); + assert_eq!(preview_static_limit_at_4, 4 + 1 * 2); + } +} diff --git a/crates/atuin/src/command/client/search/sort.rs b/crates/atuin/src/command/client/search/sort.rs new file mode 100644 index 00000000..4465a142 --- /dev/null +++ b/crates/atuin/src/command/client/search/sort.rs @@ -0,0 +1,46 @@ +use atuin_client::history::History; + +type ScoredHistory = (f64, History); + +// Fuzzy search already comes sorted by minspan +// This sorting should be applicable to all search modes, and solve the more "obvious" issues +// first. +// Later on, we can pass in context and do some boosts there too. +pub fn sort(query: &str, input: Vec<History>) -> Vec<History> { + // This can totally be extended. We need to be _careful_ that it's not slow. + // We also need to balance sorting db-side with sorting here. SQLite can do a lot, + // but some things are just much easier/more doable in Rust. + + let mut scored = input + .into_iter() + .map(|h| { + // If history is _prefixed_ with the query, score it more highly + let score = if h.command.starts_with(query) { + 2.0 + } else if h.command.contains(query) { + 1.75 + } else { + 1.0 + }; + + // calculate how long ago the history was, in seconds + let now = time::OffsetDateTime::now_utc().unix_timestamp(); + let time = h.timestamp.unix_timestamp(); + let diff = std::cmp::max(1, now - time); // no /0 please + + // prefer newer history, but not hugely so as to offset the other scoring + // the numbers will get super small over time, but I don't want time to overpower other + // scoring + #[allow(clippy::cast_precision_loss)] + let time_score = 1.0 + (1.0 / diff as f64); + let score = score * time_score; + + (score, h) + }) + .collect::<Vec<ScoredHistory>>(); + + scored.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap().reverse()); + + // Remove the scores and return the history + scored.into_iter().map(|(_, h)| h).collect::<Vec<History>>() +} diff --git a/crates/atuin/src/command/client/stats.rs b/crates/atuin/src/command/client/stats.rs new file mode 100644 index 00000000..7f2e7aa8 --- /dev/null +++ b/crates/atuin/src/command/client/stats.rs @@ -0,0 +1,437 @@ +use std::collections::{HashMap, HashSet}; + +use clap::Parser; +use crossterm::style::{Color, ResetColor, SetAttribute, SetForegroundColor}; +use eyre::Result; +use interim::parse_date_string; + +use atuin_client::{ + database::{current_context, Database}, + history::History, + settings::Settings, +}; +use time::{Duration, OffsetDateTime, Time}; +use unicode_segmentation::UnicodeSegmentation; + +#[derive(Parser, Debug)] +#[command(infer_subcommands = true)] +pub struct Cmd { + /// Compute statistics for the specified period, leave blank for statistics since the beginning. See https://docs.atuin.sh/reference/stats/ for more details. + period: Vec<String>, + + /// How many top commands to list + #[arg(long, short, default_value = "10")] + count: usize, + + /// The number of consecutive commands to consider + #[arg(long, short, default_value = "1")] + ngram_size: usize, +} + +fn split_at_pipe(command: &str) -> Vec<&str> { + let mut result = vec![]; + let mut quoted = false; + let mut start = 0; + let mut graphemes = UnicodeSegmentation::grapheme_indices(command, true); + + while let Some((i, c)) = graphemes.next() { + let current = i; + match c { + "\"" => { + if command[start..current] != *"\"" { + quoted = !quoted; + } + } + "'" => { + if command[start..current] != *"'" { + quoted = !quoted; + } + } + "\\" => if graphemes.next().is_some() {}, + "|" => { + if !quoted { + if command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..current]); + start = current; + } + } + _ => {} + } + } + if command[start..].starts_with('|') { + start += 1; + } + result.push(&command[start..]); + result +} + +fn compute_stats( + settings: &Settings, + history: &[History], + count: usize, + ngram_size: usize, +) -> (usize, usize) { + let mut commands = HashSet::<&str>::with_capacity(history.len()); + let mut total_unignored = 0; + let mut prefixes = HashMap::<Vec<&str>, usize>::with_capacity(history.len()); + for i in history { + // just in case it somehow has a leading tab or space or something (legacy atuin didn't ignore space prefixes) + let command = i.command.trim(); + let prefix = interesting_command(settings, command); + + if settings.stats.ignored_commands.iter().any(|c| c == prefix) { + continue; + } + + total_unignored += 1; + commands.insert(command); + + split_at_pipe(i.command.trim()) + .iter() + .map(|l| { + let command = l.trim(); + commands.insert(command); + command + }) + .collect::<Vec<_>>() + .windows(ngram_size) + .for_each(|w| { + *prefixes + .entry(w.iter().map(|c| interesting_command(settings, c)).collect()) + .or_default() += 1; + }); + } + + let unique = commands.len(); + let mut top = prefixes.into_iter().collect::<Vec<_>>(); + top.sort_unstable_by_key(|x| std::cmp::Reverse(x.1)); + top.truncate(count); + if top.is_empty() { + println!("No commands found"); + return (0, 0); + } + + let max = top.iter().map(|x| x.1).max().unwrap(); + let num_pad = max.ilog10() as usize + 1; + + // Find the length of the longest command name for each column + let column_widths = top + .iter() + .map(|(commands, _)| commands.iter().map(|c| c.len()).collect::<Vec<usize>>()) + .fold(vec![0; ngram_size], |acc, item| { + acc.iter() + .zip(item.iter()) + .map(|(a, i)| *std::cmp::max(a, i)) + .collect() + }); + + for (command, count) in top { + let gray = SetForegroundColor(Color::Grey); + let bold = SetAttribute(crossterm::style::Attribute::Bold); + + let in_ten = 10 * count / max; + print!("["); + print!("{}", SetForegroundColor(Color::Red)); + for i in 0..in_ten { + if i == 2 { + print!("{}", SetForegroundColor(Color::Yellow)); + } + if i == 5 { + print!("{}", SetForegroundColor(Color::Green)); + } + print!("▮"); + } + for _ in in_ten..10 { + print!(" "); + } + + let formatted_command = command + .iter() + .zip(column_widths.iter()) + .map(|(cmd, width)| format!("{cmd:width$}")) + .collect::<Vec<_>>() + .join(" | "); + + println!("{ResetColor}] {gray}{count:num_pad$}{ResetColor} {bold}{formatted_command}{ResetColor}"); + } + println!("Total commands: {total_unignored}"); + println!("Unique commands: {unique}"); + + (total_unignored, unique) +} + +impl Cmd { + pub async fn run(&self, db: &impl Database, settings: &Settings) -> Result<()> { + let context = current_context(); + let words = if self.period.is_empty() { + String::from("all") + } else { + self.period.join(" ") + }; + + let now = OffsetDateTime::now_utc().to_offset(settings.timezone.0); + let last_night = now.replace_time(Time::MIDNIGHT); + + let history = if words.as_str() == "all" { + db.list(&[], &context, None, false, false).await? + } else if words.trim() == "today" { + let start = last_night; + let end = start + Duration::days(1); + db.range(start, end).await? + } else if words.trim() == "month" { + let end = last_night; + let start = end - Duration::days(31); + db.range(start, end).await? + } else if words.trim() == "week" { + let end = last_night; + let start = end - Duration::days(7); + db.range(start, end).await? + } else if words.trim() == "year" { + let end = last_night; + let start = end - Duration::days(365); + db.range(start, end).await? + } else { + let start = parse_date_string(&words, now, settings.dialect.into())?; + let end = start + Duration::days(1); + db.range(start, end).await? + }; + compute_stats(settings, &history, self.count, self.ngram_size); + Ok(()) + } +} + +fn first_non_whitespace(s: &str) -> Option<usize> { + s.char_indices() + // find the first non whitespace char + .find(|(_, c)| !c.is_ascii_whitespace()) + // return the index of that char + .map(|(i, _)| i) +} + +fn first_whitespace(s: &str) -> usize { + s.char_indices() + // find the first whitespace char + .find(|(_, c)| c.is_ascii_whitespace()) + // return the index of that char, (or the max length of the string) + .map_or(s.len(), |(i, _)| i) +} + +fn interesting_command<'a>(settings: &Settings, mut command: &'a str) -> &'a str { + // Sort by length so that we match the longest prefix first + let mut common_prefix = settings.stats.common_prefix.clone(); + common_prefix.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Trim off the common prefix, if it exists + for p in &common_prefix { + if command.starts_with(p) { + let i = p.len(); + let prefix = &command[..i]; + command = command[i..].trim_start(); + if command.is_empty() { + // no commands following, just use the prefix + return prefix; + } + break; + } + } + + // Sort the common_subcommands by length so that we match the longest subcommand first + let mut common_subcommands = settings.stats.common_subcommands.clone(); + common_subcommands.sort_by_key(|b| std::cmp::Reverse(b.len())); + + // Check for a common subcommand + for p in &common_subcommands { + if command.starts_with(p) { + // if the subcommand is the same length as the command, then we just use the subcommand + if p.len() == command.len() { + return command; + } + // otherwise we need to use the subcommand + the next word + let non_whitespace = first_non_whitespace(&command[p.len()..]).unwrap_or(0); + let j = + p.len() + non_whitespace + first_whitespace(&command[p.len() + non_whitespace..]); + return &command[..j]; + } + } + // Return the first word if there is no subcommand + &command[..first_whitespace(command)] +} + +#[cfg(test)] +mod tests { + use atuin_client::history::History; + use atuin_client::settings::Settings; + use time::OffsetDateTime; + + use super::compute_stats; + use super::{interesting_command, split_at_pipe}; + + #[test] + fn ignored_commands() { + let mut settings = Settings::utc(); + settings.stats.ignored_commands.push("cd".to_string()); + + let history = [ + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cd foo") + .build() + .into(), + History::import() + .timestamp(OffsetDateTime::now_utc()) + .command("cargo build stuff") + .build() + .into(), + ]; + + let (total, unique) = compute_stats(&settings, &history, 10, 1); + assert_eq!(total, 1); + assert_eq!(unique, 1); + } + + #[test] + fn interesting_commands() { + let settings = Settings::utc(); + + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build" + ); + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + } + + // Test with spaces in the common_prefix + #[test] + fn interesting_commands_spaces() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + + assert_eq!(interesting_command(&settings, "sudo test"), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test "), "sudo test"); + assert_eq!(interesting_command(&settings, "sudo test foo bar"), "foo"); + assert_eq!( + interesting_command(&settings, "sudo test foo bar"), + "foo" + ); + + // Works with a common_subcommand as well + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build" + ); + + // We still match on just the sudo prefix + assert_eq!(interesting_command(&settings, "sudo"), "sudo"); + assert_eq!(interesting_command(&settings, "sudo foo"), "foo"); + } + + // Test with spaces in the common_subcommand + #[test] + fn interesting_commands_spaces_subcommand() { + let mut settings = Settings::utc(); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!(interesting_command(&settings, "cargo build"), "cargo build"); + assert_eq!( + interesting_command(&settings, "cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "cargo build foo bar"), + "cargo build foo" + ); + + // Works with a common_prefix as well + assert_eq!( + interesting_command(&settings, "sudo cargo build foo bar"), + "cargo build foo" + ); + + // We still match on just cargo as a subcommand + assert_eq!(interesting_command(&settings, "cargo"), "cargo"); + assert_eq!(interesting_command(&settings, "cargo foo"), "cargo foo"); + } + + // Test with spaces in the common_prefix and common_subcommand + #[test] + fn interesting_commands_spaces_both() { + let mut settings = Settings::utc(); + settings.stats.common_prefix.push("sudo test".to_string()); + settings + .stats + .common_subcommands + .push("cargo build".to_string()); + + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build"), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build "), + "cargo build" + ); + assert_eq!( + interesting_command(&settings, "sudo test cargo build foo bar"), + "cargo build foo" + ); + } + + #[test] + fn split_simple() { + assert_eq!(split_at_pipe("fd | rg"), ["fd ", " rg"]); + } + + #[test] + fn split_multi() { + assert_eq!( + split_at_pipe("kubectl | jq | rg"), + ["kubectl ", " jq ", " rg"] + ); + } + + #[test] + fn split_simple_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz {} | quux' | xyzzy"), + ["foo ", " bar 'baz {} | quux' ", " xyzzy"] + ); + } + + #[test] + fn split_multi_quoted() { + assert_eq!( + split_at_pipe("foo | bar 'baz \"{}\" | quux' | xyzzy"), + ["foo ", " bar 'baz \"{}\" | quux' ", " xyzzy"] + ); + } + + #[test] + fn escaped_pipes() { + assert_eq!( + split_at_pipe("foo | bar baz \\| quux"), + ["foo ", " bar baz \\| quux"] + ); + } + + #[test] + fn emoji() { + assert_eq!( + split_at_pipe("git commit -m \"🚀\""), + ["git commit -m \"🚀\""] + ); + } +} diff --git a/crates/atuin/src/command/client/store.rs b/crates/atuin/src/command/client/store.rs new file mode 100644 index 00000000..8e53954d --- /dev/null +++ b/crates/atuin/src/command/client/store.rs @@ -0,0 +1,105 @@ +use clap::Subcommand; +use eyre::Result; + +use atuin_client::{ + database::Database, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; +use time::OffsetDateTime; + +#[cfg(feature = "sync")] +mod push; + +#[cfg(feature = "sync")] +mod pull; + +mod purge; +mod rebuild; +mod rekey; +mod verify; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + Status, + Rebuild(rebuild::Rebuild), + Rekey(rekey::Rekey), + Purge(purge::Purge), + Verify(verify::Verify), + + #[cfg(feature = "sync")] + Push(push::Push), + + #[cfg(feature = "sync")] + Pull(pull::Pull), +} + +impl Cmd { + pub async fn run( + &self, + settings: &Settings, + database: &dyn Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Status => self.status(store).await, + Self::Rebuild(rebuild) => rebuild.run(settings, store, database).await, + Self::Rekey(rekey) => rekey.run(settings, store).await, + Self::Verify(verify) => verify.run(settings, store).await, + Self::Purge(purge) => purge.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Push(push) => push.run(settings, store).await, + + #[cfg(feature = "sync")] + Self::Pull(pull) => pull.run(settings, store, database).await, + } + } + + pub async fn status(&self, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().expect("failed to get host_id"); + + let status = store.status().await?; + + // TODO: should probs build some data structure and then pretty-print it or smth + for (host, st) in &status.hosts { + let host_string = if host == &host_id { + format!("host: {} <- CURRENT HOST", host.0.as_hyphenated()) + } else { + format!("host: {}", host.0.as_hyphenated()) + }; + + println!("{host_string}"); + + for (tag, idx) in st { + println!("\tstore: {tag}"); + + let first = store.first(*host, tag).await?; + let last = store.last(*host, tag).await?; + + println!("\t\tidx: {idx}"); + + if let Some(first) = first { + println!("\t\tfirst: {}", first.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(first.timestamp))?; + println!("\t\t\tcreated: {time}"); + } + + if let Some(last) = last { + println!("\t\tlast: {}", last.id.0.as_hyphenated()); + + let time = + OffsetDateTime::from_unix_timestamp_nanos(i128::from(last.timestamp))?; + println!("\t\t\tcreated: {time}"); + } + } + + println!(); + } + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/pull.rs b/crates/atuin/src/command/client/store/pull.rs new file mode 100644 index 00000000..36450fbf --- /dev/null +++ b/crates/atuin/src/command/client/store/pull.rs @@ -0,0 +1,78 @@ +use clap::Args; +use eyre::Result; + +use atuin_client::{ + database::Database, + record::store::Store, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Pull { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// Force push records + /// This will first wipe the local store, and then download all records from the remote + #[arg(long, default_value = "false")] + pub force: bool, +} + +impl Pull { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + db: &dyn Database, + ) -> Result<()> { + if self.force { + println!("Forcing local overwrite!"); + println!("Clearing local store"); + + store.delete_all().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they a download op? + // b) are they for the host/tag we are pushing here? + let (diff, _) = sync::diff(settings, &store).await?; + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Upload { .. } => false, + + // pull, so yes plz to downloads! + Operation::Download { tag, .. } => { + if self.force { + return true; + } + + if let Some(t) = self.tag.clone() { + if t != *tag { + return false; + } + } + + true + } + }) + .collect(); + + let (_, downloaded) = sync::sync_remote(operations, &store, settings).await?; + + println!("Downloaded {} records", downloaded.len()); + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/purge.rs b/crates/atuin/src/command/client/store/purge.rs new file mode 100644 index 00000000..ad2369ce --- /dev/null +++ b/crates/atuin/src/command/client/store/purge.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Purge {} + +impl Purge { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Purging local records that cannot be decrypted"); + + let key = load_key(settings)?; + + match store.purge(&key.into()).await { + Ok(()) => println!("Local store purge completed OK"), + Err(e) => println!("Failed to purge local store: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/push.rs b/crates/atuin/src/command/client/store/push.rs new file mode 100644 index 00000000..17a72f2a --- /dev/null +++ b/crates/atuin/src/command/client/store/push.rs @@ -0,0 +1,96 @@ +use atuin_common::record::HostId; +use clap::Args; +use eyre::Result; +use uuid::Uuid; + +use atuin_client::{ + api_client::Client, + record::sync::Operation, + record::{sqlite_store::SqliteStore, sync}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Push { + /// The tag to push (eg, 'history'). Defaults to all tags + #[arg(long, short)] + pub tag: Option<String>, + + /// The host to push, in the form of a UUID host ID. Defaults to the current host. + #[arg(long)] + pub host: Option<Uuid>, + + /// Force push records + /// This will override both host and tag, to be all hosts and all tags. First clear the remote store, then upload all of the + /// local store + #[arg(long, default_value = "false")] + pub force: bool, +} + +impl Push { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let host_id = Settings::host_id().expect("failed to get host_id"); + + if self.force { + println!("Forcing remote store overwrite!"); + println!("Clearing remote store"); + + let client = Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout * 10, // we may be deleting a lot of data... so up the + // timeout + ) + .expect("failed to create client"); + + client.delete_store().await?; + } + + // We can actually just use the existing diff/etc to push + // 1. Diff + // 2. Get operations + // 3. Filter operations by + // a) are they an upload op? + // b) are they for the host/tag we are pushing here? + let (diff, _) = sync::diff(settings, &store).await?; + let operations = sync::operations(diff, &store).await?; + + let operations = operations + .into_iter() + .filter(|op| match op { + // No noops or downloads thx + Operation::Noop { .. } | Operation::Download { .. } => false, + + // push, so yes plz to uploads! + Operation::Upload { host, tag, .. } => { + if self.force { + return true; + } + + if let Some(h) = self.host { + if HostId(h) != *host { + return false; + } + } else if *host != host_id { + return false; + } + + if let Some(t) = self.tag.clone() { + if t != *tag { + return false; + } + } + + true + } + }) + .collect(); + + let (uploaded, _) = sync::sync_remote(operations, &store, settings).await?; + + println!("Uploaded {uploaded} records"); + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/rebuild.rs b/crates/atuin/src/command/client/store/rebuild.rs new file mode 100644 index 00000000..f99d3247 --- /dev/null +++ b/crates/atuin/src/command/client/store/rebuild.rs @@ -0,0 +1,68 @@ +use atuin_dotfiles::store::AliasStore; +use clap::Args; +use eyre::{bail, Result}; + +use atuin_client::{ + database::Database, encryption, history::store::HistoryStore, + record::sqlite_store::SqliteStore, settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rebuild { + pub tag: String, +} + +impl Rebuild { + pub async fn run( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + // keep it as a string and not an enum atm + // would be super cool to build this dynamically in the future + // eg register handles for rebuilding various tags without having to make this part of the + // binary big + match self.tag.as_str() { + "history" => { + self.rebuild_history(settings, store.clone(), database) + .await?; + } + + "dotfiles" => { + self.rebuild_dotfiles(settings, store.clone()).await?; + } + + tag => bail!("unknown tag: {tag}"), + } + + Ok(()) + } + + async fn rebuild_history( + &self, + settings: &Settings, + store: SqliteStore, + database: &dyn Database, + ) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + let history_store = HistoryStore::new(store, host_id, encryption_key); + + history_store.build(database).await?; + + Ok(()) + } + + async fn rebuild_dotfiles(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let encryption_key: [u8; 32] = encryption::load_key(settings)?.into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + let alias_store = AliasStore::new(store, host_id, encryption_key); + + alias_store.build().await?; + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/rekey.rs b/crates/atuin/src/command/client/store/rekey.rs new file mode 100644 index 00000000..3e079a5a --- /dev/null +++ b/crates/atuin/src/command/client/store/rekey.rs @@ -0,0 +1,64 @@ +use clap::Args; +use eyre::{bail, Result}; +use tokio::{fs::File, io::AsyncWriteExt}; + +use atuin_client::{ + encryption::{decode_key, encode_key, generate_encoded_key, load_key, Key}, + record::sqlite_store::SqliteStore, + record::store::Store, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Rekey { + /// The new key to use for encryption. Omit for a randomly-generated key + key: Option<String>, +} + +impl Rekey { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + let key = if let Some(key) = self.key.clone() { + println!("Re-encrypting store with specified key"); + + let key = match bip39::Mnemonic::from_phrase(&key, bip39::Language::English) { + Ok(mnemonic) => encode_key(Key::from_slice(mnemonic.entropy()))?, + Err(err) => { + if let Some(err) = err.downcast_ref::<bip39::ErrorKind>() { + match err { + // assume they copied in the base64 key + bip39::ErrorKind::InvalidWord => key, + bip39::ErrorKind::InvalidChecksum => { + bail!("key mnemonic was not valid") + } + bip39::ErrorKind::InvalidKeysize(_) + | bip39::ErrorKind::InvalidWordLength(_) + | bip39::ErrorKind::InvalidEntropyLength(_, _) => { + bail!("key was not the correct length") + } + } + } else { + // unknown error. assume they copied the base64 key + key + } + } + }; + + key + } else { + println!("Re-encrypting store with freshly-generated key"); + let (_, encoded) = generate_encoded_key()?; + encoded + }; + + let current_key: [u8; 32] = load_key(settings)?.into(); + let new_key: [u8; 32] = decode_key(key.clone())?.into(); + + store.re_encrypt(¤t_key, &new_key).await?; + + println!("Store rewritten. Saving new key"); + let mut file = File::create(settings.key_path.clone()).await?; + file.write_all(key.as_bytes()).await?; + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/store/verify.rs b/crates/atuin/src/command/client/store/verify.rs new file mode 100644 index 00000000..84bec96a --- /dev/null +++ b/crates/atuin/src/command/client/store/verify.rs @@ -0,0 +1,26 @@ +use clap::Args; +use eyre::Result; + +use atuin_client::{ + encryption::load_key, + record::{sqlite_store::SqliteStore, store::Store}, + settings::Settings, +}; + +#[derive(Args, Debug)] +pub struct Verify {} + +impl Verify { + pub async fn run(&self, settings: &Settings, store: SqliteStore) -> Result<()> { + println!("Verifying local store can be decrypted with the current key"); + + let key = load_key(settings)?; + + match store.verify(&key.into()).await { + Ok(()) => println!("Local store encryption verified OK"), + Err(e) => println!("Failed to verify local store encryption: {e:?}"), + } + + Ok(()) + } +} diff --git a/crates/atuin/src/command/client/sync.rs b/crates/atuin/src/command/client/sync.rs new file mode 100644 index 00000000..be1bf6d2 --- /dev/null +++ b/crates/atuin/src/command/client/sync.rs @@ -0,0 +1,131 @@ +use clap::Subcommand; +use eyre::{Result, WrapErr}; + +use atuin_client::{ + database::Database, + encryption, + history::store::HistoryStore, + record::{sqlite_store::SqliteStore, store::Store, sync}, + settings::Settings, +}; + +mod status; + +use crate::command::client::account; + +#[derive(Subcommand, Debug)] +#[command(infer_subcommands = true)] +pub enum Cmd { + /// Sync with the configured server + Sync { + /// Force re-download everything + #[arg(long, short)] + force: bool, + }, + + /// Login to the configured server + Login(account::login::Cmd), + + /// Log out + Logout, + + /// Register with the configured server + Register(account::register::Cmd), + + /// Print the encryption key for transfer to another machine + Key { + /// Switch to base64 output of the key + #[arg(long)] + base64: bool, + }, + + /// Display the sync status + Status, +} + +impl Cmd { + pub async fn run( + self, + settings: Settings, + db: &impl Database, + store: SqliteStore, + ) -> Result<()> { + match self { + Self::Sync { force } => run(&settings, force, db, store).await, + Self::Login(l) => l.run(&settings, &store).await, + Self::Logout => account::logout::run(&settings), + Self::Register(r) => r.run(&settings).await, + Self::Status => status::run(&settings, db).await, + Self::Key { base64 } => { + use atuin_client::encryption::{encode_key, load_key}; + let key = load_key(&settings).wrap_err("could not load encryption key")?; + + if base64 { + let encode = encode_key(&key).wrap_err("could not encode encryption key")?; + println!("{encode}"); + } else { + let mnemonic = bip39::Mnemonic::from_entropy(&key, bip39::Language::English) + .map_err(|_| eyre::eyre!("invalid key"))?; + println!("{mnemonic}"); + } + Ok(()) + } + } + } +} + +async fn run( + settings: &Settings, + force: bool, + db: &impl Database, + store: SqliteStore, +) -> Result<()> { + if settings.sync.records { + let encryption_key: [u8; 32] = encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + + let (uploaded, downloaded) = sync::sync(settings, &store).await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + + let history_length = db.history_count(true).await?; + let store_history_length = store.len_tag("history").await?; + + #[allow(clippy::cast_sign_loss)] + if history_length as u64 > store_history_length { + println!( + "{history_length} in history index, but {store_history_length} in history store" + ); + println!("Running automatic history store init..."); + + // Internally we use the global filter mode, so this context is ignored. + // don't recurse or loop here. + history_store.init_store(db).await?; + + println!("Re-running sync due to new records locally"); + + // we'll want to run sync once more, as there will now be stuff to upload + let (uploaded, downloaded) = sync::sync(settings, &store).await?; + + crate::sync::build(settings, &store, db, Some(&downloaded)).await?; + + println!("{uploaded}/{} up/down to record store", downloaded.len()); + } + } else { + atuin_client::sync::sync(settings, force, db).await?; + } + + println!( + "Sync complete! {} items in history database, force: {}", + db.history_count(true).await?, + force + ); + + Ok(()) +} diff --git a/crates/atuin/src/command/client/sync/status.rs b/crates/atuin/src/command/client/sync/status.rs new file mode 100644 index 00000000..29a1e113 --- /dev/null +++ b/crates/atuin/src/command/client/sync/status.rs @@ -0,0 +1,51 @@ +use std::path::PathBuf; + +use crate::{SHA, VERSION}; +use atuin_client::{api_client, database::Database, settings::Settings}; +use colored::Colorize; +use eyre::Result; + +pub async fn run(settings: &Settings, db: &impl Database) -> Result<()> { + let session_path = settings.session_path.as_str(); + + if !PathBuf::from(session_path).exists() { + println!("You are not logged in to a sync server - cannot show sync status"); + + return Ok(()); + } + + let client = api_client::Client::new( + &settings.sync_address, + &settings.session_token, + settings.network_connect_timeout, + settings.network_timeout, + )?; + + let status = client.status().await?; + let last_sync = Settings::last_sync()?; + + println!("Atuin v{VERSION} - Build rev {SHA}\n"); + + println!("{}", "[Local]".green()); + + if settings.auto_sync { + println!("Sync frequency: {}", settings.sync_frequency); + println!("Last sync: {last_sync}"); + } + + if !settings.sync.records { + let local_count = db.history_count(false).await?; + let deleted_count = db.history_count(true).await? - local_count; + + println!("History count: {local_count}"); + println!("Deleted history count: {deleted_count}\n"); + } + + if settings.auto_sync { + println!("{}", "[Remote]".green()); + println!("Address: {}", settings.sync_address); + println!("Username: {}", status.username); + } + + Ok(()) +} diff --git a/crates/atuin/src/command/contributors.rs b/crates/atuin/src/command/contributors.rs new file mode 100644 index 00000000..452fd335 --- /dev/null +++ b/crates/atuin/src/command/contributors.rs @@ -0,0 +1,5 @@ +static CONTRIBUTORS: &str = include_str!("CONTRIBUTORS"); + +pub fn run() { + println!("\n{CONTRIBUTORS}"); +} diff --git a/crates/atuin/src/command/gen_completions.rs b/crates/atuin/src/command/gen_completions.rs new file mode 100644 index 00000000..2872a58a --- /dev/null +++ b/crates/atuin/src/command/gen_completions.rs @@ -0,0 +1,84 @@ +use clap::{CommandFactory, Parser, ValueEnum}; +use clap_complete::{generate, generate_to, Generator, Shell}; +use clap_complete_nushell::Nushell; +use eyre::Result; + +// clap put nushell completions into a separate package due to the maintainers +// being a little less committed to support them. +// This means we have to do a tiny bit of legwork to combine these completions +// into one command. +#[derive(Debug, Clone, ValueEnum)] +#[value(rename_all = "lower")] +pub enum GenShell { + Bash, + Elvish, + Fish, + Nushell, + PowerShell, + Zsh, +} + +impl Generator for GenShell { + fn file_name(&self, name: &str) -> String { + match self { + // clap_complete + Self::Bash => Shell::Bash.file_name(name), + Self::Elvish => Shell::Elvish.file_name(name), + Self::Fish => Shell::Fish.file_name(name), + Self::PowerShell => Shell::PowerShell.file_name(name), + Self::Zsh => Shell::Zsh.file_name(name), + + // clap_complete_nushell + Self::Nushell => Nushell.file_name(name), + } + } + + fn generate(&self, cmd: &clap::Command, buf: &mut dyn std::io::prelude::Write) { + match self { + // clap_complete + Self::Bash => Shell::Bash.generate(cmd, buf), + Self::Elvish => Shell::Elvish.generate(cmd, buf), + Self::Fish => Shell::Fish.generate(cmd, buf), + Self::PowerShell => Shell::PowerShell.generate(cmd, buf), + Self::Zsh => Shell::Zsh.generate(cmd, buf), + + // clap_complete_nushell + Self::Nushell => Nushell.generate(cmd, buf), + } + } +} + +#[derive(Debug, Parser)] +pub struct Cmd { + /// Set the shell for generating completions + #[arg(long, short)] + shell: GenShell, + + /// Set the output directory + #[arg(long, short)] + out_dir: Option<String>, +} + +impl Cmd { + pub fn run(self) -> Result<()> { + let Cmd { shell, out_dir } = self; + + let mut cli = crate::Atuin::command(); + + match out_dir { + Some(out_dir) => { + generate_to(shell, &mut cli, env!("CARGO_PKG_NAME"), &out_dir)?; + } + None => { + generate( + shell, + &mut cli, + env!("CARGO_PKG_NAME"), + &mut std::io::stdout(), + ); + } + } + + Ok(()) + } +} diff --git a/crates/atuin/src/command/mod.rs b/crates/atuin/src/command/mod.rs new file mode 100644 index 00000000..09df430e --- /dev/null +++ b/crates/atuin/src/command/mod.rs @@ -0,0 +1,65 @@ +use clap::Subcommand; +use eyre::Result; + +#[cfg(not(windows))] +use rustix::{fs::Mode, process::umask}; + +#[cfg(feature = "client")] +mod client; + +#[cfg(feature = "server")] +mod server; + +mod contributors; + +mod gen_completions; + +#[derive(Subcommand)] +#[command(infer_subcommands = true)] +pub enum AtuinCmd { + #[cfg(feature = "client")] + #[command(flatten)] + Client(client::Cmd), + + /// Start an atuin server + #[cfg(feature = "server")] + #[command(subcommand)] + Server(server::Cmd), + + /// Generate a UUID + Uuid, + + Contributors, + + /// Generate shell completions + GenCompletions(gen_completions::Cmd), +} + +impl AtuinCmd { + pub fn run(self) -> Result<()> { + #[cfg(not(windows))] + { + // set umask before we potentially open/create files + // or in other words, 077. Do not allow any access to any other user + let mode = Mode::RWXG | Mode::RWXO; + umask(mode); + } + + match self { + #[cfg(feature = "client")] + Self::Client(client) => client.run(), + + #[cfg(feature = "server")] + Self::Server(server) => server.run(), + Self::Contributors => { + contributors::run(); + Ok(()) + } + Self::Uuid => { + println!("{}", atuin_common::utils::uuid_v7().as_simple()); + Ok(()) + } + Self::GenCompletions(gen_completions) => gen_completions.run(), + } + } +} diff --git a/crates/atuin/src/command/server.rs b/crates/atuin/src/command/server.rs new file mode 100644 index 00000000..d45d6ef8 --- /dev/null +++ b/crates/atuin/src/command/server.rs @@ -0,0 +1,61 @@ +use std::net::SocketAddr; + +use atuin_server_postgres::Postgres; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; + +use clap::Parser; +use eyre::{Context, Result}; + +use atuin_server::{example_config, launch, launch_metrics_server, Settings}; + +#[derive(Parser, Debug)] +#[clap(infer_subcommands = true)] +pub enum Cmd { + /// Start the server + Start { + /// The host address to bind + #[clap(long)] + host: Option<String>, + + /// The port to bind + #[clap(long, short)] + port: Option<u16>, + }, + + /// Print server example configuration + DefaultConfig, +} + +impl Cmd { + #[tokio::main] + pub async fn run(self) -> Result<()> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + tracing::trace!(command = ?self, "server command"); + + match self { + Self::Start { host, port } => { + let settings = Settings::new().wrap_err("could not load server settings")?; + let host = host.as_ref().unwrap_or(&settings.host).clone(); + let port = port.unwrap_or(settings.port); + let addr = SocketAddr::new(host.parse()?, port); + + if settings.metrics.enable { + tokio::spawn(launch_metrics_server( + settings.metrics.host.clone(), + settings.metrics.port, + )); + } + + launch::<Postgres>(settings, addr).await + } + Self::DefaultConfig => { + println!("{}", example_config()); + Ok(()) + } + } + } +} diff --git a/crates/atuin/src/main.rs b/crates/atuin/src/main.rs new file mode 100644 index 00000000..16a80b10 --- /dev/null +++ b/crates/atuin/src/main.rs @@ -0,0 +1,47 @@ +#![warn(clippy::pedantic, clippy::nursery)] +#![allow(clippy::use_self, clippy::missing_const_for_fn)] // not 100% reliable + +use clap::Parser; +use eyre::Result; + +use command::AtuinCmd; + +mod command; + +#[cfg(feature = "sync")] +mod sync; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); +const SHA: &str = env!("GIT_HASH"); + +static HELP_TEMPLATE: &str = "\ +{before-help}{name} {version} +{author} +{about} + +{usage-heading} + {usage} + +{all-args}{after-help}"; + +/// Magical shell history +#[derive(Parser)] +#[command( + author = "Ellie Huxtable <e@elm.sh>", + version = VERSION, + help_template(HELP_TEMPLATE), +)] +struct Atuin { + #[command(subcommand)] + atuin: AtuinCmd, +} + +impl Atuin { + fn run(self) -> Result<()> { + self.atuin.run() + } +} + +fn main() -> Result<()> { + Atuin::parse().run() +} diff --git a/crates/atuin/src/shell/atuin.bash b/crates/atuin/src/shell/atuin.bash new file mode 100644 index 00000000..8eda0a6f --- /dev/null +++ b/crates/atuin/src/shell/atuin.bash @@ -0,0 +1,342 @@ +# Include guard +if [[ ${__atuin_initialized-} == true ]]; then + false +elif [[ $- != *i* ]]; then + # Enable only in interactive shells + false +elif ((BASH_VERSINFO[0] < 3 || BASH_VERSINFO[0] == 3 && BASH_VERSINFO[1] < 1)); then + # Require bash >= 3.1 + [[ -t 2 ]] && printf 'atuin: requires bash >= 3.1 for the integration.\n' >&2 + false +else # (include guard) beginning of main content +#------------------------------------------------------------------------------ +__atuin_initialized=true + +ATUIN_SESSION=$(atuin uuid) +ATUIN_STTY=$(stty -g) +export ATUIN_SESSION +ATUIN_HISTORY_ID="" + +export ATUIN_PREEXEC_BACKEND=$SHLVL:none +__atuin_update_preexec_backend() { + if [[ ${BLE_ATTACHED-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:blesh-${BLE_VERSION-} + elif [[ ${bash_preexec_imported-} ]]; then + ATUIN_PREEXEC_BACKEND=$SHLVL:bash-preexec + elif [[ ${__bp_imported-} ]]; then + ATUIN_PREEXEC_BACKEND="$SHLVL:bash-preexec (old)" + else + ATUIN_PREEXEC_BACKEND=$SHLVL:unknown + fi +} + +__atuin_preexec() { + # Workaround for old versions of bash-preexec + if [[ ! ${BLE_ATTACHED-} ]]; then + # In older versions of bash-preexec, the preexec hook may be called + # even for the commands run by keybindings. There is no general and + # robust way to detect the command for keybindings, but at least we + # want to exclude Atuin's keybindings. When the preexec hook is called + # for a keybinding, the preexec hook for the user command will not + # fire, so we instead set a fake ATUIN_HISTORY_ID here to notify + # __atuin_precmd of this failure. + if [[ $BASH_COMMAND == '__atuin_history'* && $BASH_COMMAND != "$1" ]]; then + ATUIN_HISTORY_ID=__bash_preexec_failure__ + return 0 + fi + fi + + # Note: We update ATUIN_PREEXEC_BACKEND on every preexec because blesh's + # attaching state can dynamically change. + __atuin_update_preexec_backend + + local id + id=$(atuin history start -- "$1") + export ATUIN_HISTORY_ID=$id + __atuin_preexec_time=${EPOCHREALTIME-} +} + +__atuin_precmd() { + local EXIT=$? __atuin_precmd_time=${EPOCHREALTIME-} + + [[ ! $ATUIN_HISTORY_ID ]] && return + + # If the previous preexec hook failed, we manually call __atuin_preexec + if [[ $ATUIN_HISTORY_ID == __bash_preexec_failure__ ]]; then + # This is the command extraction code taken from bash-preexec + local previous_command + previous_command=$( + export LC_ALL=C HISTTIMEFORMAT='' + builtin history 1 | sed '1 s/^ *[0-9][0-9]*[* ] //' + ) + __atuin_preexec "$previous_command" + fi + + local duration="" + # shellcheck disable=SC2154,SC2309 + if [[ ${BLE_ATTACHED-} && ${_ble_exec_time_ata-} ]]; then + # With ble.sh, we utilize the shell variable `_ble_exec_time_ata` + # recorded by ble.sh. It is more accurate than the measurements by + # Atuin, which includes the spawn cost of Atuin. ble.sh uses the + # special shell variable `EPOCHREALTIME` in bash >= 5.0 with the + # microsecond resolution, or the builtin `time` in bash < 5.0 with the + # millisecond resolution. + duration=${_ble_exec_time_ata}000 + elif ((BASH_VERSINFO[0] >= 5)); then + # We calculate the high-resolution duration based on EPOCHREALTIME + # (bash >= 5.0) recorded by precmd/preexec, though it might not be as + # accurate as `_ble_exec_time_ata` provided by ble.sh because it + # includes the extra time of the precmd/preexec handling. Since Bash + # does not offer floating-point arithmetic, we remove the non-digit + # characters and perform the integral arithmetic. The fraction part of + # EPOCHREALTIME is fixed to have 6 digits in Bash. We remove all the + # non-digit characters because the decimal point is not necessarily a + # period depending on the locale. + duration=$((${__atuin_precmd_time//[!0-9]} - ${__atuin_preexec_time//[!0-9]})) + if ((duration >= 0)); then + duration=${duration}000 + else + duration="" # clear the result on overflow + fi + fi + + (ATUIN_LOG=error atuin history end --exit "$EXIT" ${duration:+"--duration=$duration"} -- "$ATUIN_HISTORY_ID" &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +__atuin_set_ret_value() { + return ${1:+"$1"} +} + +# The shell function `__atuin_evaluate_prompt` evaluates prompt sequences in +# $PS1. We switch the implementation of the shell function +# `__atuin_evaluate_prompt` based on the Bash version because the expansion +# ${PS1@P} is only available in bash >= 4.4. +if ((BASH_VERSINFO[0] >= 5 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 4)); then + __atuin_evaluate_prompt() { + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + __atuin_prompt=${PS1@P} + + # Note: Strip the control characters ^A (\001) and ^B (\002), which + # Bash internally uses to enclose the escape sequences. They are + # produced by '\[' and '\]', respectively, in $PS1 and used to tell + # Bash that the strings inbetween do not contribute to the prompt + # width. After the prompt width calculation, Bash strips those control + # characters before outputting it to the terminal. We here strip these + # characters following Bash's behavior. + __atuin_prompt=${__atuin_prompt//[$'\001\002']} + + # Count the number of newlines contained in $__atuin_prompt + __atuin_prompt_offset=${__atuin_prompt//[!$'\n']} + __atuin_prompt_offset=${#__atuin_prompt_offset} + } +else + __atuin_evaluate_prompt() { + __atuin_prompt='$ ' + __atuin_prompt_offset=0 + } +fi + +# The shell function `__atuin_clear_prompt N` outputs terminal control +# sequences to clear the contents of the current and N previous lines. After +# clearing, the cursor is placed at the beginning of the N-th previous line. +__atuin_clear_prompt_cache=() +__atuin_clear_prompt() { + local offset=$1 + if [[ ! ${__atuin_clear_prompt_cache[offset]+set} ]]; then + if [[ ! ${__atuin_clear_prompt_cache[0]+set} ]]; then + __atuin_clear_prompt_cache[0]=$'\r'$(tput el 2>/dev/null || tput ce 2>/dev/null) + fi + if ((offset > 0)); then + __atuin_clear_prompt_cache[offset]=${__atuin_clear_prompt_cache[0]}$( + tput cuu "$offset" 2>/dev/null || tput UP "$offset" 2>/dev/null + tput dl "$offset" 2>/dev/null || tput DL "$offset" 2>/dev/null + tput il "$offset" 2>/dev/null || tput AL "$offset" 2>/dev/null + ) + fi + fi + printf '%s' "${__atuin_clear_prompt_cache[offset]}" +} + +__atuin_accept_line() { + local __atuin_command=$1 + + # Reprint the prompt, accounting for multiple lines + local __atuin_prompt __atuin_prompt_offset + __atuin_evaluate_prompt + __atuin_clear_prompt "$__atuin_prompt_offset" + printf '%s\n' "$__atuin_prompt$__atuin_command" + + # Add it to the bash history + history -s "$__atuin_command" + + # Assuming bash-preexec + # Invoke every function in the preexec array + local __atuin_preexec_function + local __atuin_preexec_function_ret_value + local __atuin_preexec_ret_value=0 + for __atuin_preexec_function in "${preexec_functions[@]:-}"; do + if type -t "$__atuin_preexec_function" 1>/dev/null; then + __atuin_set_ret_value "${__bp_last_ret_value:-}" + "$__atuin_preexec_function" "$__atuin_command" + __atuin_preexec_function_ret_value=$? + if [[ $__atuin_preexec_function_ret_value != 0 ]]; then + __atuin_preexec_ret_value=$__atuin_preexec_function_ret_value + fi + fi + done + + # If extdebug is turned on and any preexec function returns non-zero + # exit status, we do not run the user command. + if ! { shopt -q extdebug && ((__atuin_preexec_ret_value)); }; then + # Juggle the terminal settings so that the command can be interacted + # with + local __atuin_stty_backup + __atuin_stty_backup=$(stty -g) + stty "$ATUIN_STTY" + + # Execute the command. Note: We need to record $? and $_ after the + # user command within the same call of "eval" because $_ is otherwise + # overwritten by the last argument of "eval". + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_command"$'\n__bp_last_ret_value=$? __bp_last_argument_prev_command=$_' + + stty "$__atuin_stty_backup" + fi + + # Execute preprompt commands + local __atuin_prompt_command + for __atuin_prompt_command in "${PROMPT_COMMAND[@]}"; do + __atuin_set_ret_value "${__bp_last_ret_value-}" "${__bp_last_argument_prev_command-}" + eval -- "$__atuin_prompt_command" + done + # Bash will redraw only the line with the prompt after we finish, + # so to work for a multiline prompt we need to print it ourselves, + # then go to the beginning of the last line. + __atuin_evaluate_prompt + printf '%s' "$__atuin_prompt" + __atuin_clear_prompt 0 +} + +__atuin_history() { + # Default action of the up key: When this function is called with the first + # argument `--shell-up-key-binding`, we perform Atuin's history search only + # when the up key is supposed to cause the history movement in the original + # binding. We do this only for ble.sh because the up key always invokes + # the history movement in the plain Bash. + if [[ ${BLE_ATTACHED-} && ${1-} == --shell-up-key-binding ]]; then + # When the current cursor position is not in the first line, the up key + # should move the cursor to the previous line. While the selection is + # performed, the up key should not start the history search. + # shellcheck disable=SC2154 # Note: these variables are set by ble.sh + if [[ ${_ble_edit_str::_ble_edit_ind} == *$'\n'* || $_ble_edit_mark_active ]]; then + ble/widget/@nomarked backward-line + local status=$? + READLINE_LINE=$_ble_edit_str + READLINE_POINT=$_ble_edit_ind + READLINE_MARK=$_ble_edit_mark + return "$status" + fi + fi + + # READLINE_LINE and READLINE_POINT are only supported by bash >= 4.0 or + # ble.sh. When it is not supported, we localize them to suppress strange + # behaviors. + [[ ${BLE_ATTACHED-} ]] || ((BASH_VERSINFO[0] >= 4)) || + local READLINE_LINE="" READLINE_POINT=0 + + local __atuin_output + __atuin_output=$(ATUIN_SHELL_BASH=t ATUIN_LOG=error ATUIN_QUERY="$READLINE_LINE" atuin search "$@" -i 3>&1 1>&2 2>&3) + + # We do nothing when the search is canceled. + [[ $__atuin_output ]] || return 0 + + if [[ $__atuin_output == __atuin_accept__:* ]]; then + __atuin_output=${__atuin_output#__atuin_accept__:} + + if [[ ${BLE_ATTACHED-} ]]; then + ble-edit/content/reset-and-check-dirty "$__atuin_output" + ble/widget/accept-line + else + __atuin_accept_line "$__atuin_output" + fi + + READLINE_LINE="" + READLINE_POINT=${#READLINE_LINE} + else + READLINE_LINE=$__atuin_output + READLINE_POINT=${#READLINE_LINE} + fi +} + +# shellcheck disable=SC2154 +if [[ ${BLE_VERSION-} ]] && ((_ble_version >= 400)); then + ble-import contrib/integration/bash-preexec + + # Define and register an autosuggestion source for ble.sh's auto-complete. + # If you'd like to overwrite this, define the same name of shell function + # after the $(atuin init bash) line in your .bashrc. If you do not need + # the auto-complete source by atuin, please add the following code to + # remove the entry after the $(atuin init bash) line in your .bashrc: + # + # ble/util/import/eval-after-load core-complete ' + # ble/array#remove _ble_complete_auto_source atuin-history' + # + function ble/complete/auto-complete/source:atuin-history { + local suggestion + suggestion=$(ATUIN_QUERY="$_ble_edit_str" atuin search --cmd-only --limit 1 --search-mode prefix) + [[ $suggestion == "$_ble_edit_str"?* ]] || return 1 + ble/complete/auto-complete/enter h 0 "${suggestion:${#_ble_edit_str}}" '' "$suggestion" + } + ble/util/import/eval-after-load core-complete ' + ble/array#unshift _ble_complete_auto_source atuin-history' + + # @env BLE_SESSION_ID: `atuin doctor` references the environment variable + # BLE_SESSION_ID. We explicitly export the variable because it was not + # exported in older versions of ble.sh. + [[ ${BLE_SESSION_ID-} ]] && export BLE_SESSION_ID +fi +precmd_functions+=(__atuin_precmd) +preexec_functions+=(__atuin_preexec) + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_ctrl_r == true ]]; then + # Note: We do not overwrite [C-r] in the vi-command keymap for Bash because + # we do not want to overwrite "redo", which is already bound to [C-r] in + # the vi_nmap keymap in ble.sh. + bind -m emacs -x '"\C-r": __atuin_history --keymap-mode=emacs' + bind -m vi-insert -x '"\C-r": __atuin_history --keymap-mode=vim-insert' + bind -m vi-command -x '"/": __atuin_history --keymap-mode=emacs' +fi + +# shellcheck disable=SC2154 +if [[ $__atuin_bind_up_arrow == true ]]; then + if ((BASH_VERSINFO[0] > 4 || BASH_VERSINFO[0] == 4 && BASH_VERSINFO[1] >= 3)); then + bind -m emacs -x '"\e[A": __atuin_history --shell-up-key-binding --keymap-mode=emacs' + bind -m emacs -x '"\eOA": __atuin_history --shell-up-key-binding --keymap-mode=emacs' + bind -m vi-insert -x '"\e[A": __atuin_history --shell-up-key-binding --keymap-mode=vim-insert' + bind -m vi-insert -x '"\eOA": __atuin_history --shell-up-key-binding --keymap-mode=vim-insert' + bind -m vi-command -x '"\e[A": __atuin_history --shell-up-key-binding --keymap-mode=vim-normal' + bind -m vi-command -x '"\eOA": __atuin_history --shell-up-key-binding --keymap-mode=vim-normal' + bind -m vi-command -x '"k": __atuin_history --shell-up-key-binding --keymap-mode=vim-normal' + else + # In bash < 4.3, "bind -x" cannot bind a shell command to a keyseq + # having more than two bytes. To work around this, we first translate + # the keyseqs to the two-byte sequence \C-x\C-p (which is not used by + # default) using string macros and run the shell command through the + # keybinding to \C-x\C-p. + bind -m emacs -x '"\C-x\C-p": __atuin_history --shell-up-key-binding --keymap-mode=emacs' + bind -m emacs '"\e[A": "\C-x\C-p"' + bind -m emacs '"\eOA": "\C-x\C-p"' + bind -m vi-insert -x '"\C-x\C-p": __atuin_history --shell-up-key-binding --keymap-mode=vim-insert' + bind -m vi-insert -x '"\e[A": "\C-x\C-p"' + bind -m vi-insert -x '"\eOA": "\C-x\C-p"' + bind -m vi-command -x '"\C-x\C-p": __atuin_history --shell-up-key-binding --keymap-mode=vim-normal' + bind -m vi-command -x '"\e[A": "\C-x\C-p"' + bind -m vi-command -x '"\eOA": "\C-x\C-p"' + bind -m vi-command -x '"k": "\C-x\C-p"' + fi +fi + +#------------------------------------------------------------------------------ +fi # (include guard) end of main content diff --git a/crates/atuin/src/shell/atuin.fish b/crates/atuin/src/shell/atuin.fish new file mode 100644 index 00000000..6ef1e2d2 --- /dev/null +++ b/crates/atuin/src/shell/atuin.fish @@ -0,0 +1,71 @@ +set -gx ATUIN_SESSION (atuin uuid) +set --erase ATUIN_HISTORY_ID + +function _atuin_preexec --on-event fish_preexec + if not test -n "$fish_private_mode" + set -g ATUIN_HISTORY_ID (atuin history start -- "$argv[1]") + end +end + +function _atuin_postexec --on-event fish_postexec + set -l s $status + + if test -n "$ATUIN_HISTORY_ID" + ATUIN_LOG=error atuin history end --exit $s -- $ATUIN_HISTORY_ID &>/dev/null & + disown + end + + set --erase ATUIN_HISTORY_ID +end + +function _atuin_search + set -l keymap_mode + switch $fish_key_bindings + case fish_vi_key_bindings + switch $fish_bind_mode + case default + set keymap_mode vim-normal + case insert + set keymap_mode vim-insert + end + case '*' + set keymap_mode emacs + end + + # In fish 3.4 and above we can use `"$(some command)"` to keep multiple lines separate; + # but to support fish 3.3 we need to use `(some command | string collect)`. + # https://fishshell.com/docs/current/relnotes.html#id24 (fish 3.4 "Notable improvements and fixes") + set -l ATUIN_H (ATUIN_SHELL_FISH=t ATUIN_LOG=error ATUIN_QUERY=(commandline -b) atuin search --keymap-mode=$keymap_mode $argv -i 3>&1 1>&2 2>&3 | string collect) + + if test -n "$ATUIN_H" + if string match --quiet '__atuin_accept__:*' "$ATUIN_H" + set -l ATUIN_HIST (string replace "__atuin_accept__:" "" -- "$ATUIN_H" | string collect) + commandline -r "$ATUIN_HIST" + commandline -f repaint + commandline -f execute + return + else + commandline -r "$ATUIN_H" + end + end + + commandline -f repaint +end + +function _atuin_bind_up + # Fallback to fish's builtin up-or-search if we're in search or paging mode + if commandline --search-mode; or commandline --paging-mode + up-or-search + return + end + + # Only invoke atuin if we're on the top line of the command + set -l lineno (commandline --line) + + switch $lineno + case 1 + _atuin_search --shell-up-key-binding + case '*' + up-or-search + end +end diff --git a/crates/atuin/src/shell/atuin.nu b/crates/atuin/src/shell/atuin.nu new file mode 100644 index 00000000..102c6dbe --- /dev/null +++ b/crates/atuin/src/shell/atuin.nu @@ -0,0 +1,60 @@ +# Source this in your ~/.config/nushell/config.nu +$env.ATUIN_SESSION = (atuin uuid) +hide-env -i ATUIN_HISTORY_ID + +# Magic token to make sure we don't record commands run by keybindings +let ATUIN_KEYBINDING_TOKEN = $"# (random uuid)" + +let _atuin_pre_execution = {|| + if ($nu | get -i history-enabled) == false { + return + } + let cmd = (commandline) + if ($cmd | is-empty) { + return + } + if not ($cmd | str starts-with $ATUIN_KEYBINDING_TOKEN) { + $env.ATUIN_HISTORY_ID = (atuin history start -- $cmd) + } +} + +let _atuin_pre_prompt = {|| + let last_exit = $env.LAST_EXIT_CODE + if 'ATUIN_HISTORY_ID' not-in $env { + return + } + with-env { ATUIN_LOG: error } { + do { atuin history end $'--exit=($last_exit)' -- $env.ATUIN_HISTORY_ID } | complete + + } + hide-env ATUIN_HISTORY_ID +} + +def _atuin_search_cmd [...flags: string] { + let nu_version = ($env.NU_VERSION | split row '.' | each { || into int }) + [ + $ATUIN_KEYBINDING_TOKEN, + ([ + `with-env { ATUIN_LOG: error, ATUIN_QUERY: (commandline) } {`, + (if $nu_version.0 <= 0 and $nu_version.1 <= 90 { 'commandline' } else { 'commandline edit' }), + (if $nu_version.1 >= 92 { '(run-external atuin search' } else { '(run-external --redirect-stderr atuin search' }), + ($flags | append [--interactive] | each {|e| $'"($e)"'}), + (if $nu_version.1 >= 92 { ' e>| str trim)' } else {' | complete | $in.stderr | str substring ..-1)'}), + `}`, + ] | flatten | str join ' '), + ] | str join "\n" +} + +$env.config = ($env | default {} config).config +$env.config = ($env.config | default {} hooks) +$env.config = ( + $env.config | upsert hooks ( + $env.config.hooks + | upsert pre_execution ( + $env.config.hooks | get -i pre_execution | default [] | append $_atuin_pre_execution) + | upsert pre_prompt ( + $env.config.hooks | get -i pre_prompt | default [] | append $_atuin_pre_prompt) + ) +) + +$env.config = ($env.config | default [] keybindings) diff --git a/crates/atuin/src/shell/atuin.xsh b/crates/atuin/src/shell/atuin.xsh new file mode 100644 index 00000000..d504c627 --- /dev/null +++ b/crates/atuin/src/shell/atuin.xsh @@ -0,0 +1,80 @@ +import subprocess + +from prompt_toolkit.application.current import get_app +from prompt_toolkit.filters import Condition +from prompt_toolkit.keys import Keys + + +$ATUIN_SESSION=$(atuin uuid).rstrip('\n') + +@events.on_precommand +def _atuin_precommand(cmd: str): + cmd = cmd.rstrip("\n") + $ATUIN_HISTORY_ID = $(atuin history start -- @(cmd)).rstrip("\n") + + +@events.on_postcommand +def _atuin_postcommand(cmd: str, rtn: int, out, ts): + if "ATUIN_HISTORY_ID" not in ${...}: + return + + duration = ts[1] - ts[0] + # Duration is float representing seconds, but atuin expects integer of nanoseconds + nanos = round(duration * 10 ** 9) + with ${...}.swap(ATUIN_LOG="error"): + # This causes the entire .xonshrc to be re-executed, which is incredibly slow + # This happens when using a subshell and using output redirection at the same time + # For more details, see https://github.com/xonsh/xonsh/issues/5224 + # (atuin history end --exit @(rtn) -- $ATUIN_HISTORY_ID &) > /dev/null 2>&1 + atuin history end --exit @(rtn) --duration @(nanos) -- $ATUIN_HISTORY_ID > /dev/null 2>&1 + del $ATUIN_HISTORY_ID + + +def _search(event, extra_args: list[str]): + buffer = event.current_buffer + cmd = ["atuin", "search", "--interactive", *extra_args] + # We need to explicitly pass in xonsh env, in case user has set XDG_HOME or something else that matters + env = ${...}.detype() + env["ATUIN_SHELL_XONSH"] = "t" + env["ATUIN_QUERY"] = buffer.text + + p = subprocess.run(cmd, stderr=subprocess.PIPE, encoding="utf-8", env=env) + result = p.stderr.rstrip("\n") + # redraw prompt - necessary if atuin is configured to run inline, rather than fullscreen + event.cli.renderer.erase() + + if not result: + return + + buffer.reset() + if result.startswith("__atuin_accept__:"): + buffer.insert_text(result[17:]) + buffer.validate_and_handle() + else: + buffer.insert_text(result) + + +@events.on_ptk_create +def _custom_keybindings(bindings, **kw): + if _ATUIN_BIND_CTRL_R: + @bindings.add(Keys.ControlR) + def r_search(event): + _search(event, extra_args=[]) + + if _ATUIN_BIND_UP_ARROW: + @Condition + def should_search(): + buffer = get_app().current_buffer + # disable keybind when there is an active completion, so + # that up arrow can be used to navigate completion menu + if buffer.complete_state is not None: + return False + # similarly, disable when buffer text contains multiple lines + if '\n' in buffer.text: + return False + + return True + + @bindings.add(Keys.Up, filter=should_search) + def up_search(event): + _search(event, extra_args=["--shell-up-key-binding"]) diff --git a/crates/atuin/src/shell/atuin.zsh b/crates/atuin/src/shell/atuin.zsh new file mode 100644 index 00000000..d580f704 --- /dev/null +++ b/crates/atuin/src/shell/atuin.zsh @@ -0,0 +1,108 @@ +# shellcheck disable=SC2034,SC2153,SC2086,SC2155 + +# Above line is because shellcheck doesn't support zsh, per +# https://github.com/koalaman/shellcheck/wiki/SC1071, and the ignore: param in +# ludeeus/action-shellcheck only supports _directories_, not _files_. So +# instead, we manually add any error the shellcheck step finds in the file to +# the above line ... + +# Source this in your ~/.zshrc +autoload -U add-zsh-hook + +zmodload zsh/datetime 2>/dev/null + +# If zsh-autosuggestions is installed, configure it to use Atuin's search. If +# you'd like to override this, then add your config after the $(atuin init zsh) +# in your .zshrc +_zsh_autosuggest_strategy_atuin() { + suggestion=$(ATUIN_QUERY="$1" atuin search --cmd-only --limit 1 --search-mode prefix) +} + +if [ -n "${ZSH_AUTOSUGGEST_STRATEGY:-}" ]; then + ZSH_AUTOSUGGEST_STRATEGY=("atuin" "${ZSH_AUTOSUGGEST_STRATEGY[@]}") +else + ZSH_AUTOSUGGEST_STRATEGY=("atuin") +fi + +export ATUIN_SESSION=$(atuin uuid) +ATUIN_HISTORY_ID="" + +_atuin_preexec() { + local id + id=$(atuin history start -- "$1") + export ATUIN_HISTORY_ID="$id" + __atuin_preexec_time=${EPOCHREALTIME-} +} + +_atuin_precmd() { + local EXIT="$?" __atuin_precmd_time=${EPOCHREALTIME-} + + [[ -z "${ATUIN_HISTORY_ID:-}" ]] && return + + local duration="" + if [[ -n $__atuin_preexec_time && -n $__atuin_precmd_time ]]; then + printf -v duration %.0f $(((__atuin_precmd_time - __atuin_preexec_time) * 1000000000)) + fi + + (ATUIN_LOG=error atuin history end --exit $EXIT ${duration:+--duration=$duration} -- $ATUIN_HISTORY_ID &) >/dev/null 2>&1 + export ATUIN_HISTORY_ID="" +} + +_atuin_search() { + emulate -L zsh + zle -I + + # swap stderr and stdout, so that the tui stuff works + # TODO: not this + local output + # shellcheck disable=SC2048 + output=$(ATUIN_SHELL_ZSH=t ATUIN_LOG=error ATUIN_QUERY=$BUFFER atuin search $* -i 3>&1 1>&2 2>&3) + + zle reset-prompt + + if [[ -n $output ]]; then + RBUFFER="" + LBUFFER=$output + + if [[ $LBUFFER == __atuin_accept__:* ]] + then + LBUFFER=${LBUFFER#__atuin_accept__:} + zle accept-line + fi + fi +} +_atuin_search_vicmd() { + _atuin_search --keymap-mode=vim-normal +} +_atuin_search_viins() { + _atuin_search --keymap-mode=vim-insert +} + +_atuin_up_search() { + # Only trigger if the buffer is a single line + if [[ ! $BUFFER == *$'\n'* ]]; then + _atuin_search --shell-up-key-binding "$@" + else + zle up-line + fi +} +_atuin_up_search_vicmd() { + _atuin_up_search --keymap-mode=vim-normal +} +_atuin_up_search_viins() { + _atuin_up_search --keymap-mode=vim-insert +} + +add-zsh-hook preexec _atuin_preexec +add-zsh-hook precmd _atuin_precmd + +zle -N atuin-search _atuin_search +zle -N atuin-search-vicmd _atuin_search_vicmd +zle -N atuin-search-viins _atuin_search_viins +zle -N atuin-up-search _atuin_up_search +zle -N atuin-up-search-vicmd _atuin_up_search_vicmd +zle -N atuin-up-search-viins _atuin_up_search_viins + +# These are compatibility widget names for "atuin <= 17.2.1" users. +zle -N _atuin_search_widget _atuin_search +zle -N _atuin_up_search_widget _atuin_up_search diff --git a/crates/atuin/src/sync.rs b/crates/atuin/src/sync.rs new file mode 100644 index 00000000..894a4aaa --- /dev/null +++ b/crates/atuin/src/sync.rs @@ -0,0 +1,37 @@ +use atuin_dotfiles::store::AliasStore; +use eyre::{Context, Result}; + +use atuin_client::{ + database::Database, history::store::HistoryStore, record::sqlite_store::SqliteStore, + settings::Settings, +}; +use atuin_common::record::RecordId; + +/// This is the only crate that ties together all other crates. +/// Therefore, it's the only crate where functions tying together all stores can live + +/// Rebuild all stores after a sync +/// Note: for history, this only does an _incremental_ sync. Hence the need to specify downloaded +/// records. +pub async fn build( + settings: &Settings, + store: &SqliteStore, + db: &dyn Database, + downloaded: Option<&[RecordId]>, +) -> Result<()> { + let encryption_key: [u8; 32] = atuin_client::encryption::load_key(settings) + .context("could not load encryption key")? + .into(); + + let host_id = Settings::host_id().expect("failed to get host_id"); + + let downloaded = downloaded.unwrap_or(&[]); + + let history_store = HistoryStore::new(store.clone(), host_id, encryption_key); + let alias_store = AliasStore::new(store.clone(), host_id, encryption_key); + + history_store.incremental_build(db, downloaded).await?; + alias_store.build().await?; + + Ok(()) +} diff --git a/crates/atuin/tests/common/mod.rs b/crates/atuin/tests/common/mod.rs new file mode 100644 index 00000000..65679244 --- /dev/null +++ b/crates/atuin/tests/common/mod.rs @@ -0,0 +1,100 @@ +use std::{env, time::Duration}; + +use atuin_client::api_client; +use atuin_common::utils::uuid_v7; +use atuin_server::{launch_with_tcp_listener, Settings as ServerSettings}; +use atuin_server_postgres::{Postgres, PostgresSettings}; +use futures_util::TryFutureExt; +use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle}; +use tracing::{dispatcher, Dispatch}; +use tracing_subscriber::{layer::SubscriberExt, EnvFilter}; + +pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandle<()>) { + let formatting_layer = tracing_tree::HierarchicalLayer::default() + .with_writer(tracing_subscriber::fmt::TestWriter::new()) + .with_indent_lines(true) + .with_ansi(true) + .with_targets(true) + .with_indent_amount(2); + + let dispatch: Dispatch = tracing_subscriber::registry() + .with(formatting_layer) + .with(EnvFilter::new("atuin_server=debug,atuin_client=debug,info")) + .into(); + + let db_uri = env::var("ATUIN_DB_URI") + .unwrap_or_else(|_| "postgres://atuin:pass@localhost:5432/atuin".to_owned()); + + let server_settings = ServerSettings { + host: "127.0.0.1".to_owned(), + port: 0, + path: path.to_owned(), + open_registration: true, + max_history_length: 8192, + max_record_size: 1024 * 1024 * 1024, + page_size: 1100, + register_webhook_url: None, + register_webhook_username: String::new(), + db_settings: PostgresSettings { db_uri }, + metrics: atuin_server::settings::Metrics::default(), + tls: atuin_server::settings::Tls::default(), + }; + + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tokio::spawn(async move { + let _tracing_guard = dispatcher::set_default(&dispatch); + + if let Err(e) = launch_with_tcp_listener::<Postgres>( + server_settings, + listener, + shutdown_rx.unwrap_or_else(|_| ()), + ) + .await + { + tracing::error!(error=?e, "server error"); + panic!("error running server: {e:?}"); + } + }); + + // let the server come online + tokio::time::sleep(Duration::from_millis(200)).await; + + (format!("http://{addr}{path}"), shutdown_tx, server) +} + +pub async fn register_inner<'a>( + address: &'a str, + username: &str, + password: &str, +) -> api_client::Client<'a> { + let email = format!("{}@example.com", uuid_v7().as_simple()); + + // registration works + let registration_response = api_client::register(address, username, &email, password) + .await + .unwrap(); + + api_client::Client::new(address, ®istration_response.session, 5, 30).unwrap() +} + +#[allow(dead_code)] +pub async fn login(address: &str, username: String, password: String) -> api_client::Client<'_> { + // registration works + let login_respose = api_client::login( + address, + atuin_common::api::LoginRequest { username, password }, + ) + .await + .unwrap(); + + api_client::Client::new(address, &login_respose.session, 5, 30).unwrap() +} + +#[allow(dead_code)] +pub async fn register(address: &str) -> api_client::Client<'_> { + let username = uuid_v7().as_simple().to_string(); + let password = uuid_v7().as_simple().to_string(); + register_inner(address, &username, &password).await +} diff --git a/crates/atuin/tests/sync.rs b/crates/atuin/tests/sync.rs new file mode 100644 index 00000000..7e25d1c2 --- /dev/null +++ b/crates/atuin/tests/sync.rs @@ -0,0 +1,45 @@ +use atuin_common::{api::AddHistoryRequest, utils::uuid_v7}; +use time::OffsetDateTime; + +mod common; + +#[tokio::test] +async fn sync() { + let path = format!("/{}", uuid_v7().as_simple()); + let (address, shutdown, server) = common::start_server(&path).await; + + let client = common::register(&address).await; + let hostname = uuid_v7().as_simple().to_string(); + let now = OffsetDateTime::now_utc(); + + let data1 = uuid_v7().as_simple().to_string(); + let data2 = uuid_v7().as_simple().to_string(); + + client + .post_history(&[ + AddHistoryRequest { + id: uuid_v7().as_simple().to_string(), + timestamp: now, + data: data1.clone(), + hostname: hostname.clone(), + }, + AddHistoryRequest { + id: uuid_v7().as_simple().to_string(), + timestamp: now, + data: data2.clone(), + hostname: hostname.clone(), + }, + ]) + .await + .unwrap(); + + let history = client + .get_history(OffsetDateTime::UNIX_EPOCH, OffsetDateTime::UNIX_EPOCH, None) + .await + .unwrap(); + + assert_eq!(history.history, vec![data1, data2]); + + shutdown.send(()).unwrap(); + server.await.unwrap(); +} diff --git a/crates/atuin/tests/users.rs b/crates/atuin/tests/users.rs new file mode 100644 index 00000000..95fb533b --- /dev/null +++ b/crates/atuin/tests/users.rs @@ -0,0 +1,121 @@ +use atuin_common::utils::uuid_v7; + +mod common; + +#[tokio::test] +async fn registration() { + let path = format!("/{}", uuid_v7().as_simple()); + let (address, shutdown, server) = common::start_server(&path).await; + dbg!(&address); + + // -- REGISTRATION -- + + let username = uuid_v7().as_simple().to_string(); + let password = uuid_v7().as_simple().to_string(); + let client = common::register_inner(&address, &username, &password).await; + + // the session token works + let status = client.status().await.unwrap(); + assert_eq!(status.username, username); + + // -- LOGIN -- + + let client = common::login(&address, username.clone(), password).await; + + // the session token works + let status = client.status().await.unwrap(); + assert_eq!(status.username, username); + + shutdown.send(()).unwrap(); + server.await.unwrap(); +} + +#[tokio::test] +async fn change_password() { + let path = format!("/{}", uuid_v7().as_simple()); + let (address, shutdown, server) = common::start_server(&path).await; + + // -- REGISTRATION -- + + let username = uuid_v7().as_simple().to_string(); + let password = uuid_v7().as_simple().to_string(); + let client = common::register_inner(&address, &username, &password).await; + + // the session token works + let status = client.status().await.unwrap(); + assert_eq!(status.username, username); + + // -- PASSWORD CHANGE -- + + let current_password = password; + let new_password = uuid_v7().as_simple().to_string(); + let result = client + .change_password(current_password, new_password.clone()) + .await; + + // the password change request succeeded + assert!(result.is_ok()); + + // -- LOGIN -- + + let client = common::login(&address, username.clone(), new_password).await; + + // login with new password yields a working token + let status = client.status().await.unwrap(); + assert_eq!(status.username, username); + + shutdown.send(()).unwrap(); + server.await.unwrap(); +} + +#[tokio::test] +async fn multi_user_test() { + let path = format!("/{}", uuid_v7().as_simple()); + let (address, shutdown, server) = common::start_server(&path).await; + dbg!(&address); + + // -- REGISTRATION -- + + let user_one = uuid_v7().as_simple().to_string(); + let password_one = uuid_v7().as_simple().to_string(); + let client_one = common::register_inner(&address, &user_one, &password_one).await; + + // the session token works + let status = client_one.status().await.unwrap(); + assert_eq!(status.username, user_one); + + let user_two = uuid_v7().as_simple().to_string(); + let password_two = uuid_v7().as_simple().to_string(); + let client_two = common::register_inner(&address, &user_two, &password_two).await; + + // the session token works + let status = client_two.status().await.unwrap(); + assert_eq!(status.username, user_two); + + // check that we can change user one's password, and _this does not affect user two_ + + let current_password = password_one; + let new_password = uuid_v7().as_simple().to_string(); + let result = client_one + .change_password(current_password, new_password.clone()) + .await; + + // the password change request succeeded + assert!(result.is_ok()); + + // -- LOGIN -- + + let client_one = common::login(&address, user_one.clone(), new_password).await; + let client_two = common::login(&address, user_two.clone(), password_two).await; + + // login with new password yields a working token + let status = client_one.status().await.unwrap(); + assert_eq!(status.username, user_one); + assert_ne!(status.username, user_two); + + let status = client_two.status().await.unwrap(); + assert_eq!(status.username, user_two); + + shutdown.send(()).unwrap(); + server.await.unwrap(); +} |
