diff options
| author | Ellie Huxtable <e@elm.sh> | 2021-04-13 19:14:07 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-04-13 19:14:07 +0100 |
| commit | 5751463942cc91f1f1ffaf6e2ac633d7a0085f25 (patch) | |
| tree | f7b5b9a4702c4c3ef29aa60d36612f61ffeae052 /src/local | |
| parent | Update config (diff) | |
| download | atuin-5751463942cc91f1f1ffaf6e2ac633d7a0085f25.zip | |
Add history sync, resolves #13 (#31)
* Add encryption
* Add login and register command
* Add count endpoint
* Write initial sync push
* Add single sync command
Confirmed working for one client only
* Automatically sync on a configurable frequency
* Add key command, key arg to login
* Only load session if it exists
* Use sync and history timestamps for download
* Bind other key code
Seems like some systems have this code for up arrow? I'm not sure why,
and it's not an easy one to google.
* Simplify upload
* Try and fix download sync loop
* Change sync order to avoid uploading what we just downloaded
* Multiline import fix
* Fix time parsing
* Fix importing history with no time
* Add hostname to sync
* Use hostname to filter sync
* Fixes
* Add binding
* Stuff from yesterday
* Set cursor modes
* Make clippy happy
* Bump version
Diffstat (limited to 'src/local')
| -rw-r--r-- | src/local/api_client.rs | 94 | ||||
| -rw-r--r-- | src/local/database.rs | 55 | ||||
| -rw-r--r-- | src/local/encryption.rs | 108 | ||||
| -rw-r--r-- | src/local/history.rs | 11 | ||||
| -rw-r--r-- | src/local/import.rs | 116 | ||||
| -rw-r--r-- | src/local/mod.rs | 3 | ||||
| -rw-r--r-- | src/local/sync.rs | 135 |
7 files changed, 481 insertions, 41 deletions
diff --git a/src/local/api_client.rs b/src/local/api_client.rs new file mode 100644 index 00000000..434c07ba --- /dev/null +++ b/src/local/api_client.rs @@ -0,0 +1,94 @@ +use chrono::Utc; +use eyre::Result; +use reqwest::header::AUTHORIZATION; + +use crate::api::{AddHistoryRequest, CountResponse, ListHistoryResponse}; +use crate::local::encryption::{decrypt, load_key}; +use crate::local::history::History; +use crate::settings::Settings; +use crate::utils::hash_str; + +pub struct Client<'a> { + settings: &'a Settings, +} + +impl<'a> Client<'a> { + pub const fn new(settings: &'a Settings) -> Self { + Client { settings } + } + + pub fn count(&self) -> Result<i64> { + let url = format!("{}/sync/count", self.settings.local.sync_address); + let client = reqwest::blocking::Client::new(); + + let resp = client + .get(url) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + let count = resp.json::<CountResponse>()?; + + Ok(count.count) + } + + pub fn get_history( + &self, + sync_ts: chrono::DateTime<Utc>, + history_ts: chrono::DateTime<Utc>, + host: Option<String>, + ) -> Result<Vec<History>> { + let key = load_key(self.settings)?; + + let host = match host { + None => hash_str(&format!("{}:{}", whoami::hostname(), whoami::username())), + Some(h) => h, + }; + + // this allows for syncing between users on the same machine + let url = format!( + "{}/sync/history?sync_ts={}&history_ts={}&host={}", + self.settings.local.sync_address, + sync_ts.to_rfc3339(), + history_ts.to_rfc3339(), + host, + ); + let client = reqwest::blocking::Client::new(); + + let resp = client + .get(url) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + let history = resp.json::<ListHistoryResponse>()?; + let history = history + .history + .iter() + .map(|h| serde_json::from_str(h).expect("invalid base64")) + .map(|h| decrypt(&h, &key).expect("failed to decrypt history! check your key")) + .collect(); + + Ok(history) + } + + pub fn post_history(&self, history: &[AddHistoryRequest]) -> Result<()> { + let client = reqwest::blocking::Client::new(); + + let url = format!("{}/history", self.settings.local.sync_address); + client + .post(url) + .json(history) + .header( + AUTHORIZATION, + format!("Token {}", self.settings.local.session_token), + ) + .send()?; + + Ok(()) + } +} diff --git a/src/local/database.rs b/src/local/database.rs index ad7078e5..977f11cc 100644 --- a/src/local/database.rs +++ b/src/local/database.rs @@ -1,3 +1,4 @@ +use chrono::prelude::*; use chrono::Utc; use std::path::Path; @@ -21,6 +22,10 @@ pub trait Database { fn update(&self, h: &History) -> Result<()>; fn history_count(&self) -> Result<i64>; + fn first(&self) -> Result<History>; + fn last(&self) -> Result<History>; + fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>>; + fn prefix_search(&self, query: &str) -> Result<Vec<History>>; } @@ -44,9 +49,7 @@ impl Sqlite { let conn = Connection::open(path)?; - if create { - Self::setup_db(&conn)?; - } + Self::setup_db(&conn)?; Ok(Self { conn }) } @@ -70,6 +73,14 @@ impl Sqlite { [], )?; + conn.execute( + "create table if not exists history_encrypted ( + id text primary key, + data blob not null + )", + [], + )?; + Ok(()) } @@ -87,7 +98,7 @@ impl Sqlite { ) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", params![ h.id, - h.timestamp, + h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, @@ -146,7 +157,7 @@ impl Database for Sqlite { "update history set timestamp = ?2, duration = ?3, exit = ?4, command = ?5, cwd = ?6, session = ?7, hostname = ?8 where id = ?1", - params![h.id, h.timestamp, h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], + params![h.id, h.timestamp.timestamp_nanos(), h.duration, h.exit, h.command, h.cwd, h.session, h.hostname], )?; Ok(()) @@ -183,6 +194,38 @@ impl Database for Sqlite { Ok(history_iter.filter_map(Result::ok).collect()) } + fn first(&self) -> Result<History> { + let mut stmt = self + .conn + .prepare("SELECT * FROM history order by timestamp asc limit 1")?; + + let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + + Ok(history) + } + + fn last(&self) -> Result<History> { + let mut stmt = self + .conn + .prepare("SELECT * FROM history order by timestamp desc limit 1")?; + + let history = stmt.query_row(params![], |row| history_from_sqlite_row(None, row))?; + + Ok(history) + } + + fn before(&self, timestamp: chrono::DateTime<Utc>, count: i64) -> Result<Vec<History>> { + let mut stmt = self.conn.prepare( + "SELECT * FROM history where timestamp <= ? order by timestamp desc limit ?", + )?; + + let history_iter = stmt.query_map(params![timestamp.timestamp_nanos(), count], |row| { + history_from_sqlite_row(None, row) + })?; + + Ok(history_iter.filter_map(Result::ok).collect()) + } + fn query(&self, query: &str, params: impl Params) -> Result<Vec<History>> { let mut stmt = self.conn.prepare(query)?; @@ -218,7 +261,7 @@ fn history_from_sqlite_row( Ok(History { id, - timestamp: row.get(1)?, + timestamp: Utc.timestamp_nanos(row.get(1)?), duration: row.get(2)?, exit: row.get(3)?, command: row.get(4)?, diff --git a/src/local/encryption.rs b/src/local/encryption.rs new file mode 100644 index 00000000..3c1699e3 --- /dev/null +++ b/src/local/encryption.rs @@ -0,0 +1,108 @@ +// The general idea is that we NEVER send cleartext history to the server +// This way the odds of anything private ending up where it should not are +// very low +// The server authenticates via the usual username and password. This has +// nothing to do with the encryption, and is purely authentication! The client +// generates its own secret key, and encrypts all shell history with libsodium's +// secretbox. The data is then sent to the server, where it is stored. All +// clients must share the secret in order to be able to sync, as it is needed +// to decrypt + +use std::fs::File; +use std::io::prelude::*; +use std::path::PathBuf; + +use eyre::{eyre, Result}; +use sodiumoxide::crypto::secretbox; + +use crate::local::history::History; +use crate::settings::Settings; + +#[derive(Debug, Serialize, Deserialize)] +pub struct EncryptedHistory { + pub ciphertext: Vec<u8>, + pub nonce: secretbox::Nonce, +} + +// Loads the secret key, will create + save if it doesn't exist +pub fn load_key(settings: &Settings) -> Result<secretbox::Key> { + let path = settings.local.key_path.as_str(); + + if PathBuf::from(path).exists() { + let bytes = std::fs::read(path)?; + let key: secretbox::Key = rmp_serde::from_read_ref(&bytes)?; + Ok(key) + } else { + let key = secretbox::gen_key(); + let buf = rmp_serde::to_vec(&key)?; + + let mut file = File::create(path)?; + file.write_all(&buf)?; + + Ok(key) + } +} + +pub fn encrypt(history: &History, key: &secretbox::Key) -> Result<EncryptedHistory> { + // serialize with msgpack + let buf = rmp_serde::to_vec(history)?; + + let nonce = secretbox::gen_nonce(); + + let ciphertext = secretbox::seal(&buf, &nonce, key); + + Ok(EncryptedHistory { ciphertext, nonce }) +} + +pub fn decrypt(encrypted_history: &EncryptedHistory, key: &secretbox::Key) -> Result<History> { + let plaintext = secretbox::open(&encrypted_history.ciphertext, &encrypted_history.nonce, key) + .map_err(|_| eyre!("failed to open secretbox - invalid key?"))?; + + let history = rmp_serde::from_read_ref(&plaintext)?; + + Ok(history) +} + +#[cfg(test)] +mod test { + use sodiumoxide::crypto::secretbox; + + use crate::local::history::History; + + use super::{decrypt, encrypt}; + + #[test] + fn test_encrypt_decrypt() { + let key1 = secretbox::gen_key(); + let key2 = secretbox::gen_key(); + + let history = History::new( + chrono::Utc::now(), + "ls".to_string(), + "/home/ellie".to_string(), + 0, + 1, + Some("beep boop".to_string()), + Some("booop".to_string()), + ); + + let e1 = encrypt(&history, &key1).unwrap(); + let e2 = encrypt(&history, &key2).unwrap(); + + assert_ne!(e1.ciphertext, e2.ciphertext); + assert_ne!(e1.nonce, e2.nonce); + + // test decryption works + // this should pass + match decrypt(&e1, &key1) { + Err(e) => assert!(false, "failed to decrypt, got {}", e), + Ok(h) => assert_eq!(h, history), + }; + + // this should err + match decrypt(&e2, &key1) { + Ok(_) => assert!(false, "expected an error decrypting with invalid key"), + Err(_) => {} + }; + } +} diff --git a/src/local/history.rs b/src/local/history.rs index 0ca112bd..1712f8b9 100644 --- a/src/local/history.rs +++ b/src/local/history.rs @@ -1,12 +1,15 @@ use std::env; use std::hash::{Hash, Hasher}; +use chrono::Utc; + use crate::command::uuid_v4; -#[derive(Debug, Clone)] +// Any new fields MUST be Optional<>! +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct History { pub id: String, - pub timestamp: i64, + pub timestamp: chrono::DateTime<Utc>, pub duration: i64, pub exit: i64, pub command: String, @@ -17,7 +20,7 @@ pub struct History { impl History { pub fn new( - timestamp: i64, + timestamp: chrono::DateTime<Utc>, command: String, cwd: String, exit: i64, @@ -29,7 +32,7 @@ impl History { .or_else(|| env::var("ATUIN_SESSION").ok()) .unwrap_or_else(uuid_v4); let hostname = - hostname.unwrap_or_else(|| hostname::get().unwrap().to_str().unwrap().to_string()); + hostname.unwrap_or_else(|| format!("{}:{}", whoami::hostname(), whoami::username())); Self { id: uuid_v4(), diff --git a/src/local/import.rs b/src/local/import.rs index 9bf79c72..d0f679c9 100644 --- a/src/local/import.rs +++ b/src/local/import.rs @@ -4,7 +4,9 @@ use std::io::{BufRead, BufReader, Seek, SeekFrom}; use std::{fs::File, path::Path}; -use eyre::{Result, WrapErr}; +use chrono::prelude::*; +use chrono::Utc; +use eyre::{eyre, Result}; use super::history::History; @@ -13,6 +15,7 @@ pub struct Zsh { file: BufReader<File>, pub loc: u64, + pub counter: i64, } // this could probably be sped up @@ -32,19 +35,23 @@ impl Zsh { Ok(Self { file: buf, loc: loc as u64, + counter: 0, }) } } -fn parse_extended(line: &str) -> History { +fn parse_extended(line: &str, counter: i64) -> History { let line = line.replacen(": ", "", 2); let (time, duration) = line.split_once(':').unwrap(); let (duration, command) = duration.split_once(';').unwrap(); - let time = time.parse::<i64>().map_or_else( - |_| chrono::Utc::now().timestamp_nanos(), - |t| t * 1_000_000_000, - ); + let time = time + .parse::<i64>() + .unwrap_or_else(|_| chrono::Utc::now().timestamp()); + + let offset = chrono::Duration::milliseconds(counter); + let time = Utc.timestamp(time, 0); + let time = time + offset; let duration = duration.parse::<i64>().map_or(-1, |t| t * 1_000_000_000); @@ -60,6 +67,18 @@ fn parse_extended(line: &str) -> History { ) } +impl Zsh { + fn read_line(&mut self) -> Option<Result<String>> { + let mut line = String::new(); + + match self.file.read_line(&mut line) { + Ok(0) => None, + Ok(_) => Some(Ok(line)), + Err(e) => Some(Err(eyre!("failed to read line: {}", e))), // we can skip past things like invalid utf8 + } + } +} + impl Iterator for Zsh { type Item = Result<History>; @@ -68,54 +87,89 @@ impl Iterator for Zsh { // These lines begin with : // So, if the line begins with :, parse it. Otherwise it's just // the command - let mut line = String::new(); + let line = self.read_line()?; - match self.file.read_line(&mut line) { - Ok(0) => None, - Ok(_) => { - let extended = line.starts_with(':'); + if let Err(e) = line { + return Some(Err(e)); // :( + } - if extended { - Some(Ok(parse_extended(line.as_str()))) - } else { - Some(Ok(History::new( - chrono::Utc::now().timestamp_nanos(), // what else? :/ - line.trim_end().to_string(), - String::from("unknown"), - -1, - -1, - None, - None, - ))) - } + let mut line = line.unwrap(); + + while line.ends_with("\\\n") { + let next_line = self.read_line()?; + + if next_line.is_err() { + // There's a chance that the last line of a command has invalid + // characters, the only safe thing to do is break :/ + // usually just invalid utf8 or smth + // however, we really need to avoid missing history, so it's + // better to have some items that should have been part of + // something else, than to miss things. So break. + break; } - Err(e) => Some(Err(e).wrap_err("failed to parse line")), + + line.push_str(next_line.unwrap().as_str()); + } + + // We have to handle the case where a line has escaped newlines. + // Keep reading until we have a non-escaped newline + + let extended = line.starts_with(':'); + + if extended { + self.counter += 1; + Some(Ok(parse_extended(line.as_str(), self.counter))) + } else { + let time = chrono::Utc::now(); + let offset = chrono::Duration::seconds(self.counter); + let time = time - offset; + + self.counter += 1; + + Some(Ok(History::new( + time, + line.trim_end().to_string(), + String::from("unknown"), + -1, + -1, + None, + None, + ))) } } } #[cfg(test)] mod test { + use chrono::prelude::*; + use chrono::Utc; + use super::parse_extended; #[test] fn test_parse_extended_simple() { - let parsed = parse_extended(": 1613322469:0;cargo install atuin"); + let parsed = parse_extended(": 1613322469:0;cargo install atuin", 0); assert_eq!(parsed.command, "cargo install atuin"); assert_eq!(parsed.duration, 0); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update"); + let parsed = parse_extended(": 1613322469:10;cargo install atuin;cargo update", 0); assert_eq!(parsed.command, "cargo install atuin;cargo update"); assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); - let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); + let parsed = parse_extended(": 1613322469:10;cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷", 0); assert_eq!(parsed.command, "cargo :b̷i̶t̴r̵o̴t̴ ̵i̷s̴ ̷r̶e̵a̸l̷"); assert_eq!(parsed.duration, 10_000_000_000); - assert_eq!(parsed.timestamp, 1_613_322_469_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); + + let parsed = parse_extended(": 1613322469:10;cargo install \\n atuin\n", 0); + + assert_eq!(parsed.command, "cargo install \\n atuin"); + assert_eq!(parsed.duration, 10_000_000_000); + assert_eq!(parsed.timestamp, Utc.timestamp(1_613_322_469, 0)); } } diff --git a/src/local/mod.rs b/src/local/mod.rs index a11ee213..9fe31292 100644 --- a/src/local/mod.rs +++ b/src/local/mod.rs @@ -1,3 +1,6 @@ +pub mod api_client; pub mod database; +pub mod encryption; pub mod history; pub mod import; +pub mod sync; diff --git a/src/local/sync.rs b/src/local/sync.rs new file mode 100644 index 00000000..c22d2f27 --- /dev/null +++ b/src/local/sync.rs @@ -0,0 +1,135 @@ +use std::convert::TryInto; + +use chrono::prelude::*; +use eyre::Result; + +use crate::local::api_client; +use crate::local::database::Database; +use crate::local::encryption::{encrypt, load_key}; +use crate::settings::{Local, Settings, HISTORY_PAGE_SIZE}; +use crate::{api::AddHistoryRequest, utils::hash_str}; + +// Currently sync is kinda naive, and basically just pages backwards through +// history. This means newly added stuff shows up properly! We also just use +// the total count in each database to indicate whether a sync is needed. +// I think this could be massively improved! If we had a way of easily +// indicating count per time period (hour, day, week, year, etc) then we can +// easily pinpoint where we are missing data and what needs downloading. Start +// with year, then find the week, then the day, then the hour, then download it +// all! The current naive approach will do for now. + +// Check if remote has things we don't, and if so, download them. +// Returns (num downloaded, total local) +fn sync_download( + force: bool, + client: &api_client::Client, + db: &mut impl Database, +) -> Result<(i64, i64)> { + let remote_count = client.count()?; + + let initial_local = db.history_count()?; + let mut local_count = initial_local; + + let mut last_sync = if force { + Utc.timestamp_millis(0) + } else { + Local::last_sync()? + }; + + let mut last_timestamp = Utc.timestamp_millis(0); + + let host = if force { Some(String::from("")) } else { None }; + + while remote_count > local_count { + let page = client.get_history(last_sync, last_timestamp, host.clone())?; + + if page.len() < HISTORY_PAGE_SIZE.try_into().unwrap() { + break; + } + + db.save_bulk(&page)?; + + local_count = db.history_count()?; + + let page_last = page + .last() + .expect("could not get last element of page") + .timestamp; + + // in the case of a small sync frequency, it's possible for history to + // be "lost" between syncs. In this case we need to rewind the sync + // timestamps + if page_last == last_timestamp { + last_timestamp = Utc.timestamp_millis(0); + last_sync = last_sync - chrono::Duration::hours(1); + } else { + last_timestamp = page_last; + } + } + + Ok((local_count - initial_local, local_count)) +} + +// Check if we have things remote doesn't, and if so, upload them +fn sync_upload( + settings: &Settings, + _force: bool, + client: &api_client::Client, + db: &mut impl Database, +) -> Result<()> { + let initial_remote_count = client.count()?; + let mut remote_count = initial_remote_count; + + let local_count = db.history_count()?; + + let key = load_key(settings)?; // encryption key + + // first just try the most recent set + + let mut cursor = Utc::now(); + + while local_count > remote_count { + let last = db.before(cursor, HISTORY_PAGE_SIZE)?; + let mut buffer = Vec::<AddHistoryRequest>::new(); + + if last.is_empty() { + break; + } + + for i in last { + let data = encrypt(&i, &key)?; + let data = serde_json::to_string(&data)?; + + let add_hist = AddHistoryRequest { + id: i.id, + timestamp: i.timestamp, + data, + hostname: hash_str(i.hostname.as_str()), + }; + + buffer.push(add_hist); + } + + // anything left over outside of the 100 block size + client.post_history(&buffer)?; + cursor = buffer.last().unwrap().timestamp; + + remote_count = client.count()?; + } + + Ok(()) +} + +pub fn sync(settings: &Settings, force: bool, db: &mut impl Database) -> Result<()> { + let client = api_client::Client::new(settings); + + sync_upload(settings, force, &client, db)?; + + let download = sync_download(force, &client, db)?; + + debug!("sync downloaded {}", download.0); + + Local::save_sync_time()?; + + Ok(()) +} |
